1use std::{
2 cell::UnsafeCell,
3 future::Future,
4 io,
5 net::{SocketAddr, ToSocketAddrs},
6 time::Duration,
7};
8
9#[cfg(unix)]
10use {
11 libc::{shutdown, AF_INET, AF_INET6, SHUT_WR, SOCK_STREAM},
12 std::os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
13};
14#[cfg(windows)]
15use {
16 std::os::windows::prelude::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket},
17 windows_sys::Win32::Networking::WinSock::{
18 shutdown, AF_INET, AF_INET6, SD_SEND as SHUT_WR, SOCK_STREAM,
19 },
20};
21
22use crate::{
23 buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut},
24 driver::{op::Op, shared_fd::SharedFd},
25 io::{
26 as_fd::{AsReadFd, AsWriteFd, SharedFdWrapper},
27 operation_canceled, AsyncReadRent, AsyncWriteRent, CancelHandle, CancelableAsyncReadRent,
28 CancelableAsyncWriteRent, Split,
29 },
30 BufResult,
31};
32
33#[derive(Debug, Clone, Copy)]
35#[non_exhaustive]
36pub struct TcpConnectOpts {
37 pub tcp_fast_open: bool,
39}
40
41impl Default for TcpConnectOpts {
42 #[inline]
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl TcpConnectOpts {
49 #[inline]
51 pub const fn new() -> Self {
52 Self {
53 tcp_fast_open: false,
54 }
55 }
56
57 #[must_use]
63 #[inline]
64 pub fn tcp_fast_open(mut self, fast_open: bool) -> Self {
65 self.tcp_fast_open = fast_open;
66 self
67 }
68}
69pub struct TcpStream {
71 pub(super) fd: SharedFd,
72 meta: StreamMeta,
73}
74
75unsafe impl Split for TcpStream {}
77
78impl TcpStream {
79 pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
80 #[cfg(unix)]
81 let meta = StreamMeta::new(fd.raw_fd());
82 #[cfg(windows)]
83 let meta = StreamMeta::new(fd.raw_socket());
84 #[cfg(feature = "zero-copy")]
85 meta.set_zero_copy();
87
88 Self { fd, meta }
89 }
90
91 pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
96 let addrs = addr.to_socket_addrs()?;
97 let mut last_error = None;
98 for addr in addrs {
100 match Self::connect_addr(addr).await {
101 Ok(stream) => return Ok(stream),
102 Err(err) => last_error = Some(err),
103 }
104 }
105 Err(last_error.unwrap_or_else(|| io::Error::other("empty address")))
106 }
107
108 pub async fn connect_addr(addr: SocketAddr) -> io::Result<Self> {
110 const DEFAULT_OPTS: TcpConnectOpts = TcpConnectOpts {
111 tcp_fast_open: false,
112 };
113 Self::connect_addr_with_config(addr, &DEFAULT_OPTS).await
114 }
115
116 pub async fn connect_addr_with_config(
118 addr: SocketAddr,
119 opts: &TcpConnectOpts,
120 ) -> io::Result<Self> {
121 let domain = match addr {
122 SocketAddr::V4(_) => AF_INET,
123 SocketAddr::V6(_) => AF_INET6,
124 };
125 let socket = crate::net::new_socket(domain, SOCK_STREAM)?;
126 #[allow(unused_mut)]
127 let mut tfo = opts.tcp_fast_open;
128
129 if tfo {
130 #[cfg(any(target_os = "linux", target_os = "android"))]
131 super::tfo::try_set_tcp_fastopen_connect(&socket);
132 #[cfg(any(target_os = "ios", target_os = "macos"))]
133 if super::tfo::set_tcp_fastopen_force_enable(&socket).is_err() {
135 tfo = false;
136 }
137 }
138 let completion = Op::connect(SharedFd::new::<false>(socket)?, addr, tfo)?.await;
139 completion.meta.result?;
140
141 let stream = TcpStream::from_shared_fd(completion.data.fd);
142 if crate::driver::op::is_legacy() {
144 #[cfg(all(any(target_os = "ios", target_os = "macos"), feature = "legacy"))]
145 if !tfo {
146 stream.writable(true).await?;
147 } else {
148 crate::driver::CURRENT.with(|inner| match inner {
150 crate::driver::Inner::Legacy(inner) => {
151 let idx = stream.fd.registered_index().unwrap();
152 if let Some(mut readiness) =
153 unsafe { &mut *inner.get() }.io_dispatch.get(idx)
154 {
155 readiness.set_writable();
156 }
157 }
158 #[allow(unreachable_patterns)]
159 _ => unreachable!("should never happens"),
160 })
161 }
162 #[cfg(not(any(target_os = "ios", target_os = "macos")))]
163 stream.writable(true).await?;
164
165 #[cfg(unix)]
167 let sys_socket = unsafe { std::net::TcpStream::from_raw_fd(stream.fd.raw_fd()) };
168 #[cfg(windows)]
169 let sys_socket =
170 unsafe { std::net::TcpStream::from_raw_socket(stream.fd.raw_socket()) };
171 let err = sys_socket.take_error();
172 #[cfg(unix)]
173 let _ = sys_socket.into_raw_fd();
174 #[cfg(windows)]
175 let _ = sys_socket.into_raw_socket();
176 if let Some(e) = err? {
177 return Err(e);
178 }
179 }
180 Ok(stream)
181 }
182
183 #[inline]
185 pub fn local_addr(&self) -> io::Result<SocketAddr> {
186 self.meta.local_addr()
187 }
188
189 #[inline]
191 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
192 self.meta.peer_addr()
193 }
194
195 #[inline]
197 pub fn nodelay(&self) -> io::Result<bool> {
198 self.meta.no_delay()
199 }
200
201 #[inline]
203 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
204 self.meta.set_no_delay(nodelay)
205 }
206
207 #[inline]
209 pub fn set_tcp_keepalive(
210 &self,
211 time: Option<Duration>,
212 interval: Option<Duration>,
213 retries: Option<u32>,
214 ) -> io::Result<()> {
215 self.meta.set_tcp_keepalive(time, interval, retries)
216 }
217
218 pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
220 #[cfg(unix)]
221 let fd = stream.as_raw_fd();
222 #[cfg(windows)]
223 let fd = stream.as_raw_socket();
224 match SharedFd::new::<false>(fd) {
225 Ok(shared) => {
226 #[cfg(unix)]
227 let _ = stream.into_raw_fd();
228 #[cfg(windows)]
229 let _ = stream.into_raw_socket();
230 Ok(Self::from_shared_fd(shared))
231 }
232 Err(e) => Err(e),
233 }
234 }
235
236 pub async fn readable(&self, relaxed: bool) -> io::Result<()> {
247 let op = Op::poll_read(&self.fd, relaxed).unwrap();
248 op.wait().await
249 }
250
251 pub async fn writable(&self, relaxed: bool) -> io::Result<()> {
262 let op = Op::poll_write(&self.fd, relaxed).unwrap();
263 op.wait().await
264 }
265}
266
267impl AsReadFd for TcpStream {
268 #[inline]
269 fn as_reader_fd(&mut self) -> &SharedFdWrapper {
270 SharedFdWrapper::new(&self.fd)
271 }
272}
273
274impl AsWriteFd for TcpStream {
275 #[inline]
276 fn as_writer_fd(&mut self) -> &SharedFdWrapper {
277 SharedFdWrapper::new(&self.fd)
278 }
279}
280
281#[cfg(unix)]
282impl IntoRawFd for TcpStream {
283 #[inline]
284 fn into_raw_fd(self) -> RawFd {
285 self.fd
286 .try_unwrap()
287 .expect("unexpected multiple reference to rawfd")
288 }
289}
290#[cfg(unix)]
291impl AsRawFd for TcpStream {
292 #[inline]
293 fn as_raw_fd(&self) -> RawFd {
294 self.fd.raw_fd()
295 }
296}
297
298#[cfg(windows)]
299impl IntoRawSocket for TcpStream {
300 #[inline]
301 fn into_raw_socket(self) -> RawSocket {
302 self.fd
303 .try_unwrap()
304 .expect("unexpected multiple reference to rawfd")
305 }
306}
307
308#[cfg(windows)]
309impl AsRawSocket for TcpStream {
310 #[inline]
311 fn as_raw_socket(&self) -> RawSocket {
312 self.fd.raw_socket()
313 }
314}
315
316impl std::fmt::Debug for TcpStream {
317 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 f.debug_struct("TcpStream").field("fd", &self.fd).finish()
319 }
320}
321
322impl AsyncWriteRent for TcpStream {
323 #[inline]
324 fn write<T: IoBuf>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
325 let op = Op::send(self.fd.clone(), buf).unwrap();
327 op.result()
328 }
329
330 #[inline]
331 fn writev<T: IoVecBuf>(&mut self, buf_vec: T) -> impl Future<Output = BufResult<usize, T>> {
332 let op = Op::writev(self.fd.clone(), buf_vec).unwrap();
333 op.result()
334 }
335
336 #[inline]
337 fn flush(&mut self) -> impl Future<Output = std::io::Result<()>> {
338 std::future::ready(Ok(()))
340 }
341
342 fn shutdown(&mut self) -> impl Future<Output = std::io::Result<()>> {
343 #[cfg(unix)]
346 let fd = self.as_raw_fd();
347 #[cfg(windows)]
348 let fd = self.as_raw_socket() as _;
349 let res = match unsafe { shutdown(fd, SHUT_WR) } {
350 -1 => Err(io::Error::last_os_error()),
351 _ => Ok(()),
352 };
353 std::future::ready(res)
354 }
355}
356
357impl CancelableAsyncWriteRent for TcpStream {
358 #[inline]
359 async fn cancelable_write<T: IoBuf>(
360 &mut self,
361 buf: T,
362 c: CancelHandle,
363 ) -> crate::BufResult<usize, T> {
364 let fd = self.fd.clone();
365
366 if c.canceled() {
367 return (Err(operation_canceled()), buf);
368 }
369
370 let op = Op::send(fd, buf).unwrap();
371 let _guard = c.associate_op(op.op_canceller());
372 op.result().await
373 }
374
375 #[inline]
376 async fn cancelable_writev<T: IoVecBuf>(
377 &mut self,
378 buf_vec: T,
379 c: CancelHandle,
380 ) -> crate::BufResult<usize, T> {
381 let fd = self.fd.clone();
382
383 if c.canceled() {
384 return (Err(operation_canceled()), buf_vec);
385 }
386
387 let op = Op::writev(fd.clone(), buf_vec).unwrap();
388 let _guard = c.associate_op(op.op_canceller());
389 op.result().await
390 }
391
392 #[inline]
393 async fn cancelable_flush(&mut self, _c: CancelHandle) -> io::Result<()> {
394 Ok(())
396 }
397
398 fn cancelable_shutdown(&mut self, _c: CancelHandle) -> impl Future<Output = io::Result<()>> {
399 #[cfg(unix)]
402 let fd = self.as_raw_fd();
403 #[cfg(windows)]
404 let fd = self.as_raw_socket() as _;
405 let res = match unsafe { shutdown(fd, SHUT_WR) } {
406 -1 => Err(io::Error::last_os_error()),
407 _ => Ok(()),
408 };
409 std::future::ready(res)
410 }
411}
412
413impl AsyncReadRent for TcpStream {
414 #[inline]
415 fn read<T: IoBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
416 let op = Op::recv(self.fd.clone(), buf).unwrap();
418 op.result()
419 }
420
421 #[inline]
422 fn readv<T: IoVecBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
423 let op = Op::readv(self.fd.clone(), buf).unwrap();
425 op.result()
426 }
427}
428
429impl CancelableAsyncReadRent for TcpStream {
430 #[inline]
431 async fn cancelable_read<T: IoBufMut>(
432 &mut self,
433 buf: T,
434 c: CancelHandle,
435 ) -> crate::BufResult<usize, T> {
436 let fd = self.fd.clone();
437
438 if c.canceled() {
439 return (Err(operation_canceled()), buf);
440 }
441
442 let op = Op::recv(fd, buf).unwrap();
443 let _guard = c.associate_op(op.op_canceller());
444 op.result().await
445 }
446
447 #[inline]
448 async fn cancelable_readv<T: IoVecBufMut>(
449 &mut self,
450 buf: T,
451 c: CancelHandle,
452 ) -> crate::BufResult<usize, T> {
453 let fd = self.fd.clone();
454
455 if c.canceled() {
456 return (Err(operation_canceled()), buf);
457 }
458
459 let op = Op::readv(fd, buf).unwrap();
460 let _guard = c.associate_op(op.op_canceller());
461 op.result().await
462 }
463}
464
465#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
466impl tokio::io::AsyncRead for TcpStream {
467 fn poll_read(
468 self: std::pin::Pin<&mut Self>,
469 cx: &mut std::task::Context<'_>,
470 buf: &mut tokio::io::ReadBuf<'_>,
471 ) -> std::task::Poll<io::Result<()>> {
472 unsafe {
473 let slice = buf.unfilled_mut();
474 let raw_buf = crate::buf::RawBuf::new(slice.as_ptr() as *const u8, slice.len());
475 let mut recv = Op::recv_raw(&self.fd, raw_buf);
476 let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut recv, cx));
477
478 std::task::Poll::Ready(ret.result.map(|n| {
479 let n = n.into_inner();
480 buf.assume_init(n as usize);
481 buf.advance(n as usize);
482 }))
483 }
484 }
485}
486
487#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
488impl tokio::io::AsyncWrite for TcpStream {
489 fn poll_write(
490 self: std::pin::Pin<&mut Self>,
491 cx: &mut std::task::Context<'_>,
492 buf: &[u8],
493 ) -> std::task::Poll<Result<usize, io::Error>> {
494 unsafe {
495 let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
496 let mut send = Op::send_raw(&self.fd, raw_buf);
497 let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx));
498
499 std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
500 }
501 }
502
503 fn poll_flush(
504 self: std::pin::Pin<&mut Self>,
505 _cx: &mut std::task::Context<'_>,
506 ) -> std::task::Poll<Result<(), io::Error>> {
507 std::task::Poll::Ready(Ok(()))
508 }
509
510 fn poll_shutdown(
511 self: std::pin::Pin<&mut Self>,
512 _cx: &mut std::task::Context<'_>,
513 ) -> std::task::Poll<Result<(), io::Error>> {
514 let fd = self.as_raw_fd();
515 let res = match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
516 -1 => Err(io::Error::last_os_error()),
517 _ => Ok(()),
518 };
519 std::task::Poll::Ready(res)
520 }
521
522 fn poll_write_vectored(
523 self: std::pin::Pin<&mut Self>,
524 cx: &mut std::task::Context<'_>,
525 bufs: &[std::io::IoSlice<'_>],
526 ) -> std::task::Poll<Result<usize, io::Error>> {
527 unsafe {
528 let raw_buf =
529 crate::buf::RawBufVectored::new(bufs.as_ptr() as *const libc::iovec, bufs.len());
530 let mut writev = Op::writev_raw(&self.fd, raw_buf);
531 let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut writev, cx));
532
533 std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
534 }
535 }
536
537 fn is_write_vectored(&self) -> bool {
538 true
539 }
540}
541
542struct StreamMeta {
543 socket: Option<socket2::Socket>,
544 meta: UnsafeCell<Meta>,
545}
546
547#[derive(Debug, Default, Clone)]
548struct Meta {
549 local_addr: Option<SocketAddr>,
550 peer_addr: Option<SocketAddr>,
551}
552
553impl StreamMeta {
554 #[cfg(unix)]
555 fn new(fd: RawFd) -> Self {
556 Self {
557 socket: unsafe { Some(socket2::Socket::from_raw_fd(fd)) },
558 meta: Default::default(),
559 }
560 }
561
562 #[cfg(windows)]
565 fn new(fd: RawSocket) -> Self {
566 Self {
567 socket: unsafe { Some(socket2::Socket::from_raw_socket(fd)) },
568 meta: Default::default(),
569 }
570 }
571
572 fn local_addr(&self) -> io::Result<SocketAddr> {
573 let meta = unsafe { &mut *self.meta.get() };
574 if let Some(addr) = meta.local_addr {
575 return Ok(addr);
576 }
577
578 let ret = self
579 .socket
580 .as_ref()
581 .unwrap()
582 .local_addr()
583 .map(|addr| addr.as_socket().expect("tcp socket is expected"));
584 if let Ok(addr) = ret {
585 meta.local_addr = Some(addr);
586 }
587 ret
588 }
589
590 fn peer_addr(&self) -> io::Result<SocketAddr> {
591 let meta = unsafe { &mut *self.meta.get() };
592 if let Some(addr) = meta.peer_addr {
593 return Ok(addr);
594 }
595
596 let ret = self
597 .socket
598 .as_ref()
599 .unwrap()
600 .peer_addr()
601 .map(|addr| addr.as_socket().expect("tcp socket is expected"));
602 if let Ok(addr) = ret {
603 meta.peer_addr = Some(addr);
604 }
605 ret
606 }
607
608 fn no_delay(&self) -> io::Result<bool> {
609 self.socket.as_ref().unwrap().tcp_nodelay()
610 }
611
612 fn set_no_delay(&self, no_delay: bool) -> io::Result<()> {
613 self.socket.as_ref().unwrap().set_tcp_nodelay(no_delay)
614 }
615
616 #[allow(unused_variables)]
617 fn set_tcp_keepalive(
618 &self,
619 time: Option<Duration>,
620 interval: Option<Duration>,
621 retries: Option<u32>,
622 ) -> io::Result<()> {
623 let mut t = socket2::TcpKeepalive::new();
624 if let Some(time) = time {
625 t = t.with_time(time)
626 }
627 if let Some(interval) = interval {
628 t = t.with_interval(interval)
629 }
630 #[cfg(unix)]
631 if let Some(retries) = retries {
632 t = t.with_retries(retries)
633 }
634 self.socket.as_ref().unwrap().set_tcp_keepalive(&t)
635 }
636
637 #[cfg(feature = "zero-copy")]
638 fn set_zero_copy(&self) {
639 #[cfg(target_os = "linux")]
640 unsafe {
641 let fd = self.socket.as_ref().unwrap().as_raw_fd();
642 let v: libc::c_int = 1;
643 libc::setsockopt(
644 fd,
645 libc::SOL_SOCKET,
646 libc::SO_ZEROCOPY,
647 &v as *const _ as *const _,
648 std::mem::size_of::<libc::c_int>() as _,
649 );
650 }
651 }
652}
653
654impl Drop for StreamMeta {
655 fn drop(&mut self) {
656 let socket = self.socket.take().unwrap();
657 #[cfg(unix)]
658 let _ = socket.into_raw_fd();
659 #[cfg(windows)]
660 let _ = socket.into_raw_socket();
661 }
662}