ferron_common/http_proxy/
proxy_client.rs1use std::cell::UnsafeCell;
2use std::error::Error;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use bytes::Bytes;
8use http_body_util::combinators::BoxBody;
9use http_body_util::BodyExt;
10use hyper::body::Body;
11use hyper::{Request, Response, StatusCode};
12#[cfg(feature = "runtime-tokio")]
13use hyper_util::rt::{TokioExecutor, TokioIo};
14#[cfg(feature = "runtime-monoio")]
15use monoio_compat::hyper::{MonoioExecutor, MonoioIo};
16use tokio::io::{AsyncRead, AsyncWrite};
17#[cfg(feature = "runtime-vibeio")]
18use vibeio_hyper::{VibeioExecutor, VibeioIo};
19
20use super::ConnectionPoolItem;
21#[cfg(any(feature = "runtime-monoio", feature = "runtime-vibeio"))]
22use super::DropGuard;
23use crate::http_proxy::send_request::{SendRequest, SendRequestWrapper};
24use crate::logging::ErrorLogger;
25use crate::modules::ResponseData;
26
27struct TrackedBody<B> {
29 inner: B,
30 _tracker: Option<Arc<()>>,
31 _tracker_pool: Option<Arc<UnsafeCell<ConnectionPoolItem>>>,
32}
33
34impl<B> TrackedBody<B> {
35 fn new(inner: B, tracker: Option<Arc<()>>, tracker_pool: Option<Arc<UnsafeCell<ConnectionPoolItem>>>) -> Self {
36 Self {
37 inner,
38 _tracker: tracker,
39 _tracker_pool: tracker_pool,
40 }
41 }
42}
43
44impl<B> Body for TrackedBody<B>
45where
46 B: Body + Unpin,
47{
48 type Data = B::Data;
49 type Error = B::Error;
50
51 #[inline]
52 fn poll_frame(
53 mut self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 ) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
56 Pin::new(&mut self.inner).poll_frame(cx)
57 }
58
59 #[inline]
60 fn is_end_stream(&self) -> bool {
61 self.inner.is_end_stream()
62 }
63
64 #[inline]
65 fn size_hint(&self) -> hyper::body::SizeHint {
66 self.inner.size_hint()
67 }
68}
69
70unsafe impl<B> Send for TrackedBody<B> where B: Send {}
73unsafe impl<B> Sync for TrackedBody<B> where B: Sync {}
74
75pub(super) async fn http_proxy_handshake(
77 stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
78 use_http2: bool,
79 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))] drop_guard: DropGuard,
80) -> Result<SendRequest, Box<dyn Error + Send + Sync>> {
81 #[cfg(feature = "runtime-vibeio")]
82 let io = VibeioIo::new(stream);
83 #[cfg(feature = "runtime-monoio")]
84 let io = MonoioIo::new(stream);
85 #[cfg(feature = "runtime-tokio")]
86 let io = TokioIo::new(stream);
87
88 Ok(if use_http2 {
89 #[cfg(feature = "runtime-vibeio")]
90 let executor = VibeioExecutor;
91 #[cfg(feature = "runtime-monoio")]
92 let executor = MonoioExecutor;
93 #[cfg(feature = "runtime-tokio")]
94 let executor = TokioExecutor::new();
95
96 let (sender, conn) = hyper::client::conn::http2::handshake(executor, io).await?;
97
98 crate::runtime::spawn(async move {
99 conn.await.unwrap_or_default();
100 #[cfg(feature = "runtime-monoio")]
101 drop(drop_guard);
102 });
103
104 SendRequest::Http2(sender)
105 } else {
106 let (sender, conn) = hyper::client::conn::http1::handshake(io).await?;
107
108 let conn_with_upgrades = conn.with_upgrades();
109 crate::runtime::spawn(async move {
110 conn_with_upgrades.await.unwrap_or_default();
111 #[cfg(any(feature = "runtime-vibeio", feature = "runtime-monoio"))]
112 drop(drop_guard);
113 });
114
115 SendRequest::Http1(sender)
116 })
117}
118
119pub(super) async fn http_proxy(
121 mut sender: SendRequest,
122 connection_pool_item: ConnectionPoolItem,
123 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
124 error_logger: &ErrorLogger,
125 proxy_intercept_errors: bool,
126 tracked_connection: Option<Arc<()>>,
127 enable_keepalive: bool,
128) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
129 let (proxy_request_parts, proxy_request_body) = proxy_request.into_parts();
130 #[cfg(feature = "runtime-vibeio")]
131 let mut proxy_request_cloned = Request::from_parts(
132 proxy_request_parts.clone(),
133 http_body_util::Empty::<bytes::Bytes>::new(),
134 );
135 #[cfg(not(feature = "runtime-vibeio"))]
136 let proxy_request_cloned = Request::from_parts(proxy_request_parts.clone(), ());
137 let proxy_request = Request::from_parts(proxy_request_parts, proxy_request_body);
138
139 let send_request_result = sender.send_request(proxy_request).await;
140 #[allow(clippy::arc_with_non_send_sync)]
141 let connection_pool_item = Arc::new(UnsafeCell::new(connection_pool_item));
142
143 let proxy_response = match send_request_result {
144 Ok(response) => response,
145 Err(err) => {
146 error_logger.log(&format!("Bad gateway: {err}")).await;
147 return Ok(ResponseData {
148 request: None,
149 response: None,
150 response_status: Some(StatusCode::BAD_GATEWAY),
151 response_headers: None,
152 new_remote_address: None,
153 });
154 }
155 };
156
157 let status_code = proxy_response.status();
158
159 let (proxy_response_parts, proxy_response_body) = proxy_response.into_parts();
160 if proxy_response_parts.status == StatusCode::SWITCHING_PROTOCOLS {
161 let proxy_response_cloned = Response::from_parts(proxy_response_parts.clone(), ());
162 match hyper::upgrade::on(proxy_response_cloned).await {
163 Ok(upgraded_backend) => {
164 let error_logger = error_logger.clone();
165 let connection_pool_item = connection_pool_item.clone();
166 #[cfg(feature = "runtime-vibeio")]
167 let upgrade_on = vibeio_http::prepare_upgrade(&mut proxy_request_cloned);
168 crate::runtime::spawn(async move {
169 #[cfg(feature = "runtime-vibeio")]
170 let upgrade_on = (if let Some(upgraded_request) = upgrade_on {
171 upgraded_request.await
172 } else {
173 None
174 })
175 .ok_or(std::io::Error::other("vibeio HTTP upgrade failure"));
176 #[cfg(not(feature = "runtime-vibeio"))]
177 let upgrade_on = hyper::upgrade::on(proxy_request_cloned).await;
178 match upgrade_on {
179 Ok(upgraded_proxy) => {
180 #[cfg(feature = "runtime-vibeio")]
181 let mut upgraded_backend = VibeioIo::new(upgraded_backend);
182 #[cfg(feature = "runtime-monoio")]
183 let mut upgraded_backend = MonoioIo::new(upgraded_backend);
184 #[cfg(feature = "runtime-tokio")]
185 let mut upgraded_backend = TokioIo::new(upgraded_backend);
186
187 #[cfg(feature = "runtime-vibeio")]
188 let mut upgraded_proxy = upgraded_proxy;
189 #[cfg(feature = "runtime-monoio")]
190 let mut upgraded_proxy = MonoioIo::new(upgraded_proxy);
191 #[cfg(feature = "runtime-tokio")]
192 let mut upgraded_proxy = TokioIo::new(upgraded_proxy);
193
194 crate::runtime::spawn(async move {
195 tokio::io::copy_bidirectional(&mut upgraded_backend, &mut upgraded_proxy)
196 .await
197 .unwrap_or_default();
198 drop(connection_pool_item);
199 });
200 }
201 Err(err) => {
202 error_logger.log(&format!("HTTP upgrade error: {err}")).await;
203 }
204 }
205 });
206 }
207 Err(err) => {
208 error_logger.log(&format!("HTTP upgrade error: {err}")).await;
209 }
210 }
211 }
212 let proxy_response = Response::from_parts(proxy_response_parts, proxy_response_body);
213
214 let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
215 ResponseData {
216 request: None,
217 response: None,
218 response_status: Some(status_code),
219 response_headers: None,
220 new_remote_address: None,
221 }
222 } else {
223 let (response_parts, response_body) = proxy_response.into_parts();
224 let boxed_body = TrackedBody::new(
225 response_body.map_err(|e| std::io::Error::other(e.to_string())),
226 tracked_connection,
227 if enable_keepalive && !sender.is_closed() {
228 None
229 } else {
230 Some(connection_pool_item.clone())
231 },
232 )
233 .boxed();
234 ResponseData {
235 request: None,
236 response: Some(Response::from_parts(response_parts, boxed_body)),
237 response_status: None,
238 response_headers: None,
239 new_remote_address: None,
240 }
241 };
242
243 if enable_keepalive && !sender.is_closed() {
244 let connection_pool_item = unsafe { &mut *connection_pool_item.get() };
245 connection_pool_item
246 .inner_mut()
247 .replace(SendRequestWrapper::new(sender));
248 }
249
250 drop(connection_pool_item);
251
252 Ok(response)
253}