1 use super::Rate; 2 use futures_core::ready; 3 use std::{ 4 future::Future, 5 pin::Pin, 6 task::{Context, Poll}, 7 }; 8 use tokio::time::{Instant, Sleep}; 9 use tower_service::Service; 10 11 /// Enforces a rate limit on the number of requests the underlying 12 /// service can handle over a period of time. 13 #[derive(Debug)] 14 pub struct RateLimit<T> { 15 inner: T, 16 rate: Rate, 17 state: State, 18 sleep: Pin<Box<Sleep>>, 19 } 20 21 #[derive(Debug)] 22 enum State { 23 // The service has hit its limit 24 Limited, 25 Ready { until: Instant, rem: u64 }, 26 } 27 28 impl<T> RateLimit<T> { 29 /// Create a new rate limiter new(inner: T, rate: Rate) -> Self30 pub fn new(inner: T, rate: Rate) -> Self { 31 let until = Instant::now(); 32 let state = State::Ready { 33 until, 34 rem: rate.num(), 35 }; 36 37 RateLimit { 38 inner, 39 rate, 40 state, 41 // The sleep won't actually be used with this duration, but 42 // we create it eagerly so that we can reset it in place rather than 43 // `Box::pin`ning a new `Sleep` every time we need one. 44 sleep: Box::pin(tokio::time::sleep_until(until)), 45 } 46 } 47 48 /// Get a reference to the inner service get_ref(&self) -> &T49 pub fn get_ref(&self) -> &T { 50 &self.inner 51 } 52 53 /// Get a mutable reference to the inner service get_mut(&mut self) -> &mut T54 pub fn get_mut(&mut self) -> &mut T { 55 &mut self.inner 56 } 57 58 /// Consume `self`, returning the inner service into_inner(self) -> T59 pub fn into_inner(self) -> T { 60 self.inner 61 } 62 } 63 64 impl<S, Request> Service<Request> for RateLimit<S> 65 where 66 S: Service<Request>, 67 { 68 type Response = S::Response; 69 type Error = S::Error; 70 type Future = S::Future; 71 poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>72 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 73 match self.state { 74 State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))), 75 State::Limited => { 76 if Pin::new(&mut self.sleep).poll(cx).is_pending() { 77 tracing::trace!("rate limit exceeded; sleeping."); 78 return Poll::Pending; 79 } 80 } 81 } 82 83 self.state = State::Ready { 84 until: Instant::now() + self.rate.per(), 85 rem: self.rate.num(), 86 }; 87 88 Poll::Ready(ready!(self.inner.poll_ready(cx))) 89 } 90 call(&mut self, request: Request) -> Self::Future91 fn call(&mut self, request: Request) -> Self::Future { 92 match self.state { 93 State::Ready { mut until, mut rem } => { 94 let now = Instant::now(); 95 96 // If the period has elapsed, reset it. 97 if now >= until { 98 until = now + self.rate.per(); 99 rem = self.rate.num(); 100 } 101 102 if rem > 1 { 103 rem -= 1; 104 self.state = State::Ready { until, rem }; 105 } else { 106 // The service is disabled until further notice 107 // Reset the sleep future in place, so that we don't have to 108 // deallocate the existing box and allocate a new one. 109 self.sleep.as_mut().reset(until); 110 self.state = State::Limited; 111 } 112 113 // Call the inner future 114 self.inner.call(request) 115 } 116 State::Limited => panic!("service not ready; poll_ready must be called first"), 117 } 118 } 119 } 120 121 #[cfg(feature = "load")] 122 #[cfg_attr(docsrs, doc(cfg(feature = "load")))] 123 impl<S> crate::load::Load for RateLimit<S> 124 where 125 S: crate::load::Load, 126 { 127 type Metric = S::Metric; load(&self) -> Self::Metric128 fn load(&self) -> Self::Metric { 129 self.inner.load() 130 } 131 } 132