tower/limit/rate/
service.rs

1use super::Rate;
2use std::{
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll},
6};
7use tokio::time::{Instant, Sleep};
8use tower_service::Service;
9
10/// Enforces a rate limit on the number of requests the underlying
11/// service can handle over a period of time.
12#[derive(Debug)]
13pub struct RateLimit<T> {
14    inner: T,
15    rate: Rate,
16    state: State,
17    sleep: Pin<Box<Sleep>>,
18}
19
20#[derive(Debug)]
21enum State {
22    // The service has hit its limit
23    Limited,
24    Ready { until: Instant, rem: u64 },
25}
26
27impl<T> RateLimit<T> {
28    /// Create a new rate limiter
29    pub fn new(inner: T, rate: Rate) -> Self {
30        let until = Instant::now();
31        let state = State::Ready {
32            until,
33            rem: rate.num(),
34        };
35
36        RateLimit {
37            inner,
38            rate,
39            state,
40            // The sleep won't actually be used with this duration, but
41            // we create it eagerly so that we can reset it in place rather than
42            // `Box::pin`ning a new `Sleep` every time we need one.
43            sleep: Box::pin(tokio::time::sleep_until(until)),
44        }
45    }
46
47    /// Get a reference to the inner service
48    pub fn get_ref(&self) -> &T {
49        &self.inner
50    }
51
52    /// Get a mutable reference to the inner service
53    pub fn get_mut(&mut self) -> &mut T {
54        &mut self.inner
55    }
56
57    /// Consume `self`, returning the inner service
58    pub fn into_inner(self) -> T {
59        self.inner
60    }
61}
62
63impl<S, Request> Service<Request> for RateLimit<S>
64where
65    S: Service<Request>,
66{
67    type Response = S::Response;
68    type Error = S::Error;
69    type Future = S::Future;
70
71    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
72        match self.state {
73            State::Ready { .. } => return self.inner.poll_ready(cx),
74            State::Limited => {
75                if Pin::new(&mut self.sleep).poll(cx).is_pending() {
76                    tracing::trace!("rate limit exceeded; sleeping.");
77                    return Poll::Pending;
78                }
79            }
80        }
81
82        self.state = State::Ready {
83            until: Instant::now() + self.rate.per(),
84            rem: self.rate.num(),
85        };
86
87        self.inner.poll_ready(cx)
88    }
89
90    fn call(&mut self, request: Request) -> Self::Future {
91        match self.state {
92            State::Ready { mut until, mut rem } => {
93                let now = Instant::now();
94
95                // If the period has elapsed, reset it.
96                if now >= until {
97                    until = now + self.rate.per();
98                    rem = self.rate.num();
99                }
100
101                if rem > 1 {
102                    rem -= 1;
103                    self.state = State::Ready { until, rem };
104                } else {
105                    // The service is disabled until further notice
106                    // Reset the sleep future in place, so that we don't have to
107                    // deallocate the existing box and allocate a new one.
108                    self.sleep.as_mut().reset(until);
109                    self.state = State::Limited;
110                }
111
112                // Call the inner future
113                self.inner.call(request)
114            }
115            State::Limited => panic!("service not ready; poll_ready must be called first"),
116        }
117    }
118}
119
120#[cfg(feature = "load")]
121impl<S> crate::load::Load for RateLimit<S>
122where
123    S: crate::load::Load,
124{
125    type Metric = S::Metric;
126    fn load(&self) -> Self::Metric {
127        self.inner.load()
128    }
129}