1use std::{future::Future, task::Poll};
4
5use threadpool::{Builder as ThreadPoolBuilder, ThreadPool as ThreadPoolImpl};
6
7use crate::{
8 task::{new_task, JoinHandle},
9 utils::thread_id::DEFAULT_THREAD_ID,
10};
11
12pub trait ThreadPool {
15 fn schedule_task(&self, task: BlockingTask);
18}
19
20#[derive(Debug, Clone, Copy)]
22pub enum JoinError {
23 Canceled,
25}
26
27pub struct BlockingTask {
30 task: Option<crate::task::Task<NoopScheduler>>,
31 blocking_vtable: &'static BlockingTaskVtable,
32}
33
34unsafe impl Send for BlockingTask {}
35
36struct BlockingTaskVtable {
37 pub(crate) drop: unsafe fn(&mut crate::task::Task<NoopScheduler>),
38}
39
40fn blocking_vtable<V>() -> &'static BlockingTaskVtable {
41 &BlockingTaskVtable {
42 drop: blocking_task_drop::<V>,
43 }
44}
45
46fn blocking_task_drop<V>(task: &mut crate::task::Task<NoopScheduler>) {
47 let mut opt: Option<Result<V, JoinError>> = Some(Err(JoinError::Canceled));
48 unsafe { task.finish((&mut opt) as *mut _ as *mut ()) };
49}
50
51impl Drop for BlockingTask {
52 fn drop(&mut self) {
53 if let Some(task) = self.task.as_mut() {
54 unsafe { (self.blocking_vtable.drop)(task) };
55 }
56 }
57}
58
59impl BlockingTask {
60 #[inline]
62 pub fn run(mut self) {
63 let task = self.task.take().unwrap();
64 task.run();
65 }
75}
76
77#[derive(Clone, Copy, Debug)]
80pub enum BlockingStrategy {
81 Panic,
83 ExecuteLocal,
85}
86
87pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<Result<R, JoinError>>
95where
96 F: FnOnce() -> R + Send + 'static,
97 R: Send + 'static,
98{
99 let fut = BlockingFuture(Some(func));
100 let (task, join) = new_task(DEFAULT_THREAD_ID, fut, NoopScheduler);
101 crate::runtime::CURRENT.with(|inner| {
102 let handle = &inner.blocking_handle;
103 match handle {
104 BlockingHandle::Attached(shared) => shared.schedule_task(BlockingTask {
105 task: Some(task),
106 blocking_vtable: blocking_vtable::<R>(),
107 }),
108 BlockingHandle::Empty(BlockingStrategy::ExecuteLocal) => task.run(),
109 BlockingHandle::Empty(BlockingStrategy::Panic) => {
110 panic!("execute blocking task without thread pool attached")
116 }
117 }
118 });
119
120 join
121}
122
123#[derive(Clone)]
127pub struct DefaultThreadPool {
128 pool: ThreadPoolImpl,
129}
130
131impl DefaultThreadPool {
132 pub fn new(num_threads: usize) -> Self {
134 let pool = ThreadPoolBuilder::default()
135 .num_threads(num_threads)
136 .build();
137 Self { pool }
138 }
139}
140
141impl ThreadPool for DefaultThreadPool {
142 #[inline]
143 fn schedule_task(&self, task: BlockingTask) {
144 self.pool.execute(move || task.run());
145 }
146}
147
148pub(crate) struct NoopScheduler;
149
150impl crate::task::Schedule for NoopScheduler {
151 fn schedule(&self, _task: crate::task::Task<Self>) {
152 unreachable!()
153 }
154
155 fn yield_now(&self, _task: crate::task::Task<Self>) {
156 unreachable!()
157 }
158}
159
160pub(crate) enum BlockingHandle {
161 Attached(Box<dyn crate::blocking::ThreadPool + Send + 'static>),
162 Empty(BlockingStrategy),
163}
164
165impl From<BlockingStrategy> for BlockingHandle {
166 fn from(value: BlockingStrategy) -> Self {
167 Self::Empty(value)
168 }
169}
170
171struct BlockingFuture<F>(Option<F>);
172
173impl<T> Unpin for BlockingFuture<T> {}
174
175impl<F, R> Future for BlockingFuture<F>
176where
177 F: FnOnce() -> R + Send + 'static,
178 R: Send + 'static,
179{
180 type Output = Result<R, JoinError>;
181
182 fn poll(
183 mut self: std::pin::Pin<&mut Self>,
184 _cx: &mut std::task::Context<'_>,
185 ) -> std::task::Poll<Self::Output> {
186 let me = &mut *self;
187 let func = me.0.take().expect("blocking task ran twice.");
188 Poll::Ready(Ok(func()))
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::DefaultThreadPool;
195
196 struct NaiveThreadPool;
198
199 impl super::ThreadPool for NaiveThreadPool {
200 fn schedule_task(&self, task: super::BlockingTask) {
201 std::thread::spawn(move || {
202 task.run();
203 });
204 }
205 }
206
207 struct FakeThreadPool;
209
210 impl super::ThreadPool for FakeThreadPool {
211 fn schedule_task(&self, _task: super::BlockingTask) {}
212 }
213
214 #[test]
215 fn hello_blocking() {
216 let shared_pool = Box::new(NaiveThreadPool);
217 let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
218 .attach_thread_pool(shared_pool)
219 .enable_timer()
220 .build()
221 .unwrap();
222 rt.block_on(async {
223 let begin = std::time::Instant::now();
224 let join = crate::spawn_blocking(|| {
225 std::thread::sleep(std::time::Duration::from_millis(400));
227 "hello spawn_blocking!".to_string()
228 });
229 let sleep_async = crate::time::sleep(std::time::Duration::from_millis(400));
230 let (result, _) = crate::join!(join, sleep_async);
231 let eps = begin.elapsed();
232 assert!(eps < std::time::Duration::from_millis(800));
233 assert!(eps >= std::time::Duration::from_millis(400));
234 assert_eq!(result.unwrap(), "hello spawn_blocking!");
235 });
236 }
237
238 #[test]
239 #[should_panic]
240 fn blocking_panic() {
241 let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
242 .with_blocking_strategy(crate::blocking::BlockingStrategy::Panic)
243 .enable_timer()
244 .build()
245 .unwrap();
246 rt.block_on(async {
247 let join = crate::spawn_blocking(|| 1);
248 let _ = join.await;
249 });
250 }
251
252 #[test]
253 fn blocking_current() {
254 let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
255 .with_blocking_strategy(crate::blocking::BlockingStrategy::ExecuteLocal)
256 .enable_timer()
257 .build()
258 .unwrap();
259 rt.block_on(async {
260 let begin = std::time::Instant::now();
261 let join = crate::spawn_blocking(|| {
262 std::thread::sleep(std::time::Duration::from_millis(100));
264 "hello spawn_blocking!".to_string()
265 });
266 let sleep_async = crate::time::sleep(std::time::Duration::from_millis(100));
267 let (result, _) = crate::join!(join, sleep_async);
268 let eps = begin.elapsed();
269 assert!(eps > std::time::Duration::from_millis(200));
270 assert_eq!(result.unwrap(), "hello spawn_blocking!");
271 });
272 }
273
274 #[test]
275 fn drop_task() {
276 let shared_pool = Box::new(FakeThreadPool);
277 let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
278 .attach_thread_pool(shared_pool)
279 .enable_timer()
280 .build()
281 .unwrap();
282 rt.block_on(async {
283 let ret = crate::spawn_blocking(|| 1).await;
284 assert!(matches!(ret, Err(super::JoinError::Canceled)));
285 });
286 }
287
288 #[test]
289 fn default_pool() {
290 let shared_pool = Box::new(DefaultThreadPool::new(6));
291 let mut rt = crate::RuntimeBuilder::<crate::FusionDriver>::new()
292 .attach_thread_pool(shared_pool)
293 .enable_timer()
294 .build()
295 .unwrap();
296 macro_rules! thread_sleep {
297 ($s:expr) => {
298 || {
299 std::thread::sleep(std::time::Duration::from_millis(500));
301 $s
302 }
303 };
304 }
305 rt.block_on(async {
306 let begin = std::time::Instant::now();
307 let join1 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking1!"));
308 let join2 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking2!"));
309 let join3 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking3!"));
310 let join4 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking4!"));
311 let join5 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking5!"));
312 let join6 = crate::spawn_blocking(thread_sleep!("hello spawn_blocking6!"));
313 let sleep_async = crate::time::sleep(std::time::Duration::from_millis(500));
314 let (result1, result2, result3, result4, result5, result6, _) =
315 crate::join!(join1, join2, join3, join4, join5, join6, sleep_async);
316 let eps = begin.elapsed();
317 assert!(eps < std::time::Duration::from_millis(3000));
318 assert!(eps >= std::time::Duration::from_millis(500));
319 assert_eq!(result1.unwrap(), "hello spawn_blocking1!");
320 assert_eq!(result2.unwrap(), "hello spawn_blocking2!");
321 assert_eq!(result3.unwrap(), "hello spawn_blocking3!");
322 assert_eq!(result4.unwrap(), "hello spawn_blocking4!");
323 assert_eq!(result5.unwrap(), "hello spawn_blocking5!");
324 assert_eq!(result6.unwrap(), "hello spawn_blocking6!");
325 });
326 }
327}