tokio_util/task/join_queue.rs
1use super::AbortOnDropHandle;
2use std::{
3 collections::VecDeque,
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::{
9 runtime::Handle,
10 task::{AbortHandle, Id, JoinError, JoinHandle},
11};
12
13/// A FIFO queue for of tasks spawned on a Tokio runtime.
14///
15/// A [`JoinQueue`] can be used to await the completion of the tasks in FIFO
16/// order. That is, if tasks are spawned in the order A, B, C, then
17/// awaiting the next completed task will always return A first, then B,
18/// then C, regardless of the order in which the tasks actually complete.
19///
20/// All of the tasks must have the same return type `T`.
21///
22/// When the [`JoinQueue`] is dropped, all tasks in the [`JoinQueue`] are
23/// immediately aborted.
24pub struct JoinQueue<T>(VecDeque<AbortOnDropHandle<T>>);
25
26impl<T> JoinQueue<T> {
27 /// Create a new empty [`JoinQueue`].
28 pub const fn new() -> Self {
29 Self(VecDeque::new())
30 }
31
32 /// Creates an empty [`JoinQueue`] with space for at least `capacity` tasks.
33 pub fn with_capacity(capacity: usize) -> Self {
34 Self(VecDeque::with_capacity(capacity))
35 }
36
37 /// Returns the number of tasks currently in the [`JoinQueue`].
38 ///
39 /// This includes both tasks that are currently running and tasks that have
40 /// completed but not yet been removed from the queue because outputting of
41 /// them waits for FIFO order.
42 pub fn len(&self) -> usize {
43 self.0.len()
44 }
45
46 /// Returns whether the [`JoinQueue`] is empty.
47 pub fn is_empty(&self) -> bool {
48 self.0.is_empty()
49 }
50
51 /// Spawn the provided task on the [`JoinQueue`], returning an [`AbortHandle`]
52 /// that can be used to remotely cancel the task.
53 ///
54 /// The provided future will start running in the background immediately
55 /// when this method is called, even if you don't await anything on this
56 /// [`JoinQueue`].
57 ///
58 /// # Panics
59 ///
60 /// This method panics if called outside of a Tokio runtime.
61 ///
62 /// [`AbortHandle`]: tokio::task::AbortHandle
63 #[track_caller]
64 pub fn spawn<F>(&mut self, task: F) -> AbortHandle
65 where
66 F: Future<Output = T> + Send + 'static,
67 T: Send + 'static,
68 {
69 self.push_back(tokio::spawn(task))
70 }
71
72 /// Spawn the provided task on the provided runtime and store it in this
73 /// [`JoinQueue`] returning an [`AbortHandle`] that can be used to remotely
74 /// cancel the task.
75 ///
76 /// The provided future will start running in the background immediately
77 /// when this method is called, even if you don't await anything on this
78 /// [`JoinQueue`].
79 ///
80 /// [`AbortHandle`]: tokio::task::AbortHandle
81 #[track_caller]
82 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
83 where
84 F: Future<Output = T> + Send + 'static,
85 T: Send + 'static,
86 {
87 self.push_back(handle.spawn(task))
88 }
89
90 /// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`]
91 /// and store it in this [`JoinQueue`], returning an [`AbortHandle`] that
92 /// can be used to remotely cancel the task.
93 ///
94 /// The provided future will start running in the background immediately
95 /// when this method is called, even if you don't await anything on this
96 /// [`JoinQueue`].
97 ///
98 /// # Panics
99 ///
100 /// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
101 ///
102 /// [`LocalSet`]: tokio::task::LocalSet
103 /// [`LocalRuntime`]: tokio::runtime::LocalRuntime
104 /// [`AbortHandle`]: tokio::task::AbortHandle
105 #[track_caller]
106 pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
107 where
108 F: Future<Output = T> + 'static,
109 T: 'static,
110 {
111 self.push_back(tokio::task::spawn_local(task))
112 }
113
114 /// Spawn the blocking code on the blocking threadpool and store
115 /// it in this [`JoinQueue`], returning an [`AbortHandle`] that can be
116 /// used to remotely cancel the task.
117 ///
118 /// # Panics
119 ///
120 /// This method panics if called outside of a Tokio runtime.
121 ///
122 /// [`AbortHandle`]: tokio::task::AbortHandle
123 #[track_caller]
124 pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
125 where
126 F: FnOnce() -> T + Send + 'static,
127 T: Send + 'static,
128 {
129 self.push_back(tokio::task::spawn_blocking(f))
130 }
131
132 /// Spawn the blocking code on the blocking threadpool of the
133 /// provided runtime and store it in this [`JoinQueue`], returning an
134 /// [`AbortHandle`] that can be used to remotely cancel the task.
135 ///
136 /// [`AbortHandle`]: tokio::task::AbortHandle
137 #[track_caller]
138 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
139 where
140 F: FnOnce() -> T + Send + 'static,
141 T: Send + 'static,
142 {
143 self.push_back(handle.spawn_blocking(f))
144 }
145
146 fn push_back(&mut self, jh: JoinHandle<T>) -> AbortHandle {
147 let jh = AbortOnDropHandle::new(jh);
148 let abort_handle = jh.abort_handle();
149 self.0.push_back(jh);
150 abort_handle
151 }
152
153 /// Waits until the next task in FIFO order completes and returns its output.
154 ///
155 /// Returns `None` if the queue is empty.
156 ///
157 /// # Cancel Safety
158 ///
159 /// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
160 /// statement and some other branch completes first, it is guaranteed that no tasks were
161 /// removed from this [`JoinQueue`].
162 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
163 std::future::poll_fn(|cx| self.poll_join_next(cx)).await
164 }
165
166 /// Waits until the next task in FIFO order completes and returns its output,
167 /// along with the [task ID] of the completed task.
168 ///
169 /// Returns `None` if the queue is empty.
170 ///
171 /// When this method returns an error, then the id of the task that failed can be accessed
172 /// using the [`JoinError::id`] method.
173 ///
174 /// # Cancel Safety
175 ///
176 /// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
177 /// statement and some other branch completes first, it is guaranteed that no tasks were
178 /// removed from this [`JoinQueue`].
179 ///
180 /// [task ID]: tokio::task::Id
181 /// [`JoinError::id`]: fn@tokio::task::JoinError::id
182 pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
183 std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
184 }
185
186 /// Tries to poll an `AbortOnDropHandle` without blocking or yielding.
187 ///
188 /// Note that on success the handle will panic on subsequent polls
189 /// since it becomes consumed.
190 fn try_poll_handle(jh: &mut AbortOnDropHandle<T>) -> Option<Result<T, JoinError>> {
191 let waker = futures_util::task::noop_waker();
192 let mut cx = Context::from_waker(&waker);
193
194 // Since this function is not async and cannot be forced to yield, we should
195 // disable budgeting when we want to check for the `JoinHandle` readiness.
196 let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh));
197 if let Poll::Ready(res) = jh.poll(&mut cx) {
198 Some(res)
199 } else {
200 None
201 }
202 }
203
204 /// Tries to join the next task in FIFO order if it has completed.
205 ///
206 /// Returns `None` if the queue is empty or if the next task is not yet ready.
207 pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
208 let jh = self.0.front_mut()?;
209 let res = Self::try_poll_handle(jh)?;
210 // Use `detach` to avoid calling `abort` on a task that has already completed.
211 // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
212 // we only need to drop the `JoinHandle` for cleanup.
213 drop(self.0.pop_front().unwrap().detach());
214 Some(res)
215 }
216
217 /// Tries to join the next task in FIFO order if it has completed and return its output,
218 /// along with its [task ID].
219 ///
220 /// Returns `None` if the queue is empty or if the next task is not yet ready.
221 ///
222 /// When this method returns an error, then the id of the task that failed can be accessed
223 /// using the [`JoinError::id`] method.
224 ///
225 /// [task ID]: tokio::task::Id
226 /// [`JoinError::id`]: fn@tokio::task::JoinError::id
227 pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
228 let jh = self.0.front_mut()?;
229 let res = Self::try_poll_handle(jh)?;
230 // Use `detach` to avoid calling `abort` on a task that has already completed.
231 // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
232 // we only need to drop the `JoinHandle` for cleanup.
233 let jh = self.0.pop_front().unwrap().detach();
234 let id = jh.id();
235 drop(jh);
236 Some(res.map(|output| (id, output)))
237 }
238
239 /// Aborts all tasks and waits for them to finish shutting down.
240 ///
241 /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
242 /// a loop until it returns `None`.
243 ///
244 /// This method ignores any panics in the tasks shutting down. When this call returns, the
245 /// [`JoinQueue`] will be empty.
246 ///
247 /// [`abort_all`]: fn@Self::abort_all
248 /// [`join_next`]: fn@Self::join_next
249 pub async fn shutdown(&mut self) {
250 self.abort_all();
251 while self.join_next().await.is_some() {}
252 }
253
254 /// Awaits the completion of all tasks in this [`JoinQueue`], returning a vector of their results.
255 ///
256 /// The results will be stored in the order they were spawned, not the order they completed.
257 /// This is a convenience method that is equivalent to calling [`join_next`] in
258 /// a loop. If any tasks on the [`JoinQueue`] fail with an [`JoinError`], then this call
259 /// to `join_all` will panic and all remaining tasks on the [`JoinQueue`] are
260 /// cancelled. To handle errors in any other way, manually call [`join_next`]
261 /// in a loop.
262 ///
263 /// # Cancel Safety
264 ///
265 /// This method is not cancel safe as it calls `join_next` in a loop. If you need
266 /// cancel safety, manually call `join_next` in a loop with `Vec` accumulator.
267 ///
268 /// [`join_next`]: fn@Self::join_next
269 /// [`JoinError::id`]: fn@tokio::task::JoinError::id
270 pub async fn join_all(mut self) -> Vec<T> {
271 let mut output = Vec::with_capacity(self.len());
272
273 while let Some(res) = self.join_next().await {
274 match res {
275 Ok(t) => output.push(t),
276 Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
277 Err(err) => panic!("{err}"),
278 }
279 }
280 output
281 }
282
283 /// Aborts all tasks on this [`JoinQueue`].
284 ///
285 /// This does not remove the tasks from the [`JoinQueue`]. To wait for the tasks to complete
286 /// cancellation, you should call `join_next` in a loop until the [`JoinQueue`] is empty.
287 pub fn abort_all(&mut self) {
288 self.0.iter().for_each(|jh| jh.abort());
289 }
290
291 /// Removes all tasks from this [`JoinQueue`] without aborting them.
292 ///
293 /// The tasks removed by this call will continue to run in the background even if the [`JoinQueue`]
294 /// is dropped.
295 pub fn detach_all(&mut self) {
296 self.0.drain(..).for_each(|jh| drop(jh.detach()));
297 }
298
299 /// Polls for the next task in [`JoinQueue`] to complete.
300 ///
301 /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
302 ///
303 /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
304 /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
305 /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
306 /// scheduled to receive a wakeup.
307 ///
308 /// # Returns
309 ///
310 /// This function returns:
311 ///
312 /// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
313 /// available right now.
314 /// * `Poll::Ready(Some(Ok(value)))` if the next task in this [`JoinQueue`] has completed.
315 /// The `value` is the return value that task.
316 /// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
317 /// aborted. The `err` is the `JoinError` from the panicked/aborted task.
318 /// * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
319 pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
320 let jh = match self.0.front_mut() {
321 None => return Poll::Ready(None),
322 Some(jh) => jh,
323 };
324 if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
325 // Use `detach` to avoid calling `abort` on a task that has already completed.
326 // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
327 // we only need to drop the `JoinHandle` for cleanup.
328 drop(self.0.pop_front().unwrap().detach());
329 Poll::Ready(Some(res))
330 } else {
331 Poll::Pending
332 }
333 }
334
335 /// Polls for the next task in [`JoinQueue`] to complete.
336 ///
337 /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
338 ///
339 /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
340 /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
341 /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
342 /// scheduled to receive a wakeup.
343 ///
344 /// # Returns
345 ///
346 /// This function returns:
347 ///
348 /// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
349 /// available right now.
350 /// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this [`JoinQueue`] has completed.
351 /// The `value` is the return value that task, and `id` is its [task ID].
352 /// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
353 /// aborted. The `err` is the `JoinError` from the panicked/aborted task.
354 /// * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
355 ///
356 /// [task ID]: tokio::task::Id
357 pub fn poll_join_next_with_id(
358 &mut self,
359 cx: &mut Context<'_>,
360 ) -> Poll<Option<Result<(Id, T), JoinError>>> {
361 let jh = match self.0.front_mut() {
362 None => return Poll::Ready(None),
363 Some(jh) => jh,
364 };
365 if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
366 // Use `detach` to avoid calling `abort` on a task that has already completed.
367 // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
368 // we only need to drop the `JoinHandle` for cleanup.
369 let jh = self.0.pop_front().unwrap().detach();
370 let id = jh.id();
371 drop(jh);
372 // If the task succeeded, add the task ID to the output. Otherwise, the
373 // `JoinError` will already have the task's ID.
374 Poll::Ready(Some(res.map(|output| (id, output))))
375 } else {
376 Poll::Pending
377 }
378 }
379}
380
381impl<T> std::fmt::Debug for JoinQueue<T> {
382 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 f.debug_list()
384 .entries(self.0.iter().map(|jh| JoinHandle::id(jh.as_ref())))
385 .finish()
386 }
387}
388
389impl<T> Default for JoinQueue<T> {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395/// Collect an iterator of futures into a [`JoinQueue`].
396///
397/// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator.
398impl<T, F> std::iter::FromIterator<F> for JoinQueue<T>
399where
400 F: Future<Output = T> + Send + 'static,
401 T: Send + 'static,
402{
403 fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
404 let mut set = Self::new();
405 iter.into_iter().for_each(|task| {
406 set.spawn(task);
407 });
408 set
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 /// A simple type that does not implement [`std::fmt::Debug`].
417 struct NotDebug;
418
419 fn is_debug<T: std::fmt::Debug>() {}
420
421 #[test]
422 fn assert_debug() {
423 is_debug::<JoinQueue<NotDebug>>();
424 }
425}