• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use super::{
2     future::ResponseFuture,
3     message::Message,
4     worker::{Handle, Worker},
5 };
6 
7 use futures_core::ready;
8 use std::sync::Arc;
9 use std::task::{Context, Poll};
10 use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
11 use tokio_util::sync::PollSemaphore;
12 use tower_service::Service;
13 
14 /// Adds an mpsc buffer in front of an inner service.
15 ///
16 /// See the module documentation for more details.
17 #[derive(Debug)]
18 pub struct Buffer<T, Request>
19 where
20     T: Service<Request>,
21 {
22     // Note: this actually _is_ bounded, but rather than using Tokio's bounded
23     // channel, we use Tokio's semaphore separately to implement the bound.
24     tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
25     // When the buffer's channel is full, we want to exert backpressure in
26     // `poll_ready`, so that callers such as load balancers could choose to call
27     // another service rather than waiting for buffer capacity.
28     //
29     // Unfortunately, this can't be done easily using Tokio's bounded MPSC
30     // channel, because it doesn't expose a polling-based interface, only an
31     // `async fn ready`, which borrows the sender. Therefore, we implement our
32     // own bounded MPSC on top of the unbounded channel, using a semaphore to
33     // limit how many items are in the channel.
34     semaphore: PollSemaphore,
35     // The current semaphore permit, if one has been acquired.
36     //
37     // This is acquired in `poll_ready` and taken in `call`.
38     permit: Option<OwnedSemaphorePermit>,
39     handle: Handle,
40 }
41 
42 impl<T, Request> Buffer<T, Request>
43 where
44     T: Service<Request>,
45     T::Error: Into<crate::BoxError>,
46 {
47     /// Creates a new [`Buffer`] wrapping `service`.
48     ///
49     /// `bound` gives the maximal number of requests that can be queued for the service before
50     /// backpressure is applied to callers.
51     ///
52     /// The default Tokio executor is used to run the given service, which means that this method
53     /// must be called while on the Tokio runtime.
54     ///
55     /// # A note on choosing a `bound`
56     ///
57     /// When [`Buffer`]'s implementation of [`poll_ready`] returns [`Poll::Ready`], it reserves a
58     /// slot in the channel for the forthcoming [`call`]. However, if this call doesn't arrive,
59     /// this reserved slot may be held up for a long time. As a result, it's advisable to set
60     /// `bound` to be at least the maximum number of concurrent requests the [`Buffer`] will see.
61     /// If you do not, all the slots in the buffer may be held up by futures that have just called
62     /// [`poll_ready`] but will not issue a [`call`], which prevents other senders from issuing new
63     /// requests.
64     ///
65     /// [`Poll::Ready`]: std::task::Poll::Ready
66     /// [`call`]: crate::Service::call
67     /// [`poll_ready`]: crate::Service::poll_ready
new(service: T, bound: usize) -> Self where T: Send + 'static, T::Future: Send, T::Error: Send + Sync, Request: Send + 'static,68     pub fn new(service: T, bound: usize) -> Self
69     where
70         T: Send + 'static,
71         T::Future: Send,
72         T::Error: Send + Sync,
73         Request: Send + 'static,
74     {
75         let (service, worker) = Self::pair(service, bound);
76         tokio::spawn(worker);
77         service
78     }
79 
80     /// Creates a new [`Buffer`] wrapping `service`, but returns the background worker.
81     ///
82     /// This is useful if you do not want to spawn directly onto the tokio runtime
83     /// but instead want to use your own executor. This will return the [`Buffer`] and
84     /// the background `Worker` that you can then spawn.
pair(service: T, bound: usize) -> (Buffer<T, Request>, Worker<T, Request>) where T: Send + 'static, T::Error: Send + Sync, Request: Send + 'static,85     pub fn pair(service: T, bound: usize) -> (Buffer<T, Request>, Worker<T, Request>)
86     where
87         T: Send + 'static,
88         T::Error: Send + Sync,
89         Request: Send + 'static,
90     {
91         let (tx, rx) = mpsc::unbounded_channel();
92         let semaphore = Arc::new(Semaphore::new(bound));
93         let (handle, worker) = Worker::new(service, rx, &semaphore);
94         let buffer = Buffer {
95             tx,
96             handle,
97             semaphore: PollSemaphore::new(semaphore),
98             permit: None,
99         };
100         (buffer, worker)
101     }
102 
get_worker_error(&self) -> crate::BoxError103     fn get_worker_error(&self) -> crate::BoxError {
104         self.handle.get_error_on_closed()
105     }
106 }
107 
108 impl<T, Request> Service<Request> for Buffer<T, Request>
109 where
110     T: Service<Request>,
111     T::Error: Into<crate::BoxError>,
112 {
113     type Response = T::Response;
114     type Error = crate::BoxError;
115     type Future = ResponseFuture<T::Future>;
116 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>117     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118         // First, check if the worker is still alive.
119         if self.tx.is_closed() {
120             // If the inner service has errored, then we error here.
121             return Poll::Ready(Err(self.get_worker_error()));
122         }
123 
124         // Then, check if we've already acquired a permit.
125         if self.permit.is_some() {
126             // We've already reserved capacity to send a request. We're ready!
127             return Poll::Ready(Ok(()));
128         }
129 
130         // Finally, if we haven't already acquired a permit, poll the semaphore
131         // to acquire one. If we acquire a permit, then there's enough buffer
132         // capacity to send a new request. Otherwise, we need to wait for
133         // capacity.
134         let permit =
135             ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
136         self.permit = Some(permit);
137 
138         Poll::Ready(Ok(()))
139     }
140 
call(&mut self, request: Request) -> Self::Future141     fn call(&mut self, request: Request) -> Self::Future {
142         tracing::trace!("sending request to buffer worker");
143         let _permit = self
144             .permit
145             .take()
146             .expect("buffer full; poll_ready must be called first");
147 
148         // get the current Span so that we can explicitly propagate it to the worker
149         // if we didn't do this, events on the worker related to this span wouldn't be counted
150         // towards that span since the worker would have no way of entering it.
151         let span = tracing::Span::current();
152 
153         // If we've made it here, then a semaphore permit has already been
154         // acquired, so we can freely allocate a oneshot.
155         let (tx, rx) = oneshot::channel();
156 
157         match self.tx.send(Message {
158             request,
159             span,
160             tx,
161             _permit,
162         }) {
163             Err(_) => ResponseFuture::failed(self.get_worker_error()),
164             Ok(_) => ResponseFuture::new(rx),
165         }
166     }
167 }
168 
169 impl<T, Request> Clone for Buffer<T, Request>
170 where
171     T: Service<Request>,
172 {
clone(&self) -> Self173     fn clone(&self) -> Self {
174         Self {
175             tx: self.tx.clone(),
176             handle: self.handle.clone(),
177             semaphore: self.semaphore.clone(),
178             // The new clone hasn't acquired a permit yet. It will when it's
179             // next polled ready.
180             permit: None,
181         }
182     }
183 }
184