monoio/net/tcp/
stream.rs

1use std::{
2    cell::UnsafeCell,
3    future::Future,
4    io,
5    net::{SocketAddr, ToSocketAddrs},
6    time::Duration,
7};
8
9#[cfg(unix)]
10use {
11    libc::{shutdown, AF_INET, AF_INET6, SHUT_WR, SOCK_STREAM},
12    std::os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
13};
14#[cfg(windows)]
15use {
16    std::os::windows::prelude::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket},
17    windows_sys::Win32::Networking::WinSock::{
18        shutdown, AF_INET, AF_INET6, SD_SEND as SHUT_WR, SOCK_STREAM,
19    },
20};
21
22use crate::{
23    buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut},
24    driver::{op::Op, shared_fd::SharedFd},
25    io::{
26        as_fd::{AsReadFd, AsWriteFd, SharedFdWrapper},
27        operation_canceled, AsyncReadRent, AsyncWriteRent, CancelHandle, CancelableAsyncReadRent,
28        CancelableAsyncWriteRent, Split,
29    },
30    BufResult,
31};
32
33/// Custom tcp connect options
34#[derive(Debug, Clone, Copy)]
35#[non_exhaustive]
36pub struct TcpConnectOpts {
37    /// TCP fast open.
38    pub tcp_fast_open: bool,
39}
40
41impl Default for TcpConnectOpts {
42    #[inline]
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl TcpConnectOpts {
49    /// Create a default TcpConnectOpts.
50    #[inline]
51    pub const fn new() -> Self {
52        Self {
53            tcp_fast_open: false,
54        }
55    }
56
57    /// Specify FastOpen
58    /// Note: This option only works for linux 4.1+
59    /// and macos/ios 9.0+.
60    /// If it is enabled, the connection will be
61    /// established on the first call to write.
62    #[must_use]
63    #[inline]
64    pub fn tcp_fast_open(mut self, fast_open: bool) -> Self {
65        self.tcp_fast_open = fast_open;
66        self
67    }
68}
69/// TcpStream
70pub struct TcpStream {
71    pub(super) fd: SharedFd,
72    meta: StreamMeta,
73}
74
75/// TcpStream is safe to split to two parts
76unsafe impl Split for TcpStream {}
77
78impl TcpStream {
79    pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
80        #[cfg(unix)]
81        let meta = StreamMeta::new(fd.raw_fd());
82        #[cfg(windows)]
83        let meta = StreamMeta::new(fd.raw_socket());
84        #[cfg(feature = "zero-copy")]
85        // enable SOCK_ZEROCOPY
86        meta.set_zero_copy();
87
88        Self { fd, meta }
89    }
90
91    /// Open a TCP connection to a remote host.
92    /// Note: This function may block the current thread while resolution is
93    /// performed.
94    // TODO(chihai): Fix it, maybe spawn_blocking like tokio.
95    pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
96        // TODO(chihai): loop for all addrs
97        let addr = addr
98            .to_socket_addrs()?
99            .next()
100            .ok_or_else(|| io::Error::other("empty address"))?;
101
102        Self::connect_addr(addr).await
103    }
104
105    /// Establish a connection to the specified `addr`.
106    pub async fn connect_addr(addr: SocketAddr) -> io::Result<Self> {
107        const DEFAULT_OPTS: TcpConnectOpts = TcpConnectOpts {
108            tcp_fast_open: false,
109        };
110        Self::connect_addr_with_config(addr, &DEFAULT_OPTS).await
111    }
112
113    /// Establish a connection to the specified `addr` with given config.
114    pub async fn connect_addr_with_config(
115        addr: SocketAddr,
116        opts: &TcpConnectOpts,
117    ) -> io::Result<Self> {
118        let domain = match addr {
119            SocketAddr::V4(_) => AF_INET,
120            SocketAddr::V6(_) => AF_INET6,
121        };
122        let socket = crate::net::new_socket(domain, SOCK_STREAM)?;
123        #[allow(unused_mut)]
124        let mut tfo = opts.tcp_fast_open;
125
126        if tfo {
127            #[cfg(any(target_os = "linux", target_os = "android"))]
128            super::tfo::try_set_tcp_fastopen_connect(&socket);
129            #[cfg(any(target_os = "ios", target_os = "macos"))]
130            // if we cannot set force tcp fastopen, we will not use it.
131            if super::tfo::set_tcp_fastopen_force_enable(&socket).is_err() {
132                tfo = false;
133            }
134        }
135        let completion = Op::connect(SharedFd::new::<false>(socket)?, addr, tfo)?.await;
136        completion.meta.result?;
137
138        let stream = TcpStream::from_shared_fd(completion.data.fd);
139        // wait write ready on epoll branch
140        if crate::driver::op::is_legacy() {
141            #[cfg(all(any(target_os = "ios", target_os = "macos"), feature = "legacy"))]
142            if !tfo {
143                stream.writable(true).await?;
144            } else {
145                // set writable as init state
146                crate::driver::CURRENT.with(|inner| match inner {
147                    crate::driver::Inner::Legacy(inner) => {
148                        let idx = stream.fd.registered_index().unwrap();
149                        if let Some(mut readiness) =
150                            unsafe { &mut *inner.get() }.io_dispatch.get(idx)
151                        {
152                            readiness.set_writable();
153                        }
154                    }
155                    #[allow(unreachable_patterns)]
156                    _ => unreachable!("should never happens"),
157                })
158            }
159            #[cfg(not(any(target_os = "ios", target_os = "macos")))]
160            stream.writable(true).await?;
161
162            // getsockopt libc::SO_ERROR
163            #[cfg(unix)]
164            let sys_socket = unsafe { std::net::TcpStream::from_raw_fd(stream.fd.raw_fd()) };
165            #[cfg(windows)]
166            let sys_socket =
167                unsafe { std::net::TcpStream::from_raw_socket(stream.fd.raw_socket()) };
168            let err = sys_socket.take_error();
169            #[cfg(unix)]
170            let _ = sys_socket.into_raw_fd();
171            #[cfg(windows)]
172            let _ = sys_socket.into_raw_socket();
173            if let Some(e) = err? {
174                return Err(e);
175            }
176        }
177        Ok(stream)
178    }
179
180    /// Return the local address that this stream is bound to.
181    #[inline]
182    pub fn local_addr(&self) -> io::Result<SocketAddr> {
183        self.meta.local_addr()
184    }
185
186    /// Return the remote address that this stream is connected to.
187    #[inline]
188    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
189        self.meta.peer_addr()
190    }
191
192    /// Get the value of the `TCP_NODELAY` option on this socket.
193    #[inline]
194    pub fn nodelay(&self) -> io::Result<bool> {
195        self.meta.no_delay()
196    }
197
198    /// Set the value of the `TCP_NODELAY` option on this socket.
199    #[inline]
200    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
201        self.meta.set_no_delay(nodelay)
202    }
203
204    /// Set the value of the `SO_KEEPALIVE` option on this socket.
205    #[inline]
206    pub fn set_tcp_keepalive(
207        &self,
208        time: Option<Duration>,
209        interval: Option<Duration>,
210        retries: Option<u32>,
211    ) -> io::Result<()> {
212        self.meta.set_tcp_keepalive(time, interval, retries)
213    }
214
215    /// Creates new `TcpStream` from a `std::net::TcpStream`.
216    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
217        #[cfg(unix)]
218        let fd = stream.as_raw_fd();
219        #[cfg(windows)]
220        let fd = stream.as_raw_socket();
221        match SharedFd::new::<false>(fd) {
222            Ok(shared) => {
223                #[cfg(unix)]
224                let _ = stream.into_raw_fd();
225                #[cfg(windows)]
226                let _ = stream.into_raw_socket();
227                Ok(Self::from_shared_fd(shared))
228            }
229            Err(e) => Err(e),
230        }
231    }
232
233    /// Wait for read readiness.
234    /// Note: Do not use it before every io. It is different from other runtimes!
235    ///
236    /// Everytime call to this method may pay a syscall cost.
237    /// In uring impl, it will push a PollAdd op; in epoll impl, it will use use
238    /// inner readiness state; if !relaxed, it will call syscall poll after that.
239    ///
240    /// If relaxed, on legacy driver it may return false positive result.
241    /// If you want to do io by your own, you must maintain io readiness and wait
242    /// for io ready with relaxed=false.
243    pub async fn readable(&self, relaxed: bool) -> io::Result<()> {
244        let op = Op::poll_read(&self.fd, relaxed).unwrap();
245        op.wait().await
246    }
247
248    /// Wait for write readiness.
249    /// Note: Do not use it before every io. It is different from other runtimes!
250    ///
251    /// Everytime call to this method may pay a syscall cost.
252    /// In uring impl, it will push a PollAdd op; in epoll impl, it will use use
253    /// inner readiness state; if !relaxed, it will call syscall poll after that.
254    ///
255    /// If relaxed, on legacy driver it may return false positive result.
256    /// If you want to do io by your own, you must maintain io readiness and wait
257    /// for io ready with relaxed=false.
258    pub async fn writable(&self, relaxed: bool) -> io::Result<()> {
259        let op = Op::poll_write(&self.fd, relaxed).unwrap();
260        op.wait().await
261    }
262}
263
264impl AsReadFd for TcpStream {
265    #[inline]
266    fn as_reader_fd(&mut self) -> &SharedFdWrapper {
267        SharedFdWrapper::new(&self.fd)
268    }
269}
270
271impl AsWriteFd for TcpStream {
272    #[inline]
273    fn as_writer_fd(&mut self) -> &SharedFdWrapper {
274        SharedFdWrapper::new(&self.fd)
275    }
276}
277
278#[cfg(unix)]
279impl IntoRawFd for TcpStream {
280    #[inline]
281    fn into_raw_fd(self) -> RawFd {
282        self.fd
283            .try_unwrap()
284            .expect("unexpected multiple reference to rawfd")
285    }
286}
287#[cfg(unix)]
288impl AsRawFd for TcpStream {
289    #[inline]
290    fn as_raw_fd(&self) -> RawFd {
291        self.fd.raw_fd()
292    }
293}
294
295#[cfg(windows)]
296impl IntoRawSocket for TcpStream {
297    #[inline]
298    fn into_raw_socket(self) -> RawSocket {
299        self.fd
300            .try_unwrap()
301            .expect("unexpected multiple reference to rawfd")
302    }
303}
304
305#[cfg(windows)]
306impl AsRawSocket for TcpStream {
307    #[inline]
308    fn as_raw_socket(&self) -> RawSocket {
309        self.fd.raw_socket()
310    }
311}
312
313impl std::fmt::Debug for TcpStream {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("TcpStream").field("fd", &self.fd).finish()
316    }
317}
318
319impl AsyncWriteRent for TcpStream {
320    #[inline]
321    fn write<T: IoBuf>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
322        // Submit the write operation
323        let op = Op::send(self.fd.clone(), buf).unwrap();
324        op.result()
325    }
326
327    #[inline]
328    fn writev<T: IoVecBuf>(&mut self, buf_vec: T) -> impl Future<Output = BufResult<usize, T>> {
329        let op = Op::writev(self.fd.clone(), buf_vec).unwrap();
330        op.result()
331    }
332
333    #[inline]
334    fn flush(&mut self) -> impl Future<Output = std::io::Result<()>> {
335        // Tcp stream does not need flush.
336        std::future::ready(Ok(()))
337    }
338
339    fn shutdown(&mut self) -> impl Future<Output = std::io::Result<()>> {
340        // We could use shutdown op here, which requires kernel 5.11+.
341        // However, for simplicity, we just close the socket using direct syscall.
342        #[cfg(unix)]
343        let fd = self.as_raw_fd();
344        #[cfg(windows)]
345        let fd = self.as_raw_socket() as _;
346        let res = match unsafe { shutdown(fd, SHUT_WR) } {
347            -1 => Err(io::Error::last_os_error()),
348            _ => Ok(()),
349        };
350        std::future::ready(res)
351    }
352}
353
354impl CancelableAsyncWriteRent for TcpStream {
355    #[inline]
356    async fn cancelable_write<T: IoBuf>(
357        &mut self,
358        buf: T,
359        c: CancelHandle,
360    ) -> crate::BufResult<usize, T> {
361        let fd = self.fd.clone();
362
363        if c.canceled() {
364            return (Err(operation_canceled()), buf);
365        }
366
367        let op = Op::send(fd, buf).unwrap();
368        let _guard = c.associate_op(op.op_canceller());
369        op.result().await
370    }
371
372    #[inline]
373    async fn cancelable_writev<T: IoVecBuf>(
374        &mut self,
375        buf_vec: T,
376        c: CancelHandle,
377    ) -> crate::BufResult<usize, T> {
378        let fd = self.fd.clone();
379
380        if c.canceled() {
381            return (Err(operation_canceled()), buf_vec);
382        }
383
384        let op = Op::writev(fd.clone(), buf_vec).unwrap();
385        let _guard = c.associate_op(op.op_canceller());
386        op.result().await
387    }
388
389    #[inline]
390    async fn cancelable_flush(&mut self, _c: CancelHandle) -> io::Result<()> {
391        // Tcp stream does not need flush.
392        Ok(())
393    }
394
395    fn cancelable_shutdown(&mut self, _c: CancelHandle) -> impl Future<Output = io::Result<()>> {
396        // We could use shutdown op here, which requires kernel 5.11+.
397        // However, for simplicity, we just close the socket using direct syscall.
398        #[cfg(unix)]
399        let fd = self.as_raw_fd();
400        #[cfg(windows)]
401        let fd = self.as_raw_socket() as _;
402        let res = match unsafe { shutdown(fd, SHUT_WR) } {
403            -1 => Err(io::Error::last_os_error()),
404            _ => Ok(()),
405        };
406        std::future::ready(res)
407    }
408}
409
410impl AsyncReadRent for TcpStream {
411    #[inline]
412    fn read<T: IoBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
413        // Submit the read operation
414        let op = Op::recv(self.fd.clone(), buf).unwrap();
415        op.result()
416    }
417
418    #[inline]
419    fn readv<T: IoVecBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
420        // Submit the read operation
421        let op = Op::readv(self.fd.clone(), buf).unwrap();
422        op.result()
423    }
424}
425
426impl CancelableAsyncReadRent for TcpStream {
427    #[inline]
428    async fn cancelable_read<T: IoBufMut>(
429        &mut self,
430        buf: T,
431        c: CancelHandle,
432    ) -> crate::BufResult<usize, T> {
433        let fd = self.fd.clone();
434
435        if c.canceled() {
436            return (Err(operation_canceled()), buf);
437        }
438
439        let op = Op::recv(fd, buf).unwrap();
440        let _guard = c.associate_op(op.op_canceller());
441        op.result().await
442    }
443
444    #[inline]
445    async fn cancelable_readv<T: IoVecBufMut>(
446        &mut self,
447        buf: T,
448        c: CancelHandle,
449    ) -> crate::BufResult<usize, T> {
450        let fd = self.fd.clone();
451
452        if c.canceled() {
453            return (Err(operation_canceled()), buf);
454        }
455
456        let op = Op::readv(fd, buf).unwrap();
457        let _guard = c.associate_op(op.op_canceller());
458        op.result().await
459    }
460}
461
462#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
463impl tokio::io::AsyncRead for TcpStream {
464    fn poll_read(
465        self: std::pin::Pin<&mut Self>,
466        cx: &mut std::task::Context<'_>,
467        buf: &mut tokio::io::ReadBuf<'_>,
468    ) -> std::task::Poll<io::Result<()>> {
469        unsafe {
470            let slice = buf.unfilled_mut();
471            let raw_buf = crate::buf::RawBuf::new(slice.as_ptr() as *const u8, slice.len());
472            let mut recv = Op::recv_raw(&self.fd, raw_buf);
473            let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut recv, cx));
474
475            std::task::Poll::Ready(ret.result.map(|n| {
476                let n = n.into_inner();
477                buf.assume_init(n as usize);
478                buf.advance(n as usize);
479            }))
480        }
481    }
482}
483
484#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
485impl tokio::io::AsyncWrite for TcpStream {
486    fn poll_write(
487        self: std::pin::Pin<&mut Self>,
488        cx: &mut std::task::Context<'_>,
489        buf: &[u8],
490    ) -> std::task::Poll<Result<usize, io::Error>> {
491        unsafe {
492            let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
493            let mut send = Op::send_raw(&self.fd, raw_buf);
494            let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx));
495
496            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
497        }
498    }
499
500    fn poll_flush(
501        self: std::pin::Pin<&mut Self>,
502        _cx: &mut std::task::Context<'_>,
503    ) -> std::task::Poll<Result<(), io::Error>> {
504        std::task::Poll::Ready(Ok(()))
505    }
506
507    fn poll_shutdown(
508        self: std::pin::Pin<&mut Self>,
509        _cx: &mut std::task::Context<'_>,
510    ) -> std::task::Poll<Result<(), io::Error>> {
511        let fd = self.as_raw_fd();
512        let res = match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
513            -1 => Err(io::Error::last_os_error()),
514            _ => Ok(()),
515        };
516        std::task::Poll::Ready(res)
517    }
518
519    fn poll_write_vectored(
520        self: std::pin::Pin<&mut Self>,
521        cx: &mut std::task::Context<'_>,
522        bufs: &[std::io::IoSlice<'_>],
523    ) -> std::task::Poll<Result<usize, io::Error>> {
524        unsafe {
525            let raw_buf =
526                crate::buf::RawBufVectored::new(bufs.as_ptr() as *const libc::iovec, bufs.len());
527            let mut writev = Op::writev_raw(&self.fd, raw_buf);
528            let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut writev, cx));
529
530            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
531        }
532    }
533
534    fn is_write_vectored(&self) -> bool {
535        true
536    }
537}
538
539struct StreamMeta {
540    socket: Option<socket2::Socket>,
541    meta: UnsafeCell<Meta>,
542}
543
544#[derive(Debug, Default, Clone)]
545struct Meta {
546    local_addr: Option<SocketAddr>,
547    peer_addr: Option<SocketAddr>,
548}
549
550impl StreamMeta {
551    #[cfg(unix)]
552    fn new(fd: RawFd) -> Self {
553        Self {
554            socket: unsafe { Some(socket2::Socket::from_raw_fd(fd)) },
555            meta: Default::default(),
556        }
557    }
558
559    /// When operating files, we should use RawHandle;
560    /// When operating sockets, we should use RawSocket;
561    #[cfg(windows)]
562    fn new(fd: RawSocket) -> Self {
563        Self {
564            socket: unsafe { Some(socket2::Socket::from_raw_socket(fd)) },
565            meta: Default::default(),
566        }
567    }
568
569    fn local_addr(&self) -> io::Result<SocketAddr> {
570        let meta = unsafe { &mut *self.meta.get() };
571        if let Some(addr) = meta.local_addr {
572            return Ok(addr);
573        }
574
575        let ret = self
576            .socket
577            .as_ref()
578            .unwrap()
579            .local_addr()
580            .map(|addr| addr.as_socket().expect("tcp socket is expected"));
581        if let Ok(addr) = ret {
582            meta.local_addr = Some(addr);
583        }
584        ret
585    }
586
587    fn peer_addr(&self) -> io::Result<SocketAddr> {
588        let meta = unsafe { &mut *self.meta.get() };
589        if let Some(addr) = meta.peer_addr {
590            return Ok(addr);
591        }
592
593        let ret = self
594            .socket
595            .as_ref()
596            .unwrap()
597            .peer_addr()
598            .map(|addr| addr.as_socket().expect("tcp socket is expected"));
599        if let Ok(addr) = ret {
600            meta.peer_addr = Some(addr);
601        }
602        ret
603    }
604
605    fn no_delay(&self) -> io::Result<bool> {
606        self.socket.as_ref().unwrap().nodelay()
607    }
608
609    fn set_no_delay(&self, no_delay: bool) -> io::Result<()> {
610        self.socket.as_ref().unwrap().set_nodelay(no_delay)
611    }
612
613    #[allow(unused_variables)]
614    fn set_tcp_keepalive(
615        &self,
616        time: Option<Duration>,
617        interval: Option<Duration>,
618        retries: Option<u32>,
619    ) -> io::Result<()> {
620        let mut t = socket2::TcpKeepalive::new();
621        if let Some(time) = time {
622            t = t.with_time(time)
623        }
624        if let Some(interval) = interval {
625            t = t.with_interval(interval)
626        }
627        #[cfg(unix)]
628        if let Some(retries) = retries {
629            t = t.with_retries(retries)
630        }
631        self.socket.as_ref().unwrap().set_tcp_keepalive(&t)
632    }
633
634    #[cfg(feature = "zero-copy")]
635    fn set_zero_copy(&self) {
636        #[cfg(target_os = "linux")]
637        unsafe {
638            let fd = self.socket.as_ref().unwrap().as_raw_fd();
639            let v: libc::c_int = 1;
640            libc::setsockopt(
641                fd,
642                libc::SOL_SOCKET,
643                libc::SO_ZEROCOPY,
644                &v as *const _ as *const _,
645                std::mem::size_of::<libc::c_int>() as _,
646            );
647        }
648    }
649}
650
651impl Drop for StreamMeta {
652    fn drop(&mut self) {
653        let socket = self.socket.take().unwrap();
654        #[cfg(unix)]
655        let _ = socket.into_raw_fd();
656        #[cfg(windows)]
657        let _ = socket.into_raw_socket();
658    }
659}