• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![doc = include_str!("../docs/error_handling.md")]
2 
3 use crate::{
4     extract::FromRequestParts,
5     http::Request,
6     response::{IntoResponse, Response},
7 };
8 use std::{
9     convert::Infallible,
10     fmt,
11     future::Future,
12     marker::PhantomData,
13     task::{Context, Poll},
14 };
15 use tower::ServiceExt;
16 use tower_layer::Layer;
17 use tower_service::Service;
18 
19 /// [`Layer`] that applies [`HandleError`] which is a [`Service`] adapter
20 /// that handles errors by converting them into responses.
21 ///
22 /// See [module docs](self) for more details on axum's error handling model.
23 pub struct HandleErrorLayer<F, T> {
24     f: F,
25     _extractor: PhantomData<fn() -> T>,
26 }
27 
28 impl<F, T> HandleErrorLayer<F, T> {
29     /// Create a new `HandleErrorLayer`.
new(f: F) -> Self30     pub fn new(f: F) -> Self {
31         Self {
32             f,
33             _extractor: PhantomData,
34         }
35     }
36 }
37 
38 impl<F, T> Clone for HandleErrorLayer<F, T>
39 where
40     F: Clone,
41 {
clone(&self) -> Self42     fn clone(&self) -> Self {
43         Self {
44             f: self.f.clone(),
45             _extractor: PhantomData,
46         }
47     }
48 }
49 
50 impl<F, E> fmt::Debug for HandleErrorLayer<F, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result51     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52         f.debug_struct("HandleErrorLayer")
53             .field("f", &format_args!("{}", std::any::type_name::<F>()))
54             .finish()
55     }
56 }
57 
58 impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
59 where
60     F: Clone,
61 {
62     type Service = HandleError<S, F, T>;
63 
layer(&self, inner: S) -> Self::Service64     fn layer(&self, inner: S) -> Self::Service {
65         HandleError::new(inner, self.f.clone())
66     }
67 }
68 
69 /// A [`Service`] adapter that handles errors by converting them into responses.
70 ///
71 /// See [module docs](self) for more details on axum's error handling model.
72 pub struct HandleError<S, F, T> {
73     inner: S,
74     f: F,
75     _extractor: PhantomData<fn() -> T>,
76 }
77 
78 impl<S, F, T> HandleError<S, F, T> {
79     /// Create a new `HandleError`.
new(inner: S, f: F) -> Self80     pub fn new(inner: S, f: F) -> Self {
81         Self {
82             inner,
83             f,
84             _extractor: PhantomData,
85         }
86     }
87 }
88 
89 impl<S, F, T> Clone for HandleError<S, F, T>
90 where
91     S: Clone,
92     F: Clone,
93 {
clone(&self) -> Self94     fn clone(&self) -> Self {
95         Self {
96             inner: self.inner.clone(),
97             f: self.f.clone(),
98             _extractor: PhantomData,
99         }
100     }
101 }
102 
103 impl<S, F, E> fmt::Debug for HandleError<S, F, E>
104 where
105     S: fmt::Debug,
106 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result107     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108         f.debug_struct("HandleError")
109             .field("inner", &self.inner)
110             .field("f", &format_args!("{}", std::any::type_name::<F>()))
111             .finish()
112     }
113 }
114 
115 impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
116 where
117     S: Service<Request<B>> + Clone + Send + 'static,
118     S::Response: IntoResponse + Send,
119     S::Error: Send,
120     S::Future: Send,
121     F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
122     Fut: Future<Output = Res> + Send,
123     Res: IntoResponse,
124     B: Send + 'static,
125 {
126     type Response = Response;
127     type Error = Infallible;
128     type Future = future::HandleErrorFuture;
129 
poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>>130     fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131         Poll::Ready(Ok(()))
132     }
133 
call(&mut self, req: Request<B>) -> Self::Future134     fn call(&mut self, req: Request<B>) -> Self::Future {
135         let f = self.f.clone();
136 
137         let clone = self.inner.clone();
138         let inner = std::mem::replace(&mut self.inner, clone);
139 
140         let future = Box::pin(async move {
141             match inner.oneshot(req).await {
142                 Ok(res) => Ok(res.into_response()),
143                 Err(err) => Ok(f(err).await.into_response()),
144             }
145         });
146 
147         future::HandleErrorFuture { future }
148     }
149 }
150 
151 #[allow(unused_macros)]
152 macro_rules! impl_service {
153     ( $($ty:ident),* $(,)? ) => {
154         impl<S, F, B, Res, Fut, $($ty,)*> Service<Request<B>>
155             for HandleError<S, F, ($($ty,)*)>
156         where
157             S: Service<Request<B>> + Clone + Send + 'static,
158             S::Response: IntoResponse + Send,
159             S::Error: Send,
160             S::Future: Send,
161             F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
162             Fut: Future<Output = Res> + Send,
163             Res: IntoResponse,
164             $( $ty: FromRequestParts<()> + Send,)*
165             B: Send + 'static,
166         {
167             type Response = Response;
168             type Error = Infallible;
169 
170             type Future = future::HandleErrorFuture;
171 
172             fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
173                 Poll::Ready(Ok(()))
174             }
175 
176             #[allow(non_snake_case)]
177             fn call(&mut self, req: Request<B>) -> Self::Future {
178                 let f = self.f.clone();
179 
180                 let clone = self.inner.clone();
181                 let inner = std::mem::replace(&mut self.inner, clone);
182 
183                 let future = Box::pin(async move {
184                     let (mut parts, body) = req.into_parts();
185 
186                     $(
187                         let $ty = match $ty::from_request_parts(&mut parts, &()).await {
188                             Ok(value) => value,
189                             Err(rejection) => return Ok(rejection.into_response()),
190                         };
191                     )*
192 
193                     let req = Request::from_parts(parts, body);
194 
195                     match inner.oneshot(req).await {
196                         Ok(res) => Ok(res.into_response()),
197                         Err(err) => Ok(f($($ty),*, err).await.into_response()),
198                     }
199                 });
200 
201                 future::HandleErrorFuture { future }
202             }
203         }
204     }
205 }
206 
207 impl_service!(T1);
208 impl_service!(T1, T2);
209 impl_service!(T1, T2, T3);
210 impl_service!(T1, T2, T3, T4);
211 impl_service!(T1, T2, T3, T4, T5);
212 impl_service!(T1, T2, T3, T4, T5, T6);
213 impl_service!(T1, T2, T3, T4, T5, T6, T7);
214 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
215 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
216 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
217 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
218 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
219 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
220 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
221 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
222 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
223 
224 pub mod future {
225     //! Future types.
226 
227     use crate::response::Response;
228     use pin_project_lite::pin_project;
229     use std::{
230         convert::Infallible,
231         future::Future,
232         pin::Pin,
233         task::{Context, Poll},
234     };
235 
236     pin_project! {
237         /// Response future for [`HandleError`].
238         pub struct HandleErrorFuture {
239             #[pin]
240             pub(super) future: Pin<Box<dyn Future<Output = Result<Response, Infallible>>
241                 + Send
242                 + 'static
243             >>,
244         }
245     }
246 
247     impl Future for HandleErrorFuture {
248         type Output = Result<Response, Infallible>;
249 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>250         fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
251             self.project().future.poll(cx)
252         }
253     }
254 }
255 
256 #[test]
traits()257 fn traits() {
258     use crate::test_helpers::*;
259 
260     assert_send::<HandleError<(), (), NotSendSync>>();
261     assert_sync::<HandleError<(), (), NotSendSync>>();
262 }
263