monoio/net/unix/
stream_poll.rs

1//! This module provide a poll-io style interface for UnixStream.
2
3use std::{io, os::fd::AsRawFd};
4
5use super::{SocketAddr, UnixStream};
6use crate::driver::op::Op;
7
8/// A UnixStream with poll-io style interface.
9/// Using this struct, you can use UnixStream in a poll-like way.
10/// Underlying, it is based on a uring-based epoll.
11#[derive(Debug)]
12pub struct UnixStreamPoll(UnixStream);
13
14impl crate::io::IntoPollIo for UnixStream {
15    type PollIo = UnixStreamPoll;
16
17    #[inline]
18    fn try_into_poll_io(self) -> Result<Self::PollIo, (std::io::Error, Self)> {
19        self.try_into_poll_io()
20    }
21}
22
23impl UnixStream {
24    /// Convert to poll-io style UnixStreamPoll
25    #[inline]
26    pub fn try_into_poll_io(mut self) -> Result<UnixStreamPoll, (io::Error, UnixStream)> {
27        match self.fd.cvt_poll() {
28            Ok(_) => Ok(UnixStreamPoll(self)),
29            Err(e) => Err((e, self)),
30        }
31    }
32}
33
34impl crate::io::IntoCompIo for UnixStreamPoll {
35    type CompIo = UnixStream;
36
37    #[inline]
38    fn try_into_comp_io(self) -> Result<Self::CompIo, (std::io::Error, Self)> {
39        self.try_into_comp_io()
40    }
41}
42
43impl UnixStreamPoll {
44    /// Convert to normal UnixStream
45    #[inline]
46    pub fn try_into_comp_io(mut self) -> Result<UnixStream, (io::Error, UnixStreamPoll)> {
47        match self.0.fd.cvt_comp() {
48            Ok(_) => Ok(self.0),
49            Err(e) => Err((e, self)),
50        }
51    }
52}
53
54impl tokio::io::AsyncRead for UnixStreamPoll {
55    #[inline]
56    fn poll_read(
57        self: std::pin::Pin<&mut Self>,
58        cx: &mut std::task::Context<'_>,
59        buf: &mut tokio::io::ReadBuf<'_>,
60    ) -> std::task::Poll<io::Result<()>> {
61        unsafe {
62            let slice = buf.unfilled_mut();
63            let raw_buf = crate::buf::RawBuf::new(slice.as_ptr() as *const u8, slice.len());
64            let mut recv = Op::recv_raw(&self.0.fd, raw_buf);
65            let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut recv, cx));
66
67            std::task::Poll::Ready(ret.result.map(|n| {
68                let n = n.into_inner();
69                buf.assume_init(n as usize);
70                buf.advance(n as usize);
71            }))
72        }
73    }
74}
75
76impl tokio::io::AsyncWrite for UnixStreamPoll {
77    #[inline]
78    fn poll_write(
79        self: std::pin::Pin<&mut Self>,
80        cx: &mut std::task::Context<'_>,
81        buf: &[u8],
82    ) -> std::task::Poll<Result<usize, io::Error>> {
83        unsafe {
84            let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
85            let mut send = Op::send_raw(&self.0.fd, raw_buf);
86            let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut send, cx));
87
88            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
89        }
90    }
91
92    #[inline]
93    fn poll_flush(
94        self: std::pin::Pin<&mut Self>,
95        _cx: &mut std::task::Context<'_>,
96    ) -> std::task::Poll<Result<(), io::Error>> {
97        std::task::Poll::Ready(Ok(()))
98    }
99
100    #[inline]
101    fn poll_shutdown(
102        self: std::pin::Pin<&mut Self>,
103        _cx: &mut std::task::Context<'_>,
104    ) -> std::task::Poll<Result<(), io::Error>> {
105        let fd = self.0.as_raw_fd();
106        let res = match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
107            -1 => Err(io::Error::last_os_error()),
108            _ => Ok(()),
109        };
110        std::task::Poll::Ready(res)
111    }
112
113    #[inline]
114    fn poll_write_vectored(
115        self: std::pin::Pin<&mut Self>,
116        cx: &mut std::task::Context<'_>,
117        bufs: &[std::io::IoSlice<'_>],
118    ) -> std::task::Poll<Result<usize, io::Error>> {
119        unsafe {
120            let raw_buf =
121                crate::buf::RawBufVectored::new(bufs.as_ptr() as *const libc::iovec, bufs.len());
122            let mut writev = Op::writev_raw(&self.0.fd, raw_buf);
123            let ret = ready!(crate::driver::op::PollLegacy::poll_io(&mut writev, cx));
124
125            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
126        }
127    }
128
129    #[inline]
130    fn is_write_vectored(&self) -> bool {
131        true
132    }
133}
134
135impl UnixStreamPoll {
136    /// Returns the socket address of the local half of this connection.
137    #[inline]
138    pub fn local_addr(&self) -> io::Result<SocketAddr> {
139        self.0.local_addr()
140    }
141
142    /// Returns the socket address of the remote half of this connection.
143    #[inline]
144    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
145        self.0.peer_addr()
146    }
147}