ferron_common/http_proxy/send_net_io/monoio/
tcp_stream_poll.rs1use std::mem::ManuallyDrop;
2#[cfg(unix)]
3use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, RawSocket};
6use std::pin::Pin;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use std::thread::ThreadId;
11
12use monoio::io::IntoPollIo;
13use monoio::net::tcp::stream_poll::TcpStreamPoll;
14use monoio::net::TcpStream;
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17pub struct SendTcpStreamPoll {
19 thread_id: ThreadId,
20 inner: Option<TcpStreamPoll>,
21 prev_inner: Option<ManuallyDrop<TcpStreamPoll>>,
22 is_write_vectored: bool,
23 #[cfg(unix)]
24 inner_fd: RawFd,
25 #[cfg(windows)]
26 inner_socket: RawSocket,
27 obtained_dropped: bool,
28 marked_dropped: Arc<AtomicBool>,
29}
30
31impl SendTcpStreamPoll {
32 #[inline]
34 #[allow(dead_code)]
35 pub fn new(inner: TcpStreamPoll) -> Self {
36 #[cfg(unix)]
37 let inner_fd = inner.as_raw_fd();
38 #[cfg(windows)]
39 let inner_socket = inner.as_raw_socket();
40 let is_write_vectored = inner.is_write_vectored();
41 SendTcpStreamPoll {
42 thread_id: std::thread::current().id(),
43 inner: Some(inner),
44 prev_inner: None,
45 is_write_vectored,
46 #[cfg(unix)]
47 inner_fd,
48 #[cfg(windows)]
49 inner_socket,
50 obtained_dropped: false,
51 marked_dropped: Arc::new(AtomicBool::new(false)),
52 }
53 }
54
55 #[inline]
57 pub fn new_comp_io(inner: TcpStream) -> Result<Self, std::io::Error> {
58 #[cfg(unix)]
59 let inner_fd = inner.as_raw_fd();
60 #[cfg(windows)]
61 let inner_socket = inner.as_raw_socket();
62 let inner = inner.into_poll_io()?;
63 let is_write_vectored = inner.is_write_vectored();
64 Ok(SendTcpStreamPoll {
65 thread_id: std::thread::current().id(),
66 inner: Some(inner),
67 prev_inner: None,
68 is_write_vectored,
69 #[cfg(unix)]
70 inner_fd,
71 #[cfg(windows)]
72 inner_socket,
73 obtained_dropped: false,
74 marked_dropped: Arc::new(AtomicBool::new(false)),
75 })
76 }
77}
78
79impl SendTcpStreamPoll {
80 #[inline]
87 pub unsafe fn get_drop_guard(&mut self) -> SendTcpStreamPollDropGuard {
88 if self.obtained_dropped {
89 panic!("the TcpStreamPoll's get_drop_guard method can be used only once");
90 }
91 self.obtained_dropped = true;
92 let inner = if let Some(inner) = self.inner.as_ref() {
93 let mut inner_data = std::mem::MaybeUninit::uninit();
95 std::ptr::copy(inner as *const _, inner_data.as_mut_ptr(), 1);
96 Some(ManuallyDrop::new(inner_data.assume_init()))
97 } else {
98 None
99 };
100 SendTcpStreamPollDropGuard {
101 inner,
102 marked_dropped: self.marked_dropped.clone(),
103 }
104 }
105
106 #[inline]
107 fn populate_if_different_thread_or_marked_dropped(&mut self, dropped: bool) {
108 let current_thread_id = std::thread::current().id();
109 let marked_dropped = !dropped && self.marked_dropped.swap(false, Ordering::Relaxed) && self.prev_inner.is_none();
110 if marked_dropped || current_thread_id != self.thread_id {
111 if !self.obtained_dropped {
112 panic!("the TcpStreamPoll can be used only once if drop guard is not obtained")
113 }
114 if self.prev_inner.is_some() {
115 panic!("the TcpStreamPoll can be moved only once across threads or if it is marked as dropped");
116 }
117 #[cfg(unix)]
119 let std_tcp_stream = unsafe { std::net::TcpStream::from_raw_fd(self.inner_fd) };
120 #[cfg(windows)]
121 let std_tcp_stream = unsafe { std::net::TcpStream::from_raw_socket(self.inner_socket) };
122 let _ = std_tcp_stream.set_nonblocking(monoio::utils::is_legacy());
123 let tcp_stream_poll = TcpStream::from_std(std_tcp_stream)
124 .expect("failed to create TcpStream")
125 .try_into_poll_io()
126 .expect("failed to create TcpStreamPoll");
127 self.is_write_vectored = tcp_stream_poll.is_write_vectored();
128 self.prev_inner = self.inner.take().map(ManuallyDrop::new);
129 self.inner = Some(tcp_stream_poll);
130 self.thread_id = current_thread_id;
131 }
132 }
133}
134
135impl AsyncRead for SendTcpStreamPoll {
136 #[inline]
137 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
138 self.populate_if_different_thread_or_marked_dropped(false);
139 Pin::new(self.inner.as_mut().expect("inner element not present")).poll_read(cx, buf)
140 }
141}
142
143impl AsyncWrite for SendTcpStreamPoll {
144 #[inline]
145 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
146 self.populate_if_different_thread_or_marked_dropped(false);
147 Pin::new(self.inner.as_mut().expect("inner element not present")).poll_write(cx, buf)
148 }
149
150 #[inline]
151 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
152 self.populate_if_different_thread_or_marked_dropped(false);
153 Pin::new(self.inner.as_mut().expect("inner element not present")).poll_flush(cx)
154 }
155
156 #[inline]
157 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
158 self.populate_if_different_thread_or_marked_dropped(false);
159 Pin::new(self.inner.as_mut().expect("inner element not present")).poll_shutdown(cx)
160 }
161
162 #[inline]
163 fn poll_write_vectored(
164 mut self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 bufs: &[std::io::IoSlice<'_>],
167 ) -> Poll<Result<usize, std::io::Error>> {
168 self.populate_if_different_thread_or_marked_dropped(false);
169 Pin::new(self.inner.as_mut().expect("inner element not present")).poll_write_vectored(cx, bufs)
170 }
171
172 #[inline]
173 fn is_write_vectored(&self) -> bool {
174 if std::thread::current().id() != self.thread_id {
175 return self.is_write_vectored;
176 }
177 self
178 .inner
179 .as_ref()
180 .expect("inner element not present")
181 .is_write_vectored()
182 }
183}
184
185#[cfg(unix)]
186impl AsRawFd for SendTcpStreamPoll {
187 #[inline]
188 fn as_raw_fd(&self) -> RawFd {
189 self.inner_fd
190 }
191}
192
193#[cfg(unix)]
194impl AsFd for SendTcpStreamPoll {
195 #[inline]
196 fn as_fd(&self) -> BorrowedFd<'_> {
197 unsafe { BorrowedFd::borrow_raw(self.inner_fd) }
199 }
200}
201
202#[cfg(windows)]
203impl AsRawSocket for SendTcpStreamPoll {
204 #[inline]
205 fn as_raw_socket(&self) -> RawSocket {
206 self.inner_socket
207 }
208}
209
210#[cfg(windows)]
211impl AsSocket for SendTcpStreamPoll {
212 #[inline]
213 fn as_socket(&self) -> BorrowedSocket<'_> {
214 unsafe { BorrowedSocket::borrow_raw(self.inner_socket) }
216 }
217}
218
219impl Drop for SendTcpStreamPoll {
220 fn drop(&mut self) {
221 if !self.marked_dropped.swap(true, Ordering::Relaxed) {
222 self.populate_if_different_thread_or_marked_dropped(true);
223 } else {
224 let _ = ManuallyDrop::new(self.inner.take());
225 }
226 }
227}
228
229unsafe impl Send for SendTcpStreamPoll {}
231
232pub struct SendTcpStreamPollDropGuard {
234 inner: Option<ManuallyDrop<TcpStreamPoll>>,
235 marked_dropped: Arc<AtomicBool>,
236}
237
238impl Drop for SendTcpStreamPollDropGuard {
239 fn drop(&mut self) {
240 if let Some(inner) = self.inner.take() {
241 if !self.marked_dropped.swap(true, Ordering::Relaxed) {
242 let inner_comp_io = ManuallyDrop::into_inner(inner)
244 .try_into_comp_io()
245 .expect("failed to convert inner TcpStreamPoll to comp_io");
246
247 #[cfg(unix)]
248 let _ = inner_comp_io.into_raw_fd();
249 #[cfg(windows)]
250 let _ = inner_comp_io.into_raw_socket();
251 }
252 }
253 }
254}