use crate::codec::compression::{ CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride, }; use crate::{ body::BoxBody, codec::{encode_server, Codec, Streaming}, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Status, }; use http_body::Body; use std::fmt; use tokio_stream::{Stream, StreamExt}; macro_rules! t { ($result:expr) => { match $result { Ok(value) => value, Err(status) => return status.to_http(), } }; } /// A gRPC Server handler. /// /// This will wrap some inner [`Codec`] and provide utilities to handle /// inbound unary, client side streaming, server side streaming, and /// bi-directional streaming. /// /// Each request handler method accepts some service that implements the /// corresponding service trait and a http request that contains some body that /// implements some [`Body`]. pub struct Grpc { codec: T, /// Which compression encodings does the server accept for requests? accept_compression_encodings: EnabledCompressionEncodings, /// Which compression encodings might the server use for responses. send_compression_encodings: EnabledCompressionEncodings, /// Limits the maximum size of a decoded message. max_decoding_message_size: Option, /// Limits the maximum size of an encoded message. max_encoding_message_size: Option, } impl Grpc where T: Codec, { /// Creates a new gRPC server with the provided [`Codec`]. pub fn new(codec: T) -> Self { Self { codec, accept_compression_encodings: EnabledCompressionEncodings::default(), send_compression_encodings: EnabledCompressionEncodings::default(), max_decoding_message_size: None, max_encoding_message_size: None, } } /// Enable accepting compressed requests. /// /// If a request with an unsupported encoding is received the server will respond with /// [`Code::UnUnimplemented`](crate::Code). /// /// # Example /// /// The most common way of using this is through a server generated by tonic-build: /// /// ```rust /// # enum CompressionEncoding { Gzip } /// # struct Svc; /// # struct ExampleServer(T); /// # impl ExampleServer { /// # fn new(svc: T) -> Self { Self(svc) } /// # fn accept_compressed(self, _: CompressionEncoding) -> Self { self } /// # } /// # #[tonic::async_trait] /// # trait Example {} /// /// #[tonic::async_trait] /// impl Example for Svc { /// // ... /// } /// /// let service = ExampleServer::new(Svc).accept_compressed(CompressionEncoding::Gzip); /// ``` pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { self.accept_compression_encodings.enable(encoding); self } /// Enable sending compressed responses. /// /// Requires the client to also support receiving compressed responses. /// /// # Example /// /// The most common way of using this is through a server generated by tonic-build: /// /// ```rust /// # enum CompressionEncoding { Gzip } /// # struct Svc; /// # struct ExampleServer(T); /// # impl ExampleServer { /// # fn new(svc: T) -> Self { Self(svc) } /// # fn send_compressed(self, _: CompressionEncoding) -> Self { self } /// # } /// # #[tonic::async_trait] /// # trait Example {} /// /// #[tonic::async_trait] /// impl Example for Svc { /// // ... /// } /// /// let service = ExampleServer::new(Svc).send_compressed(CompressionEncoding::Gzip); /// ``` pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { self.send_compression_encodings.enable(encoding); self } /// Limits the maximum size of a decoded message. /// /// # Example /// /// The most common way of using this is through a server generated by tonic-build: /// /// ```rust /// # struct Svc; /// # struct ExampleServer(T); /// # impl ExampleServer { /// # fn new(svc: T) -> Self { Self(svc) } /// # fn max_decoding_message_size(self, _: usize) -> Self { self } /// # } /// # #[tonic::async_trait] /// # trait Example {} /// /// #[tonic::async_trait] /// impl Example for Svc { /// // ... /// } /// /// // Set the limit to 2MB, Defaults to 4MB. /// let limit = 2 * 1024 * 1024; /// let service = ExampleServer::new(Svc).max_decoding_message_size(limit); /// ``` pub fn max_decoding_message_size(mut self, limit: usize) -> Self { self.max_decoding_message_size = Some(limit); self } /// Limits the maximum size of a encoded message. /// /// # Example /// /// The most common way of using this is through a server generated by tonic-build: /// /// ```rust /// # struct Svc; /// # struct ExampleServer(T); /// # impl ExampleServer { /// # fn new(svc: T) -> Self { Self(svc) } /// # fn max_encoding_message_size(self, _: usize) -> Self { self } /// # } /// # #[tonic::async_trait] /// # trait Example {} /// /// #[tonic::async_trait] /// impl Example for Svc { /// // ... /// } /// /// // Set the limit to 2MB, Defaults to 4MB. /// let limit = 2 * 1024 * 1024; /// let service = ExampleServer::new(Svc).max_encoding_message_size(limit); /// ``` pub fn max_encoding_message_size(mut self, limit: usize) -> Self { self.max_encoding_message_size = Some(limit); self } #[doc(hidden)] pub fn apply_compression_config( self, accept_encodings: EnabledCompressionEncodings, send_encodings: EnabledCompressionEncodings, ) -> Self { let mut this = self; for &encoding in CompressionEncoding::encodings() { if accept_encodings.is_enabled(encoding) { this = this.accept_compressed(encoding); } if send_encodings.is_enabled(encoding) { this = this.send_compressed(encoding); } } this } #[doc(hidden)] pub fn apply_max_message_size_config( self, max_decoding_message_size: Option, max_encoding_message_size: Option, ) -> Self { let mut this = self; if let Some(limit) = max_decoding_message_size { this = this.max_decoding_message_size(limit); } if let Some(limit) = max_encoding_message_size { this = this.max_encoding_message_size(limit); } this } /// Handle a single unary gRPC request. pub async fn unary( &mut self, mut service: S, req: http::Request, ) -> http::Response where S: UnaryService, B: Body + Send + 'static, B::Error: Into + Send, { let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, ); let request = match self.map_request_unary(req).await { Ok(r) => r, Err(status) => { return self.map_response::>>( Err(status), accept_encoding, SingleMessageCompressionOverride::default(), self.max_encoding_message_size, ); } }; let response = service .call(request) .await .map(|r| r.map(|m| tokio_stream::once(Ok(m)))); let compression_override = compression_override_from_response(&response); self.map_response( response, accept_encoding, compression_override, self.max_encoding_message_size, ) } /// Handle a server side streaming request. pub async fn server_streaming( &mut self, mut service: S, req: http::Request, ) -> http::Response where S: ServerStreamingService, S::ResponseStream: Send + 'static, B: Body + Send + 'static, B::Error: Into + Send, { let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, ); let request = match self.map_request_unary(req).await { Ok(r) => r, Err(status) => { return self.map_response::( Err(status), accept_encoding, SingleMessageCompressionOverride::default(), self.max_encoding_message_size, ); } }; let response = service.call(request).await; self.map_response( response, accept_encoding, // disabling compression of individual stream items must be done on // the items themselves SingleMessageCompressionOverride::default(), self.max_encoding_message_size, ) } /// Handle a client side streaming gRPC request. pub async fn client_streaming( &mut self, mut service: S, req: http::Request, ) -> http::Response where S: ClientStreamingService, B: Body + Send + 'static, B::Error: Into + Send + 'static, { let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, ); let request = t!(self.map_request_streaming(req)); let response = service .call(request) .await .map(|r| r.map(|m| tokio_stream::once(Ok(m)))); let compression_override = compression_override_from_response(&response); self.map_response( response, accept_encoding, compression_override, self.max_encoding_message_size, ) } /// Handle a bi-directional streaming gRPC request. pub async fn streaming( &mut self, mut service: S, req: http::Request, ) -> http::Response where S: StreamingService + Send, S::ResponseStream: Send + 'static, B: Body + Send + 'static, B::Error: Into + Send, { let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, ); let request = t!(self.map_request_streaming(req)); let response = service.call(request).await; self.map_response( response, accept_encoding, SingleMessageCompressionOverride::default(), self.max_encoding_message_size, ) } async fn map_request_unary( &mut self, request: http::Request, ) -> Result, Status> where B: Body + Send + 'static, B::Error: Into + Send, { let request_compression_encoding = self.request_encoding_if_supported(&request)?; let (parts, body) = request.into_parts(); let stream = Streaming::new_request( self.codec.decoder(), body, request_compression_encoding, self.max_decoding_message_size, ); tokio::pin!(stream); let message = stream .try_next() .await? .ok_or_else(|| Status::new(Code::Internal, "Missing request message."))?; let mut req = Request::from_http_parts(parts, message); if let Some(trailers) = stream.trailers().await? { req.metadata_mut().merge(trailers); } Ok(req) } fn map_request_streaming( &mut self, request: http::Request, ) -> Result>, Status> where B: Body + Send + 'static, B::Error: Into + Send, { let encoding = self.request_encoding_if_supported(&request)?; let request = request.map(|body| { Streaming::new_request( self.codec.decoder(), body, encoding, self.max_decoding_message_size, ) }); Ok(Request::from_http(request)) } fn map_response( &mut self, response: Result, Status>, accept_encoding: Option, compression_override: SingleMessageCompressionOverride, max_message_size: Option, ) -> http::Response where B: Stream> + Send + 'static, { let response = match response { Ok(r) => r, Err(status) => return status.to_http(), }; let (mut parts, body) = response.into_http().into_parts(); // Set the content type parts.headers.insert( http::header::CONTENT_TYPE, http::header::HeaderValue::from_static("application/grpc"), ); #[cfg(any(feature = "gzip", feature = "zstd"))] if let Some(encoding) = accept_encoding { // Set the content encoding parts.headers.insert( crate::codec::compression::ENCODING_HEADER, encoding.into_header_value(), ); } let body = encode_server( self.codec.encoder(), body, accept_encoding, compression_override, max_message_size, ); http::Response::from_parts(parts, BoxBody::new(body)) } fn request_encoding_if_supported( &self, request: &http::Request, ) -> Result, Status> { CompressionEncoding::from_encoding_header( request.headers(), self.accept_compression_encodings, ) } } impl fmt::Debug for Grpc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut f = f.debug_struct("Grpc"); f.field("codec", &self.codec); f.field( "accept_compression_encodings", &self.accept_compression_encodings, ); f.field( "send_compression_encodings", &self.send_compression_encodings, ); f.finish() } } fn compression_override_from_response( res: &Result, E>, ) -> SingleMessageCompressionOverride { res.as_ref() .ok() .and_then(|response| { response .extensions() .get::() .copied() }) .unwrap_or_default() }