ferron_common/http_proxy/
builder.rs1use std::sync::atomic::AtomicUsize;
2use std::sync::Arc;
3use std::time::Duration;
4use std::{collections::HashMap, net::IpAddr};
5
6use hickory_resolver::{config::ResolverConfig, name_server::TokioConnectionProvider};
7use hyper::header::HeaderName;
8use tokio::sync::RwLock;
9
10use super::{Connections, LoadBalancerAlgorithm, LoadBalancerAlgorithmInner, ProxyHeader, ProxyToKey, ReverseProxy};
11use crate::{
12 http_proxy::{SrvUpstreamData, Upstream, UpstreamInner},
13 util::TtlCache,
14};
15
16pub struct ReverseProxyBuilder<'a> {
18 pub(super) connections: &'a mut Connections,
19 #[allow(clippy::type_complexity)]
20 pub(super) upstreams: Vec<(Upstream, Option<usize>, Option<Duration>)>,
21 pub(super) lb_algorithm: LoadBalancerAlgorithm,
22 pub(super) lb_health_check_window: Duration,
23 pub(super) lb_health_check_max_fails: u64,
24 pub(super) lb_health_check: bool,
25 pub(super) lb_retry_connection: bool,
26 pub(super) proxy_no_verification: bool,
27 pub(super) proxy_intercept_errors: bool,
28 pub(super) proxy_http2_only: bool,
29 pub(super) proxy_http2: bool,
30 pub(super) proxy_keepalive: bool,
31 pub(super) proxy_proxy_header: Option<ProxyHeader>,
32 pub(super) proxy_request_header: Vec<(HeaderName, String)>,
33 pub(super) proxy_request_header_replace: Vec<(HeaderName, String)>,
34 pub(super) proxy_request_header_remove: Vec<HeaderName>,
35 pub(super) rewrite_host: bool,
36}
37
38impl<'a> ReverseProxyBuilder<'a> {
39 pub fn upstream(
46 mut self,
47 proxy_to: String,
48 proxy_unix: Option<String>,
49 local_limit: Option<usize>,
50 keepalive_idle_timeout: Option<Duration>,
51 ) -> Self {
52 self.upstreams.push((
53 Upstream::Static(UpstreamInner { proxy_to, proxy_unix }),
54 local_limit,
55 keepalive_idle_timeout,
56 ));
57 self
58 }
59
60 pub fn upstream_srv(
66 mut self,
67 to: String,
68 local_limit: Option<usize>,
69 keepalive_idle_timeout: Option<Duration>,
70 secondary_runtime_handle: tokio::runtime::Handle,
71 dns_servers: Vec<IpAddr>,
72 ) -> Self {
73 let dns_resolver = secondary_runtime_handle.block_on(async {
74 if !dns_servers.is_empty() {
75 hickory_resolver::Resolver::builder_with_config(
76 ResolverConfig::from_parts(
77 None,
78 vec![],
79 hickory_resolver::config::NameServerConfigGroup::from_ips_clear(&dns_servers, 53, true),
80 ),
81 TokioConnectionProvider::default(),
82 )
83 .build()
84 } else {
85 hickory_resolver::Resolver::builder_tokio()
86 .unwrap_or(hickory_resolver::Resolver::builder_with_config(
87 ResolverConfig::default(),
88 TokioConnectionProvider::default(),
89 ))
90 .build()
91 }
92 });
93 self.upstreams.push((
94 Upstream::Srv(SrvUpstreamData {
95 to,
96 secondary_runtime_handle,
97 dns_resolver: Arc::new(dns_resolver),
98 }),
99 local_limit,
100 keepalive_idle_timeout,
101 ));
102 self
103 }
104
105 pub fn lb_algorithm(mut self, algorithm: LoadBalancerAlgorithm) -> Self {
107 self.lb_algorithm = algorithm;
108 self
109 }
110
111 pub fn lb_health_check_window(mut self, window: Duration) -> Self {
113 self.lb_health_check_window = window;
114 self
115 }
116
117 pub fn lb_health_check_max_fails(mut self, max_fails: u64) -> Self {
119 self.lb_health_check_max_fails = max_fails;
120 self
121 }
122
123 pub fn lb_health_check(mut self, enable: bool) -> Self {
125 self.lb_health_check = enable;
126 self
127 }
128
129 pub fn proxy_no_verification(mut self, no_verification: bool) -> Self {
131 self.proxy_no_verification = no_verification;
132 self
133 }
134
135 pub fn proxy_intercept_errors(mut self, intercept_errors: bool) -> Self {
137 self.proxy_intercept_errors = intercept_errors;
138 self
139 }
140
141 pub fn lb_retry_connection(mut self, retry: bool) -> Self {
143 self.lb_retry_connection = retry;
144 self
145 }
146
147 pub fn proxy_http2_only(mut self, http2_only: bool) -> Self {
149 self.proxy_http2_only = http2_only;
150 self
151 }
152
153 pub fn proxy_http2(mut self, http2: bool) -> Self {
155 self.proxy_http2 = http2;
156 self
157 }
158
159 pub fn proxy_keepalive(mut self, keepalive: bool) -> Self {
161 self.proxy_keepalive = keepalive;
162 self
163 }
164
165 pub fn proxy_proxy_header(mut self, proxy_header: Option<ProxyHeader>) -> Self {
167 self.proxy_proxy_header = proxy_header;
168 self
169 }
170
171 pub fn proxy_request_header(mut self, header_name: HeaderName, header_value: String) -> Self {
173 self.proxy_request_header.push((header_name, header_value));
174 self
175 }
176
177 pub fn proxy_request_header_replace(mut self, header_name: HeaderName, header_value: String) -> Self {
179 self.proxy_request_header_replace.push((header_name, header_value));
180 self
181 }
182
183 pub fn proxy_request_header_remove(mut self, header_name: HeaderName) -> Self {
185 self.proxy_request_header_remove.push(header_name);
186 self
187 }
188
189 pub fn rewrite_host(mut self, rewrite_host: bool) -> Self {
191 self.rewrite_host = rewrite_host;
192 self
193 }
194
195 pub fn build(mut self) -> ReverseProxy {
197 let connections = self.connections.connections.clone();
198 #[cfg(unix)]
199 let unix_connections = self.connections.unix_connections.clone();
200
201 let proxy_to = self
202 .upstreams
203 .drain(..)
204 .map(|(upstream, local_limit, keepalive_idle_timeout)| {
205 let is_unix_socket = match &upstream {
206 Upstream::Static(inner) => Some(inner.proxy_unix.is_some()),
207 Upstream::Srv(_) => Some(false), };
209 (
210 upstream,
211 is_unix_socket.and_then(|is_unix_socket| {
212 apply_local_limit(
213 local_limit,
214 is_unix_socket,
215 &connections,
216 #[cfg(unix)]
217 &unix_connections,
218 )
219 }),
220 keepalive_idle_timeout,
221 )
222 })
223 .collect::<Vec<ProxyToKey>>();
224
225 let proxy_to = Arc::new(proxy_to);
226 let load_balancer_algorithm = if let Some(algorithm) = self
227 .connections
228 .load_balancer_cache
229 .get(&(self.lb_algorithm, proxy_to.clone()))
230 {
231 algorithm.clone()
232 } else {
233 let new_algorithm = Arc::new(build_load_balancer_algorithm(self.lb_algorithm));
234 self
235 .connections
236 .load_balancer_cache
237 .insert((self.lb_algorithm, proxy_to.clone()), new_algorithm.clone());
238 new_algorithm
239 };
240 let failed_backends = if let Some(failed) = self.connections.failed_backend_cache.get(&(
241 self.lb_health_check_window,
242 self.lb_health_check_max_fails,
243 proxy_to.clone(),
244 )) {
245 failed.clone()
246 } else {
247 let new_failed = Arc::new(RwLock::new(TtlCache::new(self.lb_health_check_window)));
248 self.connections.failed_backend_cache.insert(
249 (
250 self.lb_health_check_window,
251 self.lb_health_check_max_fails,
252 proxy_to.clone(),
253 ),
254 new_failed.clone(),
255 );
256 new_failed
257 };
258 ReverseProxy {
259 failed_backends,
260 load_balancer_algorithm,
261 proxy_to,
262 health_check_max_fails: self.lb_health_check_max_fails,
263 enable_health_check: self.lb_health_check,
264 disable_certificate_verification: self.proxy_no_verification,
265 proxy_intercept_errors: self.proxy_intercept_errors,
266 retry_connection: self.lb_retry_connection,
267 proxy_http2_only: self.proxy_http2_only,
268 proxy_http2: self.proxy_http2,
269 proxy_keepalive: self.proxy_keepalive,
270 proxy_header: self.proxy_proxy_header,
271 headers_to_add: Arc::new(self.proxy_request_header.drain(..).collect()),
272 headers_to_replace: Arc::new(self.proxy_request_header_replace.drain(..).collect()),
273 headers_to_remove: Arc::new(self.proxy_request_header_remove.drain(..).collect()),
274 rewrite_host: self.rewrite_host,
275 connections,
276 #[cfg(unix)]
277 unix_connections,
278 }
279 }
280}
281
282fn build_load_balancer_algorithm(algorithm: LoadBalancerAlgorithm) -> LoadBalancerAlgorithmInner {
283 match algorithm {
284 LoadBalancerAlgorithm::TwoRandomChoices => {
285 LoadBalancerAlgorithmInner::TwoRandomChoices(Arc::new(RwLock::new(HashMap::new())))
286 }
287 LoadBalancerAlgorithm::LeastConnections => {
288 LoadBalancerAlgorithmInner::LeastConnections(Arc::new(RwLock::new(HashMap::new())))
289 }
290 LoadBalancerAlgorithm::RoundRobin => LoadBalancerAlgorithmInner::RoundRobin(Arc::new(AtomicUsize::new(0))),
291 LoadBalancerAlgorithm::Random => LoadBalancerAlgorithmInner::Random,
292 }
293}
294
295fn apply_local_limit(
296 local_limit: Option<usize>,
297 is_unix_socket: bool,
298 connections: &super::ConnectionPool,
299 #[cfg(unix)] unix_connections: &super::ConnectionPool,
300) -> Option<usize> {
301 #[allow(clippy::bind_instead_of_map)]
302 local_limit.and_then(|limit| {
303 if is_unix_socket {
304 #[cfg(unix)]
305 {
306 Some(unix_connections.set_local_limit(limit))
307 }
308 #[cfg(not(unix))]
309 {
310 None
311 }
312 } else {
313 Some(connections.set_local_limit(limit))
314 }
315 })
316}