1mod builder;
2mod load_balancer;
3mod proxy_client;
4mod request_parts;
5mod send_net_io;
6mod send_request;
7
8use std::collections::HashMap;
9use std::error::Error;
10use std::net::IpAddr;
11use std::pin::Pin;
12use std::str::FromStr;
13use std::sync::atomic::AtomicUsize;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker};
16use std::time::Duration;
17
18use async_trait::async_trait;
19use bytes::Bytes;
20use connpool::{Item, Pool};
21use futures_util::FutureExt;
22use http_body_util::combinators::BoxBody;
23use hyper::header::{self, HeaderName};
24use hyper::{Request, StatusCode, Uri};
25#[cfg(feature = "runtime-monoio")]
26use monoio::net::TcpStream;
27#[cfg(all(feature = "runtime-monoio", unix))]
28use monoio::net::UnixStream;
29use rustls::client::WebPkiServerVerifier;
30use rustls_pki_types::ServerName;
31use rustls_platform_verifier::BuilderVerifierExt;
32use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
33#[cfg(feature = "runtime-tokio")]
34use tokio::net::TcpStream;
35#[cfg(all(feature = "runtime-tokio", unix))]
36use tokio::net::UnixStream;
37use tokio::sync::RwLock;
38use tokio_rustls::TlsConnector;
39#[cfg(feature = "runtime-vibeio")]
40use vibeio::net::TcpStream;
41#[cfg(all(feature = "runtime-vibeio", unix))]
42use vibeio::net::UnixStream;
43
44use crate::config::ServerConfiguration;
45use crate::http_proxy::send_request::SendRequestWrapper;
46use crate::logging::ErrorLogger;
47use crate::modules::{ModuleHandlers, ResponseData, SocketData};
48use crate::observability::{Metric, MetricAttributeValue, MetricType, MetricValue, MetricsMultiSender};
49use crate::util::{NoServerVerifier, TtlCache};
50
51pub use self::builder::ReverseProxyBuilder;
52#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
53use self::send_net_io::{SendTcpStreamPoll, SendTcpStreamPollDropGuard};
54#[cfg(all(any(feature = "runtime-vibeio", feature = "runtime-monoio"), unix))]
55use self::send_net_io::{SendUnixStreamPoll, SendUnixStreamPollDropGuard};
56use self::{
57 load_balancer::{determine_proxy_to, resolve_upstreams},
58 proxy_client::{http_proxy, http_proxy_handshake},
59 request_parts::construct_proxy_request_parts,
60};
61
62type ConnectionsTrackState = Arc<RwLock<HashMap<UpstreamInner, Arc<()>>>>;
63
64enum LoadBalancerAlgorithmInner {
65 Random,
66 RoundRobin(Arc<AtomicUsize>),
67 LeastConnections(ConnectionsTrackState),
68 TwoRandomChoices(ConnectionsTrackState),
69}
70
71#[derive(Clone, Copy, Hash, PartialEq, Eq)]
73pub enum LoadBalancerAlgorithm {
74 Random,
76 RoundRobin,
78 LeastConnections,
80 TwoRandomChoices,
82}
83
84#[derive(Clone, Copy)]
86pub enum ProxyHeader {
87 V1,
89 V2,
91}
92
93#[derive(Clone, Eq, PartialEq, Hash)]
94struct UpstreamInner {
95 proxy_to: String,
96 proxy_unix: Option<String>,
97}
98
99#[derive(Clone)]
100struct SrvUpstreamData {
101 to: String,
102 secondary_runtime_handle: tokio::runtime::Handle,
103 dns_resolver: Arc<hickory_resolver::TokioResolver>,
104}
105
106impl PartialEq for SrvUpstreamData {
107 fn eq(&self, other: &Self) -> bool {
108 self.to == other.to
109 }
110}
111
112impl Eq for SrvUpstreamData {}
113
114impl std::hash::Hash for SrvUpstreamData {
115 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
116 self.to.hash(state);
117 }
118}
119
120#[derive(Clone, Eq, PartialEq, Hash)]
121enum Upstream {
122 Static(UpstreamInner),
123 Srv(SrvUpstreamData),
124}
125
126impl Upstream {
127 async fn resolve(
128 &self,
129 failed_backends: Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
130 health_check_max_fails: u64,
131 ) -> Vec<UpstreamInner> {
132 match self {
133 Upstream::Static(inner) => vec![inner.clone()],
134 Upstream::Srv(srv_data) => {
135 let to = srv_data.to.clone();
136 let resolver = srv_data.dns_resolver.clone();
137 let failed_backends = failed_backends.clone();
138 srv_data
139 .secondary_runtime_handle
140 .spawn(async move {
141 let to_url = match Uri::from_str(&to) {
142 Ok(uri) => uri,
143 Err(_) => return vec![],
144 };
145 let to = match to_url.host() {
146 Some(host) => host.to_string(),
147 None => return vec![],
148 };
149
150 let srv_records = match resolver.srv_lookup(&to).await {
151 Ok(records) => records,
152 Err(_) => return vec![],
153 };
154
155 let failed_backends = failed_backends.read().await;
156 let srv_upstreams = srv_records
157 .into_iter()
158 .filter_map(|record| {
159 let mut to_url_parts = to_url.clone().into_parts();
160 to_url_parts.authority = Some(format!("{}:{}", record.target(), record.port()).parse().ok()?);
161 let upstream_inner = UpstreamInner {
162 proxy_to: Uri::from_parts(to_url_parts).ok()?.to_string(),
163 proxy_unix: None,
164 };
165 if failed_backends
166 .get(&upstream_inner)
167 .is_some_and(|fails| fails > health_check_max_fails)
168 {
169 None
171 } else {
172 Some((upstream_inner, record.weight(), record.priority()))
173 }
174 })
175 .collect::<Vec<_>>();
176 let highest_priority = srv_upstreams
177 .iter()
178 .map(|(_, _, priority)| *priority)
179 .min()
180 .unwrap_or(0);
181 let filtered_srv_upstreams = srv_upstreams
182 .into_iter()
183 .filter(|(_, _, priority)| *priority == highest_priority)
184 .map(|(upstream, weight, _)| (upstream, weight))
185 .collect::<Vec<_>>();
186 let cumulative_weight: u64 = filtered_srv_upstreams.iter().map(|(_, weight)| *weight as u64).sum();
187 let mut random_weight = if cumulative_weight == 0 {
188 0
190 } else {
191 rand::random_range(0..cumulative_weight)
192 };
193 for upstream in filtered_srv_upstreams {
194 let weight = upstream.1;
195 if random_weight <= weight as u64 {
196 return vec![upstream.0];
197 }
198 random_weight -= weight as u64;
199 }
200 vec![]
201 })
202 .await
203 .unwrap_or(vec![])
204 }
205 }
206 }
207}
208
209type ProxyToKey = (Upstream, Option<usize>, Option<Duration>);
210type ProxyToKeyInner = (UpstreamInner, Option<usize>, Option<Duration>);
211
212type ConnectionPool = Arc<Pool<(UpstreamInner, Option<IpAddr>), SendRequestWrapper>>;
213type ConnectionPoolItem = Item<(UpstreamInner, Option<IpAddr>), SendRequestWrapper>;
214
215#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
216#[allow(unused)]
217enum DropGuard {
218 Tcp(SendTcpStreamPollDropGuard),
219 #[cfg(unix)]
220 Unix(SendUnixStreamPollDropGuard),
221}
222
223enum Connection {
224 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
225 Tcp(SendTcpStreamPoll),
226 #[cfg(not(any(feature = "runtime-vibeio", feature = "runtime-monoio")))]
227 Tcp(TcpStream),
228 #[cfg(all(any(feature = "runtime-vibeio", feature = "runtime-monoio"), unix))]
229 Unix(SendUnixStreamPoll),
230 #[cfg(all(not(any(feature = "runtime-vibeio", feature = "runtime-monoio")), unix))]
231 Unix(UnixStream),
232}
233
234#[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
235impl Connection {
236 unsafe fn get_drop_guard(&mut self) -> DropGuard {
237 match self {
238 Connection::Tcp(stream) => DropGuard::Tcp(stream.get_drop_guard()),
239 #[cfg(unix)]
240 Connection::Unix(stream) => DropGuard::Unix(stream.get_drop_guard()),
241 }
242 }
243}
244
245impl AsyncRead for Connection {
246 fn poll_read(
247 mut self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 buf: &mut tokio::io::ReadBuf,
250 ) -> Poll<Result<(), std::io::Error>> {
251 match &mut *self {
252 Connection::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
253 #[cfg(unix)]
254 Connection::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
255 }
256 }
257}
258
259impl AsyncWrite for Connection {
260 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
261 match &mut *self {
262 Connection::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
263 #[cfg(unix)]
264 Connection::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
265 }
266 }
267
268 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
269 match &mut *self {
270 Connection::Tcp(stream) => Pin::new(stream).poll_flush(cx),
271 #[cfg(unix)]
272 Connection::Unix(stream) => Pin::new(stream).poll_flush(cx),
273 }
274 }
275
276 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
277 match &mut *self {
278 Connection::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
279 #[cfg(unix)]
280 Connection::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
281 }
282 }
283
284 fn is_write_vectored(&self) -> bool {
285 match self {
286 Connection::Tcp(stream) => stream.is_write_vectored(),
287 #[cfg(unix)]
288 Connection::Unix(stream) => stream.is_write_vectored(),
289 }
290 }
291
292 fn poll_write_vectored(
293 mut self: Pin<&mut Self>,
294 cx: &mut Context<'_>,
295 bufs: &[std::io::IoSlice<'_>],
296 ) -> Poll<Result<usize, std::io::Error>> {
297 match &mut *self {
298 Connection::Tcp(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
299 #[cfg(unix)]
300 Connection::Unix(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
301 }
302 }
303}
304
305pub struct Connections {
307 #[allow(clippy::type_complexity)]
308 load_balancer_cache: HashMap<
309 (
310 LoadBalancerAlgorithm,
311 Arc<Vec<(Upstream, Option<usize>, Option<Duration>)>>,
312 ),
313 Arc<LoadBalancerAlgorithmInner>,
314 >,
315 #[allow(clippy::type_complexity)]
316 failed_backend_cache: HashMap<
317 (Duration, u64, Arc<Vec<(Upstream, Option<usize>, Option<Duration>)>>),
318 Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
319 >,
320 connections: ConnectionPool,
321 #[cfg(unix)]
322 unix_connections: ConnectionPool,
323}
324
325impl Connections {
326 pub fn new() -> Self {
328 Self {
329 load_balancer_cache: HashMap::new(),
330 failed_backend_cache: HashMap::new(),
331 connections: Arc::new(Pool::new_unbounded()),
332 #[cfg(unix)]
333 unix_connections: Arc::new(Pool::new_unbounded()),
334 }
335 }
336
337 pub fn with_global_limit(global_limit: usize) -> Self {
341 Self {
342 load_balancer_cache: HashMap::new(),
343 failed_backend_cache: HashMap::new(),
344 connections: Arc::new(Pool::new(global_limit)),
345 #[cfg(unix)]
346 unix_connections: Arc::new(Pool::new_unbounded()),
347 }
348 }
349
350 pub fn get_builder<'a>(&'a mut self) -> ReverseProxyBuilder<'a> {
352 ReverseProxyBuilder {
353 connections: self,
354 upstreams: Vec::new(),
355 lb_algorithm: LoadBalancerAlgorithm::TwoRandomChoices,
356 lb_health_check_window: Duration::from_millis(5000),
357 lb_health_check_max_fails: 3,
358 lb_health_check: false,
359 proxy_no_verification: false,
360 proxy_intercept_errors: false,
361 lb_retry_connection: true,
362 proxy_http2_only: false,
363 proxy_http2: false,
364 proxy_keepalive: true,
365 proxy_proxy_header: None,
366 proxy_request_header: Vec::new(),
367 proxy_request_header_replace: Vec::new(),
368 proxy_request_header_remove: Vec::new(),
369 rewrite_host: false,
370 }
371 }
372}
373
374impl Default for Connections {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380pub struct ReverseProxy {
382 #[allow(clippy::type_complexity)]
383 failed_backends: Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
384 load_balancer_algorithm: Arc<LoadBalancerAlgorithmInner>,
385 proxy_to: Arc<Vec<ProxyToKey>>,
386 health_check_max_fails: u64,
387 enable_health_check: bool,
388 disable_certificate_verification: bool,
389 proxy_intercept_errors: bool,
390 retry_connection: bool,
391 proxy_http2_only: bool,
392 proxy_http2: bool,
393 proxy_keepalive: bool,
394 proxy_header: Option<ProxyHeader>,
395 headers_to_add: Arc<Vec<(HeaderName, String)>>,
396 headers_to_replace: Arc<Vec<(HeaderName, String)>>,
397 headers_to_remove: Arc<Vec<HeaderName>>,
398 rewrite_host: bool,
399 connections: ConnectionPool,
400 #[cfg(unix)]
401 unix_connections: ConnectionPool,
402}
403
404impl ReverseProxy {
405 pub fn get_handler(&self) -> ReverseProxyHandler {
407 ReverseProxyHandler {
408 failed_backends: self.failed_backends.clone(),
409 load_balancer_algorithm: self.load_balancer_algorithm.clone(),
410 proxy_to: self.proxy_to.clone(),
411 health_check_max_fails: self.health_check_max_fails,
412 selected_backends_metrics: None,
413 unhealthy_backends_metrics: None,
414 connection_reused: false,
415 enable_health_check: self.enable_health_check,
416 disable_certificate_verification: self.disable_certificate_verification,
417 proxy_intercept_errors: self.proxy_intercept_errors,
418 retry_connection: self.retry_connection,
419 proxy_http2_only: self.proxy_http2_only,
420 proxy_http2: self.proxy_http2,
421 proxy_keepalive: self.proxy_keepalive,
422 proxy_header: self.proxy_header,
423 headers_to_add: self.headers_to_add.clone(),
424 headers_to_replace: self.headers_to_replace.clone(),
425 headers_to_remove: self.headers_to_remove.clone(),
426 rewrite_host: self.rewrite_host,
427 connections: self.connections.clone(),
428 #[cfg(unix)]
429 unix_connections: self.unix_connections.clone(),
430 }
431 }
432}
433
434pub struct ReverseProxyHandler {
436 #[allow(clippy::type_complexity)]
437 failed_backends: Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
438 load_balancer_algorithm: Arc<LoadBalancerAlgorithmInner>,
439 proxy_to: Arc<Vec<ProxyToKey>>,
440 health_check_max_fails: u64,
441 selected_backends_metrics: Option<Vec<UpstreamInner>>,
442 unhealthy_backends_metrics: Option<Vec<UpstreamInner>>,
443 connection_reused: bool,
444 enable_health_check: bool,
445 disable_certificate_verification: bool,
446 proxy_intercept_errors: bool,
447 retry_connection: bool,
448 proxy_http2_only: bool,
449 proxy_http2: bool,
450 proxy_keepalive: bool,
451 proxy_header: Option<ProxyHeader>,
452 headers_to_add: Arc<Vec<(HeaderName, String)>>,
453 headers_to_replace: Arc<Vec<(HeaderName, String)>>,
454 headers_to_remove: Arc<Vec<HeaderName>>,
455 rewrite_host: bool,
456 connections: ConnectionPool,
457 #[cfg(unix)]
458 unix_connections: ConnectionPool,
459}
460
461impl ReverseProxyHandler {
462 #[inline]
463 fn status_response(status_code: StatusCode) -> ResponseData {
464 ResponseData {
465 request: None,
466 response: None,
467 response_status: Some(status_code),
468 response_headers: None,
469 new_remote_address: None,
470 }
471 }
472
473 async fn mark_backend_failure(&mut self, upstream: &UpstreamInner) {
474 if !self.enable_health_check {
475 return;
476 }
477 if let Some(unhealthy_backends_metrics) = self.unhealthy_backends_metrics.as_mut() {
478 unhealthy_backends_metrics.push(upstream.clone());
479 }
480 let mut failed_backends_write = self.failed_backends.write().await;
481 let failed_attempts = failed_backends_write.get(upstream);
482 failed_backends_write.insert(upstream.clone(), failed_attempts.map_or(1, |x| x + 1));
483 }
484
485 async fn retry_or_respond(
486 &self,
487 error_logger: &ErrorLogger,
488 err: &dyn std::fmt::Display,
489 retry_connection: bool,
490 has_more_backends: bool,
491 status_code: StatusCode,
492 log_prefix: &str,
493 ) -> Option<ResponseData> {
494 if retry_connection && has_more_backends {
495 error_logger
496 .log(&format!("Failed to connect to backend, trying another backend: {err}"))
497 .await;
498 None
499 } else {
500 error_logger.log(&format!("{log_prefix}: {err}")).await;
501 Some(Self::status_response(status_code))
502 }
503 }
504
505 #[inline]
506 fn io_error_status(err: &std::io::Error) -> (StatusCode, &'static str) {
507 match err.kind() {
508 std::io::ErrorKind::ConnectionRefused | std::io::ErrorKind::NotFound | std::io::ErrorKind::HostUnreachable => {
509 (StatusCode::SERVICE_UNAVAILABLE, "Service unavailable")
510 }
511 std::io::ErrorKind::TimedOut => (StatusCode::GATEWAY_TIMEOUT, "Gateway timeout"),
512 _ => (StatusCode::BAD_GATEWAY, "Bad gateway"),
513 }
514 }
515}
516
517#[async_trait(?Send)]
518impl ModuleHandlers for ReverseProxyHandler {
519 async fn request_handler(
536 &mut self,
537 request: Request<BoxBody<Bytes, std::io::Error>>,
538 config: &ServerConfiguration,
539 socket_data: &SocketData,
540 error_logger: &ErrorLogger,
541 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
542 let enable_health_check = self.enable_health_check;
543 let health_check_max_fails = self.health_check_max_fails;
544 let disable_certificate_verification = self.disable_certificate_verification;
545 let proxy_intercept_errors = self.proxy_intercept_errors;
546 if self.proxy_to.is_empty() {
547 return Ok(ResponseData {
549 request: Some(request),
550 response: None,
551 response_status: None,
552 response_headers: None,
553 new_remote_address: None,
554 });
555 }
556 let mut proxy_to_vector = resolve_upstreams(
557 &self.proxy_to,
558 self.failed_backends.clone(),
559 self.health_check_max_fails,
560 )
561 .await;
562 let load_balancer_algorithm = self.load_balancer_algorithm.clone();
563 let connection_track = match &*load_balancer_algorithm {
564 LoadBalancerAlgorithmInner::LeastConnections(connection_track) => Some(connection_track),
565 LoadBalancerAlgorithmInner::TwoRandomChoices(connection_track) => Some(connection_track),
566 _ => None,
567 };
568 let retry_connection = self.retry_connection;
569 let (request_parts, request_body) = request.into_parts();
570 let mut request_parts = Some(request_parts);
571
572 loop {
573 if let Some((upstream, local_limit_index, keepalive_idle_timeout)) = determine_proxy_to(
574 &mut proxy_to_vector,
575 &self.failed_backends,
576 enable_health_check,
577 health_check_max_fails,
578 &load_balancer_algorithm,
579 )
580 .await
581 {
582 if let Some(selected_backends_metrics) = self.selected_backends_metrics.as_mut() {
583 selected_backends_metrics.push(upstream.clone());
584 }
585 let UpstreamInner { proxy_to, proxy_unix } = &upstream;
586 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
587 let scheme_str = proxy_request_url.scheme_str();
588 let mut encrypted = false;
589
590 match scheme_str {
591 Some("http") => {
592 encrypted = false;
593 }
594 Some("https") => {
595 encrypted = true;
596 }
597 _ => Err(anyhow::anyhow!("Only HTTP and HTTPS reverse proxy URLs are supported."))?,
598 };
599
600 let host = match proxy_request_url.host() {
601 Some(host) => host,
602 None => Err(anyhow::anyhow!("The reverse proxy URL doesn't include the host"))?,
603 };
604
605 let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
606 Some("http") => 80,
607 Some("https") => 443,
608 _ => 80,
609 });
610
611 let addr = format!("{host}:{port}");
612
613 let request_parts_option = if proxy_to_vector.is_empty() {
614 request_parts.take()
615 } else {
616 request_parts.clone()
617 };
618 let request_parts = request_parts_option.ok_or(anyhow::anyhow!("Request parts not found"))?;
619 let proxy_request_parts = construct_proxy_request_parts(
620 request_parts,
621 config,
622 socket_data,
623 &proxy_request_url,
624 &self.headers_to_add,
625 &self.headers_to_replace,
626 &self.headers_to_remove,
627 self.rewrite_host,
628 )?;
629
630 let tracked_connection = if let Some(connection_track) = connection_track {
631 let connection_track_read = connection_track.read().await;
632 Some(if let Some(connection_count) = connection_track_read.get(&upstream) {
633 connection_count.clone()
634 } else {
635 let tracked_connection = Arc::new(());
636 drop(connection_track_read);
637 connection_track
638 .write()
639 .await
640 .insert(upstream.clone(), tracked_connection.clone());
641 tracked_connection
642 })
643 } else {
644 None
645 };
646
647 let proxy_header = self.proxy_header;
648
649 let is_http_upgrade = proxy_request_parts.headers.contains_key(header::UPGRADE);
650 let enable_http2_only_config = self.proxy_http2_only;
651 let enable_http2_config = self.proxy_http2;
652
653 let enable_keepalive =
654 (enable_http2_only_config || !enable_http2_config || !is_http_upgrade) && self.proxy_keepalive;
655 let connection_pool_item = {
656 #[cfg(unix)]
657 let connections = if proxy_unix.is_some() {
658 &self.unix_connections
659 } else {
660 &self.connections
661 };
662 #[cfg(not(unix))]
663 let connections = &self.connections;
664 let sender;
665 let mut send_request_items = Vec::new();
666 let proxy_client_ip = match proxy_header {
667 Some(ProxyHeader::V1) | Some(ProxyHeader::V2) => Some(socket_data.remote_addr.ip().to_canonical()),
668 _ => None,
669 };
670 loop {
671 let mut send_request_item = if send_request_items.is_empty() {
672 connections
673 .pull_with_wait_local_limit((upstream.clone(), proxy_client_ip), local_limit_index)
674 .await
675 } else if let Poll::Ready(send_request_item_option) = connections
676 .pull_with_wait_local_limit((upstream.clone(), proxy_client_ip), local_limit_index)
677 .boxed_local()
678 .poll_unpin(&mut Context::from_waker(Waker::noop()))
679 {
680 send_request_item_option
681 } else {
682 let send_request_items_taken = send_request_items;
683 send_request_items = Vec::new();
684 let fetch_nonready_send_request_fut = async {
685 let result = futures_util::future::select_ok(send_request_items_taken).await;
686 if let Ok((item, send_request_items_smaller)) = result {
687 send_request_items = send_request_items_smaller;
688 item
689 } else {
690 futures_util::future::pending().await
691 }
692 };
693 crate::runtime::select! {
694 item = connections
695 .pull_with_wait_local_limit((upstream.clone(), proxy_client_ip), local_limit_index)
696 => {
697 item
698 },
699 item = fetch_nonready_send_request_fut => {
700 item
701 }
702 }
703 };
704 if let Some(send_request) = send_request_item.inner_mut() {
705 match send_request.get(keepalive_idle_timeout) {
706 (Some(send_request), true) => {
707 send_request_items.clear();
709 self.connection_reused = true;
710 let _ = send_request_item.inner_mut().take();
711 let proxy_request = Request::from_parts(proxy_request_parts, request_body);
712 let result = http_proxy(
713 send_request,
714 send_request_item,
715 proxy_request,
716 error_logger,
717 proxy_intercept_errors,
718 tracked_connection,
719 true,
720 )
721 .await;
722 return result;
723 }
724 (None, true) => {
725 send_request_items.push(Box::pin(async move {
727 let inner_item = send_request_item.inner_mut();
728 if let Some(inner_item_2) = inner_item {
729 if !inner_item_2.wait_ready(keepalive_idle_timeout).await {
730 inner_item.take();
732 return Err(());
733 }
734 let _ = inner_item;
735 Ok(send_request_item)
736 } else {
737 Err(())
738 }
739 }));
740 continue;
741 }
742 (_, false) => {
743 let _ = send_request_item.inner_mut().take();
745 continue;
746 }
747 }
748 }
749 send_request_items.clear();
750 sender = send_request_item;
751 break;
752 }
753 sender
754 };
755
756 let stream = if let Some(proxy_unix_str) = &proxy_unix {
757 #[cfg(not(unix))]
758 {
759 let _ = proxy_unix_str; Err(anyhow::anyhow!("Unix sockets are not supported on this platform"))?
761 }
762
763 #[cfg(unix)]
764 {
765 let stream = match UnixStream::connect(proxy_unix_str).await {
766 Ok(stream) => stream,
767 Err(err) => {
768 self.mark_backend_failure(&upstream).await;
769 let (status_code, log_prefix) = Self::io_error_status(&err);
770 if let Some(response) = self
771 .retry_or_respond(
772 error_logger,
773 &err,
774 retry_connection,
775 !proxy_to_vector.is_empty(),
776 status_code,
777 log_prefix,
778 )
779 .await
780 {
781 return Ok(response);
782 }
783 continue;
784 }
785 };
786
787 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
788 let stream = match SendUnixStreamPoll::new_comp_io(stream) {
789 Ok(stream) => stream,
790 Err(err) => {
791 self.mark_backend_failure(&upstream).await;
792 if let Some(response) = self
793 .retry_or_respond(
794 error_logger,
795 &err,
796 retry_connection,
797 !proxy_to_vector.is_empty(),
798 StatusCode::BAD_GATEWAY,
799 "Bad gateway",
800 )
801 .await
802 {
803 return Ok(response);
804 }
805 continue;
806 }
807 };
808
809 Connection::Unix(stream)
810 }
811 } else {
812 let stream = match TcpStream::connect(&addr).await {
813 Ok(stream) => stream,
814 Err(err) => {
815 self.mark_backend_failure(&upstream).await;
816 let (status_code, log_prefix) = Self::io_error_status(&err);
817 if let Some(response) = self
818 .retry_or_respond(
819 error_logger,
820 &err,
821 retry_connection,
822 !proxy_to_vector.is_empty(),
823 status_code,
824 log_prefix,
825 )
826 .await
827 {
828 return Ok(response);
829 }
830 continue;
831 }
832 };
833
834 if let Err(err) = stream.set_nodelay(true) {
835 self.mark_backend_failure(&upstream).await;
836 if let Some(response) = self
837 .retry_or_respond(
838 error_logger,
839 &err,
840 retry_connection,
841 !proxy_to_vector.is_empty(),
842 StatusCode::BAD_GATEWAY,
843 "Bad gateway",
844 )
845 .await
846 {
847 return Ok(response);
848 }
849 continue;
850 };
851
852 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
853 let stream = match SendTcpStreamPoll::new_comp_io(stream) {
854 Ok(stream) => stream,
855 Err(err) => {
856 self.mark_backend_failure(&upstream).await;
857 if let Some(response) = self
858 .retry_or_respond(
859 error_logger,
860 &err,
861 retry_connection,
862 !proxy_to_vector.is_empty(),
863 StatusCode::BAD_GATEWAY,
864 "Bad gateway",
865 )
866 .await
867 {
868 return Ok(response);
869 }
870 continue;
871 }
872 };
873
874 Connection::Tcp(stream)
875 };
876
877 let proxy_header_to_write = match proxy_header {
878 Some(ProxyHeader::V1) => {
879 let is_ipv4 = socket_data.local_addr.ip().to_canonical().is_ipv4()
880 && socket_data.remote_addr.ip().to_canonical().is_ipv4();
881 let local_addr = if is_ipv4 {
882 match socket_data.local_addr.ip().to_canonical() {
883 IpAddr::V4(ip) => ip.to_string(),
884 IpAddr::V6(ip) => ip
885 .to_ipv4_mapped()
886 .ok_or(anyhow::anyhow!("Connection IP address type mismatch"))?
887 .to_string(),
888 }
889 } else {
890 match socket_data.local_addr.ip().to_canonical() {
891 IpAddr::V4(ip) => ip
892 .to_ipv6_mapped()
893 .segments()
894 .iter()
895 .map(|seg| format!("{:04x}", seg))
896 .collect::<Vec<_>>()
897 .join(":"),
898 IpAddr::V6(ip) => ip
899 .segments()
900 .iter()
901 .map(|seg| format!("{:04x}", seg))
902 .collect::<Vec<_>>()
903 .join(":"),
904 }
905 };
906 let remote_addr = if is_ipv4 {
907 match socket_data.remote_addr.ip().to_canonical() {
908 IpAddr::V4(ip) => ip.to_string(),
909 IpAddr::V6(ip) => ip
910 .to_ipv4_mapped()
911 .ok_or(anyhow::anyhow!("Connection IP address type mismatch"))?
912 .to_string(),
913 }
914 } else {
915 match socket_data.remote_addr.ip().to_canonical() {
916 IpAddr::V4(ip) => ip
917 .to_ipv6_mapped()
918 .segments()
919 .iter()
920 .map(|seg| format!("{:04x}", seg))
921 .collect::<Vec<_>>()
922 .join(":"),
923 IpAddr::V6(ip) => ip
924 .segments()
925 .iter()
926 .map(|seg| format!("{:04x}", seg))
927 .collect::<Vec<_>>()
928 .join(":"),
929 }
930 };
931 let local_port = socket_data.local_addr.port();
932 let remote_port = socket_data.remote_addr.port();
933 let header = format!(
934 "PROXY {} {} {} {} {}\r\n",
935 if is_ipv4 { "TCP4" } else { "TCP6" },
936 remote_addr,
937 local_addr,
938 remote_port,
939 local_port,
940 );
941 Some(header.into_bytes())
942 }
943 Some(ProxyHeader::V2) => {
944 let is_ipv4 = socket_data.local_addr.ip().to_canonical().is_ipv4()
945 && socket_data.remote_addr.ip().to_canonical().is_ipv4();
946 let addresses = if is_ipv4 {
947 ppp::v2::Addresses::IPv4(ppp::v2::IPv4::new(
948 match socket_data.remote_addr.ip().to_canonical() {
949 IpAddr::V4(ip) => ip,
950 IpAddr::V6(ip) => ip
951 .to_ipv4_mapped()
952 .ok_or(anyhow::anyhow!("Connection IP address type mismatch"))?,
953 },
954 match socket_data.local_addr.ip().to_canonical() {
955 IpAddr::V4(ip) => ip,
956 IpAddr::V6(ip) => ip
957 .to_ipv4_mapped()
958 .ok_or(anyhow::anyhow!("Connection IP address type mismatch"))?,
959 },
960 socket_data.remote_addr.port(),
961 socket_data.local_addr.port(),
962 ))
963 } else {
964 ppp::v2::Addresses::IPv6(ppp::v2::IPv6::new(
965 match socket_data.remote_addr.ip().to_canonical() {
966 IpAddr::V4(ip) => ip.to_ipv6_mapped(),
967 IpAddr::V6(ip) => ip,
968 },
969 match socket_data.local_addr.ip().to_canonical() {
970 IpAddr::V4(ip) => ip.to_ipv6_mapped(),
971 IpAddr::V6(ip) => ip,
972 },
973 socket_data.remote_addr.port(),
974 socket_data.local_addr.port(),
975 ))
976 };
977 let header_builder = ppp::v2::Builder::with_addresses(
978 ppp::v2::Version::Two | ppp::v2::Command::Proxy,
979 ppp::v2::Protocol::Stream,
980 addresses,
981 );
982 Some(header_builder.build()?)
983 }
984 _ => None,
985 };
986
987 let mut stream = stream; if let Some(proxy_header_to_write) = proxy_header_to_write {
990 if let Err(err) = stream.write_all(&proxy_header_to_write).await {
991 self.mark_backend_failure(&upstream).await;
992 if let Some(response) = self
993 .retry_or_respond(
994 error_logger,
995 &err,
996 retry_connection,
997 !proxy_to_vector.is_empty(),
998 StatusCode::BAD_GATEWAY,
999 "Bad gateway",
1000 )
1001 .await
1002 {
1003 return Ok(response);
1004 }
1005 continue;
1006 }
1007 }
1008
1009 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
1013 let drop_guard = unsafe { stream.get_drop_guard() };
1014
1015 let sender = if !encrypted {
1016 let sender = match http_proxy_handshake(
1017 stream,
1018 enable_http2_only_config,
1019 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
1020 drop_guard,
1021 )
1022 .await
1023 {
1024 Ok(sender) => sender,
1025 Err(err) => {
1026 self.mark_backend_failure(&upstream).await;
1027 if let Some(response) = self
1028 .retry_or_respond(
1029 error_logger,
1030 &err,
1031 retry_connection,
1032 !proxy_to_vector.is_empty(),
1033 StatusCode::BAD_GATEWAY,
1034 "Bad gateway",
1035 )
1036 .await
1037 {
1038 return Ok(response);
1039 }
1040 continue;
1041 }
1042 };
1043
1044 sender
1045 } else {
1046 let enable_http2_config = enable_http2_only_config || (enable_http2_config && !is_http_upgrade);
1047 let mut tls_client_config = (if disable_certificate_verification {
1048 rustls::ClientConfig::builder()
1049 .dangerous()
1050 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
1051 } else if let Ok(client_config) = BuilderVerifierExt::with_platform_verifier(rustls::ClientConfig::builder())
1052 {
1053 client_config
1054 } else {
1055 rustls::ClientConfig::builder().with_webpki_verifier(
1056 WebPkiServerVerifier::builder(Arc::new(rustls::RootCertStore {
1057 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
1058 }))
1059 .build()?,
1060 )
1061 })
1062 .with_no_client_auth();
1063 if enable_http2_only_config {
1064 tls_client_config.alpn_protocols = vec![b"h2".to_vec()];
1065 } else if enable_http2_config {
1066 tls_client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];
1067 } else {
1068 tls_client_config.alpn_protocols = vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()];
1069 }
1070 let connector = TlsConnector::from(Arc::new(tls_client_config));
1071 let domain = ServerName::try_from(host)?.to_owned();
1072
1073 let tls_stream = match connector.connect(domain, stream).await {
1074 Ok(stream) => stream,
1075 Err(err) => {
1076 self.mark_backend_failure(&upstream).await;
1077 if let Some(response) = self
1078 .retry_or_respond(
1079 error_logger,
1080 &err,
1081 retry_connection,
1082 !proxy_to_vector.is_empty(),
1083 StatusCode::BAD_GATEWAY,
1084 "Bad gateway",
1085 )
1086 .await
1087 {
1088 return Ok(response);
1089 }
1090 continue;
1091 }
1092 };
1093
1094 let enable_http2 = enable_http2_config && tls_stream.get_ref().1.alpn_protocol() == Some(b"h2");
1096
1097 let sender = match http_proxy_handshake(
1098 tls_stream,
1099 enable_http2,
1100 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
1101 drop_guard,
1102 )
1103 .await
1104 {
1105 Ok(sender) => sender,
1106 Err(err) => {
1107 self.mark_backend_failure(&upstream).await;
1108 if let Some(response) = self
1109 .retry_or_respond(
1110 error_logger,
1111 &err,
1112 retry_connection,
1113 !proxy_to_vector.is_empty(),
1114 StatusCode::BAD_GATEWAY,
1115 "Bad gateway",
1116 )
1117 .await
1118 {
1119 return Ok(response);
1120 }
1121 continue;
1122 }
1123 };
1124
1125 sender
1126 };
1127
1128 let proxy_request = Request::from_parts(proxy_request_parts, request_body);
1129
1130 return http_proxy(
1131 sender,
1132 connection_pool_item,
1133 proxy_request,
1134 error_logger,
1135 proxy_intercept_errors,
1136 tracked_connection,
1137 enable_keepalive,
1138 )
1139 .await;
1140 } else {
1141 let request_parts = request_parts.ok_or(anyhow::anyhow!("Request parts are missing"))?;
1142 error_logger.log("No upstreams available").await;
1143 return Ok(ResponseData {
1144 request: Some(Request::from_parts(request_parts, request_body)),
1145 response: None,
1146 response_status: Some(StatusCode::SERVICE_UNAVAILABLE), response_headers: None,
1148 new_remote_address: None,
1149 });
1150 }
1151 }
1152 }
1153
1154 async fn metric_data_before_handler(
1155 &mut self,
1156 _request: &Request<BoxBody<Bytes, std::io::Error>>,
1157 _socket_data: &SocketData,
1158 _metrics_sender: &MetricsMultiSender,
1159 ) {
1160 self.selected_backends_metrics = Some(Vec::new());
1161 self.unhealthy_backends_metrics = Some(Vec::new());
1162 }
1163
1164 async fn metric_data_after_handler(&mut self, metrics_sender: &MetricsMultiSender) {
1165 if let Some(selected_backends_metrics) = self.selected_backends_metrics.take() {
1166 for selected_backend in selected_backends_metrics {
1167 let mut attributes = Vec::new();
1168 attributes.push((
1169 "ferron.proxy.backend_url",
1170 MetricAttributeValue::String(selected_backend.proxy_to),
1171 ));
1172 if let Some(backend_unix) = selected_backend.proxy_unix {
1173 attributes.push((
1174 "ferron.proxy.backend_unix_path",
1175 MetricAttributeValue::String(backend_unix),
1176 ));
1177 }
1178 metrics_sender
1179 .send(Metric::new(
1180 "ferron.proxy.backends.selected",
1181 attributes,
1182 MetricType::Counter,
1183 MetricValue::U64(1),
1184 Some("{backend}"),
1185 Some("Number of times a backend server was selected."),
1186 ))
1187 .await;
1188 }
1189 }
1190 if let Some(unhealthy_backends_metrics) = self.unhealthy_backends_metrics.take() {
1191 for unhealthy_backend in unhealthy_backends_metrics {
1192 let mut attributes = Vec::new();
1193 attributes.push((
1194 "ferron.proxy.backend_url",
1195 MetricAttributeValue::String(unhealthy_backend.proxy_to),
1196 ));
1197 if let Some(backend_unix) = unhealthy_backend.proxy_unix {
1198 attributes.push((
1199 "ferron.proxy.backend_unix_path",
1200 MetricAttributeValue::String(backend_unix),
1201 ));
1202 }
1203 metrics_sender
1204 .send(Metric::new(
1205 "ferron.proxy.backends.unhealthy",
1206 attributes,
1207 MetricType::Counter,
1208 MetricValue::U64(1),
1209 Some("{backend}"),
1210 Some("Number of health check failures for a backend server."),
1211 ))
1212 .await;
1213 }
1214 }
1215 metrics_sender
1216 .send(Metric::new(
1217 "ferron.proxy.requests",
1218 vec![(
1219 "ferron.proxy.connection_reused",
1220 MetricAttributeValue::Bool(self.connection_reused),
1221 )],
1222 MetricType::Counter,
1223 MetricValue::U64(1),
1224 Some("{request}"),
1225 Some("Number of reverse proxy requests."),
1226 ))
1227 .await;
1228 }
1229}