ferron/listeners/
quic.rs

1// Copyright (c) 2018 The quinn Developers
2// Portions of this file are derived from Quinn (https://github.com/quinn-rs/quinn).
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in all
12// copies or substantial portions of the Software.
13//
14
15use std::error::Error;
16#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
17use std::fmt::Debug;
18#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
19use std::future::Future;
20use std::io;
21use std::net::{IpAddr, Ipv6Addr, SocketAddr};
22#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
23use std::pin::Pin;
24use std::sync::Arc;
25#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
26use std::task::{Context, Poll};
27use std::time::Duration;
28#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
29use std::time::Instant;
30
31use async_channel::{Receiver, Sender};
32use ferron_common::logging::LogMessage;
33#[cfg(feature = "runtime-monoio")]
34use monoio::time::Sleep;
35use quinn::crypto::rustls::QuicServerConfig;
36#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
37use quinn::{AsyncTimer, AsyncUdpSocket, Runtime};
38use rustls::ServerConfig;
39#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
40use send_wrapper::SendWrapper;
41use tokio_util::sync::CancellationToken;
42#[cfg(feature = "runtime-vibeio")]
43use vibeio::time::Sleep;
44
45use crate::listener_handler_communication::{Connection, ConnectionData};
46
47type ListenerError = Box<dyn Error + Send + Sync>;
48
49/// A timer for Quinn that utilizes Monoio's timer.
50#[cfg(feature = "runtime-monoio")]
51#[derive(Debug)]
52struct MonoioTimer {
53  inner: SendWrapper<Pin<Box<Sleep>>>,
54}
55
56#[cfg(feature = "runtime-monoio")]
57impl AsyncTimer for MonoioTimer {
58  fn reset(mut self: Pin<&mut Self>, t: Instant) {
59    (*self.inner).as_mut().reset(t.into())
60  }
61
62  fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
63    (*self.inner).as_mut().poll(cx)
64  }
65}
66
67/// A timer for Quinn that utilizes `vibeio`'s timer.
68#[cfg(feature = "runtime-vibeio")]
69struct CustomAsyncTimer {
70  inner: SendWrapper<Pin<Box<Sleep>>>,
71}
72
73#[cfg(feature = "runtime-vibeio")]
74impl AsyncTimer for CustomAsyncTimer {
75  fn reset(mut self: Pin<&mut Self>, t: Instant) {
76    (*self.inner).as_mut().reset(t)
77  }
78
79  fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
80    (*self.inner).as_mut().poll(cx)
81  }
82}
83
84#[cfg(feature = "runtime-vibeio")]
85impl Debug for CustomAsyncTimer {
86  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87    f.debug_struct("CustomAsyncTimer").finish()
88  }
89}
90
91/// A runtime for Quinn that utilizes Tokio, if under Tokio runtime, and otherwise Monoio with async_io.
92#[derive(Debug)]
93#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
94struct EnterTokioRuntime;
95
96#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
97impl Runtime for EnterTokioRuntime {
98  fn new_timer(&self, t: Instant) -> Pin<Box<dyn AsyncTimer>> {
99    if tokio::runtime::Handle::try_current().is_ok() {
100      Box::pin(tokio::time::sleep_until(t.into()))
101    } else {
102      #[cfg(feature = "runtime-monoio")]
103      let timer = Box::pin(MonoioTimer {
104        inner: SendWrapper::new(Box::pin(monoio::time::sleep_until(t.into()))),
105      });
106      #[cfg(feature = "runtime-vibeio")]
107      let timer = Box::pin(CustomAsyncTimer {
108        inner: SendWrapper::new(Box::pin(vibeio::time::sleep_until(t))),
109      });
110      timer
111    }
112  }
113
114  fn spawn(&self, future: Pin<Box<dyn Future<Output = ()> + Send>>) {
115    if let Ok(handle) = tokio::runtime::Handle::try_current() {
116      handle.spawn(future);
117    } else {
118      #[cfg(feature = "runtime-monoio")]
119      monoio::spawn(future);
120      #[cfg(feature = "runtime-vibeio")]
121      vibeio::spawn(future);
122    }
123  }
124
125  fn wrap_udp_socket(&self, sock: std::net::UdpSocket) -> io::Result<Arc<dyn AsyncUdpSocket>> {
126    quinn::TokioRuntime::wrap_udp_socket(&quinn::TokioRuntime, sock)
127  }
128}
129
130#[inline]
131fn build_quic_server_config(tls_config: Arc<ServerConfig>) -> Result<quinn::ServerConfig, ListenerError> {
132  let quic_server_config = QuicServerConfig::try_from(tls_config)
133    .map_err(|err| anyhow::anyhow!("Cannot prepare the QUIC server configuration: {err}"))?;
134  Ok(quinn::ServerConfig::with_crypto(Arc::new(quic_server_config)))
135}
136
137#[inline]
138fn bind_udp_socket(address: SocketAddr) -> io::Result<std::net::UdpSocket> {
139  // Create a new socket
140  let listener_socket2 = socket2::Socket::new(
141    if address.is_ipv6() {
142      socket2::Domain::IPV6
143    } else {
144      socket2::Domain::IPV4
145    },
146    socket2::Type::DGRAM,
147    Some(socket2::Protocol::UDP),
148  )?;
149
150  // Set socket options
151  if address.is_ipv6() {
152    listener_socket2.set_only_v6(false).unwrap_or_default();
153  }
154
155  // Bind the socket to the address
156  listener_socket2.bind(&address.into())?;
157
158  // Wrap the socket into a UdpSocket
159  Ok(listener_socket2.into())
160}
161
162#[inline]
163async fn log_accept_closed(logging_tx: &Option<Sender<LogMessage>>) {
164  if let Some(logging_tx) = logging_tx {
165    logging_tx
166      .send(LogMessage::new(
167        "HTTP/3 connections can't be accepted anymore".to_string(),
168        true,
169      ))
170      .await
171      .unwrap_or_default();
172  }
173}
174
175/// Creates a QUIC listener
176#[allow(clippy::type_complexity)]
177pub fn create_quic_listener(
178  address: SocketAddr,
179  tls_config: Arc<ServerConfig>,
180  tx: Sender<ConnectionData>,
181  logging_tx: Option<Sender<LogMessage>>,
182  first_startup: bool,
183) -> Result<(CancellationToken, Sender<Arc<ServerConfig>>), Box<dyn Error + Send + Sync>> {
184  let shutdown_tx = CancellationToken::new();
185  let shutdown_rx = shutdown_tx.clone();
186  let (rustls_config_tx, rustls_config_rx) = async_channel::unbounded();
187  let (listen_error_tx, listen_error_rx) = async_channel::unbounded();
188  std::thread::Builder::new()
189    .name(format!("QUIC listener for {address}"))
190    .spawn(move || {
191      let mut rt = match crate::runtime::Runtime::new_runtime_tokio_only() {
192        Ok(rt) => rt,
193        Err(error) => {
194          listen_error_tx
195            .send_blocking(Some(
196              anyhow::anyhow!("Can't create async runtime: {error}").into_boxed_dyn_error(),
197            ))
198            .unwrap_or_default();
199          return;
200        }
201      };
202      rt.run(async move {
203        if let Err(error) = quic_listener_fn(
204          address,
205          tls_config,
206          tx,
207          &listen_error_tx,
208          logging_tx,
209          first_startup,
210          shutdown_rx,
211          rustls_config_rx,
212        )
213        .await
214        {
215          listen_error_tx.send(Some(error)).await.unwrap_or_default();
216        }
217      });
218    })?;
219
220  if let Some(error) = listen_error_rx.recv_blocking()? {
221    Err(error)?;
222  }
223
224  Ok((shutdown_tx, rustls_config_tx))
225}
226
227/// QUIC listener function
228#[allow(clippy::too_many_arguments)]
229async fn quic_listener_fn(
230  address: SocketAddr,
231  tls_config: Arc<ServerConfig>,
232  tx: Sender<ConnectionData>,
233  listen_error_tx: &Sender<Option<ListenerError>>,
234  logging_tx: Option<Sender<LogMessage>>,
235  first_startup: bool,
236  shutdown_rx: CancellationToken,
237  rustls_config_rx: Receiver<Arc<ServerConfig>>,
238) -> Result<(), ListenerError> {
239  let server_config = build_quic_server_config(tls_config)?;
240  let udp_port = address.port();
241  let mut udp_socket_result;
242  let mut tries: u64 = 0;
243  loop {
244    udp_socket_result = bind_udp_socket(address);
245    if first_startup || udp_socket_result.is_ok() {
246      break;
247    }
248    tries += 1;
249    let duration = Duration::from_millis(1000);
250    if tries >= 10 {
251      println!("HTTP/3 port is used at try #{tries}, skipping...");
252      listen_error_tx.send(None).await.unwrap_or_default();
253      break;
254    }
255    println!("HTTP/3 port is used at try #{tries}, retrying in {duration:?}...");
256    if shutdown_rx.is_cancelled() {
257      break;
258    }
259    tokio::time::sleep(duration).await;
260  }
261  let udp_socket = match udp_socket_result {
262    Ok(socket) => socket,
263    Err(err) => Err(anyhow::anyhow!("Cannot listen to HTTP/3 port: {err}"))?,
264  };
265  let endpoint = match quinn::Endpoint::new(quinn::EndpointConfig::default(), Some(server_config), udp_socket, {
266    #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
267    let runtime = Arc::new(EnterTokioRuntime);
268    #[cfg(feature = "runtime-tokio")]
269    let runtime = Arc::new(quinn::TokioRuntime);
270
271    runtime
272  }) {
273    Ok(endpoint) => endpoint,
274    Err(err) => Err(anyhow::anyhow!("Cannot listen to HTTP/3 port: {err}"))?,
275  };
276  println!("HTTP/3 server is listening on {address}...");
277  listen_error_tx.send(None).await.unwrap_or_default();
278
279  loop {
280    let rustls_receive_future = async { rustls_config_rx.recv().await.ok() };
281
282    let connection = tokio::select! {
283      result = endpoint.accept() => {
284        match result {
285          Some(conn) => conn,
286          None => {
287            log_accept_closed(&logging_tx).await;
288            break;
289          }
290        }
291      }
292      tls_config = rustls_receive_future => {
293        let Some(tls_config) = tls_config else {
294          futures_util::future::pending::<()>().await;
295          unreachable!();
296        };
297
298        if let Ok(server_config) = build_quic_server_config(tls_config) {
299          endpoint.set_server_config(Some(server_config));
300        }
301        continue;
302      }
303      _ = shutdown_rx.cancelled() => {
304        break;
305      }
306    };
307    let remote_address = connection.remote_address();
308    let local_address = SocketAddr::new(
309      connection.local_ip().unwrap_or(IpAddr::V6(Ipv6Addr::UNSPECIFIED)),
310      udp_port,
311    );
312    let quic_data = ConnectionData {
313      connection: Connection::Quic(connection),
314      client_address: remote_address,
315      server_address: local_address,
316    };
317    let quic_tx = tx.clone();
318    tokio::spawn(async move {
319      quic_tx.send(quic_data).await.unwrap_or_default();
320    });
321  }
322
323  endpoint.wait_idle().await;
324
325  Ok(())
326}