• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::metadata::GRPC_TIMEOUT_HEADER;
2 use http::{HeaderMap, HeaderValue, Request};
3 use pin_project::pin_project;
4 use std::{
5     fmt,
6     future::Future,
7     pin::Pin,
8     task::{ready, Context, Poll},
9     time::Duration,
10 };
11 use tokio::time::Sleep;
12 use tower_service::Service;
13 
14 #[derive(Debug, Clone)]
15 pub(crate) struct GrpcTimeout<S> {
16     inner: S,
17     server_timeout: Option<Duration>,
18 }
19 
20 impl<S> GrpcTimeout<S> {
new(inner: S, server_timeout: Option<Duration>) -> Self21     pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
22         Self {
23             inner,
24             server_timeout,
25         }
26     }
27 }
28 
29 impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTimeout<S>
30 where
31     S: Service<Request<ReqBody>>,
32     S::Error: Into<crate::Error>,
33 {
34     type Response = S::Response;
35     type Error = crate::Error;
36     type Future = ResponseFuture<S::Future>;
37 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>38     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
39         self.inner.poll_ready(cx).map_err(Into::into)
40     }
41 
call(&mut self, req: Request<ReqBody>) -> Self::Future42     fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
43         let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
44             tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
45             None
46         });
47 
48         // Use the shorter of the two durations, if either are set
49         let timeout_duration = match (client_timeout, self.server_timeout) {
50             (None, None) => None,
51             (Some(dur), None) => Some(dur),
52             (None, Some(dur)) => Some(dur),
53             (Some(header), Some(server)) => {
54                 let shorter_duration = std::cmp::min(header, server);
55                 Some(shorter_duration)
56             }
57         };
58 
59         ResponseFuture {
60             inner: self.inner.call(req),
61             sleep: timeout_duration
62                 .map(tokio::time::sleep)
63                 .map(Some)
64                 .unwrap_or(None),
65         }
66     }
67 }
68 
69 #[pin_project]
70 pub(crate) struct ResponseFuture<F> {
71     #[pin]
72     inner: F,
73     #[pin]
74     sleep: Option<Sleep>,
75 }
76 
77 impl<F, Res, E> Future for ResponseFuture<F>
78 where
79     F: Future<Output = Result<Res, E>>,
80     E: Into<crate::Error>,
81 {
82     type Output = Result<Res, crate::Error>;
83 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>84     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85         let this = self.project();
86 
87         if let Poll::Ready(result) = this.inner.poll(cx) {
88             return Poll::Ready(result.map_err(Into::into));
89         }
90 
91         if let Some(sleep) = this.sleep.as_pin_mut() {
92             ready!(sleep.poll(cx));
93             return Poll::Ready(Err(TimeoutExpired(()).into()));
94         }
95 
96         Poll::Pending
97     }
98 }
99 
100 const SECONDS_IN_HOUR: u64 = 60 * 60;
101 const SECONDS_IN_MINUTE: u64 = 60;
102 
103 /// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns
104 /// the value we attempted to parse.
105 ///
106 /// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
try_parse_grpc_timeout( headers: &HeaderMap<HeaderValue>, ) -> Result<Option<Duration>, &HeaderValue>107 fn try_parse_grpc_timeout(
108     headers: &HeaderMap<HeaderValue>,
109 ) -> Result<Option<Duration>, &HeaderValue> {
110     match headers.get(GRPC_TIMEOUT_HEADER) {
111         Some(val) => {
112             let (timeout_value, timeout_unit) = val
113                 .to_str()
114                 .map_err(|_| val)
115                 .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
116                 // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this
117                 // `split_at` will never panic from trying to split in the middle of a character.
118                 // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str
119                 //
120                 // `len - 1` also wont panic since we just checked `s.is_empty`.
121                 .split_at(val.len() - 1);
122 
123             // gRPC spec specifies `TimeoutValue` will be at most 8 digits
124             // Caping this at 8 digits also prevents integer overflow from ever occurring
125             if timeout_value.len() > 8 {
126                 return Err(val);
127             }
128 
129             let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
130 
131             let duration = match timeout_unit {
132                 // Hours
133                 "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
134                 // Minutes
135                 "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
136                 // Seconds
137                 "S" => Duration::from_secs(timeout_value),
138                 // Milliseconds
139                 "m" => Duration::from_millis(timeout_value),
140                 // Microseconds
141                 "u" => Duration::from_micros(timeout_value),
142                 // Nanoseconds
143                 "n" => Duration::from_nanos(timeout_value),
144                 _ => return Err(val),
145             };
146 
147             Ok(Some(duration))
148         }
149         None => Ok(None),
150     }
151 }
152 
153 /// Error returned if a request didn't complete within the configured timeout.
154 ///
155 /// Timeouts can be configured either with [`Endpoint::timeout`], [`Server::timeout`], or by
156 /// setting the [`grpc-timeout` metadata value][spec].
157 ///
158 /// [`Endpoint::timeout`]: crate::transport::server::Server::timeout
159 /// [`Server::timeout`]: crate::transport::channel::Endpoint::timeout
160 /// [spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
161 #[derive(Debug)]
162 pub struct TimeoutExpired(());
163 
164 impl fmt::Display for TimeoutExpired {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result165     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166         write!(f, "Timeout expired")
167     }
168 }
169 
170 // std::error::Error only requires a type to impl Debug and Display
171 impl std::error::Error for TimeoutExpired {}
172 
173 #[cfg(test)]
174 mod tests {
175     use super::*;
176     use quickcheck::{Arbitrary, Gen};
177     use quickcheck_macros::quickcheck;
178 
179     // Helper function to reduce the boiler plate of our test cases
setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue>180     fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
181         let mut hm = HeaderMap::new();
182         if let Some(v) = val {
183             let hv = HeaderValue::from_str(v).unwrap();
184             hm.insert(GRPC_TIMEOUT_HEADER, hv);
185         };
186 
187         try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
188     }
189 
190     #[test]
test_hours()191     fn test_hours() {
192         let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
193         assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
194     }
195 
196     #[test]
test_minutes()197     fn test_minutes() {
198         let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
199         assert_eq!(Duration::from_secs(60), parsed_duration);
200     }
201 
202     #[test]
test_seconds()203     fn test_seconds() {
204         let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
205         assert_eq!(Duration::from_secs(42), parsed_duration);
206     }
207 
208     #[test]
test_milliseconds()209     fn test_milliseconds() {
210         let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
211         assert_eq!(Duration::from_millis(13), parsed_duration);
212     }
213 
214     #[test]
test_microseconds()215     fn test_microseconds() {
216         let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
217         assert_eq!(Duration::from_micros(2), parsed_duration);
218     }
219 
220     #[test]
test_nanoseconds()221     fn test_nanoseconds() {
222         let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
223         assert_eq!(Duration::from_nanos(82), parsed_duration);
224     }
225 
226     #[test]
test_header_not_present()227     fn test_header_not_present() {
228         let parsed_duration = setup_map_try_parse(None).unwrap();
229         assert!(parsed_duration.is_none());
230     }
231 
232     #[test]
233     #[should_panic(expected = "82f")]
test_invalid_unit()234     fn test_invalid_unit() {
235         // "f" is not a valid TimeoutUnit
236         setup_map_try_parse(Some("82f")).unwrap().unwrap();
237     }
238 
239     #[test]
240     #[should_panic(expected = "123456789H")]
test_too_many_digits()241     fn test_too_many_digits() {
242         // gRPC spec states TimeoutValue will be at most 8 digits
243         setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
244     }
245 
246     #[test]
247     #[should_panic(expected = "oneH")]
test_invalid_digits()248     fn test_invalid_digits() {
249         // gRPC spec states TimeoutValue will be at most 8 digits
250         setup_map_try_parse(Some("oneH")).unwrap().unwrap();
251     }
252 
253     #[quickcheck]
fuzz(header_value: HeaderValueGen) -> bool254     fn fuzz(header_value: HeaderValueGen) -> bool {
255         let header_value = header_value.0;
256 
257         // this just shouldn't panic
258         let _ = setup_map_try_parse(Some(&header_value));
259 
260         true
261     }
262 
263     /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s.
264     #[derive(Clone, Debug)]
265     struct HeaderValueGen(String);
266 
267     impl Arbitrary for HeaderValueGen {
arbitrary(g: &mut Gen) -> Self268         fn arbitrary(g: &mut Gen) -> Self {
269             let max = g.choose(&(1..70).collect::<Vec<_>>()).copied().unwrap();
270             Self(gen_string(g, 0, max))
271         }
272     }
273 
274     // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs
gen_string(g: &mut Gen, min: usize, max: usize) -> String275     fn gen_string(g: &mut Gen, min: usize, max: usize) -> String {
276         let bytes: Vec<_> = (min..max)
277             .map(|_| {
278                 // Chars to pick from
279                 g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----")
280                     .copied()
281                     .unwrap()
282             })
283             .collect();
284 
285         String::from_utf8(bytes).unwrap()
286     }
287 }
288