monoio/net/tcp/
stream_poll.rs1use std::{io, net::SocketAddr, time::Duration};
4
5#[cfg(unix)]
6use {
7 libc::{shutdown, SHUT_WR},
8 std::os::fd::AsRawFd,
9};
10#[cfg(windows)]
11use {
12 std::os::windows::io::AsRawSocket,
13 windows_sys::Win32::Networking::WinSock::{shutdown, SD_SEND as SHUT_WR},
14};
15
16use super::TcpStream;
17use crate::driver::op::Op;
18
19#[derive(Debug)]
23pub struct TcpStreamPoll(TcpStream);
24
25impl crate::io::IntoPollIo for TcpStream {
26 type PollIo = TcpStreamPoll;
27
28 #[inline]
29 fn try_into_poll_io(self) -> Result<Self::PollIo, (std::io::Error, Self)> {
30 self.try_into_poll_io()
31 }
32}
33
34impl TcpStream {
35 #[inline]
37 pub fn try_into_poll_io(mut self) -> Result<TcpStreamPoll, (io::Error, TcpStream)> {
38 match self.fd.cvt_poll() {
39 Ok(_) => Ok(TcpStreamPoll(self)),
40 Err(e) => Err((e, self)),
41 }
42 }
43}
44
45impl crate::io::IntoCompIo for TcpStreamPoll {
46 type CompIo = TcpStream;
47
48 #[inline]
49 fn try_into_comp_io(self) -> Result<Self::CompIo, (std::io::Error, Self)> {
50 self.try_into_comp_io()
51 }
52}
53
54impl TcpStreamPoll {
55 #[inline]
57 pub fn try_into_comp_io(mut self) -> Result<TcpStream, (io::Error, TcpStreamPoll)> {
58 match self.0.fd.cvt_comp() {
59 Ok(_) => Ok(self.0),
60 Err(e) => Err((e, self)),
61 }
62 }
63}
64
65impl tokio::io::AsyncRead for TcpStreamPoll {
66 #[inline]
67 fn poll_read(
68 self: std::pin::Pin<&mut Self>,
69 cx: &mut std::task::Context<'_>,
70 buf: &mut tokio::io::ReadBuf<'_>,
71 ) -> std::task::Poll<io::Result<()>> {
72 unsafe {
73 let slice = buf.unfilled_mut();
74 let raw_buf = crate::buf::RawBuf::new(slice.as_ptr() as *const u8, slice.len());
75 let mut recv = Op::recv_raw(&self.0.fd, raw_buf);
76 let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut recv, cx));
77
78 std::task::Poll::Ready(ret.result.map(|n| {
79 let n = n.into_inner();
80 buf.assume_init(n as usize);
81 buf.advance(n as usize);
82 }))
83 }
84 }
85}
86
87impl tokio::io::AsyncWrite for TcpStreamPoll {
88 #[inline]
89 fn poll_write(
90 self: std::pin::Pin<&mut Self>,
91 cx: &mut std::task::Context<'_>,
92 buf: &[u8],
93 ) -> std::task::Poll<Result<usize, io::Error>> {
94 unsafe {
95 let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
96 let mut send = Op::send_raw(&self.0.fd, raw_buf);
97 let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut send, cx));
98
99 std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
100 }
101 }
102
103 #[inline]
104 fn poll_flush(
105 self: std::pin::Pin<&mut Self>,
106 _cx: &mut std::task::Context<'_>,
107 ) -> std::task::Poll<Result<(), io::Error>> {
108 std::task::Poll::Ready(Ok(()))
109 }
110
111 #[inline]
112 fn poll_shutdown(
113 self: std::pin::Pin<&mut Self>,
114 _cx: &mut std::task::Context<'_>,
115 ) -> std::task::Poll<Result<(), io::Error>> {
116 #[cfg(unix)]
117 let fd = self.0.as_raw_fd();
118 #[cfg(windows)]
119 let fd = self.0.as_raw_socket() as _;
120 let res = match unsafe { shutdown(fd, SHUT_WR) } {
121 -1 => Err(io::Error::last_os_error()),
122 _ => Ok(()),
123 };
124 std::task::Poll::Ready(res)
125 }
126
127 #[inline]
128 fn poll_write_vectored(
129 self: std::pin::Pin<&mut Self>,
130 cx: &mut std::task::Context<'_>,
131 bufs: &[std::io::IoSlice<'_>],
132 ) -> std::task::Poll<Result<usize, io::Error>> {
133 unsafe {
134 let raw_buf = crate::buf::RawBufVectored::new(bufs.as_ptr() as _, bufs.len());
135 let mut writev = Op::writev_raw(&self.0.fd, raw_buf);
136 let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut writev, cx));
137
138 std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
139 }
140 }
141
142 #[inline]
143 fn is_write_vectored(&self) -> bool {
144 true
145 }
146}
147
148impl TcpStreamPoll {
149 #[inline]
151 pub fn local_addr(&self) -> io::Result<SocketAddr> {
152 self.0.local_addr()
153 }
154
155 #[inline]
157 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
158 self.0.peer_addr()
159 }
160
161 #[inline]
163 pub fn nodelay(&self) -> io::Result<bool> {
164 self.0.nodelay()
165 }
166
167 #[inline]
169 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
170 self.0.set_nodelay(nodelay)
171 }
172
173 #[inline]
175 pub fn set_tcp_keepalive(
176 &self,
177 time: Option<Duration>,
178 interval: Option<Duration>,
179 retries: Option<u32>,
180 ) -> io::Result<()> {
181 self.0.set_tcp_keepalive(time, interval, retries)
182 }
183}
184
185#[cfg(unix)]
186impl AsRawFd for TcpStreamPoll {
187 #[inline]
188 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
189 self.0.as_raw_fd()
190 }
191}
192
193#[cfg(windows)]
194impl AsRawSocket for TcpStreamPoll {
195 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
196 self.0.as_raw_socket()
197 }
198}