ferron_common/http_proxy/
mod.rs

1mod builder;
2mod load_balancer;
3mod proxy_client;
4mod request_parts;
5mod send_net_io;
6mod send_request;
7
8use std::collections::HashMap;
9use std::error::Error;
10use std::net::IpAddr;
11use std::pin::Pin;
12use std::str::FromStr;
13use std::sync::atomic::AtomicUsize;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker};
16use std::time::Duration;
17
18use async_trait::async_trait;
19use bytes::Bytes;
20use connpool::{Item, Pool};
21use futures_util::FutureExt;
22use http_body_util::combinators::BoxBody;
23use hyper::header::{self, HeaderName};
24use hyper::{Request, StatusCode, Uri};
25#[cfg(feature = "runtime-monoio")]
26use monoio::net::TcpStream;
27#[cfg(all(feature = "runtime-monoio", unix))]
28use monoio::net::UnixStream;
29use rustls::client::WebPkiServerVerifier;
30use rustls_pki_types::ServerName;
31use rustls_platform_verifier::BuilderVerifierExt;
32use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
33#[cfg(feature = "runtime-tokio")]
34use tokio::net::TcpStream;
35#[cfg(all(feature = "runtime-tokio", unix))]
36use tokio::net::UnixStream;
37use tokio::sync::RwLock;
38use tokio_rustls::TlsConnector;
39#[cfg(feature = "runtime-vibeio")]
40use vibeio::net::TcpStream;
41#[cfg(all(feature = "runtime-vibeio", unix))]
42use vibeio::net::UnixStream;
43
44use crate::config::ServerConfiguration;
45use crate::http_proxy::send_request::SendRequestWrapper;
46use crate::logging::ErrorLogger;
47use crate::modules::{ModuleHandlers, ResponseData, SocketData};
48use crate::observability::{Metric, MetricAttributeValue, MetricType, MetricValue, MetricsMultiSender};
49use crate::util::{NoServerVerifier, TtlCache};
50
51pub use self::builder::ReverseProxyBuilder;
52#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
53use self::send_net_io::{SendTcpStreamPoll, SendTcpStreamPollDropGuard};
54#[cfg(all(any(feature = "runtime-vibeio", feature = "runtime-monoio"), unix))]
55use self::send_net_io::{SendUnixStreamPoll, SendUnixStreamPollDropGuard};
56use self::{
57  load_balancer::{determine_proxy_to, resolve_upstreams},
58  proxy_client::{http_proxy, http_proxy_handshake},
59  request_parts::construct_proxy_request_parts,
60};
61
62type ConnectionsTrackState = Arc<RwLock<HashMap<UpstreamInner, Arc<()>>>>;
63
64enum LoadBalancerAlgorithmInner {
65  Random,
66  RoundRobin(Arc<AtomicUsize>),
67  LeastConnections(ConnectionsTrackState),
68  TwoRandomChoices(ConnectionsTrackState),
69}
70
71/// Backend selection strategy used when multiple upstreams are configured.
72#[derive(Clone, Copy, Hash, PartialEq, Eq)]
73pub enum LoadBalancerAlgorithm {
74  /// Selects a backend randomly for each request.
75  Random,
76  /// Cycles through backends in order.
77  RoundRobin,
78  /// Selects the backend with the least active tracked connections.
79  LeastConnections,
80  /// Chooses two random backends and picks the less loaded one.
81  TwoRandomChoices,
82}
83
84/// Proxy protocol version to prepend to upstream connections.
85#[derive(Clone, Copy)]
86pub enum ProxyHeader {
87  /// HAProxy PROXY protocol v1.
88  V1,
89  /// HAProxy PROXY protocol v2.
90  V2,
91}
92
93#[derive(Clone, Eq, PartialEq, Hash)]
94struct UpstreamInner {
95  proxy_to: String,
96  proxy_unix: Option<String>,
97}
98
99#[derive(Clone)]
100struct SrvUpstreamData {
101  to: String,
102  secondary_runtime_handle: tokio::runtime::Handle,
103  dns_resolver: Arc<hickory_resolver::TokioResolver>,
104}
105
106impl PartialEq for SrvUpstreamData {
107  fn eq(&self, other: &Self) -> bool {
108    self.to == other.to
109  }
110}
111
112impl Eq for SrvUpstreamData {}
113
114impl std::hash::Hash for SrvUpstreamData {
115  fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
116    self.to.hash(state);
117  }
118}
119
120#[derive(Clone, Eq, PartialEq, Hash)]
121enum Upstream {
122  Static(UpstreamInner),
123  Srv(SrvUpstreamData),
124}
125
126impl Upstream {
127  async fn resolve(
128    &self,
129    failed_backends: Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
130    health_check_max_fails: u64,
131  ) -> Vec<UpstreamInner> {
132    match self {
133      Upstream::Static(inner) => vec![inner.clone()],
134      Upstream::Srv(srv_data) => {
135        let to = srv_data.to.clone();
136        let resolver = srv_data.dns_resolver.clone();
137        let failed_backends = failed_backends.clone();
138        srv_data
139          .secondary_runtime_handle
140          .spawn(async move {
141            let to_url = match Uri::from_str(&to) {
142              Ok(uri) => uri,
143              Err(_) => return vec![],
144            };
145            let to = match to_url.host() {
146              Some(host) => host.to_string(),
147              None => return vec![],
148            };
149
150            let srv_records = match resolver.srv_lookup(&to).await {
151              Ok(records) => records,
152              Err(_) => return vec![],
153            };
154
155            let failed_backends = failed_backends.read().await;
156            let srv_upstreams = srv_records
157              .into_iter()
158              .filter_map(|record| {
159                let mut to_url_parts = to_url.clone().into_parts();
160                to_url_parts.authority = Some(format!("{}:{}", record.target(), record.port()).parse().ok()?);
161                let upstream_inner = UpstreamInner {
162                  proxy_to: Uri::from_parts(to_url_parts).ok()?.to_string(),
163                  proxy_unix: None,
164                };
165                if failed_backends
166                  .get(&upstream_inner)
167                  .is_some_and(|fails| fails > health_check_max_fails)
168                {
169                  // Backend is unhealthy, skip it
170                  None
171                } else {
172                  Some((upstream_inner, record.weight(), record.priority()))
173                }
174              })
175              .collect::<Vec<_>>();
176            let highest_priority = srv_upstreams
177              .iter()
178              .map(|(_, _, priority)| *priority)
179              .min()
180              .unwrap_or(0);
181            let filtered_srv_upstreams = srv_upstreams
182              .into_iter()
183              .filter(|(_, _, priority)| *priority == highest_priority)
184              .map(|(upstream, weight, _)| (upstream, weight))
185              .collect::<Vec<_>>();
186            let cumulative_weight: u64 = filtered_srv_upstreams.iter().map(|(_, weight)| *weight as u64).sum();
187            let mut random_weight = if cumulative_weight == 0 {
188              // Prevent empty range sampling panics
189              0
190            } else {
191              rand::random_range(0..cumulative_weight)
192            };
193            for upstream in filtered_srv_upstreams {
194              let weight = upstream.1;
195              if random_weight <= weight as u64 {
196                return vec![upstream.0];
197              }
198              random_weight -= weight as u64;
199            }
200            vec![]
201          })
202          .await
203          .unwrap_or(vec![])
204      }
205    }
206  }
207}
208
209type ProxyToKey = (Upstream, Option<usize>, Option<Duration>);
210type ProxyToKeyInner = (UpstreamInner, Option<usize>, Option<Duration>);
211
212type ConnectionPool = Arc<Pool<(UpstreamInner, Option<IpAddr>), SendRequestWrapper>>;
213type ConnectionPoolItem = Item<(UpstreamInner, Option<IpAddr>), SendRequestWrapper>;
214
215#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
216#[allow(unused)]
217enum DropGuard {
218  Tcp(SendTcpStreamPollDropGuard),
219  #[cfg(unix)]
220  Unix(SendUnixStreamPollDropGuard),
221}
222
223enum Connection {
224  #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
225  Tcp(SendTcpStreamPoll),
226  #[cfg(not(any(feature = "runtime-vibeio", feature = "runtime-monoio")))]
227  Tcp(TcpStream),
228  #[cfg(all(any(feature = "runtime-vibeio", feature = "runtime-monoio"), unix))]
229  Unix(SendUnixStreamPoll),
230  #[cfg(all(not(any(feature = "runtime-vibeio", feature = "runtime-monoio")), unix))]
231  Unix(UnixStream),
232}
233
234#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
235impl Connection {
236  unsafe fn get_drop_guard(&mut self) -> DropGuard {
237    match self {
238      Connection::Tcp(stream) => DropGuard::Tcp(stream.get_drop_guard()),
239      #[cfg(unix)]
240      Connection::Unix(stream) => DropGuard::Unix(stream.get_drop_guard()),
241    }
242  }
243}
244
245impl AsyncRead for Connection {
246  fn poll_read(
247    mut self: Pin<&mut Self>,
248    cx: &mut Context<'_>,
249    buf: &mut tokio::io::ReadBuf,
250  ) -> Poll<Result<(), std::io::Error>> {
251    match &mut *self {
252      Connection::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
253      #[cfg(unix)]
254      Connection::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
255    }
256  }
257}
258
259impl AsyncWrite for Connection {
260  fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
261    match &mut *self {
262      Connection::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
263      #[cfg(unix)]
264      Connection::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
265    }
266  }
267
268  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
269    match &mut *self {
270      Connection::Tcp(stream) => Pin::new(stream).poll_flush(cx),
271      #[cfg(unix)]
272      Connection::Unix(stream) => Pin::new(stream).poll_flush(cx),
273    }
274  }
275
276  fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
277    match &mut *self {
278      Connection::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
279      #[cfg(unix)]
280      Connection::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
281    }
282  }
283
284  fn is_write_vectored(&self) -> bool {
285    match self {
286      Connection::Tcp(stream) => stream.is_write_vectored(),
287      #[cfg(unix)]
288      Connection::Unix(stream) => stream.is_write_vectored(),
289    }
290  }
291
292  fn poll_write_vectored(
293    mut self: Pin<&mut Self>,
294    cx: &mut Context<'_>,
295    bufs: &[std::io::IoSlice<'_>],
296  ) -> Poll<Result<usize, std::io::Error>> {
297    match &mut *self {
298      Connection::Tcp(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
299      #[cfg(unix)]
300      Connection::Unix(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
301    }
302  }
303}
304
305/// Connection pool for reverse proxy
306pub struct Connections {
307  #[allow(clippy::type_complexity)]
308  load_balancer_cache: HashMap<
309    (
310      LoadBalancerAlgorithm,
311      Arc<Vec<(Upstream, Option<usize>, Option<Duration>)>>,
312    ),
313    Arc<LoadBalancerAlgorithmInner>,
314  >,
315  #[allow(clippy::type_complexity)]
316  failed_backend_cache: HashMap<
317    (Duration, u64, Arc<Vec<(Upstream, Option<usize>, Option<Duration>)>>),
318    Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
319  >,
320  connections: ConnectionPool,
321  #[cfg(unix)]
322  unix_connections: ConnectionPool,
323}
324
325impl Connections {
326  /// Creates a connection pool without a global connection limit.
327  pub fn new() -> Self {
328    Self {
329      load_balancer_cache: HashMap::new(),
330      failed_backend_cache: HashMap::new(),
331      connections: Arc::new(Pool::new_unbounded()),
332      #[cfg(unix)]
333      unix_connections: Arc::new(Pool::new_unbounded()),
334    }
335  }
336
337  /// Creates a connection pool with a global TCP connection limit.
338  ///
339  /// Unix socket connections remain unbounded.
340  pub fn with_global_limit(global_limit: usize) -> Self {
341    Self {
342      load_balancer_cache: HashMap::new(),
343      failed_backend_cache: HashMap::new(),
344      connections: Arc::new(Pool::new(global_limit)),
345      #[cfg(unix)]
346      unix_connections: Arc::new(Pool::new_unbounded()),
347    }
348  }
349
350  /// Starts a reverse proxy builder using this connection pool.
351  pub fn get_builder<'a>(&'a mut self) -> ReverseProxyBuilder<'a> {
352    ReverseProxyBuilder {
353      connections: self,
354      upstreams: Vec::new(),
355      lb_algorithm: LoadBalancerAlgorithm::TwoRandomChoices,
356      lb_health_check_window: Duration::from_millis(5000),
357      lb_health_check_max_fails: 3,
358      lb_health_check: false,
359      proxy_no_verification: false,
360      proxy_intercept_errors: false,
361      lb_retry_connection: true,
362      proxy_http2_only: false,
363      proxy_http2: false,
364      proxy_keepalive: true,
365      proxy_proxy_header: None,
366      proxy_request_header: Vec::new(),
367      proxy_request_header_replace: Vec::new(),
368      proxy_request_header_remove: Vec::new(),
369      rewrite_host: false,
370    }
371  }
372}
373
374impl Default for Connections {
375  fn default() -> Self {
376    Self::new()
377  }
378}
379
380/// A reverse proxy
381pub struct ReverseProxy {
382  #[allow(clippy::type_complexity)]
383  failed_backends: Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
384  load_balancer_algorithm: Arc<LoadBalancerAlgorithmInner>,
385  proxy_to: Arc<Vec<ProxyToKey>>,
386  health_check_max_fails: u64,
387  enable_health_check: bool,
388  disable_certificate_verification: bool,
389  proxy_intercept_errors: bool,
390  retry_connection: bool,
391  proxy_http2_only: bool,
392  proxy_http2: bool,
393  proxy_keepalive: bool,
394  proxy_header: Option<ProxyHeader>,
395  headers_to_add: Arc<Vec<(HeaderName, String)>>,
396  headers_to_replace: Arc<Vec<(HeaderName, String)>>,
397  headers_to_remove: Arc<Vec<HeaderName>>,
398  rewrite_host: bool,
399  connections: ConnectionPool,
400  #[cfg(unix)]
401  unix_connections: ConnectionPool,
402}
403
404impl ReverseProxy {
405  /// Creates a request handler instance with shared proxy state.
406  pub fn get_handler(&self) -> ReverseProxyHandler {
407    ReverseProxyHandler {
408      failed_backends: self.failed_backends.clone(),
409      load_balancer_algorithm: self.load_balancer_algorithm.clone(),
410      proxy_to: self.proxy_to.clone(),
411      health_check_max_fails: self.health_check_max_fails,
412      selected_backends_metrics: None,
413      unhealthy_backends_metrics: None,
414      connection_reused: false,
415      enable_health_check: self.enable_health_check,
416      disable_certificate_verification: self.disable_certificate_verification,
417      proxy_intercept_errors: self.proxy_intercept_errors,
418      retry_connection: self.retry_connection,
419      proxy_http2_only: self.proxy_http2_only,
420      proxy_http2: self.proxy_http2,
421      proxy_keepalive: self.proxy_keepalive,
422      proxy_header: self.proxy_header,
423      headers_to_add: self.headers_to_add.clone(),
424      headers_to_replace: self.headers_to_replace.clone(),
425      headers_to_remove: self.headers_to_remove.clone(),
426      rewrite_host: self.rewrite_host,
427      connections: self.connections.clone(),
428      #[cfg(unix)]
429      unix_connections: self.unix_connections.clone(),
430    }
431  }
432}
433
434/// Handlers for the reverse proxy module
435pub struct ReverseProxyHandler {
436  #[allow(clippy::type_complexity)]
437  failed_backends: Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
438  load_balancer_algorithm: Arc<LoadBalancerAlgorithmInner>,
439  proxy_to: Arc<Vec<ProxyToKey>>,
440  health_check_max_fails: u64,
441  selected_backends_metrics: Option<Vec<UpstreamInner>>,
442  unhealthy_backends_metrics: Option<Vec<UpstreamInner>>,
443  connection_reused: bool,
444  enable_health_check: bool,
445  disable_certificate_verification: bool,
446  proxy_intercept_errors: bool,
447  retry_connection: bool,
448  proxy_http2_only: bool,
449  proxy_http2: bool,
450  proxy_keepalive: bool,
451  proxy_header: Option<ProxyHeader>,
452  headers_to_add: Arc<Vec<(HeaderName, String)>>,
453  headers_to_replace: Arc<Vec<(HeaderName, String)>>,
454  headers_to_remove: Arc<Vec<HeaderName>>,
455  rewrite_host: bool,
456  connections: ConnectionPool,
457  #[cfg(unix)]
458  unix_connections: ConnectionPool,
459}
460
461impl ReverseProxyHandler {
462  #[inline]
463  fn status_response(status_code: StatusCode) -> ResponseData {
464    ResponseData {
465      request: None,
466      response: None,
467      response_status: Some(status_code),
468      response_headers: None,
469      new_remote_address: None,
470    }
471  }
472
473  async fn mark_backend_failure(&mut self, upstream: &UpstreamInner) {
474    if !self.enable_health_check {
475      return;
476    }
477    if let Some(unhealthy_backends_metrics) = self.unhealthy_backends_metrics.as_mut() {
478      unhealthy_backends_metrics.push(upstream.clone());
479    }
480    let mut failed_backends_write = self.failed_backends.write().await;
481    let failed_attempts = failed_backends_write.get(upstream);
482    failed_backends_write.insert(upstream.clone(), failed_attempts.map_or(1, |x| x + 1));
483  }
484
485  async fn retry_or_respond(
486    &self,
487    error_logger: &ErrorLogger,
488    err: &dyn std::fmt::Display,
489    retry_connection: bool,
490    has_more_backends: bool,
491    status_code: StatusCode,
492    log_prefix: &str,
493  ) -> Option<ResponseData> {
494    if retry_connection && has_more_backends {
495      error_logger
496        .log(&format!("Failed to connect to backend, trying another backend: {err}"))
497        .await;
498      None
499    } else {
500      error_logger.log(&format!("{log_prefix}: {err}")).await;
501      Some(Self::status_response(status_code))
502    }
503  }
504
505  #[inline]
506  fn io_error_status(err: &std::io::Error) -> (StatusCode, &'static str) {
507    match err.kind() {
508      std::io::ErrorKind::ConnectionRefused | std::io::ErrorKind::NotFound | std::io::ErrorKind::HostUnreachable => {
509        (StatusCode::SERVICE_UNAVAILABLE, "Service unavailable")
510      }
511      std::io::ErrorKind::TimedOut => (StatusCode::GATEWAY_TIMEOUT, "Gateway timeout"),
512      _ => (StatusCode::BAD_GATEWAY, "Bad gateway"),
513    }
514  }
515}
516
517#[async_trait(?Send)]
518impl ModuleHandlers for ReverseProxyHandler {
519  /// Handles incoming HTTP requests and proxies them to the configured backend server(s)
520  ///
521  /// This handler:
522  /// 1. Determines which backend server to proxy to (supports load balancing)
523  /// 2. Transforms the request by:
524  ///    - Converting the URL to match the backend format
525  ///    - Setting appropriate headers (Host, X-Forwarded-*)
526  /// 3. Establishes a connection to the backend (HTTP or HTTPS)
527  /// 4. Forwards the request and returns the response
528  ///
529  /// The handler supports:
530  /// - Load balancing across multiple backends
531  /// - Connection pooling/reuse
532  /// - Health checking (marking failed backends)
533  /// - TLS/SSL for secure connections
534  /// - HTTP protocol upgrades (e.g., WebSockets)
535  async fn request_handler(
536    &mut self,
537    request: Request<BoxBody<Bytes, std::io::Error>>,
538    config: &ServerConfiguration,
539    socket_data: &SocketData,
540    error_logger: &ErrorLogger,
541  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
542    let enable_health_check = self.enable_health_check;
543    let health_check_max_fails = self.health_check_max_fails;
544    let disable_certificate_verification = self.disable_certificate_verification;
545    let proxy_intercept_errors = self.proxy_intercept_errors;
546    if self.proxy_to.is_empty() {
547      // No upstreams configured...
548      return Ok(ResponseData {
549        request: Some(request),
550        response: None,
551        response_status: None,
552        response_headers: None,
553        new_remote_address: None,
554      });
555    }
556    let mut proxy_to_vector = resolve_upstreams(
557      &self.proxy_to,
558      self.failed_backends.clone(),
559      self.health_check_max_fails,
560    )
561    .await;
562    let load_balancer_algorithm = self.load_balancer_algorithm.clone();
563    let connection_track = match &*load_balancer_algorithm {
564      LoadBalancerAlgorithmInner::LeastConnections(connection_track) => Some(connection_track),
565      LoadBalancerAlgorithmInner::TwoRandomChoices(connection_track) => Some(connection_track),
566      _ => None,
567    };
568    let retry_connection = self.retry_connection;
569    let (request_parts, request_body) = request.into_parts();
570    let mut request_parts = Some(request_parts);
571
572    loop {
573      if let Some((upstream, local_limit_index, keepalive_idle_timeout)) = determine_proxy_to(
574        &mut proxy_to_vector,
575        &self.failed_backends,
576        enable_health_check,
577        health_check_max_fails,
578        &load_balancer_algorithm,
579      )
580      .await
581      {
582        if let Some(selected_backends_metrics) = self.selected_backends_metrics.as_mut() {
583          selected_backends_metrics.push(upstream.clone());
584        }
585        let UpstreamInner { proxy_to, proxy_unix } = &upstream;
586        let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
587        let scheme_str = proxy_request_url.scheme_str();
588        let mut encrypted = false;
589
590        match scheme_str {
591          Some("http") => {
592            encrypted = false;
593          }
594          Some("https") => {
595            encrypted = true;
596          }
597          _ => Err(anyhow::anyhow!("Only HTTP and HTTPS reverse proxy URLs are supported."))?,
598        };
599
600        let host = match proxy_request_url.host() {
601          Some(host) => host,
602          None => Err(anyhow::anyhow!("The reverse proxy URL doesn't include the host"))?,
603        };
604
605        let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
606          Some("http") => 80,
607          Some("https") => 443,
608          _ => 80,
609        });
610
611        let addr = format!("{host}:{port}");
612
613        let request_parts_option = if proxy_to_vector.is_empty() {
614          request_parts.take()
615        } else {
616          request_parts.clone()
617        };
618        let request_parts = request_parts_option.ok_or(anyhow::anyhow!("Request parts not found"))?;
619        let proxy_request_parts = construct_proxy_request_parts(
620          request_parts,
621          config,
622          socket_data,
623          &proxy_request_url,
624          &self.headers_to_add,
625          &self.headers_to_replace,
626          &self.headers_to_remove,
627          self.rewrite_host,
628        )?;
629
630        let tracked_connection = if let Some(connection_track) = connection_track {
631          let connection_track_read = connection_track.read().await;
632          Some(if let Some(connection_count) = connection_track_read.get(&upstream) {
633            connection_count.clone()
634          } else {
635            let tracked_connection = Arc::new(());
636            drop(connection_track_read);
637            connection_track
638              .write()
639              .await
640              .insert(upstream.clone(), tracked_connection.clone());
641            tracked_connection
642          })
643        } else {
644          None
645        };
646
647        let proxy_header = self.proxy_header;
648
649        let is_http_upgrade = proxy_request_parts.headers.contains_key(header::UPGRADE);
650        let enable_http2_only_config = self.proxy_http2_only;
651        let enable_http2_config = self.proxy_http2;
652
653        let enable_keepalive =
654          (enable_http2_only_config || !enable_http2_config || !is_http_upgrade) && self.proxy_keepalive;
655        let connection_pool_item = {
656          #[cfg(unix)]
657          let connections = if proxy_unix.is_some() {
658            &self.unix_connections
659          } else {
660            &self.connections
661          };
662          #[cfg(not(unix))]
663          let connections = &self.connections;
664          let sender;
665          let mut send_request_items = Vec::new();
666          let proxy_client_ip = match proxy_header {
667            Some(ProxyHeader::V1) | Some(ProxyHeader::V2) => Some(socket_data.remote_addr.ip().to_canonical()),
668            _ => None,
669          };
670          loop {
671            let mut send_request_item = if send_request_items.is_empty() {
672              connections
673                .pull_with_wait_local_limit((upstream.clone(), proxy_client_ip), local_limit_index)
674                .await
675            } else if let Poll::Ready(send_request_item_option) = connections
676              .pull_with_wait_local_limit((upstream.clone(), proxy_client_ip), local_limit_index)
677              .boxed_local()
678              .poll_unpin(&mut Context::from_waker(Waker::noop()))
679            {
680              send_request_item_option
681            } else {
682              let send_request_items_taken = send_request_items;
683              send_request_items = Vec::new();
684              let fetch_nonready_send_request_fut = async {
685                let result = futures_util::future::select_ok(send_request_items_taken).await;
686                if let Ok((item, send_request_items_smaller)) = result {
687                  send_request_items = send_request_items_smaller;
688                  item
689                } else {
690                  futures_util::future::pending().await
691                }
692              };
693              crate::runtime::select! {
694                item = connections
695                  .pull_with_wait_local_limit((upstream.clone(), proxy_client_ip), local_limit_index)
696                => {
697                  item
698                },
699                item = fetch_nonready_send_request_fut => {
700                  item
701                }
702              }
703            };
704            if let Some(send_request) = send_request_item.inner_mut() {
705              match send_request.get(keepalive_idle_timeout) {
706                (Some(send_request), true) => {
707                  // Connection ready, send a request to it
708                  send_request_items.clear();
709                  self.connection_reused = true;
710                  let _ = send_request_item.inner_mut().take();
711                  let proxy_request = Request::from_parts(proxy_request_parts, request_body);
712                  let result = http_proxy(
713                    send_request,
714                    send_request_item,
715                    proxy_request,
716                    error_logger,
717                    proxy_intercept_errors,
718                    tracked_connection,
719                    true,
720                  )
721                  .await;
722                  return result;
723                }
724                (None, true) => {
725                  // Connection not ready
726                  send_request_items.push(Box::pin(async move {
727                    let inner_item = send_request_item.inner_mut();
728                    if let Some(inner_item_2) = inner_item {
729                      if !inner_item_2.wait_ready(keepalive_idle_timeout).await {
730                        // Connection closed or timed out
731                        inner_item.take();
732                        return Err(());
733                      }
734                      let _ = inner_item;
735                      Ok(send_request_item)
736                    } else {
737                      Err(())
738                    }
739                  }));
740                  continue;
741                }
742                (_, false) => {
743                  // Connection closed
744                  let _ = send_request_item.inner_mut().take();
745                  continue;
746                }
747              }
748            }
749            send_request_items.clear();
750            sender = send_request_item;
751            break;
752          }
753          sender
754        };
755
756        let stream = if let Some(proxy_unix_str) = &proxy_unix {
757          #[cfg(not(unix))]
758          {
759            let _ = proxy_unix_str; // Discard the variable to avoid unused variable warning
760            Err(anyhow::anyhow!("Unix sockets are not supported on this platform"))?
761          }
762
763          #[cfg(unix)]
764          {
765            let stream = match UnixStream::connect(proxy_unix_str).await {
766              Ok(stream) => stream,
767              Err(err) => {
768                self.mark_backend_failure(&upstream).await;
769                let (status_code, log_prefix) = Self::io_error_status(&err);
770                if let Some(response) = self
771                  .retry_or_respond(
772                    error_logger,
773                    &err,
774                    retry_connection,
775                    !proxy_to_vector.is_empty(),
776                    status_code,
777                    log_prefix,
778                  )
779                  .await
780                {
781                  return Ok(response);
782                }
783                continue;
784              }
785            };
786
787            #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
788            let stream = match SendUnixStreamPoll::new_comp_io(stream) {
789              Ok(stream) => stream,
790              Err(err) => {
791                self.mark_backend_failure(&upstream).await;
792                if let Some(response) = self
793                  .retry_or_respond(
794                    error_logger,
795                    &err,
796                    retry_connection,
797                    !proxy_to_vector.is_empty(),
798                    StatusCode::BAD_GATEWAY,
799                    "Bad gateway",
800                  )
801                  .await
802                {
803                  return Ok(response);
804                }
805                continue;
806              }
807            };
808
809            Connection::Unix(stream)
810          }
811        } else {
812          let stream = match TcpStream::connect(&addr).await {
813            Ok(stream) => stream,
814            Err(err) => {
815              self.mark_backend_failure(&upstream).await;
816              let (status_code, log_prefix) = Self::io_error_status(&err);
817              if let Some(response) = self
818                .retry_or_respond(
819                  error_logger,
820                  &err,
821                  retry_connection,
822                  !proxy_to_vector.is_empty(),
823                  status_code,
824                  log_prefix,
825                )
826                .await
827              {
828                return Ok(response);
829              }
830              continue;
831            }
832          };
833
834          if let Err(err) = stream.set_nodelay(true) {
835            self.mark_backend_failure(&upstream).await;
836            if let Some(response) = self
837              .retry_or_respond(
838                error_logger,
839                &err,
840                retry_connection,
841                !proxy_to_vector.is_empty(),
842                StatusCode::BAD_GATEWAY,
843                "Bad gateway",
844              )
845              .await
846            {
847              return Ok(response);
848            }
849            continue;
850          };
851
852          #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
853          let stream = match SendTcpStreamPoll::new_comp_io(stream) {
854            Ok(stream) => stream,
855            Err(err) => {
856              self.mark_backend_failure(&upstream).await;
857              if let Some(response) = self
858                .retry_or_respond(
859                  error_logger,
860                  &err,
861                  retry_connection,
862                  !proxy_to_vector.is_empty(),
863                  StatusCode::BAD_GATEWAY,
864                  "Bad gateway",
865                )
866                .await
867              {
868                return Ok(response);
869              }
870              continue;
871            }
872          };
873
874          Connection::Tcp(stream)
875        };
876
877        let proxy_header_to_write = match proxy_header {
878          Some(ProxyHeader::V1) => {
879            let is_ipv4 = socket_data.local_addr.ip().to_canonical().is_ipv4()
880              && socket_data.remote_addr.ip().to_canonical().is_ipv4();
881            let local_addr = if is_ipv4 {
882              match socket_data.local_addr.ip().to_canonical() {
883                IpAddr::V4(ip) => ip.to_string(),
884                IpAddr::V6(ip) => ip
885                  .to_ipv4_mapped()
886                  .ok_or(anyhow::anyhow!("Connection IP address type mismatch"))?
887                  .to_string(),
888              }
889            } else {
890              match socket_data.local_addr.ip().to_canonical() {
891                IpAddr::V4(ip) => ip
892                  .to_ipv6_mapped()
893                  .segments()
894                  .iter()
895                  .map(|seg| format!("{:04x}", seg))
896                  .collect::<Vec<_>>()
897                  .join(":"),
898                IpAddr::V6(ip) => ip
899                  .segments()
900                  .iter()
901                  .map(|seg| format!("{:04x}", seg))
902                  .collect::<Vec<_>>()
903                  .join(":"),
904              }
905            };
906            let remote_addr = if is_ipv4 {
907              match socket_data.remote_addr.ip().to_canonical() {
908                IpAddr::V4(ip) => ip.to_string(),
909                IpAddr::V6(ip) => ip
910                  .to_ipv4_mapped()
911                  .ok_or(anyhow::anyhow!("Connection IP address type mismatch"))?
912                  .to_string(),
913              }
914            } else {
915              match socket_data.remote_addr.ip().to_canonical() {
916                IpAddr::V4(ip) => ip
917                  .to_ipv6_mapped()
918                  .segments()
919                  .iter()
920                  .map(|seg| format!("{:04x}", seg))
921                  .collect::<Vec<_>>()
922                  .join(":"),
923                IpAddr::V6(ip) => ip
924                  .segments()
925                  .iter()
926                  .map(|seg| format!("{:04x}", seg))
927                  .collect::<Vec<_>>()
928                  .join(":"),
929              }
930            };
931            let local_port = socket_data.local_addr.port();
932            let remote_port = socket_data.remote_addr.port();
933            let header = format!(
934              "PROXY {} {} {} {} {}\r\n",
935              if is_ipv4 { "TCP4" } else { "TCP6" },
936              remote_addr,
937              local_addr,
938              remote_port,
939              local_port,
940            );
941            Some(header.into_bytes())
942          }
943          Some(ProxyHeader::V2) => {
944            let is_ipv4 = socket_data.local_addr.ip().to_canonical().is_ipv4()
945              && socket_data.remote_addr.ip().to_canonical().is_ipv4();
946            let addresses = if is_ipv4 {
947              ppp::v2::Addresses::IPv4(ppp::v2::IPv4::new(
948                match socket_data.remote_addr.ip().to_canonical() {
949                  IpAddr::V4(ip) => ip,
950                  IpAddr::V6(ip) => ip
951                    .to_ipv4_mapped()
952                    .ok_or(anyhow::anyhow!("Connection IP address type mismatch"))?,
953                },
954                match socket_data.local_addr.ip().to_canonical() {
955                  IpAddr::V4(ip) => ip,
956                  IpAddr::V6(ip) => ip
957                    .to_ipv4_mapped()
958                    .ok_or(anyhow::anyhow!("Connection IP address type mismatch"))?,
959                },
960                socket_data.remote_addr.port(),
961                socket_data.local_addr.port(),
962              ))
963            } else {
964              ppp::v2::Addresses::IPv6(ppp::v2::IPv6::new(
965                match socket_data.remote_addr.ip().to_canonical() {
966                  IpAddr::V4(ip) => ip.to_ipv6_mapped(),
967                  IpAddr::V6(ip) => ip,
968                },
969                match socket_data.local_addr.ip().to_canonical() {
970                  IpAddr::V4(ip) => ip.to_ipv6_mapped(),
971                  IpAddr::V6(ip) => ip,
972                },
973                socket_data.remote_addr.port(),
974                socket_data.local_addr.port(),
975              ))
976            };
977            let header_builder = ppp::v2::Builder::with_addresses(
978              ppp::v2::Version::Two | ppp::v2::Command::Proxy,
979              ppp::v2::Protocol::Stream,
980              addresses,
981            );
982            Some(header_builder.build()?)
983          }
984          _ => None,
985        };
986
987        let mut stream = stream; // Make the stream a mutable variable (to be able to write PROXY protocol header to it).
988
989        if let Some(proxy_header_to_write) = proxy_header_to_write {
990          if let Err(err) = stream.write_all(&proxy_header_to_write).await {
991            self.mark_backend_failure(&upstream).await;
992            if let Some(response) = self
993              .retry_or_respond(
994                error_logger,
995                &err,
996                retry_connection,
997                !proxy_to_vector.is_empty(),
998                StatusCode::BAD_GATEWAY,
999                "Bad gateway",
1000              )
1001              .await
1002            {
1003              return Ok(response);
1004            }
1005            continue;
1006          }
1007        }
1008
1009        // Safety: the drop guard is dropped when the connection future is completed,
1010        // and after the underlying connection is moved across threads,
1011        // see the "http_proxy_handshake" function.
1012        #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
1013        let drop_guard = unsafe { stream.get_drop_guard() };
1014
1015        let sender = if !encrypted {
1016          let sender = match http_proxy_handshake(
1017            stream,
1018            enable_http2_only_config,
1019            #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
1020            drop_guard,
1021          )
1022          .await
1023          {
1024            Ok(sender) => sender,
1025            Err(err) => {
1026              self.mark_backend_failure(&upstream).await;
1027              if let Some(response) = self
1028                .retry_or_respond(
1029                  error_logger,
1030                  &err,
1031                  retry_connection,
1032                  !proxy_to_vector.is_empty(),
1033                  StatusCode::BAD_GATEWAY,
1034                  "Bad gateway",
1035                )
1036                .await
1037              {
1038                return Ok(response);
1039              }
1040              continue;
1041            }
1042          };
1043
1044          sender
1045        } else {
1046          let enable_http2_config = enable_http2_only_config || (enable_http2_config && !is_http_upgrade);
1047          let mut tls_client_config = (if disable_certificate_verification {
1048            rustls::ClientConfig::builder()
1049              .dangerous()
1050              .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
1051          } else if let Ok(client_config) = BuilderVerifierExt::with_platform_verifier(rustls::ClientConfig::builder())
1052          {
1053            client_config
1054          } else {
1055            rustls::ClientConfig::builder().with_webpki_verifier(
1056              WebPkiServerVerifier::builder(Arc::new(rustls::RootCertStore {
1057                roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
1058              }))
1059              .build()?,
1060            )
1061          })
1062          .with_no_client_auth();
1063          if enable_http2_only_config {
1064            tls_client_config.alpn_protocols = vec![b"h2".to_vec()];
1065          } else if enable_http2_config {
1066            tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];
1067          } else {
1068            tls_client_config.alpn_protocols = vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()];
1069          }
1070          let connector = TlsConnector::from(Arc::new(tls_client_config));
1071          let domain = ServerName::try_from(host)?.to_owned();
1072
1073          let tls_stream = match connector.connect(domain, stream).await {
1074            Ok(stream) => stream,
1075            Err(err) => {
1076              self.mark_backend_failure(&upstream).await;
1077              if let Some(response) = self
1078                .retry_or_respond(
1079                  error_logger,
1080                  &err,
1081                  retry_connection,
1082                  !proxy_to_vector.is_empty(),
1083                  StatusCode::BAD_GATEWAY,
1084                  "Bad gateway",
1085                )
1086                .await
1087              {
1088                return Ok(response);
1089              }
1090              continue;
1091            }
1092          };
1093
1094          // Enable HTTP/2 when the ALPN protocol is "h2"
1095          let enable_http2 = enable_http2_config && tls_stream.get_ref().1.alpn_protocol() == Some(b"h2");
1096
1097          let sender = match http_proxy_handshake(
1098            tls_stream,
1099            enable_http2,
1100            #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
1101            drop_guard,
1102          )
1103          .await
1104          {
1105            Ok(sender) => sender,
1106            Err(err) => {
1107              self.mark_backend_failure(&upstream).await;
1108              if let Some(response) = self
1109                .retry_or_respond(
1110                  error_logger,
1111                  &err,
1112                  retry_connection,
1113                  !proxy_to_vector.is_empty(),
1114                  StatusCode::BAD_GATEWAY,
1115                  "Bad gateway",
1116                )
1117                .await
1118              {
1119                return Ok(response);
1120              }
1121              continue;
1122            }
1123          };
1124
1125          sender
1126        };
1127
1128        let proxy_request = Request::from_parts(proxy_request_parts, request_body);
1129
1130        return http_proxy(
1131          sender,
1132          connection_pool_item,
1133          proxy_request,
1134          error_logger,
1135          proxy_intercept_errors,
1136          tracked_connection,
1137          enable_keepalive,
1138        )
1139        .await;
1140      } else {
1141        let request_parts = request_parts.ok_or(anyhow::anyhow!("Request parts are missing"))?;
1142        error_logger.log("No upstreams available").await;
1143        return Ok(ResponseData {
1144          request: Some(Request::from_parts(request_parts, request_body)),
1145          response: None,
1146          response_status: Some(StatusCode::SERVICE_UNAVAILABLE), // No upstreams available
1147          response_headers: None,
1148          new_remote_address: None,
1149        });
1150      }
1151    }
1152  }
1153
1154  async fn metric_data_before_handler(
1155    &mut self,
1156    _request: &Request<BoxBody<Bytes, std::io::Error>>,
1157    _socket_data: &SocketData,
1158    _metrics_sender: &MetricsMultiSender,
1159  ) {
1160    self.selected_backends_metrics = Some(Vec::new());
1161    self.unhealthy_backends_metrics = Some(Vec::new());
1162  }
1163
1164  async fn metric_data_after_handler(&mut self, metrics_sender: &MetricsMultiSender) {
1165    if let Some(selected_backends_metrics) = self.selected_backends_metrics.take() {
1166      for selected_backend in selected_backends_metrics {
1167        let mut attributes = Vec::new();
1168        attributes.push((
1169          "ferron.proxy.backend_url",
1170          MetricAttributeValue::String(selected_backend.proxy_to),
1171        ));
1172        if let Some(backend_unix) = selected_backend.proxy_unix {
1173          attributes.push((
1174            "ferron.proxy.backend_unix_path",
1175            MetricAttributeValue::String(backend_unix),
1176          ));
1177        }
1178        metrics_sender
1179          .send(Metric::new(
1180            "ferron.proxy.backends.selected",
1181            attributes,
1182            MetricType::Counter,
1183            MetricValue::U64(1),
1184            Some("{backend}"),
1185            Some("Number of times a backend server was selected."),
1186          ))
1187          .await;
1188      }
1189    }
1190    if let Some(unhealthy_backends_metrics) = self.unhealthy_backends_metrics.take() {
1191      for unhealthy_backend in unhealthy_backends_metrics {
1192        let mut attributes = Vec::new();
1193        attributes.push((
1194          "ferron.proxy.backend_url",
1195          MetricAttributeValue::String(unhealthy_backend.proxy_to),
1196        ));
1197        if let Some(backend_unix) = unhealthy_backend.proxy_unix {
1198          attributes.push((
1199            "ferron.proxy.backend_unix_path",
1200            MetricAttributeValue::String(backend_unix),
1201          ));
1202        }
1203        metrics_sender
1204          .send(Metric::new(
1205            "ferron.proxy.backends.unhealthy",
1206            attributes,
1207            MetricType::Counter,
1208            MetricValue::U64(1),
1209            Some("{backend}"),
1210            Some("Number of health check failures for a backend server."),
1211          ))
1212          .await;
1213      }
1214    }
1215    metrics_sender
1216      .send(Metric::new(
1217        "ferron.proxy.requests",
1218        vec![(
1219          "ferron.proxy.connection_reused",
1220          MetricAttributeValue::Bool(self.connection_reused),
1221        )],
1222        MetricType::Counter,
1223        MetricValue::U64(1),
1224        Some("{request}"),
1225        Some("Number of reverse proxy requests."),
1226      ))
1227      .await;
1228  }
1229}