monoio/
blocking.rs

1//! Blocking tasks related.
2
3use std::{future::Future, task::Poll};
4
5use threadpool::{Builder as ThreadPoolBuilder, ThreadPool as ThreadPoolImpl};
6
7use crate::{
8    task::{new_task, JoinHandle},
9    utils::thread_id::DEFAULT_THREAD_ID,
10};
11
12/// Users may implement a ThreadPool and attach it to runtime.
13/// We also provide an implementation based on threadpool crate, you can use DefaultThreadPool.
14pub trait ThreadPool {
15    /// Monoio runtime will call `schedule_task` on `spawn_blocking`.
16    /// ThreadPool impl must execute it now or later.
17    fn schedule_task(&self, task: BlockingTask);
18}
19
20/// Error on waiting blocking task.
21#[derive(Debug, Clone, Copy)]
22pub enum JoinError {
23    /// Task is canceled.
24    Canceled,
25}
26
27/// BlockingTask is contrusted by monoio, ThreadPool impl
28/// will execute it with `.run()`.
29pub struct BlockingTask {
30    task: Option<crate::task::Task<NoopScheduler>>,
31    blocking_vtable: &'static BlockingTaskVtable,
32}
33
34unsafe impl Send for BlockingTask {}
35
36struct BlockingTaskVtable {
37    pub(crate) drop: unsafe fn(&mut crate::task::Task<NoopScheduler>),
38}
39
40fn blocking_vtable<V>() -> &'static BlockingTaskVtable {
41    &BlockingTaskVtable {
42        drop: blocking_task_drop::<V>,
43    }
44}
45
46fn blocking_task_drop<V>(task: &mut crate::task::Task<NoopScheduler>) {
47    let mut opt: Option<Result<V, JoinError>> = Some(Err(JoinError::Canceled));
48    unsafe { task.finish((&mut opt) as *mut _ as *mut ()) };
49}
50
51impl Drop for BlockingTask {
52    fn drop(&mut self) {
53        if let Some(task) = self.task.as_mut() {
54            unsafe { (self.blocking_vtable.drop)(task) };
55        }
56    }
57}
58
59impl BlockingTask {
60    /// Run task.
61    #[inline]
62    pub fn run(mut self) {
63        let task = self.task.take().unwrap();
64        task.run();
65        // // if we are within a runtime, just run it.
66        // if crate::runtime::CURRENT.is_set() {
67        //     task.run();
68        //     return;
69        // }
70        // // if we are on a standalone thread, we will use thread local ctx as Context.
71        // crate::runtime::DEFAULT_CTX.with(|ctx| {
72        //     crate::runtime::CURRENT.set(ctx, || task.run());
73        // });
74    }
75}
76
77/// BlockingStrategy can be set if there is no ThreadPool attached.
78/// It controls how to handle `spawn_blocking` without thread pool.
79#[derive(Clone, Copy, Debug)]
80pub enum BlockingStrategy {
81    /// Panic when `spawn_blocking`.
82    Panic,
83    /// Execute with current thread when `spawn_blocking`.
84    ExecuteLocal,
85}
86
87/// `spawn_blocking` is used for executing a task(without async) with heavy computation or blocking
88/// io.
89///
90/// To used it, users may initialize a thread pool and attach it on creating runtime.
91/// Users can also set `BlockingStrategy` for a runtime when there is no thread pool.
92/// WARNING: DO NOT USE THIS FOR ASYNC TASK! Async tasks will not be executed but only built the
93/// future!
94pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<Result<R, JoinError>>
95where
96    F: FnOnce() -> R + Send + 'static,
97    R: Send + 'static,
98{
99    let fut = BlockingFuture(Some(func));
100    let (task, join) = new_task(DEFAULT_THREAD_ID, fut, NoopScheduler);
101    crate::runtime::CURRENT.with(|inner| {
102        let handle = &inner.blocking_handle;
103        match handle {
104            BlockingHandle::Attached(shared) => shared.schedule_task(BlockingTask {
105                task: Some(task),
106                blocking_vtable: blocking_vtable::<R>(),
107            }),
108            BlockingHandle::Empty(BlockingStrategy::ExecuteLocal) => task.run(),
109            BlockingHandle::Empty(BlockingStrategy::Panic) => {
110                // For users: if you see this panic, you have 2 choices:
111                // 1. attach a shared thread pool to execute blocking tasks
112                // 2. set runtime blocking strategy to `BlockingStrategy::ExecuteLocal`
113                // Note: solution 2 will execute blocking task on current thread and may block other
114                // tasks This may cause other tasks high latency.
115                panic!("execute blocking task without thread pool attached")
116            }
117        }
118    });
119
120    join
121}
122
123/// DefaultThreadPool is a simple wrapped `threadpool::ThreadPool` that implement
124/// `monoio::blocking::ThreadPool`. You may use this implementation, or you can use your own thread
125/// pool implementation.
126#[derive(Clone)]
127pub struct DefaultThreadPool {
128    pool: ThreadPoolImpl,
129}
130
131impl DefaultThreadPool {
132    /// Create a new DefaultThreadPool.
133    pub fn new(num_threads: usize) -> Self {
134        let pool = ThreadPoolBuilder::default()
135            .num_threads(num_threads)
136            .build();
137        Self { pool }
138    }
139}
140
141impl ThreadPool for DefaultThreadPool {
142    #[inline]
143    fn schedule_task(&self, task: BlockingTask) {
144        self.pool.execute(move || task.run());
145    }
146}
147
148pub(crate) struct NoopScheduler;
149
150impl crate::task::Schedule for NoopScheduler {
151    fn schedule(&self, _task: crate::task::Task<Self>) {
152        unreachable!()
153    }
154
155    fn yield_now(&self, _task: crate::task::Task<Self>) {
156        unreachable!()
157    }
158}
159
160pub(crate) enum BlockingHandle {
161    Attached(Box<dyn crate::blocking::ThreadPool + Send + 'static>),
162    Empty(BlockingStrategy),
163}
164
165impl From<BlockingStrategy> for BlockingHandle {
166    fn from(value: BlockingStrategy) -> Self {
167        Self::Empty(value)
168    }
169}
170
171struct BlockingFuture<F>(Option<F>);
172
173impl<T> Unpin for BlockingFuture<T> {}
174
175impl<F, R> Future for BlockingFuture<F>
176where
177    F: FnOnce() -> R + Send + 'static,
178    R: Send + 'static,
179{
180    type Output = Result<R, JoinError>;
181
182    fn poll(
183        mut self: std::pin::Pin<&mut Self>,
184        _cx: &mut std::task::Context<'_>,
185    ) -> std::task::Poll<Self::Output> {
186        let me = &mut *self;
187        let func = me.0.take().expect("blocking task ran twice.");
188        Poll::Ready(Ok(func()))
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::DefaultThreadPool;
195
196    /// NaiveThreadPool always create a new thread on executing tasks.
197    struct NaiveThreadPool;
198
199    impl super::ThreadPool for NaiveThreadPool {
200        fn schedule_task(&self, task: super::BlockingTask) {
201            std::thread::spawn(move || {
202                task.run();
203            });
204        }
205    }
206
207    /// FakeThreadPool always drop tasks.
208    struct FakeThreadPool;
209
210    impl super::ThreadPool for FakeThreadPool {
211        fn schedule_task(&self, _task: super::BlockingTask) {}
212    }
213
214    #[test]
215    fn hello_blocking() {
216        let shared_pool = Box::new(NaiveThreadPool);
217        let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
218            .attach_thread_pool(shared_pool)
219            .enable_timer()
220            .build()
221            .unwrap();
222        rt.block_on(async {
223            let begin = std::time::Instant::now();
224            let join = crate::spawn_blocking(|| {
225                // Simulate a heavy computation.
226                std::thread::sleep(std::time::Duration::from_millis(400));
227                "hello spawn_blocking!".to_string()
228            });
229            let sleep_async = crate::time::sleep(std::time::Duration::from_millis(400));
230            let (result, _) = crate::join!(join, sleep_async);
231            let eps = begin.elapsed();
232            assert!(eps < std::time::Duration::from_millis(800));
233            assert!(eps >= std::time::Duration::from_millis(400));
234            assert_eq!(result.unwrap(), "hello spawn_blocking!");
235        });
236    }
237
238    #[test]
239    #[should_panic]
240    fn blocking_panic() {
241        let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
242            .with_blocking_strategy(crate::blocking::BlockingStrategy::Panic)
243            .enable_timer()
244            .build()
245            .unwrap();
246        rt.block_on(async {
247            let join = crate::spawn_blocking(|| 1);
248            let _ = join.await;
249        });
250    }
251
252    #[test]
253    fn blocking_current() {
254        let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
255            .with_blocking_strategy(crate::blocking::BlockingStrategy::ExecuteLocal)
256            .enable_timer()
257            .build()
258            .unwrap();
259        rt.block_on(async {
260            let begin = std::time::Instant::now();
261            let join = crate::spawn_blocking(|| {
262                // Simulate a heavy computation.
263                std::thread::sleep(std::time::Duration::from_millis(100));
264                "hello spawn_blocking!".to_string()
265            });
266            let sleep_async = crate::time::sleep(std::time::Duration::from_millis(100));
267            let (result, _) = crate::join!(join, sleep_async);
268            let eps = begin.elapsed();
269            assert!(eps > std::time::Duration::from_millis(200));
270            assert_eq!(result.unwrap(), "hello spawn_blocking!");
271        });
272    }
273
274    #[test]
275    fn drop_task() {
276        let shared_pool = Box::new(FakeThreadPool);
277        let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
278            .attach_thread_pool(shared_pool)
279            .enable_timer()
280            .build()
281            .unwrap();
282        rt.block_on(async {
283            let ret = crate::spawn_blocking(|| 1).await;
284            assert!(matches!(ret, Err(super::JoinError::Canceled)));
285        });
286    }
287
288    #[test]
289    fn default_pool() {
290        let shared_pool = Box::new(DefaultThreadPool::new(6));
291        let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
292            .attach_thread_pool(shared_pool)
293            .enable_timer()
294            .build()
295            .unwrap();
296        macro_rules! thread_sleep {
297            ($s:expr) => {
298                || {
299                    // Simulate a heavy computation.
300                    std::thread::sleep(std::time::Duration::from_millis(500));
301                    $s
302                }
303            };
304        }
305        rt.block_on(async {
306            let begin = std::time::Instant::now();
307            let join1 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking1!"));
308            let join2 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking2!"));
309            let join3 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking3!"));
310            let join4 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking4!"));
311            let join5 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking5!"));
312            let join6 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking6!"));
313            let sleep_async = crate::time::sleep(std::time::Duration::from_millis(500));
314            let (result1, result2, result3, result4, result5, result6, _) =
315                crate::join!(join1, join2, join3, join4, join5, join6, sleep_async);
316            let eps = begin.elapsed();
317            assert!(eps < std::time::Duration::from_millis(3000));
318            assert!(eps >= std::time::Duration::from_millis(500));
319            assert_eq!(result1.unwrap(), "hello spawn_blocking1!");
320            assert_eq!(result2.unwrap(), "hello spawn_blocking2!");
321            assert_eq!(result3.unwrap(), "hello spawn_blocking3!");
322            assert_eq!(result4.unwrap(), "hello spawn_blocking4!");
323            assert_eq!(result5.unwrap(), "hello spawn_blocking5!");
324            assert_eq!(result6.unwrap(), "hello spawn_blocking6!");
325        });
326    }
327}