ferron/util/
tls.rs

1use crate::util::HostnameRadixTree;
2use rustls::server::{ClientHello, ResolvesServerCert};
3use rustls::sign::CertifiedKey;
4use rustls_pki_types::pem::PemObject;
5use rustls_pki_types::{CertificateDer, PrivateKeyDer};
6use std::sync::Arc;
7
8/// The type for the SNI resolver lock, which is a vector of tuples containing the hostname and the corresponding certificate resolver.
9pub type SniResolverLock = Arc<tokio::sync::RwLock<HostnameRadixTree<Arc<dyn ResolvesServerCert>>>>;
10
11/// Custom SNI resolver, consisting of multiple resolvers
12#[derive(Debug)]
13pub struct CustomSniResolver {
14  fallback_resolver: Option<Arc<dyn ResolvesServerCert>>,
15  resolvers: SniResolverLock,
16  fallback_sender: Option<(async_channel::Sender<(String, u16)>, u16)>,
17}
18
19impl CustomSniResolver {
20  /// Creates a custom SNI resolver
21  #[allow(dead_code)]
22  pub fn new() -> Self {
23    Self {
24      fallback_resolver: None,
25      resolvers: Arc::new(tokio::sync::RwLock::new(HostnameRadixTree::new())),
26      fallback_sender: None,
27    }
28  }
29
30  /// Creates a custom SNI resolver with provided resolvers lock
31  pub fn with_resolvers(resolvers: SniResolverLock) -> Self {
32    Self {
33      fallback_resolver: None,
34      resolvers,
35      fallback_sender: None,
36    }
37  }
38
39  /// Loads a fallback certificate resolver for a specific host
40  pub fn load_fallback_resolver(&mut self, fallback_resolver: Arc<dyn ResolvesServerCert>) {
41    self.fallback_resolver = Some(fallback_resolver);
42  }
43
44  /// Loads a host certificate resolver for a specific host
45  pub fn load_host_resolver(&mut self, host: &str, resolver: Arc<dyn ResolvesServerCert>) {
46    self.resolvers.blocking_write().insert(host.to_string(), resolver);
47  }
48
49  /// Loads a fallback sender used for sending SNI hostnames for a specific host
50  pub fn load_fallback_sender(&mut self, fallback_sender: async_channel::Sender<(String, u16)>, port: u16) {
51    self.fallback_sender = Some((fallback_sender, port));
52  }
53}
54
55impl ResolvesServerCert for CustomSniResolver {
56  fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
57    let hostname = client_hello.server_name().map(|hn| hn.strip_suffix('.').unwrap_or(hn));
58    if let Some(hostname) = hostname {
59      // If blocking_read() method is used when only Tokio is used, the program would panic on resolving a TLS certificate.
60      #[cfg(feature = "runtime-monoio")]
61      let resolvers = self.resolvers.blocking_read();
62      #[cfg(feature = "runtime-tokio")]
63      let resolvers = futures_executor::block_on(async { self.resolvers.read().await });
64
65      if let Some(resolver) = resolvers.get(hostname).cloned() {
66        return resolver.resolve(client_hello);
67      }
68    }
69    let hostname = hostname.map(|v| v.to_string());
70    self
71      .fallback_resolver
72      .as_ref()
73      .and_then(|r| r.resolve(client_hello))
74      .or_else(|| {
75        if let Some((sender, port)) = &self.fallback_sender {
76          if let Some(hostname) = hostname {
77            sender.send_blocking((hostname.to_string(), *port)).unwrap_or_default();
78          }
79        }
80        None
81      })
82  }
83}
84
85/// A certificate resolver resolving one certified key
86#[derive(Debug)]
87pub struct OneCertifiedKeyResolver {
88  certified_key: Arc<CertifiedKey>,
89}
90
91impl OneCertifiedKeyResolver {
92  /// Creates a certificate resolver with a certified key
93  pub fn new(certified_key: Arc<CertifiedKey>) -> Self {
94    Self { certified_key }
95  }
96}
97
98impl ResolvesServerCert for OneCertifiedKeyResolver {
99  fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
100    Some(self.certified_key.clone())
101  }
102}
103
104/// Loads a public certificate from file
105pub fn load_certs(filename: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
106  let mut certfile = std::fs::File::open(filename)?;
107  CertificateDer::pem_reader_iter(&mut certfile)
108    .collect::<Result<Vec<_>, _>>()
109    .map_err(|e| match e {
110      rustls_pki_types::pem::Error::Io(err) => err,
111      err => std::io::Error::other(err),
112    })
113}
114
115/// Loads a private key from file
116pub fn load_private_key(filename: &str) -> std::io::Result<PrivateKeyDer<'static>> {
117  let mut keyfile = std::fs::File::open(filename)?;
118  match PrivateKeyDer::from_pem_reader(&mut keyfile) {
119    Ok(private_key) => Ok(private_key),
120    Err(rustls_pki_types::pem::Error::Io(err)) => Err(err),
121    Err(err) => Err(std::io::Error::other(err)),
122  }
123}