use super::compression::{decompress, CompressionEncoding}; use super::{DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http::StatusCode; use http_body::Body; use std::{ fmt, future, pin::Pin, task::ready, task::{Context, Poll}, }; use tokio_stream::Stream; use tracing::{debug, trace}; const BUFFER_SIZE: usize = 8 * 1024; /// Streaming requests and responses. /// /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface /// to fetch the message stream and trailing metadata pub struct Streaming { decoder: Box + Send + 'static>, inner: StreamingInner, } struct StreamingInner { body: BoxBody, state: State, direction: Direction, buf: BytesMut, trailers: Option, decompress_buf: BytesMut, encoding: Option, max_message_size: Option, } impl Unpin for Streaming {} #[derive(Debug, Clone, Copy)] enum State { ReadHeader, ReadBody { compression: Option, len: usize, }, Error, } #[derive(Debug, PartialEq, Eq)] enum Direction { Request, Response(StatusCode), EmptyResponse, } impl Streaming { pub(crate) fn new_response( decoder: D, body: B, status_code: StatusCode, encoding: Option, max_message_size: Option, ) -> Self where B: Body + Send + 'static, B::Error: Into, D: Decoder + Send + 'static, { Self::new( decoder, body, Direction::Response(status_code), encoding, max_message_size, ) } pub(crate) fn new_empty(decoder: D, body: B) -> Self where B: Body + Send + 'static, B::Error: Into, D: Decoder + Send + 'static, { Self::new(decoder, body, Direction::EmptyResponse, None, None) } #[doc(hidden)] pub fn new_request( decoder: D, body: B, encoding: Option, max_message_size: Option, ) -> Self where B: Body + Send + 'static, B::Error: Into, D: Decoder + Send + 'static, { Self::new( decoder, body, Direction::Request, encoding, max_message_size, ) } fn new( decoder: D, body: B, direction: Direction, encoding: Option, max_message_size: Option, ) -> Self where B: Body + Send + 'static, B::Error: Into, D: Decoder + Send + 'static, { Self { decoder: Box::new(decoder), inner: StreamingInner { body: body .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) .map_err(|err| Status::map_error(err.into())) .boxed_unsync(), state: State::ReadHeader, direction, buf: BytesMut::with_capacity(BUFFER_SIZE), trailers: None, decompress_buf: BytesMut::new(), encoding, max_message_size, }, } } } impl StreamingInner { fn decode_chunk(&mut self) -> Result>, Status> { if let State::ReadHeader = self.state { if self.buf.remaining() < HEADER_SIZE { return Ok(None); } let compression_encoding = match self.buf.get_u8() { 0 => None, 1 => { { if self.encoding.is_some() { self.encoding } else { // https://grpc.github.io/grpc/core/md_doc_compression.html // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding // entry different from identity in its metadata MUST fail with INTERNAL status, // its associated description indicating the invalid Compressed-Flag condition. return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified")); } } } f => { trace!("unexpected compression flag"); let message = if let Direction::Response(status) = self.direction { format!( "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}", f, status ) } else { format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f) }; return Err(Status::new(Code::Internal, message)); } }; let len = self.buf.get_u32() as usize; let limit = self .max_message_size .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE); if len > limit { return Err(Status::new( Code::OutOfRange, format!( "Error, message length too large: found {} bytes, the limit is: {} bytes", len, limit ), )); } self.buf.reserve(len); self.state = State::ReadBody { compression: compression_encoding, len, } } if let State::ReadBody { len, compression } = self.state { // if we haven't read enough of the message then return and keep // reading if self.buf.remaining() < len || self.buf.len() < len { return Ok(None); } let decode_buf = if let Some(encoding) = compression { self.decompress_buf.clear(); if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len) { let message = if let Direction::Response(status) = self.direction { format!( "Error decompressing: {}, while receiving response with status: {}", err, status ) } else { format!("Error decompressing: {}, while sending request", err) }; return Err(Status::new(Code::Internal, message)); } let decompressed_len = self.decompress_buf.len(); DecodeBuf::new(&mut self.decompress_buf, decompressed_len) } else { DecodeBuf::new(&mut self.buf, len) }; return Ok(Some(decode_buf)); } Ok(None) } // Returns Some(()) if data was found or None if the loop in `poll_next` should break fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { Some(Ok(d)) => Some(d), Some(Err(status)) => { if self.direction == Direction::Request && status.code() == Code::Cancelled { return Poll::Ready(Ok(None)); } let _ = std::mem::replace(&mut self.state, State::Error); debug!("decoder inner stream error: {:?}", status); return Poll::Ready(Err(status)); } None => None, }; Poll::Ready(if let Some(data) = chunk { self.buf.put(data); Ok(Some(())) } else { // FIXME: improve buf usage. if self.buf.has_remaining() { trace!("unexpected EOF decoding stream, state: {:?}", self.state); Err(Status::new( Code::Internal, "Unexpected EOF decoding stream.".to_string(), )) } else { Ok(None) } }) } fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll> { if let Direction::Response(status) = self.direction { match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { Ok(trailer) => { if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { if let Some(e) = e { return Poll::Ready(Err(e)); } else { return Poll::Ready(Ok(())); } } else { self.trailers = trailer.map(MetadataMap::from_headers); } } Err(status) => { debug!("decoder inner trailers error: {:?}", status); return Poll::Ready(Err(status)); } } } Poll::Ready(Ok(())) } } impl Streaming { /// Fetch the next message from this stream. /// /// # Return value /// /// - `Result::Err(val)` means a gRPC error was sent by the sender instead /// of a valid response message. Refer to [`Status::code`] and /// [`Status::message`] to examine possible error causes. /// /// - `Result::Ok(None)` means the stream was closed by the sender and no /// more messages will be delivered. Further attempts to call /// [`Streaming::message`] will result in the same return value. /// /// - `Result::Ok(Some(val))` means the sender streamed a valid response /// message `val`. /// /// ```rust /// # use tonic::{Streaming, Status, codec::Decoder}; /// # use std::fmt::Debug; /// # async fn next_message_ex(mut request: Streaming) -> Result<(), Status> /// # where T: Debug, /// # D: Decoder + Send + 'static, /// # { /// if let Some(next_message) = request.message().await? { /// println!("{:?}", next_message); /// } /// # Ok(()) /// # } /// ``` pub async fn message(&mut self) -> Result, Status> { match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await { Some(Ok(m)) => Ok(Some(m)), Some(Err(e)) => Err(e), None => Ok(None), } } /// Fetch the trailing metadata. /// /// This will drain the stream of all its messages to receive the trailing /// metadata. If [`Streaming::message`] returns `None` then this function /// will not need to poll for trailers since the body was totally consumed. /// /// ```rust /// # use tonic::{Streaming, Status}; /// # async fn trailers_ex(mut request: Streaming) -> Result<(), Status> { /// if let Some(metadata) = request.trailers().await? { /// println!("{:?}", metadata); /// } /// # Ok(()) /// # } /// ``` pub async fn trailers(&mut self) -> Result, Status> { // Shortcut to see if we already pulled the trailers in the stream step // we need to do that so that the stream can error on trailing grpc-status if let Some(trailers) = self.inner.trailers.take() { return Ok(Some(trailers)); } // To fetch the trailers we must clear the body and drop it. while self.message().await?.is_some() {} // Since we call poll_trailers internally on poll_next we need to // check if it got cached again. if let Some(trailers) = self.inner.trailers.take() { return Ok(Some(trailers)); } // Trailers were not caught during poll_next and thus lets poll for // them manually. let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx)) .await .map_err(|e| Status::from_error(Box::new(e))); map.map(|x| x.map(MetadataMap::from_headers)) } fn decode_chunk(&mut self) -> Result, Status> { match self.inner.decode_chunk()? { Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? { Some(msg) => { self.inner.state = State::ReadHeader; Ok(Some(msg)) } None => Ok(None), }, None => Ok(None), } } } impl Stream for Streaming { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { if let State::Error = &self.inner.state { return Poll::Ready(None); } // FIXME: implement the ability to poll trailers when we _know_ that // the consumer of this stream will only poll for the first message. // This means we skip the poll_trailers step. if let Some(item) = self.decode_chunk()? { return Poll::Ready(Some(Ok(item))); } match ready!(self.inner.poll_data(cx))? { Some(()) => (), None => break, } } Poll::Ready(match ready!(self.inner.poll_response(cx)) { Ok(()) => None, Err(err) => Some(Err(err)), }) } } impl fmt::Debug for Streaming { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Streaming").finish() } } #[cfg(test)] static_assertions::assert_impl_all!(Streaming<()>: Send);