ferron_common/http_proxy/
builder.rs

1use std::sync::atomic::AtomicUsize;
2use std::sync::Arc;
3use std::time::Duration;
4use std::{collections::HashMap, net::IpAddr};
5
6use hickory_resolver::{config::ResolverConfig, name_server::TokioConnectionProvider};
7use hyper::header::HeaderName;
8use tokio::sync::RwLock;
9
10use super::{Connections, LoadBalancerAlgorithm, LoadBalancerAlgorithmInner, ProxyHeader, ProxyToKey, ReverseProxy};
11use crate::{
12  http_proxy::{SrvUpstreamData, Upstream, UpstreamInner},
13  util::TtlCache,
14};
15
16/// Builder for configuring and constructing a [`ReverseProxy`].
17pub struct ReverseProxyBuilder<'a> {
18  pub(super) connections: &'a mut Connections,
19  #[allow(clippy::type_complexity)]
20  pub(super) upstreams: Vec<(Upstream, Option<usize>, Option<Duration>)>,
21  pub(super) lb_algorithm: LoadBalancerAlgorithm,
22  pub(super) lb_health_check_window: Duration,
23  pub(super) lb_health_check_max_fails: u64,
24  pub(super) lb_health_check: bool,
25  pub(super) lb_retry_connection: bool,
26  pub(super) proxy_no_verification: bool,
27  pub(super) proxy_intercept_errors: bool,
28  pub(super) proxy_http2_only: bool,
29  pub(super) proxy_http2: bool,
30  pub(super) proxy_keepalive: bool,
31  pub(super) proxy_proxy_header: Option<ProxyHeader>,
32  pub(super) proxy_request_header: Vec<(HeaderName, String)>,
33  pub(super) proxy_request_header_replace: Vec<(HeaderName, String)>,
34  pub(super) proxy_request_header_remove: Vec<HeaderName>,
35  pub(super) rewrite_host: bool,
36}
37
38impl<'a> ReverseProxyBuilder<'a> {
39  /// Adds an upstream backend target.
40  ///
41  /// `proxy_to` is the backend URL (for example `http://127.0.0.1:8080`).
42  /// `proxy_unix` can be used to target a Unix socket path.
43  /// `local_limit` controls per-upstream connection limit.
44  /// `keepalive_idle_timeout` sets pooled connection idle timeout.
45  pub fn upstream(
46    mut self,
47    proxy_to: String,
48    proxy_unix: Option<String>,
49    local_limit: Option<usize>,
50    keepalive_idle_timeout: Option<Duration>,
51  ) -> Self {
52    self.upstreams.push((
53      Upstream::Static(UpstreamInner { proxy_to, proxy_unix }),
54      local_limit,
55      keepalive_idle_timeout,
56    ));
57    self
58  }
59
60  /// Adds a dynamic (SRV-based) upstream backend target.
61  ///
62  /// `to` is the backend URL (for example `http://_http._tcp.example.com`).
63  /// `local_limit` controls per-upstream connection limit.
64  /// `keepalive_idle_timeout` sets pooled connection idle timeout.
65  pub fn upstream_srv(
66    mut self,
67    to: String,
68    local_limit: Option<usize>,
69    keepalive_idle_timeout: Option<Duration>,
70    secondary_runtime_handle: tokio::runtime::Handle,
71    dns_servers: Vec<IpAddr>,
72  ) -> Self {
73    let dns_resolver = secondary_runtime_handle.block_on(async {
74      if !dns_servers.is_empty() {
75        hickory_resolver::Resolver::builder_with_config(
76          ResolverConfig::from_parts(
77            None,
78            vec![],
79            hickory_resolver::config::NameServerConfigGroup::from_ips_clear(&dns_servers, 53, true),
80          ),
81          TokioConnectionProvider::default(),
82        )
83        .build()
84      } else {
85        hickory_resolver::Resolver::builder_tokio()
86          .unwrap_or(hickory_resolver::Resolver::builder_with_config(
87            ResolverConfig::default(),
88            TokioConnectionProvider::default(),
89          ))
90          .build()
91      }
92    });
93    self.upstreams.push((
94      Upstream::Srv(SrvUpstreamData {
95        to,
96        secondary_runtime_handle,
97        dns_resolver: Arc::new(dns_resolver),
98      }),
99      local_limit,
100      keepalive_idle_timeout,
101    ));
102    self
103  }
104
105  /// Sets load balancing algorithm.
106  pub fn lb_algorithm(mut self, algorithm: LoadBalancerAlgorithm) -> Self {
107    self.lb_algorithm = algorithm;
108    self
109  }
110
111  /// Sets health-check TTL window for failed backend counters.
112  pub fn lb_health_check_window(mut self, window: Duration) -> Self {
113    self.lb_health_check_window = window;
114    self
115  }
116
117  /// Sets maximum consecutive failed checks before a backend is considered unhealthy.
118  pub fn lb_health_check_max_fails(mut self, max_fails: u64) -> Self {
119    self.lb_health_check_max_fails = max_fails;
120    self
121  }
122
123  /// Enables or disables backend health checks.
124  pub fn lb_health_check(mut self, enable: bool) -> Self {
125    self.lb_health_check = enable;
126    self
127  }
128
129  /// Disables certificate verification for upstream TLS connections.
130  pub fn proxy_no_verification(mut self, no_verification: bool) -> Self {
131    self.proxy_no_verification = no_verification;
132    self
133  }
134
135  /// Intercepts upstream errors and converts them to proxy-generated responses.
136  pub fn proxy_intercept_errors(mut self, intercept_errors: bool) -> Self {
137    self.proxy_intercept_errors = intercept_errors;
138    self
139  }
140
141  /// Enables retrying a different backend when connection setup fails.
142  pub fn lb_retry_connection(mut self, retry: bool) -> Self {
143    self.lb_retry_connection = retry;
144    self
145  }
146
147  /// Forces HTTP/2-only upstream connections.
148  pub fn proxy_http2_only(mut self, http2_only: bool) -> Self {
149    self.proxy_http2_only = http2_only;
150    self
151  }
152
153  /// Enables HTTP/2 support for upstream connections.
154  pub fn proxy_http2(mut self, http2: bool) -> Self {
155    self.proxy_http2 = http2;
156    self
157  }
158
159  /// Enables connection pooling and keepalive reuse.
160  pub fn proxy_keepalive(mut self, keepalive: bool) -> Self {
161    self.proxy_keepalive = keepalive;
162    self
163  }
164
165  /// Sets PROXY protocol header mode for upstream connections.
166  pub fn proxy_proxy_header(mut self, proxy_header: Option<ProxyHeader>) -> Self {
167    self.proxy_proxy_header = proxy_header;
168    self
169  }
170
171  /// Adds a request header to upstream requests.
172  pub fn proxy_request_header(mut self, header_name: HeaderName, header_value: String) -> Self {
173    self.proxy_request_header.push((header_name, header_value));
174    self
175  }
176
177  /// Replaces a request header on upstream requests.
178  pub fn proxy_request_header_replace(mut self, header_name: HeaderName, header_value: String) -> Self {
179    self.proxy_request_header_replace.push((header_name, header_value));
180    self
181  }
182
183  /// Removes a request header from upstream requests.
184  pub fn proxy_request_header_remove(mut self, header_name: HeaderName) -> Self {
185    self.proxy_request_header_remove.push(header_name);
186    self
187  }
188
189  /// Enables or disables `Host` header rewriting for non-HTTPS upstream requests.
190  pub fn rewrite_host(mut self, rewrite_host: bool) -> Self {
191    self.rewrite_host = rewrite_host;
192    self
193  }
194
195  /// Builds a [`ReverseProxy`] from the configured options.
196  pub fn build(mut self) -> ReverseProxy {
197    let connections = self.connections.connections.clone();
198    #[cfg(unix)]
199    let unix_connections = self.connections.unix_connections.clone();
200
201    let proxy_to = self
202      .upstreams
203      .drain(..)
204      .map(|(upstream, local_limit, keepalive_idle_timeout)| {
205        let is_unix_socket = match &upstream {
206          Upstream::Static(inner) => Some(inner.proxy_unix.is_some()),
207          Upstream::Srv(_) => Some(false), // SRV records lead to A/AAAA lookups, so they cannot be Unix sockets
208        };
209        (
210          upstream,
211          is_unix_socket.and_then(|is_unix_socket| {
212            apply_local_limit(
213              local_limit,
214              is_unix_socket,
215              &connections,
216              #[cfg(unix)]
217              &unix_connections,
218            )
219          }),
220          keepalive_idle_timeout,
221        )
222      })
223      .collect::<Vec<ProxyToKey>>();
224
225    let proxy_to = Arc::new(proxy_to);
226    let load_balancer_algorithm = if let Some(algorithm) = self
227      .connections
228      .load_balancer_cache
229      .get(&(self.lb_algorithm, proxy_to.clone()))
230    {
231      algorithm.clone()
232    } else {
233      let new_algorithm = Arc::new(build_load_balancer_algorithm(self.lb_algorithm));
234      self
235        .connections
236        .load_balancer_cache
237        .insert((self.lb_algorithm, proxy_to.clone()), new_algorithm.clone());
238      new_algorithm
239    };
240    let failed_backends = if let Some(failed) = self.connections.failed_backend_cache.get(&(
241      self.lb_health_check_window,
242      self.lb_health_check_max_fails,
243      proxy_to.clone(),
244    )) {
245      failed.clone()
246    } else {
247      let new_failed = Arc::new(RwLock::new(TtlCache::new(self.lb_health_check_window)));
248      self.connections.failed_backend_cache.insert(
249        (
250          self.lb_health_check_window,
251          self.lb_health_check_max_fails,
252          proxy_to.clone(),
253        ),
254        new_failed.clone(),
255      );
256      new_failed
257    };
258    ReverseProxy {
259      failed_backends,
260      load_balancer_algorithm,
261      proxy_to,
262      health_check_max_fails: self.lb_health_check_max_fails,
263      enable_health_check: self.lb_health_check,
264      disable_certificate_verification: self.proxy_no_verification,
265      proxy_intercept_errors: self.proxy_intercept_errors,
266      retry_connection: self.lb_retry_connection,
267      proxy_http2_only: self.proxy_http2_only,
268      proxy_http2: self.proxy_http2,
269      proxy_keepalive: self.proxy_keepalive,
270      proxy_header: self.proxy_proxy_header,
271      headers_to_add: Arc::new(self.proxy_request_header.drain(..).collect()),
272      headers_to_replace: Arc::new(self.proxy_request_header_replace.drain(..).collect()),
273      headers_to_remove: Arc::new(self.proxy_request_header_remove.drain(..).collect()),
274      rewrite_host: self.rewrite_host,
275      connections,
276      #[cfg(unix)]
277      unix_connections,
278    }
279  }
280}
281
282fn build_load_balancer_algorithm(algorithm: LoadBalancerAlgorithm) -> LoadBalancerAlgorithmInner {
283  match algorithm {
284    LoadBalancerAlgorithm::TwoRandomChoices => {
285      LoadBalancerAlgorithmInner::TwoRandomChoices(Arc::new(RwLock::new(HashMap::new())))
286    }
287    LoadBalancerAlgorithm::LeastConnections => {
288      LoadBalancerAlgorithmInner::LeastConnections(Arc::new(RwLock::new(HashMap::new())))
289    }
290    LoadBalancerAlgorithm::RoundRobin => LoadBalancerAlgorithmInner::RoundRobin(Arc::new(AtomicUsize::new(0))),
291    LoadBalancerAlgorithm::Random => LoadBalancerAlgorithmInner::Random,
292  }
293}
294
295fn apply_local_limit(
296  local_limit: Option<usize>,
297  is_unix_socket: bool,
298  connections: &super::ConnectionPool,
299  #[cfg(unix)] unix_connections: &super::ConnectionPool,
300) -> Option<usize> {
301  #[allow(clippy::bind_instead_of_map)]
302  local_limit.and_then(|limit| {
303    if is_unix_socket {
304      #[cfg(unix)]
305      {
306        Some(unix_connections.set_local_limit(limit))
307      }
308      #[cfg(not(unix))]
309      {
310        None
311      }
312    } else {
313      Some(connections.set_local_limit(limit))
314    }
315  })
316}