• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::Status;
2 use http::Response;
3 use pin_project::pin_project;
4 use std::{
5     future::Future,
6     pin::Pin,
7     task::{ready, Context, Poll},
8 };
9 use tower::Service;
10 
11 /// Middleware that attempts to recover from service errors by turning them into a response built
12 /// from the `Status`.
13 #[derive(Debug, Clone)]
14 pub(crate) struct RecoverError<S> {
15     inner: S,
16 }
17 
18 impl<S> RecoverError<S> {
new(inner: S) -> Self19     pub(crate) fn new(inner: S) -> Self {
20         Self { inner }
21     }
22 }
23 
24 impl<S, R, ResBody> Service<R> for RecoverError<S>
25 where
26     S: Service<R, Response = Response<ResBody>>,
27     S::Error: Into<crate::Error>,
28 {
29     type Response = Response<MaybeEmptyBody<ResBody>>;
30     type Error = crate::Error;
31     type Future = ResponseFuture<S::Future>;
32 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>33     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
34         self.inner.poll_ready(cx).map_err(Into::into)
35     }
36 
call(&mut self, req: R) -> Self::Future37     fn call(&mut self, req: R) -> Self::Future {
38         ResponseFuture {
39             inner: self.inner.call(req),
40         }
41     }
42 }
43 
44 #[pin_project]
45 pub(crate) struct ResponseFuture<F> {
46     #[pin]
47     inner: F,
48 }
49 
50 impl<F, E, ResBody> Future for ResponseFuture<F>
51 where
52     F: Future<Output = Result<Response<ResBody>, E>>,
53     E: Into<crate::Error>,
54 {
55     type Output = Result<Response<MaybeEmptyBody<ResBody>>, crate::Error>;
56 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>57     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
58         let result: Result<Response<_>, crate::Error> =
59             ready!(self.project().inner.poll(cx)).map_err(Into::into);
60 
61         match result {
62             Ok(response) => {
63                 let response = response.map(MaybeEmptyBody::full);
64                 Poll::Ready(Ok(response))
65             }
66             Err(err) => match Status::try_from_error(err) {
67                 Ok(status) => {
68                     let mut res = Response::new(MaybeEmptyBody::empty());
69                     status.add_header(res.headers_mut()).unwrap();
70                     Poll::Ready(Ok(res))
71                 }
72                 Err(err) => Poll::Ready(Err(err)),
73             },
74         }
75     }
76 }
77 
78 #[pin_project]
79 pub(crate) struct MaybeEmptyBody<B> {
80     #[pin]
81     inner: Option<B>,
82 }
83 
84 impl<B> MaybeEmptyBody<B> {
full(inner: B) -> Self85     fn full(inner: B) -> Self {
86         Self { inner: Some(inner) }
87     }
88 
empty() -> Self89     fn empty() -> Self {
90         Self { inner: None }
91     }
92 }
93 
94 impl<B> http_body::Body for MaybeEmptyBody<B>
95 where
96     B: http_body::Body + Send,
97 {
98     type Data = B::Data;
99     type Error = B::Error;
100 
poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>101     fn poll_data(
102         self: Pin<&mut Self>,
103         cx: &mut Context<'_>,
104     ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
105         match self.project().inner.as_pin_mut() {
106             Some(b) => b.poll_data(cx),
107             None => Poll::Ready(None),
108         }
109     }
110 
poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>>111     fn poll_trailers(
112         self: Pin<&mut Self>,
113         cx: &mut Context<'_>,
114     ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
115         match self.project().inner.as_pin_mut() {
116             Some(b) => b.poll_trailers(cx),
117             None => Poll::Ready(Ok(None)),
118         }
119     }
120 
is_end_stream(&self) -> bool121     fn is_end_stream(&self) -> bool {
122         match &self.inner {
123             Some(b) => b.is_end_stream(),
124             None => true,
125         }
126     }
127 }
128