1use crate::loom::sync::{Arc, Condvar, Mutex};
4use crate::loom::thread;
5use crate::runtime::blocking::schedule::BlockingSchedule;
6use crate::runtime::blocking::{shutdown, BlockingTask};
7use crate::runtime::builder::ThreadNameFn;
8use crate::runtime::task::{self, JoinHandle};
9use crate::runtime::{Builder, Callback, Handle, BOX_FUTURE_THRESHOLD};
10use crate::util::metric_atomics::MetricAtomicUsize;
11use crate::util::trace::{blocking_task, SpawnMeta};
12
13use std::collections::{HashMap, VecDeque};
14use std::fmt;
15use std::io;
16use std::sync::atomic::Ordering;
17use std::time::Duration;
18
19pub(crate) struct BlockingPool {
20    spawner: Spawner,
21    shutdown_rx: shutdown::Receiver,
22}
23
24#[derive(Clone)]
25pub(crate) struct Spawner {
26    inner: Arc<Inner>,
27}
28
29#[derive(Default)]
30pub(crate) struct SpawnerMetrics {
31    num_threads: MetricAtomicUsize,
32    num_idle_threads: MetricAtomicUsize,
33    queue_depth: MetricAtomicUsize,
34}
35
36impl SpawnerMetrics {
37    fn num_threads(&self) -> usize {
38        self.num_threads.load(Ordering::Relaxed)
39    }
40
41    fn num_idle_threads(&self) -> usize {
42        self.num_idle_threads.load(Ordering::Relaxed)
43    }
44
45    cfg_unstable_metrics! {
46        fn queue_depth(&self) -> usize {
47            self.queue_depth.load(Ordering::Relaxed)
48        }
49    }
50
51    fn inc_num_threads(&self) {
52        self.num_threads.increment();
53    }
54
55    fn dec_num_threads(&self) {
56        self.num_threads.decrement();
57    }
58
59    fn inc_num_idle_threads(&self) {
60        self.num_idle_threads.increment();
61    }
62
63    fn dec_num_idle_threads(&self) -> usize {
64        self.num_idle_threads.decrement()
65    }
66
67    fn inc_queue_depth(&self) {
68        self.queue_depth.increment();
69    }
70
71    fn dec_queue_depth(&self) {
72        self.queue_depth.decrement();
73    }
74}
75
76struct Inner {
77    shared: Mutex<Shared>,
79
80    condvar: Condvar,
82
83    thread_name: ThreadNameFn,
85
86    stack_size: Option<usize>,
88
89    after_start: Option<Callback>,
91
92    before_stop: Option<Callback>,
94
95    thread_cap: usize,
97
98    keep_alive: Duration,
100
101    metrics: SpawnerMetrics,
103}
104
105struct Shared {
106    queue: VecDeque<Task>,
107    num_notify: u32,
108    shutdown: bool,
109    shutdown_tx: Option<shutdown::Sender>,
110    last_exiting_thread: Option<thread::JoinHandle<()>>,
116    worker_threads: HashMap<usize, thread::JoinHandle<()>>,
119    worker_thread_index: usize,
122}
123
124pub(crate) struct Task {
125    task: task::UnownedTask<BlockingSchedule>,
126    mandatory: Mandatory,
127}
128
129#[derive(PartialEq, Eq)]
130pub(crate) enum Mandatory {
131    #[cfg_attr(not(feature = "fs"), allow(dead_code))]
132    Mandatory,
133    NonMandatory,
134}
135
136pub(crate) enum SpawnError {
137    ShuttingDown,
139    NoThreads(io::Error),
142}
143
144impl From<SpawnError> for io::Error {
145    fn from(e: SpawnError) -> Self {
146        match e {
147            SpawnError::ShuttingDown => {
148                io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
149            }
150            SpawnError::NoThreads(e) => e,
151        }
152    }
153}
154
155impl Task {
156    pub(crate) fn new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task {
157        Task { task, mandatory }
158    }
159
160    fn run(self) {
161        self.task.run();
162    }
163
164    fn shutdown_or_run_if_mandatory(self) {
165        match self.mandatory {
166            Mandatory::NonMandatory => self.task.shutdown(),
167            Mandatory::Mandatory => self.task.run(),
168        }
169    }
170}
171
172const KEEP_ALIVE: Duration = Duration::from_secs(10);
173
174#[track_caller]
178#[cfg_attr(target_os = "wasi", allow(dead_code))]
179pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
180where
181    F: FnOnce() -> R + Send + 'static,
182    R: Send + 'static,
183{
184    let rt = Handle::current();
185    rt.spawn_blocking(func)
186}
187
188cfg_fs! {
189    #[cfg_attr(any(
190        all(loom, not(test)), test
192    ), allow(dead_code))]
193    pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
198    where
199        F: FnOnce() -> R + Send + 'static,
200        R: Send + 'static,
201    {
202        let rt = Handle::current();
203        rt.inner.blocking_spawner().spawn_mandatory_blocking(&rt, func)
204    }
205}
206
207impl BlockingPool {
210    pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
211        let (shutdown_tx, shutdown_rx) = shutdown::channel();
212        let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
213
214        BlockingPool {
215            spawner: Spawner {
216                inner: Arc::new(Inner {
217                    shared: Mutex::new(Shared {
218                        queue: VecDeque::new(),
219                        num_notify: 0,
220                        shutdown: false,
221                        shutdown_tx: Some(shutdown_tx),
222                        last_exiting_thread: None,
223                        worker_threads: HashMap::new(),
224                        worker_thread_index: 0,
225                    }),
226                    condvar: Condvar::new(),
227                    thread_name: builder.thread_name.clone(),
228                    stack_size: builder.thread_stack_size,
229                    after_start: builder.after_start.clone(),
230                    before_stop: builder.before_stop.clone(),
231                    thread_cap,
232                    keep_alive,
233                    metrics: SpawnerMetrics::default(),
234                }),
235            },
236            shutdown_rx,
237        }
238    }
239
240    pub(crate) fn spawner(&self) -> &Spawner {
241        &self.spawner
242    }
243
244    pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
245        let mut shared = self.spawner.inner.shared.lock();
246
247        if shared.shutdown {
251            return;
252        }
253
254        shared.shutdown = true;
255        shared.shutdown_tx = None;
256        self.spawner.inner.condvar.notify_all();
257
258        let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
259        let workers = std::mem::take(&mut shared.worker_threads);
260
261        drop(shared);
262
263        if self.shutdown_rx.wait(timeout) {
264            let _ = last_exited_thread.map(thread::JoinHandle::join);
265
266            #[cfg(loom)]
269            let workers: Vec<(usize, thread::JoinHandle<()>)> = {
270                let mut workers: Vec<_> = workers.into_iter().collect();
271                workers.sort_by_key(|(id, _)| *id);
272                workers
273            };
274
275            for (_id, handle) in workers {
276                let _ = handle.join();
277            }
278        }
279    }
280}
281
282impl Drop for BlockingPool {
283    fn drop(&mut self) {
284        self.shutdown(None);
285    }
286}
287
288impl fmt::Debug for BlockingPool {
289    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
290        fmt.debug_struct("BlockingPool").finish()
291    }
292}
293
294impl Spawner {
297    #[track_caller]
298    pub(crate) fn spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R>
299    where
300        F: FnOnce() -> R + Send + 'static,
301        R: Send + 'static,
302    {
303        let fn_size = std::mem::size_of::<F>();
304        let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
305            self.spawn_blocking_inner(
306                Box::new(func),
307                Mandatory::NonMandatory,
308                SpawnMeta::new_unnamed(fn_size),
309                rt,
310            )
311        } else {
312            self.spawn_blocking_inner(
313                func,
314                Mandatory::NonMandatory,
315                SpawnMeta::new_unnamed(fn_size),
316                rt,
317            )
318        };
319
320        match spawn_result {
321            Ok(()) => join_handle,
322            Err(SpawnError::ShuttingDown) => join_handle,
324            Err(SpawnError::NoThreads(e)) => {
325                panic!("OS can't spawn worker thread: {e}")
326            }
327        }
328    }
329
330    cfg_fs! {
331        #[track_caller]
332        #[cfg_attr(any(
333            all(loom, not(test)), test
335        ), allow(dead_code))]
336        pub(crate) fn spawn_mandatory_blocking<F, R>(&self, rt: &Handle, func: F) -> Option<JoinHandle<R>>
337        where
338            F: FnOnce() -> R + Send + 'static,
339            R: Send + 'static,
340        {
341            let fn_size = std::mem::size_of::<F>();
342            let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
343                self.spawn_blocking_inner(
344                    Box::new(func),
345                    Mandatory::Mandatory,
346                    SpawnMeta::new_unnamed(fn_size),
347                    rt,
348                )
349            } else {
350                self.spawn_blocking_inner(
351                    func,
352                    Mandatory::Mandatory,
353                    SpawnMeta::new_unnamed(fn_size),
354                    rt,
355                )
356            };
357
358            if spawn_result.is_ok() {
359                Some(join_handle)
360            } else {
361                None
362            }
363        }
364    }
365
366    #[track_caller]
367    pub(crate) fn spawn_blocking_inner<F, R>(
368        &self,
369        func: F,
370        is_mandatory: Mandatory,
371        spawn_meta: SpawnMeta<'_>,
372        rt: &Handle,
373    ) -> (JoinHandle<R>, Result<(), SpawnError>)
374    where
375        F: FnOnce() -> R + Send + 'static,
376        R: Send + 'static,
377    {
378        let id = task::Id::next();
379        let fut =
380            blocking_task::<F, BlockingTask<F>>(BlockingTask::new(func), spawn_meta, id.as_u64());
381
382        let (task, handle) = task::unowned(
383            fut,
384            BlockingSchedule::new(rt),
385            id,
386            task::SpawnLocation::capture(),
387        );
388
389        let spawned = self.spawn_task(Task::new(task, is_mandatory), rt);
390        (handle, spawned)
391    }
392
393    fn spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError> {
394        let mut shared = self.inner.shared.lock();
395
396        if shared.shutdown {
397            task.task.shutdown();
401
402            return Err(SpawnError::ShuttingDown);
404        }
405
406        shared.queue.push_back(task);
407        self.inner.metrics.inc_queue_depth();
408
409        if self.inner.metrics.num_idle_threads() == 0 {
410            if self.inner.metrics.num_threads() == self.inner.thread_cap {
413                } else {
415                assert!(shared.shutdown_tx.is_some());
416                let shutdown_tx = shared.shutdown_tx.clone();
417
418                if let Some(shutdown_tx) = shutdown_tx {
419                    let id = shared.worker_thread_index;
420
421                    match self.spawn_thread(shutdown_tx, rt, id) {
422                        Ok(handle) => {
423                            self.inner.metrics.inc_num_threads();
424                            shared.worker_thread_index += 1;
425                            shared.worker_threads.insert(id, handle);
426                        }
427                        Err(ref e)
428                            if is_temporary_os_thread_error(e)
429                                && self.inner.metrics.num_threads() > 0 =>
430                        {
431                            }
435                        Err(e) => {
436                            return Err(SpawnError::NoThreads(e));
439                        }
440                    }
441                }
442            }
443        } else {
444            self.inner.metrics.dec_num_idle_threads();
450            shared.num_notify += 1;
451            self.inner.condvar.notify_one();
452        }
453
454        Ok(())
455    }
456
457    fn spawn_thread(
458        &self,
459        shutdown_tx: shutdown::Sender,
460        rt: &Handle,
461        id: usize,
462    ) -> io::Result<thread::JoinHandle<()>> {
463        let mut builder = thread::Builder::new().name((self.inner.thread_name)());
464
465        if let Some(stack_size) = self.inner.stack_size {
466            builder = builder.stack_size(stack_size);
467        }
468
469        let rt = rt.clone();
470
471        builder.spawn(move || {
472            let _enter = rt.enter();
474            rt.inner.blocking_spawner().inner.run(id);
475            drop(shutdown_tx);
476        })
477    }
478}
479
480cfg_unstable_metrics! {
481    impl Spawner {
482        pub(crate) fn num_threads(&self) -> usize {
483            self.inner.metrics.num_threads()
484        }
485
486        pub(crate) fn num_idle_threads(&self) -> usize {
487            self.inner.metrics.num_idle_threads()
488        }
489
490        pub(crate) fn queue_depth(&self) -> usize {
491            self.inner.metrics.queue_depth()
492        }
493    }
494}
495
496#[inline]
498fn is_temporary_os_thread_error(error: &io::Error) -> bool {
499    matches!(error.kind(), io::ErrorKind::WouldBlock)
500}
501
502impl Inner {
503    fn run(&self, worker_thread_id: usize) {
504        if let Some(f) = &self.after_start {
505            f();
506        }
507
508        let mut shared = self.shared.lock();
509        let mut join_on_thread = None;
510
511        'main: loop {
512            while let Some(task) = shared.queue.pop_front() {
514                self.metrics.dec_queue_depth();
515                drop(shared);
516                task.run();
517
518                shared = self.shared.lock();
519            }
520
521            self.metrics.inc_num_idle_threads();
523
524            while !shared.shutdown {
525                let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
526
527                shared = lock_result.0;
528                let timeout_result = lock_result.1;
529
530                if shared.num_notify != 0 {
531                    shared.num_notify -= 1;
535                    break;
536                }
537
538                if !shared.shutdown && timeout_result.timed_out() {
541                    let my_handle = shared.worker_threads.remove(&worker_thread_id);
545                    join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
546
547                    break 'main;
548                }
549
550                }
552
553            if shared.shutdown {
554                while let Some(task) = shared.queue.pop_front() {
556                    self.metrics.dec_queue_depth();
557                    drop(shared);
558
559                    task.shutdown_or_run_if_mandatory();
560
561                    shared = self.shared.lock();
562                }
563
564                self.metrics.inc_num_idle_threads();
568                break;
571            }
572        }
573
574        self.metrics.dec_num_threads();
576
577        let prev_idle = self.metrics.dec_num_idle_threads();
581        assert!(
582            prev_idle >= self.metrics.num_idle_threads(),
583            "num_idle_threads underflowed on thread exit"
584        );
585
586        if shared.shutdown && self.metrics.num_threads() == 0 {
587            self.condvar.notify_one();
588        }
589
590        drop(shared);
591
592        if let Some(f) = &self.before_stop {
593            f();
594        }
595
596        if let Some(handle) = join_on_thread {
597            let _ = handle.join();
598        }
599    }
600}
601
602impl fmt::Debug for Spawner {
603    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
604        fmt.debug_struct("blocking::Spawner").finish()
605    }
606}