tower/limit/concurrency/
service.rs

1use super::future::ResponseFuture;
2use tokio::sync::{OwnedSemaphorePermit, Semaphore};
3use tokio_util::sync::PollSemaphore;
4use tower_service::Service;
5
6use std::{
7    sync::Arc,
8    task::{ready, Context, Poll},
9};
10
11/// Enforces a limit on the concurrent number of requests the underlying
12/// service can handle.
13#[derive(Debug)]
14pub struct ConcurrencyLimit<T> {
15    inner: T,
16    semaphore: PollSemaphore,
17    /// The currently acquired semaphore permit, if there is sufficient
18    /// concurrency to send a new request.
19    ///
20    /// The permit is acquired in `poll_ready`, and taken in `call` when sending
21    /// a new request.
22    permit: Option<OwnedSemaphorePermit>,
23}
24
25impl<T> ConcurrencyLimit<T> {
26    /// Create a new concurrency limiter.
27    pub fn new(inner: T, max: usize) -> Self {
28        Self::with_semaphore(inner, Arc::new(Semaphore::new(max)))
29    }
30
31    /// Create a new concurrency limiter with a provided shared semaphore
32    pub fn with_semaphore(inner: T, semaphore: Arc<Semaphore>) -> Self {
33        ConcurrencyLimit {
34            inner,
35            semaphore: PollSemaphore::new(semaphore),
36            permit: None,
37        }
38    }
39
40    /// Get a reference to the inner service
41    pub fn get_ref(&self) -> &T {
42        &self.inner
43    }
44
45    /// Get a mutable reference to the inner service
46    pub fn get_mut(&mut self) -> &mut T {
47        &mut self.inner
48    }
49
50    /// Consume `self`, returning the inner service
51    pub fn into_inner(self) -> T {
52        self.inner
53    }
54}
55
56impl<S, Request> Service<Request> for ConcurrencyLimit<S>
57where
58    S: Service<Request>,
59{
60    type Response = S::Response;
61    type Error = S::Error;
62    type Future = ResponseFuture<S::Future>;
63
64    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65        // If we haven't already acquired a permit from the semaphore, try to
66        // acquire one first.
67        if self.permit.is_none() {
68            self.permit = ready!(self.semaphore.poll_acquire(cx));
69            debug_assert!(
70                self.permit.is_some(),
71                "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
72                 should never fail",
73            );
74        }
75
76        // Once we've acquired a permit (or if we already had one), poll the
77        // inner service.
78        self.inner.poll_ready(cx)
79    }
80
81    fn call(&mut self, request: Request) -> Self::Future {
82        // Take the permit
83        let permit = self
84            .permit
85            .take()
86            .expect("max requests in-flight; poll_ready must be called first");
87
88        // Call the inner service
89        let future = self.inner.call(request);
90
91        ResponseFuture::new(future, permit)
92    }
93}
94
95impl<T: Clone> Clone for ConcurrencyLimit<T> {
96    fn clone(&self) -> Self {
97        // Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`.
98        // Instead, when cloning the service, create a new service with the
99        // same semaphore, but with the permit in the un-acquired state.
100        Self {
101            inner: self.inner.clone(),
102            semaphore: self.semaphore.clone(),
103            permit: None,
104        }
105    }
106}
107
108#[cfg(feature = "load")]
109impl<S> crate::load::Load for ConcurrencyLimit<S>
110where
111    S: crate::load::Load,
112{
113    type Metric = S::Metric;
114    fn load(&self) -> Self::Metric {
115        self.inner.load()
116    }
117}