• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use futures_core::{ready, Stream};
2 use std::fmt;
3 use std::pin::Pin;
4 use std::sync::Arc;
5 use std::task::{Context, Poll};
6 use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
7 
8 use super::ReusableBoxFuture;
9 
10 /// A wrapper around [`Semaphore`] that provides a `poll_acquire` method.
11 ///
12 /// [`Semaphore`]: tokio::sync::Semaphore
13 pub struct PollSemaphore {
14     semaphore: Arc<Semaphore>,
15     permit_fut: Option<(
16         u32, // The number of permits requested.
17         ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>,
18     )>,
19 }
20 
21 impl PollSemaphore {
22     /// Create a new `PollSemaphore`.
new(semaphore: Arc<Semaphore>) -> Self23     pub fn new(semaphore: Arc<Semaphore>) -> Self {
24         Self {
25             semaphore,
26             permit_fut: None,
27         }
28     }
29 
30     /// Closes the semaphore.
close(&self)31     pub fn close(&self) {
32         self.semaphore.close()
33     }
34 
35     /// Obtain a clone of the inner semaphore.
clone_inner(&self) -> Arc<Semaphore>36     pub fn clone_inner(&self) -> Arc<Semaphore> {
37         self.semaphore.clone()
38     }
39 
40     /// Get back the inner semaphore.
into_inner(self) -> Arc<Semaphore>41     pub fn into_inner(self) -> Arc<Semaphore> {
42         self.semaphore
43     }
44 
45     /// Poll to acquire a permit from the semaphore.
46     ///
47     /// This can return the following values:
48     ///
49     ///  - `Poll::Pending` if a permit is not currently available.
50     ///  - `Poll::Ready(Some(permit))` if a permit was acquired.
51     ///  - `Poll::Ready(None)` if the semaphore has been closed.
52     ///
53     /// When this method returns `Poll::Pending`, the current task is scheduled
54     /// to receive a wakeup when a permit becomes available, or when the
55     /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
56     /// the `Waker` from the `Context` passed to the most recent call is
57     /// scheduled to receive a wakeup.
poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>>58     pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
59         self.poll_acquire_many(cx, 1)
60     }
61 
62     /// Poll to acquire many permits from the semaphore.
63     ///
64     /// This can return the following values:
65     ///
66     ///  - `Poll::Pending` if a permit is not currently available.
67     ///  - `Poll::Ready(Some(permit))` if a permit was acquired.
68     ///  - `Poll::Ready(None)` if the semaphore has been closed.
69     ///
70     /// When this method returns `Poll::Pending`, the current task is scheduled
71     /// to receive a wakeup when the permits become available, or when the
72     /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
73     /// the `Waker` from the `Context` passed to the most recent call is
74     /// scheduled to receive a wakeup.
poll_acquire_many( &mut self, cx: &mut Context<'_>, permits: u32, ) -> Poll<Option<OwnedSemaphorePermit>>75     pub fn poll_acquire_many(
76         &mut self,
77         cx: &mut Context<'_>,
78         permits: u32,
79     ) -> Poll<Option<OwnedSemaphorePermit>> {
80         let permit_future = match self.permit_fut.as_mut() {
81             Some((prev_permits, fut)) if *prev_permits == permits => fut,
82             Some((old_permits, fut_box)) => {
83                 // We're requesting a different number of permits, so replace the future
84                 // and record the new amount.
85                 let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
86                 fut_box.set(fut);
87                 *old_permits = permits;
88                 fut_box
89             }
90             None => {
91                 // avoid allocations completely if we can grab a permit immediately
92                 match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) {
93                     Ok(permit) => return Poll::Ready(Some(permit)),
94                     Err(TryAcquireError::Closed) => return Poll::Ready(None),
95                     Err(TryAcquireError::NoPermits) => {}
96                 }
97 
98                 let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
99                 &mut self
100                     .permit_fut
101                     .get_or_insert((permits, ReusableBoxFuture::new(next_fut)))
102                     .1
103             }
104         };
105 
106         let result = ready!(permit_future.poll(cx));
107 
108         // Assume we'll request the same amount of permits in a subsequent call.
109         let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
110         permit_future.set(next_fut);
111 
112         match result {
113             Ok(permit) => Poll::Ready(Some(permit)),
114             Err(_closed) => {
115                 self.permit_fut = None;
116                 Poll::Ready(None)
117             }
118         }
119     }
120 
121     /// Returns the current number of available permits.
122     ///
123     /// This is equivalent to the [`Semaphore::available_permits`] method on the
124     /// `tokio::sync::Semaphore` type.
125     ///
126     /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits
available_permits(&self) -> usize127     pub fn available_permits(&self) -> usize {
128         self.semaphore.available_permits()
129     }
130 
131     /// Adds `n` new permits to the semaphore.
132     ///
133     /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function
134     /// will panic if the limit is exceeded.
135     ///
136     /// This is equivalent to the [`Semaphore::add_permits`] method on the
137     /// `tokio::sync::Semaphore` type.
138     ///
139     /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits
add_permits(&self, n: usize)140     pub fn add_permits(&self, n: usize) {
141         self.semaphore.add_permits(n);
142     }
143 }
144 
145 impl Stream for PollSemaphore {
146     type Item = OwnedSemaphorePermit;
147 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>>148     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
149         Pin::into_inner(self).poll_acquire(cx)
150     }
151 }
152 
153 impl Clone for PollSemaphore {
clone(&self) -> PollSemaphore154     fn clone(&self) -> PollSemaphore {
155         PollSemaphore::new(self.clone_inner())
156     }
157 }
158 
159 impl fmt::Debug for PollSemaphore {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result160     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161         f.debug_struct("PollSemaphore")
162             .field("semaphore", &self.semaphore)
163             .finish()
164     }
165 }
166 
167 impl AsRef<Semaphore> for PollSemaphore {
as_ref(&self) -> &Semaphore168     fn as_ref(&self) -> &Semaphore {
169         &self.semaphore
170     }
171 }
172