ferron/
tls_util.rs

1use crate::util::match_hostname;
2use rustls::server::{ClientHello, ResolvesServerCert};
3use rustls::sign::CertifiedKey;
4use rustls_pki_types::{CertificateDer, PrivateKeyDer};
5use std::sync::Arc;
6
7/// Custom SNI resolver, consisting of multiple resolvers
8#[derive(Debug)]
9pub struct CustomSniResolver {
10  fallback_resolver: Option<Arc<dyn ResolvesServerCert>>,
11  #[allow(clippy::type_complexity)]
12  resolvers: Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>,
13  fallback_sender: Option<(async_channel::Sender<(String, u16)>, u16)>,
14}
15
16impl CustomSniResolver {
17  /// Creates a custom SNI resolver
18  #[allow(dead_code)]
19  pub fn new() -> Self {
20    Self {
21      fallback_resolver: None,
22      resolvers: Arc::new(tokio::sync::RwLock::new(Vec::new())),
23      fallback_sender: None,
24    }
25  }
26
27  /// Creates a custom SNI resolver with provided resolvers lock
28  #[allow(clippy::type_complexity)]
29  pub fn with_resolvers(resolvers: Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>) -> Self {
30    Self {
31      fallback_resolver: None,
32      resolvers,
33      fallback_sender: None,
34    }
35  }
36
37  /// Loads a fallback certificate resolver for a specific host
38  pub fn load_fallback_resolver(&mut self, fallback_resolver: Arc<dyn ResolvesServerCert>) {
39    self.fallback_resolver = Some(fallback_resolver);
40  }
41
42  /// Loads a host certificate resolver for a specific host
43  pub fn load_host_resolver(&mut self, host: &str, resolver: Arc<dyn ResolvesServerCert>) {
44    load_host_resolver(&mut self.resolvers.blocking_write(), host, resolver);
45  }
46
47  /// Loads a fallback sender used for sending SNI hostnames for a specific host
48  pub fn load_fallback_sender(&mut self, fallback_sender: async_channel::Sender<(String, u16)>, port: u16) {
49    self.fallback_sender = Some((fallback_sender, port));
50  }
51}
52
53impl ResolvesServerCert for CustomSniResolver {
54  fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
55    let hostname = client_hello.server_name().map(|hn| hn.strip_suffix('.').unwrap_or(hn));
56    if let Some(hostname) = hostname {
57      // If blocking_read() method is used when only Tokio is used, the program would panic on resolving a TLS certificate.
58      #[cfg(feature = "runtime-monoio")]
59      let resolvers = self.resolvers.blocking_read();
60      #[cfg(feature = "runtime-tokio")]
61      let resolvers = futures_executor::block_on(async { self.resolvers.read().await });
62
63      for (configured_hostname, resolver) in resolvers.iter() {
64        if match_hostname(Some(configured_hostname), Some(hostname)) {
65          return resolver.resolve(client_hello);
66        }
67      }
68    }
69    let hostname = hostname.map(|v| v.to_string());
70    self
71      .fallback_resolver
72      .as_ref()
73      .and_then(|r| r.resolve(client_hello))
74      .or_else(|| {
75        if let Some((sender, port)) = &self.fallback_sender {
76          if let Some(hostname) = hostname {
77            sender.send_blocking((hostname.to_string(), *port)).unwrap_or_default();
78          }
79        }
80        None
81      })
82  }
83}
84
85/// A certificate resolver resolving one certified key
86#[derive(Debug)]
87pub struct OneCertifiedKeyResolver {
88  certified_key: Arc<CertifiedKey>,
89}
90
91impl OneCertifiedKeyResolver {
92  /// Creates a certificate resolver with a certified key
93  pub fn new(certified_key: Arc<CertifiedKey>) -> Self {
94    Self { certified_key }
95  }
96}
97
98impl ResolvesServerCert for OneCertifiedKeyResolver {
99  fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
100    Some(self.certified_key.clone())
101  }
102}
103
104/// Loads a host certificate resolver for a specific host
105pub fn load_host_resolver(
106  resolvers: &mut Vec<(String, Arc<dyn ResolvesServerCert>)>,
107  host: &str,
108  resolver: Arc<dyn ResolvesServerCert>,
109) {
110  if !resolvers.iter().any(|(h, _)| h == host) {
111    resolvers.push((host.to_string(), resolver));
112  }
113  resolvers.sort_by(|a, b| {
114    (a.0.starts_with("*."))
115      .cmp(&(b.0.starts_with("*."))) // Take wildcard hostnames into account
116      .then_with(|| {
117        b.0
118          .trim_end_matches('.')
119          .chars()
120          .filter(|c| *c == '.')
121          .count()
122          .cmp(&a.0.trim_end_matches('.').chars().filter(|c| *c == '.').count())
123      }) // Take also amount of dots in hostnames (domain level) into account
124  });
125}
126/// Loads a public certificate from file
127pub fn load_certs(filename: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
128  let certfile = std::fs::File::open(filename)?;
129  let mut reader = std::io::BufReader::new(certfile);
130  rustls_pemfile::certs(&mut reader).collect()
131}
132
133/// Loads a private key from file
134pub fn load_private_key(filename: &str) -> std::io::Result<PrivateKeyDer<'static>> {
135  let keyfile = std::fs::File::open(filename)?;
136  let mut reader = std::io::BufReader::new(keyfile);
137  match rustls_pemfile::private_key(&mut reader) {
138    Ok(Some(private_key)) => Ok(private_key),
139    Ok(None) => Err(std::io::Error::new(
140      std::io::ErrorKind::InvalidData,
141      "Invalid private key",
142    )),
143    Err(err) => Err(err),
144  }
145}