• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::error::Error;
15 use std::fmt::{Debug, Display, Formatter};
16 use std::future::Future;
17 use std::pin::Pin;
18 use std::sync::atomic::AtomicUsize;
19 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
20 use std::task::Poll::{Pending, Ready};
21 use std::task::{Context, Poll};
22 
23 use crate::sync::wake_list::WakerList;
24 
25 /// Maximum capacity of `Semaphore`.
26 const MAX_PERMITS: usize = usize::MAX >> 1;
27 /// The least significant bit that marks the number of permits.
28 const PERMIT_SHIFT: usize = 1;
29 /// The flag marks that Semaphore is closed.
30 const CLOSED: usize = 1;
31 
32 pub(crate) struct SemaphoreInner {
33     permits: AtomicUsize,
34     waker_list: WakerList,
35 }
36 
37 pub(crate) struct Permit<'a> {
38     semaphore: &'a SemaphoreInner,
39     waker_index: Option<usize>,
40     enqueue: bool,
41 }
42 
43 /// Error returned by `Semaphore`.
44 #[derive(Debug, Eq, PartialEq)]
45 pub enum SemaphoreError {
46     /// The number of Permits is overflowed.
47     Overflow,
48     /// Semaphore doesn't have enough permits.
49     Empty,
50     /// Semaphore was closed.
51     Closed,
52 }
53 
54 impl Display for SemaphoreError {
fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result55     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56         match self {
57             SemaphoreError::Overflow => write!(f, "permit overflow MAX_PERMITS : {MAX_PERMITS}"),
58             SemaphoreError::Empty => write!(f, "no permits available"),
59             SemaphoreError::Closed => write!(f, "semaphore has been closed"),
60         }
61     }
62 }
63 
64 impl Error for SemaphoreError {}
65 
66 impl SemaphoreInner {
new(permits: usize) -> Result<SemaphoreInner, SemaphoreError>67     pub(crate) fn new(permits: usize) -> Result<SemaphoreInner, SemaphoreError> {
68         if permits >= MAX_PERMITS {
69             return Err(SemaphoreError::Overflow);
70         }
71         Ok(SemaphoreInner {
72             permits: AtomicUsize::new(permits << PERMIT_SHIFT),
73             waker_list: WakerList::new(),
74         })
75     }
76 
current_permits(&self) -> usize77     pub(crate) fn current_permits(&self) -> usize {
78         self.permits.load(Acquire) >> PERMIT_SHIFT
79     }
80 
release(&self)81     pub(crate) fn release(&self) {
82         // Get the lock first to ensure the atomicity of the two operations.
83         let mut waker_list = self.waker_list.lock();
84         if !waker_list.notify_one() {
85             let prev = self.permits.fetch_add(1 << PERMIT_SHIFT, Release);
86             assert!(
87                 (prev >> PERMIT_SHIFT) < MAX_PERMITS,
88                 "the number of permits will overflow the capacity after addition"
89             );
90         }
91     }
92 
release_notify(&self)93     pub(crate) fn release_notify(&self) {
94         // Get the lock first to ensure the atomicity of the two operations.
95         let mut waker_list = self.waker_list.lock();
96         if !waker_list.notify_one() {
97             self.permits.store(1 << PERMIT_SHIFT, Release);
98         }
99     }
100 
release_multi(&self, mut permits: usize)101     pub(crate) fn release_multi(&self, mut permits: usize) {
102         let mut waker_list = self.waker_list.lock();
103         while permits > 0 && waker_list.notify_one() {
104             permits -= 1;
105         }
106         let prev = self.permits.fetch_add(permits << PERMIT_SHIFT, Release);
107         assert!(
108             (prev >> PERMIT_SHIFT) < MAX_PERMITS,
109             "the number of permits will overflow the capacity after addition"
110         );
111     }
112 
release_all(&self)113     pub(crate) fn release_all(&self) {
114         self.waker_list.notify_all();
115     }
116 
try_acquire(&self) -> Result<(), SemaphoreError>117     pub(crate) fn try_acquire(&self) -> Result<(), SemaphoreError> {
118         let mut curr = self.permits.load(Acquire);
119         loop {
120             if curr & CLOSED == CLOSED {
121                 return Err(SemaphoreError::Closed);
122             }
123 
124             if curr > 0 {
125                 match self.permits.compare_exchange(
126                     curr,
127                     curr - (1 << PERMIT_SHIFT),
128                     AcqRel,
129                     Acquire,
130                 ) {
131                     Ok(_) => {
132                         return Ok(());
133                     }
134                     Err(actual) => {
135                         curr = actual;
136                     }
137                 }
138             } else {
139                 return Err(SemaphoreError::Empty);
140             }
141         }
142     }
143 
is_closed(&self) -> bool144     pub(crate) fn is_closed(&self) -> bool {
145         self.permits.load(Acquire) & CLOSED == CLOSED
146     }
147 
close(&self)148     pub(crate) fn close(&self) {
149         // Get the lock first to ensure the atomicity of the two operations.
150         let mut waker_list = self.waker_list.lock();
151         self.permits.fetch_or(CLOSED, Release);
152         waker_list.notify_all();
153     }
154 
acquire(&self) -> Permit<'_>155     pub(crate) fn acquire(&self) -> Permit<'_> {
156         Permit::new(self)
157     }
158 
poll_acquire( &self, cx: &mut Context<'_>, waker_index: &mut Option<usize>, enqueue: &mut bool, ) -> Poll<Result<(), SemaphoreError>>159     fn poll_acquire(
160         &self,
161         cx: &mut Context<'_>,
162         waker_index: &mut Option<usize>,
163         enqueue: &mut bool,
164     ) -> Poll<Result<(), SemaphoreError>> {
165         let mut curr = self.permits.load(Acquire);
166         if curr & CLOSED == CLOSED {
167             return Ready(Err(SemaphoreError::Closed));
168         } else if *enqueue {
169             *enqueue = false;
170             return Ready(Ok(()));
171         }
172         let permit_num = 1 << PERMIT_SHIFT;
173         loop {
174             if curr & CLOSED == CLOSED {
175                 return Ready(Err(SemaphoreError::Closed));
176             }
177             if curr >= permit_num {
178                 match self
179                     .permits
180                     .compare_exchange(curr, curr - permit_num, AcqRel, Acquire)
181                 {
182                     Ok(_) => {
183                         if *enqueue {
184                             self.release();
185                             return Pending;
186                         }
187                         return Ready(Ok(()));
188                     }
189                     Err(actual) => {
190                         curr = actual;
191                     }
192                 }
193             } else if !(*enqueue) {
194                 *waker_index = Some(self.waker_list.insert(cx.waker().clone()));
195                 *enqueue = true;
196                 curr = self.permits.load(Acquire);
197             } else {
198                 return Pending;
199             }
200         }
201     }
202 }
203 
204 impl Debug for SemaphoreInner {
fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result205     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
206         f.debug_struct("Semaphore")
207             .field("permits", &self.current_permits())
208             .finish()
209     }
210 }
211 
212 impl<'a> Permit<'a> {
new(semaphore: &'a SemaphoreInner) -> Permit213     fn new(semaphore: &'a SemaphoreInner) -> Permit {
214         Permit {
215             semaphore,
216             waker_index: None,
217             enqueue: false,
218         }
219     }
220 }
221 
222 impl Future for Permit<'_> {
223     type Output = Result<(), SemaphoreError>;
224 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>225     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
226         let (semaphore, waker_index, enqueue) = unsafe {
227             let me = self.get_unchecked_mut();
228             (me.semaphore, &mut me.waker_index, &mut me.enqueue)
229         };
230 
231         semaphore.poll_acquire(cx, waker_index, enqueue)
232     }
233 }
234 
235 impl Drop for Permit<'_> {
drop(&mut self)236     fn drop(&mut self) {
237         if self.enqueue {
238             // if `enqueue` is true, `waker_index` must be `Some(_)`.
239             let _ = self.semaphore.waker_list.remove(self.waker_index.unwrap());
240         }
241     }
242 }
243