ferron/listeners/
quic.rs

1// Copyright (c) 2018 The quinn Developers
2// Portions of this file are derived from Quinn (https://github.com/quinn-rs/quinn).
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in all
12// copies or substantial portions of the Software.
13//
14
15use std::error::Error;
16#[cfg(feature = "runtime-monoio")]
17use std::fmt::Debug;
18#[cfg(feature = "runtime-monoio")]
19use std::future::Future;
20#[cfg(feature = "runtime-monoio")]
21use std::io;
22use std::net::{IpAddr, Ipv6Addr, SocketAddr};
23#[cfg(feature = "runtime-monoio")]
24use std::pin::Pin;
25use std::sync::Arc;
26#[cfg(feature = "runtime-monoio")]
27use std::task::{Context, Poll};
28use std::time::Duration;
29#[cfg(feature = "runtime-monoio")]
30use std::time::Instant;
31
32use async_channel::{Receiver, Sender};
33use ferron_common::logging::LogMessage;
34#[cfg(feature = "runtime-monoio")]
35use monoio::time::Sleep;
36use quinn::crypto::rustls::QuicServerConfig;
37#[cfg(feature = "runtime-monoio")]
38use quinn::{AsyncTimer, AsyncUdpSocket, Runtime};
39use rustls::ServerConfig;
40#[cfg(feature = "runtime-monoio")]
41use send_wrapper::SendWrapper;
42use tokio_util::sync::CancellationToken;
43
44use crate::listener_handler_communication::{Connection, ConnectionData};
45
46/// A timer for Quinn that utilizes Monoio's timer.
47#[cfg(feature = "runtime-monoio")]
48#[derive(Debug)]
49struct MonoioTimer {
50  inner: SendWrapper<Pin<Box<Sleep>>>,
51}
52
53#[cfg(feature = "runtime-monoio")]
54impl AsyncTimer for MonoioTimer {
55  fn reset(mut self: Pin<&mut Self>, t: Instant) {
56    (*self.inner).as_mut().reset(t.into())
57  }
58
59  fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
60    (*self.inner).as_mut().poll(cx)
61  }
62}
63
64/// A runtime for Quinn that utilizes Tokio, if under Tokio runtime, and otherwise Monoio with async_io.
65#[derive(Debug)]
66#[cfg(feature = "runtime-monoio")]
67struct EnterTokioRuntime;
68
69#[cfg(feature = "runtime-monoio")]
70impl Runtime for EnterTokioRuntime {
71  fn new_timer(&self, t: Instant) -> Pin<Box<dyn AsyncTimer>> {
72    if tokio::runtime::Handle::try_current().is_ok() {
73      Box::pin(tokio::time::sleep_until(t.into()))
74    } else {
75      Box::pin(MonoioTimer {
76        inner: SendWrapper::new(Box::pin(monoio::time::sleep_until(t.into()))),
77      })
78    }
79  }
80
81  fn spawn(&self, future: Pin<Box<dyn Future<Output = ()> + Send>>) {
82    if let Ok(handle) = tokio::runtime::Handle::try_current() {
83      handle.spawn(future);
84    } else {
85      monoio::spawn(future);
86    }
87  }
88
89  fn wrap_udp_socket(&self, sock: std::net::UdpSocket) -> io::Result<Arc<dyn AsyncUdpSocket>> {
90    quinn::TokioRuntime::wrap_udp_socket(&quinn::TokioRuntime, sock)
91  }
92}
93
94/// Creates a QUIC listener
95#[allow(clippy::type_complexity)]
96pub fn create_quic_listener(
97  address: SocketAddr,
98  tls_config: Arc<ServerConfig>,
99  tx: Sender<ConnectionData>,
100  logging_tx: Option<Sender<LogMessage>>,
101  first_startup: bool,
102) -> Result<(CancellationToken, Sender<Arc<ServerConfig>>), Box<dyn Error + Send + Sync>> {
103  let shutdown_tx = CancellationToken::new();
104  let shutdown_rx = shutdown_tx.clone();
105  let (rustls_config_tx, rustls_config_rx) = async_channel::unbounded();
106  let (listen_error_tx, listen_error_rx) = async_channel::unbounded();
107  std::thread::Builder::new()
108    .name(format!("QUIC listener for {address}"))
109    .spawn(move || {
110      let mut rt = match crate::runtime::Runtime::new_runtime_tokio_only() {
111        Ok(rt) => rt,
112        Err(error) => {
113          listen_error_tx
114            .send_blocking(Some(
115              anyhow::anyhow!("Can't create async runtime: {error}").into_boxed_dyn_error(),
116            ))
117            .unwrap_or_default();
118          return;
119        }
120      };
121      rt.run(async move {
122        if let Err(error) = quic_listener_fn(
123          address,
124          tls_config,
125          tx,
126          &listen_error_tx,
127          logging_tx,
128          first_startup,
129          shutdown_rx,
130          rustls_config_rx,
131        )
132        .await
133        {
134          listen_error_tx.send(Some(error)).await.unwrap_or_default();
135        }
136      });
137    })?;
138
139  if let Some(error) = listen_error_rx.recv_blocking()? {
140    Err(error)?;
141  }
142
143  Ok((shutdown_tx, rustls_config_tx))
144}
145
146/// QUIC listener function
147#[allow(clippy::too_many_arguments)]
148async fn quic_listener_fn(
149  address: SocketAddr,
150  tls_config: Arc<ServerConfig>,
151  tx: Sender<ConnectionData>,
152  listen_error_tx: &Sender<Option<Box<dyn Error + Send + Sync>>>,
153  logging_tx: Option<Sender<LogMessage>>,
154  first_startup: bool,
155  shutdown_rx: CancellationToken,
156  rustls_config_rx: Receiver<Arc<ServerConfig>>,
157) -> Result<(), Box<dyn Error + Send + Sync>> {
158  let quic_server_config = Arc::new(match QuicServerConfig::try_from(tls_config) {
159    Ok(config) => config,
160    Err(err) => Err(anyhow::anyhow!("Cannot prepare the QUIC server configuration: {}", err))?,
161  });
162  let server_config = quinn::ServerConfig::with_crypto(quic_server_config);
163  let udp_port = address.port();
164  let mut udp_socket_result;
165  let mut tries: u64 = 0;
166  loop {
167    udp_socket_result = (|| {
168      // Create a new socket
169      let listener_socket2 = socket2::Socket::new(
170        if address.is_ipv6() {
171          socket2::Domain::IPV6
172        } else {
173          socket2::Domain::IPV4
174        },
175        socket2::Type::DGRAM,
176        Some(socket2::Protocol::UDP),
177      )?;
178
179      // Set socket options
180      if address.is_ipv6() {
181        listener_socket2.set_only_v6(false).unwrap_or_default();
182      }
183
184      // Bind the socket to the address
185      listener_socket2.bind(&address.into())?;
186
187      // Wrap the socket into a UdpSocket
188      let listener_socket: std::net::UdpSocket = listener_socket2.into();
189      Ok::<_, std::io::Error>(listener_socket)
190    })();
191    if first_startup || udp_socket_result.is_ok() {
192      break;
193    }
194    tries += 1;
195    let duration = Duration::from_millis(1000);
196    if tries >= 10 {
197      println!("HTTP/3 port is used at try #{tries}, skipping...");
198      listen_error_tx.send(None).await.unwrap_or_default();
199      break;
200    }
201    println!("HTTP/3 port is used at try #{tries}, retrying in {duration:?}...");
202    if shutdown_rx.is_cancelled() {
203      break;
204    }
205    crate::runtime::sleep(duration).await;
206  }
207  let udp_socket = match udp_socket_result {
208    Ok(socket) => socket,
209    Err(err) => Err(anyhow::anyhow!("Cannot listen to HTTP/3 port: {}", err))?,
210  };
211  let endpoint = match quinn::Endpoint::new(quinn::EndpointConfig::default(), Some(server_config), udp_socket, {
212    #[cfg(feature = "runtime-monoio")]
213    let runtime = Arc::new(EnterTokioRuntime);
214    #[cfg(feature = "runtime-tokio")]
215    let runtime = Arc::new(quinn::TokioRuntime);
216
217    runtime
218  }) {
219    Ok(endpoint) => endpoint,
220    Err(err) => Err(anyhow::anyhow!("Cannot listen to HTTP/3 port: {}", err))?,
221  };
222  println!("HTTP/3 server is listening on {address}...");
223  listen_error_tx.send(None).await.unwrap_or_default();
224
225  loop {
226    let rustls_receive_future = async {
227      if let Ok(rustls_server_config) = rustls_config_rx.recv().await {
228        rustls_server_config
229      } else {
230        futures_util::future::pending().await
231      }
232    };
233
234    let new_conn = crate::runtime::select! {
235      result = endpoint.accept() => {
236          match result {
237              Some(conn) => conn,
238              None => {
239                  if let Some(logging_tx) = &logging_tx {
240                      logging_tx
241                          .send(LogMessage::new(
242                              "HTTP/3 connections can't be accepted anymore".to_string(),
243                              true,
244                          ))
245                          .await
246                          .unwrap_or_default();
247                  }
248                  break;
249              }
250          }
251      }
252      tls_config = rustls_receive_future => {
253          let quic_server_config = Arc::new(match QuicServerConfig::try_from(tls_config) {
254              Ok(config) => config,
255              Err(_) => continue,
256          });
257          let server_config = quinn::ServerConfig::with_crypto(quic_server_config);
258          endpoint.set_server_config(Some(server_config));
259          continue;
260      }
261      _ = shutdown_rx.cancelled() => {
262          break;
263      }
264    };
265    let remote_address = new_conn.remote_address();
266    let local_address = SocketAddr::new(
267      new_conn.local_ip().unwrap_or(IpAddr::V6(Ipv6Addr::UNSPECIFIED)),
268      udp_port,
269    );
270    let quic_data = ConnectionData {
271      connection: Connection::Quic(new_conn),
272      client_address: remote_address,
273      server_address: local_address,
274    };
275    let quic_tx = tx.clone();
276    tokio::spawn(async move {
277      quic_tx.send(quic_data).await.unwrap_or_default();
278    });
279  }
280
281  endpoint.wait_idle().await;
282
283  Ok(())
284}