ferron/util/
send_async_io.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::thread::ThreadId;
4
5use monoio_compat::{AsyncRead, AsyncWrite};
6use tokio::io::ReadBuf;
7
8/// SendAsyncIo is a wrapper around an AsyncRead or AsyncWrite that ensures that all operations are performed on the same thread.
9pub struct SendAsyncIo<T> {
10  thread_id: ThreadId,
11  inner: T,
12}
13
14impl<T> SendAsyncIo<T> {
15  /// Creates a new SendAsyncIo wrapper around the given AsyncRead or AsyncWrite.
16  pub fn new(inner: T) -> Self {
17    SendAsyncIo {
18      thread_id: std::thread::current().id(),
19      inner,
20    }
21  }
22}
23
24impl<T: AsyncRead + Unpin> AsyncRead for SendAsyncIo<T> {
25  fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
26    if std::thread::current().id() != self.thread_id {
27      panic!("SendAsyncIo can only be used from the same thread it was created on");
28    }
29    Pin::new(&mut self.inner).poll_read(cx, buf)
30  }
31}
32
33impl<T: AsyncWrite + Unpin> AsyncWrite for SendAsyncIo<T> {
34  fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
35    if std::thread::current().id() != self.thread_id {
36      panic!("SendAsyncIo can only be used from the same thread it was created on");
37    }
38    Pin::new(&mut self.inner).poll_write(cx, buf)
39  }
40
41  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
42    if std::thread::current().id() != self.thread_id {
43      panic!("SendAsyncIo can only be used from the same thread it was created on");
44    }
45    Pin::new(&mut self.inner).poll_flush(cx)
46  }
47
48  fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
49    if std::thread::current().id() != self.thread_id {
50      panic!("SendAsyncIo can only be used from the same thread it was created on");
51    }
52    Pin::new(&mut self.inner).poll_shutdown(cx)
53  }
54
55  fn poll_write_vectored(
56    mut self: Pin<&mut Self>,
57    cx: &mut Context<'_>,
58    bufs: &[std::io::IoSlice<'_>],
59  ) -> Poll<Result<usize, std::io::Error>> {
60    if std::thread::current().id() != self.thread_id {
61      panic!("SendAsyncIo can only be used from the same thread it was created on");
62    }
63    Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
64  }
65
66  fn is_write_vectored(&self) -> bool {
67    if std::thread::current().id() != self.thread_id {
68      panic!("SendAsyncIo can only be used from the same thread it was created on");
69    }
70    self.inner.is_write_vectored()
71  }
72}
73
74impl<T> Drop for SendAsyncIo<T> {
75  fn drop(&mut self) {
76    if std::thread::current().id() != self.thread_id {
77      panic!("SendAsyncIo can only be used from the same thread it was created on");
78    }
79  }
80}
81
82// Safety: SendAsyncIo would panic if used from a different thread, instead of having undefined behavior.
83unsafe impl<T> Send for SendAsyncIo<T> {}