1 use crate::metadata::GRPC_TIMEOUT_HEADER;
2 use http::{HeaderMap, HeaderValue, Request};
3 use pin_project::pin_project;
4 use std::{
5 fmt,
6 future::Future,
7 pin::Pin,
8 task::{ready, Context, Poll},
9 time::Duration,
10 };
11 use tokio::time::Sleep;
12 use tower_service::Service;
13
14 #[derive(Debug, Clone)]
15 pub(crate) struct GrpcTimeout<S> {
16 inner: S,
17 server_timeout: Option<Duration>,
18 }
19
20 impl<S> GrpcTimeout<S> {
new(inner: S, server_timeout: Option<Duration>) -> Self21 pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
22 Self {
23 inner,
24 server_timeout,
25 }
26 }
27 }
28
29 impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTimeout<S>
30 where
31 S: Service<Request<ReqBody>>,
32 S::Error: Into<crate::Error>,
33 {
34 type Response = S::Response;
35 type Error = crate::Error;
36 type Future = ResponseFuture<S::Future>;
37
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>38 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
39 self.inner.poll_ready(cx).map_err(Into::into)
40 }
41
call(&mut self, req: Request<ReqBody>) -> Self::Future42 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
43 let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
44 tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
45 None
46 });
47
48 // Use the shorter of the two durations, if either are set
49 let timeout_duration = match (client_timeout, self.server_timeout) {
50 (None, None) => None,
51 (Some(dur), None) => Some(dur),
52 (None, Some(dur)) => Some(dur),
53 (Some(header), Some(server)) => {
54 let shorter_duration = std::cmp::min(header, server);
55 Some(shorter_duration)
56 }
57 };
58
59 ResponseFuture {
60 inner: self.inner.call(req),
61 sleep: timeout_duration
62 .map(tokio::time::sleep)
63 .map(Some)
64 .unwrap_or(None),
65 }
66 }
67 }
68
69 #[pin_project]
70 pub(crate) struct ResponseFuture<F> {
71 #[pin]
72 inner: F,
73 #[pin]
74 sleep: Option<Sleep>,
75 }
76
77 impl<F, Res, E> Future for ResponseFuture<F>
78 where
79 F: Future<Output = Result<Res, E>>,
80 E: Into<crate::Error>,
81 {
82 type Output = Result<Res, crate::Error>;
83
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>84 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85 let this = self.project();
86
87 if let Poll::Ready(result) = this.inner.poll(cx) {
88 return Poll::Ready(result.map_err(Into::into));
89 }
90
91 if let Some(sleep) = this.sleep.as_pin_mut() {
92 ready!(sleep.poll(cx));
93 return Poll::Ready(Err(TimeoutExpired(()).into()));
94 }
95
96 Poll::Pending
97 }
98 }
99
100 const SECONDS_IN_HOUR: u64 = 60 * 60;
101 const SECONDS_IN_MINUTE: u64 = 60;
102
103 /// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns
104 /// the value we attempted to parse.
105 ///
106 /// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
try_parse_grpc_timeout( headers: &HeaderMap<HeaderValue>, ) -> Result<Option<Duration>, &HeaderValue>107 fn try_parse_grpc_timeout(
108 headers: &HeaderMap<HeaderValue>,
109 ) -> Result<Option<Duration>, &HeaderValue> {
110 match headers.get(GRPC_TIMEOUT_HEADER) {
111 Some(val) => {
112 let (timeout_value, timeout_unit) = val
113 .to_str()
114 .map_err(|_| val)
115 .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
116 // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this
117 // `split_at` will never panic from trying to split in the middle of a character.
118 // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str
119 //
120 // `len - 1` also wont panic since we just checked `s.is_empty`.
121 .split_at(val.len() - 1);
122
123 // gRPC spec specifies `TimeoutValue` will be at most 8 digits
124 // Caping this at 8 digits also prevents integer overflow from ever occurring
125 if timeout_value.len() > 8 {
126 return Err(val);
127 }
128
129 let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
130
131 let duration = match timeout_unit {
132 // Hours
133 "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
134 // Minutes
135 "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
136 // Seconds
137 "S" => Duration::from_secs(timeout_value),
138 // Milliseconds
139 "m" => Duration::from_millis(timeout_value),
140 // Microseconds
141 "u" => Duration::from_micros(timeout_value),
142 // Nanoseconds
143 "n" => Duration::from_nanos(timeout_value),
144 _ => return Err(val),
145 };
146
147 Ok(Some(duration))
148 }
149 None => Ok(None),
150 }
151 }
152
153 /// Error returned if a request didn't complete within the configured timeout.
154 ///
155 /// Timeouts can be configured either with [`Endpoint::timeout`], [`Server::timeout`], or by
156 /// setting the [`grpc-timeout` metadata value][spec].
157 ///
158 /// [`Endpoint::timeout`]: crate::transport::server::Server::timeout
159 /// [`Server::timeout`]: crate::transport::channel::Endpoint::timeout
160 /// [spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
161 #[derive(Debug)]
162 pub struct TimeoutExpired(());
163
164 impl fmt::Display for TimeoutExpired {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 write!(f, "Timeout expired")
167 }
168 }
169
170 // std::error::Error only requires a type to impl Debug and Display
171 impl std::error::Error for TimeoutExpired {}
172
173 #[cfg(test)]
174 mod tests {
175 use super::*;
176 use quickcheck::{Arbitrary, Gen};
177 use quickcheck_macros::quickcheck;
178
179 // Helper function to reduce the boiler plate of our test cases
setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue>180 fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
181 let mut hm = HeaderMap::new();
182 if let Some(v) = val {
183 let hv = HeaderValue::from_str(v).unwrap();
184 hm.insert(GRPC_TIMEOUT_HEADER, hv);
185 };
186
187 try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
188 }
189
190 #[test]
test_hours()191 fn test_hours() {
192 let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
193 assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
194 }
195
196 #[test]
test_minutes()197 fn test_minutes() {
198 let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
199 assert_eq!(Duration::from_secs(60), parsed_duration);
200 }
201
202 #[test]
test_seconds()203 fn test_seconds() {
204 let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
205 assert_eq!(Duration::from_secs(42), parsed_duration);
206 }
207
208 #[test]
test_milliseconds()209 fn test_milliseconds() {
210 let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
211 assert_eq!(Duration::from_millis(13), parsed_duration);
212 }
213
214 #[test]
test_microseconds()215 fn test_microseconds() {
216 let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
217 assert_eq!(Duration::from_micros(2), parsed_duration);
218 }
219
220 #[test]
test_nanoseconds()221 fn test_nanoseconds() {
222 let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
223 assert_eq!(Duration::from_nanos(82), parsed_duration);
224 }
225
226 #[test]
test_header_not_present()227 fn test_header_not_present() {
228 let parsed_duration = setup_map_try_parse(None).unwrap();
229 assert!(parsed_duration.is_none());
230 }
231
232 #[test]
233 #[should_panic(expected = "82f")]
test_invalid_unit()234 fn test_invalid_unit() {
235 // "f" is not a valid TimeoutUnit
236 setup_map_try_parse(Some("82f")).unwrap().unwrap();
237 }
238
239 #[test]
240 #[should_panic(expected = "123456789H")]
test_too_many_digits()241 fn test_too_many_digits() {
242 // gRPC spec states TimeoutValue will be at most 8 digits
243 setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
244 }
245
246 #[test]
247 #[should_panic(expected = "oneH")]
test_invalid_digits()248 fn test_invalid_digits() {
249 // gRPC spec states TimeoutValue will be at most 8 digits
250 setup_map_try_parse(Some("oneH")).unwrap().unwrap();
251 }
252
253 #[quickcheck]
fuzz(header_value: HeaderValueGen) -> bool254 fn fuzz(header_value: HeaderValueGen) -> bool {
255 let header_value = header_value.0;
256
257 // this just shouldn't panic
258 let _ = setup_map_try_parse(Some(&header_value));
259
260 true
261 }
262
263 /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s.
264 #[derive(Clone, Debug)]
265 struct HeaderValueGen(String);
266
267 impl Arbitrary for HeaderValueGen {
arbitrary(g: &mut Gen) -> Self268 fn arbitrary(g: &mut Gen) -> Self {
269 let max = g.choose(&(1..70).collect::<Vec<_>>()).copied().unwrap();
270 Self(gen_string(g, 0, max))
271 }
272 }
273
274 // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs
gen_string(g: &mut Gen, min: usize, max: usize) -> String275 fn gen_string(g: &mut Gen, min: usize, max: usize) -> String {
276 let bytes: Vec<_> = (min..max)
277 .map(|_| {
278 // Chars to pick from
279 g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----")
280 .copied()
281 .unwrap()
282 })
283 .collect();
284
285 String::from_utf8(bytes).unwrap()
286 }
287 }
288