ferron/
listener_quic.rs

1// Copyright (c) 2018 The quinn Developers
2// Portions of this file are derived from Quinn (https://github.com/quinn-rs/quinn).
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in all
12// copies or substantial portions of the Software.
13//
14
15use std::error::Error;
16#[cfg(feature = "runtime-monoio")]
17use std::fmt::{Debug, Formatter};
18#[cfg(feature = "runtime-monoio")]
19use std::future::Future;
20#[cfg(feature = "runtime-monoio")]
21use std::io;
22use std::net::{IpAddr, Ipv6Addr, SocketAddr};
23#[cfg(feature = "runtime-monoio")]
24use std::pin::Pin;
25use std::sync::Arc;
26#[cfg(feature = "runtime-monoio")]
27use std::task::{ready, Context, Poll};
28use std::time::Duration;
29#[cfg(feature = "runtime-monoio")]
30use std::time::Instant;
31
32use async_channel::{Receiver, Sender};
33#[cfg(feature = "runtime-monoio")]
34use async_io::Async;
35#[cfg(feature = "runtime-monoio")]
36use pin_project_lite::pin_project;
37use quinn::crypto::rustls::QuicServerConfig;
38#[cfg(feature = "runtime-monoio")]
39use quinn::{udp, AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller};
40use rustls::ServerConfig;
41use tokio_util::sync::CancellationToken;
42
43use crate::listener_handler_communication::{Connection, ConnectionData};
44use crate::logging::LogMessage;
45
46#[cfg(feature = "runtime-monoio")]
47pin_project_lite::pin_project! {
48    /// Helper adapting a function `MakeFut` that constructs a single-use future `Fut` into a
49    /// [`UdpPoller`] that may be reused indefinitely
50    struct UdpPollHelper<MakeFut, Fut> {
51        make_fut: MakeFut,
52        #[pin]
53        fut: Option<Fut>,
54    }
55}
56
57#[cfg(feature = "runtime-monoio")]
58impl<MakeFut, Fut> UdpPollHelper<MakeFut, Fut> {
59  /// Construct a [`UdpPoller`] that calls `make_fut` to get the future to poll, storing it until
60  /// it yields [`Poll::Ready`], then creating a new one on the next
61  /// [`poll_writable`](UdpPoller::poll_writable)
62  fn new(make_fut: MakeFut) -> Self {
63    Self { make_fut, fut: None }
64  }
65}
66
67#[cfg(feature = "runtime-monoio")]
68impl<MakeFut, Fut> UdpPoller for UdpPollHelper<MakeFut, Fut>
69where
70  MakeFut: Fn() -> Fut + Send + Sync + 'static,
71  Fut: Future<Output = io::Result<()>> + Send + Sync + 'static,
72{
73  fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
74    let mut this = self.project();
75    if this.fut.is_none() {
76      this.fut.set(Some((this.make_fut)()));
77    }
78    // We're forced to `unwrap` here because `Fut` may be `!Unpin`, which means we can't safely
79    // obtain an `&mut Fut` after storing it in `self.fut` when `self` is already behind `Pin`,
80    // and if we didn't store it then we wouldn't be able to keep it alive between
81    // `poll_writable` calls.
82    let result = this.fut.as_mut().as_pin_mut().unwrap().poll(cx);
83    if result.is_ready() {
84      // Polling an arbitrary `Future` after it becomes ready is a logic error, so arrange for
85      // a new `Future` to be created on the next call.
86      this.fut.set(None);
87    }
88    result
89  }
90}
91
92#[cfg(feature = "runtime-monoio")]
93impl<MakeFut, Fut> Debug for UdpPollHelper<MakeFut, Fut> {
94  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95    f.debug_struct("UdpPollHelper").finish_non_exhaustive()
96  }
97}
98
99/// A runtime for Quinn that utilizes Monoio and async_io
100#[derive(Debug)]
101#[cfg(feature = "runtime-monoio")]
102struct MonoioAsyncioRuntime;
103
104#[cfg(feature = "runtime-monoio")]
105impl Runtime for MonoioAsyncioRuntime {
106  fn new_timer(&self, t: Instant) -> Pin<Box<dyn AsyncTimer>> {
107    Box::pin(Timer {
108      inner: async_io::Timer::at(t),
109    })
110  }
111
112  fn spawn(&self, future: Pin<Box<dyn Future<Output = ()> + Send>>) {
113    monoio::spawn(future);
114  }
115
116  fn wrap_udp_socket(&self, sock: std::net::UdpSocket) -> io::Result<Arc<dyn AsyncUdpSocket>> {
117    Ok(Arc::new(UdpSocket::new(sock)?))
118  }
119}
120
121#[cfg(feature = "runtime-monoio")]
122pin_project! {
123    struct Timer {
124        #[pin]
125        inner: async_io::Timer
126    }
127}
128
129#[cfg(feature = "runtime-monoio")]
130impl AsyncTimer for Timer {
131  fn reset(mut self: Pin<&mut Self>, t: Instant) {
132    self.inner.set_at(t)
133  }
134
135  fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
136    Future::poll(self.project().inner, cx).map(|_| ())
137  }
138}
139
140#[cfg(feature = "runtime-monoio")]
141impl Debug for Timer {
142  fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
143    self.inner.fmt(f)
144  }
145}
146
147#[cfg(feature = "runtime-monoio")]
148#[derive(Debug)]
149struct UdpSocket {
150  io: Async<std::net::UdpSocket>,
151  inner: udp::UdpSocketState,
152}
153
154#[cfg(feature = "runtime-monoio")]
155impl UdpSocket {
156  fn new(sock: std::net::UdpSocket) -> io::Result<Self> {
157    Ok(Self {
158      inner: udp::UdpSocketState::new((&sock).into())?,
159      io: Async::new_nonblocking(sock)?,
160    })
161  }
162}
163
164#[cfg(feature = "runtime-monoio")]
165impl AsyncUdpSocket for UdpSocket {
166  fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
167    Box::pin(UdpPollHelper::new(move || {
168      let socket = self.clone();
169      async move { socket.io.writable().await }
170    }))
171  }
172
173  fn try_send(&self, transmit: &udp::Transmit) -> io::Result<()> {
174    self.inner.send((&self.io).into(), transmit)
175  }
176
177  fn poll_recv(
178    &self,
179    cx: &mut Context,
180    bufs: &mut [io::IoSliceMut<'_>],
181    meta: &mut [udp::RecvMeta],
182  ) -> Poll<io::Result<usize>> {
183    loop {
184      ready!(self.io.poll_readable(cx))?;
185      if let Ok(res) = self.inner.recv((&self.io).into(), bufs, meta) {
186        return Poll::Ready(Ok(res));
187      }
188    }
189  }
190
191  fn local_addr(&self) -> io::Result<std::net::SocketAddr> {
192    self.io.as_ref().local_addr()
193  }
194
195  fn may_fragment(&self) -> bool {
196    self.inner.may_fragment()
197  }
198
199  fn max_transmit_segments(&self) -> usize {
200    self.inner.max_gso_segments()
201  }
202
203  fn max_receive_segments(&self) -> usize {
204    self.inner.gro_segments()
205  }
206}
207
208/// Creates a QUIC listener
209#[allow(clippy::type_complexity)]
210pub fn create_quic_listener(
211  address: SocketAddr,
212  tls_config: Arc<ServerConfig>,
213  tx: Sender<ConnectionData>,
214  enable_uring: bool,
215  logging_tx: Option<Sender<LogMessage>>,
216  first_startup: bool,
217) -> Result<(CancellationToken, Sender<Arc<ServerConfig>>), Box<dyn Error + Send + Sync>> {
218  let shutdown_tx = CancellationToken::new();
219  let shutdown_rx = shutdown_tx.clone();
220  let (rustls_config_tx, rustls_config_rx) = async_channel::unbounded();
221  let (listen_error_tx, listen_error_rx) = async_channel::unbounded();
222  std::thread::Builder::new()
223    .name(format!("QUIC listener for {address}"))
224    .spawn(move || {
225      crate::runtime::new_runtime(
226        async move {
227          if let Err(error) = quic_listener_fn(
228            address,
229            tls_config,
230            tx,
231            &listen_error_tx,
232            logging_tx,
233            first_startup,
234            shutdown_rx,
235            rustls_config_rx,
236          )
237          .await
238          {
239            listen_error_tx.send(Some(error)).await.unwrap_or_default();
240          }
241        },
242        enable_uring,
243      )
244      .unwrap();
245    })?;
246
247  if let Some(error) = listen_error_rx.recv_blocking()? {
248    Err(error)?;
249  }
250
251  Ok((shutdown_tx, rustls_config_tx))
252}
253
254/// QUIC listener function
255#[allow(clippy::too_many_arguments)]
256async fn quic_listener_fn(
257  address: SocketAddr,
258  tls_config: Arc<ServerConfig>,
259  tx: Sender<ConnectionData>,
260  listen_error_tx: &Sender<Option<Box<dyn Error + Send + Sync>>>,
261  logging_tx: Option<Sender<LogMessage>>,
262  first_startup: bool,
263  shutdown_rx: CancellationToken,
264  rustls_config_rx: Receiver<Arc<ServerConfig>>,
265) -> Result<(), Box<dyn Error + Send + Sync>> {
266  let quic_server_config = Arc::new(match QuicServerConfig::try_from(tls_config) {
267    Ok(config) => config,
268    Err(err) => Err(anyhow::anyhow!(format!(
269      "Cannot prepare the QUIC server configuration: {}",
270      err
271    )))?,
272  });
273  let server_config = quinn::ServerConfig::with_crypto(quic_server_config);
274  let udp_port = address.port();
275  let mut udp_socket_result;
276  let mut tries: u64 = 0;
277  loop {
278    udp_socket_result = std::net::UdpSocket::bind(address);
279    if first_startup || udp_socket_result.is_ok() {
280      break;
281    }
282    tries += 1;
283    let duration = Duration::from_millis(1000);
284    if tries >= 10 {
285      println!("HTTP/3 port is used at try #{tries}, skipping...");
286      listen_error_tx.send(None).await.unwrap_or_default();
287      break;
288    }
289    println!("HTTP/3 port is used at try #{tries}, retrying in {duration:?}...");
290    if shutdown_rx.is_cancelled() {
291      break;
292    }
293    crate::runtime::sleep(duration).await;
294  }
295  let udp_socket = match udp_socket_result {
296    Ok(socket) => socket,
297    Err(err) => Err(anyhow::anyhow!(format!("Cannot listen to HTTP/3 port: {}", err)))?,
298  };
299  let endpoint = match quinn::Endpoint::new(quinn::EndpointConfig::default(), Some(server_config), udp_socket, {
300    #[cfg(feature = "runtime-monoio")]
301    let runtime = Arc::new(MonoioAsyncioRuntime);
302    #[cfg(feature = "runtime-tokio")]
303    let runtime = Arc::new(quinn::TokioRuntime);
304
305    runtime
306  }) {
307    Ok(endpoint) => endpoint,
308    Err(err) => Err(anyhow::anyhow!(format!("Cannot listen to HTTP/3 port: {}", err)))?,
309  };
310  println!("HTTP/3 server is listening on {address}...");
311  listen_error_tx.send(None).await.unwrap_or_default();
312
313  loop {
314    let rustls_receive_future = async {
315      if let Ok(rustls_server_config) = rustls_config_rx.recv().await {
316        rustls_server_config
317      } else {
318        futures_util::future::pending().await
319      }
320    };
321
322    let new_conn = crate::runtime::select! {
323      result = endpoint.accept() => {
324          match result {
325              Some(conn) => conn,
326              None => {
327                  if let Some(logging_tx) = &logging_tx {
328                      logging_tx
329                          .send(LogMessage::new(
330                              "HTTP/3 connections can't be accepted anymore".to_string(),
331                              true,
332                          ))
333                          .await
334                          .unwrap_or_default();
335                  }
336                  break;
337              }
338          }
339      }
340      tls_config = rustls_receive_future => {
341          let quic_server_config = Arc::new(match QuicServerConfig::try_from(tls_config) {
342              Ok(config) => config,
343              Err(_) => continue,
344          });
345          let server_config = quinn::ServerConfig::with_crypto(quic_server_config);
346          endpoint.set_server_config(Some(server_config));
347          continue;
348      }
349      _ = shutdown_rx.cancelled() => {
350          break;
351      }
352    };
353    let remote_address = new_conn.remote_address();
354    let local_address = SocketAddr::new(
355      new_conn.local_ip().unwrap_or(IpAddr::V6(Ipv6Addr::UNSPECIFIED)),
356      udp_port,
357    );
358    let quic_data = ConnectionData {
359      connection: Connection::Quic(new_conn),
360      client_address: remote_address,
361      server_address: local_address,
362    };
363    let quic_tx = tx.clone();
364    crate::runtime::spawn(async move {
365      quic_tx.send(quic_data).await.unwrap_or_default();
366    });
367  }
368
369  endpoint.wait_idle().await;
370
371  Ok(())
372}