ferron/setup/
ocsp.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, SystemTime};
4
5use ferron_common::logging::LogMessage;
6use hyper::Request;
7use hyper_util::client::legacy::Client;
8use hyper_util::rt::TokioExecutor;
9use rasn::prelude::*;
10use rasn_ocsp::{CertId, OcspRequest, OcspResponse, OcspResponseStatus, Request as OcspInnerRequest, TbsRequest};
11use rustls::client::WebPkiServerVerifier;
12use rustls::server::{ClientHello, ResolvesServerCert};
13use rustls::sign::CertifiedKey;
14use rustls_pki_types::CertificateDer;
15use rustls_platform_verifier::BuilderVerifierExt;
16use sha1::{Digest, Sha1};
17use sha2::Sha256;
18use tokio::sync::RwLock;
19use tokio_util::sync::CancellationToken;
20use x509_parser::prelude::*;
21
22type OcspCache = Arc<RwLock<HashMap<Vec<u8>, Option<Arc<CertifiedKey>>>>>;
23
24#[derive(Debug)]
25pub struct OcspStapler {
26  inner: Arc<dyn ResolvesServerCert>,
27  cache: OcspCache,
28  sender: async_channel::Sender<CertifiedKey>,
29  cancel_token: CancellationToken,
30}
31
32impl OcspStapler {
33  pub fn new(
34    inner: Arc<dyn ResolvesServerCert>,
35    runtime: &tokio::runtime::Runtime,
36    logging_tx: Vec<async_channel::Sender<LogMessage>>,
37  ) -> Self {
38    let (sender, receiver) = async_channel::unbounded();
39    let cache = Arc::new(RwLock::new(HashMap::new()));
40    let cancel_token = CancellationToken::new();
41
42    let stapler = Self {
43      inner,
44      cache,
45      sender,
46      cancel_token: cancel_token.clone(),
47    };
48
49    runtime.spawn(background_ocsp_task(
50      receiver,
51      stapler.cache.clone(),
52      cancel_token,
53      logging_tx,
54    ));
55
56    stapler
57  }
58
59  pub fn preload(&self, key: Arc<CertifiedKey>) {
60    if !key.cert.is_empty() {
61      // Add to cache immediately (even without OCSP) to track it, or just trigger fetch
62      let _ = self.sender.send_blocking((*key).clone());
63    }
64  }
65
66  pub async fn stop(&self) {
67    self.cancel_token.cancel();
68  }
69}
70
71impl ResolvesServerCert for OcspStapler {
72  fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
73    let original_key = self.inner.resolve(client_hello)?;
74    if let Some(leaf) = original_key.cert.first() {
75      // Check cache
76      //
77      // If blocking_read() method is used when only Tokio is used, the program would panic on resolving a TLS certificate.
78      #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
79      let cache = self.cache.blocking_read();
80      #[cfg(feature = "runtime-tokio")]
81      let cache = futures_executor::block_on(async { self.cache.read().await });
82
83      if let Some(cached_key_option) = cache.get(&leaf.to_vec()) {
84        if let Some(cached_key) = cached_key_option.as_ref() {
85          // If cached key has OCSP, return it.
86          // Note: We might want to check if it's expired here, but the background task handles cleanup/refresh.
87          // For simplicity, we return what's in cache.
88          if cached_key.ocsp.is_some() {
89            return Some(cached_key.clone());
90          }
91        }
92        // If cached key has no OCSP, don't trigger fetch.
93      } else {
94        // Not in cache or no OCSP yet. Trigger fetch.
95        let _ = self.sender.send_blocking((*original_key).clone());
96      }
97    }
98    Some(original_key)
99  }
100}
101
102async fn background_ocsp_task(
103  receiver: async_channel::Receiver<CertifiedKey>,
104  cache: OcspCache,
105  cancel_token: CancellationToken,
106  logging_tx: Vec<async_channel::Sender<LogMessage>>,
107) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
108  // Track next update times
109  let mut next_updates: HashMap<Vec<u8>, SystemTime> = HashMap::new();
110  // Track known cert chains
111  let mut known_certs: HashMap<Vec<u8>, CertifiedKey> = HashMap::new();
112
113  // Create HTTP client
114  let tls_config_builder =
115    match rustls::ClientConfig::builder_with_provider(rustls::crypto::aws_lc_rs::default_provider().into())
116      .with_safe_default_protocol_versions()
117    {
118      Ok(builder) => builder,
119      Err(e) => {
120        for tx in &logging_tx {
121          let _ = tx
122            .send(LogMessage::new(
123              format!("Failed to create TLS config builder for OCSP stapling: {e}"),
124              true,
125            ))
126            .await;
127        }
128        return Err(e.into());
129      }
130    };
131  let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
132    .with_tls_config(
133      (if let Ok(client_config) = BuilderVerifierExt::with_platform_verifier(tls_config_builder.clone()) {
134        client_config
135      } else {
136        tls_config_builder.with_webpki_verifier(
137          match WebPkiServerVerifier::builder(Arc::new(rustls::RootCertStore {
138            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
139          }))
140          .build()
141          {
142            Ok(verifier) => verifier,
143            Err(e) => {
144              for tx in &logging_tx {
145                let _ = tx
146                  .send(LogMessage::new(
147                    format!("Failed to create TLS verifier for OCSP stapling: {e}"),
148                    true,
149                  ))
150                  .await;
151              }
152              return Err(e.into());
153            }
154          },
155        )
156      })
157      .with_no_client_auth(),
158    )
159    .https_or_http()
160    .enable_http1()
161    .build();
162
163  let client =
164    Client::builder(TokioExecutor::new()).build::<_, http_body_util::Full<hyper::body::Bytes>>(https_connector);
165
166  loop {
167    let mut sleep_duration = Duration::from_secs(60); // Default check interval
168
169    // Calculate time to next update
170    let now = SystemTime::now();
171    for next_update in next_updates.values() {
172      if let Ok(duration) = next_update.duration_since(now) {
173        if duration < sleep_duration {
174          sleep_duration = duration;
175        }
176      } else {
177        // Already expired, refresh immediately (or very soon)
178        sleep_duration = Duration::from_secs(1);
179      }
180    }
181
182    let received_certified_key = tokio::select! {
183      _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Cancelled"))?,
184      _ = tokio::time::sleep(sleep_duration) => None,
185      res = receiver.recv() => match res {
186        Ok(chain) => Some(chain),
187        Err(e) => Err(e)?, // Channel closed
188      }
189    };
190
191    if let Some(certified_key) = received_certified_key {
192      let chain = &certified_key.cert;
193      if let Some(leaf) = chain.first() {
194        let key = leaf.to_vec();
195        if !known_certs.contains_key(&key) {
196          known_certs.insert(key.clone(), certified_key);
197          // Trigger immediate update for new cert
198          next_updates.insert(key, SystemTime::now());
199        }
200      }
201    }
202
203    // Process updates
204    let now = SystemTime::now();
205    let mut updates_to_fetch = Vec::new();
206    for (key, next_update) in &next_updates {
207      if *next_update <= now {
208        updates_to_fetch.push(key.clone());
209      }
210    }
211
212    for key in updates_to_fetch {
213      if let Some(certified_key) = known_certs.get(&key) {
214        match fetch_ocsp_response(&client, &certified_key.cert).await {
215          Ok(Some((response, next_update_time))) => {
216            let mut new_certified_key = certified_key.clone();
217            new_certified_key.ocsp = Some(response.clone());
218            cache
219              .write()
220              .await
221              .insert(certified_key.cert[0].to_vec(), Some(Arc::new(new_certified_key)));
222            next_updates.insert(key, next_update_time);
223          }
224          Ok(None) => {
225            // Don't retry OCSP stapling
226            cache.write().await.insert(certified_key.cert[0].to_vec(), None);
227            next_updates.remove(&key);
228          }
229          Err(e) => {
230            // Log error
231            for tx in &logging_tx {
232              let _ = tx.send(LogMessage::new(format!("OCSP fetch failed: {e}"), true)).await;
233            }
234            // Retry later; with some randomness to avoid refresh storm.
235            next_updates.insert(key, now + Duration::from_secs(rand::random_range(100..=500)));
236            continue;
237          }
238        };
239      }
240    }
241  }
242}
243
244async fn fetch_ocsp_response(
245  client: &Client<
246    hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
247    http_body_util::Full<hyper::body::Bytes>,
248  >,
249  chain: &[CertificateDer<'_>],
250) -> anyhow::Result<Option<(Vec<u8>, SystemTime)>> {
251  // Try SHA-256 first
252  let response = fetch_ocsp_response_inner(client, chain, true).await;
253
254  if response.is_ok() {
255    return response;
256  }
257
258  // SHA-1 fallback
259  if let Ok(sha1_response) = fetch_ocsp_response_inner(client, chain, false).await {
260    return Ok(sha1_response);
261  }
262
263  // If both fail, return the error from SHA-256
264  response
265}
266
267async fn fetch_ocsp_response_inner(
268  client: &Client<
269    hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
270    http_body_util::Full<hyper::body::Bytes>,
271  >,
272  chain: &[CertificateDer<'_>],
273  use_sha256: bool,
274) -> anyhow::Result<Option<(Vec<u8>, SystemTime)>> {
275  if chain.len() < 2 {
276    // Certificate chain too short, don't bother with OCSP
277    return Ok(None);
278  }
279  let leaf = &chain[0];
280  let issuer = &chain[1];
281
282  let leaf_cert = X509Certificate::from_der(leaf)?.1;
283  let issuer_cert = X509Certificate::from_der(issuer)?.1;
284
285  // Extract OCSP URL
286  let Some(ocsp_url) = extract_ocsp_url(&leaf_cert) else {
287    // No OCSP URL found
288    return Ok(None);
289  };
290
291  // Create Request
292  let req_der = create_ocsp_request(&leaf_cert, &issuer_cert, use_sha256)?;
293
294  let req = Request::builder()
295    .method("POST")
296    .uri(&ocsp_url)
297    .header("Content-Type", "application/ocsp-request")
298    .body(http_body_util::Full::new(hyper::body::Bytes::from(req_der)))?;
299
300  let res = client.request(req).await?;
301  if !res.status().is_success() {
302    return Err(anyhow::anyhow!(
303      "OCSP request failed with status: {} for URL: {ocsp_url}",
304      res.status()
305    ));
306  }
307
308  // Read response
309  use http_body_util::BodyExt;
310  let body_bytes = res.collect().await?.to_bytes();
311  let response_der = body_bytes.to_vec();
312
313  // Parse response to get next update
314  let response: OcspResponse =
315    rasn::der::decode(&response_der).map_err(|e| anyhow::anyhow!("Failed to decode OCSP response: {}", e))?;
316
317  if response.status != OcspResponseStatus::Successful {
318    return Err(anyhow::anyhow!(
319      "OCSP response status unsuccessful: {:?}",
320      response.status
321    ));
322  }
323
324  let bytes = response.bytes.ok_or_else(|| anyhow::anyhow!("No response bytes"))?;
325  if bytes.r#type
326    != ObjectIdentifier::new(vec![1, 3, 6, 1, 5, 5, 7, 48, 1, 1])
327      .ok_or_else(|| anyhow::anyhow!("Invalid OCSP basic response OID"))?
328  {
329    return Err(anyhow::anyhow!("Unsupported OCSP response type"));
330  }
331
332  let basic_response: rasn_ocsp::BasicOcspResponse =
333    rasn::der::decode(&bytes.response).map_err(|e| anyhow::anyhow!("Failed to decode BasicOcspResponse: {}", e))?;
334
335  // Check validities of all single responses.
336  // For simplicity, take the earliest next_update.
337  let mut min_next_update = None;
338
339  // Need to adjust for data types. `rasn_ocsp` uses `rasn::types::UtcTime` or `GeneralizedTime`.
340  // We need to convert to SystemTime.
341
342  for single_res in basic_response.tbs_response_data.responses {
343    let next_update = single_res.next_update.map(SystemTime::from);
344
345    if let Some(mut nu) = next_update {
346      // Next update with safety margin.
347      let nu_safety_margin = nu
348        .duration_since(SystemTime::from(single_res.this_update))
349        .map(|d| d / 4)
350        .unwrap_or_else(|_| Duration::from_secs(0))
351        .max(Duration::from_hours(1)); // Minimum 1h
352
353      // Add randomness to avoid refresh storm.
354      let nu_safety_margin = nu_safety_margin + (nu_safety_margin.mul_f64(rand::random_range::<f64, _>(0.0..0.5)));
355
356      if nu - nu_safety_margin > SystemTime::now() {
357        nu -= nu_safety_margin;
358      }
359
360      match min_next_update {
361        Some(min) => {
362          if nu < min {
363            min_next_update = Some(nu)
364          }
365        }
366        None => min_next_update = Some(nu),
367      }
368    }
369  }
370
371  let next_update = min_next_update.unwrap_or_else(|| SystemTime::now() + Duration::from_hours(12));
372
373  Ok(Some((response_der, next_update)))
374}
375
376fn extract_ocsp_url(cert: &X509Certificate) -> Option<String> {
377  for ext in cert.extensions() {
378    if let x509_parser::extensions::ParsedExtension::AuthorityInfoAccess(aia) = ext.parsed_extension() {
379      for access_desc in &aia.accessdescs {
380        if access_desc.access_method == x509_parser::oid_registry::OID_PKIX_ACCESS_DESCRIPTOR_OCSP {
381          if let x509_parser::extensions::GeneralName::URI(uri) = access_desc.access_location {
382            return Some(uri.to_string());
383          }
384        }
385      }
386    }
387  }
388  None
389}
390
391fn create_ocsp_request(leaf: &X509Certificate, issuer: &X509Certificate, use_sha256: bool) -> anyhow::Result<Vec<u8>> {
392  // 1. Hash Issuer DN
393  let issuer_name_hash = if use_sha256 {
394    let mut sha256 = Sha256::new();
395    sha256.update(issuer.subject().as_raw());
396    sha256.finalize().to_vec()
397  } else {
398    let mut sha1 = Sha1::new();
399    sha1.update(issuer.subject().as_raw());
400    sha1.finalize().to_vec()
401  };
402
403  // 2. Hash Issuer Key
404  // x509-parser gives SubjectPublicKeyInfo.
405  // RFC 6960: hash of the value (excluding tag and length) of the subject public key field.
406  let spki = issuer.public_key();
407  // spki.subject_public_key is BitString. We want the bytes.
408  let pub_key_bytes = &spki.subject_public_key.data;
409  let issuer_key_hash = if use_sha256 {
410    let mut sha256 = Sha256::new();
411    sha256.update(pub_key_bytes);
412    sha256.finalize().to_vec()
413  } else {
414    let mut sha1 = Sha1::new();
415    sha1.update(pub_key_bytes);
416    sha1.finalize().to_vec()
417  };
418
419  // 3. Serial Number
420  let serial_number = &leaf.tbs_certificate.serial;
421  // Need to convert x509_parser serial (BigUint) to rasn Integer.
422  // x509_parser serial is `BigUint`. rasn `Integer` is BigInt.
423  let serial_int = rasn::types::Integer::from(num_bigint::BigInt::from_biguint(
424    num_bigint::Sign::Plus,
425    serial_number.to_owned(),
426  ));
427
428  let cert_id = CertId {
429    hash_algorithm: rasn_pkix::AlgorithmIdentifier {
430      algorithm: if use_sha256 {
431        rasn::types::Oid::JOINT_ISO_ITU_T_COUNTRY_US_ORGANIZATION_GOV_CSOR_NIST_ALGORITHMS_HASH_SHA256.to_owned()
432      } else {
433        rasn::types::Oid::ISO_IDENTIFIED_ORGANISATION_OIW_SECSIG_ALGORITHM_SHA1.to_owned()
434      },
435      parameters: None,
436    },
437    issuer_name_hash: rasn::types::OctetString::from(issuer_name_hash),
438    issuer_key_hash: rasn::types::OctetString::from(issuer_key_hash),
439    serial_number: serial_int,
440  };
441
442  let req = OcspRequest {
443    tbs_request: TbsRequest {
444      version: rasn::types::Integer::from(0), // v1(0)
445      requestor_name: None,
446      request_list: vec![OcspInnerRequest {
447        req_cert: cert_id,
448        single_request_extensions: None,
449      }],
450      request_extensions: None,
451    },
452    optional_signature: None,
453  };
454
455  rasn::der::encode(&req).map_err(|e| anyhow::anyhow!(e))
456}