1use crate::util::HostnameRadixTree;
2use rustls::server::{ClientHello, ResolvesServerCert};
3use rustls::sign::CertifiedKey;
4use rustls_pki_types::pem::PemObject;
5use rustls_pki_types::{CertificateDer, PrivateKeyDer};
6use std::sync::Arc;
7
8pub type SniResolverLock = Arc<tokio::sync::RwLock<HostnameRadixTree<Arc<dyn ResolvesServerCert>>>>;
10
11#[derive(Debug)]
13pub struct CustomSniResolver {
14 fallback_resolver: Option<Arc<dyn ResolvesServerCert>>,
15 resolvers: SniResolverLock,
16 fallback_sender: Option<(async_channel::Sender<(String, u16)>, u16)>,
17}
18
19impl CustomSniResolver {
20 #[allow(dead_code)]
22 pub fn new() -> Self {
23 Self {
24 fallback_resolver: None,
25 resolvers: Arc::new(tokio::sync::RwLock::new(HostnameRadixTree::new())),
26 fallback_sender: None,
27 }
28 }
29
30 pub fn with_resolvers(resolvers: SniResolverLock) -> Self {
32 Self {
33 fallback_resolver: None,
34 resolvers,
35 fallback_sender: None,
36 }
37 }
38
39 pub fn load_fallback_resolver(&mut self, fallback_resolver: Arc<dyn ResolvesServerCert>) {
41 self.fallback_resolver = Some(fallback_resolver);
42 }
43
44 pub fn load_host_resolver(&mut self, host: &str, resolver: Arc<dyn ResolvesServerCert>) {
46 self.resolvers.blocking_write().insert(host.to_string(), resolver);
47 }
48
49 pub fn load_fallback_sender(&mut self, fallback_sender: async_channel::Sender<(String, u16)>, port: u16) {
51 self.fallback_sender = Some((fallback_sender, port));
52 }
53}
54
55impl ResolvesServerCert for CustomSniResolver {
56 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
57 let hostname = client_hello.server_name().map(|hn| hn.strip_suffix('.').unwrap_or(hn));
58 if let Some(hostname) = hostname {
59 #[cfg(feature = "runtime-monoio")]
61 let resolvers = self.resolvers.blocking_read();
62 #[cfg(feature = "runtime-tokio")]
63 let resolvers = futures_executor::block_on(async { self.resolvers.read().await });
64
65 if let Some(resolver) = resolvers.get(hostname).cloned() {
66 return resolver.resolve(client_hello);
67 }
68 }
69 let hostname = hostname.map(|v| v.to_string());
70 self
71 .fallback_resolver
72 .as_ref()
73 .and_then(|r| r.resolve(client_hello))
74 .or_else(|| {
75 if let Some((sender, port)) = &self.fallback_sender {
76 if let Some(hostname) = hostname {
77 sender.send_blocking((hostname.to_string(), *port)).unwrap_or_default();
78 }
79 }
80 None
81 })
82 }
83}
84
85#[derive(Debug)]
87pub struct OneCertifiedKeyResolver {
88 certified_key: Arc<CertifiedKey>,
89}
90
91impl OneCertifiedKeyResolver {
92 pub fn new(certified_key: Arc<CertifiedKey>) -> Self {
94 Self { certified_key }
95 }
96}
97
98impl ResolvesServerCert for OneCertifiedKeyResolver {
99 fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
100 Some(self.certified_key.clone())
101 }
102}
103
104pub fn load_certs(filename: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
106 let mut certfile = std::fs::File::open(filename)?;
107 CertificateDer::pem_reader_iter(&mut certfile)
108 .collect::<Result<Vec<_>, _>>()
109 .map_err(|e| match e {
110 rustls_pki_types::pem::Error::Io(err) => err,
111 err => std::io::Error::other(err),
112 })
113}
114
115pub fn load_private_key(filename: &str) -> std::io::Result<PrivateKeyDer<'static>> {
117 let mut keyfile = std::fs::File::open(filename)?;
118 match PrivateKeyDer::from_pem_reader(&mut keyfile) {
119 Ok(private_key) => Ok(private_key),
120 Err(rustls_pki_types::pem::Error::Io(err)) => Err(err),
121 Err(err) => Err(std::io::Error::other(err)),
122 }
123}