• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use super::Rate;
2 use futures_core::ready;
3 use std::{
4     future::Future,
5     pin::Pin,
6     task::{Context, Poll},
7 };
8 use tokio::time::{Instant, Sleep};
9 use tower_service::Service;
10 
11 /// Enforces a rate limit on the number of requests the underlying
12 /// service can handle over a period of time.
13 #[derive(Debug)]
14 pub struct RateLimit<T> {
15     inner: T,
16     rate: Rate,
17     state: State,
18     sleep: Pin<Box<Sleep>>,
19 }
20 
21 #[derive(Debug)]
22 enum State {
23     // The service has hit its limit
24     Limited,
25     Ready { until: Instant, rem: u64 },
26 }
27 
28 impl<T> RateLimit<T> {
29     /// Create a new rate limiter
new(inner: T, rate: Rate) -> Self30     pub fn new(inner: T, rate: Rate) -> Self {
31         let until = Instant::now();
32         let state = State::Ready {
33             until,
34             rem: rate.num(),
35         };
36 
37         RateLimit {
38             inner,
39             rate,
40             state,
41             // The sleep won't actually be used with this duration, but
42             // we create it eagerly so that we can reset it in place rather than
43             // `Box::pin`ning a new `Sleep` every time we need one.
44             sleep: Box::pin(tokio::time::sleep_until(until)),
45         }
46     }
47 
48     /// Get a reference to the inner service
get_ref(&self) -> &T49     pub fn get_ref(&self) -> &T {
50         &self.inner
51     }
52 
53     /// Get a mutable reference to the inner service
get_mut(&mut self) -> &mut T54     pub fn get_mut(&mut self) -> &mut T {
55         &mut self.inner
56     }
57 
58     /// Consume `self`, returning the inner service
into_inner(self) -> T59     pub fn into_inner(self) -> T {
60         self.inner
61     }
62 }
63 
64 impl<S, Request> Service<Request> for RateLimit<S>
65 where
66     S: Service<Request>,
67 {
68     type Response = S::Response;
69     type Error = S::Error;
70     type Future = S::Future;
71 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>72     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73         match self.state {
74             State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))),
75             State::Limited => {
76                 if Pin::new(&mut self.sleep).poll(cx).is_pending() {
77                     tracing::trace!("rate limit exceeded; sleeping.");
78                     return Poll::Pending;
79                 }
80             }
81         }
82 
83         self.state = State::Ready {
84             until: Instant::now() + self.rate.per(),
85             rem: self.rate.num(),
86         };
87 
88         Poll::Ready(ready!(self.inner.poll_ready(cx)))
89     }
90 
call(&mut self, request: Request) -> Self::Future91     fn call(&mut self, request: Request) -> Self::Future {
92         match self.state {
93             State::Ready { mut until, mut rem } => {
94                 let now = Instant::now();
95 
96                 // If the period has elapsed, reset it.
97                 if now >= until {
98                     until = now + self.rate.per();
99                     rem = self.rate.num();
100                 }
101 
102                 if rem > 1 {
103                     rem -= 1;
104                     self.state = State::Ready { until, rem };
105                 } else {
106                     // The service is disabled until further notice
107                     // Reset the sleep future in place, so that we don't have to
108                     // deallocate the existing box and allocate a new one.
109                     self.sleep.as_mut().reset(until);
110                     self.state = State::Limited;
111                 }
112 
113                 // Call the inner future
114                 self.inner.call(request)
115             }
116             State::Limited => panic!("service not ready; poll_ready must be called first"),
117         }
118     }
119 }
120 
121 #[cfg(feature = "load")]
122 #[cfg_attr(docsrs, doc(cfg(feature = "load")))]
123 impl<S> crate::load::Load for RateLimit<S>
124 where
125     S: crate::load::Load,
126 {
127     type Metric = S::Metric;
load(&self) -> Self::Metric128     fn load(&self) -> Self::Metric {
129         self.inner.load()
130     }
131 }
132