ferron/util/
proxy_protocol.rs

1// Copyright 2021 Axum Server Contributors
2// Portions of this file are derived from `hyper-server` (https://github.com/warlock-labs/postel/tree/6d93b4251766d97120b96ecee6d198b3406da7da).
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in all
12// copies or substantial portions of the Software.
13
14use std::net::{IpAddr, SocketAddr};
15
16use ppp::HeaderResult;
17use tokio::io::{AsyncRead, AsyncReadExt};
18
19/// The length of a v1 header in bytes.
20const V1_PREFIX_LEN: usize = 5;
21/// The maximum length of a v1 header in bytes.
22const V1_MAX_LENGTH: usize = 107;
23/// The terminator bytes of a v1 header.
24const V1_TERMINATOR: &[u8] = b"\r\n";
25/// The prefix length of a v2 header in bytes.
26const V2_PREFIX_LEN: usize = 12;
27/// The minimum length of a v2 header in bytes.
28const V2_MINIMUM_LEN: usize = 16;
29/// The index of the start of the big-endian u16 length in the v2 header.
30const V2_LENGTH_INDEX: usize = 14;
31/// The length of the read buffer used to read the PROXY protocol header.
32const READ_BUFFER_LEN: usize = 512;
33
34/// Reads the PROXY protocol header from the given `AsyncRead`.
35pub async fn read_proxy_header<I>(mut stream: I) -> Result<(I, Option<SocketAddr>, Option<SocketAddr>), std::io::Error>
36where
37  I: AsyncRead + Unpin,
38{
39  // Mutable buffer for storing stream data
40  let mut buffer = [0; READ_BUFFER_LEN];
41  // Dynamic in case v2 header is too long
42  let mut dynamic_buffer = None;
43
44  // Read prefix to check for v1, v2, or kill
45  stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?;
46
47  if &buffer[..V1_PREFIX_LEN] == ppp::v1::PROTOCOL_PREFIX.as_bytes() {
48    read_v1_header(&mut stream, &mut buffer).await?;
49  } else {
50    stream.read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]).await?;
51    if &buffer[..V2_PREFIX_LEN] == ppp::v2::PROTOCOL_PREFIX {
52      dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?;
53    } else {
54      return Err(std::io::Error::new(
55        std::io::ErrorKind::InvalidData,
56        "No valid Proxy Protocol header detected",
57      ));
58    }
59  }
60
61  // Choose which buffer to parse
62  let buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]);
63
64  // Parse the header
65  let header = HeaderResult::parse(buffer);
66  match header {
67    HeaderResult::V1(Ok(header)) => {
68      let (client_address, server_address) = match header.addresses {
69        ppp::v1::Addresses::Tcp4(ip) => (
70          SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port),
71          SocketAddr::new(IpAddr::V4(ip.destination_address), ip.destination_port),
72        ),
73        ppp::v1::Addresses::Tcp6(ip) => (
74          SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port),
75          SocketAddr::new(IpAddr::V6(ip.destination_address), ip.destination_port),
76        ),
77        ppp::v1::Addresses::Unknown => {
78          // Return client address as `None` so that "unknown" is used in the http header
79          return Ok((stream, None, None));
80        }
81      };
82
83      Ok((stream, Some(client_address), Some(server_address)))
84    }
85    HeaderResult::V2(Ok(header)) => {
86      let (client_address, server_address) = match header.addresses {
87        ppp::v2::Addresses::IPv4(ip) => (
88          SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port),
89          SocketAddr::new(IpAddr::V4(ip.destination_address), ip.destination_port),
90        ),
91        ppp::v2::Addresses::IPv6(ip) => (
92          SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port),
93          SocketAddr::new(IpAddr::V6(ip.destination_address), ip.destination_port),
94        ),
95        ppp::v2::Addresses::Unix(unix) => {
96          return Err(std::io::Error::new(
97            std::io::ErrorKind::InvalidData,
98            format!("Unix socket addresses are not supported. Addresses: {unix:?}"),
99          ));
100        }
101        ppp::v2::Addresses::Unspecified => {
102          // Return client address as `None` so that "unknown" is used in the http header
103          return Ok((stream, None, None));
104        }
105      };
106
107      Ok((stream, Some(client_address), Some(server_address)))
108    }
109    HeaderResult::V1(Err(_error)) => Err(std::io::Error::new(
110      std::io::ErrorKind::InvalidData,
111      "No valid V1 Proxy Protocol header received",
112    )),
113    HeaderResult::V2(Err(_error)) => Err(std::io::Error::new(
114      std::io::ErrorKind::InvalidData,
115      "No valid V2 Proxy Protocol header received",
116    )),
117  }
118}
119
120async fn read_v2_header<I>(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<Option<Vec<u8>>, std::io::Error>
121where
122  I: AsyncRead + Unpin,
123{
124  let length = u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize;
125  let full_length = V2_MINIMUM_LEN + length;
126
127  // Switch to dynamic buffer if header is too long; v2 has no maximum length
128  if full_length > READ_BUFFER_LEN {
129    let mut dynamic_buffer = Vec::with_capacity(full_length);
130    dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]);
131
132    // Read the remaining header length
133    stream
134      .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length])
135      .await?;
136
137    Ok(Some(dynamic_buffer))
138  } else {
139    // Read the remaining header length
140    stream.read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]).await?;
141
142    Ok(None)
143  }
144}
145
146async fn read_v1_header<I>(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<(), std::io::Error>
147where
148  I: AsyncRead + Unpin,
149{
150  // Read one byte at a time until terminator found
151  let mut end_found = false;
152  for i in V1_PREFIX_LEN..V1_MAX_LENGTH {
153    buffer[i] = stream.read_u8().await?;
154
155    if [buffer[i - 1], buffer[i]] == V1_TERMINATOR {
156      end_found = true;
157      break;
158    }
159  }
160  if !end_found {
161    return Err(std::io::Error::new(
162      std::io::ErrorKind::InvalidData,
163      "No valid Proxy Protocol header detected",
164    ));
165  }
166
167  Ok(())
168}