1#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::Stream;
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::{pin::Pin, task::ready};
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::task::{Context, Poll};
15use std::{
16 sync::{Arc, Mutex},
17 time::Duration,
18};
19use tokio::time::Instant;
20use tower_service::Service;
21use tracing::trace;
22
23#[derive(Debug)]
44pub struct PeakEwma<S, C = CompleteOnResponse> {
45 service: S,
46 decay_ns: f64,
47 rtt_estimate: Arc<Mutex<RttEstimate>>,
48 completion: C,
49}
50
51#[cfg(feature = "discover")]
52pin_project! {
53 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
55 #[derive(Debug)]
56 pub struct PeakEwmaDiscover<D, C = CompleteOnResponse> {
57 #[pin]
58 discover: D,
59 decay_ns: f64,
60 default_rtt: Duration,
61 completion: C,
62 }
63}
64
65#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
70pub struct Cost(f64);
71
72#[derive(Debug)]
74pub struct Handle {
75 sent_at: Instant,
76 decay_ns: f64,
77 rtt_estimate: Arc<Mutex<RttEstimate>>,
78}
79
80#[derive(Clone, Debug)]
82pub struct RttEstimate {
83 update_at: Instant,
84 rtt_ns: f64,
85}
86
87const NANOS_PER_MILLI: f64 = 1_000_000.0;
88
89impl<S, C> PeakEwma<S, C> {
92 pub fn new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self {
94 debug_assert!(decay_ns > 0.0, "decay_ns must be positive");
95 Self {
96 service,
97 decay_ns,
98 rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
99 completion,
100 }
101 }
102
103 fn handle(&self) -> Handle {
104 Handle {
105 decay_ns: self.decay_ns,
106 sent_at: Instant::now(),
107 rtt_estimate: self.rtt_estimate.clone(),
108 }
109 }
110
111 pub fn rtt_estimate(&self) -> RttEstimate {
117 self.rtt_estimate
118 .lock()
119 .expect("mutex should not be poisoned")
120 .clone()
121 }
122}
123
124impl<S, C, Request> Service<Request> for PeakEwma<S, C>
125where
126 S: Service<Request>,
127 C: TrackCompletion<Handle, S::Response>,
128{
129 type Response = C::Output;
130 type Error = S::Error;
131 type Future = TrackCompletionFuture<S::Future, C, Handle>;
132
133 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134 self.service.poll_ready(cx)
135 }
136
137 fn call(&mut self, req: Request) -> Self::Future {
138 TrackCompletionFuture::new(
139 self.completion.clone(),
140 self.handle(),
141 self.service.call(req),
142 )
143 }
144}
145
146impl<S, C> Load for PeakEwma<S, C> {
147 type Metric = Cost;
148
149 fn load(&self) -> Self::Metric {
150 let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
151
152 let estimate = self.update_estimate();
155
156 let cost = Cost(estimate * f64::from(pending + 1));
157 trace!(
158 "load estimate={:.0}ms pending={} cost={:?}",
159 estimate / NANOS_PER_MILLI,
160 pending,
161 cost,
162 );
163 cost
164 }
165}
166
167impl<S, C> PeakEwma<S, C> {
168 fn update_estimate(&self) -> f64 {
169 let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
170 rtt.decay(self.decay_ns)
171 }
172}
173
174#[cfg(feature = "discover")]
177impl<D, C> PeakEwmaDiscover<D, C> {
178 pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self
186 where
187 D: Discover,
188 D::Service: Service<Request>,
189 C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
190 {
191 PeakEwmaDiscover {
192 discover,
193 decay_ns: nanos(decay),
194 default_rtt,
195 completion,
196 }
197 }
198}
199
200#[cfg(feature = "discover")]
201impl<D, C> Stream for PeakEwmaDiscover<D, C>
202where
203 D: Discover,
204 C: Clone,
205{
206 type Item = Result<Change<D::Key, PeakEwma<D::Service, C>>, D::Error>;
207
208 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
209 let this = self.project();
210 let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
211 None => return Poll::Ready(None),
212 Some(Change::Remove(k)) => Change::Remove(k),
213 Some(Change::Insert(k, svc)) => {
214 let peak_ewma = PeakEwma::new(
215 svc,
216 *this.default_rtt,
217 *this.decay_ns,
218 this.completion.clone(),
219 );
220 Change::Insert(k, peak_ewma)
221 }
222 };
223
224 Poll::Ready(Some(Ok(change)))
225 }
226}
227
228impl RttEstimate {
231 pub fn updated_at(&self) -> Instant {
233 self.update_at
234 }
235
236 pub fn rtt_ns(&self) -> f64 {
238 self.rtt_ns
239 }
240
241 fn new(rtt_ns: f64) -> Self {
242 debug_assert!(0.0 < rtt_ns, "rtt must be positive");
243 Self {
244 rtt_ns,
245 update_at: Instant::now(),
246 }
247 }
248
249 fn decay(&mut self, decay_ns: f64) -> f64 {
251 let now = Instant::now();
253 self.update(now, now, decay_ns)
254 }
255
256 fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
260 debug_assert!(
261 sent_at <= recv_at,
262 "recv_at={:?} after sent_at={:?}",
263 recv_at,
264 sent_at
265 );
266 let rtt = nanos(recv_at.saturating_duration_since(sent_at));
267
268 let now = Instant::now();
269 debug_assert!(
270 self.update_at <= now,
271 "update_at={:?} in the future",
272 self.update_at
273 );
274
275 self.rtt_ns = if self.rtt_ns < rtt {
276 trace!(
279 "update peak rtt={}ms prior={}ms",
280 rtt / NANOS_PER_MILLI,
281 self.rtt_ns / NANOS_PER_MILLI,
282 );
283 rtt
284 } else {
285 let elapsed = nanos(now.saturating_duration_since(self.update_at));
290 let decay = (-elapsed / decay_ns).exp();
291 let recency = 1.0 - decay;
292 let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
293 trace!(
294 "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
295 rtt / NANOS_PER_MILLI,
296 self.rtt_ns - next_estimate,
297 next_estimate / NANOS_PER_MILLI,
298 );
299 next_estimate
300 };
301 self.update_at = now;
302
303 self.rtt_ns
304 }
305}
306
307impl Drop for Handle {
310 fn drop(&mut self) {
311 let recv_at = Instant::now();
312
313 if let Ok(mut rtt) = self.rtt_estimate.lock() {
314 rtt.update(self.sent_at, recv_at, self.decay_ns);
315 }
316 }
317}
318
319fn nanos(d: Duration) -> f64 {
326 const NANOS_PER_SEC: u64 = 1_000_000_000;
327 let n = f64::from(d.subsec_nanos());
328 let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
329 n + s
330}
331
332#[cfg(test)]
333mod tests {
334 use std::{future, time::Duration};
335 use tokio::time;
336 use tokio_test::{assert_ready, assert_ready_ok, task};
337
338 use super::*;
339
340 struct Svc;
341 impl Service<()> for Svc {
342 type Response = ();
343 type Error = ();
344 type Future = future::Ready<Result<(), ()>>;
345
346 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
347 Poll::Ready(Ok(()))
348 }
349
350 fn call(&mut self, (): ()) -> Self::Future {
351 future::ready(Ok(()))
352 }
353 }
354
355 #[tokio::test]
358 async fn default_decay() {
359 time::pause();
360
361 let svc = PeakEwma::new(
362 Svc,
363 Duration::from_millis(10),
364 NANOS_PER_MILLI * 1_000.0,
365 CompleteOnResponse,
366 );
367 let Cost(load) = svc.load();
368 assert_eq!(load, 10.0 * NANOS_PER_MILLI);
369
370 time::advance(Duration::from_millis(100)).await;
371 let Cost(load) = svc.load();
372 assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
373
374 time::advance(Duration::from_millis(100)).await;
375 let Cost(load) = svc.load();
376 assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
377 }
378
379 #[tokio::test]
382 async fn compound_decay() {
383 time::pause();
384
385 let mut svc = PeakEwma::new(
386 Svc,
387 Duration::from_millis(20),
388 NANOS_PER_MILLI * 1_000.0,
389 CompleteOnResponse,
390 );
391 assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
392
393 time::advance(Duration::from_millis(100)).await;
394 let mut rsp0 = task::spawn(svc.call(()));
395 assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
396
397 time::advance(Duration::from_millis(100)).await;
398 let mut rsp1 = task::spawn(svc.call(()));
399 assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
400
401 time::advance(Duration::from_millis(100)).await;
402 let () = assert_ready_ok!(rsp0.poll());
403 assert_eq!(svc.load(), Cost(400_000_000.0));
404
405 time::advance(Duration::from_millis(100)).await;
406 let () = assert_ready_ok!(rsp1.poll());
407 assert_eq!(svc.load(), Cost(200_000_000.0));
408
409 time::advance(Duration::from_secs(1)).await;
411 assert!(svc.load() < Cost(100_000_000.0));
412
413 time::advance(Duration::from_secs(10)).await;
414 assert!(svc.load() < Cost(100_000.0));
415 }
416
417 #[test]
418 fn nanos() {
419 assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
420 assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
421 assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
422 assert_eq!(
423 super::nanos(Duration::new(u64::MAX, 999_999_999)),
424 18446744074709553000.0
425 );
426 }
427}