• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cell::RefCell;
15 use std::future::Future;
16 use std::mem::MaybeUninit;
17 use std::pin::Pin;
18 use std::sync::atomic::AtomicUsize;
19 use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
20 use std::task::Poll::{Pending, Ready};
21 use std::task::{Context, Poll};
22 
23 use crate::sync::atomic_waker::AtomicWaker;
24 use crate::sync::error::{RecvError, SendError, TryRecvError, TrySendError};
25 use crate::sync::mpsc::Container;
26 use crate::sync::wake_list::WakerList;
27 
28 /// The offset of the index.
29 const INDEX_SHIFT: usize = 1;
30 /// The flag marks that Array is closed.
31 const CLOSED: usize = 0b01;
32 
33 pub(crate) struct Node<T> {
34     index: AtomicUsize,
35     value: RefCell<MaybeUninit<T>>,
36 }
37 
38 /// Bounded lockless queue.
39 pub(crate) struct Array<T> {
40     head: RefCell<usize>,
41     tail: AtomicUsize,
42     capacity: usize,
43     rx_waker: AtomicWaker,
44     waiters: WakerList,
45     data: Box<[Node<T>]>,
46 }
47 
48 unsafe impl<T: Send> Send for Array<T> {}
49 unsafe impl<T: Send> Sync for Array<T> {}
50 
51 pub(crate) enum SendPosition {
52     Pos(usize),
53     Full,
54     Closed,
55 }
56 
57 impl<T> Array<T> {
new(capacity: usize) -> Array<T>58     pub(crate) fn new(capacity: usize) -> Array<T> {
59         assert!(capacity > 0, "Capacity cannot be zero.");
60         let data = (0..capacity)
61             .map(|i| Node {
62                 index: AtomicUsize::new(i),
63                 value: RefCell::new(MaybeUninit::uninit()),
64             })
65             .collect();
66         Array {
67             head: RefCell::new(0),
68             tail: AtomicUsize::new(0),
69             capacity,
70             rx_waker: AtomicWaker::new(),
71             waiters: WakerList::new(),
72             data,
73         }
74     }
75 
prepare_send(&self) -> SendPosition76     fn prepare_send(&self) -> SendPosition {
77         let mut tail = self.tail.load(Acquire);
78         loop {
79             if tail & CLOSED == CLOSED {
80                 return SendPosition::Closed;
81             }
82             let index = (tail >> INDEX_SHIFT) % self.capacity;
83             let node = self.data.get(index).unwrap();
84             let node_index = node.index.load(Acquire);
85 
86             // Compare the index of the node with the tail to avoid senders in different
87             // cycles writing data to the same point at the same time.
88             if (tail >> INDEX_SHIFT) == node_index {
89                 match self.tail.compare_exchange_weak(
90                     tail,
91                     tail.wrapping_add(1 << INDEX_SHIFT),
92                     AcqRel,
93                     Acquire,
94                 ) {
95                     Ok(_) => {
96                         return SendPosition::Pos(index);
97                     }
98                     Err(actual) => {
99                         tail = actual;
100                     }
101                 }
102             } else {
103                 return SendPosition::Full;
104             }
105         }
106     }
107 
write(&self, index: usize, value: T)108     pub(crate) fn write(&self, index: usize, value: T) {
109         let node = self.data.get(index).unwrap();
110         node.value.borrow_mut().write(value);
111 
112         // Mark that the node has data.
113         node.index.fetch_sub(1, Release);
114         self.rx_waker.wake();
115     }
116 
get_position(&self) -> SendPosition117     pub(crate) async fn get_position(&self) -> SendPosition {
118         Position { array: self }.await
119     }
120 
try_send(&self, value: T) -> Result<(), TrySendError<T>>121     pub(crate) fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
122         match self.prepare_send() {
123             SendPosition::Pos(index) => {
124                 self.write(index, value);
125                 Ok(())
126             }
127             SendPosition::Full => Err(TrySendError::Full(value)),
128             SendPosition::Closed => Err(TrySendError::Closed(value)),
129         }
130     }
131 
send(&self, value: T) -> Result<(), SendError<T>>132     pub(crate) async fn send(&self, value: T) -> Result<(), SendError<T>> {
133         match self.get_position().await {
134             SendPosition::Pos(index) => {
135                 self.write(index, value);
136                 Ok(())
137             }
138             SendPosition::Closed => Err(SendError(value)),
139             // If the array is full, the task will wait until it's available.
140             SendPosition::Full => unreachable!(),
141         }
142     }
143 
try_recv(&self) -> Result<T, TryRecvError>144     pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
145         let head = *self.head.borrow();
146         let index = head % self.capacity;
147         let node = self.data.get(index).unwrap();
148         let node_index = node.index.load(Acquire);
149 
150         // Check whether the node has data.
151         if head == node_index.wrapping_add(1) {
152             let value = unsafe { node.value.as_ptr().read().assume_init() };
153             // Adding one indicates that this point is empty, Adding <capacity> enables the
154             // corresponding tail node to write in.
155             node.index.fetch_add(self.capacity + 1, Release);
156             self.waiters.notify_one();
157             self.head.replace(head + 1);
158             Ok(value)
159         } else {
160             let tail = self.tail.load(Acquire);
161             if tail & CLOSED == CLOSED {
162                 Err(TryRecvError::Closed)
163             } else {
164                 Err(TryRecvError::Empty)
165             }
166         }
167     }
168 
poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>>169     pub(crate) fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
170         match self.try_recv() {
171             Ok(val) => return Ready(Ok(val)),
172             Err(TryRecvError::Closed) => return Ready(Err(RecvError)),
173             Err(TryRecvError::Empty) => {}
174         }
175 
176         self.rx_waker.register_by_ref(cx.waker());
177 
178         match self.try_recv() {
179             Ok(val) => Ready(Ok(val)),
180             Err(TryRecvError::Closed) => Ready(Err(RecvError)),
181             Err(TryRecvError::Empty) => Pending,
182         }
183     }
184 
capacity(&self) -> usize185     pub(crate) fn capacity(&self) -> usize {
186         self.capacity
187     }
188 }
189 
190 impl<T> Container for Array<T> {
close(&self)191     fn close(&self) {
192         self.tail.fetch_or(CLOSED, Release);
193         self.waiters.notify_all();
194         self.rx_waker.wake();
195     }
196 
is_close(&self) -> bool197     fn is_close(&self) -> bool {
198         self.tail.load(Acquire) & CLOSED == CLOSED
199     }
200 
len(&self) -> usize201     fn len(&self) -> usize {
202         let head = *self.head.borrow();
203         let tail = self.tail.load(Acquire) >> INDEX_SHIFT;
204         tail - head
205     }
206 }
207 
208 impl<T> Drop for Array<T> {
drop(&mut self)209     fn drop(&mut self) {
210         let len = self.len();
211         if len == 0 {
212             return;
213         }
214         let head = *self.head.borrow();
215         for i in 0..len {
216             let mut index = head + i;
217             if index > self.capacity {
218                 index -= self.capacity;
219             }
220             let node = self.data.get_mut(index).unwrap();
221             unsafe {
222                 node.value.borrow_mut().as_mut_ptr().drop_in_place();
223             }
224         }
225     }
226 }
227 
228 struct Position<'a, T> {
229     array: &'a Array<T>,
230 }
231 
232 impl<T> Future for Position<'_, T> {
233     type Output = SendPosition;
234 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>235     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
236         match self.array.prepare_send() {
237             SendPosition::Pos(index) => return Ready(SendPosition::Pos(index)),
238             SendPosition::Closed => return Ready(SendPosition::Closed),
239             SendPosition::Full => {}
240         }
241 
242         self.array.waiters.insert(cx.waker().clone());
243 
244         let tail = self.array.tail.load(Acquire);
245         let index = (tail >> INDEX_SHIFT) % self.array.capacity;
246         let node = self.array.data.get(index).unwrap();
247         let node_index = node.index.load(Acquire);
248         if (tail >> INDEX_SHIFT) == node_index || tail & CLOSED == CLOSED {
249             self.array.waiters.notify_one();
250         }
251         Pending
252     }
253 }
254