• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::extract::FromRequestParts;
2 use async_trait::async_trait;
3 use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts};
4 use headers::HeaderMapExt;
5 use http::request::Parts;
6 use std::convert::Infallible;
7 
8 /// Extractor and response that works with typed header values from [`headers`].
9 ///
10 /// # As extractor
11 ///
12 /// In general, it's recommended to extract only the needed headers via `TypedHeader` rather than
13 /// removing all headers with the `HeaderMap` extractor.
14 ///
15 /// ```rust,no_run
16 /// use axum::{
17 ///     TypedHeader,
18 ///     headers::UserAgent,
19 ///     routing::get,
20 ///     Router,
21 /// };
22 ///
23 /// async fn users_teams_show(
24 ///     TypedHeader(user_agent): TypedHeader<UserAgent>,
25 /// ) {
26 ///     // ...
27 /// }
28 ///
29 /// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));
30 /// # async {
31 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
32 /// # };
33 /// ```
34 ///
35 /// # As response
36 ///
37 /// ```rust
38 /// use axum::{
39 ///     TypedHeader,
40 ///     response::IntoResponse,
41 ///     headers::ContentType,
42 /// };
43 ///
44 /// async fn handler() -> (TypedHeader<ContentType>, &'static str) {
45 ///     (
46 ///         TypedHeader(ContentType::text_utf8()),
47 ///         "Hello, World!",
48 ///     )
49 /// }
50 /// ```
51 #[cfg(feature = "headers")]
52 #[derive(Debug, Clone, Copy)]
53 #[must_use]
54 pub struct TypedHeader<T>(pub T);
55 
56 #[async_trait]
57 impl<T, S> FromRequestParts<S> for TypedHeader<T>
58 where
59     T: headers::Header,
60     S: Send + Sync,
61 {
62     type Rejection = TypedHeaderRejection;
63 
from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>64     async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
65         let mut values = parts.headers.get_all(T::name()).iter();
66         let is_missing = values.size_hint() == (0, Some(0));
67         T::decode(&mut values)
68             .map(Self)
69             .map_err(|err| TypedHeaderRejection {
70                 name: T::name(),
71                 reason: if is_missing {
72                     // Report a more precise rejection for the missing header case.
73                     TypedHeaderRejectionReason::Missing
74                 } else {
75                     TypedHeaderRejectionReason::Error(err)
76                 },
77             })
78     }
79 }
80 
81 axum_core::__impl_deref!(TypedHeader);
82 
83 impl<T> IntoResponseParts for TypedHeader<T>
84 where
85     T: headers::Header,
86 {
87     type Error = Infallible;
88 
into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error>89     fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
90         res.headers_mut().typed_insert(self.0);
91         Ok(res)
92     }
93 }
94 
95 impl<T> IntoResponse for TypedHeader<T>
96 where
97     T: headers::Header,
98 {
into_response(self) -> Response99     fn into_response(self) -> Response {
100         let mut res = ().into_response();
101         res.headers_mut().typed_insert(self.0);
102         res
103     }
104 }
105 
106 /// Rejection used for [`TypedHeader`](super::TypedHeader).
107 #[cfg(feature = "headers")]
108 #[derive(Debug)]
109 pub struct TypedHeaderRejection {
110     name: &'static http::header::HeaderName,
111     reason: TypedHeaderRejectionReason,
112 }
113 
114 impl TypedHeaderRejection {
115     /// Name of the header that caused the rejection
name(&self) -> &http::header::HeaderName116     pub fn name(&self) -> &http::header::HeaderName {
117         self.name
118     }
119 
120     /// Reason why the header extraction has failed
reason(&self) -> &TypedHeaderRejectionReason121     pub fn reason(&self) -> &TypedHeaderRejectionReason {
122         &self.reason
123     }
124 }
125 
126 /// Additional information regarding a [`TypedHeaderRejection`]
127 #[cfg(feature = "headers")]
128 #[derive(Debug)]
129 #[non_exhaustive]
130 pub enum TypedHeaderRejectionReason {
131     /// The header was missing from the HTTP request
132     Missing,
133     /// An error occured when parsing the header from the HTTP request
134     Error(headers::Error),
135 }
136 
137 impl IntoResponse for TypedHeaderRejection {
into_response(self) -> Response138     fn into_response(self) -> Response {
139         (http::StatusCode::BAD_REQUEST, self.to_string()).into_response()
140     }
141 }
142 
143 impl std::fmt::Display for TypedHeaderRejection {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result144     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145         match &self.reason {
146             TypedHeaderRejectionReason::Missing => {
147                 write!(f, "Header of type `{}` was missing", self.name)
148             }
149             TypedHeaderRejectionReason::Error(err) => {
150                 write!(f, "{} ({})", err, self.name)
151             }
152         }
153     }
154 }
155 
156 impl std::error::Error for TypedHeaderRejection {
source(&self) -> Option<&(dyn std::error::Error + 'static)>157     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
158         match &self.reason {
159             TypedHeaderRejectionReason::Error(err) => Some(err),
160             TypedHeaderRejectionReason::Missing => None,
161         }
162     }
163 }
164 
165 #[cfg(test)]
166 mod tests {
167     use super::*;
168     use crate::{response::IntoResponse, routing::get, test_helpers::*, Router};
169 
170     #[crate::test]
typed_header()171     async fn typed_header() {
172         async fn handle(
173             TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
174             TypedHeader(cookies): TypedHeader<headers::Cookie>,
175         ) -> impl IntoResponse {
176             let user_agent = user_agent.as_str();
177             let cookies = cookies.iter().collect::<Vec<_>>();
178             format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
179         }
180 
181         let app = Router::new().route("/", get(handle));
182 
183         let client = TestClient::new(app);
184 
185         let res = client
186             .get("/")
187             .header("user-agent", "foobar")
188             .header("cookie", "a=1; b=2")
189             .header("cookie", "c=3")
190             .send()
191             .await;
192         let body = res.text().await;
193         assert_eq!(
194             body,
195             r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
196         );
197 
198         let res = client.get("/").header("user-agent", "foobar").send().await;
199         let body = res.text().await;
200         assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
201 
202         let res = client.get("/").header("cookie", "a=1").send().await;
203         let body = res.text().await;
204         assert_eq!(body, "Header of type `user-agent` was missing");
205     }
206 }
207