1use std::future::Future;
2#[cfg(feature = "sync")]
3use std::sync::Arc;
4
5#[cfg(feature = "sync")]
6use crossbeam_queue::SegQueue;
7
8#[cfg(any(all(target_os = "linux", feature = "iouring"), feature = "legacy"))]
9use crate::time::TimeDriver;
10#[cfg(all(target_os = "linux", feature = "iouring"))]
11use crate::IoUringDriver;
12#[cfg(feature = "legacy")]
13use crate::LegacyDriver;
14use crate::{
15 driver::Driver,
16 scheduler::{LocalScheduler, TaskQueue},
17 task::{
18 new_task,
19 waker_fn::{dummy_waker, set_poll, should_poll},
20 JoinHandle,
21 },
22 time::driver::Handle as TimeHandle,
23};
24
25#[cfg(feature = "sync")]
26thread_local! {
27 pub(crate) static DEFAULT_CTX: Context = Context {
28 thread_id: crate::utils::thread_id::DEFAULT_THREAD_ID,
29 unpark_cache: std::cell::RefCell::new(rustc_hash::FxHashMap::default()),
30 waker_queue_cache: std::cell::RefCell::new(rustc_hash::FxHashMap::default()),
31 tasks: Default::default(),
32 time_handle: None,
33 blocking_handle: crate::blocking::BlockingHandle::Empty(crate::blocking::BlockingStrategy::Panic),
34 };
35}
36
37scoped_thread_local!(pub(crate) static CURRENT: Context);
38
39pub(crate) struct Context {
40 pub(crate) tasks: TaskQueue,
42
43 pub(crate) thread_id: usize,
45
46 #[cfg(feature = "sync")]
48 pub(crate) unpark_cache:
49 std::cell::RefCell<rustc_hash::FxHashMap<usize, crate::driver::UnparkHandle>>,
50
51 #[cfg(feature = "sync")]
53 pub(crate) waker_queue_cache:
54 std::cell::RefCell<rustc_hash::FxHashMap<usize, Arc<SegQueue<std::task::Waker>>>>,
55
56 pub(crate) time_handle: Option<TimeHandle>,
58
59 #[cfg(feature = "sync")]
61 pub(crate) blocking_handle: crate::blocking::BlockingHandle,
62}
63
64impl Context {
65 #[cfg(feature = "sync")]
66 pub(crate) fn new(blocking_handle: crate::blocking::BlockingHandle) -> Self {
67 let thread_id = crate::builder::BUILD_THREAD_ID.with(|id| *id);
68
69 Self {
70 thread_id,
71 unpark_cache: std::cell::RefCell::new(rustc_hash::FxHashMap::default()),
72 waker_queue_cache: std::cell::RefCell::new(rustc_hash::FxHashMap::default()),
73 tasks: TaskQueue::default(),
74 time_handle: None,
75 blocking_handle,
76 }
77 }
78
79 #[cfg(not(feature = "sync"))]
80 pub(crate) fn new() -> Self {
81 let thread_id = crate::builder::BUILD_THREAD_ID.with(|id| *id);
82
83 Self {
84 thread_id,
85 tasks: TaskQueue::default(),
86 time_handle: None,
87 }
88 }
89
90 #[allow(unused)]
91 #[cfg(feature = "sync")]
92 pub(crate) fn unpark_thread(&self, id: usize) {
93 use crate::driver::{thread::get_unpark_handle, unpark::Unpark};
94 if let Some(handle) = self.unpark_cache.borrow().get(&id) {
95 handle.unpark();
96 return;
97 }
98
99 if let Some(v) = get_unpark_handle(id) {
100 let w = v.clone();
102 self.unpark_cache.borrow_mut().insert(id, w);
103 v.unpark();
104 }
105 }
106
107 #[allow(unused)]
108 #[cfg(feature = "sync")]
109 pub(crate) fn send_waker(&self, id: usize, w: std::task::Waker) {
110 use crate::driver::thread::get_waker_queue;
111 if let Some(sender) = self.waker_queue_cache.borrow().get(&id) {
112 let _ = sender.push(w);
113 return;
114 }
115
116 if let Some(s) = get_waker_queue(id) {
117 let _ = s.push(w);
119 self.waker_queue_cache.borrow_mut().insert(id, s);
120 }
121 }
122}
123
124pub struct Runtime<D> {
126 pub(crate) context: Context,
127 pub(crate) driver: D,
128}
129
130impl<D> Runtime<D> {
131 pub(crate) fn new(context: Context, driver: D) -> Self {
132 Self { context, driver }
133 }
134
135 pub fn block_on<F>(&mut self, future: F) -> F::Output
137 where
138 F: Future,
139 D: Driver,
140 {
141 assert!(
142 !CURRENT.is_set(),
143 "Can not start a runtime inside a runtime"
144 );
145
146 let waker = dummy_waker();
147 let cx = &mut std::task::Context::from_waker(&waker);
148
149 self.driver.with(|| {
150 CURRENT.set(&self.context, || {
151 #[cfg(feature = "sync")]
152 let join = unsafe { spawn_without_static(future) };
153 #[cfg(not(feature = "sync"))]
154 let join = future;
155
156 let mut join = std::pin::pin!(join);
157 set_poll();
158 loop {
159 loop {
160 let mut max_round = self.context.tasks.len() * 2;
162 while let Some(t) = self.context.tasks.pop() {
163 t.run();
164 if max_round == 0 {
165 break;
167 } else {
168 max_round -= 1;
169 }
170 }
171
172 while should_poll() {
174 if let std::task::Poll::Ready(t) = join.as_mut().poll(cx) {
176 return t;
177 }
178 }
179
180 if self.context.tasks.is_empty() {
181 break;
184 }
185
186 let _ = self.driver.submit();
188 }
189
190 #[cfg(not(all(debug_assertions, feature = "debug")))]
192 let _ = self.driver.park();
193
194 #[cfg(all(debug_assertions, feature = "debug"))]
195 if let Err(e) = self.driver.park() {
196 trace!("park error: {:?}", e);
197 }
198 }
199 })
200 })
201 }
202}
203
204#[cfg(feature = "legacy")]
207pub enum FusionRuntime<#[cfg(all(target_os = "linux", feature = "iouring"))] L, R> {
208 #[cfg(all(target_os = "linux", feature = "iouring"))]
210 Uring(Runtime<L>),
211 Legacy(Runtime<R>),
213}
214
215#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
218pub enum FusionRuntime<L> {
219 Uring(Runtime<L>),
221}
222
223#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
224impl<L, R> FusionRuntime<L, R>
225where
226 L: Driver,
227 R: Driver,
228{
229 pub fn block_on<F>(&mut self, future: F) -> F::Output
231 where
232 F: Future,
233 {
234 match self {
235 FusionRuntime::Uring(inner) => {
236 info!("Monoio is running with io_uring driver");
237 inner.block_on(future)
238 }
239 FusionRuntime::Legacy(inner) => {
240 info!("Monoio is running with legacy driver");
241 inner.block_on(future)
242 }
243 }
244 }
245}
246
247#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
248impl<R> FusionRuntime<R>
249where
250 R: Driver,
251{
252 pub fn block_on<F>(&mut self, future: F) -> F::Output
254 where
255 F: Future,
256 {
257 match self {
258 FusionRuntime::Legacy(inner) => inner.block_on(future),
259 }
260 }
261}
262
263#[cfg(all(not(feature = "legacy"), all(target_os = "linux", feature = "iouring")))]
264impl<R> FusionRuntime<R>
265where
266 R: Driver,
267{
268 pub fn block_on<F>(&mut self, future: F) -> F::Output
270 where
271 F: Future,
272 {
273 match self {
274 FusionRuntime::Uring(inner) => inner.block_on(future),
275 }
276 }
277}
278
279#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
281impl From<Runtime<IoUringDriver>> for FusionRuntime<IoUringDriver, LegacyDriver> {
282 fn from(r: Runtime<IoUringDriver>) -> Self {
283 Self::Uring(r)
284 }
285}
286
287#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
289impl From<Runtime<TimeDriver<IoUringDriver>>>
290 for FusionRuntime<TimeDriver<IoUringDriver>, TimeDriver<LegacyDriver>>
291{
292 fn from(r: Runtime<TimeDriver<IoUringDriver>>) -> Self {
293 Self::Uring(r)
294 }
295}
296
297#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
299impl From<Runtime<LegacyDriver>> for FusionRuntime<IoUringDriver, LegacyDriver> {
300 fn from(r: Runtime<LegacyDriver>) -> Self {
301 Self::Legacy(r)
302 }
303}
304
305#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
307impl From<Runtime<TimeDriver<LegacyDriver>>>
308 for FusionRuntime<TimeDriver<IoUringDriver>, TimeDriver<LegacyDriver>>
309{
310 fn from(r: Runtime<TimeDriver<LegacyDriver>>) -> Self {
311 Self::Legacy(r)
312 }
313}
314
315#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
317impl From<Runtime<LegacyDriver>> for FusionRuntime<LegacyDriver> {
318 fn from(r: Runtime<LegacyDriver>) -> Self {
319 Self::Legacy(r)
320 }
321}
322
323#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
325impl From<Runtime<TimeDriver<LegacyDriver>>> for FusionRuntime<TimeDriver<LegacyDriver>> {
326 fn from(r: Runtime<TimeDriver<LegacyDriver>>) -> Self {
327 Self::Legacy(r)
328 }
329}
330
331#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
333impl From<Runtime<IoUringDriver>> for FusionRuntime<IoUringDriver> {
334 fn from(r: Runtime<IoUringDriver>) -> Self {
335 Self::Uring(r)
336 }
337}
338
339#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
341impl From<Runtime<TimeDriver<IoUringDriver>>> for FusionRuntime<TimeDriver<IoUringDriver>> {
342 fn from(r: Runtime<TimeDriver<IoUringDriver>>) -> Self {
343 Self::Uring(r)
344 }
345}
346
347pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
374where
375 T: Future + 'static,
376 T::Output: 'static,
377{
378 let (task, join) = new_task(
379 crate::utils::thread_id::get_current_thread_id(),
380 future,
381 LocalScheduler,
382 );
383
384 CURRENT.with(|ctx| {
385 ctx.tasks.push(task);
386 });
387 join
388}
389
390#[cfg(feature = "sync")]
391unsafe fn spawn_without_static<T>(future: T) -> JoinHandle<T::Output>
392where
393 T: Future,
394{
395 use crate::task::new_task_holding;
396 let (task, join) = new_task_holding(
397 crate::utils::thread_id::get_current_thread_id(),
398 future,
399 LocalScheduler,
400 );
401
402 CURRENT.with(|ctx| {
403 ctx.tasks.push(task);
404 });
405 join
406}
407
408#[cfg(test)]
409mod tests {
410 #[cfg(all(feature = "sync", target_os = "linux", feature = "iouring"))]
411 #[test]
412 fn across_thread() {
413 use futures::channel::oneshot;
414
415 use crate::driver::IoUringDriver;
416
417 let (tx1, rx1) = oneshot::channel::<u8>();
418 let (tx2, rx2) = oneshot::channel::<u8>();
419
420 std::thread::spawn(move || {
421 let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
422 .build()
423 .unwrap();
424 rt.block_on(async move {
425 let n = rx1.await.expect("unable to receive rx1");
426 assert!(tx2.send(n).is_ok());
427 });
428 });
429
430 let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
431 .build()
432 .unwrap();
433 rt.block_on(async move {
434 assert!(tx1.send(24).is_ok());
435 assert_eq!(rx2.await.expect("unable to receive rx2"), 24);
436 });
437 }
438
439 #[cfg(all(target_os = "linux", feature = "iouring"))]
440 #[test]
441 fn timer() {
442 use crate::driver::IoUringDriver;
443 let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
444 .enable_timer()
445 .build()
446 .unwrap();
447 let instant = std::time::Instant::now();
448 rt.block_on(async {
449 crate::time::sleep(std::time::Duration::from_millis(200)).await;
450 });
451 let eps = instant.elapsed().subsec_millis();
452 assert!((eps as i32 - 200).abs() < 50);
453 }
454}