tokio_util/task/
join_queue.rs

1use super::AbortOnDropHandle;
2use std::{
3    collections::VecDeque,
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use tokio::{
9    runtime::Handle,
10    task::{AbortHandle, Id, JoinError, JoinHandle},
11};
12
13/// A FIFO queue for of tasks spawned on a Tokio runtime.
14///
15/// A [`JoinQueue`] can be used to await the completion of the tasks in FIFO
16/// order. That is, if tasks are spawned in the order A, B, C, then
17/// awaiting the next completed task will always return A first, then B,
18/// then C, regardless of the order in which the tasks actually complete.
19///
20/// All of the tasks must have the same return type `T`.
21///
22/// When the [`JoinQueue`] is dropped, all tasks in the [`JoinQueue`] are
23/// immediately aborted.
24pub struct JoinQueue<T>(VecDeque<AbortOnDropHandle<T>>);
25
26impl<T> JoinQueue<T> {
27    /// Create a new empty [`JoinQueue`].
28    pub const fn new() -> Self {
29        Self(VecDeque::new())
30    }
31
32    /// Creates an empty [`JoinQueue`] with space for at least `capacity` tasks.
33    pub fn with_capacity(capacity: usize) -> Self {
34        Self(VecDeque::with_capacity(capacity))
35    }
36
37    /// Returns the number of tasks currently in the [`JoinQueue`].
38    ///
39    /// This includes both tasks that are currently running and tasks that have
40    /// completed but not yet been removed from the queue because outputting of
41    /// them waits for FIFO order.
42    pub fn len(&self) -> usize {
43        self.0.len()
44    }
45
46    /// Returns whether the [`JoinQueue`] is empty.
47    pub fn is_empty(&self) -> bool {
48        self.0.is_empty()
49    }
50
51    /// Spawn the provided task on the [`JoinQueue`], returning an [`AbortHandle`]
52    /// that can be used to remotely cancel the task.
53    ///
54    /// The provided future will start running in the background immediately
55    /// when this method is called, even if you don't await anything on this
56    /// [`JoinQueue`].
57    ///
58    /// # Panics
59    ///
60    /// This method panics if called outside of a Tokio runtime.
61    ///
62    /// [`AbortHandle`]: tokio::task::AbortHandle
63    #[track_caller]
64    pub fn spawn<F>(&mut self, task: F) -> AbortHandle
65    where
66        F: Future<Output = T> + Send + 'static,
67        T: Send + 'static,
68    {
69        self.push_back(tokio::spawn(task))
70    }
71
72    /// Spawn the provided task on the provided runtime and store it in this
73    /// [`JoinQueue`] returning an [`AbortHandle`] that can be used to remotely
74    /// cancel the task.
75    ///
76    /// The provided future will start running in the background immediately
77    /// when this method is called, even if you don't await anything on this
78    /// [`JoinQueue`].
79    ///
80    /// [`AbortHandle`]: tokio::task::AbortHandle
81    #[track_caller]
82    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
83    where
84        F: Future<Output = T> + Send + 'static,
85        T: Send + 'static,
86    {
87        self.push_back(handle.spawn(task))
88    }
89
90    /// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`]
91    /// and store it in this [`JoinQueue`], returning an [`AbortHandle`] that
92    /// can be used to remotely cancel the task.
93    ///
94    /// The provided future will start running in the background immediately
95    /// when this method is called, even if you don't await anything on this
96    /// [`JoinQueue`].
97    ///
98    /// # Panics
99    ///
100    /// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
101    ///
102    /// [`LocalSet`]: tokio::task::LocalSet
103    /// [`LocalRuntime`]: tokio::runtime::LocalRuntime
104    /// [`AbortHandle`]: tokio::task::AbortHandle
105    #[track_caller]
106    pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
107    where
108        F: Future<Output = T> + 'static,
109        T: 'static,
110    {
111        self.push_back(tokio::task::spawn_local(task))
112    }
113
114    /// Spawn the blocking code on the blocking threadpool and store
115    /// it in this [`JoinQueue`], returning an [`AbortHandle`] that can be
116    /// used to remotely cancel the task.
117    ///
118    /// # Panics
119    ///
120    /// This method panics if called outside of a Tokio runtime.
121    ///
122    /// [`AbortHandle`]: tokio::task::AbortHandle
123    #[track_caller]
124    pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
125    where
126        F: FnOnce() -> T + Send + 'static,
127        T: Send + 'static,
128    {
129        self.push_back(tokio::task::spawn_blocking(f))
130    }
131
132    /// Spawn the blocking code on the blocking threadpool of the
133    /// provided runtime and store it in this [`JoinQueue`], returning an
134    /// [`AbortHandle`] that can be used to remotely cancel the task.
135    ///
136    /// [`AbortHandle`]: tokio::task::AbortHandle
137    #[track_caller]
138    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
139    where
140        F: FnOnce() -> T + Send + 'static,
141        T: Send + 'static,
142    {
143        self.push_back(handle.spawn_blocking(f))
144    }
145
146    fn push_back(&mut self, jh: JoinHandle<T>) -> AbortHandle {
147        let jh = AbortOnDropHandle::new(jh);
148        let abort_handle = jh.abort_handle();
149        self.0.push_back(jh);
150        abort_handle
151    }
152
153    /// Waits until the next task in FIFO order completes and returns its output.
154    ///
155    /// Returns `None` if the queue is empty.
156    ///
157    /// # Cancel Safety
158    ///
159    /// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
160    /// statement and some other branch completes first, it is guaranteed that no tasks were
161    /// removed from this [`JoinQueue`].
162    pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
163        std::future::poll_fn(|cx| self.poll_join_next(cx)).await
164    }
165
166    /// Waits until the next task in FIFO order completes and returns its output,
167    /// along with the [task ID] of the completed task.
168    ///
169    /// Returns `None` if the queue is empty.
170    ///
171    /// When this method returns an error, then the id of the task that failed can be accessed
172    /// using the [`JoinError::id`] method.
173    ///
174    /// # Cancel Safety
175    ///
176    /// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
177    /// statement and some other branch completes first, it is guaranteed that no tasks were
178    /// removed from this [`JoinQueue`].
179    ///
180    /// [task ID]: tokio::task::Id
181    /// [`JoinError::id`]: fn@tokio::task::JoinError::id
182    pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
183        std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
184    }
185
186    /// Tries to poll an `AbortOnDropHandle` without blocking or yielding.
187    ///
188    /// Note that on success the handle will panic on subsequent polls
189    /// since it becomes consumed.
190    fn try_poll_handle(jh: &mut AbortOnDropHandle<T>) -> Option<Result<T, JoinError>> {
191        let waker = futures_util::task::noop_waker();
192        let mut cx = Context::from_waker(&waker);
193
194        // Since this function is not async and cannot be forced to yield, we should
195        // disable budgeting when we want to check for the `JoinHandle` readiness.
196        let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh));
197        if let Poll::Ready(res) = jh.poll(&mut cx) {
198            Some(res)
199        } else {
200            None
201        }
202    }
203
204    /// Tries to join the next task in FIFO order if it has completed.
205    ///
206    /// Returns `None` if the queue is empty or if the next task is not yet ready.
207    pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
208        let jh = self.0.front_mut()?;
209        let res = Self::try_poll_handle(jh)?;
210        // Use `detach` to avoid calling `abort` on a task that has already completed.
211        // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
212        // we only need to drop the `JoinHandle` for cleanup.
213        drop(self.0.pop_front().unwrap().detach());
214        Some(res)
215    }
216
217    /// Tries to join the next task in FIFO order if it has completed and return its output,
218    /// along with its [task ID].
219    ///
220    /// Returns `None` if the queue is empty or if the next task is not yet ready.
221    ///
222    /// When this method returns an error, then the id of the task that failed can be accessed
223    /// using the [`JoinError::id`] method.
224    ///
225    /// [task ID]: tokio::task::Id
226    /// [`JoinError::id`]: fn@tokio::task::JoinError::id
227    pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
228        let jh = self.0.front_mut()?;
229        let res = Self::try_poll_handle(jh)?;
230        // Use `detach` to avoid calling `abort` on a task that has already completed.
231        // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
232        // we only need to drop the `JoinHandle` for cleanup.
233        let jh = self.0.pop_front().unwrap().detach();
234        let id = jh.id();
235        drop(jh);
236        Some(res.map(|output| (id, output)))
237    }
238
239    /// Aborts all tasks and waits for them to finish shutting down.
240    ///
241    /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
242    /// a loop until it returns `None`.
243    ///
244    /// This method ignores any panics in the tasks shutting down. When this call returns, the
245    /// [`JoinQueue`] will be empty.
246    ///
247    /// [`abort_all`]: fn@Self::abort_all
248    /// [`join_next`]: fn@Self::join_next
249    pub async fn shutdown(&mut self) {
250        self.abort_all();
251        while self.join_next().await.is_some() {}
252    }
253
254    /// Awaits the completion of all tasks in this [`JoinQueue`], returning a vector of their results.
255    ///
256    /// The results will be stored in the order they were spawned, not the order they completed.
257    /// This is a convenience method that is equivalent to calling [`join_next`] in
258    /// a loop. If any tasks on the [`JoinQueue`] fail with an [`JoinError`], then this call
259    /// to `join_all` will panic and all remaining tasks on the [`JoinQueue`] are
260    /// cancelled. To handle errors in any other way, manually call [`join_next`]
261    /// in a loop.
262    ///
263    /// # Cancel Safety
264    ///
265    /// This method is not cancel safe as it calls `join_next` in a loop. If you need
266    /// cancel safety, manually call `join_next` in a loop with `Vec` accumulator.
267    ///
268    /// [`join_next`]: fn@Self::join_next
269    /// [`JoinError::id`]: fn@tokio::task::JoinError::id
270    pub async fn join_all(mut self) -> Vec<T> {
271        let mut output = Vec::with_capacity(self.len());
272
273        while let Some(res) = self.join_next().await {
274            match res {
275                Ok(t) => output.push(t),
276                Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
277                Err(err) => panic!("{err}"),
278            }
279        }
280        output
281    }
282
283    /// Aborts all tasks on this [`JoinQueue`].
284    ///
285    /// This does not remove the tasks from the [`JoinQueue`]. To wait for the tasks to complete
286    /// cancellation, you should call `join_next` in a loop until the [`JoinQueue`] is empty.
287    pub fn abort_all(&mut self) {
288        self.0.iter().for_each(|jh| jh.abort());
289    }
290
291    /// Removes all tasks from this [`JoinQueue`] without aborting them.
292    ///
293    /// The tasks removed by this call will continue to run in the background even if the [`JoinQueue`]
294    /// is dropped.
295    pub fn detach_all(&mut self) {
296        self.0.drain(..).for_each(|jh| drop(jh.detach()));
297    }
298
299    /// Polls for the next task in [`JoinQueue`] to complete.
300    ///
301    /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
302    ///
303    /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
304    /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
305    /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
306    /// scheduled to receive a wakeup.
307    ///
308    /// # Returns
309    ///
310    /// This function returns:
311    ///
312    ///  * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
313    ///    available right now.
314    ///  * `Poll::Ready(Some(Ok(value)))` if the next task in this [`JoinQueue`] has completed.
315    ///    The `value` is the return value that task.
316    ///  * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
317    ///    aborted. The `err` is the `JoinError` from the panicked/aborted task.
318    ///  * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
319    pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
320        let jh = match self.0.front_mut() {
321            None => return Poll::Ready(None),
322            Some(jh) => jh,
323        };
324        if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
325            // Use `detach` to avoid calling `abort` on a task that has already completed.
326            // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
327            // we only need to drop the `JoinHandle` for cleanup.
328            drop(self.0.pop_front().unwrap().detach());
329            Poll::Ready(Some(res))
330        } else {
331            Poll::Pending
332        }
333    }
334
335    /// Polls for the next task in [`JoinQueue`] to complete.
336    ///
337    /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
338    ///
339    /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
340    /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
341    /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
342    /// scheduled to receive a wakeup.
343    ///
344    /// # Returns
345    ///
346    /// This function returns:
347    ///
348    ///  * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
349    ///    available right now.
350    ///  * `Poll::Ready(Some(Ok((id, value))))` if the next task in this [`JoinQueue`] has completed.
351    ///    The `value` is the return value that task, and `id` is its [task ID].
352    ///  * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
353    ///    aborted. The `err` is the `JoinError` from the panicked/aborted task.
354    ///  * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
355    ///
356    /// [task ID]: tokio::task::Id
357    pub fn poll_join_next_with_id(
358        &mut self,
359        cx: &mut Context<'_>,
360    ) -> Poll<Option<Result<(Id, T), JoinError>>> {
361        let jh = match self.0.front_mut() {
362            None => return Poll::Ready(None),
363            Some(jh) => jh,
364        };
365        if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
366            // Use `detach` to avoid calling `abort` on a task that has already completed.
367            // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
368            // we only need to drop the `JoinHandle` for cleanup.
369            let jh = self.0.pop_front().unwrap().detach();
370            let id = jh.id();
371            drop(jh);
372            // If the task succeeded, add the task ID to the output. Otherwise, the
373            // `JoinError` will already have the task's ID.
374            Poll::Ready(Some(res.map(|output| (id, output))))
375        } else {
376            Poll::Pending
377        }
378    }
379}
380
381impl<T> std::fmt::Debug for JoinQueue<T> {
382    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        f.debug_list()
384            .entries(self.0.iter().map(|jh| JoinHandle::id(jh.as_ref())))
385            .finish()
386    }
387}
388
389impl<T> Default for JoinQueue<T> {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395/// Collect an iterator of futures into a [`JoinQueue`].
396///
397/// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator.
398impl<T, F> std::iter::FromIterator<F> for JoinQueue<T>
399where
400    F: Future<Output = T> + Send + 'static,
401    T: Send + 'static,
402{
403    fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
404        let mut set = Self::new();
405        iter.into_iter().for_each(|task| {
406            set.spawn(task);
407        });
408        set
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    /// A simple type that does not implement [`std::fmt::Debug`].
417    struct NotDebug;
418
419    fn is_debug<T: std::fmt::Debug>() {}
420
421    #[test]
422    fn assert_debug() {
423        is_debug::<JoinQueue<NotDebug>>();
424    }
425}