tower/limit/rate/
service.rs1use super::Rate;
2use std::{
3 future::Future,
4 pin::Pin,
5 task::{Context, Poll},
6};
7use tokio::time::{Instant, Sleep};
8use tower_service::Service;
9
10#[derive(Debug)]
13pub struct RateLimit<T> {
14 inner: T,
15 rate: Rate,
16 state: State,
17 sleep: Pin<Box<Sleep>>,
18}
19
20#[derive(Debug)]
21enum State {
22 Limited,
24 Ready { until: Instant, rem: u64 },
25}
26
27impl<T> RateLimit<T> {
28 pub fn new(inner: T, rate: Rate) -> Self {
30 let until = Instant::now();
31 let state = State::Ready {
32 until,
33 rem: rate.num(),
34 };
35
36 RateLimit {
37 inner,
38 rate,
39 state,
40 sleep: Box::pin(tokio::time::sleep_until(until)),
44 }
45 }
46
47 pub fn get_ref(&self) -> &T {
49 &self.inner
50 }
51
52 pub fn get_mut(&mut self) -> &mut T {
54 &mut self.inner
55 }
56
57 pub fn into_inner(self) -> T {
59 self.inner
60 }
61}
62
63impl<S, Request> Service<Request> for RateLimit<S>
64where
65 S: Service<Request>,
66{
67 type Response = S::Response;
68 type Error = S::Error;
69 type Future = S::Future;
70
71 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
72 match self.state {
73 State::Ready { .. } => return self.inner.poll_ready(cx),
74 State::Limited => {
75 if Pin::new(&mut self.sleep).poll(cx).is_pending() {
76 tracing::trace!("rate limit exceeded; sleeping.");
77 return Poll::Pending;
78 }
79 }
80 }
81
82 self.state = State::Ready {
83 until: Instant::now() + self.rate.per(),
84 rem: self.rate.num(),
85 };
86
87 self.inner.poll_ready(cx)
88 }
89
90 fn call(&mut self, request: Request) -> Self::Future {
91 match self.state {
92 State::Ready { mut until, mut rem } => {
93 let now = Instant::now();
94
95 if now >= until {
97 until = now + self.rate.per();
98 rem = self.rate.num();
99 }
100
101 if rem > 1 {
102 rem -= 1;
103 self.state = State::Ready { until, rem };
104 } else {
105 self.sleep.as_mut().reset(until);
109 self.state = State::Limited;
110 }
111
112 self.inner.call(request)
114 }
115 State::Limited => panic!("service not ready; poll_ready must be called first"),
116 }
117 }
118}
119
120#[cfg(feature = "load")]
121impl<S> crate::load::Load for RateLimit<S>
122where
123 S: crate::load::Load,
124{
125 type Metric = S::Metric;
126 fn load(&self) -> Self::Metric {
127 self.inner.load()
128 }
129}