1 // Copyright 2020 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::{
6 cell::UnsafeCell,
7 hint, mem,
8 sync::{
9 atomic::{AtomicUsize, Ordering},
10 Arc,
11 },
12 };
13
14 use super::super::sync::{
15 mu::{MutexGuard, MutexReadGuard, RawMutex},
16 waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor},
17 };
18
19 const SPINLOCK: usize = 1 << 0;
20 const HAS_WAITERS: usize = 1 << 1;
21
22 /// A primitive to wait for an event to occur without consuming CPU time.
23 ///
24 /// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some
25 /// condition to become true. The condition must always be verified while holding the `Mutex` lock.
26 /// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on
27 /// the `Condvar`.
28 ///
29 /// # Examples
30 ///
31 /// ```edition2018
32 /// use std::sync::Arc;
33 /// use std::thread;
34 /// use std::sync::mpsc::channel;
35 ///
36 /// use cros_async::{
37 /// block_on,
38 /// sync::{Condvar, Mutex},
39 /// };
40 ///
41 /// const N: usize = 13;
42 ///
43 /// // Spawn a few threads to increment a shared variable (non-atomically), and
44 /// // let all threads waiting on the Condvar know once the increments are done.
45 /// let data = Arc::new(Mutex::new(0));
46 /// let cv = Arc::new(Condvar::new());
47 ///
48 /// for _ in 0..N {
49 /// let (data, cv) = (data.clone(), cv.clone());
50 /// thread::spawn(move || {
51 /// let mut data = block_on(data.lock());
52 /// *data += 1;
53 /// if *data == N {
54 /// cv.notify_all();
55 /// }
56 /// });
57 /// }
58 ///
59 /// let mut val = block_on(data.lock());
60 /// while *val != N {
61 /// val = block_on(cv.wait(val));
62 /// }
63 /// ```
64 #[repr(align(128))]
65 pub struct Condvar {
66 state: AtomicUsize,
67 waiters: UnsafeCell<WaiterList>,
68 mu: UnsafeCell<usize>,
69 }
70
71 impl Condvar {
72 /// Creates a new condition variable ready to be waited on and notified.
new() -> Condvar73 pub fn new() -> Condvar {
74 Condvar {
75 state: AtomicUsize::new(0),
76 waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
77 mu: UnsafeCell::new(0),
78 }
79 }
80
81 /// Block the current thread until this `Condvar` is notified by another thread.
82 ///
83 /// This method will atomically unlock the `Mutex` held by `guard` and then block the current
84 /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up
85 /// the thread.
86 ///
87 /// To allow for more efficient scheduling, this call may return even when the programmer
88 /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a
89 /// loop that checks the predicate before continuing.
90 ///
91 /// Callers that are not in an async context may wish to use the `block_on` method to block the
92 /// thread until the `Condvar` is notified.
93 ///
94 /// # Panics
95 ///
96 /// This method will panic if used with more than one `Mutex` at the same time.
97 ///
98 /// # Examples
99 ///
100 /// ```
101 /// # use std::sync::Arc;
102 /// # use std::thread;
103 ///
104 /// # use cros_async::{
105 /// # block_on,
106 /// # sync::{Condvar, Mutex},
107 /// # };
108 ///
109 /// # let mu = Arc::new(Mutex::new(false));
110 /// # let cv = Arc::new(Condvar::new());
111 /// # let (mu2, cv2) = (mu.clone(), cv.clone());
112 ///
113 /// # let t = thread::spawn(move || {
114 /// # *block_on(mu2.lock()) = true;
115 /// # cv2.notify_all();
116 /// # });
117 ///
118 /// let mut ready = block_on(mu.lock());
119 /// while !*ready {
120 /// ready = block_on(cv.wait(ready));
121 /// }
122 ///
123 /// # t.join().expect("failed to join thread");
124 /// ```
125 // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
126 // that doesn't compile.
127 #[allow(clippy::needless_lifetimes)]
wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T>128 pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> {
129 let waiter = Arc::new(Waiter::new(
130 WaiterKind::Exclusive,
131 cancel_waiter,
132 self as *const Condvar as usize,
133 WaitingFor::Condvar,
134 ));
135
136 self.add_waiter(waiter.clone(), guard.as_raw_mutex());
137
138 // Get a reference to the mutex and then drop the lock.
139 let mu = guard.into_inner();
140
141 // Wait to be woken up.
142 waiter.wait().await;
143
144 // Now re-acquire the lock.
145 mu.lock_from_cv().await
146 }
147
148 /// Like `wait()` but takes and returns a `MutexReadGuard` instead.
149 // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
150 // that doesn't compile.
151 #[allow(clippy::needless_lifetimes)]
wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T>152 pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> {
153 let waiter = Arc::new(Waiter::new(
154 WaiterKind::Shared,
155 cancel_waiter,
156 self as *const Condvar as usize,
157 WaitingFor::Condvar,
158 ));
159
160 self.add_waiter(waiter.clone(), guard.as_raw_mutex());
161
162 // Get a reference to the mutex and then drop the lock.
163 let mu = guard.into_inner();
164
165 // Wait to be woken up.
166 waiter.wait().await;
167
168 // Now re-acquire the lock.
169 mu.read_lock_from_cv().await
170 }
171
add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex)172 fn add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex) {
173 // Acquire the spin lock.
174 let mut oldstate = self.state.load(Ordering::Relaxed);
175 while (oldstate & SPINLOCK) != 0
176 || self
177 .state
178 .compare_exchange_weak(
179 oldstate,
180 oldstate | SPINLOCK | HAS_WAITERS,
181 Ordering::Acquire,
182 Ordering::Relaxed,
183 )
184 .is_err()
185 {
186 hint::spin_loop();
187 oldstate = self.state.load(Ordering::Relaxed);
188 }
189
190 // Safe because the spin lock guarantees exclusive access and the reference does not escape
191 // this function.
192 let mu = unsafe { &mut *self.mu.get() };
193 let muptr = raw_mutex as *const RawMutex as usize;
194
195 match *mu {
196 0 => *mu = muptr,
197 p if p == muptr => {}
198 _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"),
199 }
200
201 // Safe because the spin lock guarantees exclusive access.
202 unsafe { (*self.waiters.get()).push_back(waiter) };
203
204 // Release the spin lock. Use a direct store here because no other thread can modify
205 // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier
206 // because we just added a waiter.
207 self.state.store(HAS_WAITERS, Ordering::Release);
208 }
209
210 /// Notify at most one thread currently waiting on the `Condvar`.
211 ///
212 /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to
213 /// `wait`.
214 ///
215 /// Unlike more traditional condition variable interfaces, this method requires a reference to
216 /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call
217 /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking
218 /// a reference to the `Mutex` here allows us to make some optimizations that can improve
219 /// performance by reducing unnecessary wakeups.
notify_one(&self)220 pub fn notify_one(&self) {
221 let mut oldstate = self.state.load(Ordering::Relaxed);
222 if (oldstate & HAS_WAITERS) == 0 {
223 // No waiters.
224 return;
225 }
226
227 while (oldstate & SPINLOCK) != 0
228 || self
229 .state
230 .compare_exchange_weak(
231 oldstate,
232 oldstate | SPINLOCK,
233 Ordering::Acquire,
234 Ordering::Relaxed,
235 )
236 .is_err()
237 {
238 hint::spin_loop();
239 oldstate = self.state.load(Ordering::Relaxed);
240 }
241
242 // Safe because the spin lock guarantees exclusive access and the reference does not escape
243 // this function.
244 let waiters = unsafe { &mut *self.waiters.get() };
245 let wake_list = get_wake_list(waiters);
246
247 let newstate = if waiters.is_empty() {
248 // Also clear the mutex associated with this Condvar since there are no longer any
249 // waiters. Safe because the spin lock guarantees exclusive access.
250 unsafe { *self.mu.get() = 0 };
251
252 // We are releasing the spin lock and there are no more waiters so we can clear all bits
253 // in `self.state`.
254 0
255 } else {
256 // There are still waiters so we need to keep the HAS_WAITERS bit in the state.
257 HAS_WAITERS
258 };
259
260 // Release the spin lock.
261 self.state.store(newstate, Ordering::Release);
262
263 // Now wake any waiters in the wake list.
264 for w in wake_list {
265 w.wake();
266 }
267 }
268
269 /// Notify all threads currently waiting on the `Condvar`.
270 ///
271 /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`.
272 ///
273 /// Unlike more traditional condition variable interfaces, this method requires a reference to
274 /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call
275 /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking
276 /// a reference to the `Mutex` here allows us to make some optimizations that can improve
277 /// performance by reducing unnecessary wakeups.
notify_all(&self)278 pub fn notify_all(&self) {
279 let mut oldstate = self.state.load(Ordering::Relaxed);
280 if (oldstate & HAS_WAITERS) == 0 {
281 // No waiters.
282 return;
283 }
284
285 while (oldstate & SPINLOCK) != 0
286 || self
287 .state
288 .compare_exchange_weak(
289 oldstate,
290 oldstate | SPINLOCK,
291 Ordering::Acquire,
292 Ordering::Relaxed,
293 )
294 .is_err()
295 {
296 hint::spin_loop();
297 oldstate = self.state.load(Ordering::Relaxed);
298 }
299
300 // Safe because the spin lock guarantees exclusive access to `self.waiters`.
301 let wake_list = unsafe { (*self.waiters.get()).take() };
302
303 // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
304 // because we the spin lock guarantees exclusive access.
305 unsafe { *self.mu.get() = 0 };
306
307 // Mark any waiters left as no longer waiting for the Condvar.
308 for w in &wake_list {
309 w.set_waiting_for(WaitingFor::None);
310 }
311
312 // Release the spin lock. We can clear all bits in the state since we took all the waiters.
313 self.state.store(0, Ordering::Release);
314
315 // Now wake any waiters in the wake list.
316 for w in wake_list {
317 w.wake();
318 }
319 }
320
cancel_waiter(&self, waiter: &Waiter, wake_next: bool)321 fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) {
322 let mut oldstate = self.state.load(Ordering::Relaxed);
323 while oldstate & SPINLOCK != 0
324 || self
325 .state
326 .compare_exchange_weak(
327 oldstate,
328 oldstate | SPINLOCK,
329 Ordering::Acquire,
330 Ordering::Relaxed,
331 )
332 .is_err()
333 {
334 hint::spin_loop();
335 oldstate = self.state.load(Ordering::Relaxed);
336 }
337
338 // Safe because the spin lock provides exclusive access and the reference does not escape
339 // this function.
340 let waiters = unsafe { &mut *self.waiters.get() };
341
342 let waiting_for = waiter.is_waiting_for();
343 // Don't drop the old waiter now as we're still holding the spin lock.
344 let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar {
345 // Safe because we know that the waiter is still linked and is waiting for the Condvar,
346 // which guarantees that it is still in `self.waiters`.
347 let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
348 cursor.remove()
349 } else {
350 None
351 };
352
353 let wake_list = if wake_next || waiting_for == WaitingFor::None {
354 // Either the waiter was already woken or it's been removed from the condvar's waiter
355 // list and is going to be woken. Either way, we need to wake up another thread.
356 get_wake_list(waiters)
357 } else {
358 WaiterList::new(WaiterAdapter::new())
359 };
360
361 let set_on_release = if waiters.is_empty() {
362 // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
363 // because we the spin lock guarantees exclusive access.
364 unsafe { *self.mu.get() = 0 };
365
366 0
367 } else {
368 HAS_WAITERS
369 };
370
371 self.state.store(set_on_release, Ordering::Release);
372
373 // Now wake any waiters still left in the wake list.
374 for w in wake_list {
375 w.wake();
376 }
377
378 mem::drop(old_waiter);
379 }
380 }
381
382 unsafe impl Send for Condvar {}
383 unsafe impl Sync for Condvar {}
384
385 impl Default for Condvar {
default() -> Self386 fn default() -> Self {
387 Self::new()
388 }
389 }
390
391 // Scan `waiters` and return all waiters that should be woken up.
392 //
393 // If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
394 // waiting for a shared lock are also woken up. In addition one writer is woken up, if possible.
395 //
396 // If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
397 // the rest of the list is not scanned.
get_wake_list(waiters: &mut WaiterList) -> WaiterList398 fn get_wake_list(waiters: &mut WaiterList) -> WaiterList {
399 let mut to_wake = WaiterList::new(WaiterAdapter::new());
400 let mut cursor = waiters.front_mut();
401
402 let mut waking_readers = false;
403 let mut all_readers = true;
404 while let Some(w) = cursor.get() {
405 match w.kind() {
406 WaiterKind::Exclusive if !waking_readers => {
407 // This is the first waiter and it's a writer. No need to check the other waiters.
408 // Also mark the waiter as having been removed from the Condvar's waiter list.
409 let waiter = cursor.remove().unwrap();
410 waiter.set_waiting_for(WaitingFor::None);
411 to_wake.push_back(waiter);
412 break;
413 }
414
415 WaiterKind::Shared => {
416 // This is a reader and the first waiter in the list was not a writer so wake up all
417 // the readers in the wait list.
418 let waiter = cursor.remove().unwrap();
419 waiter.set_waiting_for(WaitingFor::None);
420 to_wake.push_back(waiter);
421 waking_readers = true;
422 }
423
424 WaiterKind::Exclusive => {
425 debug_assert!(waking_readers);
426 if all_readers {
427 // We are waking readers but we need to ensure that at least one writer is woken
428 // up. Since we haven't yet woken up a writer, wake up this one.
429 let waiter = cursor.remove().unwrap();
430 waiter.set_waiting_for(WaitingFor::None);
431 to_wake.push_back(waiter);
432 all_readers = false;
433 } else {
434 // We are waking readers and have already woken one writer. Skip this one.
435 cursor.move_next();
436 }
437 }
438 }
439 }
440
441 to_wake
442 }
443
cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool)444 fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) {
445 let condvar = cv as *const Condvar;
446
447 // Safe because the thread that owns the waiter being canceled must also own a reference to the
448 // Condvar, which guarantees that this pointer is valid.
449 unsafe { (*condvar).cancel_waiter(waiter, wake_next) }
450 }
451
452 #[cfg(test)]
453 mod test {
454 use super::*;
455
456 use std::{
457 future::Future,
458 mem, ptr,
459 rc::Rc,
460 sync::{
461 mpsc::{channel, Sender},
462 Arc,
463 },
464 task::{Context, Poll},
465 thread::{
466 JoinHandle, {self},
467 },
468 time::Duration,
469 };
470
471 use futures::{
472 channel::oneshot,
473 select,
474 task::{waker_ref, ArcWake},
475 FutureExt,
476 };
477 use futures_executor::{LocalPool, LocalSpawner, ThreadPool};
478 use futures_util::task::LocalSpawnExt;
479
480 use super::super::super::{block_on, sync::Mutex};
481
482 // Dummy waker used when we want to manually drive futures.
483 struct TestWaker;
484 impl ArcWake for TestWaker {
wake_by_ref(_arc_self: &Arc<Self>)485 fn wake_by_ref(_arc_self: &Arc<Self>) {}
486 }
487
488 #[test]
smoke()489 fn smoke() {
490 let cv = Condvar::new();
491 cv.notify_one();
492 cv.notify_all();
493 }
494
495 #[test]
notify_one()496 fn notify_one() {
497 let mu = Arc::new(Mutex::new(()));
498 let cv = Arc::new(Condvar::new());
499
500 let mu2 = mu.clone();
501 let cv2 = cv.clone();
502
503 let guard = block_on(mu.lock());
504 thread::spawn(move || {
505 let _g = block_on(mu2.lock());
506 cv2.notify_one();
507 });
508
509 let guard = block_on(cv.wait(guard));
510 mem::drop(guard);
511 }
512
513 #[test]
multi_mutex()514 fn multi_mutex() {
515 const NUM_THREADS: usize = 5;
516
517 let mu = Arc::new(Mutex::new(false));
518 let cv = Arc::new(Condvar::new());
519
520 let mut threads = Vec::with_capacity(NUM_THREADS);
521 for _ in 0..NUM_THREADS {
522 let mu = mu.clone();
523 let cv = cv.clone();
524
525 threads.push(thread::spawn(move || {
526 let mut ready = block_on(mu.lock());
527 while !*ready {
528 ready = block_on(cv.wait(ready));
529 }
530 }));
531 }
532
533 let mut g = block_on(mu.lock());
534 *g = true;
535 mem::drop(g);
536 cv.notify_all();
537
538 threads
539 .into_iter()
540 .try_for_each(JoinHandle::join)
541 .expect("Failed to join threads");
542
543 // Now use the Condvar with a different mutex.
544 let alt_mu = Arc::new(Mutex::new(None));
545 let alt_mu2 = alt_mu.clone();
546 let cv2 = cv.clone();
547 let handle = thread::spawn(move || {
548 let mut g = block_on(alt_mu2.lock());
549 while g.is_none() {
550 g = block_on(cv2.wait(g));
551 }
552 });
553
554 let mut alt_g = block_on(alt_mu.lock());
555 *alt_g = Some(());
556 mem::drop(alt_g);
557 cv.notify_all();
558
559 handle
560 .join()
561 .expect("Failed to join thread alternate mutex");
562 }
563
564 #[test]
notify_one_single_thread_async()565 fn notify_one_single_thread_async() {
566 async fn notify(mu: Rc<Mutex<()>>, cv: Rc<Condvar>) {
567 let _g = mu.lock().await;
568 cv.notify_one();
569 }
570
571 async fn wait(mu: Rc<Mutex<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
572 let mu2 = Rc::clone(&mu);
573 let cv2 = Rc::clone(&cv);
574
575 let g = mu.lock().await;
576 // Has to be spawned _after_ acquiring the lock to prevent a race
577 // where the notify happens before the waiter has acquired the lock.
578 spawner
579 .spawn_local(notify(mu2, cv2))
580 .expect("Failed to spawn `notify` task");
581 let _g = cv.wait(g).await;
582 }
583
584 let mut ex = LocalPool::new();
585 let spawner = ex.spawner();
586
587 let mu = Rc::new(Mutex::new(()));
588 let cv = Rc::new(Condvar::new());
589
590 spawner
591 .spawn_local(wait(mu, cv, spawner.clone()))
592 .expect("Failed to spawn `wait` task");
593
594 ex.run();
595 }
596
597 #[test]
notify_one_multi_thread_async()598 fn notify_one_multi_thread_async() {
599 async fn notify(mu: Arc<Mutex<()>>, cv: Arc<Condvar>) {
600 let _g = mu.lock().await;
601 cv.notify_one();
602 }
603
604 async fn wait(mu: Arc<Mutex<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) {
605 let mu2 = Arc::clone(&mu);
606 let cv2 = Arc::clone(&cv);
607
608 let g = mu.lock().await;
609 // Has to be spawned _after_ acquiring the lock to prevent a race
610 // where the notify happens before the waiter has acquired the lock.
611 pool.spawn_ok(notify(mu2, cv2));
612 let _g = cv.wait(g).await;
613
614 tx.send(()).expect("Failed to send completion notification");
615 }
616
617 let ex = ThreadPool::new().expect("Failed to create ThreadPool");
618
619 let mu = Arc::new(Mutex::new(()));
620 let cv = Arc::new(Condvar::new());
621
622 let (tx, rx) = channel();
623 ex.spawn_ok(wait(mu, cv, tx, ex.clone()));
624
625 rx.recv_timeout(Duration::from_secs(5))
626 .expect("Failed to receive completion notification");
627 }
628
629 #[test]
notify_one_with_cancel()630 fn notify_one_with_cancel() {
631 const TASKS: usize = 17;
632 const OBSERVERS: usize = 7;
633 const ITERATIONS: usize = 103;
634
635 async fn observe(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
636 let mut count = mu.read_lock().await;
637 while *count == 0 {
638 count = cv.wait_read(count).await;
639 }
640 let _ = unsafe { ptr::read_volatile(&*count as *const usize) };
641 }
642
643 async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
644 let mut count = mu.lock().await;
645 while *count == 0 {
646 count = cv.wait(count).await;
647 }
648 *count -= 1;
649 }
650
651 async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
652 for _ in 0..TASKS * OBSERVERS * ITERATIONS {
653 *mu.lock().await += 1;
654 cv.notify_one();
655 }
656
657 done.send(()).expect("Failed to send completion message");
658 }
659
660 async fn observe_either(
661 mu: Arc<Mutex<usize>>,
662 cv: Arc<Condvar>,
663 alt_mu: Arc<Mutex<usize>>,
664 alt_cv: Arc<Condvar>,
665 done: Sender<()>,
666 ) {
667 for _ in 0..ITERATIONS {
668 select! {
669 () = observe(&mu, &cv).fuse() => {},
670 () = observe(&alt_mu, &alt_cv).fuse() => {},
671 }
672 }
673
674 done.send(()).expect("Failed to send completion message");
675 }
676
677 async fn decrement_either(
678 mu: Arc<Mutex<usize>>,
679 cv: Arc<Condvar>,
680 alt_mu: Arc<Mutex<usize>>,
681 alt_cv: Arc<Condvar>,
682 done: Sender<()>,
683 ) {
684 for _ in 0..ITERATIONS {
685 select! {
686 () = decrement(&mu, &cv).fuse() => {},
687 () = decrement(&alt_mu, &alt_cv).fuse() => {},
688 }
689 }
690
691 done.send(()).expect("Failed to send completion message");
692 }
693
694 let ex = ThreadPool::new().expect("Failed to create ThreadPool");
695
696 let mu = Arc::new(Mutex::new(0usize));
697 let alt_mu = Arc::new(Mutex::new(0usize));
698
699 let cv = Arc::new(Condvar::new());
700 let alt_cv = Arc::new(Condvar::new());
701
702 let (tx, rx) = channel();
703 for _ in 0..TASKS {
704 ex.spawn_ok(decrement_either(
705 Arc::clone(&mu),
706 Arc::clone(&cv),
707 Arc::clone(&alt_mu),
708 Arc::clone(&alt_cv),
709 tx.clone(),
710 ));
711 }
712
713 for _ in 0..OBSERVERS {
714 ex.spawn_ok(observe_either(
715 Arc::clone(&mu),
716 Arc::clone(&cv),
717 Arc::clone(&alt_mu),
718 Arc::clone(&alt_cv),
719 tx.clone(),
720 ));
721 }
722
723 ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
724 ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
725
726 for _ in 0..TASKS + OBSERVERS + 2 {
727 if let Err(e) = rx.recv_timeout(Duration::from_secs(20)) {
728 panic!("Error while waiting for threads to complete: {}", e);
729 }
730 }
731
732 assert_eq!(
733 *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
734 (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS)
735 );
736 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
737 assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
738 }
739
740 #[test]
notify_all_with_cancel()741 fn notify_all_with_cancel() {
742 const TASKS: usize = 17;
743 const ITERATIONS: usize = 103;
744
745 async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
746 let mut count = mu.lock().await;
747 while *count == 0 {
748 count = cv.wait(count).await;
749 }
750 *count -= 1;
751 }
752
753 async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
754 for _ in 0..TASKS * ITERATIONS {
755 *mu.lock().await += 1;
756 cv.notify_all();
757 }
758
759 done.send(()).expect("Failed to send completion message");
760 }
761
762 async fn decrement_either(
763 mu: Arc<Mutex<usize>>,
764 cv: Arc<Condvar>,
765 alt_mu: Arc<Mutex<usize>>,
766 alt_cv: Arc<Condvar>,
767 done: Sender<()>,
768 ) {
769 for _ in 0..ITERATIONS {
770 select! {
771 () = decrement(&mu, &cv).fuse() => {},
772 () = decrement(&alt_mu, &alt_cv).fuse() => {},
773 }
774 }
775
776 done.send(()).expect("Failed to send completion message");
777 }
778
779 let ex = ThreadPool::new().expect("Failed to create ThreadPool");
780
781 let mu = Arc::new(Mutex::new(0usize));
782 let alt_mu = Arc::new(Mutex::new(0usize));
783
784 let cv = Arc::new(Condvar::new());
785 let alt_cv = Arc::new(Condvar::new());
786
787 let (tx, rx) = channel();
788 for _ in 0..TASKS {
789 ex.spawn_ok(decrement_either(
790 Arc::clone(&mu),
791 Arc::clone(&cv),
792 Arc::clone(&alt_mu),
793 Arc::clone(&alt_cv),
794 tx.clone(),
795 ));
796 }
797
798 ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
799 ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
800
801 for _ in 0..TASKS + 2 {
802 if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
803 panic!("Error while waiting for threads to complete: {}", e);
804 }
805 }
806
807 assert_eq!(
808 *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
809 TASKS * ITERATIONS
810 );
811 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
812 assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
813 }
814 #[test]
notify_all()815 fn notify_all() {
816 const THREADS: usize = 13;
817
818 let mu = Arc::new(Mutex::new(0));
819 let cv = Arc::new(Condvar::new());
820 let (tx, rx) = channel();
821
822 let mut threads = Vec::with_capacity(THREADS);
823 for _ in 0..THREADS {
824 let mu2 = mu.clone();
825 let cv2 = cv.clone();
826 let tx2 = tx.clone();
827
828 threads.push(thread::spawn(move || {
829 let mut count = block_on(mu2.lock());
830 *count += 1;
831 if *count == THREADS {
832 tx2.send(()).unwrap();
833 }
834
835 while *count != 0 {
836 count = block_on(cv2.wait(count));
837 }
838 }));
839 }
840
841 mem::drop(tx);
842
843 // Wait till all threads have started.
844 rx.recv_timeout(Duration::from_secs(5)).unwrap();
845
846 let mut count = block_on(mu.lock());
847 *count = 0;
848 mem::drop(count);
849 cv.notify_all();
850
851 for t in threads {
852 t.join().unwrap();
853 }
854 }
855
856 #[test]
notify_all_single_thread_async()857 fn notify_all_single_thread_async() {
858 const TASKS: usize = 13;
859
860 async fn reset(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>) {
861 let mut count = mu.lock().await;
862 *count = 0;
863 cv.notify_all();
864 }
865
866 async fn watcher(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
867 let mut count = mu.lock().await;
868 *count += 1;
869 if *count == TASKS {
870 spawner
871 .spawn_local(reset(mu.clone(), cv.clone()))
872 .expect("Failed to spawn reset task");
873 }
874
875 while *count != 0 {
876 count = cv.wait(count).await;
877 }
878 }
879
880 let mut ex = LocalPool::new();
881 let spawner = ex.spawner();
882
883 let mu = Rc::new(Mutex::new(0));
884 let cv = Rc::new(Condvar::new());
885
886 for _ in 0..TASKS {
887 spawner
888 .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone()))
889 .expect("Failed to spawn watcher task");
890 }
891
892 ex.run();
893 }
894
895 #[test]
notify_all_multi_thread_async()896 fn notify_all_multi_thread_async() {
897 const TASKS: usize = 13;
898
899 async fn reset(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
900 let mut count = mu.lock().await;
901 *count = 0;
902 cv.notify_all();
903 }
904
905 async fn watcher(
906 mu: Arc<Mutex<usize>>,
907 cv: Arc<Condvar>,
908 pool: ThreadPool,
909 tx: Sender<()>,
910 ) {
911 let mut count = mu.lock().await;
912 *count += 1;
913 if *count == TASKS {
914 pool.spawn_ok(reset(mu.clone(), cv.clone()));
915 }
916
917 while *count != 0 {
918 count = cv.wait(count).await;
919 }
920
921 tx.send(()).expect("Failed to send completion notification");
922 }
923
924 let pool = ThreadPool::new().expect("Failed to create ThreadPool");
925
926 let mu = Arc::new(Mutex::new(0));
927 let cv = Arc::new(Condvar::new());
928
929 let (tx, rx) = channel();
930 for _ in 0..TASKS {
931 pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone()));
932 }
933
934 for _ in 0..TASKS {
935 rx.recv_timeout(Duration::from_secs(5))
936 .expect("Failed to receive completion notification");
937 }
938 }
939
940 #[test]
wake_all_readers()941 fn wake_all_readers() {
942 async fn read(mu: Arc<Mutex<bool>>, cv: Arc<Condvar>) {
943 let mut ready = mu.read_lock().await;
944 while !*ready {
945 ready = cv.wait_read(ready).await;
946 }
947 }
948
949 let mu = Arc::new(Mutex::new(false));
950 let cv = Arc::new(Condvar::new());
951 let mut readers = [
952 Box::pin(read(mu.clone(), cv.clone())),
953 Box::pin(read(mu.clone(), cv.clone())),
954 Box::pin(read(mu.clone(), cv.clone())),
955 Box::pin(read(mu.clone(), cv.clone())),
956 ];
957
958 let arc_waker = Arc::new(TestWaker);
959 let waker = waker_ref(&arc_waker);
960 let mut cx = Context::from_waker(&waker);
961
962 // First have all the readers wait on the Condvar.
963 for r in &mut readers {
964 if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
965 panic!("reader unexpectedly ready");
966 }
967 }
968
969 assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
970
971 // Now make the condition true and notify the condvar. Even though we will call notify_one,
972 // all the readers should be woken up.
973 *block_on(mu.lock()) = true;
974 cv.notify_one();
975
976 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
977
978 // All readers should now be able to complete.
979 for r in &mut readers {
980 if r.as_mut().poll(&mut cx).is_pending() {
981 panic!("reader unable to complete");
982 }
983 }
984 }
985
986 #[test]
cancel_before_notify()987 fn cancel_before_notify() {
988 async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
989 let mut count = mu.lock().await;
990
991 while *count == 0 {
992 count = cv.wait(count).await;
993 }
994
995 *count -= 1;
996 }
997
998 let mu = Arc::new(Mutex::new(0));
999 let cv = Arc::new(Condvar::new());
1000
1001 let arc_waker = Arc::new(TestWaker);
1002 let waker = waker_ref(&arc_waker);
1003 let mut cx = Context::from_waker(&waker);
1004
1005 let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1006 let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1007
1008 if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1009 panic!("future unexpectedly ready");
1010 }
1011 if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1012 panic!("future unexpectedly ready");
1013 }
1014 assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1015
1016 *block_on(mu.lock()) = 2;
1017 // Drop fut1 before notifying the cv.
1018 mem::drop(fut1);
1019 cv.notify_one();
1020
1021 // fut2 should now be ready to complete.
1022 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1023
1024 if fut2.as_mut().poll(&mut cx).is_pending() {
1025 panic!("future unable to complete");
1026 }
1027
1028 assert_eq!(*block_on(mu.lock()), 1);
1029 }
1030
1031 #[test]
cancel_after_notify_one()1032 fn cancel_after_notify_one() {
1033 async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
1034 let mut count = mu.lock().await;
1035
1036 while *count == 0 {
1037 count = cv.wait(count).await;
1038 }
1039
1040 *count -= 1;
1041 }
1042
1043 let mu = Arc::new(Mutex::new(0));
1044 let cv = Arc::new(Condvar::new());
1045
1046 let arc_waker = Arc::new(TestWaker);
1047 let waker = waker_ref(&arc_waker);
1048 let mut cx = Context::from_waker(&waker);
1049
1050 let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1051 let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1052
1053 if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1054 panic!("future unexpectedly ready");
1055 }
1056 if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1057 panic!("future unexpectedly ready");
1058 }
1059 assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1060
1061 *block_on(mu.lock()) = 2;
1062 cv.notify_one();
1063
1064 // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2.
1065 mem::drop(fut1);
1066 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1067
1068 if fut2.as_mut().poll(&mut cx).is_pending() {
1069 panic!("future unable to complete");
1070 }
1071
1072 assert_eq!(*block_on(mu.lock()), 1);
1073 }
1074
1075 #[test]
cancel_after_notify_all()1076 fn cancel_after_notify_all() {
1077 async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
1078 let mut count = mu.lock().await;
1079
1080 while *count == 0 {
1081 count = cv.wait(count).await;
1082 }
1083
1084 *count -= 1;
1085 }
1086
1087 let mu = Arc::new(Mutex::new(0));
1088 let cv = Arc::new(Condvar::new());
1089
1090 let arc_waker = Arc::new(TestWaker);
1091 let waker = waker_ref(&arc_waker);
1092 let mut cx = Context::from_waker(&waker);
1093
1094 let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
1095 let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
1096
1097 if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
1098 panic!("future unexpectedly ready");
1099 }
1100 if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
1101 panic!("future unexpectedly ready");
1102 }
1103 assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
1104
1105 let mut count = block_on(mu.lock());
1106 *count = 2;
1107
1108 // Notify the cv while holding the lock. This should wake up both waiters.
1109 cv.notify_all();
1110 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1111
1112 mem::drop(count);
1113
1114 mem::drop(fut1);
1115
1116 if fut2.as_mut().poll(&mut cx).is_pending() {
1117 panic!("future unable to complete");
1118 }
1119
1120 assert_eq!(*block_on(mu.lock()), 1);
1121 }
1122
1123 #[test]
timed_wait()1124 fn timed_wait() {
1125 async fn wait_deadline(
1126 mu: Arc<Mutex<usize>>,
1127 cv: Arc<Condvar>,
1128 timeout: oneshot::Receiver<()>,
1129 ) {
1130 let mut count = mu.lock().await;
1131
1132 if *count == 0 {
1133 let mut rx = timeout.fuse();
1134
1135 while *count == 0 {
1136 select! {
1137 res = rx => {
1138 if let Err(e) = res {
1139 panic!("Error while receiving timeout notification: {}", e);
1140 }
1141
1142 return;
1143 },
1144 c = cv.wait(count).fuse() => count = c,
1145 }
1146 }
1147 }
1148
1149 *count += 1;
1150 }
1151
1152 let mu = Arc::new(Mutex::new(0));
1153 let cv = Arc::new(Condvar::new());
1154
1155 let arc_waker = Arc::new(TestWaker);
1156 let waker = waker_ref(&arc_waker);
1157 let mut cx = Context::from_waker(&waker);
1158
1159 let (tx, rx) = oneshot::channel();
1160 let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx));
1161
1162 if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) {
1163 panic!("wait_deadline unexpectedly ready");
1164 }
1165
1166 assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS);
1167
1168 // Signal the channel, which should cancel the wait.
1169 tx.send(()).expect("Failed to send wakeup");
1170
1171 // Wait for the timer to run out.
1172 if wait.as_mut().poll(&mut cx).is_pending() {
1173 panic!("wait_deadline unable to complete in time");
1174 }
1175
1176 assert_eq!(cv.state.load(Ordering::Relaxed), 0);
1177 assert_eq!(*block_on(mu.lock()), 0);
1178 }
1179 }
1180