// Copyright 2020 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. use std::cell::UnsafeCell; use std::mem; use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; use std::sync::Arc; use crate::sync::mu::{MutexGuard, MutexReadGuard, RawMutex}; use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor}; const SPINLOCK: usize = 1 << 0; const HAS_WAITERS: usize = 1 << 1; /// A primitive to wait for an event to occur without consuming CPU time. /// /// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some /// condition to become true. The condition must always be verified while holding the `Mutex` lock. /// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on /// the `Condvar`. /// /// # Examples /// /// ```edition2018 /// use std::sync::Arc; /// use std::thread; /// use std::sync::mpsc::channel; /// /// use libchromeos::sync::{block_on, Condvar, Mutex}; /// /// const N: usize = 13; /// /// // Spawn a few threads to increment a shared variable (non-atomically), and /// // let all threads waiting on the Condvar know once the increments are done. /// let data = Arc::new(Mutex::new(0)); /// let cv = Arc::new(Condvar::new()); /// /// for _ in 0..N { /// let (data, cv) = (data.clone(), cv.clone()); /// thread::spawn(move || { /// let mut data = block_on(data.lock()); /// *data += 1; /// if *data == N { /// cv.notify_all(); /// } /// }); /// } /// /// let mut val = block_on(data.lock()); /// while *val != N { /// val = block_on(cv.wait(val)); /// } /// ``` pub struct Condvar { state: AtomicUsize, waiters: UnsafeCell, mu: UnsafeCell, } impl Condvar { /// Creates a new condition variable ready to be waited on and notified. pub fn new() -> Condvar { Condvar { state: AtomicUsize::new(0), waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())), mu: UnsafeCell::new(0), } } /// Block the current thread until this `Condvar` is notified by another thread. /// /// This method will atomically unlock the `Mutex` held by `guard` and then block the current /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up /// the thread. /// /// To allow for more efficient scheduling, this call may return even when the programmer /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a /// loop that checks the predicate before continuing. /// /// Callers that are not in an async context may wish to use the `block_on` method to block the /// thread until the `Condvar` is notified. /// /// # Panics /// /// This method will panic if used with more than one `Mutex` at the same time. /// /// # Examples /// /// ``` /// # use std::sync::Arc; /// # use std::thread; /// /// # use libchromeos::sync::{block_on, Condvar, Mutex}; /// /// # let mu = Arc::new(Mutex::new(false)); /// # let cv = Arc::new(Condvar::new()); /// # let (mu2, cv2) = (mu.clone(), cv.clone()); /// /// # let t = thread::spawn(move || { /// # *block_on(mu2.lock()) = true; /// # cv2.notify_all(); /// # }); /// /// let mut ready = block_on(mu.lock()); /// while !*ready { /// ready = block_on(cv.wait(ready)); /// } /// /// # t.join().expect("failed to join thread"); /// ``` // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code // that doesn't compile. #[allow(clippy::needless_lifetimes)] pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> { let waiter = Arc::new(Waiter::new( WaiterKind::Exclusive, cancel_waiter, self as *const Condvar as usize, WaitingFor::Condvar, )); self.add_waiter(waiter.clone(), guard.as_raw_mutex()); // Get a reference to the mutex and then drop the lock. let mu = guard.into_inner(); // Wait to be woken up. waiter.wait().await; // Now re-acquire the lock. mu.lock_from_cv().await } /// Like `wait()` but takes and returns a `MutexReadGuard` instead. // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code // that doesn't compile. #[allow(clippy::needless_lifetimes)] pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> { let waiter = Arc::new(Waiter::new( WaiterKind::Shared, cancel_waiter, self as *const Condvar as usize, WaitingFor::Condvar, )); self.add_waiter(waiter.clone(), guard.as_raw_mutex()); // Get a reference to the mutex and then drop the lock. let mu = guard.into_inner(); // Wait to be woken up. waiter.wait().await; // Now re-acquire the lock. mu.read_lock_from_cv().await } fn add_waiter(&self, waiter: Arc, raw_mutex: &RawMutex) { // Acquire the spin lock. let mut oldstate = self.state.load(Ordering::Relaxed); while (oldstate & SPINLOCK) != 0 || self.state.compare_and_swap( oldstate, oldstate | SPINLOCK | HAS_WAITERS, Ordering::Acquire, ) != oldstate { spin_loop_hint(); oldstate = self.state.load(Ordering::Relaxed); } // Safe because the spin lock guarantees exclusive access and the reference does not escape // this function. let mu = unsafe { &mut *self.mu.get() }; let muptr = raw_mutex as *const RawMutex as usize; match *mu { 0 => *mu = muptr, p if p == muptr => {} _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"), } // Safe because the spin lock guarantees exclusive access. unsafe { (*self.waiters.get()).push_back(waiter) }; // Release the spin lock. Use a direct store here because no other thread can modify // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier // because we just added a waiter. self.state.store(HAS_WAITERS, Ordering::Release); } /// Notify at most one thread currently waiting on the `Condvar`. /// /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to /// `wait`. /// /// Unlike more traditional condition variable interfaces, this method requires a reference to /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking /// a reference to the `Mutex` here allows us to make some optimizations that can improve /// performance by reducing unnecessary wakeups. pub fn notify_one(&self) { let mut oldstate = self.state.load(Ordering::Relaxed); if (oldstate & HAS_WAITERS) == 0 { // No waiters. return; } while (oldstate & SPINLOCK) != 0 || self .state .compare_and_swap(oldstate, oldstate | SPINLOCK, Ordering::Acquire) != oldstate { spin_loop_hint(); oldstate = self.state.load(Ordering::Relaxed); } // Safe because the spin lock guarantees exclusive access and the reference does not escape // this function. let waiters = unsafe { &mut *self.waiters.get() }; let (mut wake_list, all_readers) = get_wake_list(waiters); // Safe because the spin lock guarantees exclusive access. let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; let newstate = if waiters.is_empty() { // Also clear the mutex associated with this Condvar since there are no longer any // waiters. Safe because the spin lock guarantees exclusive access. unsafe { *self.mu.get() = 0 }; // We are releasing the spin lock and there are no more waiters so we can clear all bits // in `self.state`. 0 } else { // There are still waiters so we need to keep the HAS_WAITERS bit in the state. HAS_WAITERS }; // Try to transfer waiters before releasing the spin lock. if !wake_list.is_empty() { // Safe because there was a waiter in the queue and the thread that owns the waiter also // owns a reference to the Mutex, guaranteeing that the pointer is valid. unsafe { (*muptr).transfer_waiters(&mut wake_list, all_readers) }; } // Release the spin lock. self.state.store(newstate, Ordering::Release); // Now wake any waiters still left in the wake list. for w in wake_list { w.wake(); } } /// Notify all threads currently waiting on the `Condvar`. /// /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`. /// /// Unlike more traditional condition variable interfaces, this method requires a reference to /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking /// a reference to the `Mutex` here allows us to make some optimizations that can improve /// performance by reducing unnecessary wakeups. pub fn notify_all(&self) { let mut oldstate = self.state.load(Ordering::Relaxed); if (oldstate & HAS_WAITERS) == 0 { // No waiters. return; } while (oldstate & SPINLOCK) != 0 || self .state .compare_and_swap(oldstate, oldstate | SPINLOCK, Ordering::Acquire) != oldstate { spin_loop_hint(); oldstate = self.state.load(Ordering::Relaxed); } // Safe because the spin lock guarantees exclusive access to `self.waiters`. let mut wake_list = unsafe { (*self.waiters.get()).take() }; // Safe because the spin lock guarantees exclusive access. let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe // because we the spin lock guarantees exclusive access. unsafe { *self.mu.get() = 0 }; // Try to transfer waiters before releasing the spin lock. if !wake_list.is_empty() { // Safe because there was a waiter in the queue and the thread that owns the waiter also // owns a reference to the Mutex, guaranteeing that the pointer is valid. unsafe { (*muptr).transfer_waiters(&mut wake_list, false) }; } // Mark any waiters left as no longer waiting for the Condvar. for w in &wake_list { w.set_waiting_for(WaitingFor::None); } // Release the spin lock. We can clear all bits in the state since we took all the waiters. self.state.store(0, Ordering::Release); // Now wake any waiters still left in the wake list. for w in wake_list { w.wake(); } } fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) -> bool { let mut oldstate = self.state.load(Ordering::Relaxed); while oldstate & SPINLOCK != 0 || self .state .compare_exchange_weak( oldstate, oldstate | SPINLOCK, Ordering::Acquire, Ordering::Relaxed, ) .is_err() { spin_loop_hint(); oldstate = self.state.load(Ordering::Relaxed); } // Safe because the spin lock provides exclusive access and the reference does not escape // this function. let waiters = unsafe { &mut *self.waiters.get() }; let waiting_for = waiter.is_waiting_for(); if waiting_for == WaitingFor::Mutex { // The waiter was moved to the mutex's list. Retry the cancel. let set_on_release = if waiters.is_empty() { // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe // because we the spin lock guarantees exclusive access. unsafe { *self.mu.get() = 0 }; 0 } else { HAS_WAITERS }; self.state.store(set_on_release, Ordering::Release); false } else { // Don't drop the old waiter now as we're still holding the spin lock. let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar { // Safe because we know that the waiter is still linked and is waiting for the Condvar, // which guarantees that it is still in `self.waiters`. let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) }; cursor.remove() } else { None }; let (mut wake_list, all_readers) = if wake_next || waiting_for == WaitingFor::None { // Either the waiter was already woken or it's been removed from the condvar's waiter // list and is going to be woken. Either way, we need to wake up another thread. get_wake_list(waiters) } else { (WaiterList::new(WaiterAdapter::new()), false) }; // Safe because the spin lock guarantees exclusive access. let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; // Try to transfer waiters before releasing the spin lock. if !wake_list.is_empty() { // Safe because there was a waiter in the queue and the thread that owns the waiter also // owns a reference to the Mutex, guaranteeing that the pointer is valid. unsafe { (*muptr).transfer_waiters(&mut wake_list, all_readers) }; } let set_on_release = if waiters.is_empty() { // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe // because we the spin lock guarantees exclusive access. unsafe { *self.mu.get() = 0 }; 0 } else { HAS_WAITERS }; self.state.store(set_on_release, Ordering::Release); // Now wake any waiters still left in the wake list. for w in wake_list { w.wake(); } mem::drop(old_waiter); true } } } unsafe impl Send for Condvar {} unsafe impl Sync for Condvar {} impl Default for Condvar { fn default() -> Self { Self::new() } } // Scan `waiters` and return all waiters that should be woken up. If all waiters in the returned // wait list are readers then the returned bool will be true. // // If the first waiter is trying to acquire a shared lock, then all waiters in the list that are // waiting for a shared lock are also woken up. In addition one writer is woken up, if possible. // // If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and // the rest of the list is not scanned. fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, bool) { let mut to_wake = WaiterList::new(WaiterAdapter::new()); let mut cursor = waiters.front_mut(); let mut waking_readers = false; let mut all_readers = true; while let Some(w) = cursor.get() { match w.kind() { WaiterKind::Exclusive if !waking_readers => { // This is the first waiter and it's a writer. No need to check the other waiters. // Also mark the waiter as having been removed from the Condvar's waiter list. let waiter = cursor.remove().unwrap(); waiter.set_waiting_for(WaitingFor::None); to_wake.push_back(waiter); all_readers = false; break; } WaiterKind::Shared => { // This is a reader and the first waiter in the list was not a writer so wake up all // the readers in the wait list. let waiter = cursor.remove().unwrap(); waiter.set_waiting_for(WaitingFor::None); to_wake.push_back(waiter); waking_readers = true; } WaiterKind::Exclusive => { debug_assert!(waking_readers); if all_readers { // We are waking readers but we need to ensure that at least one writer is woken // up. Since we haven't yet woken up a writer, wake up this one. let waiter = cursor.remove().unwrap(); waiter.set_waiting_for(WaitingFor::None); to_wake.push_back(waiter); all_readers = false; } else { // We are waking readers and have already woken one writer. Skip this one. cursor.move_next(); } } } } (to_wake, all_readers) } fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) -> bool { let condvar = cv as *const Condvar; // Safe because the thread that owns the waiter being canceled must also own a reference to the // Condvar, which guarantees that this pointer is valid. unsafe { (*condvar).cancel_waiter(waiter, wake_next) } } #[cfg(test)] mod test { use super::*; use std::future::Future; use std::mem; use std::ptr; use std::rc::Rc; use std::sync::mpsc::{channel, Sender}; use std::sync::Arc; use std::task::{Context, Poll}; use std::thread::{self, JoinHandle}; use std::time::Duration; use futures::channel::oneshot; use futures::task::{waker_ref, ArcWake}; use futures::{select, FutureExt}; use futures_executor::{LocalPool, LocalSpawner, ThreadPool}; use futures_util::task::LocalSpawnExt; use crate::sync::{block_on, Mutex}; // Dummy waker used when we want to manually drive futures. struct TestWaker; impl ArcWake for TestWaker { fn wake_by_ref(_arc_self: &Arc) {} } #[test] fn smoke() { let cv = Condvar::new(); cv.notify_one(); cv.notify_all(); } #[test] fn notify_one() { let mu = Arc::new(Mutex::new(())); let cv = Arc::new(Condvar::new()); let mu2 = mu.clone(); let cv2 = cv.clone(); let guard = block_on(mu.lock()); thread::spawn(move || { let _g = block_on(mu2.lock()); cv2.notify_one(); }); let guard = block_on(cv.wait(guard)); mem::drop(guard); } #[test] fn multi_mutex() { const NUM_THREADS: usize = 5; let mu = Arc::new(Mutex::new(false)); let cv = Arc::new(Condvar::new()); let mut threads = Vec::with_capacity(NUM_THREADS); for _ in 0..NUM_THREADS { let mu = mu.clone(); let cv = cv.clone(); threads.push(thread::spawn(move || { let mut ready = block_on(mu.lock()); while !*ready { ready = block_on(cv.wait(ready)); } })); } let mut g = block_on(mu.lock()); *g = true; mem::drop(g); cv.notify_all(); threads .into_iter() .map(JoinHandle::join) .collect::>() .expect("Failed to join threads"); // Now use the Condvar with a different mutex. let alt_mu = Arc::new(Mutex::new(None)); let alt_mu2 = alt_mu.clone(); let cv2 = cv.clone(); let handle = thread::spawn(move || { let mut g = block_on(alt_mu2.lock()); while g.is_none() { g = block_on(cv2.wait(g)); } }); let mut alt_g = block_on(alt_mu.lock()); *alt_g = Some(()); mem::drop(alt_g); cv.notify_all(); handle .join() .expect("Failed to join thread alternate mutex"); } #[test] fn notify_one_single_thread_async() { async fn notify(mu: Rc>, cv: Rc) { let _g = mu.lock().await; cv.notify_one(); } async fn wait(mu: Rc>, cv: Rc, spawner: LocalSpawner) { let mu2 = Rc::clone(&mu); let cv2 = Rc::clone(&cv); let g = mu.lock().await; // Has to be spawned _after_ acquiring the lock to prevent a race // where the notify happens before the waiter has acquired the lock. spawner .spawn_local(notify(mu2, cv2)) .expect("Failed to spawn `notify` task"); let _g = cv.wait(g).await; } let mut ex = LocalPool::new(); let spawner = ex.spawner(); let mu = Rc::new(Mutex::new(())); let cv = Rc::new(Condvar::new()); spawner .spawn_local(wait(mu, cv, spawner.clone())) .expect("Failed to spawn `wait` task"); ex.run(); } #[test] fn notify_one_multi_thread_async() { async fn notify(mu: Arc>, cv: Arc) { let _g = mu.lock().await; cv.notify_one(); } async fn wait(mu: Arc>, cv: Arc, tx: Sender<()>, pool: ThreadPool) { let mu2 = Arc::clone(&mu); let cv2 = Arc::clone(&cv); let g = mu.lock().await; // Has to be spawned _after_ acquiring the lock to prevent a race // where the notify happens before the waiter has acquired the lock. pool.spawn_ok(notify(mu2, cv2)); let _g = cv.wait(g).await; tx.send(()).expect("Failed to send completion notification"); } let ex = ThreadPool::new().expect("Failed to create ThreadPool"); let mu = Arc::new(Mutex::new(())); let cv = Arc::new(Condvar::new()); let (tx, rx) = channel(); ex.spawn_ok(wait(mu, cv, tx, ex.clone())); rx.recv_timeout(Duration::from_secs(5)) .expect("Failed to receive completion notification"); } #[test] fn notify_one_with_cancel() { const TASKS: usize = 17; const OBSERVERS: usize = 7; const ITERATIONS: usize = 103; async fn observe(mu: &Arc>, cv: &Arc) { let mut count = mu.read_lock().await; while *count == 0 { count = cv.wait_read(count).await; } let _ = unsafe { ptr::read_volatile(&*count as *const usize) }; } async fn decrement(mu: &Arc>, cv: &Arc) { let mut count = mu.lock().await; while *count == 0 { count = cv.wait(count).await; } *count -= 1; } async fn increment(mu: Arc>, cv: Arc, done: Sender<()>) { for _ in 0..TASKS * OBSERVERS * ITERATIONS { *mu.lock().await += 1; cv.notify_one(); } done.send(()).expect("Failed to send completion message"); } async fn observe_either( mu: Arc>, cv: Arc, alt_mu: Arc>, alt_cv: Arc, done: Sender<()>, ) { for _ in 0..ITERATIONS { select! { () = observe(&mu, &cv).fuse() => {}, () = observe(&alt_mu, &alt_cv).fuse() => {}, } } done.send(()).expect("Failed to send completion message"); } async fn decrement_either( mu: Arc>, cv: Arc, alt_mu: Arc>, alt_cv: Arc, done: Sender<()>, ) { for _ in 0..ITERATIONS { select! { () = decrement(&mu, &cv).fuse() => {}, () = decrement(&alt_mu, &alt_cv).fuse() => {}, } } done.send(()).expect("Failed to send completion message"); } let ex = ThreadPool::new().expect("Failed to create ThreadPool"); let mu = Arc::new(Mutex::new(0usize)); let alt_mu = Arc::new(Mutex::new(0usize)); let cv = Arc::new(Condvar::new()); let alt_cv = Arc::new(Condvar::new()); let (tx, rx) = channel(); for _ in 0..TASKS { ex.spawn_ok(decrement_either( Arc::clone(&mu), Arc::clone(&cv), Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx.clone(), )); } for _ in 0..OBSERVERS { ex.spawn_ok(observe_either( Arc::clone(&mu), Arc::clone(&cv), Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx.clone(), )); } ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); for _ in 0..TASKS + OBSERVERS + 2 { if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { panic!("Error while waiting for threads to complete: {}", e); } } assert_eq!( *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS) ); assert_eq!(cv.state.load(Ordering::Relaxed), 0); assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); } #[test] fn notify_all_with_cancel() { const TASKS: usize = 17; const ITERATIONS: usize = 103; async fn decrement(mu: &Arc>, cv: &Arc) { let mut count = mu.lock().await; while *count == 0 { count = cv.wait(count).await; } *count -= 1; } async fn increment(mu: Arc>, cv: Arc, done: Sender<()>) { for _ in 0..TASKS * ITERATIONS { *mu.lock().await += 1; cv.notify_all(); } done.send(()).expect("Failed to send completion message"); } async fn decrement_either( mu: Arc>, cv: Arc, alt_mu: Arc>, alt_cv: Arc, done: Sender<()>, ) { for _ in 0..ITERATIONS { select! { () = decrement(&mu, &cv).fuse() => {}, () = decrement(&alt_mu, &alt_cv).fuse() => {}, } } done.send(()).expect("Failed to send completion message"); } let ex = ThreadPool::new().expect("Failed to create ThreadPool"); let mu = Arc::new(Mutex::new(0usize)); let alt_mu = Arc::new(Mutex::new(0usize)); let cv = Arc::new(Condvar::new()); let alt_cv = Arc::new(Condvar::new()); let (tx, rx) = channel(); for _ in 0..TASKS { ex.spawn_ok(decrement_either( Arc::clone(&mu), Arc::clone(&cv), Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx.clone(), )); } ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); for _ in 0..TASKS + 2 { if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { panic!("Error while waiting for threads to complete: {}", e); } } assert_eq!( *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), TASKS * ITERATIONS ); assert_eq!(cv.state.load(Ordering::Relaxed), 0); assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); } #[test] fn notify_all() { const THREADS: usize = 13; let mu = Arc::new(Mutex::new(0)); let cv = Arc::new(Condvar::new()); let (tx, rx) = channel(); let mut threads = Vec::with_capacity(THREADS); for _ in 0..THREADS { let mu2 = mu.clone(); let cv2 = cv.clone(); let tx2 = tx.clone(); threads.push(thread::spawn(move || { let mut count = block_on(mu2.lock()); *count += 1; if *count == THREADS { tx2.send(()).unwrap(); } while *count != 0 { count = block_on(cv2.wait(count)); } })); } mem::drop(tx); // Wait till all threads have started. rx.recv_timeout(Duration::from_secs(5)).unwrap(); let mut count = block_on(mu.lock()); *count = 0; mem::drop(count); cv.notify_all(); for t in threads { t.join().unwrap(); } } #[test] fn notify_all_single_thread_async() { const TASKS: usize = 13; async fn reset(mu: Rc>, cv: Rc) { let mut count = mu.lock().await; *count = 0; cv.notify_all(); } async fn watcher(mu: Rc>, cv: Rc, spawner: LocalSpawner) { let mut count = mu.lock().await; *count += 1; if *count == TASKS { spawner .spawn_local(reset(mu.clone(), cv.clone())) .expect("Failed to spawn reset task"); } while *count != 0 { count = cv.wait(count).await; } } let mut ex = LocalPool::new(); let spawner = ex.spawner(); let mu = Rc::new(Mutex::new(0)); let cv = Rc::new(Condvar::new()); for _ in 0..TASKS { spawner .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone())) .expect("Failed to spawn watcher task"); } ex.run(); } #[test] fn notify_all_multi_thread_async() { const TASKS: usize = 13; async fn reset(mu: Arc>, cv: Arc) { let mut count = mu.lock().await; *count = 0; cv.notify_all(); } async fn watcher( mu: Arc>, cv: Arc, pool: ThreadPool, tx: Sender<()>, ) { let mut count = mu.lock().await; *count += 1; if *count == TASKS { pool.spawn_ok(reset(mu.clone(), cv.clone())); } while *count != 0 { count = cv.wait(count).await; } tx.send(()).expect("Failed to send completion notification"); } let pool = ThreadPool::new().expect("Failed to create ThreadPool"); let mu = Arc::new(Mutex::new(0)); let cv = Arc::new(Condvar::new()); let (tx, rx) = channel(); for _ in 0..TASKS { pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone())); } for _ in 0..TASKS { rx.recv_timeout(Duration::from_secs(5)) .expect("Failed to receive completion notification"); } } #[test] fn wake_all_readers() { async fn read(mu: Arc>, cv: Arc) { let mut ready = mu.read_lock().await; while !*ready { ready = cv.wait_read(ready).await; } } let mu = Arc::new(Mutex::new(false)); let cv = Arc::new(Condvar::new()); let mut readers = [ Box::pin(read(mu.clone(), cv.clone())), Box::pin(read(mu.clone(), cv.clone())), Box::pin(read(mu.clone(), cv.clone())), Box::pin(read(mu.clone(), cv.clone())), ]; let arc_waker = Arc::new(TestWaker); let waker = waker_ref(&arc_waker); let mut cx = Context::from_waker(&waker); // First have all the readers wait on the Condvar. for r in &mut readers { if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { panic!("reader unexpectedly ready"); } } assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); // Now make the condition true and notify the condvar. Even though we will call notify_one, // all the readers should be woken up. *block_on(mu.lock()) = true; cv.notify_one(); assert_eq!(cv.state.load(Ordering::Relaxed), 0); // All readers should now be able to complete. for r in &mut readers { if let Poll::Pending = r.as_mut().poll(&mut cx) { panic!("reader unable to complete"); } } } #[test] fn cancel_before_notify() { async fn dec(mu: Arc>, cv: Arc) { let mut count = mu.lock().await; while *count == 0 { count = cv.wait(count).await; } *count -= 1; } let mu = Arc::new(Mutex::new(0)); let cv = Arc::new(Condvar::new()); let arc_waker = Arc::new(TestWaker); let waker = waker_ref(&arc_waker); let mut cx = Context::from_waker(&waker); let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); *block_on(mu.lock()) = 2; // Drop fut1 before notifying the cv. mem::drop(fut1); cv.notify_one(); // fut2 should now be ready to complete. assert_eq!(cv.state.load(Ordering::Relaxed), 0); if let Poll::Pending = fut2.as_mut().poll(&mut cx) { panic!("future unable to complete"); } assert_eq!(*block_on(mu.lock()), 1); } #[test] fn cancel_after_notify() { async fn dec(mu: Arc>, cv: Arc) { let mut count = mu.lock().await; while *count == 0 { count = cv.wait(count).await; } *count -= 1; } let mu = Arc::new(Mutex::new(0)); let cv = Arc::new(Condvar::new()); let arc_waker = Arc::new(TestWaker); let waker = waker_ref(&arc_waker); let mut cx = Context::from_waker(&waker); let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); *block_on(mu.lock()) = 2; cv.notify_one(); // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2. mem::drop(fut1); assert_eq!(cv.state.load(Ordering::Relaxed), 0); if let Poll::Pending = fut2.as_mut().poll(&mut cx) { panic!("future unable to complete"); } assert_eq!(*block_on(mu.lock()), 1); } #[test] fn cancel_after_transfer() { async fn dec(mu: Arc>, cv: Arc) { let mut count = mu.lock().await; while *count == 0 { count = cv.wait(count).await; } *count -= 1; } let mu = Arc::new(Mutex::new(0)); let cv = Arc::new(Condvar::new()); let arc_waker = Arc::new(TestWaker); let waker = waker_ref(&arc_waker); let mut cx = Context::from_waker(&waker); let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); let mut count = block_on(mu.lock()); *count = 2; // Notify the cv while holding the lock. Only transfer one waiter. cv.notify_one(); assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); // Drop the lock and then the future. This should not cause fut2 to become runnable as it // should still be in the Condvar's wait queue. mem::drop(count); mem::drop(fut1); if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } // Now wake up fut2. Since the lock isn't held, it should wake up immediately. cv.notify_one(); if let Poll::Pending = fut2.as_mut().poll(&mut cx) { panic!("future unable to complete"); } assert_eq!(*block_on(mu.lock()), 1); } #[test] fn cancel_after_transfer_and_wake() { async fn dec(mu: Arc>, cv: Arc) { let mut count = mu.lock().await; while *count == 0 { count = cv.wait(count).await; } *count -= 1; } let mu = Arc::new(Mutex::new(0)); let cv = Arc::new(Condvar::new()); let arc_waker = Arc::new(TestWaker); let waker = waker_ref(&arc_waker); let mut cx = Context::from_waker(&waker); let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { panic!("future unexpectedly ready"); } assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); let mut count = block_on(mu.lock()); *count = 2; // Notify the cv while holding the lock. This should transfer both waiters to the mutex's // wait queue. cv.notify_all(); assert_eq!(cv.state.load(Ordering::Relaxed), 0); mem::drop(count); mem::drop(fut1); if let Poll::Pending = fut2.as_mut().poll(&mut cx) { panic!("future unable to complete"); } assert_eq!(*block_on(mu.lock()), 1); } #[test] fn timed_wait() { async fn wait_deadline( mu: Arc>, cv: Arc, timeout: oneshot::Receiver<()>, ) { let mut count = mu.lock().await; if *count == 0 { let mut rx = timeout.fuse(); while *count == 0 { select! { res = rx => { if let Err(e) = res { panic!("Error while receiving timeout notification: {}", e); } return; }, c = cv.wait(count).fuse() => count = c, } } } *count += 1; } let mu = Arc::new(Mutex::new(0)); let cv = Arc::new(Condvar::new()); let arc_waker = Arc::new(TestWaker); let waker = waker_ref(&arc_waker); let mut cx = Context::from_waker(&waker); let (tx, rx) = oneshot::channel(); let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx)); if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) { panic!("wait_deadline unexpectedly ready"); } assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS); // Signal the channel, which should cancel the wait. tx.send(()).expect("Failed to send wakeup"); // Wait for the timer to run out. if let Poll::Pending = wait.as_mut().poll(&mut cx) { panic!("wait_deadline unable to complete in time"); } assert_eq!(cv.state.load(Ordering::Relaxed), 0); assert_eq!(*block_on(mu.lock()), 0); } }