1use crate::{metadata::MetadataValue, Status};
2use bytes::{Buf, BufMut, BytesMut};
3#[cfg(feature = "gzip")]
4use flate2::read::{GzDecoder, GzEncoder};
5#[cfg(feature = "deflate")]
6use flate2::read::{ZlibDecoder, ZlibEncoder};
7use std::{borrow::Cow, fmt};
8#[cfg(feature = "zstd")]
9use zstd::stream::read::{Decoder, Encoder};
10
11pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
12pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
13
14#[derive(Debug, Default, Clone, Copy)]
18pub struct EnabledCompressionEncodings {
19 inner: [Option<CompressionEncoding>; 3],
20}
21
22impl EnabledCompressionEncodings {
23 pub fn enable(&mut self, encoding: CompressionEncoding) {
27 for e in self.inner.iter_mut() {
28 match e {
29 Some(e) if *e == encoding => return,
30 None => {
31 *e = Some(encoding);
32 return;
33 }
34 _ => continue,
35 }
36 }
37 }
38
39 pub fn pop(&mut self) -> Option<CompressionEncoding> {
41 self.inner
42 .iter_mut()
43 .rev()
44 .find(|entry| entry.is_some())?
45 .take()
46 }
47
48 pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
49 let mut value = BytesMut::new();
50 for encoding in self.inner.into_iter().flatten() {
51 value.put_slice(encoding.as_str().as_bytes());
52 value.put_u8(b',');
53 }
54
55 if value.is_empty() {
56 return None;
57 }
58
59 value.put_slice(b"identity");
60 Some(http::HeaderValue::from_maybe_shared(value).unwrap())
61 }
62
63 pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
65 self.inner.contains(&Some(encoding))
66 }
67
68 pub fn is_empty(&self) -> bool {
70 self.inner.iter().all(|e| e.is_none())
71 }
72}
73
74#[derive(Clone, Copy, Debug, PartialEq, Eq)]
75pub(crate) struct CompressionSettings {
76 pub(crate) encoding: CompressionEncoding,
77 pub(crate) buffer_growth_interval: usize,
80}
81
82#[derive(Clone, Copy, Debug, PartialEq, Eq)]
84#[non_exhaustive]
85pub enum CompressionEncoding {
86 #[allow(missing_docs)]
87 #[cfg(feature = "gzip")]
88 Gzip,
89 #[allow(missing_docs)]
90 #[cfg(feature = "deflate")]
91 Deflate,
92 #[allow(missing_docs)]
93 #[cfg(feature = "zstd")]
94 Zstd,
95}
96
97impl CompressionEncoding {
98 pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
99 #[cfg(feature = "gzip")]
100 CompressionEncoding::Gzip,
101 #[cfg(feature = "deflate")]
102 CompressionEncoding::Deflate,
103 #[cfg(feature = "zstd")]
104 CompressionEncoding::Zstd,
105 ];
106
107 pub(crate) fn from_accept_encoding_header(
109 map: &http::HeaderMap,
110 enabled_encodings: EnabledCompressionEncodings,
111 ) -> Option<Self> {
112 if enabled_encodings.is_empty() {
113 return None;
114 }
115
116 let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
117 let header_value_str = header_value.to_str().ok()?;
118
119 split_by_comma(header_value_str).find_map(|value| match value {
120 #[cfg(feature = "gzip")]
121 "gzip" => Some(CompressionEncoding::Gzip),
122 #[cfg(feature = "deflate")]
123 "deflate" => Some(CompressionEncoding::Deflate),
124 #[cfg(feature = "zstd")]
125 "zstd" => Some(CompressionEncoding::Zstd),
126 _ => None,
127 })
128 }
129
130 pub(crate) fn from_encoding_header(
132 map: &http::HeaderMap,
133 enabled_encodings: EnabledCompressionEncodings,
134 ) -> Result<Option<Self>, Status> {
135 let Some(header_value) = map.get(ENCODING_HEADER) else {
136 return Ok(None);
137 };
138
139 match header_value.as_bytes() {
140 #[cfg(feature = "gzip")]
141 b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
142 Ok(Some(CompressionEncoding::Gzip))
143 }
144 #[cfg(feature = "deflate")]
145 b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
146 Ok(Some(CompressionEncoding::Deflate))
147 }
148 #[cfg(feature = "zstd")]
149 b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
150 Ok(Some(CompressionEncoding::Zstd))
151 }
152 b"identity" => Ok(None),
153 other => {
154 let other = match std::str::from_utf8(other) {
155 Ok(s) => Cow::Borrowed(s),
156 Err(_) => Cow::Owned(format!("{other:?}")),
157 };
158
159 let mut status = Status::unimplemented(format!(
160 "Content is compressed with `{other}` which isn't supported"
161 ));
162
163 let header_value = enabled_encodings
164 .into_accept_encoding_header_value()
165 .map(MetadataValue::unchecked_from_header_value)
166 .unwrap_or_else(|| MetadataValue::from_static("identity"));
167 status
168 .metadata_mut()
169 .insert(ACCEPT_ENCODING_HEADER, header_value);
170
171 Err(status)
172 }
173 }
174 }
175
176 pub(crate) fn as_str(self) -> &'static str {
177 match self {
178 #[cfg(feature = "gzip")]
179 CompressionEncoding::Gzip => "gzip",
180 #[cfg(feature = "deflate")]
181 CompressionEncoding::Deflate => "deflate",
182 #[cfg(feature = "zstd")]
183 CompressionEncoding::Zstd => "zstd",
184 }
185 }
186
187 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
188 pub(crate) fn into_header_value(self) -> http::HeaderValue {
189 http::HeaderValue::from_static(self.as_str())
190 }
191}
192
193impl fmt::Display for CompressionEncoding {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.write_str(self.as_str())
196 }
197}
198
199fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
200 s.split(',').map(|s| s.trim())
201}
202
203#[allow(unused_variables, unreachable_code)]
206pub(crate) fn compress(
207 settings: CompressionSettings,
208 decompressed_buf: &mut BytesMut,
209 out_buf: &mut BytesMut,
210 len: usize,
211) -> Result<(), std::io::Error> {
212 let buffer_growth_interval = settings.buffer_growth_interval;
213 let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
214 out_buf.reserve(capacity);
215
216 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
217 let mut out_writer = out_buf.writer();
218
219 match settings.encoding {
220 #[cfg(feature = "gzip")]
221 CompressionEncoding::Gzip => {
222 let mut gzip_encoder = GzEncoder::new(
223 &decompressed_buf[0..len],
224 flate2::Compression::new(6),
226 );
227 std::io::copy(&mut gzip_encoder, &mut out_writer)?;
228 }
229 #[cfg(feature = "deflate")]
230 CompressionEncoding::Deflate => {
231 let mut deflate_encoder = ZlibEncoder::new(
232 &decompressed_buf[0..len],
233 flate2::Compression::new(6),
235 );
236 std::io::copy(&mut deflate_encoder, &mut out_writer)?;
237 }
238 #[cfg(feature = "zstd")]
239 CompressionEncoding::Zstd => {
240 let mut zstd_encoder = Encoder::new(
241 &decompressed_buf[0..len],
242 zstd::DEFAULT_COMPRESSION_LEVEL,
244 )?;
245 std::io::copy(&mut zstd_encoder, &mut out_writer)?;
246 }
247 }
248
249 decompressed_buf.advance(len);
250
251 Ok(())
252}
253
254#[allow(unused_variables, unreachable_code)]
256pub(crate) fn decompress(
257 settings: CompressionSettings,
258 compressed_buf: &mut BytesMut,
259 mut out_buf: bytes::buf::Limit<&mut BytesMut>,
260 len: usize,
261) -> Result<(), std::io::Error> {
262 let buffer_growth_interval = settings.buffer_growth_interval;
263 let estimate_decompressed_len = len * 2;
264 let capacity = std::cmp::min(
265 bytes::buf::Limit::limit(&out_buf),
266 ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval,
267 );
268
269 out_buf.get_mut().reserve(capacity);
270
271 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
272 let mut out_writer = out_buf.writer();
273
274 match settings.encoding {
275 #[cfg(feature = "gzip")]
276 CompressionEncoding::Gzip => {
277 let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
278 std::io::copy(&mut gzip_decoder, &mut out_writer)?;
279 }
280 #[cfg(feature = "deflate")]
281 CompressionEncoding::Deflate => {
282 let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]);
283 std::io::copy(&mut deflate_decoder, &mut out_writer)?;
284 }
285 #[cfg(feature = "zstd")]
286 CompressionEncoding::Zstd => {
287 let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
288 std::io::copy(&mut zstd_decoder, &mut out_writer)?;
289 }
290 }
291
292 compressed_buf.advance(len);
293
294 Ok(())
295}
296
297#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
299pub enum SingleMessageCompressionOverride {
300 #[default]
305 Inherit,
306 Disable,
308}
309
310#[cfg(test)]
311mod tests {
312 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
313 use http::HeaderValue;
314
315 use super::*;
316
317 #[test]
318 fn convert_none_into_header_value() {
319 let encodings = EnabledCompressionEncodings::default();
320
321 assert!(encodings.into_accept_encoding_header_value().is_none());
322 }
323
324 #[test]
325 #[cfg(feature = "gzip")]
326 fn convert_gzip_into_header_value() {
327 const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
328
329 let encodings = EnabledCompressionEncodings {
330 inner: [Some(CompressionEncoding::Gzip), None, None],
331 };
332
333 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
334
335 let encodings = EnabledCompressionEncodings {
336 inner: [None, None, Some(CompressionEncoding::Gzip)],
337 };
338
339 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
340 }
341
342 #[test]
343 #[cfg(feature = "zstd")]
344 fn convert_zstd_into_header_value() {
345 const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
346
347 let encodings = EnabledCompressionEncodings {
348 inner: [Some(CompressionEncoding::Zstd), None, None],
349 };
350
351 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
352
353 let encodings = EnabledCompressionEncodings {
354 inner: [None, None, Some(CompressionEncoding::Zstd)],
355 };
356
357 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
358 }
359
360 #[test]
361 #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
362 fn convert_compression_encodings_into_header_value() {
363 let encodings = EnabledCompressionEncodings {
364 inner: [
365 Some(CompressionEncoding::Gzip),
366 Some(CompressionEncoding::Deflate),
367 Some(CompressionEncoding::Zstd),
368 ],
369 };
370
371 assert_eq!(
372 encodings.into_accept_encoding_header_value().unwrap(),
373 HeaderValue::from_static("gzip,deflate,zstd,identity"),
374 );
375
376 let encodings = EnabledCompressionEncodings {
377 inner: [
378 Some(CompressionEncoding::Zstd),
379 Some(CompressionEncoding::Deflate),
380 Some(CompressionEncoding::Gzip),
381 ],
382 };
383
384 assert_eq!(
385 encodings.into_accept_encoding_header_value().unwrap(),
386 HeaderValue::from_static("zstd,deflate,gzip,identity"),
387 );
388 }
389}