1use std::{
2 collections::VecDeque,
3 fmt,
4 future::Future,
5 io,
6 io::IoSliceMut,
7 mem,
8 net::{SocketAddr, SocketAddrV6},
9 pin::Pin,
10 str,
11 sync::{Arc, Mutex},
12 task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use crate::runtime::default_runtime;
17use crate::{
18 Instant,
19 runtime::{AsyncUdpSocket, Runtime},
20 udp_transmit,
21};
22use bytes::{Bytes, BytesMut};
23use pin_project_lite::pin_project;
24use proto::{
25 self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
26 EndpointEvent, ServerConfig,
27};
28use rustc_hash::FxHashMap;
29#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"),))]
30use socket2::{Domain, Protocol, Socket, Type};
31use tokio::sync::{Notify, futures::Notified, mpsc};
32use tracing::{Instrument, Span};
33use udp::{BATCH_SIZE, RecvMeta};
34
35use crate::{
36 ConnectionEvent, EndpointConfig, IO_LOOP_BOUND, RECV_TIME_BOUND, VarInt,
37 connection::Connecting, incoming::Incoming, work_limiter::WorkLimiter,
38};
39
40#[derive(Debug, Clone)]
47pub struct Endpoint {
48 pub(crate) inner: EndpointRef,
49 pub(crate) default_client_config: Option<ClientConfig>,
50 runtime: Arc<dyn Runtime>,
51}
52
53impl Endpoint {
54 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] pub fn client(addr: SocketAddr) -> io::Result<Self> {
73 let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
74 if addr.is_ipv6() {
75 if let Err(e) = socket.set_only_v6(false) {
76 tracing::debug!(%e, "unable to make socket dual-stack");
77 }
78 }
79 socket.bind(&addr.into())?;
80 let runtime =
81 default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
82 Self::new_with_abstract_socket(
83 EndpointConfig::default(),
84 None,
85 runtime.wrap_udp_socket(socket.into())?,
86 runtime,
87 )
88 }
89
90 pub fn stats(&self) -> EndpointStats {
92 self.inner.state.lock().unwrap().stats
93 }
94
95 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
103 let socket = std::net::UdpSocket::bind(addr)?;
104 let runtime =
105 default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
106 Self::new_with_abstract_socket(
107 EndpointConfig::default(),
108 Some(config),
109 runtime.wrap_udp_socket(socket)?,
110 runtime,
111 )
112 }
113
114 #[cfg(not(wasm_browser))]
116 pub fn new(
117 config: EndpointConfig,
118 server_config: Option<ServerConfig>,
119 socket: std::net::UdpSocket,
120 runtime: Arc<dyn Runtime>,
121 ) -> io::Result<Self> {
122 let socket = runtime.wrap_udp_socket(socket)?;
123 Self::new_with_abstract_socket(config, server_config, socket, runtime)
124 }
125
126 pub fn new_with_abstract_socket(
131 config: EndpointConfig,
132 server_config: Option<ServerConfig>,
133 socket: Arc<dyn AsyncUdpSocket>,
134 runtime: Arc<dyn Runtime>,
135 ) -> io::Result<Self> {
136 let addr = socket.local_addr()?;
137 let allow_mtud = !socket.may_fragment();
138 let rc = EndpointRef::new(
139 socket,
140 proto::Endpoint::new(
141 Arc::new(config),
142 server_config.map(Arc::new),
143 allow_mtud,
144 None,
145 ),
146 addr.is_ipv6(),
147 runtime.clone(),
148 );
149 let driver = EndpointDriver(rc.clone());
150 runtime.spawn(Box::pin(
151 async {
152 if let Err(e) = driver.await {
153 tracing::error!("I/O error: {}", e);
154 }
155 }
156 .instrument(Span::current()),
157 ));
158 Ok(Self {
159 inner: rc,
160 default_client_config: None,
161 runtime,
162 })
163 }
164
165 pub fn accept(&self) -> Accept<'_> {
172 Accept {
173 endpoint: self,
174 notify: self.inner.shared.incoming.notified(),
175 }
176 }
177
178 pub fn set_default_client_config(&mut self, config: ClientConfig) {
180 self.default_client_config = Some(config);
181 }
182
183 pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
192 let config = match &self.default_client_config {
193 Some(config) => config.clone(),
194 None => return Err(ConnectError::NoDefaultClientConfig),
195 };
196
197 self.connect_with(config, addr, server_name)
198 }
199
200 pub fn connect_with(
206 &self,
207 config: ClientConfig,
208 addr: SocketAddr,
209 server_name: &str,
210 ) -> Result<Connecting, ConnectError> {
211 let mut endpoint = self.inner.state.lock().unwrap();
212 if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
213 return Err(ConnectError::EndpointStopping);
214 }
215 if addr.is_ipv6() && !endpoint.ipv6 {
216 return Err(ConnectError::InvalidRemoteAddress(addr));
217 }
218 let addr = if endpoint.ipv6 {
219 SocketAddr::V6(ensure_ipv6(addr))
220 } else {
221 addr
222 };
223
224 let (ch, conn) = endpoint
225 .inner
226 .connect(self.runtime.now(), config, addr, server_name)?;
227
228 let socket = endpoint.socket.clone();
229 endpoint.stats.outgoing_handshakes += 1;
230 Ok(endpoint
231 .recv_state
232 .connections
233 .insert(ch, conn, socket, self.runtime.clone()))
234 }
235
236 #[cfg(not(wasm_browser))]
240 pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
241 self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
242 }
243
244 pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
251 let addr = socket.local_addr()?;
252 let mut inner = self.inner.state.lock().unwrap();
253 inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
254 inner.ipv6 = addr.is_ipv6();
255
256 for sender in inner.recv_state.connections.senders.values() {
258 let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
260 }
261 if let Some(driver) = inner.driver.take() {
262 driver.wake();
264 }
265
266 Ok(())
267 }
268
269 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
273 self.inner
274 .state
275 .lock()
276 .unwrap()
277 .inner
278 .set_server_config(server_config.map(Arc::new))
279 }
280
281 pub fn local_addr(&self) -> io::Result<SocketAddr> {
283 self.inner.state.lock().unwrap().socket.local_addr()
284 }
285
286 pub fn open_connections(&self) -> usize {
288 self.inner.state.lock().unwrap().inner.open_connections()
289 }
290
291 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
297 let reason = Bytes::copy_from_slice(reason);
298 let mut endpoint = self.inner.state.lock().unwrap();
299 endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
300 for sender in endpoint.recv_state.connections.senders.values() {
301 let _ = sender.send(ConnectionEvent::Close {
303 error_code,
304 reason: reason.clone(),
305 });
306 }
307 self.inner.shared.incoming.notify_waiters();
308 }
309
310 pub async fn wait_idle(&self) {
321 loop {
322 {
323 let endpoint = &mut *self.inner.state.lock().unwrap();
324 if endpoint.recv_state.connections.is_empty() {
325 break;
326 }
327 self.inner.shared.idle.notified()
329 }
330 .await;
331 }
332 }
333}
334
335#[non_exhaustive]
337#[derive(Debug, Default, Copy, Clone)]
338pub struct EndpointStats {
339 pub accepted_handshakes: u64,
341 pub outgoing_handshakes: u64,
343 pub refused_handshakes: u64,
345 pub ignored_handshakes: u64,
347}
348
349#[must_use = "endpoint drivers must be spawned for I/O to occur"]
360#[derive(Debug)]
361pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
362
363impl Future for EndpointDriver {
364 type Output = Result<(), io::Error>;
365
366 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
367 let mut endpoint = self.0.state.lock().unwrap();
368 if endpoint.driver.is_none() {
369 endpoint.driver = Some(cx.waker().clone());
370 }
371
372 let now = endpoint.runtime.now();
373 let mut keep_going = false;
374 keep_going |= endpoint.drive_recv(cx, now)?;
375 keep_going |= endpoint.handle_events(cx, &self.0.shared);
376
377 if !endpoint.recv_state.incoming.is_empty() {
378 self.0.shared.incoming.notify_waiters();
379 }
380
381 if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
382 Poll::Ready(Ok(()))
383 } else {
384 drop(endpoint);
385 if keep_going {
389 cx.waker().wake_by_ref();
390 }
391 Poll::Pending
392 }
393 }
394}
395
396impl Drop for EndpointDriver {
397 fn drop(&mut self) {
398 let mut endpoint = self.0.state.lock().unwrap();
399 endpoint.driver_lost = true;
400 self.0.shared.incoming.notify_waiters();
401 endpoint.recv_state.connections.senders.clear();
404 }
405}
406
407#[derive(Debug)]
408pub(crate) struct EndpointInner {
409 pub(crate) state: Mutex<State>,
410 pub(crate) shared: Shared,
411}
412
413impl EndpointInner {
414 pub(crate) fn accept(
415 &self,
416 incoming: proto::Incoming,
417 server_config: Option<Arc<ServerConfig>>,
418 ) -> Result<Connecting, ConnectionError> {
419 let mut state = self.state.lock().unwrap();
420 let mut response_buffer = Vec::new();
421 let now = state.runtime.now();
422 match state
423 .inner
424 .accept(incoming, now, &mut response_buffer, server_config)
425 {
426 Ok((handle, conn)) => {
427 state.stats.accepted_handshakes += 1;
428 let socket = state.socket.clone();
429 let runtime = state.runtime.clone();
430 Ok(state
431 .recv_state
432 .connections
433 .insert(handle, conn, socket, runtime))
434 }
435 Err(error) => {
436 if let Some(transmit) = error.response {
437 respond(transmit, &response_buffer, &*state.socket);
438 }
439 Err(error.cause)
440 }
441 }
442 }
443
444 pub(crate) fn refuse(&self, incoming: proto::Incoming) {
445 let mut state = self.state.lock().unwrap();
446 state.stats.refused_handshakes += 1;
447 let mut response_buffer = Vec::new();
448 let transmit = state.inner.refuse(incoming, &mut response_buffer);
449 respond(transmit, &response_buffer, &*state.socket);
450 }
451
452 pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
453 let mut state = self.state.lock().unwrap();
454 let mut response_buffer = Vec::new();
455 let transmit = state.inner.retry(incoming, &mut response_buffer)?;
456 respond(transmit, &response_buffer, &*state.socket);
457 Ok(())
458 }
459
460 pub(crate) fn ignore(&self, incoming: proto::Incoming) {
461 let mut state = self.state.lock().unwrap();
462 state.stats.ignored_handshakes += 1;
463 state.inner.ignore(incoming);
464 }
465}
466
467#[derive(Debug)]
468pub(crate) struct State {
469 socket: Arc<dyn AsyncUdpSocket>,
470 prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
473 inner: proto::Endpoint,
474 recv_state: RecvState,
475 driver: Option<Waker>,
476 ipv6: bool,
477 events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
478 ref_count: usize,
480 driver_lost: bool,
481 runtime: Arc<dyn Runtime>,
482 stats: EndpointStats,
483}
484
485#[derive(Debug)]
486pub(crate) struct Shared {
487 incoming: Notify,
488 idle: Notify,
489}
490
491impl State {
492 fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
493 let get_time = || self.runtime.now();
494 self.recv_state.recv_limiter.start_cycle(get_time);
495 if let Some(socket) = &self.prev_socket {
496 let poll_res =
498 self.recv_state
499 .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
500 if poll_res.is_err() {
501 self.prev_socket = None;
502 }
503 };
504 let poll_res =
505 self.recv_state
506 .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
507 self.recv_state.recv_limiter.finish_cycle(get_time);
508 let poll_res = poll_res?;
509 if poll_res.received_connection_packet {
510 self.prev_socket = None;
513 }
514 Ok(poll_res.keep_going)
515 }
516
517 fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
518 for _ in 0..IO_LOOP_BOUND {
519 let (ch, event) = match self.events.poll_recv(cx) {
520 Poll::Ready(Some(x)) => x,
521 Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
522 Poll::Pending => {
523 return false;
524 }
525 };
526
527 if event.is_drained() {
528 self.recv_state.connections.senders.remove(&ch);
529 if self.recv_state.connections.is_empty() {
530 shared.idle.notify_waiters();
531 }
532 }
533 let Some(event) = self.inner.handle_event(ch, event) else {
534 continue;
535 };
536 let _ = self
538 .recv_state
539 .connections
540 .senders
541 .get_mut(&ch)
542 .unwrap()
543 .send(ConnectionEvent::Proto(event));
544 }
545
546 true
547 }
548}
549
550impl Drop for State {
551 fn drop(&mut self) {
552 for incoming in self.recv_state.incoming.drain(..) {
553 self.inner.ignore(incoming);
554 }
555 }
556}
557
558fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
559 _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
580}
581
582#[inline]
583fn proto_ecn(ecn: udp::EcnCodepoint) -> proto::EcnCodepoint {
584 match ecn {
585 udp::EcnCodepoint::Ect0 => proto::EcnCodepoint::Ect0,
586 udp::EcnCodepoint::Ect1 => proto::EcnCodepoint::Ect1,
587 udp::EcnCodepoint::Ce => proto::EcnCodepoint::Ce,
588 }
589}
590
591#[derive(Debug)]
592struct ConnectionSet {
593 senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
595 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
597 close: Option<(VarInt, Bytes)>,
599}
600
601impl ConnectionSet {
602 fn insert(
603 &mut self,
604 handle: ConnectionHandle,
605 conn: proto::Connection,
606 socket: Arc<dyn AsyncUdpSocket>,
607 runtime: Arc<dyn Runtime>,
608 ) -> Connecting {
609 let (send, recv) = mpsc::unbounded_channel();
610 if let Some((error_code, ref reason)) = self.close {
611 send.send(ConnectionEvent::Close {
612 error_code,
613 reason: reason.clone(),
614 })
615 .unwrap();
616 }
617 self.senders.insert(handle, send);
618 Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
619 }
620
621 fn is_empty(&self) -> bool {
622 self.senders.is_empty()
623 }
624}
625
626fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
627 match x {
628 SocketAddr::V6(x) => x,
629 SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
630 }
631}
632
633pin_project! {
634 pub struct Accept<'a> {
636 endpoint: &'a Endpoint,
637 #[pin]
638 notify: Notified<'a>,
639 }
640}
641
642impl Future for Accept<'_> {
643 type Output = Option<Incoming>;
644 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
645 let mut this = self.project();
646 let mut endpoint = this.endpoint.inner.state.lock().unwrap();
647 if endpoint.driver_lost {
648 return Poll::Ready(None);
649 }
650 if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
651 drop(endpoint);
653 let incoming = Incoming::new(incoming, this.endpoint.inner.clone());
654 return Poll::Ready(Some(incoming));
655 }
656 if endpoint.recv_state.connections.close.is_some() {
657 return Poll::Ready(None);
658 }
659 loop {
660 match this.notify.as_mut().poll(ctx) {
661 Poll::Pending => return Poll::Pending,
663 Poll::Ready(()) => this
665 .notify
666 .set(this.endpoint.inner.shared.incoming.notified()),
667 }
668 }
669 }
670}
671
672#[derive(Debug)]
673pub(crate) struct EndpointRef(Arc<EndpointInner>);
674
675impl EndpointRef {
676 pub(crate) fn new(
677 socket: Arc<dyn AsyncUdpSocket>,
678 inner: proto::Endpoint,
679 ipv6: bool,
680 runtime: Arc<dyn Runtime>,
681 ) -> Self {
682 let (sender, events) = mpsc::unbounded_channel();
683 let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
684 Self(Arc::new(EndpointInner {
685 shared: Shared {
686 incoming: Notify::new(),
687 idle: Notify::new(),
688 },
689 state: Mutex::new(State {
690 socket,
691 prev_socket: None,
692 inner,
693 ipv6,
694 events,
695 driver: None,
696 ref_count: 0,
697 driver_lost: false,
698 recv_state,
699 runtime,
700 stats: EndpointStats::default(),
701 }),
702 }))
703 }
704}
705
706impl Clone for EndpointRef {
707 fn clone(&self) -> Self {
708 self.0.state.lock().unwrap().ref_count += 1;
709 Self(self.0.clone())
710 }
711}
712
713impl Drop for EndpointRef {
714 fn drop(&mut self) {
715 let endpoint = &mut *self.0.state.lock().unwrap();
716 if let Some(x) = endpoint.ref_count.checked_sub(1) {
717 endpoint.ref_count = x;
718 if x == 0 {
719 if let Some(task) = endpoint.driver.take() {
722 task.wake();
723 }
724 }
725 }
726 }
727}
728
729impl std::ops::Deref for EndpointRef {
730 type Target = EndpointInner;
731 fn deref(&self) -> &Self::Target {
732 &self.0
733 }
734}
735
736struct RecvState {
738 incoming: VecDeque<proto::Incoming>,
739 connections: ConnectionSet,
740 recv_buf: Box<[u8]>,
741 recv_limiter: WorkLimiter,
742}
743
744impl RecvState {
745 fn new(
746 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
747 max_receive_segments: usize,
748 endpoint: &proto::Endpoint,
749 ) -> Self {
750 let recv_buf = vec![
751 0;
752 endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
753 * max_receive_segments
754 * BATCH_SIZE
755 ];
756 Self {
757 connections: ConnectionSet {
758 senders: FxHashMap::default(),
759 sender,
760 close: None,
761 },
762 incoming: VecDeque::new(),
763 recv_buf: recv_buf.into(),
764 recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
765 }
766 }
767
768 fn poll_socket(
769 &mut self,
770 cx: &mut Context,
771 endpoint: &mut proto::Endpoint,
772 socket: &dyn AsyncUdpSocket,
773 runtime: &dyn Runtime,
774 now: Instant,
775 ) -> Result<PollProgress, io::Error> {
776 let mut received_connection_packet = false;
777 let mut metas = [RecvMeta::default(); BATCH_SIZE];
778 let mut iovs: [IoSliceMut; BATCH_SIZE] = {
779 let mut bufs = self
780 .recv_buf
781 .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
782 .map(IoSliceMut::new);
783
784 std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
788 };
789 loop {
790 match socket.poll_recv(cx, &mut iovs, &mut metas) {
791 Poll::Ready(Ok(msgs)) => {
792 self.recv_limiter.record_work(msgs);
793 for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
794 let mut data: BytesMut = buf[0..meta.len].into();
795 while !data.is_empty() {
796 let buf = data.split_to(meta.stride.min(data.len()));
797 let mut response_buffer = Vec::new();
798 match endpoint.handle(
799 now,
800 meta.addr,
801 meta.dst_ip,
802 meta.ecn.map(proto_ecn),
803 buf,
804 &mut response_buffer,
805 ) {
806 Some(DatagramEvent::NewConnection(incoming)) => {
807 if self.connections.close.is_none() {
808 self.incoming.push_back(incoming);
809 } else {
810 let transmit =
811 endpoint.refuse(incoming, &mut response_buffer);
812 respond(transmit, &response_buffer, socket);
813 }
814 }
815 Some(DatagramEvent::ConnectionEvent(handle, event)) => {
816 received_connection_packet = true;
818 let _ = self
819 .connections
820 .senders
821 .get_mut(&handle)
822 .unwrap()
823 .send(ConnectionEvent::Proto(event));
824 }
825 Some(DatagramEvent::Response(transmit)) => {
826 respond(transmit, &response_buffer, socket);
827 }
828 None => {}
829 }
830 }
831 }
832 }
833 Poll::Pending => {
834 return Ok(PollProgress {
835 received_connection_packet,
836 keep_going: false,
837 });
838 }
839 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
842 continue;
843 }
844 Poll::Ready(Err(e)) => {
845 return Err(e);
846 }
847 }
848 if !self.recv_limiter.allow_work(|| runtime.now()) {
849 return Ok(PollProgress {
850 received_connection_packet,
851 keep_going: true,
852 });
853 }
854 }
855 }
856}
857
858impl fmt::Debug for RecvState {
859 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
860 f.debug_struct("RecvState")
861 .field("incoming", &self.incoming)
862 .field("connections", &self.connections)
863 .field("recv_limiter", &self.recv_limiter)
865 .finish_non_exhaustive()
866 }
867}
868
869#[derive(Default)]
870struct PollProgress {
871 received_connection_packet: bool,
873 keep_going: bool,
875}