tonic/codec/
compression.rs

1use crate::{metadata::MetadataValue, Status};
2use bytes::{Buf, BufMut, BytesMut};
3#[cfg(feature = "gzip")]
4use flate2::read::{GzDecoder, GzEncoder};
5#[cfg(feature = "deflate")]
6use flate2::read::{ZlibDecoder, ZlibEncoder};
7use std::{borrow::Cow, fmt};
8#[cfg(feature = "zstd")]
9use zstd::stream::read::{Decoder, Encoder};
10
11pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
12pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
13
14/// Struct used to configure which encodings are enabled on a server or channel.
15///
16/// Represents an ordered list of compression encodings that are enabled.
17#[derive(Debug, Default, Clone, Copy)]
18pub struct EnabledCompressionEncodings {
19    inner: [Option<CompressionEncoding>; 3],
20}
21
22impl EnabledCompressionEncodings {
23    /// Enable a [`CompressionEncoding`].
24    ///
25    /// Adds the new encoding to the end of the encoding list.
26    pub fn enable(&mut self, encoding: CompressionEncoding) {
27        for e in self.inner.iter_mut() {
28            match e {
29                Some(e) if *e == encoding => return,
30                None => {
31                    *e = Some(encoding);
32                    return;
33                }
34                _ => continue,
35            }
36        }
37    }
38
39    /// Remove the last [`CompressionEncoding`].
40    pub fn pop(&mut self) -> Option<CompressionEncoding> {
41        self.inner
42            .iter_mut()
43            .rev()
44            .find(|entry| entry.is_some())?
45            .take()
46    }
47
48    pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
49        let mut value = BytesMut::new();
50        for encoding in self.inner.into_iter().flatten() {
51            value.put_slice(encoding.as_str().as_bytes());
52            value.put_u8(b',');
53        }
54
55        if value.is_empty() {
56            return None;
57        }
58
59        value.put_slice(b"identity");
60        Some(http::HeaderValue::from_maybe_shared(value).unwrap())
61    }
62
63    /// Check if a [`CompressionEncoding`] is enabled.
64    pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
65        self.inner.contains(&Some(encoding))
66    }
67
68    /// Check if any [`CompressionEncoding`]s are enabled.
69    pub fn is_empty(&self) -> bool {
70        self.inner.iter().all(|e| e.is_none())
71    }
72}
73
74#[derive(Clone, Copy, Debug, PartialEq, Eq)]
75pub(crate) struct CompressionSettings {
76    pub(crate) encoding: CompressionEncoding,
77    /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste.
78    /// The default buffer growth interval is 8 kilobytes.
79    pub(crate) buffer_growth_interval: usize,
80}
81
82/// The compression encodings Tonic supports.
83#[derive(Clone, Copy, Debug, PartialEq, Eq)]
84#[non_exhaustive]
85pub enum CompressionEncoding {
86    #[allow(missing_docs)]
87    #[cfg(feature = "gzip")]
88    Gzip,
89    #[allow(missing_docs)]
90    #[cfg(feature = "deflate")]
91    Deflate,
92    #[allow(missing_docs)]
93    #[cfg(feature = "zstd")]
94    Zstd,
95}
96
97impl CompressionEncoding {
98    pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
99        #[cfg(feature = "gzip")]
100        CompressionEncoding::Gzip,
101        #[cfg(feature = "deflate")]
102        CompressionEncoding::Deflate,
103        #[cfg(feature = "zstd")]
104        CompressionEncoding::Zstd,
105    ];
106
107    /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
108    pub(crate) fn from_accept_encoding_header(
109        map: &http::HeaderMap,
110        enabled_encodings: EnabledCompressionEncodings,
111    ) -> Option<Self> {
112        if enabled_encodings.is_empty() {
113            return None;
114        }
115
116        let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
117        let header_value_str = header_value.to_str().ok()?;
118
119        split_by_comma(header_value_str).find_map(|value| match value {
120            #[cfg(feature = "gzip")]
121            "gzip" => Some(CompressionEncoding::Gzip),
122            #[cfg(feature = "deflate")]
123            "deflate" => Some(CompressionEncoding::Deflate),
124            #[cfg(feature = "zstd")]
125            "zstd" => Some(CompressionEncoding::Zstd),
126            _ => None,
127        })
128    }
129
130    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
131    pub(crate) fn from_encoding_header(
132        map: &http::HeaderMap,
133        enabled_encodings: EnabledCompressionEncodings,
134    ) -> Result<Option<Self>, Status> {
135        let Some(header_value) = map.get(ENCODING_HEADER) else {
136            return Ok(None);
137        };
138
139        match header_value.as_bytes() {
140            #[cfg(feature = "gzip")]
141            b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
142                Ok(Some(CompressionEncoding::Gzip))
143            }
144            #[cfg(feature = "deflate")]
145            b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
146                Ok(Some(CompressionEncoding::Deflate))
147            }
148            #[cfg(feature = "zstd")]
149            b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
150                Ok(Some(CompressionEncoding::Zstd))
151            }
152            b"identity" => Ok(None),
153            other => {
154                let other = match std::str::from_utf8(other) {
155                    Ok(s) => Cow::Borrowed(s),
156                    Err(_) => Cow::Owned(format!("{other:?}")),
157                };
158
159                let mut status = Status::unimplemented(format!(
160                    "Content is compressed with `{other}` which isn't supported"
161                ));
162
163                let header_value = enabled_encodings
164                    .into_accept_encoding_header_value()
165                    .map(MetadataValue::unchecked_from_header_value)
166                    .unwrap_or_else(|| MetadataValue::from_static("identity"));
167                status
168                    .metadata_mut()
169                    .insert(ACCEPT_ENCODING_HEADER, header_value);
170
171                Err(status)
172            }
173        }
174    }
175
176    pub(crate) fn as_str(self) -> &'static str {
177        match self {
178            #[cfg(feature = "gzip")]
179            CompressionEncoding::Gzip => "gzip",
180            #[cfg(feature = "deflate")]
181            CompressionEncoding::Deflate => "deflate",
182            #[cfg(feature = "zstd")]
183            CompressionEncoding::Zstd => "zstd",
184        }
185    }
186
187    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
188    pub(crate) fn into_header_value(self) -> http::HeaderValue {
189        http::HeaderValue::from_static(self.as_str())
190    }
191}
192
193impl fmt::Display for CompressionEncoding {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        f.write_str(self.as_str())
196    }
197}
198
199fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
200    s.split(',').map(|s| s.trim())
201}
202
203/// Compress `len` bytes from `decompressed_buf` into `out_buf`.
204/// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it.
205#[allow(unused_variables, unreachable_code)]
206pub(crate) fn compress(
207    settings: CompressionSettings,
208    decompressed_buf: &mut BytesMut,
209    out_buf: &mut BytesMut,
210    len: usize,
211) -> Result<(), std::io::Error> {
212    let buffer_growth_interval = settings.buffer_growth_interval;
213    let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
214    out_buf.reserve(capacity);
215
216    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
217    let mut out_writer = out_buf.writer();
218
219    match settings.encoding {
220        #[cfg(feature = "gzip")]
221        CompressionEncoding::Gzip => {
222            let mut gzip_encoder = GzEncoder::new(
223                &decompressed_buf[0..len],
224                // FIXME: support customizing the compression level
225                flate2::Compression::new(6),
226            );
227            std::io::copy(&mut gzip_encoder, &mut out_writer)?;
228        }
229        #[cfg(feature = "deflate")]
230        CompressionEncoding::Deflate => {
231            let mut deflate_encoder = ZlibEncoder::new(
232                &decompressed_buf[0..len],
233                // FIXME: support customizing the compression level
234                flate2::Compression::new(6),
235            );
236            std::io::copy(&mut deflate_encoder, &mut out_writer)?;
237        }
238        #[cfg(feature = "zstd")]
239        CompressionEncoding::Zstd => {
240            let mut zstd_encoder = Encoder::new(
241                &decompressed_buf[0..len],
242                // FIXME: support customizing the compression level
243                zstd::DEFAULT_COMPRESSION_LEVEL,
244            )?;
245            std::io::copy(&mut zstd_encoder, &mut out_writer)?;
246        }
247    }
248
249    decompressed_buf.advance(len);
250
251    Ok(())
252}
253
254/// Decompress `len` bytes from `compressed_buf` into `out_buf`.
255#[allow(unused_variables, unreachable_code)]
256pub(crate) fn decompress(
257    settings: CompressionSettings,
258    compressed_buf: &mut BytesMut,
259    mut out_buf: bytes::buf::Limit<&mut BytesMut>,
260    len: usize,
261) -> Result<(), std::io::Error> {
262    let buffer_growth_interval = settings.buffer_growth_interval;
263    let estimate_decompressed_len = len * 2;
264    let capacity = std::cmp::min(
265        bytes::buf::Limit::limit(&out_buf),
266        ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval,
267    );
268
269    out_buf.get_mut().reserve(capacity);
270
271    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
272    let mut out_writer = out_buf.writer();
273
274    match settings.encoding {
275        #[cfg(feature = "gzip")]
276        CompressionEncoding::Gzip => {
277            let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
278            std::io::copy(&mut gzip_decoder, &mut out_writer)?;
279        }
280        #[cfg(feature = "deflate")]
281        CompressionEncoding::Deflate => {
282            let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]);
283            std::io::copy(&mut deflate_decoder, &mut out_writer)?;
284        }
285        #[cfg(feature = "zstd")]
286        CompressionEncoding::Zstd => {
287            let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
288            std::io::copy(&mut zstd_decoder, &mut out_writer)?;
289        }
290    }
291
292    compressed_buf.advance(len);
293
294    Ok(())
295}
296
297/// Controls compression behavior for individual messages within a stream.
298#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
299pub enum SingleMessageCompressionOverride {
300    /// Inherit whatever compression is already configured. If the stream is compressed this
301    /// message will also be configured.
302    ///
303    /// This is the default.
304    #[default]
305    Inherit,
306    /// Don't compress this message, even if compression is enabled on the stream.
307    Disable,
308}
309
310#[cfg(test)]
311mod tests {
312    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
313    use http::HeaderValue;
314
315    use super::*;
316
317    #[test]
318    fn convert_none_into_header_value() {
319        let encodings = EnabledCompressionEncodings::default();
320
321        assert!(encodings.into_accept_encoding_header_value().is_none());
322    }
323
324    #[test]
325    #[cfg(feature = "gzip")]
326    fn convert_gzip_into_header_value() {
327        const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
328
329        let encodings = EnabledCompressionEncodings {
330            inner: [Some(CompressionEncoding::Gzip), None, None],
331        };
332
333        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
334
335        let encodings = EnabledCompressionEncodings {
336            inner: [None, None, Some(CompressionEncoding::Gzip)],
337        };
338
339        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
340    }
341
342    #[test]
343    #[cfg(feature = "zstd")]
344    fn convert_zstd_into_header_value() {
345        const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
346
347        let encodings = EnabledCompressionEncodings {
348            inner: [Some(CompressionEncoding::Zstd), None, None],
349        };
350
351        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
352
353        let encodings = EnabledCompressionEncodings {
354            inner: [None, None, Some(CompressionEncoding::Zstd)],
355        };
356
357        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
358    }
359
360    #[test]
361    #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
362    fn convert_compression_encodings_into_header_value() {
363        let encodings = EnabledCompressionEncodings {
364            inner: [
365                Some(CompressionEncoding::Gzip),
366                Some(CompressionEncoding::Deflate),
367                Some(CompressionEncoding::Zstd),
368            ],
369        };
370
371        assert_eq!(
372            encodings.into_accept_encoding_header_value().unwrap(),
373            HeaderValue::from_static("gzip,deflate,zstd,identity"),
374        );
375
376        let encodings = EnabledCompressionEncodings {
377            inner: [
378                Some(CompressionEncoding::Zstd),
379                Some(CompressionEncoding::Deflate),
380                Some(CompressionEncoding::Gzip),
381            ],
382        };
383
384        assert_eq!(
385            encodings.into_accept_encoding_header_value().unwrap(),
386            HeaderValue::from_static("zstd,deflate,gzip,identity"),
387        );
388    }
389}