1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, SystemTime};
4
5use ferron_common::logging::LogMessage;
6use hyper::Request;
7use hyper_util::client::legacy::Client;
8use hyper_util::rt::TokioExecutor;
9use rasn::prelude::*;
10use rasn_ocsp::{CertId, OcspRequest, OcspResponse, OcspResponseStatus, Request as OcspInnerRequest, TbsRequest};
11use rustls::client::WebPkiServerVerifier;
12use rustls::server::{ClientHello, ResolvesServerCert};
13use rustls::sign::CertifiedKey;
14use rustls_pki_types::CertificateDer;
15use rustls_platform_verifier::BuilderVerifierExt;
16use sha1::{Digest, Sha1};
17use sha2::Sha256;
18use tokio::sync::RwLock;
19use tokio_util::sync::CancellationToken;
20use x509_parser::prelude::*;
21
22type OcspCache = Arc<RwLock<HashMap<Vec<u8>, Option<Arc<CertifiedKey>>>>>;
23
24#[derive(Debug)]
25pub struct OcspStapler {
26 inner: Arc<dyn ResolvesServerCert>,
27 cache: OcspCache,
28 sender: async_channel::Sender<CertifiedKey>,
29 cancel_token: CancellationToken,
30}
31
32impl OcspStapler {
33 pub fn new(
34 inner: Arc<dyn ResolvesServerCert>,
35 runtime: &tokio::runtime::Runtime,
36 logging_tx: Vec<async_channel::Sender<LogMessage>>,
37 ) -> Self {
38 let (sender, receiver) = async_channel::unbounded();
39 let cache = Arc::new(RwLock::new(HashMap::new()));
40 let cancel_token = CancellationToken::new();
41
42 let stapler = Self {
43 inner,
44 cache,
45 sender,
46 cancel_token: cancel_token.clone(),
47 };
48
49 runtime.spawn(background_ocsp_task(
50 receiver,
51 stapler.cache.clone(),
52 cancel_token,
53 logging_tx,
54 ));
55
56 stapler
57 }
58
59 pub fn preload(&self, key: Arc<CertifiedKey>) {
60 if !key.cert.is_empty() {
61 let _ = self.sender.send_blocking((*key).clone());
63 }
64 }
65
66 pub async fn stop(&self) {
67 self.cancel_token.cancel();
68 }
69}
70
71impl ResolvesServerCert for OcspStapler {
72 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
73 let original_key = self.inner.resolve(client_hello)?;
74 if let Some(leaf) = original_key.cert.first() {
75 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
79 let cache = self.cache.blocking_read();
80 #[cfg(feature = "runtime-tokio")]
81 let cache = futures_executor::block_on(async { self.cache.read().await });
82
83 if let Some(cached_key_option) = cache.get(&leaf.to_vec()) {
84 if let Some(cached_key) = cached_key_option.as_ref() {
85 if cached_key.ocsp.is_some() {
89 return Some(cached_key.clone());
90 }
91 }
92 } else {
94 let _ = self.sender.send_blocking((*original_key).clone());
96 }
97 }
98 Some(original_key)
99 }
100}
101
102async fn background_ocsp_task(
103 receiver: async_channel::Receiver<CertifiedKey>,
104 cache: OcspCache,
105 cancel_token: CancellationToken,
106 logging_tx: Vec<async_channel::Sender<LogMessage>>,
107) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
108 let mut next_updates: HashMap<Vec<u8>, SystemTime> = HashMap::new();
110 let mut known_certs: HashMap<Vec<u8>, CertifiedKey> = HashMap::new();
112
113 let tls_config_builder =
115 match rustls::ClientConfig::builder_with_provider(rustls::crypto::aws_lc_rs::default_provider().into())
116 .with_safe_default_protocol_versions()
117 {
118 Ok(builder) => builder,
119 Err(e) => {
120 for tx in &logging_tx {
121 let _ = tx
122 .send(LogMessage::new(
123 format!("Failed to create TLS config builder for OCSP stapling: {e}"),
124 true,
125 ))
126 .await;
127 }
128 return Err(e.into());
129 }
130 };
131 let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
132 .with_tls_config(
133 (if let Ok(client_config) = BuilderVerifierExt::with_platform_verifier(tls_config_builder.clone()) {
134 client_config
135 } else {
136 tls_config_builder.with_webpki_verifier(
137 match WebPkiServerVerifier::builder(Arc::new(rustls::RootCertStore {
138 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
139 }))
140 .build()
141 {
142 Ok(verifier) => verifier,
143 Err(e) => {
144 for tx in &logging_tx {
145 let _ = tx
146 .send(LogMessage::new(
147 format!("Failed to create TLS verifier for OCSP stapling: {e}"),
148 true,
149 ))
150 .await;
151 }
152 return Err(e.into());
153 }
154 },
155 )
156 })
157 .with_no_client_auth(),
158 )
159 .https_or_http()
160 .enable_http1()
161 .build();
162
163 let client =
164 Client::builder(TokioExecutor::new()).build::<_, http_body_util::Full<hyper::body::Bytes>>(https_connector);
165
166 loop {
167 let mut sleep_duration = Duration::from_secs(60); let now = SystemTime::now();
171 for next_update in next_updates.values() {
172 if let Ok(duration) = next_update.duration_since(now) {
173 if duration < sleep_duration {
174 sleep_duration = duration;
175 }
176 } else {
177 sleep_duration = Duration::from_secs(1);
179 }
180 }
181
182 let received_certified_key = tokio::select! {
183 _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Cancelled"))?,
184 _ = tokio::time::sleep(sleep_duration) => None,
185 res = receiver.recv() => match res {
186 Ok(chain) => Some(chain),
187 Err(e) => Err(e)?, }
189 };
190
191 if let Some(certified_key) = received_certified_key {
192 let chain = &certified_key.cert;
193 if let Some(leaf) = chain.first() {
194 let key = leaf.to_vec();
195 if !known_certs.contains_key(&key) {
196 known_certs.insert(key.clone(), certified_key);
197 next_updates.insert(key, SystemTime::now());
199 }
200 }
201 }
202
203 let now = SystemTime::now();
205 let mut updates_to_fetch = Vec::new();
206 for (key, next_update) in &next_updates {
207 if *next_update <= now {
208 updates_to_fetch.push(key.clone());
209 }
210 }
211
212 for key in updates_to_fetch {
213 if let Some(certified_key) = known_certs.get(&key) {
214 match fetch_ocsp_response(&client, &certified_key.cert).await {
215 Ok(Some((response, next_update_time))) => {
216 let mut new_certified_key = certified_key.clone();
217 new_certified_key.ocsp = Some(response.clone());
218 cache
219 .write()
220 .await
221 .insert(certified_key.cert[0].to_vec(), Some(Arc::new(new_certified_key)));
222 next_updates.insert(key, next_update_time);
223 }
224 Ok(None) => {
225 cache.write().await.insert(certified_key.cert[0].to_vec(), None);
227 next_updates.remove(&key);
228 }
229 Err(e) => {
230 for tx in &logging_tx {
232 let _ = tx.send(LogMessage::new(format!("OCSP fetch failed: {e}"), true)).await;
233 }
234 next_updates.insert(key, now + Duration::from_secs(rand::random_range(100..=500)));
236 continue;
237 }
238 };
239 }
240 }
241 }
242}
243
244async fn fetch_ocsp_response(
245 client: &Client<
246 hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
247 http_body_util::Full<hyper::body::Bytes>,
248 >,
249 chain: &[CertificateDer<'_>],
250) -> anyhow::Result<Option<(Vec<u8>, SystemTime)>> {
251 let response = fetch_ocsp_response_inner(client, chain, true).await;
253
254 if response.is_ok() {
255 return response;
256 }
257
258 if let Ok(sha1_response) = fetch_ocsp_response_inner(client, chain, false).await {
260 return Ok(sha1_response);
261 }
262
263 response
265}
266
267async fn fetch_ocsp_response_inner(
268 client: &Client<
269 hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
270 http_body_util::Full<hyper::body::Bytes>,
271 >,
272 chain: &[CertificateDer<'_>],
273 use_sha256: bool,
274) -> anyhow::Result<Option<(Vec<u8>, SystemTime)>> {
275 if chain.len() < 2 {
276 return Ok(None);
278 }
279 let leaf = &chain[0];
280 let issuer = &chain[1];
281
282 let leaf_cert = X509Certificate::from_der(leaf)?.1;
283 let issuer_cert = X509Certificate::from_der(issuer)?.1;
284
285 let Some(ocsp_url) = extract_ocsp_url(&leaf_cert) else {
287 return Ok(None);
289 };
290
291 let req_der = create_ocsp_request(&leaf_cert, &issuer_cert, use_sha256)?;
293
294 let req = Request::builder()
295 .method("POST")
296 .uri(&ocsp_url)
297 .header("Content-Type", "application/ocsp-request")
298 .body(http_body_util::Full::new(hyper::body::Bytes::from(req_der)))?;
299
300 let res = client.request(req).await?;
301 if !res.status().is_success() {
302 return Err(anyhow::anyhow!(
303 "OCSP request failed with status: {} for URL: {ocsp_url}",
304 res.status()
305 ));
306 }
307
308 use http_body_util::BodyExt;
310 let body_bytes = res.collect().await?.to_bytes();
311 let response_der = body_bytes.to_vec();
312
313 let response: OcspResponse =
315 rasn::der::decode(&response_der).map_err(|e| anyhow::anyhow!("Failed to decode OCSP response: {}", e))?;
316
317 if response.status != OcspResponseStatus::Successful {
318 return Err(anyhow::anyhow!(
319 "OCSP response status unsuccessful: {:?}",
320 response.status
321 ));
322 }
323
324 let bytes = response.bytes.ok_or_else(|| anyhow::anyhow!("No response bytes"))?;
325 if bytes.r#type
326 != ObjectIdentifier::new(vec![1, 3, 6, 1, 5, 5, 7, 48, 1, 1])
327 .ok_or_else(|| anyhow::anyhow!("Invalid OCSP basic response OID"))?
328 {
329 return Err(anyhow::anyhow!("Unsupported OCSP response type"));
330 }
331
332 let basic_response: rasn_ocsp::BasicOcspResponse =
333 rasn::der::decode(&bytes.response).map_err(|e| anyhow::anyhow!("Failed to decode BasicOcspResponse: {}", e))?;
334
335 let mut min_next_update = None;
338
339 for single_res in basic_response.tbs_response_data.responses {
343 let next_update = single_res.next_update.map(SystemTime::from);
344
345 if let Some(mut nu) = next_update {
346 let nu_safety_margin = nu
348 .duration_since(SystemTime::from(single_res.this_update))
349 .map(|d| d / 4)
350 .unwrap_or_else(|_| Duration::from_secs(0))
351 .max(Duration::from_hours(1)); let nu_safety_margin = nu_safety_margin + (nu_safety_margin.mul_f64(rand::random_range::<f64, _>(0.0..0.5)));
355
356 if nu - nu_safety_margin > SystemTime::now() {
357 nu -= nu_safety_margin;
358 }
359
360 match min_next_update {
361 Some(min) => {
362 if nu < min {
363 min_next_update = Some(nu)
364 }
365 }
366 None => min_next_update = Some(nu),
367 }
368 }
369 }
370
371 let next_update = min_next_update.unwrap_or_else(|| SystemTime::now() + Duration::from_hours(12));
372
373 Ok(Some((response_der, next_update)))
374}
375
376fn extract_ocsp_url(cert: &X509Certificate) -> Option<String> {
377 for ext in cert.extensions() {
378 if let x509_parser::extensions::ParsedExtension::AuthorityInfoAccess(aia) = ext.parsed_extension() {
379 for access_desc in &aia.accessdescs {
380 if access_desc.access_method == x509_parser::oid_registry::OID_PKIX_ACCESS_DESCRIPTOR_OCSP {
381 if let x509_parser::extensions::GeneralName::URI(uri) = access_desc.access_location {
382 return Some(uri.to_string());
383 }
384 }
385 }
386 }
387 }
388 None
389}
390
391fn create_ocsp_request(leaf: &X509Certificate, issuer: &X509Certificate, use_sha256: bool) -> anyhow::Result<Vec<u8>> {
392 let issuer_name_hash = if use_sha256 {
394 let mut sha256 = Sha256::new();
395 sha256.update(issuer.subject().as_raw());
396 sha256.finalize().to_vec()
397 } else {
398 let mut sha1 = Sha1::new();
399 sha1.update(issuer.subject().as_raw());
400 sha1.finalize().to_vec()
401 };
402
403 let spki = issuer.public_key();
407 let pub_key_bytes = &spki.subject_public_key.data;
409 let issuer_key_hash = if use_sha256 {
410 let mut sha256 = Sha256::new();
411 sha256.update(pub_key_bytes);
412 sha256.finalize().to_vec()
413 } else {
414 let mut sha1 = Sha1::new();
415 sha1.update(pub_key_bytes);
416 sha1.finalize().to_vec()
417 };
418
419 let serial_number = &leaf.tbs_certificate.serial;
421 let serial_int = rasn::types::Integer::from(num_bigint::BigInt::from_biguint(
424 num_bigint::Sign::Plus,
425 serial_number.to_owned(),
426 ));
427
428 let cert_id = CertId {
429 hash_algorithm: rasn_pkix::AlgorithmIdentifier {
430 algorithm: if use_sha256 {
431 rasn::types::Oid::JOINT_ISO_ITU_T_COUNTRY_US_ORGANIZATION_GOV_CSOR_NIST_ALGORITHMS_HASH_SHA256.to_owned()
432 } else {
433 rasn::types::Oid::ISO_IDENTIFIED_ORGANISATION_OIW_SECSIG_ALGORITHM_SHA1.to_owned()
434 },
435 parameters: None,
436 },
437 issuer_name_hash: rasn::types::OctetString::from(issuer_name_hash),
438 issuer_key_hash: rasn::types::OctetString::from(issuer_key_hash),
439 serial_number: serial_int,
440 };
441
442 let req = OcspRequest {
443 tbs_request: TbsRequest {
444 version: rasn::types::Integer::from(0), requestor_name: None,
446 request_list: vec![OcspInnerRequest {
447 req_cert: cert_id,
448 single_request_extensions: None,
449 }],
450 request_extensions: None,
451 },
452 optional_signature: None,
453 };
454
455 rasn::der::encode(&req).map_err(|e| anyhow::anyhow!(e))
456}