• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Bounded channel based on a preallocated array.
2 //!
3 //! This flavor has a fixed, positive capacity.
4 //!
5 //! The implementation is based on Dmitry Vyukov's bounded MPMC queue.
6 //!
7 //! Source:
8 //!   - <http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue>
9 //!   - <https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub>
10 
11 use std::cell::UnsafeCell;
12 use std::mem::{self, MaybeUninit};
13 use std::ptr;
14 use std::sync::atomic::{self, AtomicUsize, Ordering};
15 use std::time::Instant;
16 
17 use crossbeam_utils::{Backoff, CachePadded};
18 
19 use crate::context::Context;
20 use crate::err::{RecvTimeoutError, SendTimeoutError, TryRecvError, TrySendError};
21 use crate::select::{Operation, SelectHandle, Selected, Token};
22 use crate::waker::SyncWaker;
23 
24 /// A slot in a channel.
25 struct Slot<T> {
26     /// The current stamp.
27     stamp: AtomicUsize,
28 
29     /// The message in this slot.
30     msg: UnsafeCell<MaybeUninit<T>>,
31 }
32 
33 /// The token type for the array flavor.
34 #[derive(Debug)]
35 pub(crate) struct ArrayToken {
36     /// Slot to read from or write to.
37     slot: *const u8,
38 
39     /// Stamp to store into the slot after reading or writing.
40     stamp: usize,
41 }
42 
43 impl Default for ArrayToken {
44     #[inline]
default() -> Self45     fn default() -> Self {
46         ArrayToken {
47             slot: ptr::null(),
48             stamp: 0,
49         }
50     }
51 }
52 
53 /// Bounded channel based on a preallocated array.
54 pub(crate) struct Channel<T> {
55     /// The head of the channel.
56     ///
57     /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
58     /// packed into a single `usize`. The lower bits represent the index, while the upper bits
59     /// represent the lap. The mark bit in the head is always zero.
60     ///
61     /// Messages are popped from the head of the channel.
62     head: CachePadded<AtomicUsize>,
63 
64     /// The tail of the channel.
65     ///
66     /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
67     /// packed into a single `usize`. The lower bits represent the index, while the upper bits
68     /// represent the lap. The mark bit indicates that the channel is disconnected.
69     ///
70     /// Messages are pushed into the tail of the channel.
71     tail: CachePadded<AtomicUsize>,
72 
73     /// The buffer holding slots.
74     buffer: Box<[Slot<T>]>,
75 
76     /// The channel capacity.
77     cap: usize,
78 
79     /// A stamp with the value of `{ lap: 1, mark: 0, index: 0 }`.
80     one_lap: usize,
81 
82     /// If this bit is set in the tail, that means the channel is disconnected.
83     mark_bit: usize,
84 
85     /// Senders waiting while the channel is full.
86     senders: SyncWaker,
87 
88     /// Receivers waiting while the channel is empty and not disconnected.
89     receivers: SyncWaker,
90 }
91 
92 impl<T> Channel<T> {
93     /// Creates a bounded channel of capacity `cap`.
with_capacity(cap: usize) -> Self94     pub(crate) fn with_capacity(cap: usize) -> Self {
95         assert!(cap > 0, "capacity must be positive");
96 
97         // Compute constants `mark_bit` and `one_lap`.
98         let mark_bit = (cap + 1).next_power_of_two();
99         let one_lap = mark_bit * 2;
100 
101         // Head is initialized to `{ lap: 0, mark: 0, index: 0 }`.
102         let head = 0;
103         // Tail is initialized to `{ lap: 0, mark: 0, index: 0 }`.
104         let tail = 0;
105 
106         // Allocate a buffer of `cap` slots initialized
107         // with stamps.
108         let buffer: Box<[Slot<T>]> = (0..cap)
109             .map(|i| {
110                 // Set the stamp to `{ lap: 0, mark: 0, index: i }`.
111                 Slot {
112                     stamp: AtomicUsize::new(i),
113                     msg: UnsafeCell::new(MaybeUninit::uninit()),
114                 }
115             })
116             .collect();
117 
118         Channel {
119             buffer,
120             cap,
121             one_lap,
122             mark_bit,
123             head: CachePadded::new(AtomicUsize::new(head)),
124             tail: CachePadded::new(AtomicUsize::new(tail)),
125             senders: SyncWaker::new(),
126             receivers: SyncWaker::new(),
127         }
128     }
129 
130     /// Returns a receiver handle to the channel.
receiver(&self) -> Receiver<'_, T>131     pub(crate) fn receiver(&self) -> Receiver<'_, T> {
132         Receiver(self)
133     }
134 
135     /// Returns a sender handle to the channel.
sender(&self) -> Sender<'_, T>136     pub(crate) fn sender(&self) -> Sender<'_, T> {
137         Sender(self)
138     }
139 
140     /// Attempts to reserve a slot for sending a message.
start_send(&self, token: &mut Token) -> bool141     fn start_send(&self, token: &mut Token) -> bool {
142         let backoff = Backoff::new();
143         let mut tail = self.tail.load(Ordering::Relaxed);
144 
145         loop {
146             // Check if the channel is disconnected.
147             if tail & self.mark_bit != 0 {
148                 token.array.slot = ptr::null();
149                 token.array.stamp = 0;
150                 return true;
151             }
152 
153             // Deconstruct the tail.
154             let index = tail & (self.mark_bit - 1);
155             let lap = tail & !(self.one_lap - 1);
156 
157             // Inspect the corresponding slot.
158             debug_assert!(index < self.buffer.len());
159             let slot = unsafe { self.buffer.get_unchecked(index) };
160             let stamp = slot.stamp.load(Ordering::Acquire);
161 
162             // If the tail and the stamp match, we may attempt to push.
163             if tail == stamp {
164                 let new_tail = if index + 1 < self.cap {
165                     // Same lap, incremented index.
166                     // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
167                     tail + 1
168                 } else {
169                     // One lap forward, index wraps around to zero.
170                     // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
171                     lap.wrapping_add(self.one_lap)
172                 };
173 
174                 // Try moving the tail.
175                 match self.tail.compare_exchange_weak(
176                     tail,
177                     new_tail,
178                     Ordering::SeqCst,
179                     Ordering::Relaxed,
180                 ) {
181                     Ok(_) => {
182                         // Prepare the token for the follow-up call to `write`.
183                         token.array.slot = slot as *const Slot<T> as *const u8;
184                         token.array.stamp = tail + 1;
185                         return true;
186                     }
187                     Err(t) => {
188                         tail = t;
189                         backoff.spin();
190                     }
191                 }
192             } else if stamp.wrapping_add(self.one_lap) == tail + 1 {
193                 atomic::fence(Ordering::SeqCst);
194                 let head = self.head.load(Ordering::Relaxed);
195 
196                 // If the head lags one lap behind the tail as well...
197                 if head.wrapping_add(self.one_lap) == tail {
198                     // ...then the channel is full.
199                     return false;
200                 }
201 
202                 backoff.spin();
203                 tail = self.tail.load(Ordering::Relaxed);
204             } else {
205                 // Snooze because we need to wait for the stamp to get updated.
206                 backoff.snooze();
207                 tail = self.tail.load(Ordering::Relaxed);
208             }
209         }
210     }
211 
212     /// Writes a message into the channel.
write(&self, token: &mut Token, msg: T) -> Result<(), T>213     pub(crate) unsafe fn write(&self, token: &mut Token, msg: T) -> Result<(), T> {
214         // If there is no slot, the channel is disconnected.
215         if token.array.slot.is_null() {
216             return Err(msg);
217         }
218 
219         let slot: &Slot<T> = &*token.array.slot.cast::<Slot<T>>();
220 
221         // Write the message into the slot and update the stamp.
222         slot.msg.get().write(MaybeUninit::new(msg));
223         slot.stamp.store(token.array.stamp, Ordering::Release);
224 
225         // Wake a sleeping receiver.
226         self.receivers.notify();
227         Ok(())
228     }
229 
230     /// Attempts to reserve a slot for receiving a message.
start_recv(&self, token: &mut Token) -> bool231     fn start_recv(&self, token: &mut Token) -> bool {
232         let backoff = Backoff::new();
233         let mut head = self.head.load(Ordering::Relaxed);
234 
235         loop {
236             // Deconstruct the head.
237             let index = head & (self.mark_bit - 1);
238             let lap = head & !(self.one_lap - 1);
239 
240             // Inspect the corresponding slot.
241             debug_assert!(index < self.buffer.len());
242             let slot = unsafe { self.buffer.get_unchecked(index) };
243             let stamp = slot.stamp.load(Ordering::Acquire);
244 
245             // If the the stamp is ahead of the head by 1, we may attempt to pop.
246             if head + 1 == stamp {
247                 let new = if index + 1 < self.cap {
248                     // Same lap, incremented index.
249                     // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
250                     head + 1
251                 } else {
252                     // One lap forward, index wraps around to zero.
253                     // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
254                     lap.wrapping_add(self.one_lap)
255                 };
256 
257                 // Try moving the head.
258                 match self.head.compare_exchange_weak(
259                     head,
260                     new,
261                     Ordering::SeqCst,
262                     Ordering::Relaxed,
263                 ) {
264                     Ok(_) => {
265                         // Prepare the token for the follow-up call to `read`.
266                         token.array.slot = slot as *const Slot<T> as *const u8;
267                         token.array.stamp = head.wrapping_add(self.one_lap);
268                         return true;
269                     }
270                     Err(h) => {
271                         head = h;
272                         backoff.spin();
273                     }
274                 }
275             } else if stamp == head {
276                 atomic::fence(Ordering::SeqCst);
277                 let tail = self.tail.load(Ordering::Relaxed);
278 
279                 // If the tail equals the head, that means the channel is empty.
280                 if (tail & !self.mark_bit) == head {
281                     // If the channel is disconnected...
282                     if tail & self.mark_bit != 0 {
283                         // ...then receive an error.
284                         token.array.slot = ptr::null();
285                         token.array.stamp = 0;
286                         return true;
287                     } else {
288                         // Otherwise, the receive operation is not ready.
289                         return false;
290                     }
291                 }
292 
293                 backoff.spin();
294                 head = self.head.load(Ordering::Relaxed);
295             } else {
296                 // Snooze because we need to wait for the stamp to get updated.
297                 backoff.snooze();
298                 head = self.head.load(Ordering::Relaxed);
299             }
300         }
301     }
302 
303     /// Reads a message from the channel.
read(&self, token: &mut Token) -> Result<T, ()>304     pub(crate) unsafe fn read(&self, token: &mut Token) -> Result<T, ()> {
305         if token.array.slot.is_null() {
306             // The channel is disconnected.
307             return Err(());
308         }
309 
310         let slot: &Slot<T> = &*token.array.slot.cast::<Slot<T>>();
311 
312         // Read the message from the slot and update the stamp.
313         let msg = slot.msg.get().read().assume_init();
314         slot.stamp.store(token.array.stamp, Ordering::Release);
315 
316         // Wake a sleeping sender.
317         self.senders.notify();
318         Ok(msg)
319     }
320 
321     /// Attempts to send a message into the channel.
try_send(&self, msg: T) -> Result<(), TrySendError<T>>322     pub(crate) fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
323         let token = &mut Token::default();
324         if self.start_send(token) {
325             unsafe { self.write(token, msg).map_err(TrySendError::Disconnected) }
326         } else {
327             Err(TrySendError::Full(msg))
328         }
329     }
330 
331     /// Sends a message into the channel.
send( &self, msg: T, deadline: Option<Instant>, ) -> Result<(), SendTimeoutError<T>>332     pub(crate) fn send(
333         &self,
334         msg: T,
335         deadline: Option<Instant>,
336     ) -> Result<(), SendTimeoutError<T>> {
337         let token = &mut Token::default();
338         loop {
339             // Try sending a message several times.
340             let backoff = Backoff::new();
341             loop {
342                 if self.start_send(token) {
343                     let res = unsafe { self.write(token, msg) };
344                     return res.map_err(SendTimeoutError::Disconnected);
345                 }
346 
347                 if backoff.is_completed() {
348                     break;
349                 } else {
350                     backoff.snooze();
351                 }
352             }
353 
354             if let Some(d) = deadline {
355                 if Instant::now() >= d {
356                     return Err(SendTimeoutError::Timeout(msg));
357                 }
358             }
359 
360             Context::with(|cx| {
361                 // Prepare for blocking until a receiver wakes us up.
362                 let oper = Operation::hook(token);
363                 self.senders.register(oper, cx);
364 
365                 // Has the channel become ready just now?
366                 if !self.is_full() || self.is_disconnected() {
367                     let _ = cx.try_select(Selected::Aborted);
368                 }
369 
370                 // Block the current thread.
371                 let sel = cx.wait_until(deadline);
372 
373                 match sel {
374                     Selected::Waiting => unreachable!(),
375                     Selected::Aborted | Selected::Disconnected => {
376                         self.senders.unregister(oper).unwrap();
377                     }
378                     Selected::Operation(_) => {}
379                 }
380             });
381         }
382     }
383 
384     /// Attempts to receive a message without blocking.
try_recv(&self) -> Result<T, TryRecvError>385     pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
386         let token = &mut Token::default();
387 
388         if self.start_recv(token) {
389             unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) }
390         } else {
391             Err(TryRecvError::Empty)
392         }
393     }
394 
395     /// Receives a message from the channel.
recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError>396     pub(crate) fn recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError> {
397         let token = &mut Token::default();
398         loop {
399             // Try receiving a message several times.
400             let backoff = Backoff::new();
401             loop {
402                 if self.start_recv(token) {
403                     let res = unsafe { self.read(token) };
404                     return res.map_err(|_| RecvTimeoutError::Disconnected);
405                 }
406 
407                 if backoff.is_completed() {
408                     break;
409                 } else {
410                     backoff.snooze();
411                 }
412             }
413 
414             if let Some(d) = deadline {
415                 if Instant::now() >= d {
416                     return Err(RecvTimeoutError::Timeout);
417                 }
418             }
419 
420             Context::with(|cx| {
421                 // Prepare for blocking until a sender wakes us up.
422                 let oper = Operation::hook(token);
423                 self.receivers.register(oper, cx);
424 
425                 // Has the channel become ready just now?
426                 if !self.is_empty() || self.is_disconnected() {
427                     let _ = cx.try_select(Selected::Aborted);
428                 }
429 
430                 // Block the current thread.
431                 let sel = cx.wait_until(deadline);
432 
433                 match sel {
434                     Selected::Waiting => unreachable!(),
435                     Selected::Aborted | Selected::Disconnected => {
436                         self.receivers.unregister(oper).unwrap();
437                         // If the channel was disconnected, we still have to check for remaining
438                         // messages.
439                     }
440                     Selected::Operation(_) => {}
441                 }
442             });
443         }
444     }
445 
446     /// Returns the current number of messages inside the channel.
len(&self) -> usize447     pub(crate) fn len(&self) -> usize {
448         loop {
449             // Load the tail, then load the head.
450             let tail = self.tail.load(Ordering::SeqCst);
451             let head = self.head.load(Ordering::SeqCst);
452 
453             // If the tail didn't change, we've got consistent values to work with.
454             if self.tail.load(Ordering::SeqCst) == tail {
455                 let hix = head & (self.mark_bit - 1);
456                 let tix = tail & (self.mark_bit - 1);
457 
458                 return if hix < tix {
459                     tix - hix
460                 } else if hix > tix {
461                     self.cap - hix + tix
462                 } else if (tail & !self.mark_bit) == head {
463                     0
464                 } else {
465                     self.cap
466                 };
467             }
468         }
469     }
470 
471     /// Returns the capacity of the channel.
capacity(&self) -> Option<usize>472     pub(crate) fn capacity(&self) -> Option<usize> {
473         Some(self.cap)
474     }
475 
476     /// Disconnects the channel and wakes up all blocked senders and receivers.
477     ///
478     /// Returns `true` if this call disconnected the channel.
disconnect(&self) -> bool479     pub(crate) fn disconnect(&self) -> bool {
480         let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
481 
482         if tail & self.mark_bit == 0 {
483             self.senders.disconnect();
484             self.receivers.disconnect();
485             true
486         } else {
487             false
488         }
489     }
490 
491     /// Returns `true` if the channel is disconnected.
is_disconnected(&self) -> bool492     pub(crate) fn is_disconnected(&self) -> bool {
493         self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
494     }
495 
496     /// Returns `true` if the channel is empty.
is_empty(&self) -> bool497     pub(crate) fn is_empty(&self) -> bool {
498         let head = self.head.load(Ordering::SeqCst);
499         let tail = self.tail.load(Ordering::SeqCst);
500 
501         // Is the tail equal to the head?
502         //
503         // Note: If the head changes just before we load the tail, that means there was a moment
504         // when the channel was not empty, so it is safe to just return `false`.
505         (tail & !self.mark_bit) == head
506     }
507 
508     /// Returns `true` if the channel is full.
is_full(&self) -> bool509     pub(crate) fn is_full(&self) -> bool {
510         let tail = self.tail.load(Ordering::SeqCst);
511         let head = self.head.load(Ordering::SeqCst);
512 
513         // Is the head lagging one lap behind tail?
514         //
515         // Note: If the tail changes just before we load the head, that means there was a moment
516         // when the channel was not full, so it is safe to just return `false`.
517         head.wrapping_add(self.one_lap) == tail & !self.mark_bit
518     }
519 }
520 
521 impl<T> Drop for Channel<T> {
drop(&mut self)522     fn drop(&mut self) {
523         if mem::needs_drop::<T>() {
524             // Get the index of the head.
525             let head = *self.head.get_mut();
526             let tail = *self.tail.get_mut();
527 
528             let hix = head & (self.mark_bit - 1);
529             let tix = tail & (self.mark_bit - 1);
530 
531             let len = if hix < tix {
532                 tix - hix
533             } else if hix > tix {
534                 self.cap - hix + tix
535             } else if (tail & !self.mark_bit) == head {
536                 0
537             } else {
538                 self.cap
539             };
540 
541             // Loop over all slots that hold a message and drop them.
542             for i in 0..len {
543                 // Compute the index of the next slot holding a message.
544                 let index = if hix + i < self.cap {
545                     hix + i
546                 } else {
547                     hix + i - self.cap
548                 };
549 
550                 unsafe {
551                     debug_assert!(index < self.buffer.len());
552                     let slot = self.buffer.get_unchecked_mut(index);
553                     (*slot.msg.get()).assume_init_drop();
554                 }
555             }
556         }
557     }
558 }
559 
560 /// Receiver handle to a channel.
561 pub(crate) struct Receiver<'a, T>(&'a Channel<T>);
562 
563 /// Sender handle to a channel.
564 pub(crate) struct Sender<'a, T>(&'a Channel<T>);
565 
566 impl<T> SelectHandle for Receiver<'_, T> {
try_select(&self, token: &mut Token) -> bool567     fn try_select(&self, token: &mut Token) -> bool {
568         self.0.start_recv(token)
569     }
570 
deadline(&self) -> Option<Instant>571     fn deadline(&self) -> Option<Instant> {
572         None
573     }
574 
register(&self, oper: Operation, cx: &Context) -> bool575     fn register(&self, oper: Operation, cx: &Context) -> bool {
576         self.0.receivers.register(oper, cx);
577         self.is_ready()
578     }
579 
unregister(&self, oper: Operation)580     fn unregister(&self, oper: Operation) {
581         self.0.receivers.unregister(oper);
582     }
583 
accept(&self, token: &mut Token, _cx: &Context) -> bool584     fn accept(&self, token: &mut Token, _cx: &Context) -> bool {
585         self.try_select(token)
586     }
587 
is_ready(&self) -> bool588     fn is_ready(&self) -> bool {
589         !self.0.is_empty() || self.0.is_disconnected()
590     }
591 
watch(&self, oper: Operation, cx: &Context) -> bool592     fn watch(&self, oper: Operation, cx: &Context) -> bool {
593         self.0.receivers.watch(oper, cx);
594         self.is_ready()
595     }
596 
unwatch(&self, oper: Operation)597     fn unwatch(&self, oper: Operation) {
598         self.0.receivers.unwatch(oper);
599     }
600 }
601 
602 impl<T> SelectHandle for Sender<'_, T> {
try_select(&self, token: &mut Token) -> bool603     fn try_select(&self, token: &mut Token) -> bool {
604         self.0.start_send(token)
605     }
606 
deadline(&self) -> Option<Instant>607     fn deadline(&self) -> Option<Instant> {
608         None
609     }
610 
register(&self, oper: Operation, cx: &Context) -> bool611     fn register(&self, oper: Operation, cx: &Context) -> bool {
612         self.0.senders.register(oper, cx);
613         self.is_ready()
614     }
615 
unregister(&self, oper: Operation)616     fn unregister(&self, oper: Operation) {
617         self.0.senders.unregister(oper);
618     }
619 
accept(&self, token: &mut Token, _cx: &Context) -> bool620     fn accept(&self, token: &mut Token, _cx: &Context) -> bool {
621         self.try_select(token)
622     }
623 
is_ready(&self) -> bool624     fn is_ready(&self) -> bool {
625         !self.0.is_full() || self.0.is_disconnected()
626     }
627 
watch(&self, oper: Operation, cx: &Context) -> bool628     fn watch(&self, oper: Operation, cx: &Context) -> bool {
629         self.0.senders.watch(oper, cx);
630         self.is_ready()
631     }
632 
unwatch(&self, oper: Operation)633     fn unwatch(&self, oper: Operation) {
634         self.0.senders.unwatch(oper);
635     }
636 }
637