1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use std::cell::RefCell; 15 use std::future::Future; 16 use std::mem::MaybeUninit; 17 use std::pin::Pin; 18 use std::sync::atomic::AtomicUsize; 19 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; 20 use std::task::Poll::{Pending, Ready}; 21 use std::task::{Context, Poll}; 22 23 use crate::sync::atomic_waker::AtomicWaker; 24 use crate::sync::error::{RecvError, SendError, TryRecvError, TrySendError}; 25 use crate::sync::mpsc::Container; 26 use crate::sync::wake_list::WakerList; 27 28 /// The offset of the index. 29 const INDEX_SHIFT: usize = 1; 30 /// The flag marks that Array is closed. 31 const CLOSED: usize = 0b01; 32 33 pub(crate) struct Node<T> { 34 index: AtomicUsize, 35 value: RefCell<MaybeUninit<T>>, 36 } 37 38 /// Bounded lockless queue. 39 pub(crate) struct Array<T> { 40 head: RefCell<usize>, 41 tail: AtomicUsize, 42 capacity: usize, 43 rx_waker: AtomicWaker, 44 waiters: WakerList, 45 data: Box<[Node<T>]>, 46 } 47 48 unsafe impl<T: Send> Send for Array<T> {} 49 unsafe impl<T: Send> Sync for Array<T> {} 50 51 pub(crate) enum SendPosition { 52 Pos(usize), 53 Full, 54 Closed, 55 } 56 57 impl<T> Array<T> { new(capacity: usize) -> Array<T>58 pub(crate) fn new(capacity: usize) -> Array<T> { 59 assert!(capacity > 0, "Capacity cannot be zero."); 60 let data = (0..capacity) 61 .map(|i| Node { 62 index: AtomicUsize::new(i), 63 value: RefCell::new(MaybeUninit::uninit()), 64 }) 65 .collect(); 66 Array { 67 head: RefCell::new(0), 68 tail: AtomicUsize::new(0), 69 capacity, 70 rx_waker: AtomicWaker::new(), 71 waiters: WakerList::new(), 72 data, 73 } 74 } 75 prepare_send(&self) -> SendPosition76 fn prepare_send(&self) -> SendPosition { 77 let mut tail = self.tail.load(Acquire); 78 loop { 79 if tail & CLOSED == CLOSED { 80 return SendPosition::Closed; 81 } 82 let index = (tail >> INDEX_SHIFT) % self.capacity; 83 let node = self.data.get(index).unwrap(); 84 let node_index = node.index.load(Acquire); 85 86 // Compare the index of the node with the tail to avoid senders in different 87 // cycles writing data to the same point at the same time. 88 if (tail >> INDEX_SHIFT) == node_index { 89 match self.tail.compare_exchange_weak( 90 tail, 91 tail.wrapping_add(1 << INDEX_SHIFT), 92 AcqRel, 93 Acquire, 94 ) { 95 Ok(_) => { 96 return SendPosition::Pos(index); 97 } 98 Err(actual) => { 99 tail = actual; 100 } 101 } 102 } else { 103 return SendPosition::Full; 104 } 105 } 106 } 107 write(&self, index: usize, value: T)108 pub(crate) fn write(&self, index: usize, value: T) { 109 let node = self.data.get(index).unwrap(); 110 node.value.borrow_mut().write(value); 111 112 // Mark that the node has data. 113 node.index.fetch_sub(1, Release); 114 self.rx_waker.wake(); 115 } 116 get_position(&self) -> SendPosition117 pub(crate) async fn get_position(&self) -> SendPosition { 118 Position { array: self }.await 119 } 120 try_send(&self, value: T) -> Result<(), TrySendError<T>>121 pub(crate) fn try_send(&self, value: T) -> Result<(), TrySendError<T>> { 122 match self.prepare_send() { 123 SendPosition::Pos(index) => { 124 self.write(index, value); 125 Ok(()) 126 } 127 SendPosition::Full => Err(TrySendError::Full(value)), 128 SendPosition::Closed => Err(TrySendError::Closed(value)), 129 } 130 } 131 send(&self, value: T) -> Result<(), SendError<T>>132 pub(crate) async fn send(&self, value: T) -> Result<(), SendError<T>> { 133 match self.get_position().await { 134 SendPosition::Pos(index) => { 135 self.write(index, value); 136 Ok(()) 137 } 138 SendPosition::Closed => Err(SendError(value)), 139 // If the array is full, the task will wait until it's available. 140 SendPosition::Full => unreachable!(), 141 } 142 } 143 try_recv(&self) -> Result<T, TryRecvError>144 pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> { 145 let head = *self.head.borrow(); 146 let index = head % self.capacity; 147 let node = self.data.get(index).unwrap(); 148 let node_index = node.index.load(Acquire); 149 150 // Check whether the node has data. 151 if head == node_index.wrapping_add(1) { 152 let value = unsafe { node.value.as_ptr().read().assume_init() }; 153 // Adding one indicates that this point is empty, Adding <capacity> enables the 154 // corresponding tail node to write in. 155 node.index.fetch_add(self.capacity + 1, Release); 156 self.waiters.notify_one(); 157 self.head.replace(head + 1); 158 Ok(value) 159 } else { 160 let tail = self.tail.load(Acquire); 161 if tail & CLOSED == CLOSED { 162 Err(TryRecvError::Closed) 163 } else { 164 Err(TryRecvError::Empty) 165 } 166 } 167 } 168 poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>>169 pub(crate) fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { 170 match self.try_recv() { 171 Ok(val) => return Ready(Ok(val)), 172 Err(TryRecvError::Closed) => return Ready(Err(RecvError)), 173 Err(TryRecvError::Empty) => {} 174 } 175 176 self.rx_waker.register_by_ref(cx.waker()); 177 178 match self.try_recv() { 179 Ok(val) => Ready(Ok(val)), 180 Err(TryRecvError::Closed) => Ready(Err(RecvError)), 181 Err(TryRecvError::Empty) => Pending, 182 } 183 } 184 capacity(&self) -> usize185 pub(crate) fn capacity(&self) -> usize { 186 self.capacity 187 } 188 } 189 190 impl<T> Container for Array<T> { close(&self)191 fn close(&self) { 192 self.tail.fetch_or(CLOSED, Release); 193 self.waiters.notify_all(); 194 self.rx_waker.wake(); 195 } 196 is_close(&self) -> bool197 fn is_close(&self) -> bool { 198 self.tail.load(Acquire) & CLOSED == CLOSED 199 } 200 len(&self) -> usize201 fn len(&self) -> usize { 202 let head = *self.head.borrow(); 203 let tail = self.tail.load(Acquire) >> INDEX_SHIFT; 204 tail - head 205 } 206 } 207 208 impl<T> Drop for Array<T> { drop(&mut self)209 fn drop(&mut self) { 210 let len = self.len(); 211 if len == 0 { 212 return; 213 } 214 let head = *self.head.borrow(); 215 for i in 0..len { 216 let mut index = head + i; 217 if index > self.capacity { 218 index -= self.capacity; 219 } 220 let node = self.data.get_mut(index).unwrap(); 221 unsafe { 222 node.value.borrow_mut().as_mut_ptr().drop_in_place(); 223 } 224 } 225 } 226 } 227 228 struct Position<'a, T> { 229 array: &'a Array<T>, 230 } 231 232 impl<T> Future for Position<'_, T> { 233 type Output = SendPosition; 234 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>235 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 236 match self.array.prepare_send() { 237 SendPosition::Pos(index) => return Ready(SendPosition::Pos(index)), 238 SendPosition::Closed => return Ready(SendPosition::Closed), 239 SendPosition::Full => {} 240 } 241 242 self.array.waiters.insert(cx.waker().clone()); 243 244 let tail = self.array.tail.load(Acquire); 245 let index = (tail >> INDEX_SHIFT) % self.array.capacity; 246 let node = self.array.data.get(index).unwrap(); 247 let node_index = node.index.load(Acquire); 248 if (tail >> INDEX_SHIFT) == node_index || tail & CLOSED == CLOSED { 249 self.array.waiters.notify_one(); 250 } 251 Pending 252 } 253 } 254