#![doc = include_str!("../docs/error_handling.md")] use crate::{ extract::FromRequestParts, http::Request, response::{IntoResponse, Response}, }; use std::{ convert::Infallible, fmt, future::Future, marker::PhantomData, task::{Context, Poll}, }; use tower::ServiceExt; use tower_layer::Layer; use tower_service::Service; /// [`Layer`] that applies [`HandleError`] which is a [`Service`] adapter /// that handles errors by converting them into responses. /// /// See [module docs](self) for more details on axum's error handling model. pub struct HandleErrorLayer { f: F, _extractor: PhantomData T>, } impl HandleErrorLayer { /// Create a new `HandleErrorLayer`. pub fn new(f: F) -> Self { Self { f, _extractor: PhantomData, } } } impl Clone for HandleErrorLayer where F: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), _extractor: PhantomData, } } } impl fmt::Debug for HandleErrorLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HandleErrorLayer") .field("f", &format_args!("{}", std::any::type_name::())) .finish() } } impl Layer for HandleErrorLayer where F: Clone, { type Service = HandleError; fn layer(&self, inner: S) -> Self::Service { HandleError::new(inner, self.f.clone()) } } /// A [`Service`] adapter that handles errors by converting them into responses. /// /// See [module docs](self) for more details on axum's error handling model. pub struct HandleError { inner: S, f: F, _extractor: PhantomData T>, } impl HandleError { /// Create a new `HandleError`. pub fn new(inner: S, f: F) -> Self { Self { inner, f, _extractor: PhantomData, } } } impl Clone for HandleError where S: Clone, F: Clone, { fn clone(&self) -> Self { Self { inner: self.inner.clone(), f: self.f.clone(), _extractor: PhantomData, } } } impl fmt::Debug for HandleError where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HandleError") .field("inner", &self.inner) .field("f", &format_args!("{}", std::any::type_name::())) .finish() } } impl Service> for HandleError where S: Service> + Clone + Send + 'static, S::Response: IntoResponse + Send, S::Error: Send, S::Future: Send, F: FnOnce(S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, B: Send + 'static, { type Response = Response; type Error = Infallible; type Future = future::HandleErrorFuture; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Request) -> Self::Future { let f = self.f.clone(); let clone = self.inner.clone(); let inner = std::mem::replace(&mut self.inner, clone); let future = Box::pin(async move { match inner.oneshot(req).await { Ok(res) => Ok(res.into_response()), Err(err) => Ok(f(err).await.into_response()), } }); future::HandleErrorFuture { future } } } #[allow(unused_macros)] macro_rules! impl_service { ( $($ty:ident),* $(,)? ) => { impl Service> for HandleError where S: Service> + Clone + Send + 'static, S::Response: IntoResponse + Send, S::Error: Send, S::Future: Send, F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, $( $ty: FromRequestParts<()> + Send,)* B: Send + 'static, { type Response = Response; type Error = Infallible; type Future = future::HandleErrorFuture; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } #[allow(non_snake_case)] fn call(&mut self, req: Request) -> Self::Future { let f = self.f.clone(); let clone = self.inner.clone(); let inner = std::mem::replace(&mut self.inner, clone); let future = Box::pin(async move { let (mut parts, body) = req.into_parts(); $( let $ty = match $ty::from_request_parts(&mut parts, &()).await { Ok(value) => value, Err(rejection) => return Ok(rejection.into_response()), }; )* let req = Request::from_parts(parts, body); match inner.oneshot(req).await { Ok(res) => Ok(res.into_response()), Err(err) => Ok(f($($ty),*, err).await.into_response()), } }); future::HandleErrorFuture { future } } } } } impl_service!(T1); impl_service!(T1, T2); impl_service!(T1, T2, T3); impl_service!(T1, T2, T3, T4); impl_service!(T1, T2, T3, T4, T5); impl_service!(T1, T2, T3, T4, T5, T6); impl_service!(T1, T2, T3, T4, T5, T6, T7); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); pub mod future { //! Future types. use crate::response::Response; use pin_project_lite::pin_project; use std::{ convert::Infallible, future::Future, pin::Pin, task::{Context, Poll}, }; pin_project! { /// Response future for [`HandleError`]. pub struct HandleErrorFuture { #[pin] pub(super) future: Pin> + Send + 'static >>, } } impl Future for HandleErrorFuture { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().future.poll(cx) } } } #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); assert_sync::>(); }