monoio/net/tcp/
listener.rs

1use std::{
2    cell::UnsafeCell,
3    io,
4    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
5};
6
7#[cfg(unix)]
8use {
9    libc::{sockaddr_in, sockaddr_in6, AF_INET, AF_INET6},
10    std::os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
11};
12#[cfg(windows)]
13use {
14    std::os::windows::prelude::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket},
15    windows_sys::Win32::Networking::WinSock::{
16        AF_INET, AF_INET6, SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6,
17    },
18};
19
20use super::stream::TcpStream;
21use crate::{
22    driver::{op::Op, shared_fd::SharedFd},
23    io::{stream::Stream, CancelHandle},
24    net::ListenerOpts,
25};
26
27/// TcpListener
28pub struct TcpListener {
29    fd: SharedFd,
30    sys_listener: Option<std::net::TcpListener>,
31    meta: UnsafeCell<ListenerMeta>,
32}
33
34impl TcpListener {
35    #[allow(unreachable_code, clippy::diverging_sub_expression, unused_variables)]
36    pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
37        #[cfg(unix)]
38        let sys_listener = unsafe { std::net::TcpListener::from_raw_fd(fd.raw_fd()) };
39        #[cfg(windows)]
40        let sys_listener = unsafe { std::net::TcpListener::from_raw_socket(fd.raw_socket()) };
41        Self {
42            fd,
43            sys_listener: Some(sys_listener),
44            meta: UnsafeCell::new(ListenerMeta::default()),
45        }
46    }
47
48    /// Bind to address with config
49    pub fn bind_with_config<A: ToSocketAddrs>(addr: A, opts: &ListenerOpts) -> io::Result<Self> {
50        let addr = addr
51            .to_socket_addrs()?
52            .next()
53            .ok_or_else(|| io::Error::other("empty address"))?;
54
55        let domain = if addr.is_ipv6() {
56            socket2::Domain::IPV6
57        } else {
58            socket2::Domain::IPV4
59        };
60        let sys_listener =
61            socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
62
63        #[cfg(feature = "legacy")]
64        Self::set_non_blocking(&sys_listener)?;
65
66        let addr = socket2::SockAddr::from(addr);
67        #[cfg(unix)]
68        if opts.reuse_port {
69            sys_listener.set_reuse_port(true)?;
70        }
71        if opts.reuse_addr {
72            sys_listener.set_reuse_address(true)?;
73        }
74        if let Some(send_buf_size) = opts.send_buf_size {
75            sys_listener.set_send_buffer_size(send_buf_size)?;
76        }
77        if let Some(recv_buf_size) = opts.recv_buf_size {
78            sys_listener.set_recv_buffer_size(recv_buf_size)?;
79        }
80        if opts.tcp_fast_open {
81            #[cfg(any(target_os = "linux", target_os = "android"))]
82            super::tfo::set_tcp_fastopen(&sys_listener, opts.backlog)?;
83            #[cfg(any(target_os = "ios", target_os = "macos"))]
84            let _ = super::tfo::set_tcp_fastopen_force_enable(&sys_listener);
85        }
86        sys_listener.bind(&addr)?;
87        sys_listener.listen(opts.backlog)?;
88
89        #[cfg(any(target_os = "ios", target_os = "macos"))]
90        if opts.tcp_fast_open {
91            super::tfo::set_tcp_fastopen(&sys_listener)?;
92        }
93
94        #[cfg(unix)]
95        let fd = sys_listener.into_raw_fd();
96
97        #[cfg(windows)]
98        let fd = sys_listener.into_raw_socket();
99
100        Ok(Self::from_shared_fd(SharedFd::new::<false>(fd)?))
101    }
102
103    /// Bind to address
104    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
105        const DEFAULT_CFG: ListenerOpts = ListenerOpts::new();
106        Self::bind_with_config(addr, &DEFAULT_CFG)
107    }
108
109    /// Accept
110    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
111        let op = Op::accept(&self.fd)?;
112
113        // Await the completion of the event
114        let completion = op.await;
115
116        // Convert fd
117        let fd = completion.meta.result?;
118
119        // Construct stream
120        let stream = TcpStream::from_shared_fd(SharedFd::new::<false>(fd.into_inner() as _)?);
121
122        // Construct SocketAddr
123        let storage = completion.data.addr.0.as_ptr();
124        let addr = unsafe {
125            match (*storage).ss_family as _ {
126                AF_INET => {
127                    // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
128                    let addr: &sockaddr_in = &*(storage as *const sockaddr_in);
129                    #[cfg(unix)]
130                    let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
131                    #[cfg(windows)]
132                    let ip = Ipv4Addr::from(addr.sin_addr.S_un.S_addr.to_ne_bytes());
133                    let port = u16::from_be(addr.sin_port);
134                    SocketAddr::V4(SocketAddrV4::new(ip, port))
135                }
136                AF_INET6 => {
137                    // Safety: if the ss_family field is AF_INET6 then storage must be a
138                    // sockaddr_in6.
139                    let addr: &sockaddr_in6 = &*(storage as *const sockaddr_in6);
140                    #[cfg(unix)]
141                    let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
142                    #[cfg(windows)]
143                    let ip = Ipv6Addr::from(addr.sin6_addr.u.Byte);
144                    let port = u16::from_be(addr.sin6_port);
145                    #[cfg(unix)]
146                    let scope_id = addr.sin6_scope_id;
147                    #[cfg(windows)]
148                    let scope_id = addr.Anonymous.sin6_scope_id;
149                    SocketAddr::V6(SocketAddrV6::new(ip, port, addr.sin6_flowinfo, scope_id))
150                }
151                _ => {
152                    return Err(io::ErrorKind::InvalidInput.into());
153                }
154            }
155        };
156
157        Ok((stream, addr))
158    }
159
160    /// Cancelable accept
161    pub async fn cancelable_accept(&self, c: CancelHandle) -> io::Result<(TcpStream, SocketAddr)> {
162        use crate::io::operation_canceled;
163
164        if c.canceled() {
165            return Err(operation_canceled());
166        }
167        let op = Op::accept(&self.fd)?;
168        let _guard = c.associate_op(op.op_canceller());
169
170        // Await the completion of the event
171        let completion = op.await;
172
173        // Convert fd
174        let fd = completion.meta.result?;
175
176        // Construct stream
177        let stream = TcpStream::from_shared_fd(SharedFd::new::<false>(fd.into_inner() as _)?);
178
179        // Construct SocketAddr
180        let storage = completion.data.addr.0.as_ptr();
181        let addr = unsafe {
182            match (*storage).ss_family as _ {
183                AF_INET => {
184                    // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
185                    let addr: &sockaddr_in = &*(storage as *const sockaddr_in);
186                    #[cfg(unix)]
187                    let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
188                    #[cfg(windows)]
189                    let ip = Ipv4Addr::from(addr.sin_addr.S_un.S_addr.to_ne_bytes());
190                    let port = u16::from_be(addr.sin_port);
191                    SocketAddr::V4(SocketAddrV4::new(ip, port))
192                }
193                AF_INET6 => {
194                    // Safety: if the ss_family field is AF_INET6 then storage must be a
195                    // sockaddr_in6.
196                    let addr: &sockaddr_in6 = &*(storage as *const sockaddr_in6);
197                    #[cfg(unix)]
198                    let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
199                    #[cfg(windows)]
200                    let ip = Ipv6Addr::from(addr.sin6_addr.u.Byte);
201                    let port = u16::from_be(addr.sin6_port);
202                    #[cfg(unix)]
203                    let scope_id = addr.sin6_scope_id;
204                    #[cfg(windows)]
205                    let scope_id = addr.Anonymous.sin6_scope_id;
206                    SocketAddr::V6(SocketAddrV6::new(ip, port, addr.sin6_flowinfo, scope_id))
207                }
208                _ => {
209                    return Err(io::ErrorKind::InvalidInput.into());
210                }
211            }
212        };
213
214        Ok((stream, addr))
215    }
216
217    /// Returns the local address that this listener is bound to.
218    pub fn local_addr(&self) -> io::Result<SocketAddr> {
219        let meta = self.meta.get();
220        if let Some(addr) = unsafe { &*meta }.local_addr {
221            return Ok(addr);
222        }
223        self.sys_listener
224            .as_ref()
225            .unwrap()
226            .local_addr()
227            .inspect(|&addr| {
228                unsafe { &mut *meta }.local_addr = Some(addr);
229            })
230    }
231
232    #[cfg(feature = "legacy")]
233    fn set_non_blocking(_socket: &socket2::Socket) -> io::Result<()> {
234        crate::driver::CURRENT.with(|x| match x {
235            // TODO: windows ioring support
236            #[cfg(all(target_os = "linux", feature = "iouring"))]
237            crate::driver::Inner::Uring(_) => Ok(()),
238            crate::driver::Inner::Legacy(_) => _socket.set_nonblocking(true),
239        })
240    }
241
242    /// Wait for read readiness.
243    /// Note: Do not use it before every io. It is different from other runtimes!
244    ///
245    /// Everytime call to this method may pay a syscall cost.
246    /// In uring impl, it will push a PollAdd op; in epoll impl, it will use use
247    /// inner readiness state; if !relaxed, it will call syscall poll after that.
248    ///
249    /// If relaxed, on legacy driver it may return false positive result.
250    /// If you want to do io by your own, you must maintain io readiness and wait
251    /// for io ready with relaxed=false.
252    pub async fn readable(&self, relaxed: bool) -> io::Result<()> {
253        let op = Op::poll_read(&self.fd, relaxed).unwrap();
254        op.wait().await
255    }
256
257    /// Creates new `TcpListener` from a `std::net::TcpListener`.
258    pub fn from_std(stdl: std::net::TcpListener) -> io::Result<Self> {
259        #[cfg(unix)]
260        let fd = stdl.as_raw_fd();
261        #[cfg(windows)]
262        let fd = stdl.as_raw_socket();
263        match SharedFd::new::<false>(fd) {
264            Ok(shared) => {
265                #[cfg(unix)]
266                let _ = stdl.into_raw_fd();
267                #[cfg(windows)]
268                let _ = stdl.into_raw_socket();
269                Ok(Self::from_shared_fd(shared))
270            }
271            Err(e) => Err(e),
272        }
273    }
274}
275
276impl Stream for TcpListener {
277    type Item = io::Result<(TcpStream, SocketAddr)>;
278
279    #[inline]
280    async fn next(&mut self) -> Option<Self::Item> {
281        Some(self.accept().await)
282    }
283}
284
285impl std::fmt::Debug for TcpListener {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        f.debug_struct("TcpListener").field("fd", &self.fd).finish()
288    }
289}
290
291#[cfg(unix)]
292impl AsRawFd for TcpListener {
293    #[inline]
294    fn as_raw_fd(&self) -> RawFd {
295        self.fd.raw_fd()
296    }
297}
298
299#[cfg(windows)]
300impl AsRawSocket for TcpListener {
301    #[inline]
302    fn as_raw_socket(&self) -> RawSocket {
303        self.fd.raw_socket()
304    }
305}
306
307impl Drop for TcpListener {
308    #[inline]
309    fn drop(&mut self) {
310        let listener = self.sys_listener.take().unwrap();
311        #[cfg(unix)]
312        let _ = listener.into_raw_fd();
313        #[cfg(windows)]
314        let _ = listener.into_raw_socket();
315    }
316}
317
318#[derive(Debug, Default, Clone)]
319struct ListenerMeta {
320    local_addr: Option<SocketAddr>,
321}