1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use std::error::Error; 15 use std::fmt::{Debug, Display, Formatter}; 16 use std::future::Future; 17 use std::pin::Pin; 18 use std::sync::atomic::AtomicUsize; 19 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; 20 use std::task::Poll::{Pending, Ready}; 21 use std::task::{Context, Poll}; 22 23 use crate::sync::wake_list::WakerList; 24 25 /// Maximum capacity of `Semaphore`. 26 const MAX_PERMITS: usize = usize::MAX >> 1; 27 /// The least significant bit that marks the number of permits. 28 const PERMIT_SHIFT: usize = 1; 29 /// The flag marks that Semaphore is closed. 30 const CLOSED: usize = 1; 31 32 pub(crate) struct SemaphoreInner { 33 permits: AtomicUsize, 34 waker_list: WakerList, 35 } 36 37 pub(crate) struct Permit<'a> { 38 semaphore: &'a SemaphoreInner, 39 waker_index: Option<usize>, 40 enqueue: bool, 41 } 42 43 /// Error returned by `Semaphore`. 44 #[derive(Debug, Eq, PartialEq)] 45 pub enum SemaphoreError { 46 /// The number of Permits is overflowed. 47 Overflow, 48 /// Semaphore doesn't have enough permits. 49 Empty, 50 /// Semaphore was closed. 51 Closed, 52 } 53 54 impl Display for SemaphoreError { fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 56 match self { 57 SemaphoreError::Overflow => write!(f, "permit overflow MAX_PERMITS : {MAX_PERMITS}"), 58 SemaphoreError::Empty => write!(f, "no permits available"), 59 SemaphoreError::Closed => write!(f, "semaphore has been closed"), 60 } 61 } 62 } 63 64 impl Error for SemaphoreError {} 65 66 impl SemaphoreInner { new(permits: usize) -> Result<SemaphoreInner, SemaphoreError>67 pub(crate) fn new(permits: usize) -> Result<SemaphoreInner, SemaphoreError> { 68 if permits >= MAX_PERMITS { 69 return Err(SemaphoreError::Overflow); 70 } 71 Ok(SemaphoreInner { 72 permits: AtomicUsize::new(permits << PERMIT_SHIFT), 73 waker_list: WakerList::new(), 74 }) 75 } 76 current_permits(&self) -> usize77 pub(crate) fn current_permits(&self) -> usize { 78 self.permits.load(Acquire) >> PERMIT_SHIFT 79 } 80 release(&self)81 pub(crate) fn release(&self) { 82 // Get the lock first to ensure the atomicity of the two operations. 83 let mut waker_list = self.waker_list.lock(); 84 if !waker_list.notify_one() { 85 let prev = self.permits.fetch_add(1 << PERMIT_SHIFT, Release); 86 assert!( 87 (prev >> PERMIT_SHIFT) < MAX_PERMITS, 88 "the number of permits will overflow the capacity after addition" 89 ); 90 } 91 } 92 release_notify(&self)93 pub(crate) fn release_notify(&self) { 94 // Get the lock first to ensure the atomicity of the two operations. 95 let mut waker_list = self.waker_list.lock(); 96 if !waker_list.notify_one() { 97 self.permits.store(1 << PERMIT_SHIFT, Release); 98 } 99 } 100 release_multi(&self, mut permits: usize)101 pub(crate) fn release_multi(&self, mut permits: usize) { 102 let mut waker_list = self.waker_list.lock(); 103 while permits > 0 && waker_list.notify_one() { 104 permits -= 1; 105 } 106 let prev = self.permits.fetch_add(permits << PERMIT_SHIFT, Release); 107 assert!( 108 (prev >> PERMIT_SHIFT) < MAX_PERMITS, 109 "the number of permits will overflow the capacity after addition" 110 ); 111 } 112 release_all(&self)113 pub(crate) fn release_all(&self) { 114 self.waker_list.notify_all(); 115 } 116 try_acquire(&self) -> Result<(), SemaphoreError>117 pub(crate) fn try_acquire(&self) -> Result<(), SemaphoreError> { 118 let mut curr = self.permits.load(Acquire); 119 loop { 120 if curr & CLOSED == CLOSED { 121 return Err(SemaphoreError::Closed); 122 } 123 124 if curr > 0 { 125 match self.permits.compare_exchange( 126 curr, 127 curr - (1 << PERMIT_SHIFT), 128 AcqRel, 129 Acquire, 130 ) { 131 Ok(_) => { 132 return Ok(()); 133 } 134 Err(actual) => { 135 curr = actual; 136 } 137 } 138 } else { 139 return Err(SemaphoreError::Empty); 140 } 141 } 142 } 143 is_closed(&self) -> bool144 pub(crate) fn is_closed(&self) -> bool { 145 self.permits.load(Acquire) & CLOSED == CLOSED 146 } 147 close(&self)148 pub(crate) fn close(&self) { 149 // Get the lock first to ensure the atomicity of the two operations. 150 let mut waker_list = self.waker_list.lock(); 151 self.permits.fetch_or(CLOSED, Release); 152 waker_list.notify_all(); 153 } 154 acquire(&self) -> Permit<'_>155 pub(crate) fn acquire(&self) -> Permit<'_> { 156 Permit::new(self) 157 } 158 poll_acquire( &self, cx: &mut Context<'_>, waker_index: &mut Option<usize>, enqueue: &mut bool, ) -> Poll<Result<(), SemaphoreError>>159 fn poll_acquire( 160 &self, 161 cx: &mut Context<'_>, 162 waker_index: &mut Option<usize>, 163 enqueue: &mut bool, 164 ) -> Poll<Result<(), SemaphoreError>> { 165 let mut curr = self.permits.load(Acquire); 166 if curr & CLOSED == CLOSED { 167 return Ready(Err(SemaphoreError::Closed)); 168 } else if *enqueue { 169 *enqueue = false; 170 return Ready(Ok(())); 171 } 172 let permit_num = 1 << PERMIT_SHIFT; 173 loop { 174 if curr & CLOSED == CLOSED { 175 return Ready(Err(SemaphoreError::Closed)); 176 } 177 if curr >= permit_num { 178 match self 179 .permits 180 .compare_exchange(curr, curr - permit_num, AcqRel, Acquire) 181 { 182 Ok(_) => { 183 if *enqueue { 184 self.release(); 185 return Pending; 186 } 187 return Ready(Ok(())); 188 } 189 Err(actual) => { 190 curr = actual; 191 } 192 } 193 } else if !(*enqueue) { 194 *waker_index = Some(self.waker_list.insert(cx.waker().clone())); 195 *enqueue = true; 196 curr = self.permits.load(Acquire); 197 } else { 198 return Pending; 199 } 200 } 201 } 202 } 203 204 impl Debug for SemaphoreInner { fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result205 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 206 f.debug_struct("Semaphore") 207 .field("permits", &self.current_permits()) 208 .finish() 209 } 210 } 211 212 impl<'a> Permit<'a> { new(semaphore: &'a SemaphoreInner) -> Permit213 fn new(semaphore: &'a SemaphoreInner) -> Permit { 214 Permit { 215 semaphore, 216 waker_index: None, 217 enqueue: false, 218 } 219 } 220 } 221 222 impl Future for Permit<'_> { 223 type Output = Result<(), SemaphoreError>; 224 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>225 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 226 let (semaphore, waker_index, enqueue) = unsafe { 227 let me = self.get_unchecked_mut(); 228 (me.semaphore, &mut me.waker_index, &mut me.enqueue) 229 }; 230 231 semaphore.poll_acquire(cx, waker_index, enqueue) 232 } 233 } 234 235 impl Drop for Permit<'_> { drop(&mut self)236 fn drop(&mut self) { 237 if self.enqueue { 238 // if `enqueue` is true, `waker_index` must be `Some(_)`. 239 let _ = self.semaphore.waker_list.remove(self.waker_index.unwrap()); 240 } 241 } 242 } 243