monoio/net/tcp/
stream_poll.rs

1//! This module provide a poll-io style interface for TcpStream.
2
3use std::{io, net::SocketAddr, time::Duration};
4
5#[cfg(unix)]
6use {
7    libc::{shutdown, SHUT_WR},
8    std::os::fd::AsRawFd,
9};
10#[cfg(windows)]
11use {
12    std::os::windows::io::AsRawSocket,
13    windows_sys::Win32::Networking::WinSock::{shutdown, SD_SEND as SHUT_WR},
14};
15
16use super::TcpStream;
17use crate::driver::op::Op;
18
19/// A TcpStream with poll-io style interface.
20/// Using this struct, you can use TcpStream in a poll-like way.
21/// Underlying, it is based on a uring-based epoll.
22#[derive(Debug)]
23pub struct TcpStreamPoll(TcpStream);
24
25impl crate::io::IntoPollIo for TcpStream {
26    type PollIo = TcpStreamPoll;
27
28    #[inline]
29    fn try_into_poll_io(self) -> Result<Self::PollIo, (std::io::Error, Self)> {
30        self.try_into_poll_io()
31    }
32}
33
34impl TcpStream {
35    /// Convert to poll-io style TcpStreamPoll
36    #[inline]
37    pub fn try_into_poll_io(mut self) -> Result<TcpStreamPoll, (io::Error, TcpStream)> {
38        match self.fd.cvt_poll() {
39            Ok(_) => Ok(TcpStreamPoll(self)),
40            Err(e) => Err((e, self)),
41        }
42    }
43}
44
45impl crate::io::IntoCompIo for TcpStreamPoll {
46    type CompIo = TcpStream;
47
48    #[inline]
49    fn try_into_comp_io(self) -> Result<Self::CompIo, (std::io::Error, Self)> {
50        self.try_into_comp_io()
51    }
52}
53
54impl TcpStreamPoll {
55    /// Convert to normal TcpStream
56    #[inline]
57    pub fn try_into_comp_io(mut self) -> Result<TcpStream, (io::Error, TcpStreamPoll)> {
58        match self.0.fd.cvt_comp() {
59            Ok(_) => Ok(self.0),
60            Err(e) => Err((e, self)),
61        }
62    }
63}
64
65impl tokio::io::AsyncRead for TcpStreamPoll {
66    #[inline]
67    fn poll_read(
68        self: std::pin::Pin<&mut Self>,
69        cx: &mut std::task::Context<'_>,
70        buf: &mut tokio::io::ReadBuf<'_>,
71    ) -> std::task::Poll<io::Result<()>> {
72        unsafe {
73            let slice = buf.unfilled_mut();
74            let raw_buf = crate::buf::RawBuf::new(slice.as_ptr() as *const u8, slice.len());
75            let mut recv = Op::recv_raw(&self.0.fd, raw_buf);
76            let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut recv, cx));
77
78            std::task::Poll::Ready(ret.result.map(|n| {
79                let n = n.into_inner();
80                buf.assume_init(n as usize);
81                buf.advance(n as usize);
82            }))
83        }
84    }
85}
86
87impl tokio::io::AsyncWrite for TcpStreamPoll {
88    #[inline]
89    fn poll_write(
90        self: std::pin::Pin<&mut Self>,
91        cx: &mut std::task::Context<'_>,
92        buf: &[u8],
93    ) -> std::task::Poll<Result<usize, io::Error>> {
94        unsafe {
95            let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
96            let mut send = Op::send_raw(&self.0.fd, raw_buf);
97            let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut send, cx));
98
99            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
100        }
101    }
102
103    #[inline]
104    fn poll_flush(
105        self: std::pin::Pin<&mut Self>,
106        _cx: &mut std::task::Context<'_>,
107    ) -> std::task::Poll<Result<(), io::Error>> {
108        std::task::Poll::Ready(Ok(()))
109    }
110
111    #[inline]
112    fn poll_shutdown(
113        self: std::pin::Pin<&mut Self>,
114        _cx: &mut std::task::Context<'_>,
115    ) -> std::task::Poll<Result<(), io::Error>> {
116        #[cfg(unix)]
117        let fd = self.0.as_raw_fd();
118        #[cfg(windows)]
119        let fd = self.0.as_raw_socket() as _;
120        let res = match unsafe { shutdown(fd, SHUT_WR) } {
121            -1 => Err(io::Error::last_os_error()),
122            _ => Ok(()),
123        };
124        std::task::Poll::Ready(res)
125    }
126
127    #[inline]
128    fn poll_write_vectored(
129        self: std::pin::Pin<&mut Self>,
130        cx: &mut std::task::Context<'_>,
131        bufs: &[std::io::IoSlice<'_>],
132    ) -> std::task::Poll<Result<usize, io::Error>> {
133        unsafe {
134            let raw_buf = crate::buf::RawBufVectored::new(bufs.as_ptr() as _, bufs.len());
135            let mut writev = Op::writev_raw(&self.0.fd, raw_buf);
136            let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut writev, cx));
137
138            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
139        }
140    }
141
142    #[inline]
143    fn is_write_vectored(&self) -> bool {
144        true
145    }
146}
147
148impl TcpStreamPoll {
149    /// Return the local address that this stream is bound to.
150    #[inline]
151    pub fn local_addr(&self) -> io::Result<SocketAddr> {
152        self.0.local_addr()
153    }
154
155    /// Return the remote address that this stream is connected to.
156    #[inline]
157    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
158        self.0.peer_addr()
159    }
160
161    /// Get the value of the `TCP_NODELAY` option on this socket.
162    #[inline]
163    pub fn nodelay(&self) -> io::Result<bool> {
164        self.0.nodelay()
165    }
166
167    /// Set the value of the `TCP_NODELAY` option on this socket.
168    #[inline]
169    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
170        self.0.set_nodelay(nodelay)
171    }
172
173    /// Set the value of the `SO_KEEPALIVE` option on this socket.
174    #[inline]
175    pub fn set_tcp_keepalive(
176        &self,
177        time: Option<Duration>,
178        interval: Option<Duration>,
179        retries: Option<u32>,
180    ) -> io::Result<()> {
181        self.0.set_tcp_keepalive(time, interval, retries)
182    }
183}
184
185#[cfg(unix)]
186impl AsRawFd for TcpStreamPoll {
187    #[inline]
188    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
189        self.0.as_raw_fd()
190    }
191}
192
193#[cfg(windows)]
194impl AsRawSocket for TcpStreamPoll {
195    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
196        self.0.as_raw_socket()
197    }
198}