1 use crate::loom::sync::Mutex; 2 use crate::sync::watch; 3 4 /// A barrier enables multiple tasks to synchronize the beginning of some computation. 5 /// 6 /// ``` 7 /// # #[tokio::main] 8 /// # async fn main() { 9 /// use tokio::sync::Barrier; 10 /// use std::sync::Arc; 11 /// 12 /// let mut handles = Vec::with_capacity(10); 13 /// let barrier = Arc::new(Barrier::new(10)); 14 /// for _ in 0..10 { 15 /// let c = barrier.clone(); 16 /// // The same messages will be printed together. 17 /// // You will NOT see any interleaving. 18 /// handles.push(tokio::spawn(async move { 19 /// println!("before wait"); 20 /// let wait_result = c.wait().await; 21 /// println!("after wait"); 22 /// wait_result 23 /// })); 24 /// } 25 /// 26 /// // Will not resolve until all "after wait" messages have been printed 27 /// let mut num_leaders = 0; 28 /// for handle in handles { 29 /// let wait_result = handle.await.unwrap(); 30 /// if wait_result.is_leader() { 31 /// num_leaders += 1; 32 /// } 33 /// } 34 /// 35 /// // Exactly one barrier will resolve as the "leader" 36 /// assert_eq!(num_leaders, 1); 37 /// # } 38 /// ``` 39 #[derive(Debug)] 40 pub struct Barrier { 41 state: Mutex<BarrierState>, 42 wait: watch::Receiver<usize>, 43 n: usize, 44 } 45 46 #[derive(Debug)] 47 struct BarrierState { 48 waker: watch::Sender<usize>, 49 arrived: usize, 50 generation: usize, 51 } 52 53 impl Barrier { 54 /// Creates a new barrier that can block a given number of tasks. 55 /// 56 /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all 57 /// tasks at once when the `n`th task calls `wait`. new(mut n: usize) -> Barrier58 pub fn new(mut n: usize) -> Barrier { 59 let (waker, wait) = crate::sync::watch::channel(0); 60 61 if n == 0 { 62 // if n is 0, it's not clear what behavior the user wants. 63 // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every 64 // .wait() immediately unblocks, so we adopt that here as well. 65 n = 1; 66 } 67 68 Barrier { 69 state: Mutex::new(BarrierState { 70 waker, 71 arrived: 0, 72 generation: 1, 73 }), 74 n, 75 wait, 76 } 77 } 78 79 /// Does not resolve until all tasks have rendezvoused here. 80 /// 81 /// Barriers are re-usable after all tasks have rendezvoused once, and can 82 /// be used continuously. 83 /// 84 /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from 85 /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks 86 /// will receive a result that will return `false` from `is_leader`. wait(&self) -> BarrierWaitResult87 pub async fn wait(&self) -> BarrierWaitResult { 88 // NOTE: we are taking a _synchronous_ lock here. 89 // It is okay to do so because the critical section is fast and never yields, so it cannot 90 // deadlock even if another future is concurrently holding the lock. 91 // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than 92 // the asynchronous counter-parts, so we should use them where possible [citation needed]. 93 // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across 94 // a yield point, and thus marks the returned future as !Send. 95 let generation = { 96 let mut state = self.state.lock(); 97 let generation = state.generation; 98 state.arrived += 1; 99 if state.arrived == self.n { 100 // we are the leader for this generation 101 // wake everyone, increment the generation, and return 102 state 103 .waker 104 .send(state.generation) 105 .expect("there is at least one receiver"); 106 state.arrived = 0; 107 state.generation += 1; 108 return BarrierWaitResult(true); 109 } 110 111 generation 112 }; 113 114 // we're going to have to wait for the last of the generation to arrive 115 let mut wait = self.wait.clone(); 116 117 loop { 118 let _ = wait.changed().await; 119 120 // note that the first time through the loop, this _will_ yield a generation 121 // immediately, since we cloned a receiver that has never seen any values. 122 if *wait.borrow() >= generation { 123 break; 124 } 125 } 126 127 BarrierWaitResult(false) 128 } 129 } 130 131 /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. 132 #[derive(Debug, Clone)] 133 pub struct BarrierWaitResult(bool); 134 135 impl BarrierWaitResult { 136 /// Returns `true` if this task from wait is the "leader task". 137 /// 138 /// Only one task will have `true` returned from their result, all other tasks will have 139 /// `false` returned. is_leader(&self) -> bool140 pub fn is_leader(&self) -> bool { 141 self.0 142 } 143 } 144