ferron_common/http_proxy/send_net_io/monoio/
unix_stream_poll.rs

1use std::mem::ManuallyDrop;
2use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use std::thread::ThreadId;
8
9use monoio::io::IntoPollIo;
10use monoio::net::unix::stream_poll::UnixStreamPoll;
11use monoio::net::UnixStream;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14/// SendUnixStream is a wrapper around Monoio's UnixStream.
15pub struct SendUnixStreamPoll {
16  thread_id: ThreadId,
17  inner: Option<UnixStreamPoll>,
18  prev_inner: Option<ManuallyDrop<UnixStreamPoll>>,
19  is_write_vectored: bool,
20  inner_fd: RawFd,
21  obtained_dropped: bool,
22  marked_dropped: Arc<AtomicBool>,
23}
24
25impl SendUnixStreamPoll {
26  /// Creates a new SendUnixStreamPoll wrapper around the given UnixStream.
27  #[inline]
28  pub fn new_comp_io(inner: UnixStream) -> Result<Self, std::io::Error> {
29    let inner_fd = inner.as_raw_fd();
30    let inner = inner.into_poll_io()?;
31    let is_write_vectored = inner.is_write_vectored();
32    Ok(SendUnixStreamPoll {
33      thread_id: std::thread::current().id(),
34      inner: Some(inner),
35      prev_inner: None,
36      is_write_vectored,
37      inner_fd,
38      obtained_dropped: false,
39      marked_dropped: Arc::new(AtomicBool::new(false)),
40    })
41  }
42}
43
44impl SendUnixStreamPoll {
45  /// Obtains a drop guard for the UnixStreamPoll.
46  ///
47  /// # Safety
48  ///
49  /// This method is unsafe because it allows the caller to drop the inner UnixStreamPoll without marking it as dropped.
50  ///
51  #[inline]
52  pub unsafe fn get_drop_guard(&mut self) -> SendUnixStreamPollDropGuard {
53    if self.obtained_dropped {
54      panic!("the UnixStreamPoll's get_drop_guard method can be used only once");
55    }
56    self.obtained_dropped = true;
57    let inner = if let Some(inner) = self.inner.as_ref() {
58      // Copy the inner UnixStreamPoll
59      let mut inner_data = std::mem::MaybeUninit::uninit();
60      std::ptr::copy(inner as *const _, inner_data.as_mut_ptr(), 1);
61      Some(ManuallyDrop::new(inner_data.assume_init()))
62    } else {
63      None
64    };
65    SendUnixStreamPollDropGuard {
66      inner,
67      marked_dropped: self.marked_dropped.clone(),
68    }
69  }
70
71  #[inline]
72  fn populate_if_different_thread_or_marked_dropped(&mut self, dropped: bool) {
73    let current_thread_id = std::thread::current().id();
74    let marked_dropped = !dropped && self.marked_dropped.swap(false, Ordering::Relaxed) && self.prev_inner.is_none();
75    if marked_dropped || current_thread_id != self.thread_id {
76      if !self.obtained_dropped {
77        panic!("the UnixStreamPoll can be used only once if drop guard is not obtained")
78      }
79      if self.prev_inner.is_some() {
80        panic!("the UnixStreamPoll can be moved only once across threads or if it is marked as dropped");
81      }
82      // Safety: The inner UnixStreamPoll is manually dropped, so it's safe to use the raw fd/socket
83      let std_unix_stream = unsafe { std::os::unix::net::UnixStream::from_raw_fd(self.inner_fd) };
84      let _ = std_unix_stream.set_nonblocking(monoio::utils::is_legacy());
85      let unix_stream_poll = UnixStream::from_std(std_unix_stream)
86        .expect("failed to create UnixStream")
87        .try_into_poll_io()
88        .expect("failed to create UnixStreamPoll");
89      self.is_write_vectored = unix_stream_poll.is_write_vectored();
90      self.prev_inner = self.inner.take().map(ManuallyDrop::new);
91      self.inner = Some(unix_stream_poll);
92      self.thread_id = current_thread_id;
93    }
94  }
95}
96
97impl AsyncRead for SendUnixStreamPoll {
98  #[inline]
99  fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
100    self.populate_if_different_thread_or_marked_dropped(false);
101    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_read(cx, buf)
102  }
103}
104
105impl AsyncWrite for SendUnixStreamPoll {
106  #[inline]
107  fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
108    self.populate_if_different_thread_or_marked_dropped(false);
109    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_write(cx, buf)
110  }
111
112  #[inline]
113  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
114    self.populate_if_different_thread_or_marked_dropped(false);
115    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_flush(cx)
116  }
117
118  #[inline]
119  fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
120    self.populate_if_different_thread_or_marked_dropped(false);
121    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_shutdown(cx)
122  }
123
124  #[inline]
125  fn poll_write_vectored(
126    mut self: Pin<&mut Self>,
127    cx: &mut Context<'_>,
128    bufs: &[std::io::IoSlice<'_>],
129  ) -> Poll<Result<usize, std::io::Error>> {
130    self.populate_if_different_thread_or_marked_dropped(false);
131    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_write_vectored(cx, bufs)
132  }
133
134  #[inline]
135  fn is_write_vectored(&self) -> bool {
136    if std::thread::current().id() != self.thread_id {
137      return self.is_write_vectored;
138    }
139    self
140      .inner
141      .as_ref()
142      .expect("inner element not present")
143      .is_write_vectored()
144  }
145}
146
147#[cfg(unix)]
148impl AsRawFd for SendUnixStreamPoll {
149  #[inline]
150  fn as_raw_fd(&self) -> RawFd {
151    self.inner_fd
152  }
153}
154
155#[cfg(unix)]
156impl AsFd for SendUnixStreamPoll {
157  #[inline]
158  fn as_fd(&self) -> BorrowedFd<'_> {
159    // Safety: inner_fd is valid, as it is taken from the inner value
160    unsafe { BorrowedFd::borrow_raw(self.inner_fd) }
161  }
162}
163
164impl Drop for SendUnixStreamPoll {
165  fn drop(&mut self) {
166    if !self.marked_dropped.swap(true, Ordering::Relaxed) {
167      self.populate_if_different_thread_or_marked_dropped(true);
168    } else {
169      let _ = ManuallyDrop::new(self.inner.take());
170    }
171  }
172}
173
174// Safety: As far as we read from Monoio's source, inner Rc in SharedFd is cloned only during async operations.
175unsafe impl Send for SendUnixStreamPoll {}
176
177/// Drop guard for SendUnixStreamPoll.
178pub struct SendUnixStreamPollDropGuard {
179  inner: Option<ManuallyDrop<UnixStreamPoll>>,
180  marked_dropped: Arc<AtomicBool>,
181}
182
183impl Drop for SendUnixStreamPollDropGuard {
184  fn drop(&mut self) {
185    if let Some(inner) = self.inner.take() {
186      if !self.marked_dropped.swap(true, Ordering::Relaxed) {
187        // Drop if not marked as dropped
188        let inner_comp_io = ManuallyDrop::into_inner(inner)
189          .try_into_comp_io()
190          .expect("failed to convert inner UnixStreamPoll to comp_io");
191
192        let _ = inner_comp_io.into_raw_fd();
193      }
194    }
195  }
196}