1use std::future::Future;
2
3#[cfg(any(all(target_os = "linux", feature = "iouring"), feature = "legacy"))]
4use crate::time::TimeDriver;
5#[cfg(all(target_os = "linux", feature = "iouring"))]
6use crate::IoUringDriver;
7#[cfg(feature = "legacy")]
8use crate::LegacyDriver;
9use crate::{
10 driver::Driver,
11 scheduler::{LocalScheduler, TaskQueue},
12 task::{
13 new_task,
14 waker_fn::{dummy_waker, set_poll, should_poll},
15 JoinHandle,
16 },
17 time::driver::Handle as TimeHandle,
18};
19
20#[cfg(feature = "sync")]
21thread_local! {
22 pub(crate) static DEFAULT_CTX: Context = Context {
23 thread_id: crate::utils::thread_id::DEFAULT_THREAD_ID,
24 unpark_cache: std::cell::RefCell::new(fxhash::FxHashMap::default()),
25 waker_sender_cache: std::cell::RefCell::new(fxhash::FxHashMap::default()),
26 tasks: Default::default(),
27 time_handle: None,
28 blocking_handle: crate::blocking::BlockingHandle::Empty(crate::blocking::BlockingStrategy::Panic),
29 };
30}
31
32scoped_thread_local!(pub(crate) static CURRENT: Context);
33
34pub(crate) struct Context {
35 pub(crate) tasks: TaskQueue,
37
38 pub(crate) thread_id: usize,
40
41 #[cfg(feature = "sync")]
43 pub(crate) unpark_cache:
44 std::cell::RefCell<fxhash::FxHashMap<usize, crate::driver::UnparkHandle>>,
45
46 #[cfg(feature = "sync")]
48 pub(crate) waker_sender_cache:
49 std::cell::RefCell<fxhash::FxHashMap<usize, flume::Sender<std::task::Waker>>>,
50
51 pub(crate) time_handle: Option<TimeHandle>,
53
54 #[cfg(feature = "sync")]
56 pub(crate) blocking_handle: crate::blocking::BlockingHandle,
57}
58
59impl Context {
60 #[cfg(feature = "sync")]
61 pub(crate) fn new(blocking_handle: crate::blocking::BlockingHandle) -> Self {
62 let thread_id = crate::builder::BUILD_THREAD_ID.with(|id| *id);
63
64 Self {
65 thread_id,
66 unpark_cache: std::cell::RefCell::new(fxhash::FxHashMap::default()),
67 waker_sender_cache: std::cell::RefCell::new(fxhash::FxHashMap::default()),
68 tasks: TaskQueue::default(),
69 time_handle: None,
70 blocking_handle,
71 }
72 }
73
74 #[cfg(not(feature = "sync"))]
75 pub(crate) fn new() -> Self {
76 let thread_id = crate::builder::BUILD_THREAD_ID.with(|id| *id);
77
78 Self {
79 thread_id,
80 tasks: TaskQueue::default(),
81 time_handle: None,
82 }
83 }
84
85 #[allow(unused)]
86 #[cfg(feature = "sync")]
87 pub(crate) fn unpark_thread(&self, id: usize) {
88 use crate::driver::{thread::get_unpark_handle, unpark::Unpark};
89 if let Some(handle) = self.unpark_cache.borrow().get(&id) {
90 handle.unpark();
91 return;
92 }
93
94 if let Some(v) = get_unpark_handle(id) {
95 let w = v.clone();
97 self.unpark_cache.borrow_mut().insert(id, w);
98 v.unpark();
99 }
100 }
101
102 #[allow(unused)]
103 #[cfg(feature = "sync")]
104 pub(crate) fn send_waker(&self, id: usize, w: std::task::Waker) {
105 use crate::driver::thread::get_waker_sender;
106 if let Some(sender) = self.waker_sender_cache.borrow().get(&id) {
107 let _ = sender.send(w);
108 return;
109 }
110
111 if let Some(s) = get_waker_sender(id) {
112 let _ = s.send(w);
114 self.waker_sender_cache.borrow_mut().insert(id, s);
115 }
116 }
117}
118
119pub struct Runtime<D> {
121 pub(crate) context: Context,
122 pub(crate) driver: D,
123}
124
125impl<D> Runtime<D> {
126 pub(crate) fn new(context: Context, driver: D) -> Self {
127 Self { context, driver }
128 }
129
130 pub fn block_on<F>(&mut self, future: F) -> F::Output
132 where
133 F: Future,
134 D: Driver,
135 {
136 assert!(
137 !CURRENT.is_set(),
138 "Can not start a runtime inside a runtime"
139 );
140
141 let waker = dummy_waker();
142 let cx = &mut std::task::Context::from_waker(&waker);
143
144 self.driver.with(|| {
145 CURRENT.set(&self.context, || {
146 #[cfg(feature = "sync")]
147 let join = unsafe { spawn_without_static(future) };
148 #[cfg(not(feature = "sync"))]
149 let join = future;
150
151 let mut join = std::pin::pin!(join);
152 set_poll();
153 loop {
154 loop {
155 let mut max_round = self.context.tasks.len() * 2;
157 while let Some(t) = self.context.tasks.pop() {
158 t.run();
159 if max_round == 0 {
160 break;
162 } else {
163 max_round -= 1;
164 }
165 }
166
167 while should_poll() {
169 if let std::task::Poll::Ready(t) = join.as_mut().poll(cx) {
171 return t;
172 }
173 }
174
175 if self.context.tasks.is_empty() {
176 break;
179 }
180
181 let _ = self.driver.submit();
183 }
184
185 #[cfg(not(all(debug_assertions, feature = "debug")))]
187 let _ = self.driver.park();
188
189 #[cfg(all(debug_assertions, feature = "debug"))]
190 if let Err(e) = self.driver.park() {
191 trace!("park error: {:?}", e);
192 }
193 }
194 })
195 })
196 }
197}
198
199#[cfg(feature = "legacy")]
202pub enum FusionRuntime<#[cfg(all(target_os = "linux", feature = "iouring"))] L, R> {
203 #[cfg(all(target_os = "linux", feature = "iouring"))]
205 Uring(Runtime<L>),
206 Legacy(Runtime<R>),
208}
209
210#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
213pub enum FusionRuntime<L> {
214 Uring(Runtime<L>),
216}
217
218#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
219impl<L, R> FusionRuntime<L, R>
220where
221 L: Driver,
222 R: Driver,
223{
224 pub fn block_on<F>(&mut self, future: F) -> F::Output
226 where
227 F: Future,
228 {
229 match self {
230 FusionRuntime::Uring(inner) => {
231 info!("Monoio is running with io_uring driver");
232 inner.block_on(future)
233 }
234 FusionRuntime::Legacy(inner) => {
235 info!("Monoio is running with legacy driver");
236 inner.block_on(future)
237 }
238 }
239 }
240}
241
242#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
243impl<R> FusionRuntime<R>
244where
245 R: Driver,
246{
247 pub fn block_on<F>(&mut self, future: F) -> F::Output
249 where
250 F: Future,
251 {
252 match self {
253 FusionRuntime::Legacy(inner) => inner.block_on(future),
254 }
255 }
256}
257
258#[cfg(all(not(feature = "legacy"), all(target_os = "linux", feature = "iouring")))]
259impl<R> FusionRuntime<R>
260where
261 R: Driver,
262{
263 pub fn block_on<F>(&mut self, future: F) -> F::Output
265 where
266 F: Future,
267 {
268 match self {
269 FusionRuntime::Uring(inner) => inner.block_on(future),
270 }
271 }
272}
273
274#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
276impl From<Runtime<IoUringDriver>> for FusionRuntime<IoUringDriver, LegacyDriver> {
277 fn from(r: Runtime<IoUringDriver>) -> Self {
278 Self::Uring(r)
279 }
280}
281
282#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
284impl From<Runtime<TimeDriver<IoUringDriver>>>
285 for FusionRuntime<TimeDriver<IoUringDriver>, TimeDriver<LegacyDriver>>
286{
287 fn from(r: Runtime<TimeDriver<IoUringDriver>>) -> Self {
288 Self::Uring(r)
289 }
290}
291
292#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
294impl From<Runtime<LegacyDriver>> for FusionRuntime<IoUringDriver, LegacyDriver> {
295 fn from(r: Runtime<LegacyDriver>) -> Self {
296 Self::Legacy(r)
297 }
298}
299
300#[cfg(all(target_os = "linux", feature = "iouring", feature = "legacy"))]
302impl From<Runtime<TimeDriver<LegacyDriver>>>
303 for FusionRuntime<TimeDriver<IoUringDriver>, TimeDriver<LegacyDriver>>
304{
305 fn from(r: Runtime<TimeDriver<LegacyDriver>>) -> Self {
306 Self::Legacy(r)
307 }
308}
309
310#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
312impl From<Runtime<LegacyDriver>> for FusionRuntime<LegacyDriver> {
313 fn from(r: Runtime<LegacyDriver>) -> Self {
314 Self::Legacy(r)
315 }
316}
317
318#[cfg(all(feature = "legacy", not(all(target_os = "linux", feature = "iouring"))))]
320impl From<Runtime<TimeDriver<LegacyDriver>>> for FusionRuntime<TimeDriver<LegacyDriver>> {
321 fn from(r: Runtime<TimeDriver<LegacyDriver>>) -> Self {
322 Self::Legacy(r)
323 }
324}
325
326#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
328impl From<Runtime<IoUringDriver>> for FusionRuntime<IoUringDriver> {
329 fn from(r: Runtime<IoUringDriver>) -> Self {
330 Self::Uring(r)
331 }
332}
333
334#[cfg(all(target_os = "linux", feature = "iouring", not(feature = "legacy")))]
336impl From<Runtime<TimeDriver<IoUringDriver>>> for FusionRuntime<TimeDriver<IoUringDriver>> {
337 fn from(r: Runtime<TimeDriver<IoUringDriver>>) -> Self {
338 Self::Uring(r)
339 }
340}
341
342pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
369where
370 T: Future + 'static,
371 T::Output: 'static,
372{
373 let (task, join) = new_task(
374 crate::utils::thread_id::get_current_thread_id(),
375 future,
376 LocalScheduler,
377 );
378
379 CURRENT.with(|ctx| {
380 ctx.tasks.push(task);
381 });
382 join
383}
384
385#[cfg(feature = "sync")]
386unsafe fn spawn_without_static<T>(future: T) -> JoinHandle<T::Output>
387where
388 T: Future,
389{
390 use crate::task::new_task_holding;
391 let (task, join) = new_task_holding(
392 crate::utils::thread_id::get_current_thread_id(),
393 future,
394 LocalScheduler,
395 );
396
397 CURRENT.with(|ctx| {
398 ctx.tasks.push(task);
399 });
400 join
401}
402
403#[cfg(test)]
404mod tests {
405 #[cfg(all(feature = "sync", target_os = "linux", feature = "iouring"))]
406 #[test]
407 fn across_thread() {
408 use futures::channel::oneshot;
409
410 use crate::driver::IoUringDriver;
411
412 let (tx1, rx1) = oneshot::channel::<u8>();
413 let (tx2, rx2) = oneshot::channel::<u8>();
414
415 std::thread::spawn(move || {
416 let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
417 .build()
418 .unwrap();
419 rt.block_on(async move {
420 let n = rx1.await.expect("unable to receive rx1");
421 assert!(tx2.send(n).is_ok());
422 });
423 });
424
425 let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
426 .build()
427 .unwrap();
428 rt.block_on(async move {
429 assert!(tx1.send(24).is_ok());
430 assert_eq!(rx2.await.expect("unable to receive rx2"), 24);
431 });
432 }
433
434 #[cfg(all(target_os = "linux", feature = "iouring"))]
435 #[test]
436 fn timer() {
437 use crate::driver::IoUringDriver;
438 let mut rt = crate::RuntimeBuilder::<IoUringDriver>::new()
439 .enable_timer()
440 .build()
441 .unwrap();
442 let instant = std::time::Instant::now();
443 rt.block_on(async {
444 crate::time::sleep(std::time::Duration::from_millis(200)).await;
445 });
446 let eps = instant.elapsed().subsec_millis();
447 assert!((eps as i32 - 200).abs() < 50);
448 }
449}