1use crate::util::match_hostname;
2use rustls::server::{ClientHello, ResolvesServerCert};
3use rustls::sign::CertifiedKey;
4use rustls_pki_types::pem::PemObject;
5use rustls_pki_types::{CertificateDer, PrivateKeyDer};
6use std::sync::Arc;
7
8#[derive(Debug)]
10pub struct CustomSniResolver {
11 fallback_resolver: Option<Arc<dyn ResolvesServerCert>>,
12 #[allow(clippy::type_complexity)]
13 resolvers: Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>,
14 fallback_sender: Option<(async_channel::Sender<(String, u16)>, u16)>,
15}
16
17impl CustomSniResolver {
18 #[allow(dead_code)]
20 pub fn new() -> Self {
21 Self {
22 fallback_resolver: None,
23 resolvers: Arc::new(tokio::sync::RwLock::new(Vec::new())),
24 fallback_sender: None,
25 }
26 }
27
28 #[allow(clippy::type_complexity)]
30 pub fn with_resolvers(resolvers: Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>) -> Self {
31 Self {
32 fallback_resolver: None,
33 resolvers,
34 fallback_sender: None,
35 }
36 }
37
38 pub fn load_fallback_resolver(&mut self, fallback_resolver: Arc<dyn ResolvesServerCert>) {
40 self.fallback_resolver = Some(fallback_resolver);
41 }
42
43 pub fn load_host_resolver(&mut self, host: &str, resolver: Arc<dyn ResolvesServerCert>) {
45 load_host_resolver(&mut self.resolvers.blocking_write(), host, resolver);
46 }
47
48 pub fn load_fallback_sender(&mut self, fallback_sender: async_channel::Sender<(String, u16)>, port: u16) {
50 self.fallback_sender = Some((fallback_sender, port));
51 }
52}
53
54impl ResolvesServerCert for CustomSniResolver {
55 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
56 let hostname = client_hello.server_name().map(|hn| hn.strip_suffix('.').unwrap_or(hn));
57 if let Some(hostname) = hostname {
58 #[cfg(feature = "runtime-monoio")]
60 let resolvers = self.resolvers.blocking_read();
61 #[cfg(feature = "runtime-tokio")]
62 let resolvers = futures_executor::block_on(async { self.resolvers.read().await });
63
64 for (configured_hostname, resolver) in resolvers.iter() {
65 if match_hostname(Some(configured_hostname), Some(hostname)) {
66 return resolver.resolve(client_hello);
67 }
68 }
69 }
70 let hostname = hostname.map(|v| v.to_string());
71 self
72 .fallback_resolver
73 .as_ref()
74 .and_then(|r| r.resolve(client_hello))
75 .or_else(|| {
76 if let Some((sender, port)) = &self.fallback_sender {
77 if let Some(hostname) = hostname {
78 sender.send_blocking((hostname.to_string(), *port)).unwrap_or_default();
79 }
80 }
81 None
82 })
83 }
84}
85
86#[derive(Debug)]
88pub struct OneCertifiedKeyResolver {
89 certified_key: Arc<CertifiedKey>,
90}
91
92impl OneCertifiedKeyResolver {
93 pub fn new(certified_key: Arc<CertifiedKey>) -> Self {
95 Self { certified_key }
96 }
97}
98
99impl ResolvesServerCert for OneCertifiedKeyResolver {
100 fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
101 Some(self.certified_key.clone())
102 }
103}
104
105pub fn load_host_resolver(
107 resolvers: &mut Vec<(String, Arc<dyn ResolvesServerCert>)>,
108 host: &str,
109 resolver: Arc<dyn ResolvesServerCert>,
110) {
111 if !resolvers.iter().any(|(h, _)| h == host) {
112 resolvers.push((host.to_string(), resolver));
113 }
114 resolvers.sort_by(|a, b| {
115 (a.0.starts_with("*."))
116 .cmp(&(b.0.starts_with("*."))) .then_with(|| {
118 b.0
119 .trim_end_matches('.')
120 .chars()
121 .filter(|c| *c == '.')
122 .count()
123 .cmp(&a.0.trim_end_matches('.').chars().filter(|c| *c == '.').count())
124 }) });
126}
127pub fn load_certs(filename: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
129 let mut certfile = std::fs::File::open(filename)?;
130 CertificateDer::pem_reader_iter(&mut certfile)
131 .collect::<Result<Vec<_>, _>>()
132 .map_err(|e| match e {
133 rustls_pki_types::pem::Error::Io(err) => err,
134 err => std::io::Error::other(err),
135 })
136}
137
138pub fn load_private_key(filename: &str) -> std::io::Result<PrivateKeyDer<'static>> {
140 let mut keyfile = std::fs::File::open(filename)?;
141 match PrivateKeyDer::from_pem_reader(&mut keyfile) {
142 Ok(private_key) => Ok(private_key),
143 Err(rustls_pki_types::pem::Error::Io(err)) => Err(err),
144 Err(err) => Err(std::io::Error::other(err)),
145 }
146}