use super::encode::BUFFER_SIZE; use crate::{metadata::MetadataValue, Status}; use bytes::{Buf, BytesMut}; #[cfg(feature = "gzip")] use flate2::read::{GzDecoder, GzEncoder}; use std::fmt; #[cfg(feature = "zstd")] use zstd::stream::read::{Decoder, Encoder}; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; /// Struct used to configure which encodings are enabled on a server or channel. #[derive(Debug, Default, Clone, Copy)] pub struct EnabledCompressionEncodings { #[cfg(feature = "gzip")] pub(crate) gzip: bool, #[cfg(feature = "zstd")] pub(crate) zstd: bool, } impl EnabledCompressionEncodings { /// Check if a [`CompressionEncoding`] is enabled. pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool { match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => self.gzip, #[cfg(feature = "zstd")] CompressionEncoding::Zstd => self.zstd, } } /// Enable a [`CompressionEncoding`]. pub fn enable(&mut self, encoding: CompressionEncoding) { match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => self.gzip = true, #[cfg(feature = "zstd")] CompressionEncoding::Zstd => self.zstd = true, } } pub(crate) fn into_accept_encoding_header_value(self) -> Option { match (self.is_gzip_enabled(), self.is_zstd_enabled()) { (true, false) => Some(http::HeaderValue::from_static("gzip,identity")), (false, true) => Some(http::HeaderValue::from_static("zstd,identity")), (true, true) => Some(http::HeaderValue::from_static("gzip,zstd,identity")), (false, false) => None, } } #[cfg(feature = "gzip")] const fn is_gzip_enabled(&self) -> bool { self.gzip } #[cfg(not(feature = "gzip"))] const fn is_gzip_enabled(&self) -> bool { false } #[cfg(feature = "zstd")] const fn is_zstd_enabled(&self) -> bool { self.zstd } #[cfg(not(feature = "zstd"))] const fn is_zstd_enabled(&self) -> bool { false } } /// The compression encodings Tonic supports. #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[non_exhaustive] pub enum CompressionEncoding { #[allow(missing_docs)] #[cfg(feature = "gzip")] #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))] Gzip, #[allow(missing_docs)] #[cfg(feature = "zstd")] #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] Zstd, } impl CompressionEncoding { /// Based on the `grpc-accept-encoding` header, pick an encoding to use. pub(crate) fn from_accept_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Option { if !enabled_encodings.is_gzip_enabled() && !enabled_encodings.is_zstd_enabled() { return None; } let header_value = map.get(ACCEPT_ENCODING_HEADER)?; let header_value_str = header_value.to_str().ok()?; split_by_comma(header_value_str).find_map(|value| match value { #[cfg(feature = "gzip")] "gzip" => Some(CompressionEncoding::Gzip), #[cfg(feature = "zstd")] "zstd" => Some(CompressionEncoding::Zstd), _ => None, }) } /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported. pub(crate) fn from_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Result, Status> { let header_value = if let Some(value) = map.get(ENCODING_HEADER) { value } else { return Ok(None); }; let header_value_str = if let Ok(value) = header_value.to_str() { value } else { return Ok(None); }; match header_value_str { #[cfg(feature = "gzip")] "gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => { Ok(Some(CompressionEncoding::Gzip)) } #[cfg(feature = "zstd")] "zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => { Ok(Some(CompressionEncoding::Zstd)) } "identity" => Ok(None), other => { let mut status = Status::unimplemented(format!( "Content is compressed with `{}` which isn't supported", other )); let header_value = enabled_encodings .into_accept_encoding_header_value() .map(MetadataValue::unchecked_from_header_value) .unwrap_or_else(|| MetadataValue::from_static("identity")); status .metadata_mut() .insert(ACCEPT_ENCODING_HEADER, header_value); Err(status) } } } #[allow(missing_docs)] #[cfg(any(feature = "gzip", feature = "zstd"))] pub(crate) fn as_str(&self) -> &'static str { match self { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => "gzip", #[cfg(feature = "zstd")] CompressionEncoding::Zstd => "zstd", } } #[cfg(any(feature = "gzip", feature = "zstd"))] pub(crate) fn into_header_value(self) -> http::HeaderValue { http::HeaderValue::from_static(self.as_str()) } pub(crate) fn encodings() -> &'static [Self] { &[ #[cfg(feature = "gzip")] CompressionEncoding::Gzip, #[cfg(feature = "zstd")] CompressionEncoding::Zstd, ] } } impl fmt::Display for CompressionEncoding { #[allow(unused_variables)] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => write!(f, "gzip"), #[cfg(feature = "zstd")] CompressionEncoding::Zstd => write!(f, "zstd"), } } } fn split_by_comma(s: &str) -> impl Iterator { s.trim().split(',').map(|s| s.trim()) } /// Compress `len` bytes from `decompressed_buf` into `out_buf`. #[allow(unused_variables, unreachable_code)] pub(crate) fn compress( encoding: CompressionEncoding, decompressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error> { let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buf.reserve(capacity); #[cfg(any(feature = "gzip", feature = "zstd"))] let mut out_writer = bytes::BufMut::writer(out_buf); match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_encoder = GzEncoder::new( &decompressed_buf[0..len], // FIXME: support customizing the compression level flate2::Compression::new(6), ); std::io::copy(&mut gzip_encoder, &mut out_writer)?; } #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { let mut zstd_encoder = Encoder::new( &decompressed_buf[0..len], // FIXME: support customizing the compression level zstd::DEFAULT_COMPRESSION_LEVEL, )?; std::io::copy(&mut zstd_encoder, &mut out_writer)?; } } decompressed_buf.advance(len); Ok(()) } /// Decompress `len` bytes from `compressed_buf` into `out_buf`. #[allow(unused_variables, unreachable_code)] pub(crate) fn decompress( encoding: CompressionEncoding, compressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error> { let estimate_decompressed_len = len * 2; let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buf.reserve(capacity); #[cfg(any(feature = "gzip", feature = "zstd"))] let mut out_writer = bytes::BufMut::writer(out_buf); match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); std::io::copy(&mut gzip_decoder, &mut out_writer)?; } #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; std::io::copy(&mut zstd_decoder, &mut out_writer)?; } } compressed_buf.advance(len); Ok(()) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum SingleMessageCompressionOverride { /// Inherit whatever compression is already configured. If the stream is compressed this /// message will also be configured. /// /// This is the default. Inherit, /// Don't compress this message, even if compression is enabled on the stream. Disable, } impl Default for SingleMessageCompressionOverride { fn default() -> Self { Self::Inherit } }