1mod concurrent_limited_multimap;
51
52use std::hash::Hash;
53use std::sync::{Arc, RwLock};
54
55use slab::Slab;
56use tokio::sync::{OwnedSemaphorePermit, Semaphore};
57
58use self::concurrent_limited_multimap::ConcurrentLimitedMultimap;
59
60pub struct Pool<K, I> {
62 inner: Arc<ConcurrentLimitedMultimap<K, I, ahash::RandomState>>,
63 local_limits: RwLock<Slab<Arc<Semaphore>>>,
64 semaphore: Option<Arc<Semaphore>>,
65}
66
67impl<K, I> Pool<K, I>
68where
69 K: Eq + std::hash::Hash,
70{
71 pub fn new(capacity: usize) -> Self {
73 Self {
74 inner: Arc::new(ConcurrentLimitedMultimap::with_hasher(
75 capacity,
76 ahash::RandomState::new(),
77 )),
78 semaphore: Some(Arc::new(Semaphore::new(capacity))),
79 local_limits: RwLock::new(Slab::new()),
80 }
81 }
82
83 pub fn new_unbounded() -> Self {
85 Self {
86 inner: Arc::new(ConcurrentLimitedMultimap::with_hasher_unbounded(
87 ahash::RandomState::new(),
88 )),
89 semaphore: None,
90 local_limits: RwLock::new(Slab::new()),
91 }
92 }
93
94 pub fn set_local_limit(&self, limit: usize) -> usize {
96 let mut local_limits = self.local_limits.write().expect("local limits lock poisoned");
97 local_limits.insert(Arc::new(Semaphore::new(limit)))
98 }
99
100 pub async fn pull(&self, key: K) -> Item<K, I> {
103 self.pull_with_wait_local_limit(key, None).await
104 }
105
106 pub async fn pull_with_local_limit(&self, key: K, local_limit_index: Option<usize>) -> Option<Item<K, I>> {
109 let local_guard = if let Some(index) = local_limit_index {
110 let local_limits = self.local_limits.read().expect("local limits lock poisoned");
111 if let Some(semaphore) = local_limits.get(index) {
112 let semaphore = semaphore.clone();
113 drop(local_limits);
114 Some(semaphore.try_acquire_owned().ok()?)
115 } else {
116 None
117 }
118 } else {
119 None
120 };
121 let guard = if let Some(semaphore) = &self.semaphore {
122 Some(semaphore.clone().acquire_owned().await.expect("semaphore closed"))
123 } else {
124 None
125 };
126
127 let key = Arc::new(key);
128 let inner_value = self.inner.remove(key.clone());
129 Some(Item {
130 pool_inner: self.inner.clone(),
131 key: Some(key),
132 inner: inner_value,
133 _guard: guard,
134 _local_guard: local_guard,
135 })
136 }
137
138 #[allow(clippy::await_holding_lock)]
141 pub async fn pull_with_wait_local_limit(&self, key: K, local_limit_index: Option<usize>) -> Item<K, I> {
142 let local_guard = if let Some(index) = local_limit_index {
143 let local_limits = self.local_limits.read().expect("local limits lock poisoned");
144 if let Some(semaphore) = local_limits.get(index) {
145 let semaphore = semaphore.clone();
146 drop(local_limits); Some(semaphore.acquire_owned().await.expect("semaphore closed"))
148 } else {
149 None
150 }
151 } else {
152 None
153 };
154 let guard = if let Some(semaphore) = &self.semaphore {
155 Some(semaphore.clone().acquire_owned().await.expect("semaphore closed"))
156 } else {
157 None
158 };
159
160 let key = Arc::new(key);
161 let inner_value = self.inner.remove(key.clone());
162 Item {
163 pool_inner: self.inner.clone(),
164 key: Some(key),
165 inner: inner_value,
166 _guard: guard,
167 _local_guard: local_guard,
168 }
169 }
170
171 pub fn try_pull(&self, key: K) -> Option<Item<K, I>> {
174 self.try_pull_with_local_limit(key, None)
175 }
176
177 pub fn try_pull_with_local_limit(&self, key: K, local_limit_index: Option<usize>) -> Option<Item<K, I>> {
180 let local_guard = if let Some(index) = local_limit_index {
181 let local_limits = self.local_limits.read().expect("local limits lock poisoned");
182 if let Some(semaphore) = local_limits.get(index) {
183 let semaphore = semaphore.clone();
184 drop(local_limits);
185 Some(semaphore.try_acquire_owned().ok()?)
186 } else {
187 None
188 }
189 } else {
190 None
191 };
192 let guard = if let Some(semaphore) = &self.semaphore {
193 Some(semaphore.clone().try_acquire_owned().ok()?)
194 } else {
195 None
196 };
197
198 let key = Arc::new(key);
199 let inner_value = self.inner.remove(key.clone());
200 Some(Item {
201 pool_inner: self.inner.clone(),
202 key: Some(key),
203 inner: inner_value,
204 _guard: guard,
205 _local_guard: local_guard,
206 })
207 }
208}
209
210pub struct Item<K: Eq + Hash, I> {
212 pool_inner: Arc<ConcurrentLimitedMultimap<K, I, ahash::RandomState>>,
213 key: Option<Arc<K>>,
214 inner: Option<I>,
215 _guard: Option<OwnedSemaphorePermit>,
216 _local_guard: Option<OwnedSemaphorePermit>,
217}
218
219impl<K: Eq + Hash, I> Item<K, I> {
220 pub fn take(mut self) -> Option<I> {
222 self.inner.take()
223 }
224
225 pub fn inner(&self) -> &Option<I> {
227 &self.inner
228 }
229
230 pub fn inner_mut(&mut self) -> &mut Option<I> {
232 &mut self.inner
233 }
234}
235
236impl<K: Eq + Hash, I> Drop for Item<K, I> {
237 fn drop(&mut self) {
238 if let Some(inner) = self.inner.take() {
239 self.pool_inner.insert(self.key.take().expect("key not set"), inner);
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[tokio::test]
249 async fn test_pool_new() {
250 let pool = Pool::<String, u32>::new(10);
251 assert_eq!(pool.semaphore.as_ref().unwrap().available_permits(), 10);
252 }
253
254 #[tokio::test]
255 async fn test_pool_pull_and_take() {
256 let pool = Pool::<String, u32>::new(1);
257 let item = pool.pull("key1".to_string()).await;
258 assert!(item.take().is_none());
259 }
260
261 #[tokio::test]
262 async fn test_pool_pull_and_replace() {
263 let pool = Pool::<String, u32>::new(1);
264 let mut item = pool.pull("key1".to_string()).await;
265 *item.inner_mut() = Some(42);
266 assert_eq!(item.inner(), &Some(42));
267 }
268
269 #[tokio::test]
270 async fn test_pool_eviction_behavior() {
271 let pool = Pool::<String, u32>::new(2);
272 {
273 let mut item1 = pool.pull("key1".to_string()).await;
274 item1.inner_mut().replace(1);
275 }
276 {
277 let mut item2 = pool.pull("key2".to_string()).await;
278 item2.inner_mut().replace(2);
279 }
280 {
282 let _item1 = pool.pull("key1".to_string()).await;
283 }
284 {
286 let mut item3 = pool.pull("key3".to_string()).await;
287 item3.inner_mut().replace(3);
288 }
289 let mut num_entries = 0;
291 if pool.pull("key1".to_string()).await.inner().is_some() {
292 num_entries += 1;
293 }
294 if pool.pull("key2".to_string()).await.inner().is_some() {
295 num_entries += 1;
296 }
297 if pool.pull("key3".to_string()).await.inner().is_some() {
298 num_entries += 1;
299 }
300 assert_eq!(num_entries, 2);
301 }
302
303 #[tokio::test]
304 async fn test_pool_semaphore_limit() {
305 let pool = Pool::<String, u32>::new(1);
306 let item1 = pool.pull("key1".to_string()).await;
307 let semaphore_permits = pool.semaphore.as_ref().unwrap().available_permits();
308 assert_eq!(semaphore_permits, 0);
309 drop(item1);
310 assert_eq!(pool.semaphore.as_ref().unwrap().available_permits(), 1);
311 }
312
313 #[tokio::test]
314 async fn test_set_and_get_local_limit() {
315 let pool = Pool::<String, u32>::new(10);
316 let index = pool.set_local_limit(2);
317 let local_limits = pool.local_limits.read().expect("lock poisoned");
318 assert!(local_limits.get(index).is_some());
319 assert_eq!(local_limits[index].available_permits(), 2);
320 }
321
322 #[tokio::test]
323 async fn test_pull_with_local_limit_success() {
324 let pool = Pool::<String, u32>::new(10);
325 let index = pool.set_local_limit(2);
326 let item = pool.pull_with_local_limit("key1".to_string(), Some(index)).await;
327 assert!(item.is_some());
328 }
329
330 #[tokio::test]
331 async fn test_pull_with_local_limit_exhausted() {
332 let pool = Pool::<String, u32>::new(10);
333 let index = pool.set_local_limit(1);
334 let _item1 = pool.pull_with_local_limit("key1".to_string(), Some(index)).await;
336 let item2 = pool.pull_with_local_limit("key2".to_string(), Some(index)).await;
338 assert!(item2.is_none());
339 }
340
341 #[tokio::test]
342 async fn test_pull_with_local_limit_after_release() {
343 let pool = Pool::<String, u32>::new(10);
344 let index = pool.set_local_limit(1);
345 let item1 = pool.pull_with_local_limit("key1".to_string(), Some(index)).await;
347 assert!(item1.is_some());
348 drop(item1);
350 let item2 = pool.pull_with_local_limit("key2".to_string(), Some(index)).await;
352 assert!(item2.is_some());
353 }
354
355 #[tokio::test]
356 async fn test_pull_with_invalid_local_limit_index() {
357 let pool = Pool::<String, u32>::new(10);
358 let item = pool.pull_with_local_limit("key1".to_string(), Some(999)).await;
360 assert!(item.is_some()); }
362}