Skip to main content

tower/load/
peak_ewma.rs

1//! A `Load` implementation that measures load using the PeakEWMA response latency.
2
3#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::Stream;
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::{pin::Pin, task::ready};
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::task::{Context, Poll};
15use std::{
16    sync::{Arc, Mutex},
17    time::Duration,
18};
19use tokio::time::Instant;
20use tower_service::Service;
21use tracing::trace;
22
23/// Measures the load of the underlying service using Peak-EWMA load measurement.
24///
25/// [`PeakEwma`] implements [`Load`] with the [`Cost`] metric that estimates the amount of
26/// pending work to an endpoint. Work is calculated by multiplying the
27/// exponentially-weighted moving average (EWMA) of response latencies by the number of
28/// pending requests. The Peak-EWMA algorithm is designed to be especially sensitive to
29/// worst-case latencies. Over time, the peak latency value decays towards the moving
30/// average of latencies to the endpoint.
31///
32/// When no latency information has been measured for an endpoint, an arbitrary default
33/// RTT of 1 second is used to prevent the endpoint from being overloaded before a
34/// meaningful baseline can be established..
35///
36/// ## Note
37///
38/// This is derived from [Finagle][finagle], which is distributed under the Apache V2
39/// license. Copyright 2017, Twitter Inc.
40///
41/// [finagle]:
42/// https://github.com/twitter/finagle/blob/9cc08d15216497bb03a1cafda96b7266cfbbcff1/finagle-core/src/main/scala/com/twitter/finagle/loadbalancer/PeakEwma.scala
43#[derive(Debug)]
44pub struct PeakEwma<S, C = CompleteOnResponse> {
45    service: S,
46    decay_ns: f64,
47    rtt_estimate: Arc<Mutex<RttEstimate>>,
48    completion: C,
49}
50
51#[cfg(feature = "discover")]
52pin_project! {
53    /// Wraps a `D`-typed stream of discovered services with `PeakEwma`.
54    #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
55    #[derive(Debug)]
56    pub struct PeakEwmaDiscover<D, C = CompleteOnResponse> {
57        #[pin]
58        discover: D,
59        decay_ns: f64,
60        default_rtt: Duration,
61        completion: C,
62    }
63}
64
65/// Represents the relative cost of communicating with a service.
66///
67/// The underlying value estimates the amount of pending work to a service: the Peak-EWMA
68/// latency estimate multiplied by the number of pending requests.
69#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
70pub struct Cost(f64);
71
72/// Tracks an in-flight request and updates the RTT-estimate on Drop.
73#[derive(Debug)]
74pub struct Handle {
75    sent_at: Instant,
76    decay_ns: f64,
77    rtt_estimate: Arc<Mutex<RttEstimate>>,
78}
79
80/// Holds the current RTT estimate and the last time this value was updated.
81#[derive(Clone, Debug)]
82pub struct RttEstimate {
83    update_at: Instant,
84    rtt_ns: f64,
85}
86
87const NANOS_PER_MILLI: f64 = 1_000_000.0;
88
89// ===== impl PeakEwma =====
90
91impl<S, C> PeakEwma<S, C> {
92    /// Wraps an `S`-typed service so that its load is tracked by the EWMA of its peak latency.
93    pub fn new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self {
94        debug_assert!(decay_ns > 0.0, "decay_ns must be positive");
95        Self {
96            service,
97            decay_ns,
98            rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
99            completion,
100        }
101    }
102
103    fn handle(&self) -> Handle {
104        Handle {
105            decay_ns: self.decay_ns,
106            sent_at: Instant::now(),
107            rtt_estimate: self.rtt_estimate.clone(),
108        }
109    }
110
111    /// Returns the current [`RttEstimate`] of the service.
112    ///
113    /// # Panics
114    ///
115    /// This value is stored in a mutex. If the mutex has become poisoned, this will panic.
116    pub fn rtt_estimate(&self) -> RttEstimate {
117        self.rtt_estimate
118            .lock()
119            .expect("mutex should not be poisoned")
120            .clone()
121    }
122}
123
124impl<S, C, Request> Service<Request> for PeakEwma<S, C>
125where
126    S: Service<Request>,
127    C: TrackCompletion<Handle, S::Response>,
128{
129    type Response = C::Output;
130    type Error = S::Error;
131    type Future = TrackCompletionFuture<S::Future, C, Handle>;
132
133    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134        self.service.poll_ready(cx)
135    }
136
137    fn call(&mut self, req: Request) -> Self::Future {
138        TrackCompletionFuture::new(
139            self.completion.clone(),
140            self.handle(),
141            self.service.call(req),
142        )
143    }
144}
145
146impl<S, C> Load for PeakEwma<S, C> {
147    type Metric = Cost;
148
149    fn load(&self) -> Self::Metric {
150        let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
151
152        // Update the RTT estimate to account for decay since the last update.
153        // If an estimate has not been established, a default is provided
154        let estimate = self.update_estimate();
155
156        let cost = Cost(estimate * f64::from(pending + 1));
157        trace!(
158            "load estimate={:.0}ms pending={} cost={:?}",
159            estimate / NANOS_PER_MILLI,
160            pending,
161            cost,
162        );
163        cost
164    }
165}
166
167impl<S, C> PeakEwma<S, C> {
168    fn update_estimate(&self) -> f64 {
169        let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
170        rtt.decay(self.decay_ns)
171    }
172}
173
174// ===== impl PeakEwmaDiscover =====
175
176#[cfg(feature = "discover")]
177impl<D, C> PeakEwmaDiscover<D, C> {
178    /// Wraps a `D`-typed [`Discover`] so that services have a [`PeakEwma`] load metric.
179    ///
180    /// The provided `default_rtt` is used as the default RTT estimate for newly
181    /// added services.
182    ///
183    /// They `decay` value determines over what time period a RTT estimate should
184    /// decay.
185    pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self
186    where
187        D: Discover,
188        D::Service: Service<Request>,
189        C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
190    {
191        PeakEwmaDiscover {
192            discover,
193            decay_ns: nanos(decay),
194            default_rtt,
195            completion,
196        }
197    }
198}
199
200#[cfg(feature = "discover")]
201impl<D, C> Stream for PeakEwmaDiscover<D, C>
202where
203    D: Discover,
204    C: Clone,
205{
206    type Item = Result<Change<D::Key, PeakEwma<D::Service, C>>, D::Error>;
207
208    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
209        let this = self.project();
210        let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
211            None => return Poll::Ready(None),
212            Some(Change::Remove(k)) => Change::Remove(k),
213            Some(Change::Insert(k, svc)) => {
214                let peak_ewma = PeakEwma::new(
215                    svc,
216                    *this.default_rtt,
217                    *this.decay_ns,
218                    this.completion.clone(),
219                );
220                Change::Insert(k, peak_ewma)
221            }
222        };
223
224        Poll::Ready(Some(Ok(change)))
225    }
226}
227
228// ===== impl RttEstimate =====
229
230impl RttEstimate {
231    /// Returns the [`Instant`] that this estimate was last updated.
232    pub fn updated_at(&self) -> Instant {
233        self.update_at
234    }
235
236    /// Returns the round-trip time estimate, in nanoseconds.
237    pub fn rtt_ns(&self) -> f64 {
238        self.rtt_ns
239    }
240
241    fn new(rtt_ns: f64) -> Self {
242        debug_assert!(0.0 < rtt_ns, "rtt must be positive");
243        Self {
244            rtt_ns,
245            update_at: Instant::now(),
246        }
247    }
248
249    /// Decays the RTT estimate with a decay period of `decay_ns`.
250    fn decay(&mut self, decay_ns: f64) -> f64 {
251        // Updates with a 0 duration so that the estimate decays towards 0.
252        let now = Instant::now();
253        self.update(now, now, decay_ns)
254    }
255
256    /// Updates the Peak-EWMA RTT estimate.
257    ///
258    /// The elapsed time from `sent_at` to `recv_at` is added
259    fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
260        debug_assert!(
261            sent_at <= recv_at,
262            "recv_at={:?} after sent_at={:?}",
263            recv_at,
264            sent_at
265        );
266        let rtt = nanos(recv_at.saturating_duration_since(sent_at));
267
268        let now = Instant::now();
269        debug_assert!(
270            self.update_at <= now,
271            "update_at={:?} in the future",
272            self.update_at
273        );
274
275        self.rtt_ns = if self.rtt_ns < rtt {
276            // For Peak-EWMA, always use the worst-case (peak) value as the estimate for
277            // subsequent requests.
278            trace!(
279                "update peak rtt={}ms prior={}ms",
280                rtt / NANOS_PER_MILLI,
281                self.rtt_ns / NANOS_PER_MILLI,
282            );
283            rtt
284        } else {
285            // When an RTT is observed that is less than the estimated RTT, we decay the
286            // prior estimate according to how much time has elapsed since the last
287            // update. The inverse of the decay is used to scale the estimate towards the
288            // observed RTT value.
289            let elapsed = nanos(now.saturating_duration_since(self.update_at));
290            let decay = (-elapsed / decay_ns).exp();
291            let recency = 1.0 - decay;
292            let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
293            trace!(
294                "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
295                rtt / NANOS_PER_MILLI,
296                self.rtt_ns - next_estimate,
297                next_estimate / NANOS_PER_MILLI,
298            );
299            next_estimate
300        };
301        self.update_at = now;
302
303        self.rtt_ns
304    }
305}
306
307// ===== impl Handle =====
308
309impl Drop for Handle {
310    fn drop(&mut self) {
311        let recv_at = Instant::now();
312
313        if let Ok(mut rtt) = self.rtt_estimate.lock() {
314            rtt.update(self.sent_at, recv_at, self.decay_ns);
315        }
316    }
317}
318
319// ===== impl Cost =====
320
321// Utility that converts durations to nanos in f64.
322//
323// Due to a lossy transformation, the maximum value that can be represented is ~585 years,
324// which, I hope, is more than enough to represent request latencies.
325fn nanos(d: Duration) -> f64 {
326    const NANOS_PER_SEC: u64 = 1_000_000_000;
327    let n = f64::from(d.subsec_nanos());
328    let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
329    n + s
330}
331
332#[cfg(test)]
333mod tests {
334    use std::{future, time::Duration};
335    use tokio::time;
336    use tokio_test::{assert_ready, assert_ready_ok, task};
337
338    use super::*;
339
340    struct Svc;
341    impl Service<()> for Svc {
342        type Response = ();
343        type Error = ();
344        type Future = future::Ready<Result<(), ()>>;
345
346        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
347            Poll::Ready(Ok(()))
348        }
349
350        fn call(&mut self, (): ()) -> Self::Future {
351            future::ready(Ok(()))
352        }
353    }
354
355    /// The default RTT estimate decays, so that new nodes are considered if the
356    /// default RTT is too high.
357    #[tokio::test]
358    async fn default_decay() {
359        time::pause();
360
361        let svc = PeakEwma::new(
362            Svc,
363            Duration::from_millis(10),
364            NANOS_PER_MILLI * 1_000.0,
365            CompleteOnResponse,
366        );
367        let Cost(load) = svc.load();
368        assert_eq!(load, 10.0 * NANOS_PER_MILLI);
369
370        time::advance(Duration::from_millis(100)).await;
371        let Cost(load) = svc.load();
372        assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
373
374        time::advance(Duration::from_millis(100)).await;
375        let Cost(load) = svc.load();
376        assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
377    }
378
379    // The default RTT estimate decays, so that new nodes are considered if the default RTT is too
380    // high.
381    #[tokio::test]
382    async fn compound_decay() {
383        time::pause();
384
385        let mut svc = PeakEwma::new(
386            Svc,
387            Duration::from_millis(20),
388            NANOS_PER_MILLI * 1_000.0,
389            CompleteOnResponse,
390        );
391        assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
392
393        time::advance(Duration::from_millis(100)).await;
394        let mut rsp0 = task::spawn(svc.call(()));
395        assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
396
397        time::advance(Duration::from_millis(100)).await;
398        let mut rsp1 = task::spawn(svc.call(()));
399        assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
400
401        time::advance(Duration::from_millis(100)).await;
402        let () = assert_ready_ok!(rsp0.poll());
403        assert_eq!(svc.load(), Cost(400_000_000.0));
404
405        time::advance(Duration::from_millis(100)).await;
406        let () = assert_ready_ok!(rsp1.poll());
407        assert_eq!(svc.load(), Cost(200_000_000.0));
408
409        // Check that values decay as time elapses
410        time::advance(Duration::from_secs(1)).await;
411        assert!(svc.load() < Cost(100_000_000.0));
412
413        time::advance(Duration::from_secs(10)).await;
414        assert!(svc.load() < Cost(100_000.0));
415    }
416
417    #[test]
418    fn nanos() {
419        assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
420        assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
421        assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
422        assert_eq!(
423            super::nanos(Duration::new(u64::MAX, 999_999_999)),
424            18446744074709553000.0
425        );
426    }
427}