ferron_common/util/
send_async_io.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::thread::ThreadId;
4
5use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
6
7/// SendAsyncIo is a wrapper around an AsyncRead or AsyncWrite that ensures that all operations are performed on the same thread.
8pub struct SendAsyncIo<T> {
9  thread_id: ThreadId,
10  inner: T,
11}
12
13impl<T> SendAsyncIo<T> {
14  /// Creates a new SendAsyncIo wrapper around the given AsyncRead or AsyncWrite.
15  pub fn new(inner: T) -> Self {
16    SendAsyncIo {
17      thread_id: std::thread::current().id(),
18      inner,
19    }
20  }
21}
22
23impl<T: AsyncRead + Unpin> AsyncRead for SendAsyncIo<T> {
24  fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
25    if std::thread::current().id() != self.thread_id {
26      panic!("SendAsyncIo can only be used from the same thread it was created on");
27    }
28    Pin::new(&mut self.inner).poll_read(cx, buf)
29  }
30}
31
32impl<T: AsyncWrite + Unpin> AsyncWrite for SendAsyncIo<T> {
33  fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
34    if std::thread::current().id() != self.thread_id {
35      panic!("SendAsyncIo can only be used from the same thread it was created on");
36    }
37    Pin::new(&mut self.inner).poll_write(cx, buf)
38  }
39
40  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
41    if std::thread::current().id() != self.thread_id {
42      panic!("SendAsyncIo can only be used from the same thread it was created on");
43    }
44    Pin::new(&mut self.inner).poll_flush(cx)
45  }
46
47  fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
48    if std::thread::current().id() != self.thread_id {
49      panic!("SendAsyncIo can only be used from the same thread it was created on");
50    }
51    Pin::new(&mut self.inner).poll_shutdown(cx)
52  }
53
54  fn poll_write_vectored(
55    mut self: Pin<&mut Self>,
56    cx: &mut Context<'_>,
57    bufs: &[std::io::IoSlice<'_>],
58  ) -> Poll<Result<usize, std::io::Error>> {
59    if std::thread::current().id() != self.thread_id {
60      panic!("SendAsyncIo can only be used from the same thread it was created on");
61    }
62    Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
63  }
64
65  fn is_write_vectored(&self) -> bool {
66    if std::thread::current().id() != self.thread_id {
67      panic!("SendAsyncIo can only be used from the same thread it was created on");
68    }
69    self.inner.is_write_vectored()
70  }
71}
72
73impl<T> Drop for SendAsyncIo<T> {
74  fn drop(&mut self) {
75    if std::thread::current().id() != self.thread_id {
76      panic!("SendAsyncIo can only be used from the same thread it was created on");
77    }
78  }
79}
80
81// Safety: SendAsyncIo would panic if used from a different thread, instead of having undefined behavior.
82unsafe impl<T> Send for SendAsyncIo<T> {}
83unsafe impl<T> Sync for SendAsyncIo<T> {}