monoio/net/tcp/
stream.rs

1use std::{
2    cell::UnsafeCell,
3    future::Future,
4    io,
5    net::{SocketAddr, ToSocketAddrs},
6    time::Duration,
7};
8
9#[cfg(unix)]
10use {
11    libc::{shutdown, AF_INET, AF_INET6, SHUT_WR, SOCK_STREAM},
12    std::os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
13};
14#[cfg(windows)]
15use {
16    std::os::windows::prelude::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket},
17    windows_sys::Win32::Networking::WinSock::{
18        shutdown, AF_INET, AF_INET6, SD_SEND as SHUT_WR, SOCK_STREAM,
19    },
20};
21
22use crate::{
23    buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut},
24    driver::{op::Op, shared_fd::SharedFd},
25    io::{
26        as_fd::{AsReadFd, AsWriteFd, SharedFdWrapper},
27        operation_canceled, AsyncReadRent, AsyncWriteRent, CancelHandle, CancelableAsyncReadRent,
28        CancelableAsyncWriteRent, Split,
29    },
30    BufResult,
31};
32
33/// Custom tcp connect options
34#[derive(Debug, Clone, Copy)]
35#[non_exhaustive]
36pub struct TcpConnectOpts {
37    /// TCP fast open.
38    pub tcp_fast_open: bool,
39}
40
41impl Default for TcpConnectOpts {
42    #[inline]
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl TcpConnectOpts {
49    /// Create a default TcpConnectOpts.
50    #[inline]
51    pub const fn new() -> Self {
52        Self {
53            tcp_fast_open: false,
54        }
55    }
56
57    /// Specify FastOpen
58    /// Note: This option only works for linux 4.1+
59    /// and macos/ios 9.0+.
60    /// If it is enabled, the connection will be
61    /// established on the first call to write.
62    #[must_use]
63    #[inline]
64    pub fn tcp_fast_open(mut self, fast_open: bool) -> Self {
65        self.tcp_fast_open = fast_open;
66        self
67    }
68}
69/// TcpStream
70pub struct TcpStream {
71    pub(super) fd: SharedFd,
72    meta: StreamMeta,
73}
74
75/// TcpStream is safe to split to two parts
76unsafe impl Split for TcpStream {}
77
78impl TcpStream {
79    pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
80        #[cfg(unix)]
81        let meta = StreamMeta::new(fd.raw_fd());
82        #[cfg(windows)]
83        let meta = StreamMeta::new(fd.raw_socket());
84        #[cfg(feature = "zero-copy")]
85        // enable SOCK_ZEROCOPY
86        meta.set_zero_copy();
87
88        Self { fd, meta }
89    }
90
91    /// Open a TCP connection to a remote host.
92    /// Note: This function may block the current thread while resolution is
93    /// performed.
94    // TODO(chihai): Fix it, maybe spawn_blocking like tokio.
95    pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
96        let addrs = addr.to_socket_addrs()?;
97        let mut last_error = None;
98        // The logic is similar to what std::net::TcpStream::connect does
99        for addr in addrs {
100            match Self::connect_addr(addr).await {
101                Ok(stream) => return Ok(stream),
102                Err(err) => last_error = Some(err),
103            }
104        }
105        Err(last_error.unwrap_or_else(|| io::Error::other("empty address")))
106    }
107
108    /// Establish a connection to the specified `addr`.
109    pub async fn connect_addr(addr: SocketAddr) -> io::Result<Self> {
110        const DEFAULT_OPTS: TcpConnectOpts = TcpConnectOpts {
111            tcp_fast_open: false,
112        };
113        Self::connect_addr_with_config(addr, &DEFAULT_OPTS).await
114    }
115
116    /// Establish a connection to the specified `addr` with given config.
117    pub async fn connect_addr_with_config(
118        addr: SocketAddr,
119        opts: &TcpConnectOpts,
120    ) -> io::Result<Self> {
121        let domain = match addr {
122            SocketAddr::V4(_) => AF_INET,
123            SocketAddr::V6(_) => AF_INET6,
124        };
125        let socket = crate::net::new_socket(domain, SOCK_STREAM)?;
126        #[allow(unused_mut)]
127        let mut tfo = opts.tcp_fast_open;
128
129        if tfo {
130            #[cfg(any(target_os = "linux", target_os = "android"))]
131            super::tfo::try_set_tcp_fastopen_connect(&socket);
132            #[cfg(any(target_os = "ios", target_os = "macos"))]
133            // if we cannot set force tcp fastopen, we will not use it.
134            if super::tfo::set_tcp_fastopen_force_enable(&socket).is_err() {
135                tfo = false;
136            }
137        }
138        let completion = Op::connect(SharedFd::new::<false>(socket)?, addr, tfo)?.await;
139        completion.meta.result?;
140
141        let stream = TcpStream::from_shared_fd(completion.data.fd);
142        // wait write ready on epoll branch
143        if crate::driver::op::is_legacy() {
144            #[cfg(all(any(target_os = "ios", target_os = "macos"), feature = "legacy"))]
145            if !tfo {
146                stream.writable(true).await?;
147            } else {
148                // set writable as init state
149                crate::driver::CURRENT.with(|inner| match inner {
150                    crate::driver::Inner::Legacy(inner) => {
151                        let idx = stream.fd.registered_index().unwrap();
152                        if let Some(mut readiness) =
153                            unsafe { &mut *inner.get() }.io_dispatch.get(idx)
154                        {
155                            readiness.set_writable();
156                        }
157                    }
158                    #[allow(unreachable_patterns)]
159                    _ => unreachable!("should never happens"),
160                })
161            }
162            #[cfg(not(any(target_os = "ios", target_os = "macos")))]
163            stream.writable(true).await?;
164
165            // getsockopt libc::SO_ERROR
166            #[cfg(unix)]
167            let sys_socket = unsafe { std::net::TcpStream::from_raw_fd(stream.fd.raw_fd()) };
168            #[cfg(windows)]
169            let sys_socket =
170                unsafe { std::net::TcpStream::from_raw_socket(stream.fd.raw_socket()) };
171            let err = sys_socket.take_error();
172            #[cfg(unix)]
173            let _ = sys_socket.into_raw_fd();
174            #[cfg(windows)]
175            let _ = sys_socket.into_raw_socket();
176            if let Some(e) = err? {
177                return Err(e);
178            }
179        }
180        Ok(stream)
181    }
182
183    /// Return the local address that this stream is bound to.
184    #[inline]
185    pub fn local_addr(&self) -> io::Result<SocketAddr> {
186        self.meta.local_addr()
187    }
188
189    /// Return the remote address that this stream is connected to.
190    #[inline]
191    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
192        self.meta.peer_addr()
193    }
194
195    /// Get the value of the `TCP_NODELAY` option on this socket.
196    #[inline]
197    pub fn nodelay(&self) -> io::Result<bool> {
198        self.meta.no_delay()
199    }
200
201    /// Set the value of the `TCP_NODELAY` option on this socket.
202    #[inline]
203    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
204        self.meta.set_no_delay(nodelay)
205    }
206
207    /// Set the value of the `SO_KEEPALIVE` option on this socket.
208    #[inline]
209    pub fn set_tcp_keepalive(
210        &self,
211        time: Option<Duration>,
212        interval: Option<Duration>,
213        retries: Option<u32>,
214    ) -> io::Result<()> {
215        self.meta.set_tcp_keepalive(time, interval, retries)
216    }
217
218    /// Creates new `TcpStream` from a `std::net::TcpStream`.
219    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
220        #[cfg(unix)]
221        let fd = stream.as_raw_fd();
222        #[cfg(windows)]
223        let fd = stream.as_raw_socket();
224        match SharedFd::new::<false>(fd) {
225            Ok(shared) => {
226                #[cfg(unix)]
227                let _ = stream.into_raw_fd();
228                #[cfg(windows)]
229                let _ = stream.into_raw_socket();
230                Ok(Self::from_shared_fd(shared))
231            }
232            Err(e) => Err(e),
233        }
234    }
235
236    /// Wait for read readiness.
237    /// Note: Do not use it before every io. It is different from other runtimes!
238    ///
239    /// Everytime call to this method may pay a syscall cost.
240    /// In uring impl, it will push a PollAdd op; in epoll impl, it will use use
241    /// inner readiness state; if !relaxed, it will call syscall poll after that.
242    ///
243    /// If relaxed, on legacy driver it may return false positive result.
244    /// If you want to do io by your own, you must maintain io readiness and wait
245    /// for io ready with relaxed=false.
246    pub async fn readable(&self, relaxed: bool) -> io::Result<()> {
247        let op = Op::poll_read(&self.fd, relaxed).unwrap();
248        op.wait().await
249    }
250
251    /// Wait for write readiness.
252    /// Note: Do not use it before every io. It is different from other runtimes!
253    ///
254    /// Everytime call to this method may pay a syscall cost.
255    /// In uring impl, it will push a PollAdd op; in epoll impl, it will use use
256    /// inner readiness state; if !relaxed, it will call syscall poll after that.
257    ///
258    /// If relaxed, on legacy driver it may return false positive result.
259    /// If you want to do io by your own, you must maintain io readiness and wait
260    /// for io ready with relaxed=false.
261    pub async fn writable(&self, relaxed: bool) -> io::Result<()> {
262        let op = Op::poll_write(&self.fd, relaxed).unwrap();
263        op.wait().await
264    }
265}
266
267impl AsReadFd for TcpStream {
268    #[inline]
269    fn as_reader_fd(&mut self) -> &SharedFdWrapper {
270        SharedFdWrapper::new(&self.fd)
271    }
272}
273
274impl AsWriteFd for TcpStream {
275    #[inline]
276    fn as_writer_fd(&mut self) -> &SharedFdWrapper {
277        SharedFdWrapper::new(&self.fd)
278    }
279}
280
281#[cfg(unix)]
282impl IntoRawFd for TcpStream {
283    #[inline]
284    fn into_raw_fd(self) -> RawFd {
285        self.fd
286            .try_unwrap()
287            .expect("unexpected multiple reference to rawfd")
288    }
289}
290#[cfg(unix)]
291impl AsRawFd for TcpStream {
292    #[inline]
293    fn as_raw_fd(&self) -> RawFd {
294        self.fd.raw_fd()
295    }
296}
297
298#[cfg(windows)]
299impl IntoRawSocket for TcpStream {
300    #[inline]
301    fn into_raw_socket(self) -> RawSocket {
302        self.fd
303            .try_unwrap()
304            .expect("unexpected multiple reference to rawfd")
305    }
306}
307
308#[cfg(windows)]
309impl AsRawSocket for TcpStream {
310    #[inline]
311    fn as_raw_socket(&self) -> RawSocket {
312        self.fd.raw_socket()
313    }
314}
315
316impl std::fmt::Debug for TcpStream {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        f.debug_struct("TcpStream").field("fd", &self.fd).finish()
319    }
320}
321
322impl AsyncWriteRent for TcpStream {
323    #[inline]
324    fn write<T: IoBuf>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
325        // Submit the write operation
326        let op = Op::send(self.fd.clone(), buf).unwrap();
327        op.result()
328    }
329
330    #[inline]
331    fn writev<T: IoVecBuf>(&mut self, buf_vec: T) -> impl Future<Output = BufResult<usize, T>> {
332        let op = Op::writev(self.fd.clone(), buf_vec).unwrap();
333        op.result()
334    }
335
336    #[inline]
337    fn flush(&mut self) -> impl Future<Output = std::io::Result<()>> {
338        // Tcp stream does not need flush.
339        std::future::ready(Ok(()))
340    }
341
342    fn shutdown(&mut self) -> impl Future<Output = std::io::Result<()>> {
343        // We could use shutdown op here, which requires kernel 5.11+.
344        // However, for simplicity, we just close the socket using direct syscall.
345        #[cfg(unix)]
346        let fd = self.as_raw_fd();
347        #[cfg(windows)]
348        let fd = self.as_raw_socket() as _;
349        let res = match unsafe { shutdown(fd, SHUT_WR) } {
350            -1 => Err(io::Error::last_os_error()),
351            _ => Ok(()),
352        };
353        std::future::ready(res)
354    }
355}
356
357impl CancelableAsyncWriteRent for TcpStream {
358    #[inline]
359    async fn cancelable_write<T: IoBuf>(
360        &mut self,
361        buf: T,
362        c: CancelHandle,
363    ) -> crate::BufResult<usize, T> {
364        let fd = self.fd.clone();
365
366        if c.canceled() {
367            return (Err(operation_canceled()), buf);
368        }
369
370        let op = Op::send(fd, buf).unwrap();
371        let _guard = c.associate_op(op.op_canceller());
372        op.result().await
373    }
374
375    #[inline]
376    async fn cancelable_writev<T: IoVecBuf>(
377        &mut self,
378        buf_vec: T,
379        c: CancelHandle,
380    ) -> crate::BufResult<usize, T> {
381        let fd = self.fd.clone();
382
383        if c.canceled() {
384            return (Err(operation_canceled()), buf_vec);
385        }
386
387        let op = Op::writev(fd.clone(), buf_vec).unwrap();
388        let _guard = c.associate_op(op.op_canceller());
389        op.result().await
390    }
391
392    #[inline]
393    async fn cancelable_flush(&mut self, _c: CancelHandle) -> io::Result<()> {
394        // Tcp stream does not need flush.
395        Ok(())
396    }
397
398    fn cancelable_shutdown(&mut self, _c: CancelHandle) -> impl Future<Output = io::Result<()>> {
399        // We could use shutdown op here, which requires kernel 5.11+.
400        // However, for simplicity, we just close the socket using direct syscall.
401        #[cfg(unix)]
402        let fd = self.as_raw_fd();
403        #[cfg(windows)]
404        let fd = self.as_raw_socket() as _;
405        let res = match unsafe { shutdown(fd, SHUT_WR) } {
406            -1 => Err(io::Error::last_os_error()),
407            _ => Ok(()),
408        };
409        std::future::ready(res)
410    }
411}
412
413impl AsyncReadRent for TcpStream {
414    #[inline]
415    fn read<T: IoBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
416        // Submit the read operation
417        let op = Op::recv(self.fd.clone(), buf).unwrap();
418        op.result()
419    }
420
421    #[inline]
422    fn readv<T: IoVecBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
423        // Submit the read operation
424        let op = Op::readv(self.fd.clone(), buf).unwrap();
425        op.result()
426    }
427}
428
429impl CancelableAsyncReadRent for TcpStream {
430    #[inline]
431    async fn cancelable_read<T: IoBufMut>(
432        &mut self,
433        buf: T,
434        c: CancelHandle,
435    ) -> crate::BufResult<usize, T> {
436        let fd = self.fd.clone();
437
438        if c.canceled() {
439            return (Err(operation_canceled()), buf);
440        }
441
442        let op = Op::recv(fd, buf).unwrap();
443        let _guard = c.associate_op(op.op_canceller());
444        op.result().await
445    }
446
447    #[inline]
448    async fn cancelable_readv<T: IoVecBufMut>(
449        &mut self,
450        buf: T,
451        c: CancelHandle,
452    ) -> crate::BufResult<usize, T> {
453        let fd = self.fd.clone();
454
455        if c.canceled() {
456            return (Err(operation_canceled()), buf);
457        }
458
459        let op = Op::readv(fd, buf).unwrap();
460        let _guard = c.associate_op(op.op_canceller());
461        op.result().await
462    }
463}
464
465#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
466impl tokio::io::AsyncRead for TcpStream {
467    fn poll_read(
468        self: std::pin::Pin<&mut Self>,
469        cx: &mut std::task::Context<'_>,
470        buf: &mut tokio::io::ReadBuf<'_>,
471    ) -> std::task::Poll<io::Result<()>> {
472        unsafe {
473            let slice = buf.unfilled_mut();
474            let raw_buf = crate::buf::RawBuf::new(slice.as_ptr() as *const u8, slice.len());
475            let mut recv = Op::recv_raw(&self.fd, raw_buf);
476            let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut recv, cx));
477
478            std::task::Poll::Ready(ret.result.map(|n| {
479                let n = n.into_inner();
480                buf.assume_init(n as usize);
481                buf.advance(n as usize);
482            }))
483        }
484    }
485}
486
487#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
488impl tokio::io::AsyncWrite for TcpStream {
489    fn poll_write(
490        self: std::pin::Pin<&mut Self>,
491        cx: &mut std::task::Context<'_>,
492        buf: &[u8],
493    ) -> std::task::Poll<Result<usize, io::Error>> {
494        unsafe {
495            let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
496            let mut send = Op::send_raw(&self.fd, raw_buf);
497            let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx));
498
499            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
500        }
501    }
502
503    fn poll_flush(
504        self: std::pin::Pin<&mut Self>,
505        _cx: &mut std::task::Context<'_>,
506    ) -> std::task::Poll<Result<(), io::Error>> {
507        std::task::Poll::Ready(Ok(()))
508    }
509
510    fn poll_shutdown(
511        self: std::pin::Pin<&mut Self>,
512        _cx: &mut std::task::Context<'_>,
513    ) -> std::task::Poll<Result<(), io::Error>> {
514        let fd = self.as_raw_fd();
515        let res = match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
516            -1 => Err(io::Error::last_os_error()),
517            _ => Ok(()),
518        };
519        std::task::Poll::Ready(res)
520    }
521
522    fn poll_write_vectored(
523        self: std::pin::Pin<&mut Self>,
524        cx: &mut std::task::Context<'_>,
525        bufs: &[std::io::IoSlice<'_>],
526    ) -> std::task::Poll<Result<usize, io::Error>> {
527        unsafe {
528            let raw_buf =
529                crate::buf::RawBufVectored::new(bufs.as_ptr() as *const libc::iovec, bufs.len());
530            let mut writev = Op::writev_raw(&self.fd, raw_buf);
531            let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut writev, cx));
532
533            std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
534        }
535    }
536
537    fn is_write_vectored(&self) -> bool {
538        true
539    }
540}
541
542struct StreamMeta {
543    socket: Option<socket2::Socket>,
544    meta: UnsafeCell<Meta>,
545}
546
547#[derive(Debug, Default, Clone)]
548struct Meta {
549    local_addr: Option<SocketAddr>,
550    peer_addr: Option<SocketAddr>,
551}
552
553impl StreamMeta {
554    #[cfg(unix)]
555    fn new(fd: RawFd) -> Self {
556        Self {
557            socket: unsafe { Some(socket2::Socket::from_raw_fd(fd)) },
558            meta: Default::default(),
559        }
560    }
561
562    /// When operating files, we should use RawHandle;
563    /// When operating sockets, we should use RawSocket;
564    #[cfg(windows)]
565    fn new(fd: RawSocket) -> Self {
566        Self {
567            socket: unsafe { Some(socket2::Socket::from_raw_socket(fd)) },
568            meta: Default::default(),
569        }
570    }
571
572    fn local_addr(&self) -> io::Result<SocketAddr> {
573        let meta = unsafe { &mut *self.meta.get() };
574        if let Some(addr) = meta.local_addr {
575            return Ok(addr);
576        }
577
578        let ret = self
579            .socket
580            .as_ref()
581            .unwrap()
582            .local_addr()
583            .map(|addr| addr.as_socket().expect("tcp socket is expected"));
584        if let Ok(addr) = ret {
585            meta.local_addr = Some(addr);
586        }
587        ret
588    }
589
590    fn peer_addr(&self) -> io::Result<SocketAddr> {
591        let meta = unsafe { &mut *self.meta.get() };
592        if let Some(addr) = meta.peer_addr {
593            return Ok(addr);
594        }
595
596        let ret = self
597            .socket
598            .as_ref()
599            .unwrap()
600            .peer_addr()
601            .map(|addr| addr.as_socket().expect("tcp socket is expected"));
602        if let Ok(addr) = ret {
603            meta.peer_addr = Some(addr);
604        }
605        ret
606    }
607
608    fn no_delay(&self) -> io::Result<bool> {
609        self.socket.as_ref().unwrap().tcp_nodelay()
610    }
611
612    fn set_no_delay(&self, no_delay: bool) -> io::Result<()> {
613        self.socket.as_ref().unwrap().set_tcp_nodelay(no_delay)
614    }
615
616    #[allow(unused_variables)]
617    fn set_tcp_keepalive(
618        &self,
619        time: Option<Duration>,
620        interval: Option<Duration>,
621        retries: Option<u32>,
622    ) -> io::Result<()> {
623        let mut t = socket2::TcpKeepalive::new();
624        if let Some(time) = time {
625            t = t.with_time(time)
626        }
627        if let Some(interval) = interval {
628            t = t.with_interval(interval)
629        }
630        #[cfg(unix)]
631        if let Some(retries) = retries {
632            t = t.with_retries(retries)
633        }
634        self.socket.as_ref().unwrap().set_tcp_keepalive(&t)
635    }
636
637    #[cfg(feature = "zero-copy")]
638    fn set_zero_copy(&self) {
639        #[cfg(target_os = "linux")]
640        unsafe {
641            let fd = self.socket.as_ref().unwrap().as_raw_fd();
642            let v: libc::c_int = 1;
643            libc::setsockopt(
644                fd,
645                libc::SOL_SOCKET,
646                libc::SO_ZEROCOPY,
647                &v as *const _ as *const _,
648                std::mem::size_of::<libc::c_int>() as _,
649            );
650        }
651    }
652}
653
654impl Drop for StreamMeta {
655    fn drop(&mut self) {
656        let socket = self.socket.take().unwrap();
657        #[cfg(unix)]
658        let _ = socket.into_raw_fd();
659        #[cfg(windows)]
660        let _ = socket.into_raw_socket();
661    }
662}