monoio/
runtime.rs

1use std::future::Future;
2
3#[cfg(any(all(target_os = "linux", feature = "iouring"), feature = "legacy"))]
4use crate::time::TimeDriver;
5#[cfg(all(target_os = "linux", feature = "iouring"))]
6use crate::IoUringDriver;
7#[cfg(feature = "legacy")]
8use crate::LegacyDriver;
9use crate::{
10    driver::Driver,
11    scheduler::{LocalScheduler, TaskQueue},
12    task::{
13        new_task,
14        waker_fn::{dummy_waker, set_poll, should_poll},
15        JoinHandle,
16    },
17    time::driver::Handle as TimeHandle,
18};
19
20#[cfg(feature = "sync")]
21thread_local! {
22    pub(crate) static DEFAULT_CTX: Context = Context {
23        thread_id: crate::utils::thread_id::DEFAULT_THREAD_ID,
24        unpark_cache: std::cell::RefCell::new(fxhash::FxHashMap::default()),
25        waker_sender_cache: std::cell::RefCell::new(fxhash::FxHashMap::default()),
26        tasks: Default::default(),
27        time_handle: None,
28        blocking_handle: crate::blocking::BlockingHandle::Empty(crate::blocking::BlockingStrategy::Panic),
29    };
30}
31
32scoped_thread_local!(pub(crate) static CURRENT: Context);
33
34pub(crate) struct Context {
35    /// Owned task set and local run queue
36    pub(crate) tasks: TaskQueue,
37
38    /// Thread id(not the kernel thread id but a generated unique number)
39    pub(crate) thread_id: usize,
40
41    /// Thread unpark handles
42    #[cfg(feature = "sync")]
43    pub(crate) unpark_cache:
44        std::cell::RefCell<fxhash::FxHashMap<usize, crate::driver::UnparkHandle>>,
45
46    /// Waker sender cache
47    #[cfg(feature = "sync")]
48    pub(crate) waker_sender_cache:
49        std::cell::RefCell<fxhash::FxHashMap<usize, flume::Sender<std::task::Waker>>>,
50
51    /// Time Handle
52    pub(crate) time_handle: Option<TimeHandle>,
53
54    /// Blocking Handle
55    #[cfg(feature = "sync")]
56    pub(crate) blocking_handle: crate::blocking::BlockingHandle,
57}
58
59impl Context {
60    #[cfg(feature = "sync")]
61    pub(crate) fn new(blocking_handle: crate::blocking::BlockingHandle) -> Self {
62        let thread_id = crate::builder::BUILD_THREAD_ID.with(|id| *id);
63
64        Self {
65            thread_id,
66            unpark_cache: std::cell::RefCell::new(fxhash::FxHashMap::default()),
67            waker_sender_cache: std::cell::RefCell::new(fxhash::FxHashMap::default()),
68            tasks: TaskQueue::default(),
69            time_handle: None,
70            blocking_handle,
71        }
72    }
73
74    #[cfg(not(feature = "sync"))]
75    pub(crate) fn new() -> Self {
76        let thread_id = crate::builder::BUILD_THREAD_ID.with(|id| *id);
77
78        Self {
79            thread_id,
80            tasks: TaskQueue::default(),
81            time_handle: None,
82        }
83    }
84
85    #[allow(unused)]
86    #[cfg(feature = "sync")]
87    pub(crate) fn unpark_thread(&self, id: usize) {
88        use crate::driver::{thread::get_unpark_handle, unpark::Unpark};
89        if let Some(handle) = self.unpark_cache.borrow().get(&id) {
90            handle.unpark();
91            return;
92        }
93
94        if let Some(v) = get_unpark_handle(id) {
95            // Write back to local cache
96            let w = v.clone();
97            self.unpark_cache.borrow_mut().insert(id, w);
98            v.unpark();
99        }
100    }
101
102    #[allow(unused)]
103    #[cfg(feature = "sync")]
104    pub(crate) fn send_waker(&self, id: usize, w: std::task::Waker) {
105        use crate::driver::thread::get_waker_sender;
106        if let Some(sender) = self.waker_sender_cache.borrow().get(&id) {
107            let _ = sender.send(w);
108            return;
109        }
110
111        if let Some(s) = get_waker_sender(id) {
112            // Write back to local cache
113            let _ = s.send(w);
114            self.waker_sender_cache.borrow_mut().insert(id, s);
115        }
116    }
117}
118
119/// Monoio runtime
120pub struct Runtime<D> {
121    pub(crate) context: Context,
122    pub(crate) driver: D,
123}
124
125impl<D> Runtime<D> {
126    pub(crate) fn new(context: Context, driver: D) -> Self {
127        Self { context, driver }
128    }
129
130    /// Block on
131    pub fn block_on<F>(&mut self, future: F) -> F::Output
132    where
133        F: Future,
134        D: Driver,
135    {
136        assert!(
137            !CURRENT.is_set(),
138            "Can not start a runtime inside a runtime"
139        );
140
141        let waker = dummy_waker();
142        let cx = &mut std::task::Context::from_waker(&waker);
143
144        self.driver.with(|| {
145            CURRENT.set(&self.context, || {
146                #[cfg(feature = "sync")]
147                let join = unsafe { spawn_without_static(future) };
148                #[cfg(not(feature = "sync"))]
149                let join = future;
150
151                let mut join = std::pin::pin!(join);
152                set_poll();
153                loop {
154                    loop {
155                        // Consume all tasks(with max round to prevent io starvation)
156                        let mut max_round = self.context.tasks.len() * 2;
157                        while let Some(t) = self.context.tasks.pop() {
158                            t.run();
159                            if max_round == 0 {
160                                // maybe there's a looping task
161                                break;
162                            } else {
163                                max_round -= 1;
164                            }
165                        }
166
167                        // Check main future
168                        while should_poll() {
169                            // check if ready
170                            if let std::task::Poll::Ready(t) = join.as_mut().poll(cx) {
171                                return t;
172                            }
173                        }
174
175                        if self.context.tasks.is_empty() {
176                            // No task to execute, we should wait for io blockingly
177                            // Hot path
178                            break;
179                        }
180
181                        // Cold path
182                        let _ = self.driver.submit();
183                    }
184
185                    // Wait and Process CQ(the error is ignored for not debug mode)
186                    #[cfg(not(all(debug_assertions, feature = "debug")))]
187                    let _ = self.driver.park();
188
189                    #[cfg(all(debug_assertions, feature = "debug"))]
190                    if let Err(e) = self.driver.park() {
191                        trace!("park error: {:?}", e);
192                    }
193                }
194            })
195        })
196    }
197}
198
199/// Fusion Runtime is a wrapper of io_uring driver or legacy driver based
200/// runtime.
201#[cfg(feature = "legacy")]
202pub enum FusionRuntime<#[cfg(all(target_os = "linux", feature = "iouring"))] L, R> {
203    /// Uring driver based runtime.
204    #[cfg(all(target_os = "linux", feature = "iouring"))]
205    Uring(Runtime<L>),
206    /// Legacy driver based runtime.
207    Legacy(Runtime<R>),
208}
209
210/// Fusion Runtime is a wrapper of io_uring driver or legacy driver based
211/// runtime.
212#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
213pub enum FusionRuntime<L> {
214    /// Uring driver based runtime.
215    Uring(Runtime<L>),
216}
217
218#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
219impl<L, R> FusionRuntime<L, R>
220where
221    L: Driver,
222    R: Driver,
223{
224    /// Block on
225    pub fn block_on<F>(&mut self, future: F) -> F::Output
226    where
227        F: Future,
228    {
229        match self {
230            FusionRuntime::Uring(inner) => {
231                info!("Monoio is running with io_uring driver");
232                inner.block_on(future)
233            }
234            FusionRuntime::Legacy(inner) => {
235                info!("Monoio is running with legacy driver");
236                inner.block_on(future)
237            }
238        }
239    }
240}
241
242#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
243impl<R> FusionRuntime<R>
244where
245    R: Driver,
246{
247    /// Block on
248    pub fn block_on<F>(&mut self, future: F) -> F::Output
249    where
250        F: Future,
251    {
252        match self {
253            FusionRuntime::Legacy(inner) => inner.block_on(future),
254        }
255    }
256}
257
258#[cfg(all(not(feature = "legacy"), all(target_os = "linux", feature = "iouring")))]
259impl<R> FusionRuntime<R>
260where
261    R: Driver,
262{
263    /// Block on
264    pub fn block_on<F>(&mut self, future: F) -> F::Output
265    where
266        F: Future,
267    {
268        match self {
269            FusionRuntime::Uring(inner) => inner.block_on(future),
270        }
271    }
272}
273
274// L -> Fusion<L, R>
275#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
276impl From<Runtime<IoUringDriver>> for FusionRuntime<IoUringDriver, LegacyDriver> {
277    fn from(r: Runtime<IoUringDriver>) -> Self {
278        Self::Uring(r)
279    }
280}
281
282// TL -> Fusion<TL, TR>
283#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
284impl From<Runtime<TimeDriver<IoUringDriver>>>
285    for FusionRuntime<TimeDriver<IoUringDriver>, TimeDriver<LegacyDriver>>
286{
287    fn from(r: Runtime<TimeDriver<IoUringDriver>>) -> Self {
288        Self::Uring(r)
289    }
290}
291
292// R -> Fusion<L, R>
293#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
294impl From<Runtime<LegacyDriver>> for FusionRuntime<IoUringDriver, LegacyDriver> {
295    fn from(r: Runtime<LegacyDriver>) -> Self {
296        Self::Legacy(r)
297    }
298}
299
300// TR -> Fusion<TL, TR>
301#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
302impl From<Runtime<TimeDriver<LegacyDriver>>>
303    for FusionRuntime<TimeDriver<IoUringDriver>, TimeDriver<LegacyDriver>>
304{
305    fn from(r: Runtime<TimeDriver<LegacyDriver>>) -> Self {
306        Self::Legacy(r)
307    }
308}
309
310// R -> Fusion<R>
311#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
312impl From<Runtime<LegacyDriver>> for FusionRuntime<LegacyDriver> {
313    fn from(r: Runtime<LegacyDriver>) -> Self {
314        Self::Legacy(r)
315    }
316}
317
318// TR -> Fusion<TR>
319#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
320impl From<Runtime<TimeDriver<LegacyDriver>>> for FusionRuntime<TimeDriver<LegacyDriver>> {
321    fn from(r: Runtime<TimeDriver<LegacyDriver>>) -> Self {
322        Self::Legacy(r)
323    }
324}
325
326// L -> Fusion<L>
327#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
328impl From<Runtime<IoUringDriver>> for FusionRuntime<IoUringDriver> {
329    fn from(r: Runtime<IoUringDriver>) -> Self {
330        Self::Uring(r)
331    }
332}
333
334// TL -> Fusion<TL>
335#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
336impl From<Runtime<TimeDriver<IoUringDriver>>> for FusionRuntime<TimeDriver<IoUringDriver>> {
337    fn from(r: Runtime<TimeDriver<IoUringDriver>>) -> Self {
338        Self::Uring(r)
339    }
340}
341
342/// Spawns a new asynchronous task, returning a [`JoinHandle`] for it.
343///
344/// Spawning a task enables the task to execute concurrently to other tasks.
345/// There is no guarantee that a spawned task will execute to completion. When a
346/// runtime is shutdown, all outstanding tasks are dropped, regardless of the
347/// lifecycle of that task.
348///
349///
350/// [`JoinHandle`]: super::task::JoinHandle
351///
352/// # Examples
353///
354/// In this example, a server is started and `spawn` is used to start a new task
355/// that processes each received connection.
356///
357/// ```no_run
358/// #[monoio::main]
359/// async fn main() {
360///     let handle = monoio::spawn(async {
361///         println!("hello from a background task");
362///     });
363///
364///     // Let the task complete
365///     handle.await;
366/// }
367/// ```
368pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
369where
370    T: Future + 'static,
371    T::Output: 'static,
372{
373    let (task, join) = new_task(
374        crate::utils::thread_id::get_current_thread_id(),
375        future,
376        LocalScheduler,
377    );
378
379    CURRENT.with(|ctx| {
380        ctx.tasks.push(task);
381    });
382    join
383}
384
385#[cfg(feature = "sync")]
386unsafe fn spawn_without_static<T>(future: T) -> JoinHandle<T::Output>
387where
388    T: Future,
389{
390    use crate::task::new_task_holding;
391    let (task, join) = new_task_holding(
392        crate::utils::thread_id::get_current_thread_id(),
393        future,
394        LocalScheduler,
395    );
396
397    CURRENT.with(|ctx| {
398        ctx.tasks.push(task);
399    });
400    join
401}
402
403#[cfg(test)]
404mod tests {
405    #[cfg(all(feature = "sync", target_os = "linux", feature = "iouring"))]
406    #[test]
407    fn across_thread() {
408        use futures::channel::oneshot;
409
410        use crate::driver::IoUringDriver;
411
412        let (tx1, rx1) = oneshot::channel::<u8>();
413        let (tx2, rx2) = oneshot::channel::<u8>();
414
415        std::thread::spawn(move || {
416            let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
417                .build()
418                .unwrap();
419            rt.block_on(async move {
420                let n = rx1.await.expect("unable to receive rx1");
421                assert!(tx2.send(n).is_ok());
422            });
423        });
424
425        let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
426            .build()
427            .unwrap();
428        rt.block_on(async move {
429            assert!(tx1.send(24).is_ok());
430            assert_eq!(rx2.await.expect("unable to receive rx2"), 24);
431        });
432    }
433
434    #[cfg(all(target_os = "linux", feature = "iouring"))]
435    #[test]
436    fn timer() {
437        use crate::driver::IoUringDriver;
438        let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
439            .enable_timer()
440            .build()
441            .unwrap();
442        let instant = std::time::Instant::now();
443        rt.block_on(async {
444            crate::time::sleep(std::time::Duration::from_millis(200)).await;
445        });
446        let eps = instant.elapsed().subsec_millis();
447        assert!((eps as i32 - 200).abs() < 50);
448    }
449}