quinn/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io,
6    io::IoSliceMut,
7    mem,
8    net::{SocketAddr, SocketAddrV6},
9    pin::Pin,
10    str,
11    sync::{Arc, Mutex},
12    task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use crate::runtime::default_runtime;
17use crate::{
18    Instant,
19    runtime::{AsyncUdpSocket, Runtime},
20    udp_transmit,
21};
22use bytes::{Bytes, BytesMut};
23use pin_project_lite::pin_project;
24use proto::{
25    self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
26    EndpointEvent, ServerConfig,
27};
28use rustc_hash::FxHashMap;
29#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"),))]
30use socket2::{Domain, Protocol, Socket, Type};
31use tokio::sync::{Notify, futures::Notified, mpsc};
32use tracing::{Instrument, Span};
33use udp::{BATCH_SIZE, RecvMeta};
34
35use crate::{
36    ConnectionEvent, EndpointConfig, IO_LOOP_BOUND, RECV_TIME_BOUND, VarInt,
37    connection::Connecting, incoming::Incoming, work_limiter::WorkLimiter,
38};
39
40/// A QUIC endpoint.
41///
42/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
43/// client and server for different connections.
44///
45/// May be cloned to obtain another handle to the same endpoint.
46#[derive(Debug, Clone)]
47pub struct Endpoint {
48    pub(crate) inner: EndpointRef,
49    pub(crate) default_client_config: Option<ClientConfig>,
50    runtime: Arc<dyn Runtime>,
51}
52
53impl Endpoint {
54    /// Helper to construct an endpoint for use with outgoing connections only
55    ///
56    /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
57    /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
58    /// IPv6 address respectively from an OS-assigned port.
59    ///
60    /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
61    /// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
62    /// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
63    /// address. For example:
64    ///
65    /// ```
66    /// quinn::Endpoint::client((std::net::Ipv6Addr::UNSPECIFIED, 0).into());
67    /// ```
68    ///
69    /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
70    /// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
71    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
72    pub fn client(addr: SocketAddr) -> io::Result<Self> {
73        let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
74        if addr.is_ipv6() {
75            if let Err(e) = socket.set_only_v6(false) {
76                tracing::debug!(%e, "unable to make socket dual-stack");
77            }
78        }
79        socket.bind(&addr.into())?;
80        let runtime =
81            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
82        Self::new_with_abstract_socket(
83            EndpointConfig::default(),
84            None,
85            runtime.wrap_udp_socket(socket.into())?,
86            runtime,
87        )
88    }
89
90    /// Returns relevant stats from this Endpoint
91    pub fn stats(&self) -> EndpointStats {
92        self.inner.state.lock().unwrap().stats
93    }
94
95    /// Helper to construct an endpoint for use with both incoming and outgoing connections
96    ///
97    /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
98    /// IPv6 address on Windows will not by default be able to communicate with IPv4
99    /// addresses. Portable applications should bind an address that matches the family they wish to
100    /// communicate within.
101    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
102    pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
103        let socket = std::net::UdpSocket::bind(addr)?;
104        let runtime =
105            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
106        Self::new_with_abstract_socket(
107            EndpointConfig::default(),
108            Some(config),
109            runtime.wrap_udp_socket(socket)?,
110            runtime,
111        )
112    }
113
114    /// Construct an endpoint with arbitrary configuration and socket
115    #[cfg(not(wasm_browser))]
116    pub fn new(
117        config: EndpointConfig,
118        server_config: Option<ServerConfig>,
119        socket: std::net::UdpSocket,
120        runtime: Arc<dyn Runtime>,
121    ) -> io::Result<Self> {
122        let socket = runtime.wrap_udp_socket(socket)?;
123        Self::new_with_abstract_socket(config, server_config, socket, runtime)
124    }
125
126    /// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
127    ///
128    /// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
129    /// ownership is needed.
130    pub fn new_with_abstract_socket(
131        config: EndpointConfig,
132        server_config: Option<ServerConfig>,
133        socket: Arc<dyn AsyncUdpSocket>,
134        runtime: Arc<dyn Runtime>,
135    ) -> io::Result<Self> {
136        let addr = socket.local_addr()?;
137        let allow_mtud = !socket.may_fragment();
138        let rc = EndpointRef::new(
139            socket,
140            proto::Endpoint::new(
141                Arc::new(config),
142                server_config.map(Arc::new),
143                allow_mtud,
144                None,
145            ),
146            addr.is_ipv6(),
147            runtime.clone(),
148        );
149        let driver = EndpointDriver(rc.clone());
150        runtime.spawn(Box::pin(
151            async {
152                if let Err(e) = driver.await {
153                    tracing::error!("I/O error: {}", e);
154                }
155            }
156            .instrument(Span::current()),
157        ));
158        Ok(Self {
159            inner: rc,
160            default_client_config: None,
161            runtime,
162        })
163    }
164
165    /// Get the next incoming connection attempt from a client
166    ///
167    /// Yields [`Incoming`]s, or `None` if the endpoint is [`close`](Self::close)d. [`Incoming`]
168    /// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
169    /// filter connection attempts or force address validation, or converted into an intermediate
170    /// `Connecting` future which can be used to e.g. send 0.5-RTT data.
171    pub fn accept(&self) -> Accept<'_> {
172        Accept {
173            endpoint: self,
174            notify: self.inner.shared.incoming.notified(),
175        }
176    }
177
178    /// Set the client configuration used by `connect`
179    pub fn set_default_client_config(&mut self, config: ClientConfig) {
180        self.default_client_config = Some(config);
181    }
182
183    /// Connect to a remote endpoint
184    ///
185    /// `server_name` must be covered by the certificate presented by the server. This prevents a
186    /// connection from being intercepted by an attacker with a valid certificate for some other
187    /// server.
188    ///
189    /// May fail immediately due to configuration errors, or in the future if the connection could
190    /// not be established.
191    pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
192        let config = match &self.default_client_config {
193            Some(config) => config.clone(),
194            None => return Err(ConnectError::NoDefaultClientConfig),
195        };
196
197        self.connect_with(config, addr, server_name)
198    }
199
200    /// Connect to a remote endpoint using a custom configuration.
201    ///
202    /// See [`connect()`] for details.
203    ///
204    /// [`connect()`]: Endpoint::connect
205    pub fn connect_with(
206        &self,
207        config: ClientConfig,
208        addr: SocketAddr,
209        server_name: &str,
210    ) -> Result<Connecting, ConnectError> {
211        let mut endpoint = self.inner.state.lock().unwrap();
212        if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
213            return Err(ConnectError::EndpointStopping);
214        }
215        if addr.is_ipv6() && !endpoint.ipv6 {
216            return Err(ConnectError::InvalidRemoteAddress(addr));
217        }
218        let addr = if endpoint.ipv6 {
219            SocketAddr::V6(ensure_ipv6(addr))
220        } else {
221            addr
222        };
223
224        let (ch, conn) = endpoint
225            .inner
226            .connect(self.runtime.now(), config, addr, server_name)?;
227
228        let socket = endpoint.socket.clone();
229        endpoint.stats.outgoing_handshakes += 1;
230        Ok(endpoint
231            .recv_state
232            .connections
233            .insert(ch, conn, socket, self.runtime.clone()))
234    }
235
236    /// Switch to a new UDP socket
237    ///
238    /// See [`Endpoint::rebind_abstract()`] for details.
239    #[cfg(not(wasm_browser))]
240    pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
241        self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
242    }
243
244    /// Switch to a new UDP socket
245    ///
246    /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
247    /// connections and connections to servers unreachable from the new address will be lost.
248    ///
249    /// On error, the old UDP socket is retained.
250    pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
251        let addr = socket.local_addr()?;
252        let mut inner = self.inner.state.lock().unwrap();
253        inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
254        inner.ipv6 = addr.is_ipv6();
255
256        // Update connection socket references
257        for sender in inner.recv_state.connections.senders.values() {
258            // Ignoring errors from dropped connections
259            let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
260        }
261        if let Some(driver) = inner.driver.take() {
262            // Ensure the driver can register for wake-ups from the new socket
263            driver.wake();
264        }
265
266        Ok(())
267    }
268
269    /// Replace the server configuration, affecting new incoming connections only
270    ///
271    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
272    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
273        self.inner
274            .state
275            .lock()
276            .unwrap()
277            .inner
278            .set_server_config(server_config.map(Arc::new))
279    }
280
281    /// Get the local `SocketAddr` the underlying socket is bound to
282    pub fn local_addr(&self) -> io::Result<SocketAddr> {
283        self.inner.state.lock().unwrap().socket.local_addr()
284    }
285
286    /// Get the number of connections that are currently open
287    pub fn open_connections(&self) -> usize {
288        self.inner.state.lock().unwrap().inner.open_connections()
289    }
290
291    /// Close all of this endpoint's connections immediately and cease accepting new connections.
292    ///
293    /// See [`Connection::close()`] for details.
294    ///
295    /// [`Connection::close()`]: crate::Connection::close
296    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
297        let reason = Bytes::copy_from_slice(reason);
298        let mut endpoint = self.inner.state.lock().unwrap();
299        endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
300        for sender in endpoint.recv_state.connections.senders.values() {
301            // Ignoring errors from dropped connections
302            let _ = sender.send(ConnectionEvent::Close {
303                error_code,
304                reason: reason.clone(),
305            });
306        }
307        self.inner.shared.incoming.notify_waiters();
308    }
309
310    /// Wait for all connections on the endpoint to be cleanly shut down
311    ///
312    /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
313    /// peers of recent connection closes, whereas exiting immediately could force them to wait out
314    /// the idle timeout period.
315    ///
316    /// Does not proactively close existing connections or cause incoming connections to be
317    /// rejected. Consider calling [`close()`] if that is desired.
318    ///
319    /// [`close()`]: Endpoint::close
320    pub async fn wait_idle(&self) {
321        loop {
322            {
323                let endpoint = &mut *self.inner.state.lock().unwrap();
324                if endpoint.recv_state.connections.is_empty() {
325                    break;
326                }
327                // Construct future while lock is held to avoid race
328                self.inner.shared.idle.notified()
329            }
330            .await;
331        }
332    }
333}
334
335/// Statistics on [Endpoint] activity
336#[non_exhaustive]
337#[derive(Debug, Default, Copy, Clone)]
338pub struct EndpointStats {
339    /// Cummulative number of Quic handshakes accepted by this [Endpoint]
340    pub accepted_handshakes: u64,
341    /// Cummulative number of Quic handshakees sent from this [Endpoint]
342    pub outgoing_handshakes: u64,
343    /// Cummulative number of Quic handshakes refused on this [Endpoint]
344    pub refused_handshakes: u64,
345    /// Cummulative number of Quic handshakes ignored on this [Endpoint]
346    pub ignored_handshakes: u64,
347}
348
349/// A future that drives IO on an endpoint
350///
351/// This task functions as the switch point between the UDP socket object and the
352/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
353/// In order to do so, it also facilitates the exchange of different types of events
354/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
355/// running this task is necessary to keep the endpoint's connections running.
356///
357/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
358/// an I/O error occurs.
359#[must_use = "endpoint drivers must be spawned for I/O to occur"]
360#[derive(Debug)]
361pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
362
363impl Future for EndpointDriver {
364    type Output = Result<(), io::Error>;
365
366    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
367        let mut endpoint = self.0.state.lock().unwrap();
368        if endpoint.driver.is_none() {
369            endpoint.driver = Some(cx.waker().clone());
370        }
371
372        let now = endpoint.runtime.now();
373        let mut keep_going = false;
374        keep_going |= endpoint.drive_recv(cx, now)?;
375        keep_going |= endpoint.handle_events(cx, &self.0.shared);
376
377        if !endpoint.recv_state.incoming.is_empty() {
378            self.0.shared.incoming.notify_waiters();
379        }
380
381        if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
382            Poll::Ready(Ok(()))
383        } else {
384            drop(endpoint);
385            // If there is more work to do schedule the endpoint task again.
386            // `wake_by_ref()` is called outside the lock to minimize
387            // lock contention on a multithreaded runtime.
388            if keep_going {
389                cx.waker().wake_by_ref();
390            }
391            Poll::Pending
392        }
393    }
394}
395
396impl Drop for EndpointDriver {
397    fn drop(&mut self) {
398        let mut endpoint = self.0.state.lock().unwrap();
399        endpoint.driver_lost = true;
400        self.0.shared.incoming.notify_waiters();
401        // Drop all outgoing channels, signaling the termination of the endpoint to the associated
402        // connections.
403        endpoint.recv_state.connections.senders.clear();
404    }
405}
406
407#[derive(Debug)]
408pub(crate) struct EndpointInner {
409    pub(crate) state: Mutex<State>,
410    pub(crate) shared: Shared,
411}
412
413impl EndpointInner {
414    pub(crate) fn accept(
415        &self,
416        incoming: proto::Incoming,
417        server_config: Option<Arc<ServerConfig>>,
418    ) -> Result<Connecting, ConnectionError> {
419        let mut state = self.state.lock().unwrap();
420        let mut response_buffer = Vec::new();
421        let now = state.runtime.now();
422        match state
423            .inner
424            .accept(incoming, now, &mut response_buffer, server_config)
425        {
426            Ok((handle, conn)) => {
427                state.stats.accepted_handshakes += 1;
428                let socket = state.socket.clone();
429                let runtime = state.runtime.clone();
430                Ok(state
431                    .recv_state
432                    .connections
433                    .insert(handle, conn, socket, runtime))
434            }
435            Err(error) => {
436                if let Some(transmit) = error.response {
437                    respond(transmit, &response_buffer, &*state.socket);
438                }
439                Err(error.cause)
440            }
441        }
442    }
443
444    pub(crate) fn refuse(&self, incoming: proto::Incoming) {
445        let mut state = self.state.lock().unwrap();
446        state.stats.refused_handshakes += 1;
447        let mut response_buffer = Vec::new();
448        let transmit = state.inner.refuse(incoming, &mut response_buffer);
449        respond(transmit, &response_buffer, &*state.socket);
450    }
451
452    pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
453        let mut state = self.state.lock().unwrap();
454        let mut response_buffer = Vec::new();
455        let transmit = state.inner.retry(incoming, &mut response_buffer)?;
456        respond(transmit, &response_buffer, &*state.socket);
457        Ok(())
458    }
459
460    pub(crate) fn ignore(&self, incoming: proto::Incoming) {
461        let mut state = self.state.lock().unwrap();
462        state.stats.ignored_handshakes += 1;
463        state.inner.ignore(incoming);
464    }
465}
466
467#[derive(Debug)]
468pub(crate) struct State {
469    socket: Arc<dyn AsyncUdpSocket>,
470    /// During an active migration, abandoned_socket receives traffic
471    /// until the first packet arrives on the new socket.
472    prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
473    inner: proto::Endpoint,
474    recv_state: RecvState,
475    driver: Option<Waker>,
476    ipv6: bool,
477    events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
478    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
479    ref_count: usize,
480    driver_lost: bool,
481    runtime: Arc<dyn Runtime>,
482    stats: EndpointStats,
483}
484
485#[derive(Debug)]
486pub(crate) struct Shared {
487    incoming: Notify,
488    idle: Notify,
489}
490
491impl State {
492    fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
493        let get_time = || self.runtime.now();
494        self.recv_state.recv_limiter.start_cycle(get_time);
495        if let Some(socket) = &self.prev_socket {
496            // We don't care about the `PollProgress` from old sockets.
497            let poll_res =
498                self.recv_state
499                    .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
500            if poll_res.is_err() {
501                self.prev_socket = None;
502            }
503        };
504        let poll_res =
505            self.recv_state
506                .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
507        self.recv_state.recv_limiter.finish_cycle(get_time);
508        let poll_res = poll_res?;
509        if poll_res.received_connection_packet {
510            // Traffic has arrived on self.socket, therefore there is no need for the abandoned
511            // one anymore. TODO: Account for multiple outgoing connections.
512            self.prev_socket = None;
513        }
514        Ok(poll_res.keep_going)
515    }
516
517    fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
518        for _ in 0..IO_LOOP_BOUND {
519            let (ch, event) = match self.events.poll_recv(cx) {
520                Poll::Ready(Some(x)) => x,
521                Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
522                Poll::Pending => {
523                    return false;
524                }
525            };
526
527            if event.is_drained() {
528                self.recv_state.connections.senders.remove(&ch);
529                if self.recv_state.connections.is_empty() {
530                    shared.idle.notify_waiters();
531                }
532            }
533            let Some(event) = self.inner.handle_event(ch, event) else {
534                continue;
535            };
536            // Ignoring errors from dropped connections that haven't yet been cleaned up
537            let _ = self
538                .recv_state
539                .connections
540                .senders
541                .get_mut(&ch)
542                .unwrap()
543                .send(ConnectionEvent::Proto(event));
544        }
545
546        true
547    }
548}
549
550impl Drop for State {
551    fn drop(&mut self) {
552        for incoming in self.recv_state.incoming.drain(..) {
553            self.inner.ignore(incoming);
554        }
555    }
556}
557
558fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
559    // Send if there's kernel buffer space; otherwise, drop it
560    //
561    // As an endpoint-generated packet, we know this is an
562    // immediate, stateless response to an unconnected peer,
563    // one of:
564    //
565    // - A version negotiation response due to an unknown version
566    // - A `CLOSE` due to a malformed or unwanted connection attempt
567    // - A stateless reset due to an unrecognized connection
568    // - A `Retry` packet due to a connection attempt when
569    //   `use_retry` is set
570    //
571    // In each case, a well-behaved peer can be trusted to retry a
572    // few times, which is guaranteed to produce the same response
573    // from us. Repeated failures might at worst cause a peer's new
574    // connection attempt to time out, which is acceptable if we're
575    // under such heavy load that there's never room for this code
576    // to transmit. This is morally equivalent to the packet getting
577    // lost due to congestion further along the link, which
578    // similarly relies on peer retries for recovery.
579    _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
580}
581
582#[inline]
583fn proto_ecn(ecn: udp::EcnCodepoint) -> proto::EcnCodepoint {
584    match ecn {
585        udp::EcnCodepoint::Ect0 => proto::EcnCodepoint::Ect0,
586        udp::EcnCodepoint::Ect1 => proto::EcnCodepoint::Ect1,
587        udp::EcnCodepoint::Ce => proto::EcnCodepoint::Ce,
588    }
589}
590
591#[derive(Debug)]
592struct ConnectionSet {
593    /// Senders for communicating with the endpoint's connections
594    senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
595    /// Stored to give out clones to new ConnectionInners
596    sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
597    /// Set if the endpoint has been manually closed
598    close: Option<(VarInt, Bytes)>,
599}
600
601impl ConnectionSet {
602    fn insert(
603        &mut self,
604        handle: ConnectionHandle,
605        conn: proto::Connection,
606        socket: Arc<dyn AsyncUdpSocket>,
607        runtime: Arc<dyn Runtime>,
608    ) -> Connecting {
609        let (send, recv) = mpsc::unbounded_channel();
610        if let Some((error_code, ref reason)) = self.close {
611            send.send(ConnectionEvent::Close {
612                error_code,
613                reason: reason.clone(),
614            })
615            .unwrap();
616        }
617        self.senders.insert(handle, send);
618        Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
619    }
620
621    fn is_empty(&self) -> bool {
622        self.senders.is_empty()
623    }
624}
625
626fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
627    match x {
628        SocketAddr::V6(x) => x,
629        SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
630    }
631}
632
633pin_project! {
634    /// Future produced by [`Endpoint::accept`]
635    pub struct Accept<'a> {
636        endpoint: &'a Endpoint,
637        #[pin]
638        notify: Notified<'a>,
639    }
640}
641
642impl Future for Accept<'_> {
643    type Output = Option<Incoming>;
644    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
645        let mut this = self.project();
646        let mut endpoint = this.endpoint.inner.state.lock().unwrap();
647        if endpoint.driver_lost {
648            return Poll::Ready(None);
649        }
650        if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
651            // Release the mutex lock on endpoint so cloning it doesn't deadlock
652            drop(endpoint);
653            let incoming = Incoming::new(incoming, this.endpoint.inner.clone());
654            return Poll::Ready(Some(incoming));
655        }
656        if endpoint.recv_state.connections.close.is_some() {
657            return Poll::Ready(None);
658        }
659        loop {
660            match this.notify.as_mut().poll(ctx) {
661                // `state` lock ensures we didn't race with readiness
662                Poll::Pending => return Poll::Pending,
663                // Spurious wakeup, get a new future
664                Poll::Ready(()) => this
665                    .notify
666                    .set(this.endpoint.inner.shared.incoming.notified()),
667            }
668        }
669    }
670}
671
672#[derive(Debug)]
673pub(crate) struct EndpointRef(Arc<EndpointInner>);
674
675impl EndpointRef {
676    pub(crate) fn new(
677        socket: Arc<dyn AsyncUdpSocket>,
678        inner: proto::Endpoint,
679        ipv6: bool,
680        runtime: Arc<dyn Runtime>,
681    ) -> Self {
682        let (sender, events) = mpsc::unbounded_channel();
683        let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
684        Self(Arc::new(EndpointInner {
685            shared: Shared {
686                incoming: Notify::new(),
687                idle: Notify::new(),
688            },
689            state: Mutex::new(State {
690                socket,
691                prev_socket: None,
692                inner,
693                ipv6,
694                events,
695                driver: None,
696                ref_count: 0,
697                driver_lost: false,
698                recv_state,
699                runtime,
700                stats: EndpointStats::default(),
701            }),
702        }))
703    }
704}
705
706impl Clone for EndpointRef {
707    fn clone(&self) -> Self {
708        self.0.state.lock().unwrap().ref_count += 1;
709        Self(self.0.clone())
710    }
711}
712
713impl Drop for EndpointRef {
714    fn drop(&mut self) {
715        let endpoint = &mut *self.0.state.lock().unwrap();
716        if let Some(x) = endpoint.ref_count.checked_sub(1) {
717            endpoint.ref_count = x;
718            if x == 0 {
719                // If the driver is about to be on its own, ensure it can shut down if the last
720                // connection is gone.
721                if let Some(task) = endpoint.driver.take() {
722                    task.wake();
723                }
724            }
725        }
726    }
727}
728
729impl std::ops::Deref for EndpointRef {
730    type Target = EndpointInner;
731    fn deref(&self) -> &Self::Target {
732        &self.0
733    }
734}
735
736/// State directly involved in handling incoming packets
737struct RecvState {
738    incoming: VecDeque<proto::Incoming>,
739    connections: ConnectionSet,
740    recv_buf: Box<[u8]>,
741    recv_limiter: WorkLimiter,
742}
743
744impl RecvState {
745    fn new(
746        sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
747        max_receive_segments: usize,
748        endpoint: &proto::Endpoint,
749    ) -> Self {
750        let recv_buf = vec![
751            0;
752            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
753                * max_receive_segments
754                * BATCH_SIZE
755        ];
756        Self {
757            connections: ConnectionSet {
758                senders: FxHashMap::default(),
759                sender,
760                close: None,
761            },
762            incoming: VecDeque::new(),
763            recv_buf: recv_buf.into(),
764            recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
765        }
766    }
767
768    fn poll_socket(
769        &mut self,
770        cx: &mut Context,
771        endpoint: &mut proto::Endpoint,
772        socket: &dyn AsyncUdpSocket,
773        runtime: &dyn Runtime,
774        now: Instant,
775    ) -> Result<PollProgress, io::Error> {
776        let mut received_connection_packet = false;
777        let mut metas = [RecvMeta::default(); BATCH_SIZE];
778        let mut iovs: [IoSliceMut; BATCH_SIZE] = {
779            let mut bufs = self
780                .recv_buf
781                .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
782                .map(IoSliceMut::new);
783
784            // expect() safe as self.recv_buf is chunked into BATCH_SIZE items
785            // and iovs will be of size BATCH_SIZE, thus from_fn is called
786            // exactly BATCH_SIZE times.
787            std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
788        };
789        loop {
790            match socket.poll_recv(cx, &mut iovs, &mut metas) {
791                Poll::Ready(Ok(msgs)) => {
792                    self.recv_limiter.record_work(msgs);
793                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
794                        let mut data: BytesMut = buf[0..meta.len].into();
795                        while !data.is_empty() {
796                            let buf = data.split_to(meta.stride.min(data.len()));
797                            let mut response_buffer = Vec::new();
798                            match endpoint.handle(
799                                now,
800                                meta.addr,
801                                meta.dst_ip,
802                                meta.ecn.map(proto_ecn),
803                                buf,
804                                &mut response_buffer,
805                            ) {
806                                Some(DatagramEvent::NewConnection(incoming)) => {
807                                    if self.connections.close.is_none() {
808                                        self.incoming.push_back(incoming);
809                                    } else {
810                                        let transmit =
811                                            endpoint.refuse(incoming, &mut response_buffer);
812                                        respond(transmit, &response_buffer, socket);
813                                    }
814                                }
815                                Some(DatagramEvent::ConnectionEvent(handle, event)) => {
816                                    // Ignoring errors from dropped connections that haven't yet been cleaned up
817                                    received_connection_packet = true;
818                                    let _ = self
819                                        .connections
820                                        .senders
821                                        .get_mut(&handle)
822                                        .unwrap()
823                                        .send(ConnectionEvent::Proto(event));
824                                }
825                                Some(DatagramEvent::Response(transmit)) => {
826                                    respond(transmit, &response_buffer, socket);
827                                }
828                                None => {}
829                            }
830                        }
831                    }
832                }
833                Poll::Pending => {
834                    return Ok(PollProgress {
835                        received_connection_packet,
836                        keep_going: false,
837                    });
838                }
839                // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
840                // attacker
841                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
842                    continue;
843                }
844                Poll::Ready(Err(e)) => {
845                    return Err(e);
846                }
847            }
848            if !self.recv_limiter.allow_work(|| runtime.now()) {
849                return Ok(PollProgress {
850                    received_connection_packet,
851                    keep_going: true,
852                });
853            }
854        }
855    }
856}
857
858impl fmt::Debug for RecvState {
859    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
860        f.debug_struct("RecvState")
861            .field("incoming", &self.incoming)
862            .field("connections", &self.connections)
863            // recv_buf too large
864            .field("recv_limiter", &self.recv_limiter)
865            .finish_non_exhaustive()
866    }
867}
868
869#[derive(Default)]
870struct PollProgress {
871    /// Whether a datagram was routed to an existing connection
872    received_connection_packet: bool,
873    /// Whether datagram handling was interrupted early by the work limiter for fairness
874    keep_going: bool,
875}