tower/limit/concurrency/
service.rs1use super::future::ResponseFuture;
2use tokio::sync::{OwnedSemaphorePermit, Semaphore};
3use tokio_util::sync::PollSemaphore;
4use tower_service::Service;
5
6use std::{
7 sync::Arc,
8 task::{ready, Context, Poll},
9};
10
11#[derive(Debug)]
14pub struct ConcurrencyLimit<T> {
15 inner: T,
16 semaphore: PollSemaphore,
17 permit: Option<OwnedSemaphorePermit>,
23}
24
25impl<T> ConcurrencyLimit<T> {
26 pub fn new(inner: T, max: usize) -> Self {
28 Self::with_semaphore(inner, Arc::new(Semaphore::new(max)))
29 }
30
31 pub fn with_semaphore(inner: T, semaphore: Arc<Semaphore>) -> Self {
33 ConcurrencyLimit {
34 inner,
35 semaphore: PollSemaphore::new(semaphore),
36 permit: None,
37 }
38 }
39
40 pub fn get_ref(&self) -> &T {
42 &self.inner
43 }
44
45 pub fn get_mut(&mut self) -> &mut T {
47 &mut self.inner
48 }
49
50 pub fn into_inner(self) -> T {
52 self.inner
53 }
54}
55
56impl<S, Request> Service<Request> for ConcurrencyLimit<S>
57where
58 S: Service<Request>,
59{
60 type Response = S::Response;
61 type Error = S::Error;
62 type Future = ResponseFuture<S::Future>;
63
64 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65 if self.permit.is_none() {
68 self.permit = ready!(self.semaphore.poll_acquire(cx));
69 debug_assert!(
70 self.permit.is_some(),
71 "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
72 should never fail",
73 );
74 }
75
76 self.inner.poll_ready(cx)
79 }
80
81 fn call(&mut self, request: Request) -> Self::Future {
82 let permit = self
84 .permit
85 .take()
86 .expect("max requests in-flight; poll_ready must be called first");
87
88 let future = self.inner.call(request);
90
91 ResponseFuture::new(future, permit)
92 }
93}
94
95impl<T: Clone> Clone for ConcurrencyLimit<T> {
96 fn clone(&self) -> Self {
97 Self {
101 inner: self.inner.clone(),
102 semaphore: self.semaphore.clone(),
103 permit: None,
104 }
105 }
106}
107
108#[cfg(feature = "load")]
109impl<S> crate::load::Load for ConcurrencyLimit<S>
110where
111 S: crate::load::Load,
112{
113 type Metric = S::Metric;
114 fn load(&self) -> Self::Metric {
115 self.inner.load()
116 }
117}