• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::{
6     cell::UnsafeCell,
7     hint, mem,
8     sync::{
9         atomic::{AtomicUsize, Ordering},
10         Arc,
11     },
12 };
13 
14 use super::super::sync::{
15     mu::{MutexGuard, MutexReadGuard, RawMutex},
16     waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor},
17 };
18 
19 const SPINLOCK: usize = 1 << 0;
20 const HAS_WAITERS: usize = 1 << 1;
21 
22 /// A primitive to wait for an event to occur without consuming CPU time.
23 ///
24 /// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some
25 /// condition to become true. The condition must always be verified while holding the `Mutex` lock.
26 /// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on
27 /// the `Condvar`.
28 ///
29 /// # Examples
30 ///
31 /// ```edition2018
32 /// use std::sync::Arc;
33 /// use std::thread;
34 /// use std::sync::mpsc::channel;
35 ///
36 /// use cros_async::{
37 ///     block_on,
38 ///     sync::{Condvar, Mutex},
39 /// };
40 ///
41 /// const N: usize = 13;
42 ///
43 /// // Spawn a few threads to increment a shared variable (non-atomically), and
44 /// // let all threads waiting on the Condvar know once the increments are done.
45 /// let data = Arc::new(Mutex::new(0));
46 /// let cv = Arc::new(Condvar::new());
47 ///
48 /// for _ in 0..N {
49 ///     let (data, cv) = (data.clone(), cv.clone());
50 ///     thread::spawn(move || {
51 ///         let mut data = block_on(data.lock());
52 ///         *data += 1;
53 ///         if *data == N {
54 ///             cv.notify_all();
55 ///         }
56 ///     });
57 /// }
58 ///
59 /// let mut val = block_on(data.lock());
60 /// while *val != N {
61 ///     val = block_on(cv.wait(val));
62 /// }
63 /// ```
64 #[repr(align(128))]
65 pub struct Condvar {
66     state: AtomicUsize,
67     waiters: UnsafeCell<WaiterList>,
68     mu: UnsafeCell<usize>,
69 }
70 
71 impl Condvar {
72     /// Creates a new condition variable ready to be waited on and notified.
new() -> Condvar73     pub fn new() -> Condvar {
74         Condvar {
75             state: AtomicUsize::new(0),
76             waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
77             mu: UnsafeCell::new(0),
78         }
79     }
80 
81     /// Block the current thread until this `Condvar` is notified by another thread.
82     ///
83     /// This method will atomically unlock the `Mutex` held by `guard` and then block the current
84     /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up
85     /// the thread.
86     ///
87     /// To allow for more efficient scheduling, this call may return even when the programmer
88     /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a
89     /// loop that checks the predicate before continuing.
90     ///
91     /// Callers that are not in an async context may wish to use the `block_on` method to block the
92     /// thread until the `Condvar` is notified.
93     ///
94     /// # Panics
95     ///
96     /// This method will panic if used with more than one `Mutex` at the same time.
97     ///
98     /// # Examples
99     ///
100     /// ```
101     /// # use std::sync::Arc;
102     /// # use std::thread;
103     ///
104     /// # use cros_async::{
105     /// #     block_on,
106     /// #     sync::{Condvar, Mutex},
107     /// # };
108     ///
109     /// # let mu = Arc::new(Mutex::new(false));
110     /// # let cv = Arc::new(Condvar::new());
111     /// # let (mu2, cv2) = (mu.clone(), cv.clone());
112     ///
113     /// # let t = thread::spawn(move || {
114     /// #     *block_on(mu2.lock()) = true;
115     /// #     cv2.notify_all();
116     /// # });
117     ///
118     /// let mut ready = block_on(mu.lock());
119     /// while !*ready {
120     ///     ready = block_on(cv.wait(ready));
121     /// }
122     ///
123     /// # t.join().expect("failed to join thread");
124     /// ```
125     // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
126     // that doesn't compile.
127     #[allow(clippy::needless_lifetimes)]
wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T>128     pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> {
129         let waiter = Arc::new(Waiter::new(
130             WaiterKind::Exclusive,
131             cancel_waiter,
132             self as *const Condvar as usize,
133             WaitingFor::Condvar,
134         ));
135 
136         self.add_waiter(waiter.clone(), guard.as_raw_mutex());
137 
138         // Get a reference to the mutex and then drop the lock.
139         let mu = guard.into_inner();
140 
141         // Wait to be woken up.
142         waiter.wait().await;
143 
144         // Now re-acquire the lock.
145         mu.lock_from_cv().await
146     }
147 
148     /// Like `wait()` but takes and returns a `MutexReadGuard` instead.
149     // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
150     // that doesn't compile.
151     #[allow(clippy::needless_lifetimes)]
wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T>152     pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> {
153         let waiter = Arc::new(Waiter::new(
154             WaiterKind::Shared,
155             cancel_waiter,
156             self as *const Condvar as usize,
157             WaitingFor::Condvar,
158         ));
159 
160         self.add_waiter(waiter.clone(), guard.as_raw_mutex());
161 
162         // Get a reference to the mutex and then drop the lock.
163         let mu = guard.into_inner();
164 
165         // Wait to be woken up.
166         waiter.wait().await;
167 
168         // Now re-acquire the lock.
169         mu.read_lock_from_cv().await
170     }
171 
add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex)172     fn add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex) {
173         // Acquire the spin lock.
174         let mut oldstate = self.state.load(Ordering::Relaxed);
175         while (oldstate & SPINLOCK) != 0
176             || self
177                 .state
178                 .compare_exchange_weak(
179                     oldstate,
180                     oldstate | SPINLOCK | HAS_WAITERS,
181                     Ordering::Acquire,
182                     Ordering::Relaxed,
183                 )
184                 .is_err()
185         {
186             hint::spin_loop();
187             oldstate = self.state.load(Ordering::Relaxed);
188         }
189 
190         // Safe because the spin lock guarantees exclusive access and the reference does not escape
191         // this function.
192         let mu = unsafe { &mut *self.mu.get() };
193         let muptr = raw_mutex as *const RawMutex as usize;
194 
195         match *mu {
196             0 => *mu = muptr,
197             p if p == muptr => {}
198             _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"),
199         }
200 
201         // Safe because the spin lock guarantees exclusive access.
202         unsafe { (*self.waiters.get()).push_back(waiter) };
203 
204         // Release the spin lock. Use a direct store here because no other thread can modify
205         // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier
206         // because we just added a waiter.
207         self.state.store(HAS_WAITERS, Ordering::Release);
208     }
209 
210     /// Notify at most one thread currently waiting on the `Condvar`.
211     ///
212     /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to
213     /// `wait`.
214     ///
215     /// Unlike more traditional condition variable interfaces, this method requires a reference to
216     /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call
217     /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking
218     /// a reference to the `Mutex` here allows us to make some optimizations that can improve
219     /// performance by reducing unnecessary wakeups.
notify_one(&self)220     pub fn notify_one(&self) {
221         let mut oldstate = self.state.load(Ordering::Relaxed);
222         if (oldstate & HAS_WAITERS) == 0 {
223             // No waiters.
224             return;
225         }
226 
227         while (oldstate & SPINLOCK) != 0
228             || self
229                 .state
230                 .compare_exchange_weak(
231                     oldstate,
232                     oldstate | SPINLOCK,
233                     Ordering::Acquire,
234                     Ordering::Relaxed,
235                 )
236                 .is_err()
237         {
238             hint::spin_loop();
239             oldstate = self.state.load(Ordering::Relaxed);
240         }
241 
242         // Safe because the spin lock guarantees exclusive access and the reference does not escape
243         // this function.
244         let waiters = unsafe { &mut *self.waiters.get() };
245         let wake_list = get_wake_list(waiters);
246 
247         let newstate = if waiters.is_empty() {
248             // Also clear the mutex associated with this Condvar since there are no longer any
249             // waiters.  Safe because the spin lock guarantees exclusive access.
250             unsafe { *self.mu.get() = 0 };
251 
252             // We are releasing the spin lock and there are no more waiters so we can clear all bits
253             // in `self.state`.
254             0
255         } else {
256             // There are still waiters so we need to keep the HAS_WAITERS bit in the state.
257             HAS_WAITERS
258         };
259 
260         // Release the spin lock.
261         self.state.store(newstate, Ordering::Release);
262 
263         // Now wake any waiters in the wake list.
264         for w in wake_list {
265             w.wake();
266         }
267     }
268 
269     /// Notify all threads currently waiting on the `Condvar`.
270     ///
271     /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`.
272     ///
273     /// Unlike more traditional condition variable interfaces, this method requires a reference to
274     /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call
275     /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking
276     /// a reference to the `Mutex` here allows us to make some optimizations that can improve
277     /// performance by reducing unnecessary wakeups.
notify_all(&self)278     pub fn notify_all(&self) {
279         let mut oldstate = self.state.load(Ordering::Relaxed);
280         if (oldstate & HAS_WAITERS) == 0 {
281             // No waiters.
282             return;
283         }
284 
285         while (oldstate & SPINLOCK) != 0
286             || self
287                 .state
288                 .compare_exchange_weak(
289                     oldstate,
290                     oldstate | SPINLOCK,
291                     Ordering::Acquire,
292                     Ordering::Relaxed,
293                 )
294                 .is_err()
295         {
296             hint::spin_loop();
297             oldstate = self.state.load(Ordering::Relaxed);
298         }
299 
300         // Safe because the spin lock guarantees exclusive access to `self.waiters`.
301         let wake_list = unsafe { (*self.waiters.get()).take() };
302 
303         // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
304         // because we the spin lock guarantees exclusive access.
305         unsafe { *self.mu.get() = 0 };
306 
307         // Mark any waiters left as no longer waiting for the Condvar.
308         for w in &wake_list {
309             w.set_waiting_for(WaitingFor::None);
310         }
311 
312         // Release the spin lock.  We can clear all bits in the state since we took all the waiters.
313         self.state.store(0, Ordering::Release);
314 
315         // Now wake any waiters in the wake list.
316         for w in wake_list {
317             w.wake();
318         }
319     }
320 
cancel_waiter(&self, waiter: &Waiter, wake_next: bool)321     fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) {
322         let mut oldstate = self.state.load(Ordering::Relaxed);
323         while oldstate & SPINLOCK != 0
324             || self
325                 .state
326                 .compare_exchange_weak(
327                     oldstate,
328                     oldstate | SPINLOCK,
329                     Ordering::Acquire,
330                     Ordering::Relaxed,
331                 )
332                 .is_err()
333         {
334             hint::spin_loop();
335             oldstate = self.state.load(Ordering::Relaxed);
336         }
337 
338         // Safe because the spin lock provides exclusive access and the reference does not escape
339         // this function.
340         let waiters = unsafe { &mut *self.waiters.get() };
341 
342         let waiting_for = waiter.is_waiting_for();
343         // Don't drop the old waiter now as we're still holding the spin lock.
344         let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar {
345             // Safe because we know that the waiter is still linked and is waiting for the Condvar,
346             // which guarantees that it is still in `self.waiters`.
347             let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
348             cursor.remove()
349         } else {
350             None
351         };
352 
353         let wake_list = if wake_next || waiting_for == WaitingFor::None {
354             // Either the waiter was already woken or it's been removed from the condvar's waiter
355             // list and is going to be woken. Either way, we need to wake up another thread.
356             get_wake_list(waiters)
357         } else {
358             WaiterList::new(WaiterAdapter::new())
359         };
360 
361         let set_on_release = if waiters.is_empty() {
362             // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
363             // because we the spin lock guarantees exclusive access.
364             unsafe { *self.mu.get() = 0 };
365 
366             0
367         } else {
368             HAS_WAITERS
369         };
370 
371         self.state.store(set_on_release, Ordering::Release);
372 
373         // Now wake any waiters still left in the wake list.
374         for w in wake_list {
375             w.wake();
376         }
377 
378         mem::drop(old_waiter);
379     }
380 }
381 
382 unsafe impl Send for Condvar {}
383 unsafe impl Sync for Condvar {}
384 
385 impl Default for Condvar {
default() -> Self386     fn default() -> Self {
387         Self::new()
388     }
389 }
390 
391 // Scan `waiters` and return all waiters that should be woken up.
392 //
393 // If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
394 // waiting for a shared lock are also woken up. In addition one writer is woken up, if possible.
395 //
396 // If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
397 // the rest of the list is not scanned.
get_wake_list(waiters: &mut WaiterList) -> WaiterList398 fn get_wake_list(waiters: &mut WaiterList) -> WaiterList {
399     let mut to_wake = WaiterList::new(WaiterAdapter::new());
400     let mut cursor = waiters.front_mut();
401 
402     let mut waking_readers = false;
403     let mut all_readers = true;
404     while let Some(w) = cursor.get() {
405         match w.kind() {
406             WaiterKind::Exclusive if !waking_readers => {
407                 // This is the first waiter and it's a writer. No need to check the other waiters.
408                 // Also mark the waiter as having been removed from the Condvar's waiter list.
409                 let waiter = cursor.remove().unwrap();
410                 waiter.set_waiting_for(WaitingFor::None);
411                 to_wake.push_back(waiter);
412                 break;
413             }
414 
415             WaiterKind::Shared => {
416                 // This is a reader and the first waiter in the list was not a writer so wake up all
417                 // the readers in the wait list.
418                 let waiter = cursor.remove().unwrap();
419                 waiter.set_waiting_for(WaitingFor::None);
420                 to_wake.push_back(waiter);
421                 waking_readers = true;
422             }
423 
424             WaiterKind::Exclusive => {
425                 debug_assert!(waking_readers);
426                 if all_readers {
427                     // We are waking readers but we need to ensure that at least one writer is woken
428                     // up. Since we haven't yet woken up a writer, wake up this one.
429                     let waiter = cursor.remove().unwrap();
430                     waiter.set_waiting_for(WaitingFor::None);
431                     to_wake.push_back(waiter);
432                     all_readers = false;
433                 } else {
434                     // We are waking readers and have already woken one writer. Skip this one.
435                     cursor.move_next();
436                 }
437             }
438         }
439     }
440 
441     to_wake
442 }
443 
cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool)444 fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) {
445     let condvar = cv as *const Condvar;
446 
447     // Safe because the thread that owns the waiter being canceled must also own a reference to the
448     // Condvar, which guarantees that this pointer is valid.
449     unsafe { (*condvar).cancel_waiter(waiter, wake_next) }
450 }
451 
452 #[cfg(test)]
453 mod test {
454     use super::*;
455 
456     use std::{
457         future::Future,
458         mem, ptr,
459         rc::Rc,
460         sync::{
461             mpsc::{channel, Sender},
462             Arc,
463         },
464         task::{Context, Poll},
465         thread::{
466             JoinHandle, {self},
467         },
468         time::Duration,
469     };
470 
471     use futures::{
472         channel::oneshot,
473         select,
474         task::{waker_ref, ArcWake},
475         FutureExt,
476     };
477     use futures_executor::{LocalPool, LocalSpawner, ThreadPool};
478     use futures_util::task::LocalSpawnExt;
479 
480     use super::super::super::{block_on, sync::Mutex};
481 
482     // Dummy waker used when we want to manually drive futures.
483     struct TestWaker;
484     impl ArcWake for TestWaker {
wake_by_ref(_arc_self: &Arc<Self>)485         fn wake_by_ref(_arc_self: &Arc<Self>) {}
486     }
487 
488     #[test]
smoke()489     fn smoke() {
490         let cv = Condvar::new();
491         cv.notify_one();
492         cv.notify_all();
493     }
494 
495     #[test]
notify_one()496     fn notify_one() {
497         let mu = Arc::new(Mutex::new(()));
498         let cv = Arc::new(Condvar::new());
499 
500         let mu2 = mu.clone();
501         let cv2 = cv.clone();
502 
503         let guard = block_on(mu.lock());
504         thread::spawn(move || {
505             let _g = block_on(mu2.lock());
506             cv2.notify_one();
507         });
508 
509         let guard = block_on(cv.wait(guard));
510         mem::drop(guard);
511     }
512 
513     #[test]
multi_mutex()514     fn multi_mutex() {
515         const NUM_THREADS: usize = 5;
516 
517         let mu = Arc::new(Mutex::new(false));
518         let cv = Arc::new(Condvar::new());
519 
520         let mut threads = Vec::with_capacity(NUM_THREADS);
521         for _ in 0..NUM_THREADS {
522             let mu = mu.clone();
523             let cv = cv.clone();
524 
525             threads.push(thread::spawn(move || {
526                 let mut ready = block_on(mu.lock());
527                 while !*ready {
528                     ready = block_on(cv.wait(ready));
529                 }
530             }));
531         }
532 
533         let mut g = block_on(mu.lock());
534         *g = true;
535         mem::drop(g);
536         cv.notify_all();
537 
538         threads
539             .into_iter()
540             .try_for_each(JoinHandle::join)
541             .expect("Failed to join threads");
542 
543         // Now use the Condvar with a different mutex.
544         let alt_mu = Arc::new(Mutex::new(None));
545         let alt_mu2 = alt_mu.clone();
546         let cv2 = cv.clone();
547         let handle = thread::spawn(move || {
548             let mut g = block_on(alt_mu2.lock());
549             while g.is_none() {
550                 g = block_on(cv2.wait(g));
551             }
552         });
553 
554         let mut alt_g = block_on(alt_mu.lock());
555         *alt_g = Some(());
556         mem::drop(alt_g);
557         cv.notify_all();
558 
559         handle
560             .join()
561             .expect("Failed to join thread alternate mutex");
562     }
563 
564     #[test]
notify_one_single_thread_async()565     fn notify_one_single_thread_async() {
566         async fn notify(mu: Rc<Mutex<()>>, cv: Rc<Condvar>) {
567             let _g = mu.lock().await;
568             cv.notify_one();
569         }
570 
571         async fn wait(mu: Rc<Mutex<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
572             let mu2 = Rc::clone(&mu);
573             let cv2 = Rc::clone(&cv);
574 
575             let g = mu.lock().await;
576             // Has to be spawned _after_ acquiring the lock to prevent a race
577             // where the notify happens before the waiter has acquired the lock.
578             spawner
579                 .spawn_local(notify(mu2, cv2))
580                 .expect("Failed to spawn `notify` task");
581             let _g = cv.wait(g).await;
582         }
583 
584         let mut ex = LocalPool::new();
585         let spawner = ex.spawner();
586 
587         let mu = Rc::new(Mutex::new(()));
588         let cv = Rc::new(Condvar::new());
589 
590         spawner
591             .spawn_local(wait(mu, cv, spawner.clone()))
592             .expect("Failed to spawn `wait` task");
593 
594         ex.run();
595     }
596 
597     #[test]
notify_one_multi_thread_async()598     fn notify_one_multi_thread_async() {
599         async fn notify(mu: Arc<Mutex<()>>, cv: Arc<Condvar>) {
600             let _g = mu.lock().await;
601             cv.notify_one();
602         }
603 
604         async fn wait(mu: Arc<Mutex<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) {
605             let mu2 = Arc::clone(&mu);
606             let cv2 = Arc::clone(&cv);
607 
608             let g = mu.lock().await;
609             // Has to be spawned _after_ acquiring the lock to prevent a race
610             // where the notify happens before the waiter has acquired the lock.
611             pool.spawn_ok(notify(mu2, cv2));
612             let _g = cv.wait(g).await;
613 
614             tx.send(()).expect("Failed to send completion notification");
615         }
616 
617         let ex = ThreadPool::new().expect("Failed to create ThreadPool");
618 
619         let mu = Arc::new(Mutex::new(()));
620         let cv = Arc::new(Condvar::new());
621 
622         let (tx, rx) = channel();
623         ex.spawn_ok(wait(mu, cv, tx, ex.clone()));
624 
625         rx.recv_timeout(Duration::from_secs(5))
626             .expect("Failed to receive completion notification");
627     }
628 
629     #[test]
notify_one_with_cancel()630     fn notify_one_with_cancel() {
631         const TASKS: usize = 17;
632         const OBSERVERS: usize = 7;
633         const ITERATIONS: usize = 103;
634 
635         async fn observe(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
636             let mut count = mu.read_lock().await;
637             while *count == 0 {
638                 count = cv.wait_read(count).await;
639             }
640             let _ = unsafe { ptr::read_volatile(&*count as *const usize) };
641         }
642 
643         async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
644             let mut count = mu.lock().await;
645             while *count == 0 {
646                 count = cv.wait(count).await;
647             }
648             *count -= 1;
649         }
650 
651         async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
652             for _ in 0..TASKS * OBSERVERS * ITERATIONS {
653                 *mu.lock().await += 1;
654                 cv.notify_one();
655             }
656 
657             done.send(()).expect("Failed to send completion message");
658         }
659 
660         async fn observe_either(
661             mu: Arc<Mutex<usize>>,
662             cv: Arc<Condvar>,
663             alt_mu: Arc<Mutex<usize>>,
664             alt_cv: Arc<Condvar>,
665             done: Sender<()>,
666         ) {
667             for _ in 0..ITERATIONS {
668                 select! {
669                     () = observe(&mu, &cv).fuse() => {},
670                     () = observe(&alt_mu, &alt_cv).fuse() => {},
671                 }
672             }
673 
674             done.send(()).expect("Failed to send completion message");
675         }
676 
677         async fn decrement_either(
678             mu: Arc<Mutex<usize>>,
679             cv: Arc<Condvar>,
680             alt_mu: Arc<Mutex<usize>>,
681             alt_cv: Arc<Condvar>,
682             done: Sender<()>,
683         ) {
684             for _ in 0..ITERATIONS {
685                 select! {
686                     () = decrement(&mu, &cv).fuse() => {},
687                     () = decrement(&alt_mu, &alt_cv).fuse() => {},
688                 }
689             }
690 
691             done.send(()).expect("Failed to send completion message");
692         }
693 
694         let ex = ThreadPool::new().expect("Failed to create ThreadPool");
695 
696         let mu = Arc::new(Mutex::new(0usize));
697         let alt_mu = Arc::new(Mutex::new(0usize));
698 
699         let cv = Arc::new(Condvar::new());
700         let alt_cv = Arc::new(Condvar::new());
701 
702         let (tx, rx) = channel();
703         for _ in 0..TASKS {
704             ex.spawn_ok(decrement_either(
705                 Arc::clone(&mu),
706                 Arc::clone(&cv),
707                 Arc::clone(&alt_mu),
708                 Arc::clone(&alt_cv),
709                 tx.clone(),
710             ));
711         }
712 
713         for _ in 0..OBSERVERS {
714             ex.spawn_ok(observe_either(
715                 Arc::clone(&mu),
716                 Arc::clone(&cv),
717                 Arc::clone(&alt_mu),
718                 Arc::clone(&alt_cv),
719                 tx.clone(),
720             ));
721         }
722 
723         ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
724         ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
725 
726         for _ in 0..TASKS + OBSERVERS + 2 {
727             if let Err(e) = rx.recv_timeout(Duration::from_secs(20)) {
728                 panic!("Error while waiting for threads to complete: {}", e);
729             }
730         }
731 
732         assert_eq!(
733             *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
734             (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS)
735         );
736         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
737         assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
738     }
739 
740     #[test]
notify_all_with_cancel()741     fn notify_all_with_cancel() {
742         const TASKS: usize = 17;
743         const ITERATIONS: usize = 103;
744 
745         async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
746             let mut count = mu.lock().await;
747             while *count == 0 {
748                 count = cv.wait(count).await;
749             }
750             *count -= 1;
751         }
752 
753         async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
754             for _ in 0..TASKS * ITERATIONS {
755                 *mu.lock().await += 1;
756                 cv.notify_all();
757             }
758 
759             done.send(()).expect("Failed to send completion message");
760         }
761 
762         async fn decrement_either(
763             mu: Arc<Mutex<usize>>,
764             cv: Arc<Condvar>,
765             alt_mu: Arc<Mutex<usize>>,
766             alt_cv: Arc<Condvar>,
767             done: Sender<()>,
768         ) {
769             for _ in 0..ITERATIONS {
770                 select! {
771                     () = decrement(&mu, &cv).fuse() => {},
772                     () = decrement(&alt_mu, &alt_cv).fuse() => {},
773                 }
774             }
775 
776             done.send(()).expect("Failed to send completion message");
777         }
778 
779         let ex = ThreadPool::new().expect("Failed to create ThreadPool");
780 
781         let mu = Arc::new(Mutex::new(0usize));
782         let alt_mu = Arc::new(Mutex::new(0usize));
783 
784         let cv = Arc::new(Condvar::new());
785         let alt_cv = Arc::new(Condvar::new());
786 
787         let (tx, rx) = channel();
788         for _ in 0..TASKS {
789             ex.spawn_ok(decrement_either(
790                 Arc::clone(&mu),
791                 Arc::clone(&cv),
792                 Arc::clone(&alt_mu),
793                 Arc::clone(&alt_cv),
794                 tx.clone(),
795             ));
796         }
797 
798         ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
799         ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
800 
801         for _ in 0..TASKS + 2 {
802             if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
803                 panic!("Error while waiting for threads to complete: {}", e);
804             }
805         }
806 
807         assert_eq!(
808             *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
809             TASKS * ITERATIONS
810         );
811         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
812         assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
813     }
814     #[test]
notify_all()815     fn notify_all() {
816         const THREADS: usize = 13;
817 
818         let mu = Arc::new(Mutex::new(0));
819         let cv = Arc::new(Condvar::new());
820         let (tx, rx) = channel();
821 
822         let mut threads = Vec::with_capacity(THREADS);
823         for _ in 0..THREADS {
824             let mu2 = mu.clone();
825             let cv2 = cv.clone();
826             let tx2 = tx.clone();
827 
828             threads.push(thread::spawn(move || {
829                 let mut count = block_on(mu2.lock());
830                 *count += 1;
831                 if *count == THREADS {
832                     tx2.send(()).unwrap();
833                 }
834 
835                 while *count != 0 {
836                     count = block_on(cv2.wait(count));
837                 }
838             }));
839         }
840 
841         mem::drop(tx);
842 
843         // Wait till all threads have started.
844         rx.recv_timeout(Duration::from_secs(5)).unwrap();
845 
846         let mut count = block_on(mu.lock());
847         *count = 0;
848         mem::drop(count);
849         cv.notify_all();
850 
851         for t in threads {
852             t.join().unwrap();
853         }
854     }
855 
856     #[test]
notify_all_single_thread_async()857     fn notify_all_single_thread_async() {
858         const TASKS: usize = 13;
859 
860         async fn reset(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>) {
861             let mut count = mu.lock().await;
862             *count = 0;
863             cv.notify_all();
864         }
865 
866         async fn watcher(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
867             let mut count = mu.lock().await;
868             *count += 1;
869             if *count == TASKS {
870                 spawner
871                     .spawn_local(reset(mu.clone(), cv.clone()))
872                     .expect("Failed to spawn reset task");
873             }
874 
875             while *count != 0 {
876                 count = cv.wait(count).await;
877             }
878         }
879 
880         let mut ex = LocalPool::new();
881         let spawner = ex.spawner();
882 
883         let mu = Rc::new(Mutex::new(0));
884         let cv = Rc::new(Condvar::new());
885 
886         for _ in 0..TASKS {
887             spawner
888                 .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone()))
889                 .expect("Failed to spawn watcher task");
890         }
891 
892         ex.run();
893     }
894 
895     #[test]
notify_all_multi_thread_async()896     fn notify_all_multi_thread_async() {
897         const TASKS: usize = 13;
898 
899         async fn reset(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
900             let mut count = mu.lock().await;
901             *count = 0;
902             cv.notify_all();
903         }
904 
905         async fn watcher(
906             mu: Arc<Mutex<usize>>,
907             cv: Arc<Condvar>,
908             pool: ThreadPool,
909             tx: Sender<()>,
910         ) {
911             let mut count = mu.lock().await;
912             *count += 1;
913             if *count == TASKS {
914                 pool.spawn_ok(reset(mu.clone(), cv.clone()));
915             }
916 
917             while *count != 0 {
918                 count = cv.wait(count).await;
919             }
920 
921             tx.send(()).expect("Failed to send completion notification");
922         }
923 
924         let pool = ThreadPool::new().expect("Failed to create ThreadPool");
925 
926         let mu = Arc::new(Mutex::new(0));
927         let cv = Arc::new(Condvar::new());
928 
929         let (tx, rx) = channel();
930         for _ in 0..TASKS {
931             pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone()));
932         }
933 
934         for _ in 0..TASKS {
935             rx.recv_timeout(Duration::from_secs(5))
936                 .expect("Failed to receive completion notification");
937         }
938     }
939 
940     #[test]
wake_all_readers()941     fn wake_all_readers() {
942         async fn read(mu: Arc<Mutex<bool>>, cv: Arc<Condvar>) {
943             let mut ready = mu.read_lock().await;
944             while !*ready {
945                 ready = cv.wait_read(ready).await;
946             }
947         }
948 
949         let mu = Arc::new(Mutex::new(false));
950         let cv = Arc::new(Condvar::new());
951         let mut readers = [
952             Box::pin(read(mu.clone(), cv.clone())),
953             Box::pin(read(mu.clone(), cv.clone())),
954             Box::pin(read(mu.clone(), cv.clone())),
955             Box::pin(read(mu.clone(), cv.clone())),
956         ];
957 
958         let arc_waker = Arc::new(TestWaker);
959         let waker = waker_ref(&arc_waker);
960         let mut cx = Context::from_waker(&waker);
961 
962         // First have all the readers wait on the Condvar.
963         for r in &mut readers {
964             if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
965                 panic!("reader unexpectedly ready");
966             }
967         }
968 
969         assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
970 
971         // Now make the condition true and notify the condvar. Even though we will call notify_one,
972         // all the readers should be woken up.
973         *block_on(mu.lock()) = true;
974         cv.notify_one();
975 
976         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
977 
978         // All readers should now be able to complete.
979         for r in &mut readers {
980             if r.as_mut().poll(&mut cx).is_pending() {
981                 panic!("reader unable to complete");
982             }
983         }
984     }
985 
986     #[test]
cancel_before_notify()987     fn cancel_before_notify() {
988         async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
989             let mut count = mu.lock().await;
990 
991             while *count == 0 {
992                 count = cv.wait(count).await;
993             }
994 
995             *count -= 1;
996         }
997 
998         let mu = Arc::new(Mutex::new(0));
999         let cv = Arc::new(Condvar::new());
1000 
1001         let arc_waker = Arc::new(TestWaker);
1002         let waker = waker_ref(&arc_waker);
1003         let mut cx = Context::from_waker(&waker);
1004 
1005         let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1006         let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1007 
1008         if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1009             panic!("future unexpectedly ready");
1010         }
1011         if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1012             panic!("future unexpectedly ready");
1013         }
1014         assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1015 
1016         *block_on(mu.lock()) = 2;
1017         // Drop fut1 before notifying the cv.
1018         mem::drop(fut1);
1019         cv.notify_one();
1020 
1021         // fut2 should now be ready to complete.
1022         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1023 
1024         if fut2.as_mut().poll(&mut cx).is_pending() {
1025             panic!("future unable to complete");
1026         }
1027 
1028         assert_eq!(*block_on(mu.lock()), 1);
1029     }
1030 
1031     #[test]
cancel_after_notify_one()1032     fn cancel_after_notify_one() {
1033         async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
1034             let mut count = mu.lock().await;
1035 
1036             while *count == 0 {
1037                 count = cv.wait(count).await;
1038             }
1039 
1040             *count -= 1;
1041         }
1042 
1043         let mu = Arc::new(Mutex::new(0));
1044         let cv = Arc::new(Condvar::new());
1045 
1046         let arc_waker = Arc::new(TestWaker);
1047         let waker = waker_ref(&arc_waker);
1048         let mut cx = Context::from_waker(&waker);
1049 
1050         let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1051         let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1052 
1053         if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1054             panic!("future unexpectedly ready");
1055         }
1056         if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1057             panic!("future unexpectedly ready");
1058         }
1059         assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1060 
1061         *block_on(mu.lock()) = 2;
1062         cv.notify_one();
1063 
1064         // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2.
1065         mem::drop(fut1);
1066         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1067 
1068         if fut2.as_mut().poll(&mut cx).is_pending() {
1069             panic!("future unable to complete");
1070         }
1071 
1072         assert_eq!(*block_on(mu.lock()), 1);
1073     }
1074 
1075     #[test]
cancel_after_notify_all()1076     fn cancel_after_notify_all() {
1077         async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
1078             let mut count = mu.lock().await;
1079 
1080             while *count == 0 {
1081                 count = cv.wait(count).await;
1082             }
1083 
1084             *count -= 1;
1085         }
1086 
1087         let mu = Arc::new(Mutex::new(0));
1088         let cv = Arc::new(Condvar::new());
1089 
1090         let arc_waker = Arc::new(TestWaker);
1091         let waker = waker_ref(&arc_waker);
1092         let mut cx = Context::from_waker(&waker);
1093 
1094         let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1095         let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1096 
1097         if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1098             panic!("future unexpectedly ready");
1099         }
1100         if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1101             panic!("future unexpectedly ready");
1102         }
1103         assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1104 
1105         let mut count = block_on(mu.lock());
1106         *count = 2;
1107 
1108         // Notify the cv while holding the lock. This should wake up both waiters.
1109         cv.notify_all();
1110         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1111 
1112         mem::drop(count);
1113 
1114         mem::drop(fut1);
1115 
1116         if fut2.as_mut().poll(&mut cx).is_pending() {
1117             panic!("future unable to complete");
1118         }
1119 
1120         assert_eq!(*block_on(mu.lock()), 1);
1121     }
1122 
1123     #[test]
timed_wait()1124     fn timed_wait() {
1125         async fn wait_deadline(
1126             mu: Arc<Mutex<usize>>,
1127             cv: Arc<Condvar>,
1128             timeout: oneshot::Receiver<()>,
1129         ) {
1130             let mut count = mu.lock().await;
1131 
1132             if *count == 0 {
1133                 let mut rx = timeout.fuse();
1134 
1135                 while *count == 0 {
1136                     select! {
1137                         res = rx => {
1138                             if let Err(e) = res {
1139                                 panic!("Error while receiving timeout notification: {}", e);
1140                             }
1141 
1142                             return;
1143                         },
1144                         c = cv.wait(count).fuse() => count = c,
1145                     }
1146                 }
1147             }
1148 
1149             *count += 1;
1150         }
1151 
1152         let mu = Arc::new(Mutex::new(0));
1153         let cv = Arc::new(Condvar::new());
1154 
1155         let arc_waker = Arc::new(TestWaker);
1156         let waker = waker_ref(&arc_waker);
1157         let mut cx = Context::from_waker(&waker);
1158 
1159         let (tx, rx) = oneshot::channel();
1160         let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx));
1161 
1162         if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) {
1163             panic!("wait_deadline unexpectedly ready");
1164         }
1165 
1166         assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS);
1167 
1168         // Signal the channel, which should cancel the wait.
1169         tx.send(()).expect("Failed to send wakeup");
1170 
1171         // Wait for the timer to run out.
1172         if wait.as_mut().poll(&mut cx).is_pending() {
1173             panic!("wait_deadline unable to complete in time");
1174         }
1175 
1176         assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1177         assert_eq!(*block_on(mu.lock()), 0);
1178     }
1179 }
1180