ferron/util/
proxy_protocol.rs1use std::net::{IpAddr, SocketAddr};
15
16use ppp::HeaderResult;
17use tokio::io::{AsyncRead, AsyncReadExt};
18
19const V1_PREFIX_LEN: usize = 5;
21const V1_MAX_LENGTH: usize = 107;
23const V1_TERMINATOR: &[u8] = b"\r\n";
25const V2_PREFIX_LEN: usize = 12;
27const V2_MINIMUM_LEN: usize = 16;
29const V2_LENGTH_INDEX: usize = 14;
31const READ_BUFFER_LEN: usize = 512;
33
34pub async fn read_proxy_header<I>(mut stream: I) -> Result<(I, Option<SocketAddr>, Option<SocketAddr>), std::io::Error>
36where
37 I: AsyncRead + Unpin,
38{
39 let mut buffer = [0; READ_BUFFER_LEN];
41 let mut dynamic_buffer = None;
43
44 stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?;
46
47 if &buffer[..V1_PREFIX_LEN] == ppp::v1::PROTOCOL_PREFIX.as_bytes() {
48 read_v1_header(&mut stream, &mut buffer).await?;
49 } else {
50 stream.read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]).await?;
51 if &buffer[..V2_PREFIX_LEN] == ppp::v2::PROTOCOL_PREFIX {
52 dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?;
53 } else {
54 return Err(std::io::Error::new(
55 std::io::ErrorKind::InvalidData,
56 "No valid Proxy Protocol header detected",
57 ));
58 }
59 }
60
61 let buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]);
63
64 let header = HeaderResult::parse(buffer);
66 match header {
67 HeaderResult::V1(Ok(header)) => {
68 let (client_address, server_address) = match header.addresses {
69 ppp::v1::Addresses::Tcp4(ip) => (
70 SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port),
71 SocketAddr::new(IpAddr::V4(ip.destination_address), ip.destination_port),
72 ),
73 ppp::v1::Addresses::Tcp6(ip) => (
74 SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port),
75 SocketAddr::new(IpAddr::V6(ip.destination_address), ip.destination_port),
76 ),
77 ppp::v1::Addresses::Unknown => {
78 return Ok((stream, None, None));
80 }
81 };
82
83 Ok((stream, Some(client_address), Some(server_address)))
84 }
85 HeaderResult::V2(Ok(header)) => {
86 let (client_address, server_address) = match header.addresses {
87 ppp::v2::Addresses::IPv4(ip) => (
88 SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port),
89 SocketAddr::new(IpAddr::V4(ip.destination_address), ip.destination_port),
90 ),
91 ppp::v2::Addresses::IPv6(ip) => (
92 SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port),
93 SocketAddr::new(IpAddr::V6(ip.destination_address), ip.destination_port),
94 ),
95 ppp::v2::Addresses::Unix(unix) => {
96 return Err(std::io::Error::new(
97 std::io::ErrorKind::InvalidData,
98 format!("Unix socket addresses are not supported. Addresses: {unix:?}"),
99 ));
100 }
101 ppp::v2::Addresses::Unspecified => {
102 return Ok((stream, None, None));
104 }
105 };
106
107 Ok((stream, Some(client_address), Some(server_address)))
108 }
109 HeaderResult::V1(Err(_error)) => Err(std::io::Error::new(
110 std::io::ErrorKind::InvalidData,
111 "No valid V1 Proxy Protocol header received",
112 )),
113 HeaderResult::V2(Err(_error)) => Err(std::io::Error::new(
114 std::io::ErrorKind::InvalidData,
115 "No valid V2 Proxy Protocol header received",
116 )),
117 }
118}
119
120async fn read_v2_header<I>(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<Option<Vec<u8>>, std::io::Error>
121where
122 I: AsyncRead + Unpin,
123{
124 let length = u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize;
125 let full_length = V2_MINIMUM_LEN + length;
126
127 if full_length > READ_BUFFER_LEN {
129 let mut dynamic_buffer = Vec::with_capacity(full_length);
130 dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]);
131
132 stream
134 .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length])
135 .await?;
136
137 Ok(Some(dynamic_buffer))
138 } else {
139 stream.read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]).await?;
141
142 Ok(None)
143 }
144}
145
146async fn read_v1_header<I>(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<(), std::io::Error>
147where
148 I: AsyncRead + Unpin,
149{
150 let mut end_found = false;
152 for i in V1_PREFIX_LEN..V1_MAX_LENGTH {
153 buffer[i] = stream.read_u8().await?;
154
155 if [buffer[i - 1], buffer[i]] == V1_TERMINATOR {
156 end_found = true;
157 break;
158 }
159 }
160 if !end_found {
161 return Err(std::io::Error::new(
162 std::io::ErrorKind::InvalidData,
163 "No valid Proxy Protocol header detected",
164 ));
165 }
166
167 Ok(())
168}