ferron_common/http_proxy/
proxy_client.rs

1use std::cell::UnsafeCell;
2use std::error::Error;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use bytes::Bytes;
8use http_body_util::combinators::BoxBody;
9use http_body_util::BodyExt;
10use hyper::body::Body;
11use hyper::{Request, Response, StatusCode};
12#[cfg(feature = "runtime-tokio")]
13use hyper_util::rt::{TokioExecutor, TokioIo};
14#[cfg(feature = "runtime-monoio")]
15use monoio_compat::hyper::{MonoioExecutor, MonoioIo};
16use tokio::io::{AsyncRead, AsyncWrite};
17#[cfg(feature = "runtime-vibeio")]
18use vibeio_hyper::{VibeioExecutor, VibeioIo};
19
20use super::ConnectionPoolItem;
21#[cfg(any(feature = "runtime-monoio", feature = "runtime-vibeio"))]
22use super::DropGuard;
23use crate::http_proxy::send_request::{SendRequest, SendRequestWrapper};
24use crate::logging::ErrorLogger;
25use crate::modules::ResponseData;
26
27/// A tracked response body.
28struct TrackedBody<B> {
29  inner: B,
30  _tracker: Option<Arc<()>>,
31  _tracker_pool: Option<Arc<UnsafeCell<ConnectionPoolItem>>>,
32}
33
34impl<B> TrackedBody<B> {
35  fn new(inner: B, tracker: Option<Arc<()>>, tracker_pool: Option<Arc<UnsafeCell<ConnectionPoolItem>>>) -> Self {
36    Self {
37      inner,
38      _tracker: tracker,
39      _tracker_pool: tracker_pool,
40    }
41  }
42}
43
44impl<B> Body for TrackedBody<B>
45where
46  B: Body + Unpin,
47{
48  type Data = B::Data;
49  type Error = B::Error;
50
51  #[inline]
52  fn poll_frame(
53    mut self: Pin<&mut Self>,
54    cx: &mut Context<'_>,
55  ) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
56    Pin::new(&mut self.inner).poll_frame(cx)
57  }
58
59  #[inline]
60  fn is_end_stream(&self) -> bool {
61    self.inner.is_end_stream()
62  }
63
64  #[inline]
65  fn size_hint(&self) -> hyper::body::SizeHint {
66    self.inner.size_hint()
67  }
68}
69
70// Safety: after construction, the value inside `UnsafeCell` is never mutated.
71// All accesses after sharing are read-only, so sharing across threads is safe.
72unsafe impl<B> Send for TrackedBody<B> where B: Send {}
73unsafe impl<B> Sync for TrackedBody<B> where B: Sync {}
74
75/// Establishes a new HTTP connection to a backend server.
76pub(super) async fn http_proxy_handshake(
77  stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
78  use_http2: bool,
79  #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))] drop_guard: DropGuard,
80) -> Result<SendRequest, Box<dyn Error + Send + Sync>> {
81  #[cfg(feature = "runtime-vibeio")]
82  let io = VibeioIo::new(stream);
83  #[cfg(feature = "runtime-monoio")]
84  let io = MonoioIo::new(stream);
85  #[cfg(feature = "runtime-tokio")]
86  let io = TokioIo::new(stream);
87
88  Ok(if use_http2 {
89    #[cfg(feature = "runtime-vibeio")]
90    let executor = VibeioExecutor;
91    #[cfg(feature = "runtime-monoio")]
92    let executor = MonoioExecutor;
93    #[cfg(feature = "runtime-tokio")]
94    let executor = TokioExecutor::new();
95
96    let (sender, conn) = hyper::client::conn::http2::handshake(executor, io).await?;
97
98    crate::runtime::spawn(async move {
99      conn.await.unwrap_or_default();
100      #[cfg(feature = "runtime-monoio")]
101      drop(drop_guard);
102    });
103
104    SendRequest::Http2(sender)
105  } else {
106    let (sender, conn) = hyper::client::conn::http1::handshake(io).await?;
107
108    let conn_with_upgrades = conn.with_upgrades();
109    crate::runtime::spawn(async move {
110      conn_with_upgrades.await.unwrap_or_default();
111      #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
112      drop(drop_guard);
113    });
114
115    SendRequest::Http1(sender)
116  })
117}
118
119/// Forwards an HTTP request to a backend server.
120pub(super) async fn http_proxy(
121  mut sender: SendRequest,
122  connection_pool_item: ConnectionPoolItem,
123  proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
124  error_logger: &ErrorLogger,
125  proxy_intercept_errors: bool,
126  tracked_connection: Option<Arc<()>>,
127  enable_keepalive: bool,
128) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
129  let (proxy_request_parts, proxy_request_body) = proxy_request.into_parts();
130  #[cfg(feature = "runtime-vibeio")]
131  let mut proxy_request_cloned = Request::from_parts(
132    proxy_request_parts.clone(),
133    http_body_util::Empty::<bytes::Bytes>::new(),
134  );
135  #[cfg(not(feature = "runtime-vibeio"))]
136  let proxy_request_cloned = Request::from_parts(proxy_request_parts.clone(), ());
137  let proxy_request = Request::from_parts(proxy_request_parts, proxy_request_body);
138
139  let send_request_result = sender.send_request(proxy_request).await;
140  #[allow(clippy::arc_with_non_send_sync)]
141  let connection_pool_item = Arc::new(UnsafeCell::new(connection_pool_item));
142
143  let proxy_response = match send_request_result {
144    Ok(response) => response,
145    Err(err) => {
146      error_logger.log(&format!("Bad gateway: {err}")).await;
147      return Ok(ResponseData {
148        request: None,
149        response: None,
150        response_status: Some(StatusCode::BAD_GATEWAY),
151        response_headers: None,
152        new_remote_address: None,
153      });
154    }
155  };
156
157  let status_code = proxy_response.status();
158
159  let (proxy_response_parts, proxy_response_body) = proxy_response.into_parts();
160  if proxy_response_parts.status == StatusCode::SWITCHING_PROTOCOLS {
161    let proxy_response_cloned = Response::from_parts(proxy_response_parts.clone(), ());
162    match hyper::upgrade::on(proxy_response_cloned).await {
163      Ok(upgraded_backend) => {
164        let error_logger = error_logger.clone();
165        let connection_pool_item = connection_pool_item.clone();
166        #[cfg(feature = "runtime-vibeio")]
167        let upgrade_on = vibeio_http::prepare_upgrade(&mut proxy_request_cloned);
168        crate::runtime::spawn(async move {
169          #[cfg(feature = "runtime-vibeio")]
170          let upgrade_on = (if let Some(upgraded_request) = upgrade_on {
171            upgraded_request.await
172          } else {
173            None
174          })
175          .ok_or(std::io::Error::other("vibeio HTTP upgrade failure"));
176          #[cfg(not(feature = "runtime-vibeio"))]
177          let upgrade_on = hyper::upgrade::on(proxy_request_cloned).await;
178          match upgrade_on {
179            Ok(upgraded_proxy) => {
180              #[cfg(feature = "runtime-vibeio")]
181              let mut upgraded_backend = VibeioIo::new(upgraded_backend);
182              #[cfg(feature = "runtime-monoio")]
183              let mut upgraded_backend = MonoioIo::new(upgraded_backend);
184              #[cfg(feature = "runtime-tokio")]
185              let mut upgraded_backend = TokioIo::new(upgraded_backend);
186
187              #[cfg(feature = "runtime-vibeio")]
188              let mut upgraded_proxy = upgraded_proxy;
189              #[cfg(feature = "runtime-monoio")]
190              let mut upgraded_proxy = MonoioIo::new(upgraded_proxy);
191              #[cfg(feature = "runtime-tokio")]
192              let mut upgraded_proxy = TokioIo::new(upgraded_proxy);
193
194              crate::runtime::spawn(async move {
195                tokio::io::copy_bidirectional(&mut upgraded_backend, &mut upgraded_proxy)
196                  .await
197                  .unwrap_or_default();
198                drop(connection_pool_item);
199              });
200            }
201            Err(err) => {
202              error_logger.log(&format!("HTTP upgrade error: {err}")).await;
203            }
204          }
205        });
206      }
207      Err(err) => {
208        error_logger.log(&format!("HTTP upgrade error: {err}")).await;
209      }
210    }
211  }
212  let proxy_response = Response::from_parts(proxy_response_parts, proxy_response_body);
213
214  let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
215    ResponseData {
216      request: None,
217      response: None,
218      response_status: Some(status_code),
219      response_headers: None,
220      new_remote_address: None,
221    }
222  } else {
223    let (response_parts, response_body) = proxy_response.into_parts();
224    let boxed_body = TrackedBody::new(
225      response_body.map_err(|e| std::io::Error::other(e.to_string())),
226      tracked_connection,
227      if enable_keepalive && !sender.is_closed() {
228        None
229      } else {
230        Some(connection_pool_item.clone())
231      },
232    )
233    .boxed();
234    ResponseData {
235      request: None,
236      response: Some(Response::from_parts(response_parts, boxed_body)),
237      response_status: None,
238      response_headers: None,
239      new_remote_address: None,
240    }
241  };
242
243  if enable_keepalive && !sender.is_closed() {
244    let connection_pool_item = unsafe { &mut *connection_pool_item.get() };
245    connection_pool_item
246      .inner_mut()
247      .replace(SendRequestWrapper::new(sender));
248  }
249
250  drop(connection_pool_item);
251
252  Ok(response)
253}