ferron_common/http_proxy/
request_parts.rs

1use std::error::Error;
2use std::str::FromStr;
3
4use hyper::header::{self, HeaderName};
5use hyper::{HeaderMap, Uri, Version};
6
7use crate::config::ServerConfiguration;
8use crate::get_value;
9use crate::modules::SocketData;
10use crate::util::replace_header_placeholders;
11
12/// Constructs a proxy request based on the original request.
13#[inline]
14#[allow(clippy::too_many_arguments)]
15pub(super) fn construct_proxy_request_parts(
16  mut request_parts: hyper::http::request::Parts,
17  config: &ServerConfiguration,
18  socket_data: &SocketData,
19  proxy_request_url: &Uri,
20  headers_to_add: &[(HeaderName, String)],
21  headers_to_replace: &[(HeaderName, String)],
22  headers_to_remove: &[HeaderName],
23  rewrite_host: bool,
24) -> Result<hyper::http::request::Parts, Box<dyn Error + Send + Sync>> {
25  let headers_to_add = HeaderMap::from_iter(headers_to_add.iter().cloned().filter_map(|(name, value)| {
26    replace_header_placeholders(&value, &request_parts, Some(socket_data))
27      .parse()
28      .ok()
29      .map(|v| (name, v))
30  }));
31  let headers_to_replace = HeaderMap::from_iter(headers_to_replace.iter().cloned().filter_map(|(name, value)| {
32    replace_header_placeholders(&value, &request_parts, Some(socket_data))
33      .parse()
34      .ok()
35      .map(|v| (name, v))
36  }));
37  let headers_to_remove = headers_to_remove.to_vec();
38
39  let authority = proxy_request_url.authority().cloned();
40
41  let request_path = request_parts.uri.path();
42
43  let path = match request_path.as_bytes().first() {
44    Some(b'/') => {
45      let mut proxy_request_path = proxy_request_url.path();
46      while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
47        proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
48      }
49      format!("{proxy_request_path}{request_path}")
50    }
51    _ => request_path.to_string(),
52  };
53
54  request_parts.uri = Uri::from_str(&format!(
55    "{}{}",
56    path,
57    match request_parts.uri.query() {
58      Some(query) => format!("?{query}"),
59      None => "".to_string(),
60    }
61  ))?;
62
63  let original_host = request_parts.headers.get(header::HOST).cloned();
64
65  if rewrite_host || proxy_request_url.scheme_str() == Some("https") {
66    match authority {
67      Some(authority) => {
68        request_parts
69          .headers
70          .insert(header::HOST, authority.to_string().parse()?);
71      }
72      None => {
73        request_parts.headers.remove(header::HOST);
74      }
75    }
76  }
77
78  if let Some(connection_header) = request_parts.headers.get(&header::CONNECTION) {
79    let connection_str = String::from_utf8_lossy(connection_header.as_bytes());
80    if connection_str
81      .to_lowercase()
82      .split(",")
83      .map(|c| c.trim())
84      .all(|c| c != "keep-alive" && c != "upgrade" && c != "close")
85    {
86      request_parts
87        .headers
88        .insert(header::CONNECTION, format!("keep-alive, {connection_str}").parse()?);
89    }
90  } else {
91    request_parts.headers.insert(header::CONNECTION, "keep-alive".parse()?);
92  }
93
94  let trust_x_forwarded_for = get_value!("trust_x_forwarded_for", config)
95    .and_then(|v| v.as_bool())
96    .unwrap_or(false);
97
98  let remote_addr_str = socket_data.remote_addr.ip().to_canonical().to_string();
99  request_parts.headers.insert(
100    HeaderName::from_static("x-forwarded-for"),
101    (if let Some(ref forwarded_for) = request_parts
102      .headers
103      .get(HeaderName::from_static("x-forwarded-for"))
104      .and_then(|h| h.to_str().ok())
105    {
106      if trust_x_forwarded_for {
107        format!("{forwarded_for}, {remote_addr_str}")
108      } else {
109        remote_addr_str
110      }
111    } else {
112      remote_addr_str
113    })
114    .parse()?,
115  );
116
117  if !trust_x_forwarded_for
118    || !request_parts
119      .headers
120      .contains_key(HeaderName::from_static("x-forwarded-proto"))
121  {
122    if socket_data.encrypted {
123      request_parts
124        .headers
125        .insert(HeaderName::from_static("x-forwarded-proto"), "https".parse()?);
126    } else {
127      request_parts
128        .headers
129        .insert(HeaderName::from_static("x-forwarded-proto"), "http".parse()?);
130    }
131  }
132
133  if !trust_x_forwarded_for
134    || !request_parts
135      .headers
136      .contains_key(HeaderName::from_static("x-forwarded-host"))
137  {
138    if let Some(original_host) = original_host {
139      request_parts
140        .headers
141        .insert(HeaderName::from_static("x-forwarded-host"), original_host);
142    }
143  }
144
145  let mut forwarded_header_value = None;
146  if let Some(forwarded_header_value_obtained) = request_parts
147    .headers
148    .get(HeaderName::from_static("x-forwarded-for"))
149    .and_then(|h| h.to_str().ok())
150  {
151    let mut forwarded_header_value_new = Vec::new();
152    let mut is_first = true;
153
154    for ip in forwarded_header_value_obtained
155      .split(',')
156      .map(|s| s.trim())
157      .filter(|s| !s.is_empty())
158    {
159      let escape_determinants: &'static [char] = &[
160        '(', ')', ',', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}', '"', '\'', '\r', '\n', '\t',
161      ];
162
163      let forwarded_for = if ip.parse::<std::net::Ipv4Addr>().is_ok() {
164        ip.to_string()
165      } else if ip.parse::<std::net::Ipv6Addr>().is_ok() {
166        format!("\"[{ip}]\"")
167      } else if ip.contains(escape_determinants) {
168        format!("\"{}\"", ip.escape_default())
169      } else {
170        ip.to_string()
171      };
172
173      let (forwarded_host, forwarded_proto) = if is_first {
174        (
175          request_parts
176            .headers
177            .get(HeaderName::from_static("x-forwarded-host"))
178            .and_then(|h| h.to_str().ok()),
179          request_parts
180            .headers
181            .get(HeaderName::from_static("x-forwarded-proto"))
182            .and_then(|h| h.to_str().ok()),
183        )
184      } else {
185        (None, None)
186      };
187
188      let mut forwarded_entry = Vec::new();
189      forwarded_entry.push(format!("for={}", forwarded_for));
190      if let Some(forwarded_proto) = forwarded_proto {
191        forwarded_entry.push(format!(
192          "proto={}",
193          if forwarded_proto.contains(escape_determinants) {
194            format!("\"{}\"", forwarded_proto.escape_default())
195          } else {
196            forwarded_proto.to_string()
197          }
198        ));
199      }
200      if let Some(forwarded_host) = forwarded_host {
201        forwarded_entry.push(format!(
202          "host={}",
203          if forwarded_host.contains(escape_determinants) {
204            format!("\"{}\"", forwarded_host.escape_default())
205          } else {
206            forwarded_host.to_string()
207          }
208        ));
209      }
210      forwarded_header_value_new.push(forwarded_entry.join(";"));
211
212      is_first = false;
213    }
214
215    forwarded_header_value = Some(forwarded_header_value_new.join(", "));
216  }
217  if let Some(forwarded_header_value) = forwarded_header_value {
218    request_parts
219      .headers
220      .insert(header::FORWARDED, forwarded_header_value.parse()?);
221  } else {
222    request_parts.headers.remove(header::FORWARDED);
223  }
224
225  for (header_name_option, header_value) in headers_to_add {
226    if let Some(header_name) = header_name_option {
227      if !request_parts.headers.contains_key(&header_name) {
228        request_parts.headers.insert(header_name, header_value);
229      }
230    }
231  }
232
233  for (header_name_option, header_value) in headers_to_replace {
234    if let Some(header_name) = header_name_option {
235      request_parts.headers.insert(header_name, header_value);
236    }
237  }
238
239  for header_to_remove in headers_to_remove.into_iter().rev() {
240    if request_parts.headers.contains_key(&header_to_remove) {
241      while request_parts.headers.remove(&header_to_remove).is_some() {}
242    }
243  }
244
245  request_parts.version = Version::default();
246
247  Ok(request_parts)
248}