ferron/util/
tls.rs

1use crate::util::match_hostname;
2use rustls::server::{ClientHello, ResolvesServerCert};
3use rustls::sign::CertifiedKey;
4use rustls_pki_types::pem::PemObject;
5use rustls_pki_types::{CertificateDer, PrivateKeyDer};
6use std::sync::Arc;
7
8/// Custom SNI resolver, consisting of multiple resolvers
9#[derive(Debug)]
10pub struct CustomSniResolver {
11  fallback_resolver: Option<Arc<dyn ResolvesServerCert>>,
12  #[allow(clippy::type_complexity)]
13  resolvers: Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>,
14  fallback_sender: Option<(async_channel::Sender<(String, u16)>, u16)>,
15}
16
17impl CustomSniResolver {
18  /// Creates a custom SNI resolver
19  #[allow(dead_code)]
20  pub fn new() -> Self {
21    Self {
22      fallback_resolver: None,
23      resolvers: Arc::new(tokio::sync::RwLock::new(Vec::new())),
24      fallback_sender: None,
25    }
26  }
27
28  /// Creates a custom SNI resolver with provided resolvers lock
29  #[allow(clippy::type_complexity)]
30  pub fn with_resolvers(resolvers: Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>) -> Self {
31    Self {
32      fallback_resolver: None,
33      resolvers,
34      fallback_sender: None,
35    }
36  }
37
38  /// Loads a fallback certificate resolver for a specific host
39  pub fn load_fallback_resolver(&mut self, fallback_resolver: Arc<dyn ResolvesServerCert>) {
40    self.fallback_resolver = Some(fallback_resolver);
41  }
42
43  /// Loads a host certificate resolver for a specific host
44  pub fn load_host_resolver(&mut self, host: &str, resolver: Arc<dyn ResolvesServerCert>) {
45    load_host_resolver(&mut self.resolvers.blocking_write(), host, resolver);
46  }
47
48  /// Loads a fallback sender used for sending SNI hostnames for a specific host
49  pub fn load_fallback_sender(&mut self, fallback_sender: async_channel::Sender<(String, u16)>, port: u16) {
50    self.fallback_sender = Some((fallback_sender, port));
51  }
52}
53
54impl ResolvesServerCert for CustomSniResolver {
55  fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
56    let hostname = client_hello.server_name().map(|hn| hn.strip_suffix('.').unwrap_or(hn));
57    if let Some(hostname) = hostname {
58      // If blocking_read() method is used when only Tokio is used, the program would panic on resolving a TLS certificate.
59      #[cfg(feature = "runtime-monoio")]
60      let resolvers = self.resolvers.blocking_read();
61      #[cfg(feature = "runtime-tokio")]
62      let resolvers = futures_executor::block_on(async { self.resolvers.read().await });
63
64      for (configured_hostname, resolver) in resolvers.iter() {
65        if match_hostname(Some(configured_hostname), Some(hostname)) {
66          return resolver.resolve(client_hello);
67        }
68      }
69    }
70    let hostname = hostname.map(|v| v.to_string());
71    self
72      .fallback_resolver
73      .as_ref()
74      .and_then(|r| r.resolve(client_hello))
75      .or_else(|| {
76        if let Some((sender, port)) = &self.fallback_sender {
77          if let Some(hostname) = hostname {
78            sender.send_blocking((hostname.to_string(), *port)).unwrap_or_default();
79          }
80        }
81        None
82      })
83  }
84}
85
86/// A certificate resolver resolving one certified key
87#[derive(Debug)]
88pub struct OneCertifiedKeyResolver {
89  certified_key: Arc<CertifiedKey>,
90}
91
92impl OneCertifiedKeyResolver {
93  /// Creates a certificate resolver with a certified key
94  pub fn new(certified_key: Arc<CertifiedKey>) -> Self {
95    Self { certified_key }
96  }
97}
98
99impl ResolvesServerCert for OneCertifiedKeyResolver {
100  fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
101    Some(self.certified_key.clone())
102  }
103}
104
105/// Loads a host certificate resolver for a specific host
106pub fn load_host_resolver(
107  resolvers: &mut Vec<(String, Arc<dyn ResolvesServerCert>)>,
108  host: &str,
109  resolver: Arc<dyn ResolvesServerCert>,
110) {
111  if !resolvers.iter().any(|(h, _)| h == host) {
112    resolvers.push((host.to_string(), resolver));
113  }
114  resolvers.sort_by(|a, b| {
115    (a.0.starts_with("*."))
116      .cmp(&(b.0.starts_with("*."))) // Take wildcard hostnames into account
117      .then_with(|| {
118        b.0
119          .trim_end_matches('.')
120          .chars()
121          .filter(|c| *c == '.')
122          .count()
123          .cmp(&a.0.trim_end_matches('.').chars().filter(|c| *c == '.').count())
124      }) // Take also amount of dots in hostnames (domain level) into account
125  });
126}
127/// Loads a public certificate from file
128pub fn load_certs(filename: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
129  let mut certfile = std::fs::File::open(filename)?;
130  CertificateDer::pem_reader_iter(&mut certfile)
131    .collect::<Result<Vec<_>, _>>()
132    .map_err(|e| match e {
133      rustls_pki_types::pem::Error::Io(err) => err,
134      err => std::io::Error::other(err),
135    })
136}
137
138/// Loads a private key from file
139pub fn load_private_key(filename: &str) -> std::io::Result<PrivateKeyDer<'static>> {
140  let mut keyfile = std::fs::File::open(filename)?;
141  match PrivateKeyDer::from_pem_reader(&mut keyfile) {
142    Ok(private_key) => Ok(private_key),
143    Err(rustls_pki_types::pem::Error::Io(err)) => Err(err),
144    Err(err) => Err(std::io::Error::other(err)),
145  }
146}