1 use std::sync::Arc; 2 3 use super::ConcurrencyLimit; 4 use tokio::sync::Semaphore; 5 use tower_layer::Layer; 6 7 /// Enforces a limit on the concurrent number of requests the underlying 8 /// service can handle. 9 #[derive(Debug, Clone)] 10 pub struct ConcurrencyLimitLayer { 11 max: usize, 12 } 13 14 impl ConcurrencyLimitLayer { 15 /// Create a new concurrency limit layer. new(max: usize) -> Self16 pub fn new(max: usize) -> Self { 17 ConcurrencyLimitLayer { max } 18 } 19 } 20 21 impl<S> Layer<S> for ConcurrencyLimitLayer { 22 type Service = ConcurrencyLimit<S>; 23 layer(&self, service: S) -> Self::Service24 fn layer(&self, service: S) -> Self::Service { 25 ConcurrencyLimit::new(service, self.max) 26 } 27 } 28 29 /// Enforces a limit on the concurrent number of requests the underlying 30 /// service can handle. 31 /// 32 /// Unlike [`ConcurrencyLimitLayer`], which enforces a per-service concurrency 33 /// limit, this layer accepts a owned semaphore (`Arc<Semaphore>`) which can be 34 /// shared across multiple services. 35 /// 36 /// Cloning this layer will not create a new semaphore. 37 #[derive(Debug, Clone)] 38 pub struct GlobalConcurrencyLimitLayer { 39 semaphore: Arc<Semaphore>, 40 } 41 42 impl GlobalConcurrencyLimitLayer { 43 /// Create a new `GlobalConcurrencyLimitLayer`. new(max: usize) -> Self44 pub fn new(max: usize) -> Self { 45 Self::with_semaphore(Arc::new(Semaphore::new(max))) 46 } 47 48 /// Create a new `GlobalConcurrencyLimitLayer` from a `Arc<Semaphore>` with_semaphore(semaphore: Arc<Semaphore>) -> Self49 pub fn with_semaphore(semaphore: Arc<Semaphore>) -> Self { 50 GlobalConcurrencyLimitLayer { semaphore } 51 } 52 } 53 54 impl<S> Layer<S> for GlobalConcurrencyLimitLayer { 55 type Service = ConcurrencyLimit<S>; 56 layer(&self, service: S) -> Self::Service57 fn layer(&self, service: S) -> Self::Service { 58 ConcurrencyLimit::with_semaphore(service, self.semaphore.clone()) 59 } 60 } 61