1use crate::util::match_hostname;
2use rustls::server::{ClientHello, ResolvesServerCert};
3use rustls::sign::CertifiedKey;
4use rustls_pki_types::{CertificateDer, PrivateKeyDer};
5use std::sync::Arc;
6
7#[derive(Debug)]
9pub struct CustomSniResolver {
10 fallback_resolver: Option<Arc<dyn ResolvesServerCert>>,
11 #[allow(clippy::type_complexity)]
12 resolvers: Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>,
13 fallback_sender: Option<(async_channel::Sender<(String, u16)>, u16)>,
14}
15
16impl CustomSniResolver {
17 #[allow(dead_code)]
19 pub fn new() -> Self {
20 Self {
21 fallback_resolver: None,
22 resolvers: Arc::new(tokio::sync::RwLock::new(Vec::new())),
23 fallback_sender: None,
24 }
25 }
26
27 #[allow(clippy::type_complexity)]
29 pub fn with_resolvers(resolvers: Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>) -> Self {
30 Self {
31 fallback_resolver: None,
32 resolvers,
33 fallback_sender: None,
34 }
35 }
36
37 pub fn load_fallback_resolver(&mut self, fallback_resolver: Arc<dyn ResolvesServerCert>) {
39 self.fallback_resolver = Some(fallback_resolver);
40 }
41
42 pub fn load_host_resolver(&mut self, host: &str, resolver: Arc<dyn ResolvesServerCert>) {
44 load_host_resolver(&mut self.resolvers.blocking_write(), host, resolver);
45 }
46
47 pub fn load_fallback_sender(&mut self, fallback_sender: async_channel::Sender<(String, u16)>, port: u16) {
49 self.fallback_sender = Some((fallback_sender, port));
50 }
51}
52
53impl ResolvesServerCert for CustomSniResolver {
54 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
55 let hostname = client_hello.server_name().map(|hn| hn.strip_suffix('.').unwrap_or(hn));
56 if let Some(hostname) = hostname {
57 #[cfg(feature = "runtime-monoio")]
59 let resolvers = self.resolvers.blocking_read();
60 #[cfg(feature = "runtime-tokio")]
61 let resolvers = futures_executor::block_on(async { self.resolvers.read().await });
62
63 for (configured_hostname, resolver) in resolvers.iter() {
64 if match_hostname(Some(configured_hostname), Some(hostname)) {
65 return resolver.resolve(client_hello);
66 }
67 }
68 }
69 let hostname = hostname.map(|v| v.to_string());
70 self
71 .fallback_resolver
72 .as_ref()
73 .and_then(|r| r.resolve(client_hello))
74 .or_else(|| {
75 if let Some((sender, port)) = &self.fallback_sender {
76 if let Some(hostname) = hostname {
77 sender.send_blocking((hostname.to_string(), *port)).unwrap_or_default();
78 }
79 }
80 None
81 })
82 }
83}
84
85#[derive(Debug)]
87pub struct OneCertifiedKeyResolver {
88 certified_key: Arc<CertifiedKey>,
89}
90
91impl OneCertifiedKeyResolver {
92 pub fn new(certified_key: Arc<CertifiedKey>) -> Self {
94 Self { certified_key }
95 }
96}
97
98impl ResolvesServerCert for OneCertifiedKeyResolver {
99 fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
100 Some(self.certified_key.clone())
101 }
102}
103
104pub fn load_host_resolver(
106 resolvers: &mut Vec<(String, Arc<dyn ResolvesServerCert>)>,
107 host: &str,
108 resolver: Arc<dyn ResolvesServerCert>,
109) {
110 if !resolvers.iter().any(|(h, _)| h == host) {
111 resolvers.push((host.to_string(), resolver));
112 }
113 resolvers.sort_by(|a, b| {
114 (a.0.starts_with("*."))
115 .cmp(&(b.0.starts_with("*."))) .then_with(|| {
117 b.0
118 .trim_end_matches('.')
119 .chars()
120 .filter(|c| *c == '.')
121 .count()
122 .cmp(&a.0.trim_end_matches('.').chars().filter(|c| *c == '.').count())
123 }) });
125}
126pub fn load_certs(filename: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
128 let certfile = std::fs::File::open(filename)?;
129 let mut reader = std::io::BufReader::new(certfile);
130 rustls_pemfile::certs(&mut reader).collect()
131}
132
133pub fn load_private_key(filename: &str) -> std::io::Result<PrivateKeyDer<'static>> {
135 let keyfile = std::fs::File::open(filename)?;
136 let mut reader = std::io::BufReader::new(keyfile);
137 match rustls_pemfile::private_key(&mut reader) {
138 Ok(Some(private_key)) => Ok(private_key),
139 Ok(None) => Err(std::io::Error::new(
140 std::io::ErrorKind::InvalidData,
141 "Invalid private key",
142 )),
143 Err(err) => Err(err),
144 }
145}