monoio/net/unix/
stream.rs

1use std::{
2    future::Future,
3    io::{self},
4    os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
5    path::Path,
6};
7
8use super::{
9    socket_addr::{local_addr, pair, peer_addr, socket_addr, SocketAddr},
10    ucred::UCred,
11};
12use crate::{
13    buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut},
14    driver::{op::Op, shared_fd::SharedFd},
15    io::{
16        as_fd::{AsReadFd, AsWriteFd, SharedFdWrapper},
17        operation_canceled, AsyncReadRent, AsyncWriteRent, CancelHandle, CancelableAsyncReadRent,
18        CancelableAsyncWriteRent, Split,
19    },
20    net::new_socket,
21    BufResult,
22};
23
24/// UnixStream
25pub struct UnixStream {
26    pub(super) fd: SharedFd,
27}
28
29/// UnixStream is safe to split to two parts
30unsafe impl Split for UnixStream {}
31
32impl UnixStream {
33    pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
34        Self { fd }
35    }
36
37    /// Connect UnixStream to a path.
38    pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
39        let (addr, addr_len) = socket_addr(path.as_ref())?;
40        Self::inner_connect(addr, addr_len).await
41    }
42
43    /// Connects the socket to an address.
44    pub async fn connect_addr(addr: SocketAddr) -> io::Result<Self> {
45        let (addr, addr_len) = addr.into_parts();
46        Self::inner_connect(addr, addr_len).await
47    }
48
49    #[inline(always)]
50    async fn inner_connect(
51        sockaddr: libc::sockaddr_un,
52        socklen: libc::socklen_t,
53    ) -> io::Result<Self> {
54        let socket = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
55        let op = Op::connect_unix(SharedFd::new::<false>(socket)?, sockaddr, socklen)?;
56        let completion = op.await;
57        completion.meta.result?;
58
59        let stream = Self::from_shared_fd(completion.data.fd);
60        if crate::driver::op::is_legacy() {
61            stream.writable(true).await?;
62        }
63        // getsockopt
64        let sys_socket = unsafe { std::os::unix::net::UnixStream::from_raw_fd(stream.fd.raw_fd()) };
65        let err = sys_socket.take_error();
66        let _ = sys_socket.into_raw_fd();
67        if let Some(e) = err? {
68            return Err(e);
69        }
70        Ok(stream)
71    }
72
73    /// Creates an unnamed pair of connected sockets.
74    ///
75    /// Returns two `UnixStream`s which are connected to each other.
76    pub fn pair() -> io::Result<(Self, Self)> {
77        let (a, b) = pair(libc::SOCK_STREAM)?;
78        Ok((Self::from_std(a)?, Self::from_std(b)?))
79    }
80
81    /// Returns effective credentials of the process which called `connect` or
82    /// `pair`.
83    pub fn peer_cred(&self) -> io::Result<UCred> {
84        super::ucred::get_peer_cred(self)
85    }
86
87    /// Creates new `UnixStream` from a `std::os::unix::net::UnixStream`.
88    pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
89        match SharedFd::new::<false>(stream.as_raw_fd()) {
90            Ok(shared) => {
91                let _ = stream.into_raw_fd();
92                Ok(Self::from_shared_fd(shared))
93            }
94            Err(e) => Err(e),
95        }
96    }
97
98    /// Returns the socket address of the local half of this connection.
99    pub fn local_addr(&self) -> io::Result<SocketAddr> {
100        local_addr(self.as_raw_fd())
101    }
102
103    /// Returns the socket address of the remote half of this connection.
104    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
105        peer_addr(self.as_raw_fd())
106    }
107
108    /// Wait for read readiness.
109    /// Note: Do not use it before every io. It is different from other runtimes!
110    ///
111    /// Everytime call to this method may pay a syscall cost.
112    /// In uring impl, it will push a PollAdd op; in epoll impl, it will use use
113    /// inner readiness state; if !relaxed, it will call syscall poll after that.
114    ///
115    /// If relaxed, on legacy driver it may return false positive result.
116    /// If you want to do io by your own, you must maintain io readiness and wait
117    /// for io ready with relaxed=false.
118    pub async fn readable(&self, relaxed: bool) -> io::Result<()> {
119        let op = Op::poll_read(&self.fd, relaxed).unwrap();
120        op.wait().await
121    }
122
123    /// Wait for write readiness.
124    /// Note: Do not use it before every io. It is different from other runtimes!
125    ///
126    /// Everytime call to this method may pay a syscall cost.
127    /// In uring impl, it will push a PollAdd op; in epoll impl, it will use use
128    /// inner readiness state; if !relaxed, it will call syscall poll after that.
129    ///
130    /// If relaxed, on legacy driver it may return false positive result.
131    /// If you want to do io by your own, you must maintain io readiness and wait
132    /// for io ready with relaxed=false.
133    pub async fn writable(&self, relaxed: bool) -> io::Result<()> {
134        let op = Op::poll_write(&self.fd, relaxed).unwrap();
135        op.wait().await
136    }
137}
138
139impl AsReadFd for UnixStream {
140    #[inline]
141    fn as_reader_fd(&mut self) -> &SharedFdWrapper {
142        SharedFdWrapper::new(&self.fd)
143    }
144}
145
146impl AsWriteFd for UnixStream {
147    #[inline]
148    fn as_writer_fd(&mut self) -> &SharedFdWrapper {
149        SharedFdWrapper::new(&self.fd)
150    }
151}
152
153impl IntoRawFd for UnixStream {
154    #[inline]
155    fn into_raw_fd(self) -> RawFd {
156        self.fd
157            .try_unwrap()
158            .expect("unexpected multiple reference to rawfd")
159    }
160}
161
162impl AsRawFd for UnixStream {
163    #[inline]
164    fn as_raw_fd(&self) -> RawFd {
165        self.fd.raw_fd()
166    }
167}
168
169impl std::fmt::Debug for UnixStream {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.debug_struct("UnixStream").field("fd", &self.fd).finish()
172    }
173}
174
175impl AsyncWriteRent for UnixStream {
176    #[inline]
177    fn write<T: IoBuf>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
178        // Submit the write operation
179        let op = Op::send(self.fd.clone(), buf).unwrap();
180        op.result()
181    }
182
183    #[inline]
184    fn writev<T: IoVecBuf>(&mut self, buf_vec: T) -> impl Future<Output = BufResult<usize, T>> {
185        let op = Op::writev(self.fd.clone(), buf_vec).unwrap();
186        op.result()
187    }
188
189    #[inline]
190    async fn flush(&mut self) -> std::io::Result<()> {
191        // Unix stream does not need flush.
192        Ok(())
193    }
194
195    fn shutdown(&mut self) -> impl Future<Output = std::io::Result<()>> {
196        // We could use shutdown op here, which requires kernel 5.11+.
197        // However, for simplicity, we just close the socket using direct syscall.
198        let fd = self.as_raw_fd();
199        async move {
200            match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
201                -1 => Err(io::Error::last_os_error()),
202                _ => Ok(()),
203            }
204        }
205    }
206}
207
208impl CancelableAsyncWriteRent for UnixStream {
209    #[inline]
210    async fn cancelable_write<T: IoBuf>(
211        &mut self,
212        buf: T,
213        c: CancelHandle,
214    ) -> crate::BufResult<usize, T> {
215        let fd = self.fd.clone();
216
217        if c.canceled() {
218            return (Err(operation_canceled()), buf);
219        }
220
221        let op = Op::send(fd, buf).unwrap();
222        let _guard = c.associate_op(op.op_canceller());
223        op.result().await
224    }
225
226    #[inline]
227    async fn cancelable_writev<T: IoVecBuf>(
228        &mut self,
229        buf_vec: T,
230        c: CancelHandle,
231    ) -> crate::BufResult<usize, T> {
232        let fd = self.fd.clone();
233
234        if c.canceled() {
235            return (Err(operation_canceled()), buf_vec);
236        }
237
238        let op = Op::writev(fd.clone(), buf_vec).unwrap();
239        let _guard = c.associate_op(op.op_canceller());
240        op.result().await
241    }
242
243    #[inline]
244    async fn cancelable_flush(&mut self, _c: CancelHandle) -> io::Result<()> {
245        // Unix stream does not need flush.
246        Ok(())
247    }
248
249    async fn cancelable_shutdown(&mut self, _c: CancelHandle) -> io::Result<()> {
250        // We could use shutdown op here, which requires kernel 5.11+.
251        // However, for simplicity, we just close the socket using direct syscall.
252        let fd = self.as_raw_fd();
253        match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
254            -1 => Err(io::Error::last_os_error()),
255            _ => Ok(()),
256        }
257    }
258}
259
260impl AsyncReadRent for UnixStream {
261    #[inline]
262    fn read<T: IoBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
263        // Submit the read operation
264        let op = Op::recv(self.fd.clone(), buf).unwrap();
265        op.result()
266    }
267
268    #[inline]
269    fn readv<T: IoVecBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
270        // Submit the read operation
271        let op = Op::readv(self.fd.clone(), buf).unwrap();
272        op.result()
273    }
274}
275
276impl CancelableAsyncReadRent for UnixStream {
277    #[inline]
278    async fn cancelable_read<T: IoBufMut>(
279        &mut self,
280        buf: T,
281        c: CancelHandle,
282    ) -> crate::BufResult<usize, T> {
283        let fd = self.fd.clone();
284
285        if c.canceled() {
286            return (Err(operation_canceled()), buf);
287        }
288
289        let op = Op::recv(fd, buf).unwrap();
290        let _guard = c.associate_op(op.op_canceller());
291        op.result().await
292    }
293
294    #[inline]
295    async fn cancelable_readv<T: IoVecBufMut>(
296        &mut self,
297        buf: T,
298        c: CancelHandle,
299    ) -> crate::BufResult<usize, T> {
300        let fd = self.fd.clone();
301
302        if c.canceled() {
303            return (Err(operation_canceled()), buf);
304        }
305
306        let op = Op::readv(fd, buf).unwrap();
307        let _guard = c.associate_op(op.op_canceller());
308        op.result().await
309    }
310}
311
312#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
313impl tokio::io::AsyncRead for UnixStream {
314    fn poll_read(
315        self: std::pin::Pin<&mut Self>,
316        cx: &mut std::task::Context<'_>,
317        buf: &mut tokio::io::ReadBuf<'_>,
318    ) -> std::task::Poll<io::Result<()>> {
319        unsafe {
320            let slice = buf.unfilled_mut();
321            let raw_buf = crate::buf::RawBuf::new(slice.as_ptr() as *const u8, slice.len());
322            let mut recv = Op::recv_raw(&self.fd, raw_buf);
323            let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut recv, cx));
324
325            std::task::Poll::Ready(ret.result.map(|n| {
326                let n = n.into_inner();
327                buf.assume_init(n as usize);
328                buf.advance(n as usize);
329            }))
330        }
331    }
332}
333
334#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
335impl tokio::io::AsyncWrite for UnixStream {
336    fn poll_write(
337        self: std::pin::Pin<&mut Self>,
338        cx: &mut std::task::Context<'_>,
339        buf: &[u8],
340    ) -> std::task::Poll<Result<usize, io::Error>> {
341        unsafe {
342            let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
343            let mut send = Op::send_raw(&self.fd, raw_buf);
344            let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx));
345
346            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
347        }
348    }
349
350    fn poll_flush(
351        self: std::pin::Pin<&mut Self>,
352        _cx: &mut std::task::Context<'_>,
353    ) -> std::task::Poll<Result<(), io::Error>> {
354        std::task::Poll::Ready(Ok(()))
355    }
356
357    fn poll_shutdown(
358        self: std::pin::Pin<&mut Self>,
359        _cx: &mut std::task::Context<'_>,
360    ) -> std::task::Poll<Result<(), io::Error>> {
361        let fd = self.as_raw_fd();
362        let res = match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
363            -1 => Err(io::Error::last_os_error()),
364            _ => Ok(()),
365        };
366        std::task::Poll::Ready(res)
367    }
368}