ferron_common/http_proxy/send_net_io/monoio/
tcp_stream_poll.rs

1use std::mem::ManuallyDrop;
2#[cfg(unix)]
3use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, RawSocket};
6use std::pin::Pin;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use std::thread::ThreadId;
11
12use monoio::io::IntoPollIo;
13use monoio::net::tcp::stream_poll::TcpStreamPoll;
14use monoio::net::TcpStream;
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17/// SendTcpStream is a wrapper around Monoio's TcpStream.
18pub struct SendTcpStreamPoll {
19  thread_id: ThreadId,
20  inner: Option<TcpStreamPoll>,
21  prev_inner: Option<ManuallyDrop<TcpStreamPoll>>,
22  is_write_vectored: bool,
23  #[cfg(unix)]
24  inner_fd: RawFd,
25  #[cfg(windows)]
26  inner_socket: RawSocket,
27  obtained_dropped: bool,
28  marked_dropped: Arc<AtomicBool>,
29}
30
31impl SendTcpStreamPoll {
32  /// Creates a new SendTcpStreamPoll wrapper around the given TcpStreamPoll.
33  #[inline]
34  #[allow(dead_code)]
35  pub fn new(inner: TcpStreamPoll) -> Self {
36    #[cfg(unix)]
37    let inner_fd = inner.as_raw_fd();
38    #[cfg(windows)]
39    let inner_socket = inner.as_raw_socket();
40    let is_write_vectored = inner.is_write_vectored();
41    SendTcpStreamPoll {
42      thread_id: std::thread::current().id(),
43      inner: Some(inner),
44      prev_inner: None,
45      is_write_vectored,
46      #[cfg(unix)]
47      inner_fd,
48      #[cfg(windows)]
49      inner_socket,
50      obtained_dropped: false,
51      marked_dropped: Arc::new(AtomicBool::new(false)),
52    }
53  }
54
55  /// Creates a new SendTcpStreamPoll wrapper around the given TcpStream.
56  #[inline]
57  pub fn new_comp_io(inner: TcpStream) -> Result<Self, std::io::Error> {
58    #[cfg(unix)]
59    let inner_fd = inner.as_raw_fd();
60    #[cfg(windows)]
61    let inner_socket = inner.as_raw_socket();
62    let inner = inner.into_poll_io()?;
63    let is_write_vectored = inner.is_write_vectored();
64    Ok(SendTcpStreamPoll {
65      thread_id: std::thread::current().id(),
66      inner: Some(inner),
67      prev_inner: None,
68      is_write_vectored,
69      #[cfg(unix)]
70      inner_fd,
71      #[cfg(windows)]
72      inner_socket,
73      obtained_dropped: false,
74      marked_dropped: Arc::new(AtomicBool::new(false)),
75    })
76  }
77}
78
79impl SendTcpStreamPoll {
80  /// Obtains a drop guard for the TcpStreamPoll.
81  ///
82  /// # Safety
83  ///
84  /// This method is unsafe because it allows the caller to drop the inner TcpStreamPoll without marking it as dropped.
85  ///
86  #[inline]
87  pub unsafe fn get_drop_guard(&mut self) -> SendTcpStreamPollDropGuard {
88    if self.obtained_dropped {
89      panic!("the TcpStreamPoll's get_drop_guard method can be used only once");
90    }
91    self.obtained_dropped = true;
92    let inner = if let Some(inner) = self.inner.as_ref() {
93      // Copy the inner TcpStreamPoll
94      let mut inner_data = std::mem::MaybeUninit::uninit();
95      std::ptr::copy(inner as *const _, inner_data.as_mut_ptr(), 1);
96      Some(ManuallyDrop::new(inner_data.assume_init()))
97    } else {
98      None
99    };
100    SendTcpStreamPollDropGuard {
101      inner,
102      marked_dropped: self.marked_dropped.clone(),
103    }
104  }
105
106  #[inline]
107  fn populate_if_different_thread_or_marked_dropped(&mut self, dropped: bool) {
108    let current_thread_id = std::thread::current().id();
109    let marked_dropped = !dropped && self.marked_dropped.swap(false, Ordering::Relaxed) && self.prev_inner.is_none();
110    if marked_dropped || current_thread_id != self.thread_id {
111      if !self.obtained_dropped {
112        panic!("the TcpStreamPoll can be used only once if drop guard is not obtained")
113      }
114      if self.prev_inner.is_some() {
115        panic!("the TcpStreamPoll can be moved only once across threads or if it is marked as dropped");
116      }
117      // Safety: The inner TcpStreamPoll is manually dropped, so it's safe to use the raw fd/socket
118      #[cfg(unix)]
119      let std_tcp_stream = unsafe { std::net::TcpStream::from_raw_fd(self.inner_fd) };
120      #[cfg(windows)]
121      let std_tcp_stream = unsafe { std::net::TcpStream::from_raw_socket(self.inner_socket) };
122      let _ = std_tcp_stream.set_nonblocking(monoio::utils::is_legacy());
123      let tcp_stream_poll = TcpStream::from_std(std_tcp_stream)
124        .expect("failed to create TcpStream")
125        .try_into_poll_io()
126        .expect("failed to create TcpStreamPoll");
127      self.is_write_vectored = tcp_stream_poll.is_write_vectored();
128      self.prev_inner = self.inner.take().map(ManuallyDrop::new);
129      self.inner = Some(tcp_stream_poll);
130      self.thread_id = current_thread_id;
131    }
132  }
133}
134
135impl AsyncRead for SendTcpStreamPoll {
136  #[inline]
137  fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
138    self.populate_if_different_thread_or_marked_dropped(false);
139    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_read(cx, buf)
140  }
141}
142
143impl AsyncWrite for SendTcpStreamPoll {
144  #[inline]
145  fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
146    self.populate_if_different_thread_or_marked_dropped(false);
147    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_write(cx, buf)
148  }
149
150  #[inline]
151  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
152    self.populate_if_different_thread_or_marked_dropped(false);
153    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_flush(cx)
154  }
155
156  #[inline]
157  fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
158    self.populate_if_different_thread_or_marked_dropped(false);
159    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_shutdown(cx)
160  }
161
162  #[inline]
163  fn poll_write_vectored(
164    mut self: Pin<&mut Self>,
165    cx: &mut Context<'_>,
166    bufs: &[std::io::IoSlice<'_>],
167  ) -> Poll<Result<usize, std::io::Error>> {
168    self.populate_if_different_thread_or_marked_dropped(false);
169    Pin::new(self.inner.as_mut().expect("inner element not present")).poll_write_vectored(cx, bufs)
170  }
171
172  #[inline]
173  fn is_write_vectored(&self) -> bool {
174    if std::thread::current().id() != self.thread_id {
175      return self.is_write_vectored;
176    }
177    self
178      .inner
179      .as_ref()
180      .expect("inner element not present")
181      .is_write_vectored()
182  }
183}
184
185#[cfg(unix)]
186impl AsRawFd for SendTcpStreamPoll {
187  #[inline]
188  fn as_raw_fd(&self) -> RawFd {
189    self.inner_fd
190  }
191}
192
193#[cfg(unix)]
194impl AsFd for SendTcpStreamPoll {
195  #[inline]
196  fn as_fd(&self) -> BorrowedFd<'_> {
197    // Safety: inner_fd is valid, as it is taken from the inner value
198    unsafe { BorrowedFd::borrow_raw(self.inner_fd) }
199  }
200}
201
202#[cfg(windows)]
203impl AsRawSocket for SendTcpStreamPoll {
204  #[inline]
205  fn as_raw_socket(&self) -> RawSocket {
206    self.inner_socket
207  }
208}
209
210#[cfg(windows)]
211impl AsSocket for SendTcpStreamPoll {
212  #[inline]
213  fn as_socket(&self) -> BorrowedSocket<'_> {
214    // Safety: inner_socket is valid, as it is taken from the inner value
215    unsafe { BorrowedSocket::borrow_raw(self.inner_socket) }
216  }
217}
218
219impl Drop for SendTcpStreamPoll {
220  fn drop(&mut self) {
221    if !self.marked_dropped.swap(true, Ordering::Relaxed) {
222      self.populate_if_different_thread_or_marked_dropped(true);
223    } else {
224      let _ = ManuallyDrop::new(self.inner.take());
225    }
226  }
227}
228
229// Safety: As far as we read from Monoio's source, inner Rc in SharedFd is cloned only during async operations.
230unsafe impl Send for SendTcpStreamPoll {}
231
232/// Drop guard for SendTcpStreamPoll.
233pub struct SendTcpStreamPollDropGuard {
234  inner: Option<ManuallyDrop<TcpStreamPoll>>,
235  marked_dropped: Arc<AtomicBool>,
236}
237
238impl Drop for SendTcpStreamPollDropGuard {
239  fn drop(&mut self) {
240    if let Some(inner) = self.inner.take() {
241      if !self.marked_dropped.swap(true, Ordering::Relaxed) {
242        // Drop if not marked as dropped
243        let inner_comp_io = ManuallyDrop::into_inner(inner)
244          .try_into_comp_io()
245          .expect("failed to convert inner TcpStreamPoll to comp_io");
246
247        #[cfg(unix)]
248        let _ = inner_comp_io.into_raw_fd();
249        #[cfg(windows)]
250        let _ = inner_comp_io.into_raw_socket();
251      }
252    }
253  }
254}