ferron/
main.rs

1mod acme;
2mod config;
3mod handler;
4mod listener_handler_communication;
5mod listener_quic;
6mod listener_tcp;
7mod request_handler;
8mod runtime;
9mod tls_util;
10mod util;
11
12use std::collections::{HashMap, HashSet};
13use std::error::Error;
14use std::net::{IpAddr, Ipv6Addr, SocketAddr};
15use std::path::{Path, PathBuf};
16use std::str::FromStr;
17use std::sync::{Arc, LazyLock, Mutex};
18use std::thread;
19use std::time::Duration;
20
21use async_channel::{Receiver, Sender};
22use base64::Engine;
23use clap::{Arg, ArgAction, ArgMatches, Command};
24use config::adapters::ConfigurationAdapter;
25use config::processing::{load_modules, merge_duplicates, premerge_configuration, remove_and_add_global_configuration};
26use config::ServerConfigurations;
27use ferron_common::logging::LogMessage;
28use ferron_common::{get_entry, get_value, get_values};
29use ferron_load_modules::{get_dns_provider, obtain_module_loaders, obtain_observability_backend_loaders};
30use handler::create_http_handler;
31use human_panic::{setup_panic, Metadata};
32use instant_acme::{ChallengeType, ExternalAccountKey, LetsEncrypt};
33use listener_handler_communication::ConnectionData;
34use listener_quic::create_quic_listener;
35use listener_tcp::create_tcp_listener;
36use mimalloc::MiMalloc;
37use rustls::client::WebPkiServerVerifier;
38use rustls::crypto::aws_lc_rs::cipher_suite::*;
39use rustls::crypto::aws_lc_rs::default_provider;
40use rustls::crypto::aws_lc_rs::kx_group::*;
41use rustls::server::{ResolvesServerCert, WebPkiClientVerifier};
42use rustls::sign::CertifiedKey;
43use rustls::version::{TLS12, TLS13};
44use rustls::{ClientConfig, RootCertStore, ServerConfig};
45use rustls_native_certs::load_native_certs;
46use rustls_platform_verifier::BuilderVerifierExt;
47use shadow_rs::shadow;
48use tls_util::{load_certs, load_private_key, CustomSniResolver, OneCertifiedKeyResolver};
49use tokio_util::sync::CancellationToken;
50use xxhash_rust::xxh3::xxh3_128;
51
52use crate::acme::{
53  add_domain_to_cache, check_certificate_validity_or_install_cached, convert_on_demand_config, get_cached_domains,
54  provision_certificate, AcmeCache, AcmeConfig, AcmeOnDemandConfig, AcmeResolver, TlsAlpn01Resolver,
55  ACME_TLS_ALPN_NAME,
56};
57use crate::util::{is_localhost, match_hostname, NoServerVerifier};
58
59// Set the global allocator to use mimalloc for performance optimization
60#[global_allocator]
61static GLOBAL: MiMalloc = MiMalloc;
62
63shadow!(build);
64
65static LISTENER_HANDLER_CHANNEL: LazyLock<Arc<(Sender<ConnectionData>, Receiver<ConnectionData>)>> =
66  LazyLock::new(|| Arc::new(async_channel::unbounded()));
67#[allow(clippy::type_complexity)]
68static TCP_LISTENERS: LazyLock<Arc<Mutex<HashMap<SocketAddr, CancellationToken>>>> =
69  LazyLock::new(|| Arc::new(Mutex::new(HashMap::new())));
70#[allow(clippy::type_complexity)]
71static QUIC_LISTENERS: LazyLock<Arc<Mutex<HashMap<SocketAddr, (CancellationToken, Sender<Arc<ServerConfig>>)>>>> =
72  LazyLock::new(|| Arc::new(Mutex::new(HashMap::new())));
73static URING_ENABLED: LazyLock<Arc<Mutex<bool>>> = LazyLock::new(|| Arc::new(Mutex::new(true)));
74static LISTENER_LOGGING_CHANNEL: LazyLock<Arc<(Sender<LogMessage>, Receiver<LogMessage>)>> =
75  LazyLock::new(|| Arc::new(async_channel::unbounded()));
76
77/// Handles shutdown signals (SIGHUP and CTRL+C) and returns whether to continue running
78fn handle_shutdown_signals(runtime: &tokio::runtime::Runtime) -> bool {
79  runtime.block_on(async move {
80    let (continue_tx, continue_rx) = async_channel::unbounded::<bool>();
81    let cancel_token = tokio_util::sync::CancellationToken::new();
82
83    #[cfg(unix)]
84    {
85      let cancel_token_clone = cancel_token.clone();
86      let continue_tx_clone = continue_tx.clone();
87      tokio::spawn(async move {
88        if let Ok(mut signal) = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup()) {
89          tokio::select! {
90            _ = signal.recv() => {
91              continue_tx_clone.send(true).await.unwrap_or_default();
92            }
93            _ = cancel_token_clone.cancelled() => {}
94          }
95        }
96      });
97    }
98
99    let cancel_token_clone = cancel_token.clone();
100    tokio::spawn(async move {
101      tokio::select! {
102        result = tokio::signal::ctrl_c() => {
103          if result.is_ok() {
104            continue_tx.send(false).await.unwrap_or_default();
105          }
106        }
107        _ = cancel_token_clone.cancelled() => {}
108      }
109    });
110
111    let continue_running = continue_rx.recv().await.unwrap_or(false);
112    cancel_token.cancel();
113    continue_running
114  })
115}
116
117/// Function called before starting a server
118fn before_starting_server(
119  args: ArgMatches,
120  configuration_adapters: HashMap<String, Box<dyn ConfigurationAdapter + Send + Sync>>,
121) -> Result<(), Box<dyn Error + Send + Sync>> {
122  // Obtain the argument values
123  let configuration_path: &Path = args
124    .get_one::<PathBuf>("config")
125    .ok_or(anyhow::anyhow!("Cannot obtain the configuration path"))?
126    .as_path();
127  let configuration_adapter: &str = args
128    .get_one::<String>("config-adapter")
129    .map_or(determine_default_configuration_adapter(configuration_path), |s| {
130      s as &str
131    });
132
133  // Old handler shutdown channels and secondary runtime
134  let mut old_runtime: Option<(Vec<CancellationToken>, tokio::runtime::Runtime)> = None;
135
136  // Obtain the configuration adapter
137  let configuration_adapter = configuration_adapters
138    .get(configuration_adapter)
139    .ok_or(anyhow::anyhow!(
140      "The \"{}\" configuration adapter isn't supported",
141      configuration_adapter
142    ))?;
143
144  // Determine the available parallelism
145  let available_parallelism = thread::available_parallelism()?.get();
146
147  // First startup flag
148  let mut first_startup = true;
149
150  loop {
151    // Obtain the module loaders
152    let mut module_loaders = obtain_module_loaders();
153
154    // Obtain the observability backend loaders
155    let mut observability_backend_loaders = obtain_observability_backend_loaders();
156
157    // Create a secondary Tokio runtime
158    let secondary_runtime = tokio::runtime::Builder::new_multi_thread()
159      .worker_threads(match available_parallelism / 2 {
160        0 => 1,
161        non_zero => non_zero,
162      })
163      .thread_name("Secondary runtime")
164      .enable_all()
165      .build()?;
166
167    // Load the configuration
168    let configs_to_process = configuration_adapter.load_configuration(configuration_path)?;
169
170    // Process the configurations
171    let configs_to_process = merge_duplicates(configs_to_process);
172    let configs_to_process = remove_and_add_global_configuration(configs_to_process);
173    let configs_to_process = premerge_configuration(configs_to_process);
174    let (configs_to_process, first_module_error, unused_properties) = load_modules(
175      configs_to_process,
176      &mut module_loaders,
177      &mut observability_backend_loaders,
178      &secondary_runtime,
179    );
180
181    // Finalize the configurations
182    let server_configurations = Arc::new(ServerConfigurations::new(configs_to_process));
183
184    let global_configuration = server_configurations.find_global_configuration();
185    let global_configuration_clone = global_configuration.clone();
186
187    // Reference to the secondary Tokio runtime
188    let secondary_runtime_ref = &secondary_runtime;
189
190    // Mutable reference to the old runtime
191    let old_runtime_ref = &mut old_runtime;
192
193    // Execute the rest
194    let execute_rest = move || {
195      if let Some(first_module_error) = first_module_error {
196        // Error out if there was a module error
197        Err(first_module_error)?;
198      }
199
200      // Log unused properties
201      for unused_property in unused_properties {
202        for logging_tx in global_configuration
203          .as_ref()
204          .map_or(&vec![], |c| &c.observability.log_channels)
205        {
206          logging_tx
207            .send_blocking(LogMessage::new(
208              format!("Unused configuration property detected: \"{unused_property}\""),
209              true,
210            ))
211            .unwrap_or_default();
212        }
213      }
214
215      // Configure cryptography provider for Rustls
216      let mut crypto_provider = default_provider();
217
218      // Configure cipher suites
219      let cipher_suite: Vec<&config::ServerConfigurationValue> = global_configuration
220        .as_deref()
221        .map_or(vec![], |c| get_values!("tls_cipher_suite", c));
222      if !cipher_suite.is_empty() {
223        let mut cipher_suites = Vec::new();
224        let cipher_suite_iter = cipher_suite.iter();
225        for cipher_suite_config in cipher_suite_iter {
226          if let Some(cipher_suite) = cipher_suite_config.as_str() {
227            let cipher_suite_to_add = match cipher_suite {
228              "TLS_AES_128_GCM_SHA256" => TLS13_AES_128_GCM_SHA256,
229              "TLS_AES_256_GCM_SHA384" => TLS13_AES_256_GCM_SHA384,
230              "TLS_CHACHA20_POLY1305_SHA256" => TLS13_CHACHA20_POLY1305_SHA256,
231              "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
232              "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
233              "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
234              "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
235              "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
236              "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
237              _ => Err(anyhow::anyhow!(format!(
238                "The \"{}\" cipher suite is not supported",
239                cipher_suite
240              )))?,
241            };
242            cipher_suites.push(cipher_suite_to_add);
243          }
244        }
245        crypto_provider.cipher_suites = cipher_suites;
246      }
247
248      // Configure ECDH curves
249      let ecdh_curves = global_configuration
250        .as_deref()
251        .map_or(vec![], |c| get_values!("tls_ecdh_curve", c));
252      if !ecdh_curves.is_empty() {
253        let mut kx_groups = Vec::new();
254        let ecdh_curves_iter = ecdh_curves.iter();
255        for ecdh_curve_config in ecdh_curves_iter {
256          if let Some(ecdh_curve) = ecdh_curve_config.as_str() {
257            let kx_group_to_add = match ecdh_curve {
258              "secp256r1" => SECP256R1,
259              "secp384r1" => SECP384R1,
260              "x25519" => X25519,
261              "x25519mklem768" => X25519MLKEM768,
262              "mklem768" => MLKEM768,
263              _ => Err(anyhow::anyhow!(format!(
264                "The \"{}\" ECDH curve is not supported",
265                ecdh_curve
266              )))?,
267            };
268            kx_groups.push(kx_group_to_add);
269          }
270        }
271        crypto_provider.kx_groups = kx_groups;
272      }
273
274      // Install a process-wide cryptography provider. If it fails, then error it out.
275      if crypto_provider.clone().install_default().is_err() && first_startup {
276        Err(anyhow::anyhow!("Cannot install a process-wide cryptography provider"))?;
277      }
278
279      let crypto_provider = Arc::new(crypto_provider);
280
281      // Build TLS configuration
282      let tls_config_builder_wants_versions = ServerConfig::builder_with_provider(crypto_provider.clone());
283
284      let min_tls_version_option = global_configuration
285        .as_deref()
286        .and_then(|c| get_value!("tls_min_version", c))
287        .and_then(|v| v.as_str());
288      let max_tls_version_option = global_configuration
289        .as_deref()
290        .and_then(|c| get_value!("tls_max_version", c))
291        .and_then(|v| v.as_str());
292
293      let tls_config_builder_wants_verifier = if min_tls_version_option.is_none() && max_tls_version_option.is_none() {
294        tls_config_builder_wants_versions.with_safe_default_protocol_versions()?
295      } else {
296        let tls_versions = [("TLSv1.2", &TLS12), ("TLSv1.3", &TLS13)];
297        let min_tls_version_index =
298          min_tls_version_option.map_or(Some(0), |v| tls_versions.iter().position(|p| p.0 == v));
299        let max_tls_version_index = max_tls_version_option.map_or(Some(tls_versions.len() - 1), |v| {
300          tls_versions.iter().position(|p| p.0 == v)
301        });
302        if let Some(min_tls_version_index) = min_tls_version_index {
303          if let Some(max_tls_version_index) = max_tls_version_index {
304            tls_config_builder_wants_versions.with_protocol_versions(
305              &tls_versions[min_tls_version_index..max_tls_version_index]
306                .iter()
307                .map(|p| p.1)
308                .collect::<Vec<_>>(),
309            )?
310          } else {
311            Err(anyhow::anyhow!("Invalid maximum TLS version"))?
312          }
313        } else {
314          Err(anyhow::anyhow!("Invalid minimum TLS version"))?
315        }
316      };
317
318      let tls_config_builder_wants_server_cert = if let Some(client_cert_path) = global_configuration
319        .as_deref()
320        .and_then(|c| get_value!("tls_client_certificate", c))
321        .and_then(|v| v.as_str())
322      {
323        let mut roots = RootCertStore::empty();
324        let client_certificate_cas = load_certs(client_cert_path)?;
325        for cert in client_certificate_cas {
326          roots.add(cert)?;
327        }
328        tls_config_builder_wants_verifier
329          .with_client_cert_verifier(WebPkiClientVerifier::builder(Arc::new(roots)).build()?)
330      } else if global_configuration
331        .as_deref()
332        .and_then(|c| get_value!("tls_client_certificate", c))
333        .and_then(|v| v.as_bool())
334        .unwrap_or(false)
335      {
336        let roots = (|| {
337          let certs_result = load_native_certs();
338          if !certs_result.errors.is_empty() {
339            return None;
340          }
341          let certs = certs_result.certs;
342
343          let mut roots = RootCertStore::empty();
344          for cert in certs {
345            if roots.add(cert).is_err() {
346              return None;
347            }
348          }
349
350          Some(roots)
351        })()
352        .unwrap_or(RootCertStore {
353          roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
354        });
355
356        tls_config_builder_wants_verifier
357          .with_client_cert_verifier(WebPkiClientVerifier::builder(Arc::new(roots)).build()?)
358      } else {
359        tls_config_builder_wants_verifier.with_no_client_auth()
360      };
361
362      let enable_proxy_protocol = global_configuration
363        .as_ref()
364        .and_then(|c| get_value!("protocol_proxy", c))
365        .and_then(|v| v.as_bool())
366        .unwrap_or(false);
367      let protocols = global_configuration
368        .as_ref()
369        .and_then(|c| get_entry!("protocols", c))
370        .map(|e| e.values.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
371        .unwrap_or(vec!["h1", "h2"]);
372
373      if enable_proxy_protocol && protocols.contains(&"h3") {
374        Err(anyhow::anyhow!("PROXY protocol isn't supported with HTTP/3"))?
375      }
376
377      let default_http_port = global_configuration
378        .as_deref()
379        .and_then(|c| get_entry!("default_http_port", c))
380        .and_then(|e| e.values.first())
381        .map_or(Some(80), |v| {
382          if v.is_null() {
383            None
384          } else {
385            Some(v.as_i128().unwrap_or(80) as u16)
386          }
387        });
388      let default_https_port = global_configuration
389        .as_deref()
390        .and_then(|c| get_entry!("default_https_port", c))
391        .and_then(|e| e.values.first())
392        .map_or(Some(443), |v| {
393          if v.is_null() {
394            None
395          } else {
396            Some(v.as_i128().unwrap_or(443) as u16)
397          }
398        });
399
400      let mut tls_ports: HashMap<u16, CustomSniResolver> = HashMap::new();
401      #[allow(clippy::type_complexity)]
402      let mut tls_port_locks: HashMap<u16, Arc<tokio::sync::RwLock<Vec<(String, Arc<dyn ResolvesServerCert>)>>>> =
403        HashMap::new();
404      let mut nonencrypted_ports = HashSet::new();
405      let mut certified_keys_to_preload: HashMap<u16, Vec<Arc<CertifiedKey>>> = HashMap::new();
406      let mut used_sni_hostnames = HashSet::new();
407      let mut automatic_tls_used_sni_hostnames = HashSet::new();
408      let mut acme_tls_alpn_01_resolvers: HashMap<u16, TlsAlpn01Resolver> = HashMap::new();
409      let mut acme_tls_alpn_01_resolver_locks: HashMap<
410        u16,
411        Arc<tokio::sync::RwLock<Vec<crate::acme::TlsAlpn01DataLock>>>,
412      > = HashMap::new();
413      let acme_http_01_resolvers: Arc<tokio::sync::RwLock<Vec<crate::acme::Http01DataLock>>> =
414        Arc::new(tokio::sync::RwLock::new(Vec::new()));
415      let acme_default_directory = dirs::data_local_dir().and_then(|mut p| {
416        p.push("ferron-acme");
417        p.into_os_string().into_string().ok()
418      });
419      let memory_acme_account_cache_data = Arc::new(tokio::sync::RwLock::new(HashMap::new()));
420      let mut acme_configs = Vec::new();
421      let mut acme_on_demand_configs = Vec::new();
422      let (acme_on_demand_tx, acme_on_demand_rx) = async_channel::unbounded();
423      let on_demand_tls_ask_endpoint = match global_configuration
424        .as_ref()
425        .and_then(|c| get_value!("auto_tls_on_demand_ask", c))
426        .and_then(|v| v.as_str())
427        .map(|u| u.parse::<hyper::Uri>())
428      {
429        Some(Ok(uri)) => Some(uri),
430        Some(Err(err)) => Err(anyhow::anyhow!(
431          "Failed to parse automatic TLS on demand asking endpoint URI: {}",
432          err
433        ))?,
434        None => None,
435      };
436      let on_demand_tls_ask_endpoint_verify = !global_configuration
437        .as_ref()
438        .and_then(|c| get_value!("auto_tls_on_demand_ask_no_verification", c))
439        .and_then(|v| v.as_bool())
440        .unwrap_or(false);
441
442      // Iterate server configurations (TLS configuration)
443      for server_configuration in &server_configurations.inner {
444        if server_configuration.filters.is_global_non_host()
445          || (server_configuration.filters.is_global() && server_configuration.entries.is_empty())
446        {
447          // Don't add listeners from an empty global configuration or non-host global configuration
448          continue;
449        }
450
451        let on_demand_tls = get_value!("auto_tls_on_demand", server_configuration)
452          .and_then(|v| v.as_bool())
453          .unwrap_or(false);
454
455        let https_port = server_configuration.filters.port.or(default_https_port);
456
457        let sni_hostname = server_configuration.filters.hostname.clone().or_else(|| {
458          // !!! UNTESTED, many clients don't send SNI hostname when accessing via IP address anyway
459          match server_configuration.filters.ip {
460            Some(IpAddr::V4(address)) => Some(address.to_string()),
461            Some(IpAddr::V6(address)) => Some(format!("[{address}]")),
462            _ => None,
463          }
464        });
465
466        let is_sni_hostname_used = !https_port.is_none_or(|p| {
467          !used_sni_hostnames.contains(&(p, sni_hostname.clone()))
468            && !automatic_tls_used_sni_hostnames.contains(&(p, sni_hostname.clone()))
469        });
470        let is_auto_tls_sni_hostname_used =
471          https_port.is_some_and(|p| automatic_tls_used_sni_hostnames.contains(&(p, sni_hostname.clone())));
472
473        let mut automatic_tls_port = None;
474        if server_configuration.filters.port.is_none() {
475          if get_value!("auto_tls", server_configuration)
476            .and_then(|v| v.as_bool())
477            .unwrap_or(!is_localhost(
478              server_configuration.filters.ip.as_ref(),
479              server_configuration.filters.hostname.as_deref(),
480            ))
481          {
482            automatic_tls_port = default_https_port;
483          }
484          if let Some(http_port) = default_http_port {
485            nonencrypted_ports.insert(http_port);
486          }
487        }
488
489        if get_value!("auto_tls", server_configuration)
490          .and_then(|v| v.as_bool())
491          .unwrap_or(false)
492        {
493          automatic_tls_port = https_port;
494        } else if let Some(tls_entry) = get_entry!("tls", server_configuration) {
495          if let Some(https_port) = https_port {
496            if tls_entry.values.len() == 2 {
497              if let Some(cert_path) = tls_entry.values[0].as_str() {
498                if let Some(key_path) = tls_entry.values[1].as_str() {
499                  automatic_tls_port = None;
500
501                  if !is_sni_hostname_used {
502                    let certs = match load_certs(cert_path) {
503                      Ok(certs) => certs,
504                      Err(err) => Err(anyhow::anyhow!(format!(
505                        "Cannot load the \"{}\" TLS certificate: {}",
506                        cert_path, err
507                      )))?,
508                    };
509                    let key = match load_private_key(key_path) {
510                      Ok(key) => key,
511                      Err(err) => Err(anyhow::anyhow!(format!(
512                        "Cannot load the \"{}\" private key: {}",
513                        key_path, err
514                      )))?,
515                    };
516                    let signing_key = match crypto_provider.key_provider.load_private_key(key) {
517                      Ok(key) => key,
518                      Err(err) => Err(anyhow::anyhow!(format!(
519                        "Cannot load the \"{}\" private key: {}",
520                        key_path, err
521                      )))?,
522                    };
523                    let certified_key = Arc::new(CertifiedKey::new(certs, signing_key));
524                    if let Some(certified_keys) = certified_keys_to_preload.get_mut(&https_port) {
525                      certified_keys.push(certified_key.clone());
526                    } else {
527                      certified_keys_to_preload.insert(https_port, vec![certified_key.clone()]);
528                    }
529                    let resolver = Arc::new(OneCertifiedKeyResolver::new(certified_key));
530                    if let Some(sni_resolver) = tls_ports.get_mut(&https_port) {
531                      if let Some(sni_hostname) = &sni_hostname {
532                        sni_resolver.load_host_resolver(sni_hostname, resolver);
533                      } else {
534                        sni_resolver.load_fallback_resolver(resolver);
535                      }
536                    } else {
537                      let sni_resolver_list = Arc::new(tokio::sync::RwLock::new(Vec::new()));
538                      tls_port_locks.insert(https_port, sni_resolver_list.clone());
539                      let mut sni_resolver = CustomSniResolver::with_resolvers(sni_resolver_list);
540                      if let Some(sni_hostname) = &sni_hostname {
541                        sni_resolver.load_host_resolver(sni_hostname, resolver);
542                      } else {
543                        sni_resolver.load_fallback_resolver(resolver);
544                      }
545                      tls_ports.insert(https_port, sni_resolver);
546                    }
547                    used_sni_hostnames.insert((https_port, sni_hostname.clone()));
548                  }
549                }
550              }
551            }
552          }
553        } else if let Some(http_port) = server_configuration.filters.port.or(default_http_port) {
554          nonencrypted_ports.insert(http_port);
555        }
556        if let Some(automatic_tls_port) = automatic_tls_port {
557          if !is_auto_tls_sni_hostname_used {
558            if sni_hostname.is_some() || on_demand_tls {
559              let is_wildcard_domain = sni_hostname.as_ref().is_some_and(|s| s.starts_with("*."));
560              let challenge_type_entry = get_entry!("auto_tls_challenge", server_configuration);
561              let challenge_type_str = challenge_type_entry
562                .and_then(|e| e.values.first())
563                .and_then(|v| v.as_str())
564                .unwrap_or("tls-alpn-01");
565              let challenge_params = challenge_type_entry
566                .and_then(|e| {
567                  let mut props_str = HashMap::new();
568                  for (prop_name, prop_value) in e.props.iter() {
569                    if let Some(prop_value) = prop_value.as_str() {
570                      props_str.insert(prop_name.to_string(), prop_value.to_string());
571                    }
572                  }
573                  if props_str.is_empty() {
574                    None
575                  } else {
576                    Some(props_str)
577                  }
578                })
579                .unwrap_or(HashMap::new());
580              if let Some(sni_hostname) = &sni_hostname {
581                if sni_hostname.parse::<IpAddr>().is_ok() {
582                  for logging_tx in global_configuration
583                    .as_ref()
584                    .map_or(&vec![], |c| &c.observability.log_channels)
585                  {
586                    logging_tx
587                    .send_blocking(LogMessage::new(
588                      format!("Ferron's automatic TLS functionality doesn't support IP address-based identifiers, skipping SNI host \"{sni_hostname}\"..."),
589                      true,
590                    ))
591                    .unwrap_or_default();
592                  }
593                  continue;
594                }
595              }
596              let challenge_type = match &*challenge_type_str.to_uppercase() {
597                "HTTP-01" => {
598                  if let Some(sni_hostname) = &sni_hostname {
599                    if is_wildcard_domain && !on_demand_tls {
600                      for logging_tx in global_configuration
601                        .as_ref()
602                        .map_or(&vec![], |c| &c.observability.log_channels)
603                      {
604                        logging_tx
605                        .send_blocking(LogMessage::new(
606                          format!("HTTP-01 ACME challenge doesn't support wildcard hostnames, skipping SNI host \"{sni_hostname}\"..."),
607                          true,
608                        ))
609                        .unwrap_or_default();
610                      }
611                      continue;
612                    }
613                  }
614                  ChallengeType::Http01
615                }
616                "TLS-ALPN-01" => {
617                  if let Some(sni_hostname) = &sni_hostname {
618                    if is_wildcard_domain && !on_demand_tls {
619                      for logging_tx in global_configuration
620                        .as_ref()
621                        .map_or(&vec![], |c| &c.observability.log_channels)
622                      {
623                        logging_tx
624                        .send_blocking(LogMessage::new(
625                          format!("TLS-ALPN-01 ACME challenge doesn't support wildcard hostnames, skipping SNI host \"{sni_hostname}\"..."),
626                          true,
627                        ))
628                        .unwrap_or_default();
629                      }
630                      continue;
631                    }
632                  }
633                  ChallengeType::TlsAlpn01
634                }
635                "DNS-01" => ChallengeType::Dns01,
636                unsupported => Err(anyhow::anyhow!("Unsupported ACME challenge type: {}", unsupported))?,
637              };
638              let dns_provider: Option<Arc<dyn ferron_common::dns::DnsProvider + Send + Sync>> =
639                if &*challenge_type_str.to_uppercase() != "DNS-01" {
640                  None
641                } else if let Some(provider_name) = challenge_params.get("provider") {
642                  Some(get_dns_provider(provider_name, &challenge_params)?)
643                } else {
644                  Err(anyhow::anyhow!("No DNS provider specified"))?
645                };
646              let acme_cache_path_option = get_value!("auto_tls_cache", server_configuration)
647                .map_or(acme_default_directory.as_deref(), |v| {
648                  if v.is_null() {
649                    None
650                  } else if let Some(v) = v.as_str() {
651                    Some(v)
652                  } else {
653                    acme_default_directory.as_deref()
654                  }
655                })
656                .map(|path| path.to_owned());
657              let rustls_client_config = (if get_value!("auto_tls_no_verification", server_configuration)
658                .and_then(|v| v.as_bool())
659                .unwrap_or(false)
660              {
661                ClientConfig::builder_with_provider(crypto_provider.clone())
662                  .with_safe_default_protocol_versions()?
663                  .dangerous()
664                  .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
665              } else if let Ok(client_config) = BuilderVerifierExt::with_platform_verifier(
666                ClientConfig::builder_with_provider(crypto_provider.clone()).with_safe_default_protocol_versions()?,
667              ) {
668                client_config
669              } else {
670                ClientConfig::builder_with_provider(crypto_provider.clone())
671                  .with_safe_default_protocol_versions()?
672                  .with_webpki_verifier(
673                    WebPkiServerVerifier::builder(Arc::new(rustls::RootCertStore {
674                      roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
675                    }))
676                    .build()?,
677                  )
678              })
679              .with_no_client_auth();
680              if on_demand_tls {
681                if &*challenge_type_str.to_uppercase() == "TLS-ALPN-01" {
682                  // Add TLS-ALPN-01 resolver
683                  let sni_resolver_list = Arc::new(tokio::sync::RwLock::new(Vec::new()));
684                  acme_tls_alpn_01_resolver_locks.insert(automatic_tls_port, sni_resolver_list.clone());
685                  let sni_resolver = TlsAlpn01Resolver::with_resolvers(sni_resolver_list);
686                  acme_tls_alpn_01_resolvers.insert(automatic_tls_port, sni_resolver);
687                }
688
689                if let Some(sni_resolver) = tls_ports.get_mut(&automatic_tls_port) {
690                  sni_resolver.load_fallback_sender(acme_on_demand_tx.clone(), automatic_tls_port);
691                } else {
692                  let sni_resolver_list = Arc::new(tokio::sync::RwLock::new(Vec::new()));
693                  tls_port_locks.insert(automatic_tls_port, sni_resolver_list.clone());
694                  let mut sni_resolver = CustomSniResolver::with_resolvers(sni_resolver_list);
695                  sni_resolver.load_fallback_sender(acme_on_demand_tx.clone(), automatic_tls_port);
696                  tls_ports.insert(automatic_tls_port, sni_resolver);
697                }
698
699                let acme_on_demand_config = AcmeOnDemandConfig {
700                  rustls_client_config,
701                  challenge_type,
702                  contact: if let Some(contact) =
703                    get_value!("auto_tls_contact", server_configuration).and_then(|v| v.as_str())
704                  {
705                    vec![format!("mailto:{}", contact.to_string())]
706                  } else {
707                    vec![]
708                  },
709                  directory: if let Some(directory) =
710                    get_value!("auto_tls_directory", server_configuration).and_then(|v| v.as_str())
711                  {
712                    directory.to_string()
713                  } else if get_value!("auto_tls_letsencrypt_production", server_configuration)
714                    .and_then(|v| v.as_bool())
715                    .unwrap_or(true)
716                  {
717                    LetsEncrypt::Production.url().to_string()
718                  } else {
719                    LetsEncrypt::Staging.url().to_string()
720                  },
721                  eab_key: if let Some(eab_key_entry) = get_entry!("auto_tls_eab", server_configuration) {
722                    if let Some(eab_key_id) = eab_key_entry.values.first().and_then(|v| v.as_str()) {
723                      if let Some(eab_key_hmac) = eab_key_entry.values.get(1).and_then(|v| v.as_str()) {
724                        match base64::engine::general_purpose::URL_SAFE_NO_PAD
725                          .decode(eab_key_hmac.trim_end_matches('='))
726                        {
727                          Ok(decoded_key) => {
728                            Some(Arc::new(ExternalAccountKey::new(eab_key_id.to_string(), &decoded_key)))
729                          }
730                          Err(err) => Err(anyhow::anyhow!("Failed to decode EAB key HMAC: {}", err))?,
731                        }
732                      } else {
733                        None
734                      }
735                    } else {
736                      None
737                    }
738                  } else {
739                    None
740                  },
741                  profile: get_value!("auto_tls_profile", server_configuration)
742                    .and_then(|v| v.as_str().map(|s| s.to_string())),
743                  cache_path: if let Some(acme_cache_path) = acme_cache_path_option.clone() {
744                    match PathBuf::from_str(&acme_cache_path) {
745                      Ok(pathbuf) => Some(pathbuf),
746                      Err(_) => Err(anyhow::anyhow!("Invalid ACME cache path"))?,
747                    }
748                  } else {
749                    None
750                  },
751                  sni_resolver_lock: tls_port_locks
752                    .get(&automatic_tls_port)
753                    .cloned()
754                    .unwrap_or(Arc::new(tokio::sync::RwLock::new(Vec::new()))),
755                  tls_alpn_01_resolver_lock: acme_tls_alpn_01_resolver_locks
756                    .get(&automatic_tls_port)
757                    .cloned()
758                    .unwrap_or(Arc::new(tokio::sync::RwLock::new(Vec::new()))),
759                  http_01_resolver_lock: acme_http_01_resolvers.clone(),
760                  dns_provider,
761                  sni_hostname: sni_hostname.clone(),
762                  port: automatic_tls_port,
763                };
764                acme_on_demand_configs.push(acme_on_demand_config);
765                automatic_tls_used_sni_hostnames.insert((automatic_tls_port, sni_hostname));
766              } else if let Some(sni_hostname) = sni_hostname {
767                let (account_cache_path, cert_cache_path) =
768                  if let Some(acme_cache_path) = acme_cache_path_option.clone() {
769                    let mut pathbuf = match PathBuf::from_str(&acme_cache_path) {
770                      Ok(pathbuf) => pathbuf,
771                      Err(_) => Err(anyhow::anyhow!("Invalid ACME cache path"))?,
772                    };
773                    let base_pathbuf = pathbuf.clone();
774                    let append_hash = base64::engine::general_purpose::URL_SAFE_NO_PAD
775                      .encode(xxh3_128(format!("{automatic_tls_port}-{sni_hostname}").as_bytes()).to_be_bytes());
776                    pathbuf.push(append_hash);
777                    (Some(base_pathbuf), Some(pathbuf))
778                  } else {
779                    (None, None)
780                  };
781                let certified_key_lock = Arc::new(tokio::sync::RwLock::new(None));
782                let tls_alpn_01_data_lock = Arc::new(tokio::sync::RwLock::new(None));
783                let http_01_data_lock = Arc::new(tokio::sync::RwLock::new(None));
784                let acme_config = AcmeConfig {
785                  rustls_client_config,
786                  domains: vec![sni_hostname.clone()],
787                  challenge_type,
788                  contact: if let Some(contact) =
789                    get_value!("auto_tls_contact", server_configuration).and_then(|v| v.as_str())
790                  {
791                    vec![format!("mailto:{}", contact.to_string())]
792                  } else {
793                    vec![]
794                  },
795                  directory: if let Some(directory) =
796                    get_value!("auto_tls_directory", server_configuration).and_then(|v| v.as_str())
797                  {
798                    directory.to_string()
799                  } else if get_value!("auto_tls_letsencrypt_production", server_configuration)
800                    .and_then(|v| v.as_bool())
801                    .unwrap_or(true)
802                  {
803                    LetsEncrypt::Production.url().to_string()
804                  } else {
805                    LetsEncrypt::Staging.url().to_string()
806                  },
807                  eab_key: if let Some(eab_key_entry) = get_entry!("auto_tls_eab", server_configuration) {
808                    if let Some(eab_key_id) = eab_key_entry.values.first().and_then(|v| v.as_str()) {
809                      if let Some(eab_key_hmac) = eab_key_entry.values.get(1).and_then(|v| v.as_str()) {
810                        match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(eab_key_hmac) {
811                          Ok(decoded_key) => {
812                            Some(Arc::new(ExternalAccountKey::new(eab_key_id.to_string(), &decoded_key)))
813                          }
814                          Err(err) => Err(anyhow::anyhow!("Failed to decode EAB key HMAC: {}", err))?,
815                        }
816                      } else {
817                        None
818                      }
819                    } else {
820                      None
821                    }
822                  } else {
823                    None
824                  },
825                  profile: get_value!("auto_tls_profile", server_configuration)
826                    .and_then(|v| v.as_str().map(|s| s.to_string())),
827                  account_cache: if let Some(account_cache_path) = account_cache_path {
828                    AcmeCache::File(account_cache_path)
829                  } else {
830                    AcmeCache::Memory(memory_acme_account_cache_data.clone())
831                  },
832                  certificate_cache: if let Some(cert_cache_path) = cert_cache_path {
833                    AcmeCache::File(cert_cache_path)
834                  } else {
835                    AcmeCache::Memory(Arc::new(tokio::sync::RwLock::new(HashMap::new())))
836                  },
837                  certified_key_lock: certified_key_lock.clone(),
838                  tls_alpn_01_data_lock: tls_alpn_01_data_lock.clone(),
839                  http_01_data_lock: http_01_data_lock.clone(),
840                  dns_provider,
841                  renewal_info: None,
842                  account: None,
843                };
844                let acme_resolver = Arc::new(AcmeResolver::new(certified_key_lock));
845                acme_configs.push(acme_config);
846                match &*challenge_type_str.to_uppercase() {
847                  "HTTP-01" => {
848                    acme_http_01_resolvers.blocking_write().push(http_01_data_lock);
849                  }
850                  "TLS-ALPN-01" => {
851                    if let Some(sni_resolver) = acme_tls_alpn_01_resolvers.get_mut(&automatic_tls_port) {
852                      sni_resolver.load_resolver(tls_alpn_01_data_lock);
853                    } else {
854                      let sni_resolver_list = Arc::new(tokio::sync::RwLock::new(Vec::new()));
855                      acme_tls_alpn_01_resolver_locks.insert(automatic_tls_port, sni_resolver_list.clone());
856                      let sni_resolver = TlsAlpn01Resolver::with_resolvers(sni_resolver_list);
857                      sni_resolver.load_resolver(tls_alpn_01_data_lock);
858                      acme_tls_alpn_01_resolvers.insert(automatic_tls_port, sni_resolver);
859                    }
860                  }
861                  _ => (),
862                }
863                if let Some(sni_resolver) = tls_ports.get_mut(&automatic_tls_port) {
864                  sni_resolver.load_host_resolver(&sni_hostname, acme_resolver);
865                } else {
866                  let sni_resolver_list = Arc::new(tokio::sync::RwLock::new(Vec::new()));
867                  tls_port_locks.insert(automatic_tls_port, sni_resolver_list.clone());
868                  let mut sni_resolver = CustomSniResolver::with_resolvers(sni_resolver_list);
869                  sni_resolver.load_host_resolver(&sni_hostname, acme_resolver);
870                  tls_ports.insert(automatic_tls_port, sni_resolver);
871                }
872                automatic_tls_used_sni_hostnames.insert((automatic_tls_port, Some(sni_hostname)));
873              }
874            } else if !server_configuration.filters.is_global() && !server_configuration.filters.is_global_non_host() {
875              for logging_tx in global_configuration
876                .as_ref()
877                .map_or(&vec![], |c| &c.observability.log_channels)
878              {
879                logging_tx
880                  .send_blocking(LogMessage::new(
881                    "Skipping automatic TLS for a host without a SNI hostname...".to_string(),
882                    true,
883                  ))
884                  .unwrap_or_default();
885              }
886            }
887          }
888        }
889      }
890
891      // Shut down request handler threads and secondary runtime
892      if let Some((handler_shutdown_channels, secondary_runtime)) = old_runtime_ref.take() {
893        for shutdown in handler_shutdown_channels {
894          shutdown.cancel();
895        }
896        drop(secondary_runtime);
897      }
898
899      if !acme_configs.is_empty() || !acme_on_demand_configs.is_empty() {
900        // Spawn a task to handle ACME certificate provisioning, one certificate at time
901
902        let global_configuration_clone = global_configuration.clone();
903        secondary_runtime_ref.spawn(async move {
904          for acme_config in &mut acme_configs {
905            // Install the certificates from the cache if they're valid
906            check_certificate_validity_or_install_cached(acme_config, None)
907              .await
908              .unwrap_or_default();
909          }
910
911          let mut existing_combinations = HashSet::new();
912          for acme_on_demand_config in &mut acme_on_demand_configs {
913            for cached_domain in get_cached_domains(acme_on_demand_config).await {
914              let mut acme_config = convert_on_demand_config(
915                acme_on_demand_config,
916                cached_domain.clone(),
917                memory_acme_account_cache_data.clone(),
918              )
919              .await;
920
921              existing_combinations.insert((cached_domain, acme_on_demand_config.port));
922
923              // Install the certificates from the cache if they're valid
924              check_certificate_validity_or_install_cached(&mut acme_config, None)
925                .await
926                .unwrap_or_default();
927
928              acme_configs.push(acme_config);
929            }
930          }
931
932          // Wrap ACME configurations in a mutex
933          let acme_configs_mutex = Arc::new(tokio::sync::Mutex::new(acme_configs));
934
935          let prevent_file_race_conditions_mutex = Arc::new(tokio::sync::Mutex::new(()));
936
937          if !acme_on_demand_configs.is_empty() {
938            // On-demand TLS
939            let acme_configs_mutex = acme_configs_mutex.clone();
940            let acme_on_demand_configs = Arc::new(acme_on_demand_configs);
941            let global_configuration_clone = global_configuration_clone.clone();
942            tokio::spawn(async move {
943              let mut existing_combinations = existing_combinations;
944              while let Ok(received_data) = acme_on_demand_rx.recv().await {
945                let on_demand_tls_ask_endpoint = on_demand_tls_ask_endpoint.clone();
946                if let Some(on_demand_tls_ask_endpoint) = on_demand_tls_ask_endpoint {
947                  let mut url_parts = on_demand_tls_ask_endpoint.into_parts();
948                  if let Some(path_and_query) = url_parts.path_and_query {
949                    let query = path_and_query.query();
950                    let query = if let Some(query) = query {
951                      format!("{}&domain={}", query, urlencoding::encode(&received_data.0))
952                    } else {
953                      format!("domain={}", urlencoding::encode(&received_data.0))
954                    };
955                    url_parts.path_and_query = Some(match format!("{}?{}", path_and_query.path(), query).parse() {
956                      Ok(parsed) => parsed,
957                      Err(err) => {
958                        for acme_logger in global_configuration_clone
959                          .as_ref()
960                          .map_or(&vec![], |c| &c.observability.log_channels)
961                        {
962                          acme_logger
963                            .send(LogMessage::new(
964                              format!("Error while formatting the URL for on-demand TLS request: {err}"),
965                              true,
966                            ))
967                            .await
968                            .unwrap_or_default();
969                        }
970                        continue;
971                      }
972                    });
973                  } else {
974                    url_parts.path_and_query = Some(
975                      match format!("/?domain={}", urlencoding::encode(&received_data.0)).parse() {
976                        Ok(parsed) => parsed,
977                        Err(err) => {
978                          for acme_logger in global_configuration_clone
979                            .as_ref()
980                            .map_or(&vec![], |c| &c.observability.log_channels)
981                          {
982                            acme_logger
983                              .send(LogMessage::new(
984                                format!("Error while formatting the URL for on-demand TLS request: {err}"),
985                                true,
986                              ))
987                              .await
988                              .unwrap_or_default();
989                          }
990                          continue;
991                        }
992                      },
993                    );
994                  }
995                  let endpoint_url = match hyper::Uri::from_parts(url_parts) {
996                    Ok(parsed) => parsed,
997                    Err(err) => {
998                      for acme_logger in global_configuration_clone
999                        .as_ref()
1000                        .map_or(&vec![], |c| &c.observability.log_channels)
1001                      {
1002                        acme_logger
1003                          .send(LogMessage::new(
1004                            format!("Error while formatting the URL for on-demand TLS request: {err}"),
1005                            true,
1006                          ))
1007                          .await
1008                          .unwrap_or_default();
1009                      }
1010                      continue;
1011                    }
1012                  };
1013                  let ask_closure = async {
1014                    let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
1015                      .build::<_, http_body_util::Empty<hyper::body::Bytes>>(
1016                      hyper_rustls::HttpsConnectorBuilder::new()
1017                        .with_tls_config(
1018                          (if !on_demand_tls_ask_endpoint_verify {
1019                            ClientConfig::builder_with_provider(crypto_provider.clone())
1020                              .with_safe_default_protocol_versions()?
1021                              .dangerous()
1022                              .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
1023                          } else if let Ok(client_config) = BuilderVerifierExt::with_platform_verifier(
1024                            ClientConfig::builder_with_provider(crypto_provider.clone())
1025                              .with_safe_default_protocol_versions()?,
1026                          ) {
1027                            client_config
1028                          } else {
1029                            ClientConfig::builder_with_provider(crypto_provider.clone())
1030                              .with_safe_default_protocol_versions()?
1031                              .with_webpki_verifier(
1032                                WebPkiServerVerifier::builder(Arc::new(rustls::RootCertStore {
1033                                  roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
1034                                }))
1035                                .build()?,
1036                              )
1037                          })
1038                          .with_no_client_auth(),
1039                        )
1040                        .https_or_http()
1041                        .enable_http1()
1042                        .enable_http2()
1043                        .build(),
1044                    );
1045                    let request = hyper::Request::builder()
1046                      .method(hyper::Method::GET)
1047                      .uri(endpoint_url)
1048                      .body(http_body_util::Empty::<hyper::body::Bytes>::new())?;
1049                    let response = client.request(request).await?;
1050
1051                    Ok::<_, Box<dyn Error + Send + Sync>>(response.status().is_success())
1052                  };
1053                  match ask_closure.await {
1054                    Ok(true) => (),
1055                    Ok(false) => {
1056                      for acme_logger in global_configuration_clone
1057                        .as_ref()
1058                        .map_or(&vec![], |c| &c.observability.log_channels)
1059                      {
1060                        acme_logger
1061                          .send(LogMessage::new(
1062                            format!(
1063                              "The TLS certificate cannot be issued for \"{}\" hostname",
1064                              &received_data.0
1065                            ),
1066                            true,
1067                          ))
1068                          .await
1069                          .unwrap_or_default();
1070                      }
1071                      continue;
1072                    }
1073                    Err(err) => {
1074                      for acme_logger in global_configuration_clone
1075                        .as_ref()
1076                        .map_or(&vec![], |c| &c.observability.log_channels)
1077                      {
1078                        acme_logger
1079                          .send(LogMessage::new(
1080                            format!(
1081                              "Error while determining if the TLS certificate can be issued for \"{}\" hostname: {err}",
1082                              &received_data.0
1083                            ),
1084                            true,
1085                          ))
1086                          .await
1087                          .unwrap_or_default();
1088                      }
1089                      continue;
1090                    }
1091                  }
1092                }
1093                if existing_combinations.contains(&received_data) {
1094                  continue;
1095                } else {
1096                  existing_combinations.insert(received_data.clone());
1097                }
1098                let (sni_hostname, port) = received_data;
1099                let acme_configs_mutex = acme_configs_mutex.clone();
1100                let acme_on_demand_configs = acme_on_demand_configs.clone();
1101                let memory_acme_account_cache_data = memory_acme_account_cache_data.clone();
1102                let prevent_file_race_conditions_mutex = prevent_file_race_conditions_mutex.clone();
1103                tokio::spawn(async move {
1104                  for acme_on_demand_config in acme_on_demand_configs.iter() {
1105                    if match_hostname(acme_on_demand_config.sni_hostname.as_deref(), Some(&sni_hostname))
1106                      && acme_on_demand_config.port == port
1107                    {
1108                      let mutex_guard = prevent_file_race_conditions_mutex.lock().await;
1109                      add_domain_to_cache(acme_on_demand_config, &sni_hostname)
1110                        .await
1111                        .unwrap_or_default();
1112                      drop(mutex_guard);
1113
1114                      acme_configs_mutex.lock().await.push(
1115                        convert_on_demand_config(
1116                          acme_on_demand_config,
1117                          sni_hostname.clone(),
1118                          memory_acme_account_cache_data,
1119                        )
1120                        .await,
1121                      );
1122                      break;
1123                    }
1124                  }
1125                });
1126              }
1127            });
1128          }
1129
1130          loop {
1131            for acme_config in &mut *acme_configs_mutex.lock().await {
1132              if let Err(acme_error) = provision_certificate(acme_config).await {
1133                for acme_logger in global_configuration_clone
1134                  .as_ref()
1135                  .map_or(&vec![], |c| &c.observability.log_channels)
1136                {
1137                  acme_logger
1138                    .send(LogMessage::new(
1139                      format!("Error while obtaining a TLS certificate: {acme_error}"),
1140                      true,
1141                    ))
1142                    .await
1143                    .unwrap_or_default();
1144                }
1145              }
1146            }
1147            tokio::time::sleep(Duration::from_secs(10)).await;
1148          }
1149        });
1150      }
1151
1152      // If HTTP/1.1 isn't enabled, don't listen to non-encrypted ports
1153      if !protocols.contains(&"h1") {
1154        nonencrypted_ports.clear();
1155      }
1156
1157      for tls_port in tls_ports.keys() {
1158        if nonencrypted_ports.contains(tls_port) {
1159          nonencrypted_ports.remove(tls_port);
1160        }
1161      }
1162
1163      // Create TLS server configurations
1164      let mut quic_tls_configs = HashMap::new();
1165      let mut tls_configs = HashMap::new();
1166      let mut acme_tls_alpn_01_configs = HashMap::new();
1167      for (tls_port, sni_resolver) in tls_ports.into_iter() {
1168        let enable_ocsp_stapling = global_configuration
1169          .as_ref()
1170          .and_then(|c| get_value!("ocsp_stapling", c))
1171          .and_then(|v| v.as_bool())
1172          .unwrap_or(true);
1173        let resolver: Arc<dyn ResolvesServerCert> = if enable_ocsp_stapling {
1174          // The `ocsp_stapler` crate is dependent on Tokio, so we create a stapler in the Tokio runtime...
1175          // If this wasn't wrapped in a Tokio runtime, creation of a OCSP stapler would just cause a panic.
1176          let stapler =
1177            secondary_runtime_ref.block_on(async move { ocsp_stapler::Stapler::new(Arc::new(sni_resolver)) });
1178          if let Some(certified_keys_to_preload) = certified_keys_to_preload.get(&tls_port) {
1179            for certified_key in certified_keys_to_preload {
1180              stapler.preload(certified_key.clone());
1181            }
1182          }
1183          Arc::new(stapler)
1184        } else {
1185          Arc::new(sni_resolver)
1186        };
1187        let mut tls_config = tls_config_builder_wants_server_cert
1188          .clone()
1189          .with_cert_resolver(resolver);
1190        if protocols.contains(&"h3") {
1191          // TLS configuration used for QUIC listene
1192          let mut quic_tls_config = tls_config.clone();
1193          quic_tls_config.max_early_data_size = u32::MAX;
1194          quic_tls_config.alpn_protocols.insert(0, b"h3-29".to_vec());
1195          quic_tls_config.alpn_protocols.insert(0, b"h3".to_vec());
1196          quic_tls_configs.insert(tls_port, Arc::new(quic_tls_config));
1197        }
1198        if protocols.contains(&"h1") {
1199          tls_config.alpn_protocols.insert(0, b"http/1.0".to_vec());
1200          tls_config.alpn_protocols.insert(0, b"http/1.1".to_vec());
1201        }
1202        if protocols.contains(&"h2") {
1203          tls_config.alpn_protocols.insert(0, b"h2".to_vec());
1204        }
1205        tls_configs.insert(tls_port, Arc::new(tls_config));
1206      }
1207      for (tls_port, sni_resolver) in acme_tls_alpn_01_resolvers.into_iter() {
1208        let mut tls_config = tls_config_builder_wants_server_cert
1209          .clone()
1210          .with_cert_resolver(Arc::new(sni_resolver));
1211        tls_config.alpn_protocols = vec![ACME_TLS_ALPN_NAME.to_vec()];
1212        acme_tls_alpn_01_configs.insert(tls_port, Arc::new(tls_config));
1213      }
1214
1215      // Process metrics initialization
1216      #[cfg(any(target_os = "linux", target_os = "android"))]
1217      if let Some(metrics_channels) = global_configuration
1218        .as_ref()
1219        .map(|c| &c.observability.metric_channels)
1220        .cloned()
1221      {
1222        secondary_runtime_ref.spawn(async move {
1223          use ferron_common::observability::{Metric, MetricAttributeValue, MetricType, MetricValue};
1224
1225          let mut previous_instant = std::time::Instant::now();
1226          let mut previous_cpu_user_time = 0.0;
1227          let mut previous_cpu_system_time = 0.0;
1228          let mut previous_rss = 0;
1229          let mut previous_vms = 0;
1230          loop {
1231            // Sleep for 1 second
1232            tokio::time::sleep(Duration::from_secs(1)).await;
1233
1234            if let Ok(stat) = procfs::process::Process::myself().and_then(|p| p.stat()) {
1235              let cpu_user_time = stat.utime as f64 / procfs::ticks_per_second() as f64;
1236              let cpu_system_time = stat.stime as f64 / procfs::ticks_per_second() as f64;
1237              let cpu_user_time_increase = cpu_user_time - previous_cpu_user_time;
1238              let cpu_system_time_increase = cpu_system_time - previous_cpu_system_time;
1239              previous_cpu_user_time = cpu_user_time;
1240              previous_cpu_system_time = cpu_system_time;
1241
1242              let rss = stat.rss * procfs::page_size();
1243              let rss_diff = rss as i64 - previous_rss as i64;
1244              let vms_diff = stat.vsize as i64 - previous_vms as i64;
1245              previous_rss = rss;
1246              previous_vms = stat.vsize;
1247
1248              let elapsed = previous_instant.elapsed().as_secs_f64();
1249              previous_instant = std::time::Instant::now();
1250
1251              let cpu_user_utilization = cpu_user_time_increase / (elapsed * available_parallelism as f64);
1252              let cpu_system_utilization = cpu_system_time_increase / (elapsed * available_parallelism as f64);
1253
1254              for metrics_sender in &metrics_channels {
1255                metrics_sender
1256                  .send(Metric::new(
1257                    "process.cpu.time",
1258                    vec![("cpu.mode", MetricAttributeValue::String("user".to_string()))],
1259                    MetricType::Counter,
1260                    MetricValue::F64(cpu_user_time_increase),
1261                    Some("s"),
1262                    Some("Total CPU seconds broken down by different states."),
1263                  ))
1264                  .await
1265                  .unwrap_or_default();
1266
1267                metrics_sender
1268                  .send(Metric::new(
1269                    "process.cpu.time",
1270                    vec![("cpu.mode", MetricAttributeValue::String("system".to_string()))],
1271                    MetricType::Counter,
1272                    MetricValue::F64(cpu_system_time_increase),
1273                    Some("s"),
1274                    Some("Total CPU seconds broken down by different states."),
1275                  ))
1276                  .await
1277                  .unwrap_or_default();
1278
1279                metrics_sender
1280                  .send(Metric::new(
1281                    "process.cpu.utilization",
1282                    vec![("cpu.mode", MetricAttributeValue::String("user".to_string()))],
1283                    MetricType::Gauge,
1284                    MetricValue::F64(cpu_user_utilization),
1285                    Some("1"),
1286                    Some("Difference in process.cpu.time since the last measurement, divided by the elapsed time and number of CPUs available to the process."),
1287                  ))
1288                  .await
1289                  .unwrap_or_default();
1290
1291                metrics_sender
1292                  .send(Metric::new(
1293                    "process.cpu.utilization",
1294                    vec![("cpu.mode", MetricAttributeValue::String("system".to_string()))],
1295                    MetricType::Gauge,
1296                    MetricValue::F64(cpu_system_utilization),
1297                    Some("1"),
1298                    Some("Difference in process.cpu.time since the last measurement, divided by the elapsed time and number of CPUs available to the process."),
1299                  ))
1300                  .await
1301                  .unwrap_or_default();
1302
1303                metrics_sender
1304                  .send(Metric::new(
1305                    "process.memory.usage",
1306                    vec![],
1307                    MetricType::UpDownCounter,
1308                    MetricValue::I64(rss_diff),
1309                    Some("By"),
1310                    Some("The amount of physical memory in use."),
1311                  ))
1312                  .await
1313                  .unwrap_or_default();
1314
1315                metrics_sender
1316                  .send(Metric::new(
1317                    "process.memory.virtual",
1318                    vec![],
1319                    MetricType::UpDownCounter,
1320                    MetricValue::I64(vms_diff),
1321                    Some("By"),
1322                    Some("The amount of committed virtual memory."),
1323                  ))
1324                  .await
1325                  .unwrap_or_default();
1326              }
1327            }
1328          }
1329        });
1330      }
1331
1332      let (listener_handler_tx, listener_handler_rx) = &**LISTENER_HANDLER_CHANNEL;
1333      let mut tcp_listeners = TCP_LISTENERS
1334        .lock()
1335        .map_err(|_| anyhow::anyhow!("Can't access the TCP listeners"))?;
1336      let mut quic_listeners = QUIC_LISTENERS
1337        .lock()
1338        .map_err(|_| anyhow::anyhow!("Can't access the QUIC listeners"))?;
1339      let mut listened_socket_addresses = Vec::new();
1340      let mut quic_listened_socket_addresses = Vec::new();
1341      let listen_ip_addr = match global_configuration
1342        .as_deref()
1343        .and_then(|c| get_value!("listen_ip", c))
1344        .and_then(|v| v.as_str())
1345        .map_or(Ok(IpAddr::V6(Ipv6Addr::UNSPECIFIED)), |a| a.parse())
1346      {
1347        Ok(addr) => addr,
1348        Err(_) => Err(anyhow::anyhow!("Invalid IP address to listen to"))?,
1349      };
1350      for (tcp_port, encrypted) in nonencrypted_ports
1351        .iter()
1352        .map(|p| (*p, false))
1353        .chain(tls_configs.keys().map(|p| (*p, true)))
1354      {
1355        let socket_address = SocketAddr::new(listen_ip_addr, tcp_port);
1356        listened_socket_addresses.push((socket_address, encrypted));
1357      }
1358      for (quic_port, quic_tls_config) in quic_tls_configs.into_iter() {
1359        let socket_address = SocketAddr::new(listen_ip_addr, quic_port);
1360        quic_listened_socket_addresses.push((socket_address, quic_tls_config));
1361      }
1362
1363      let enable_uring = global_configuration
1364        .as_deref()
1365        .and_then(|c| get_value!("io_uring", c))
1366        .and_then(|v| v.as_bool())
1367        .unwrap_or(true);
1368      let mut uring_enabled_locked = URING_ENABLED
1369        .lock()
1370        .map_err(|_| anyhow::anyhow!("Can't access the enabled `io_uring` option"))?;
1371      let mut tcp_listener_socketaddrs_to_remove = Vec::new();
1372      let mut quic_listener_socketaddrs_to_remove = Vec::new();
1373      for (key, value) in &*tcp_listeners {
1374        if enable_uring != *uring_enabled_locked
1375          || (!listened_socket_addresses.contains(&(*key, true)) && !listened_socket_addresses.contains(&(*key, false)))
1376        {
1377          // Shut down the TCP listener
1378          value.cancel();
1379
1380          // Push the the TCP listener address to remove
1381          tcp_listener_socketaddrs_to_remove.push(*key);
1382        }
1383      }
1384      for (key, value) in &*quic_listeners {
1385        let mut contains = false;
1386        for key2 in &quic_listened_socket_addresses {
1387          if key2.0 == *key {
1388            contains = true;
1389            break;
1390          }
1391        }
1392        if enable_uring != *uring_enabled_locked || !contains {
1393          // Shut down the QUIC listener
1394          value.0.cancel();
1395
1396          // Push the the QUIC listener address to remove
1397          quic_listener_socketaddrs_to_remove.push(*key);
1398        }
1399      }
1400      *uring_enabled_locked = enable_uring;
1401      drop(uring_enabled_locked);
1402
1403      for key_to_remove in tcp_listener_socketaddrs_to_remove {
1404        // Remove the TCP listener
1405        tcp_listeners.remove(&key_to_remove);
1406      }
1407
1408      for key_to_remove in quic_listener_socketaddrs_to_remove {
1409        // Remove the QUIC listener
1410        quic_listeners.remove(&key_to_remove);
1411      }
1412
1413      // Get a global logger for listeners
1414      let (global_logging_tx, global_logging_rx) = &**LISTENER_LOGGING_CHANNEL;
1415      let global_logger = if global_configuration
1416        .as_ref()
1417        .is_none_or(|c| c.observability.log_channels.is_empty())
1418      {
1419        None
1420      } else {
1421        let global_configuration_clone = global_configuration.clone();
1422        secondary_runtime_ref.spawn(async move {
1423          while let Ok(log_message) = global_logging_rx.recv().await {
1424            for logging_tx in global_configuration_clone
1425              .as_ref()
1426              .map_or(&vec![], |c| &c.observability.log_channels)
1427            {
1428              logging_tx.send(log_message.clone()).await.unwrap_or_default();
1429            }
1430          }
1431        });
1432        Some(global_logging_tx.clone())
1433      };
1434
1435      // Spawn request handler threads
1436      let mut handler_shutdown_channels = Vec::new();
1437      for _ in 0..available_parallelism {
1438        handler_shutdown_channels.push(create_http_handler(
1439          server_configurations.clone(),
1440          listener_handler_rx.clone(),
1441          enable_uring,
1442          tls_configs.clone(),
1443          !quic_listened_socket_addresses.is_empty(),
1444          acme_tls_alpn_01_configs.clone(),
1445          acme_http_01_resolvers.clone(),
1446          enable_proxy_protocol,
1447        )?);
1448      }
1449
1450      // Error out, if server is configured to listen to no port
1451      if listened_socket_addresses.is_empty() && quic_listened_socket_addresses.is_empty() {
1452        Err(anyhow::anyhow!("The server is configured to listen to no port"))?
1453      }
1454
1455      let tcp_send_buffer_size = global_configuration
1456        .as_deref()
1457        .and_then(|c| get_value!("tcp_send_buffer", c))
1458        .and_then(|v| v.as_i128())
1459        .map(|v| v as usize);
1460      let tcp_recv_buffer_size = global_configuration
1461        .as_deref()
1462        .and_then(|c| get_value!("tcp_recv_buffer", c))
1463        .and_then(|v| v.as_i128())
1464        .map(|v| v as usize);
1465      for (socket_address, encrypted) in listened_socket_addresses {
1466        if let std::collections::hash_map::Entry::Vacant(e) = tcp_listeners.entry(socket_address) {
1467          // Create a TCP listener
1468          e.insert(create_tcp_listener(
1469            socket_address,
1470            encrypted,
1471            listener_handler_tx.clone(),
1472            enable_uring,
1473            global_logger.clone(),
1474            first_startup,
1475            (tcp_send_buffer_size, tcp_recv_buffer_size),
1476          )?);
1477        }
1478      }
1479
1480      // Drop TCP listener mutex guard
1481      drop(tcp_listeners);
1482
1483      for (socket_address, tls_config) in quic_listened_socket_addresses {
1484        if let Some(quic_listener_entry) = quic_listeners.get(&socket_address) {
1485          // Replace the TLS configuration in the QUIC listener
1486          let (_, tls_quic_listener) = quic_listener_entry;
1487          tls_quic_listener.send_blocking(tls_config).unwrap_or_default();
1488        } else {
1489          // Create a QUIC listener
1490          quic_listeners.insert(
1491            socket_address,
1492            create_quic_listener(
1493              socket_address,
1494              tls_config,
1495              listener_handler_tx.clone(),
1496              enable_uring,
1497              global_logger.clone(),
1498              first_startup,
1499            )?,
1500          );
1501        }
1502      }
1503
1504      // Drop QUIC listener mutex guard
1505      drop(quic_listeners);
1506
1507      let shutdown_result = handle_shutdown_signals(secondary_runtime_ref);
1508
1509      Ok::<_, Box<dyn Error + Send + Sync>>((shutdown_result, handler_shutdown_channels))
1510    };
1511
1512    match execute_rest() {
1513      Ok((to_restart, handler_shutdown_channels)) => {
1514        if to_restart {
1515          old_runtime = Some((handler_shutdown_channels, secondary_runtime));
1516          first_startup = false;
1517          println!("Reloading the server configuration...");
1518        } else {
1519          for shutdown in handler_shutdown_channels {
1520            shutdown.cancel();
1521          }
1522          drop(secondary_runtime);
1523          break;
1524        }
1525      }
1526      Err(err) => {
1527        for logging_tx in global_configuration_clone
1528          .as_ref()
1529          .map_or(&vec![], |c| &c.observability.log_channels)
1530        {
1531          logging_tx
1532            .send_blocking(LogMessage::new(err.to_string(), true))
1533            .unwrap_or_default();
1534        }
1535        std::thread::sleep(Duration::from_millis(100));
1536        Err(err)?
1537      }
1538    }
1539  }
1540  Ok(())
1541}
1542
1543fn obtain_configuration_adapters() -> (
1544  HashMap<String, Box<dyn ConfigurationAdapter + Send + Sync>>,
1545  Vec<&'static str>,
1546) {
1547  // Configuration adapters
1548  let mut configuration_adapters: HashMap<String, Box<dyn ConfigurationAdapter + Send + Sync>> = HashMap::new();
1549  let mut all_adapters = Vec::new();
1550
1551  // Configuration adapter registration macro
1552  macro_rules! register_configuration_adapter {
1553    ($name:literal, $adapter:expr) => {
1554      configuration_adapters.insert($name.to_string(), Box::new($adapter));
1555      all_adapters.push($name);
1556    };
1557  }
1558
1559  // Register configuration adapters
1560  register_configuration_adapter!("kdl", config::adapters::kdl::KdlConfigurationAdapter::new());
1561  #[cfg(feature = "config-yaml-legacy")]
1562  register_configuration_adapter!(
1563    "yaml-legacy",
1564    config::adapters::yaml_legacy::YamlLegacyConfigurationAdapter::new()
1565  );
1566  #[cfg(feature = "config-docker-auto")]
1567  register_configuration_adapter!(
1568    "docker-auto",
1569    config::adapters::docker_auto::DockerAutoConfigurationAdapter::new()
1570  );
1571
1572  (configuration_adapters, all_adapters)
1573}
1574
1575/// Determines the default configuration adapter
1576#[cfg(feature = "config-yaml-legacy")]
1577fn determine_default_configuration_adapter(path: &Path) -> &'static str {
1578  match path
1579    .extension()
1580    .and_then(|s| s.to_str())
1581    .map(|s| s.to_lowercase())
1582    .as_deref()
1583  {
1584    Some("yaml") | Some("yml") => "yaml-legacy",
1585    _ => "kdl",
1586  }
1587}
1588
1589/// Determines the default configuration adapter
1590#[cfg(not(feature = "config-yaml-legacy"))]
1591fn determine_default_configuration_adapter(_path: &Path) -> &'static str {
1592  "kdl"
1593}
1594
1595/// Parses the command-line arguments
1596fn parse_arguments(all_adapters: Vec<&'static str>) -> ArgMatches {
1597  Command::new("Ferron")
1598    .about("A fast, memory-safe web server written in Rust")
1599    .arg(
1600      Arg::new("config")
1601        .long("config")
1602        .short('c')
1603        .help("The path to the server configuration file")
1604        .action(ArgAction::Set)
1605        .default_value("./ferron.kdl")
1606        .value_parser(PathBuf::from_str),
1607    )
1608    .arg(
1609      Arg::new("config-adapter")
1610        .long("config-adapter")
1611        .help("The configuration adapter to use")
1612        .action(ArgAction::Set)
1613        .required(false)
1614        .value_parser(all_adapters),
1615    )
1616    .arg(
1617      Arg::new("module-config")
1618        .long("module-config")
1619        .help("Prints the used compile-time module configuration (`ferron-build.yaml` or `ferron-build-override.yaml` in the Ferron source) and exits")
1620        .action(ArgAction::SetTrue)
1621    )
1622    .arg(
1623      Arg::new("version")
1624        .long("version")
1625        .short('V')
1626        .help("Print version and build information")
1627        .action(ArgAction::SetTrue)
1628    )
1629    .get_matches()
1630}
1631
1632/// The main entry point of the application
1633fn main() {
1634  // Set the panic handler
1635  setup_panic!(Metadata::new("Ferron", env!("CARGO_PKG_VERSION"))
1636    .homepage("https://ferron.sh")
1637    .support("- Send an email message to hello@ferron.sh"));
1638
1639  // Obtain the configuration adapters
1640  let (configuration_adapters, all_adapters) = obtain_configuration_adapters();
1641
1642  // Parse command-line arguments
1643  let args = parse_arguments(all_adapters);
1644
1645  if args.get_flag("module-config") {
1646    // Dump the used compile-time module configuration and exit
1647    println!("{}", ferron_load_modules::FERRON_BUILD_YAML);
1648    return;
1649  } else if args.get_flag("version") {
1650    // Print the server version and build information
1651    println!("Ferron {}", build::PKG_VERSION);
1652    println!("  Compiled on: {}", build::BUILD_TIME);
1653    println!("  Git commit: {}", build::COMMIT_HASH);
1654    println!("  Build target: {}", build::BUILD_TARGET);
1655    println!("  Rust version: {}", build::RUST_VERSION);
1656    println!("  Build host: {}", build::BUILD_OS);
1657    if shadow_rs::is_debug() {
1658      println!("WARNING: This is a debug build. It is not recommended for production use.");
1659    }
1660    return;
1661  }
1662
1663  // Start the server!
1664  match before_starting_server(args, configuration_adapters) {
1665    Ok(_) => (),
1666    Err(err) => {
1667      eprintln!("Error while running a server: {err}");
1668      std::process::exit(1);
1669    }
1670  };
1671}