//! gRPC interceptors which are a kind of middleware. //! //! See [`Interceptor`] for more details. use crate::{ body::{boxed, BoxBody}, request::SanitizeHeaders, Status, }; use bytes::Bytes; use pin_project::pin_project; use std::{ fmt, future::Future, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// A gRPC interceptor. /// /// gRPC interceptors are similar to middleware but have less flexibility. An interceptor allows /// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each /// request. Two, cancel a request with a `Status`. /// /// Any function that satisfies the bound `FnMut(Request<()>) -> Result, Status>` can be /// used as an `Interceptor`. /// /// An interceptor can be used on both the server and client side through the `tonic-build` crate's /// generated structs. /// /// See the [interceptor example][example] for more details. /// /// If you need more powerful middleware, [tower] is the recommended approach. You can find /// examples of how to use tower with tonic [here][tower-example]. /// /// Additionally, interceptors is not the recommended way to add logging to your service. For that /// a [tower] middleware is more appropriate since it can also act on the response. For example /// tower-http's [`Trace`](https://docs.rs/tower-http/latest/tower_http/trace/index.html) /// middleware supports gRPC out of the box. /// /// [tower]: https://crates.io/crates/tower /// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor /// [tower-example]: https://github.com/hyperium/tonic/tree/master/examples/src/tower pub trait Interceptor { /// Intercept a request before it is sent, optionally cancelling it. fn call(&mut self, request: crate::Request<()>) -> Result, Status>; } impl Interceptor for F where F: FnMut(crate::Request<()>) -> Result, Status>, { fn call(&mut self, request: crate::Request<()>) -> Result, Status> { self(request) } } /// Create a new interceptor layer. /// /// See [`Interceptor`] for more details. pub fn interceptor(f: F) -> InterceptorLayer where F: Interceptor, { InterceptorLayer { f } } /// A gRPC interceptor that can be used as a [`Layer`], /// created by calling [`interceptor`]. /// /// See [`Interceptor`] for more details. #[derive(Debug, Clone, Copy)] pub struct InterceptorLayer { f: F, } impl Layer for InterceptorLayer where F: Interceptor + Clone, { type Service = InterceptedService; fn layer(&self, service: S) -> Self::Service { InterceptedService::new(service, self.f.clone()) } } /// A service wrapped in an interceptor middleware. /// /// See [`Interceptor`] for more details. #[derive(Clone, Copy)] pub struct InterceptedService { inner: S, f: F, } impl InterceptedService { /// Create a new `InterceptedService` that wraps `S` and intercepts each request with the /// function `F`. pub fn new(service: S, f: F) -> Self where F: Interceptor, { Self { inner: service, f } } } impl fmt::Debug for InterceptedService where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("InterceptedService") .field("inner", &self.inner) .field("f", &format_args!("{}", std::any::type_name::())) .finish() } } impl Service> for InterceptedService where ResBody: Default + http_body::Body + Send + 'static, F: Interceptor, S: Service, Response = http::Response>, S::Error: Into, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { type Response = http::Response; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: http::Request) -> Self::Future { // It is bad practice to modify the body (i.e. Message) of the request via an interceptor. // To avoid exposing the body of the request to the interceptor function, we first remove it // here, allow the interceptor to modify the metadata and extensions, and then recreate the // HTTP request with the body. Tonic requests do not preserve the URI, HTTP version, and // HTTP method of the HTTP request, so we extract them here and then add them back in below. let uri = req.uri().clone(); let method = req.method().clone(); let version = req.version(); let req = crate::Request::from_http(req); let (metadata, extensions, msg) = req.into_parts(); match self .f .call(crate::Request::from_parts(metadata, extensions, ())) { Ok(req) => { let (metadata, extensions, _) = req.into_parts(); let req = crate::Request::from_parts(metadata, extensions, msg); let req = req.into_http(uri, method, version, SanitizeHeaders::No); ResponseFuture::future(self.inner.call(req)) } Err(status) => ResponseFuture::status(status), } } } // required to use `InterceptedService` with `Router` impl crate::server::NamedService for InterceptedService where S: crate::server::NamedService, { const NAME: &'static str = S::NAME; } /// Response future for [`InterceptedService`]. #[pin_project] #[derive(Debug)] pub struct ResponseFuture { #[pin] kind: Kind, } impl ResponseFuture { fn future(future: F) -> Self { Self { kind: Kind::Future(future), } } fn status(status: Status) -> Self { Self { kind: Kind::Status(Some(status)), } } } #[pin_project(project = KindProj)] #[derive(Debug)] enum Kind { Future(#[pin] F), Status(Option), } impl Future for ResponseFuture where F: Future, E>>, E: Into, B: Default + http_body::Body + Send + 'static, B::Error: Into, { type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().kind.project() { KindProj::Future(future) => future .poll(cx) .map(|result| result.map(|res| res.map(boxed))), KindProj::Status(status) => { let response = status .take() .unwrap() .to_http() .map(|_| B::default()) .map(boxed); Poll::Ready(Ok(response)) } } } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use http::header::HeaderMap; use std::{ pin::Pin, task::{Context, Poll}, }; use tower::ServiceExt; #[derive(Debug, Default)] struct TestBody; impl http_body::Body for TestBody { type Data = Bytes; type Error = Status; fn poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll>> { Poll::Ready(None) } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Poll::Ready(Ok(None)) } } #[tokio::test] async fn doesnt_remove_headers_from_requests() { let svc = tower::service_fn(|request: http::Request| async move { assert_eq!( request .headers() .get("user-agent") .expect("missing in leaf service"), "test-tonic" ); Ok::<_, Status>(http::Response::new(TestBody)) }); let svc = InterceptedService::new(svc, |request: crate::Request<()>| { assert_eq!( request .metadata() .get("user-agent") .expect("missing in interceptor"), "test-tonic" ); Ok(request) }); let request = http::Request::builder() .header("user-agent", "test-tonic") .body(TestBody) .unwrap(); svc.oneshot(request).await.unwrap(); } #[tokio::test] async fn handles_intercepted_status_as_response() { let message = "Blocked by the interceptor"; let expected = Status::permission_denied(message).to_http(); let svc = tower::service_fn(|_: http::Request| async { Ok::<_, Status>(http::Response::new(TestBody)) }); let svc = InterceptedService::new(svc, |_: crate::Request<()>| { Err(Status::permission_denied(message)) }); let request = http::Request::builder().body(TestBody).unwrap(); let response = svc.oneshot(request).await.unwrap(); assert_eq!(expected.status(), response.status()); assert_eq!(expected.version(), response.version()); assert_eq!(expected.headers(), response.headers()); } #[tokio::test] async fn doesnt_change_http_method() { let svc = tower::service_fn(|request: http::Request| async move { assert_eq!(request.method(), http::Method::OPTIONS); Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) }); let svc = InterceptedService::new(svc, Ok); let request = http::Request::builder() .method(http::Method::OPTIONS) .body(hyper::Body::empty()) .unwrap(); svc.oneshot(request).await.unwrap(); } }