1use std::{
2 cell::UnsafeCell,
3 future::Future,
4 io,
5 net::{SocketAddr, ToSocketAddrs},
6 time::Duration,
7};
8
9#[cfg(unix)]
10use {
11 libc::{shutdown, AF_INET, AF_INET6, SHUT_WR, SOCK_STREAM},
12 std::os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
13};
14#[cfg(windows)]
15use {
16 std::os::windows::prelude::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket},
17 windows_sys::Win32::Networking::WinSock::{
18 shutdown, AF_INET, AF_INET6, SD_SEND as SHUT_WR, SOCK_STREAM,
19 },
20};
21
22use crate::{
23 buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut},
24 driver::{op::Op, shared_fd::SharedFd},
25 io::{
26 as_fd::{AsReadFd, AsWriteFd, SharedFdWrapper},
27 operation_canceled, AsyncReadRent, AsyncWriteRent, CancelHandle, CancelableAsyncReadRent,
28 CancelableAsyncWriteRent, Split,
29 },
30 BufResult,
31};
32
33#[derive(Debug, Clone, Copy)]
35#[non_exhaustive]
36pub struct TcpConnectOpts {
37 pub tcp_fast_open: bool,
39}
40
41impl Default for TcpConnectOpts {
42 #[inline]
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl TcpConnectOpts {
49 #[inline]
51 pub const fn new() -> Self {
52 Self {
53 tcp_fast_open: false,
54 }
55 }
56
57 #[must_use]
63 #[inline]
64 pub fn tcp_fast_open(mut self, fast_open: bool) -> Self {
65 self.tcp_fast_open = fast_open;
66 self
67 }
68}
69pub struct TcpStream {
71 pub(super) fd: SharedFd,
72 meta: StreamMeta,
73}
74
75unsafe impl Split for TcpStream {}
77
78impl TcpStream {
79 pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
80 #[cfg(unix)]
81 let meta = StreamMeta::new(fd.raw_fd());
82 #[cfg(windows)]
83 let meta = StreamMeta::new(fd.raw_socket());
84 #[cfg(feature = "zero-copy")]
85 meta.set_zero_copy();
87
88 Self { fd, meta }
89 }
90
91 pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
96 let addr = addr
98 .to_socket_addrs()?
99 .next()
100 .ok_or_else(|| io::Error::other("empty address"))?;
101
102 Self::connect_addr(addr).await
103 }
104
105 pub async fn connect_addr(addr: SocketAddr) -> io::Result<Self> {
107 const DEFAULT_OPTS: TcpConnectOpts = TcpConnectOpts {
108 tcp_fast_open: false,
109 };
110 Self::connect_addr_with_config(addr, &DEFAULT_OPTS).await
111 }
112
113 pub async fn connect_addr_with_config(
115 addr: SocketAddr,
116 opts: &TcpConnectOpts,
117 ) -> io::Result<Self> {
118 let domain = match addr {
119 SocketAddr::V4(_) => AF_INET,
120 SocketAddr::V6(_) => AF_INET6,
121 };
122 let socket = crate::net::new_socket(domain, SOCK_STREAM)?;
123 #[allow(unused_mut)]
124 let mut tfo = opts.tcp_fast_open;
125
126 if tfo {
127 #[cfg(any(target_os = "linux", target_os = "android"))]
128 super::tfo::try_set_tcp_fastopen_connect(&socket);
129 #[cfg(any(target_os = "ios", target_os = "macos"))]
130 if super::tfo::set_tcp_fastopen_force_enable(&socket).is_err() {
132 tfo = false;
133 }
134 }
135 let completion = Op::connect(SharedFd::new::<false>(socket)?, addr, tfo)?.await;
136 completion.meta.result?;
137
138 let stream = TcpStream::from_shared_fd(completion.data.fd);
139 if crate::driver::op::is_legacy() {
141 #[cfg(all(any(target_os = "ios", target_os = "macos"), feature = "legacy"))]
142 if !tfo {
143 stream.writable(true).await?;
144 } else {
145 crate::driver::CURRENT.with(|inner| match inner {
147 crate::driver::Inner::Legacy(inner) => {
148 let idx = stream.fd.registered_index().unwrap();
149 if let Some(mut readiness) =
150 unsafe { &mut *inner.get() }.io_dispatch.get(idx)
151 {
152 readiness.set_writable();
153 }
154 }
155 #[allow(unreachable_patterns)]
156 _ => unreachable!("should never happens"),
157 })
158 }
159 #[cfg(not(any(target_os = "ios", target_os = "macos")))]
160 stream.writable(true).await?;
161
162 #[cfg(unix)]
164 let sys_socket = unsafe { std::net::TcpStream::from_raw_fd(stream.fd.raw_fd()) };
165 #[cfg(windows)]
166 let sys_socket =
167 unsafe { std::net::TcpStream::from_raw_socket(stream.fd.raw_socket()) };
168 let err = sys_socket.take_error();
169 #[cfg(unix)]
170 let _ = sys_socket.into_raw_fd();
171 #[cfg(windows)]
172 let _ = sys_socket.into_raw_socket();
173 if let Some(e) = err? {
174 return Err(e);
175 }
176 }
177 Ok(stream)
178 }
179
180 #[inline]
182 pub fn local_addr(&self) -> io::Result<SocketAddr> {
183 self.meta.local_addr()
184 }
185
186 #[inline]
188 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
189 self.meta.peer_addr()
190 }
191
192 #[inline]
194 pub fn nodelay(&self) -> io::Result<bool> {
195 self.meta.no_delay()
196 }
197
198 #[inline]
200 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
201 self.meta.set_no_delay(nodelay)
202 }
203
204 #[inline]
206 pub fn set_tcp_keepalive(
207 &self,
208 time: Option<Duration>,
209 interval: Option<Duration>,
210 retries: Option<u32>,
211 ) -> io::Result<()> {
212 self.meta.set_tcp_keepalive(time, interval, retries)
213 }
214
215 pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
217 #[cfg(unix)]
218 let fd = stream.as_raw_fd();
219 #[cfg(windows)]
220 let fd = stream.as_raw_socket();
221 match SharedFd::new::<false>(fd) {
222 Ok(shared) => {
223 #[cfg(unix)]
224 let _ = stream.into_raw_fd();
225 #[cfg(windows)]
226 let _ = stream.into_raw_socket();
227 Ok(Self::from_shared_fd(shared))
228 }
229 Err(e) => Err(e),
230 }
231 }
232
233 pub async fn readable(&self, relaxed: bool) -> io::Result<()> {
244 let op = Op::poll_read(&self.fd, relaxed).unwrap();
245 op.wait().await
246 }
247
248 pub async fn writable(&self, relaxed: bool) -> io::Result<()> {
259 let op = Op::poll_write(&self.fd, relaxed).unwrap();
260 op.wait().await
261 }
262}
263
264impl AsReadFd for TcpStream {
265 #[inline]
266 fn as_reader_fd(&mut self) -> &SharedFdWrapper {
267 SharedFdWrapper::new(&self.fd)
268 }
269}
270
271impl AsWriteFd for TcpStream {
272 #[inline]
273 fn as_writer_fd(&mut self) -> &SharedFdWrapper {
274 SharedFdWrapper::new(&self.fd)
275 }
276}
277
278#[cfg(unix)]
279impl IntoRawFd for TcpStream {
280 #[inline]
281 fn into_raw_fd(self) -> RawFd {
282 self.fd
283 .try_unwrap()
284 .expect("unexpected multiple reference to rawfd")
285 }
286}
287#[cfg(unix)]
288impl AsRawFd for TcpStream {
289 #[inline]
290 fn as_raw_fd(&self) -> RawFd {
291 self.fd.raw_fd()
292 }
293}
294
295#[cfg(windows)]
296impl IntoRawSocket for TcpStream {
297 #[inline]
298 fn into_raw_socket(self) -> RawSocket {
299 self.fd
300 .try_unwrap()
301 .expect("unexpected multiple reference to rawfd")
302 }
303}
304
305#[cfg(windows)]
306impl AsRawSocket for TcpStream {
307 #[inline]
308 fn as_raw_socket(&self) -> RawSocket {
309 self.fd.raw_socket()
310 }
311}
312
313impl std::fmt::Debug for TcpStream {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("TcpStream").field("fd", &self.fd).finish()
316 }
317}
318
319impl AsyncWriteRent for TcpStream {
320 #[inline]
321 fn write<T: IoBuf>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
322 let op = Op::send(self.fd.clone(), buf).unwrap();
324 op.result()
325 }
326
327 #[inline]
328 fn writev<T: IoVecBuf>(&mut self, buf_vec: T) -> impl Future<Output = BufResult<usize, T>> {
329 let op = Op::writev(self.fd.clone(), buf_vec).unwrap();
330 op.result()
331 }
332
333 #[inline]
334 fn flush(&mut self) -> impl Future<Output = std::io::Result<()>> {
335 std::future::ready(Ok(()))
337 }
338
339 fn shutdown(&mut self) -> impl Future<Output = std::io::Result<()>> {
340 #[cfg(unix)]
343 let fd = self.as_raw_fd();
344 #[cfg(windows)]
345 let fd = self.as_raw_socket() as _;
346 let res = match unsafe { shutdown(fd, SHUT_WR) } {
347 -1 => Err(io::Error::last_os_error()),
348 _ => Ok(()),
349 };
350 std::future::ready(res)
351 }
352}
353
354impl CancelableAsyncWriteRent for TcpStream {
355 #[inline]
356 async fn cancelable_write<T: IoBuf>(
357 &mut self,
358 buf: T,
359 c: CancelHandle,
360 ) -> crate::BufResult<usize, T> {
361 let fd = self.fd.clone();
362
363 if c.canceled() {
364 return (Err(operation_canceled()), buf);
365 }
366
367 let op = Op::send(fd, buf).unwrap();
368 let _guard = c.associate_op(op.op_canceller());
369 op.result().await
370 }
371
372 #[inline]
373 async fn cancelable_writev<T: IoVecBuf>(
374 &mut self,
375 buf_vec: T,
376 c: CancelHandle,
377 ) -> crate::BufResult<usize, T> {
378 let fd = self.fd.clone();
379
380 if c.canceled() {
381 return (Err(operation_canceled()), buf_vec);
382 }
383
384 let op = Op::writev(fd.clone(), buf_vec).unwrap();
385 let _guard = c.associate_op(op.op_canceller());
386 op.result().await
387 }
388
389 #[inline]
390 async fn cancelable_flush(&mut self, _c: CancelHandle) -> io::Result<()> {
391 Ok(())
393 }
394
395 fn cancelable_shutdown(&mut self, _c: CancelHandle) -> impl Future<Output = io::Result<()>> {
396 #[cfg(unix)]
399 let fd = self.as_raw_fd();
400 #[cfg(windows)]
401 let fd = self.as_raw_socket() as _;
402 let res = match unsafe { shutdown(fd, SHUT_WR) } {
403 -1 => Err(io::Error::last_os_error()),
404 _ => Ok(()),
405 };
406 std::future::ready(res)
407 }
408}
409
410impl AsyncReadRent for TcpStream {
411 #[inline]
412 fn read<T: IoBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
413 let op = Op::recv(self.fd.clone(), buf).unwrap();
415 op.result()
416 }
417
418 #[inline]
419 fn readv<T: IoVecBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
420 let op = Op::readv(self.fd.clone(), buf).unwrap();
422 op.result()
423 }
424}
425
426impl CancelableAsyncReadRent for TcpStream {
427 #[inline]
428 async fn cancelable_read<T: IoBufMut>(
429 &mut self,
430 buf: T,
431 c: CancelHandle,
432 ) -> crate::BufResult<usize, T> {
433 let fd = self.fd.clone();
434
435 if c.canceled() {
436 return (Err(operation_canceled()), buf);
437 }
438
439 let op = Op::recv(fd, buf).unwrap();
440 let _guard = c.associate_op(op.op_canceller());
441 op.result().await
442 }
443
444 #[inline]
445 async fn cancelable_readv<T: IoVecBufMut>(
446 &mut self,
447 buf: T,
448 c: CancelHandle,
449 ) -> crate::BufResult<usize, T> {
450 let fd = self.fd.clone();
451
452 if c.canceled() {
453 return (Err(operation_canceled()), buf);
454 }
455
456 let op = Op::readv(fd, buf).unwrap();
457 let _guard = c.associate_op(op.op_canceller());
458 op.result().await
459 }
460}
461
462#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
463impl tokio::io::AsyncRead for TcpStream {
464 fn poll_read(
465 self: std::pin::Pin<&mut Self>,
466 cx: &mut std::task::Context<'_>,
467 buf: &mut tokio::io::ReadBuf<'_>,
468 ) -> std::task::Poll<io::Result<()>> {
469 unsafe {
470 let slice = buf.unfilled_mut();
471 let raw_buf = crate::buf::RawBuf::new(slice.as_ptr() as *const u8, slice.len());
472 let mut recv = Op::recv_raw(&self.fd, raw_buf);
473 let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut recv, cx));
474
475 std::task::Poll::Ready(ret.result.map(|n| {
476 let n = n.into_inner();
477 buf.assume_init(n as usize);
478 buf.advance(n as usize);
479 }))
480 }
481 }
482}
483
484#[cfg(all(unix, feature = "legacy", feature = "tokio-compat"))]
485impl tokio::io::AsyncWrite for TcpStream {
486 fn poll_write(
487 self: std::pin::Pin<&mut Self>,
488 cx: &mut std::task::Context<'_>,
489 buf: &[u8],
490 ) -> std::task::Poll<Result<usize, io::Error>> {
491 unsafe {
492 let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
493 let mut send = Op::send_raw(&self.fd, raw_buf);
494 let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx));
495
496 std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
497 }
498 }
499
500 fn poll_flush(
501 self: std::pin::Pin<&mut Self>,
502 _cx: &mut std::task::Context<'_>,
503 ) -> std::task::Poll<Result<(), io::Error>> {
504 std::task::Poll::Ready(Ok(()))
505 }
506
507 fn poll_shutdown(
508 self: std::pin::Pin<&mut Self>,
509 _cx: &mut std::task::Context<'_>,
510 ) -> std::task::Poll<Result<(), io::Error>> {
511 let fd = self.as_raw_fd();
512 let res = match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
513 -1 => Err(io::Error::last_os_error()),
514 _ => Ok(()),
515 };
516 std::task::Poll::Ready(res)
517 }
518
519 fn poll_write_vectored(
520 self: std::pin::Pin<&mut Self>,
521 cx: &mut std::task::Context<'_>,
522 bufs: &[std::io::IoSlice<'_>],
523 ) -> std::task::Poll<Result<usize, io::Error>> {
524 unsafe {
525 let raw_buf =
526 crate::buf::RawBufVectored::new(bufs.as_ptr() as *const libc::iovec, bufs.len());
527 let mut writev = Op::writev_raw(&self.fd, raw_buf);
528 let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut writev, cx));
529
530 std::task::Poll::Ready(ret.result.map(|n| n.into_inner() as usize))
531 }
532 }
533
534 fn is_write_vectored(&self) -> bool {
535 true
536 }
537}
538
539struct StreamMeta {
540 socket: Option<socket2::Socket>,
541 meta: UnsafeCell<Meta>,
542}
543
544#[derive(Debug, Default, Clone)]
545struct Meta {
546 local_addr: Option<SocketAddr>,
547 peer_addr: Option<SocketAddr>,
548}
549
550impl StreamMeta {
551 #[cfg(unix)]
552 fn new(fd: RawFd) -> Self {
553 Self {
554 socket: unsafe { Some(socket2::Socket::from_raw_fd(fd)) },
555 meta: Default::default(),
556 }
557 }
558
559 #[cfg(windows)]
562 fn new(fd: RawSocket) -> Self {
563 Self {
564 socket: unsafe { Some(socket2::Socket::from_raw_socket(fd)) },
565 meta: Default::default(),
566 }
567 }
568
569 fn local_addr(&self) -> io::Result<SocketAddr> {
570 let meta = unsafe { &mut *self.meta.get() };
571 if let Some(addr) = meta.local_addr {
572 return Ok(addr);
573 }
574
575 let ret = self
576 .socket
577 .as_ref()
578 .unwrap()
579 .local_addr()
580 .map(|addr| addr.as_socket().expect("tcp socket is expected"));
581 if let Ok(addr) = ret {
582 meta.local_addr = Some(addr);
583 }
584 ret
585 }
586
587 fn peer_addr(&self) -> io::Result<SocketAddr> {
588 let meta = unsafe { &mut *self.meta.get() };
589 if let Some(addr) = meta.peer_addr {
590 return Ok(addr);
591 }
592
593 let ret = self
594 .socket
595 .as_ref()
596 .unwrap()
597 .peer_addr()
598 .map(|addr| addr.as_socket().expect("tcp socket is expected"));
599 if let Ok(addr) = ret {
600 meta.peer_addr = Some(addr);
601 }
602 ret
603 }
604
605 fn no_delay(&self) -> io::Result<bool> {
606 self.socket.as_ref().unwrap().nodelay()
607 }
608
609 fn set_no_delay(&self, no_delay: bool) -> io::Result<()> {
610 self.socket.as_ref().unwrap().set_nodelay(no_delay)
611 }
612
613 #[allow(unused_variables)]
614 fn set_tcp_keepalive(
615 &self,
616 time: Option<Duration>,
617 interval: Option<Duration>,
618 retries: Option<u32>,
619 ) -> io::Result<()> {
620 let mut t = socket2::TcpKeepalive::new();
621 if let Some(time) = time {
622 t = t.with_time(time)
623 }
624 if let Some(interval) = interval {
625 t = t.with_interval(interval)
626 }
627 #[cfg(unix)]
628 if let Some(retries) = retries {
629 t = t.with_retries(retries)
630 }
631 self.socket.as_ref().unwrap().set_tcp_keepalive(&t)
632 }
633
634 #[cfg(feature = "zero-copy")]
635 fn set_zero_copy(&self) {
636 #[cfg(target_os = "linux")]
637 unsafe {
638 let fd = self.socket.as_ref().unwrap().as_raw_fd();
639 let v: libc::c_int = 1;
640 libc::setsockopt(
641 fd,
642 libc::SOL_SOCKET,
643 libc::SO_ZEROCOPY,
644 &v as *const _ as *const _,
645 std::mem::size_of::<libc::c_int>() as _,
646 );
647 }
648 }
649}
650
651impl Drop for StreamMeta {
652 fn drop(&mut self) {
653 let socket = self.socket.take().unwrap();
654 #[cfg(unix)]
655 let _ = socket.into_raw_fd();
656 #[cfg(windows)]
657 let _ = socket.into_raw_socket();
658 }
659}