1 use crate::loom::sync::Mutex; 2 use crate::sync::watch; 3 #[cfg(all(tokio_unstable, feature = "tracing"))] 4 use crate::util::trace; 5 6 /// A barrier enables multiple tasks to synchronize the beginning of some computation. 7 /// 8 /// ``` 9 /// # #[tokio::main] 10 /// # async fn main() { 11 /// use tokio::sync::Barrier; 12 /// use std::sync::Arc; 13 /// 14 /// let mut handles = Vec::with_capacity(10); 15 /// let barrier = Arc::new(Barrier::new(10)); 16 /// for _ in 0..10 { 17 /// let c = barrier.clone(); 18 /// // The same messages will be printed together. 19 /// // You will NOT see any interleaving. 20 /// handles.push(tokio::spawn(async move { 21 /// println!("before wait"); 22 /// let wait_result = c.wait().await; 23 /// println!("after wait"); 24 /// wait_result 25 /// })); 26 /// } 27 /// 28 /// // Will not resolve until all "after wait" messages have been printed 29 /// let mut num_leaders = 0; 30 /// for handle in handles { 31 /// let wait_result = handle.await.unwrap(); 32 /// if wait_result.is_leader() { 33 /// num_leaders += 1; 34 /// } 35 /// } 36 /// 37 /// // Exactly one barrier will resolve as the "leader" 38 /// assert_eq!(num_leaders, 1); 39 /// # } 40 /// ``` 41 #[derive(Debug)] 42 pub struct Barrier { 43 state: Mutex<BarrierState>, 44 wait: watch::Receiver<usize>, 45 n: usize, 46 #[cfg(all(tokio_unstable, feature = "tracing"))] 47 resource_span: tracing::Span, 48 } 49 50 #[derive(Debug)] 51 struct BarrierState { 52 waker: watch::Sender<usize>, 53 arrived: usize, 54 generation: usize, 55 } 56 57 impl Barrier { 58 /// Creates a new barrier that can block a given number of tasks. 59 /// 60 /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all 61 /// tasks at once when the `n`th task calls `wait`. 62 #[track_caller] new(mut n: usize) -> Barrier63 pub fn new(mut n: usize) -> Barrier { 64 let (waker, wait) = crate::sync::watch::channel(0); 65 66 if n == 0 { 67 // if n is 0, it's not clear what behavior the user wants. 68 // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every 69 // .wait() immediately unblocks, so we adopt that here as well. 70 n = 1; 71 } 72 73 #[cfg(all(tokio_unstable, feature = "tracing"))] 74 let resource_span = { 75 let location = std::panic::Location::caller(); 76 let resource_span = tracing::trace_span!( 77 "runtime.resource", 78 concrete_type = "Barrier", 79 kind = "Sync", 80 loc.file = location.file(), 81 loc.line = location.line(), 82 loc.col = location.column(), 83 ); 84 85 resource_span.in_scope(|| { 86 tracing::trace!( 87 target: "runtime::resource::state_update", 88 size = n, 89 ); 90 91 tracing::trace!( 92 target: "runtime::resource::state_update", 93 arrived = 0, 94 ) 95 }); 96 resource_span 97 }; 98 99 Barrier { 100 state: Mutex::new(BarrierState { 101 waker, 102 arrived: 0, 103 generation: 1, 104 }), 105 n, 106 wait, 107 #[cfg(all(tokio_unstable, feature = "tracing"))] 108 resource_span, 109 } 110 } 111 112 /// Does not resolve until all tasks have rendezvoused here. 113 /// 114 /// Barriers are re-usable after all tasks have rendezvoused once, and can 115 /// be used continuously. 116 /// 117 /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from 118 /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks 119 /// will receive a result that will return `false` from `is_leader`. wait(&self) -> BarrierWaitResult120 pub async fn wait(&self) -> BarrierWaitResult { 121 #[cfg(all(tokio_unstable, feature = "tracing"))] 122 return trace::async_op( 123 || self.wait_internal(), 124 self.resource_span.clone(), 125 "Barrier::wait", 126 "poll", 127 false, 128 ) 129 .await; 130 131 #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] 132 return self.wait_internal().await; 133 } wait_internal(&self) -> BarrierWaitResult134 async fn wait_internal(&self) -> BarrierWaitResult { 135 crate::trace::async_trace_leaf().await; 136 137 // NOTE: we are taking a _synchronous_ lock here. 138 // It is okay to do so because the critical section is fast and never yields, so it cannot 139 // deadlock even if another future is concurrently holding the lock. 140 // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than 141 // the asynchronous counter-parts, so we should use them where possible [citation needed]. 142 // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across 143 // a yield point, and thus marks the returned future as !Send. 144 let generation = { 145 let mut state = self.state.lock(); 146 let generation = state.generation; 147 state.arrived += 1; 148 #[cfg(all(tokio_unstable, feature = "tracing"))] 149 tracing::trace!( 150 target: "runtime::resource::state_update", 151 arrived = 1, 152 arrived.op = "add", 153 ); 154 #[cfg(all(tokio_unstable, feature = "tracing"))] 155 tracing::trace!( 156 target: "runtime::resource::async_op::state_update", 157 arrived = true, 158 ); 159 if state.arrived == self.n { 160 #[cfg(all(tokio_unstable, feature = "tracing"))] 161 tracing::trace!( 162 target: "runtime::resource::async_op::state_update", 163 is_leader = true, 164 ); 165 // we are the leader for this generation 166 // wake everyone, increment the generation, and return 167 state 168 .waker 169 .send(state.generation) 170 .expect("there is at least one receiver"); 171 state.arrived = 0; 172 state.generation += 1; 173 return BarrierWaitResult(true); 174 } 175 176 generation 177 }; 178 179 // we're going to have to wait for the last of the generation to arrive 180 let mut wait = self.wait.clone(); 181 182 loop { 183 let _ = wait.changed().await; 184 185 // note that the first time through the loop, this _will_ yield a generation 186 // immediately, since we cloned a receiver that has never seen any values. 187 if *wait.borrow() >= generation { 188 break; 189 } 190 } 191 192 BarrierWaitResult(false) 193 } 194 } 195 196 /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. 197 #[derive(Debug, Clone)] 198 pub struct BarrierWaitResult(bool); 199 200 impl BarrierWaitResult { 201 /// Returns `true` if this task from wait is the "leader task". 202 /// 203 /// Only one task will have `true` returned from their result, all other tasks will have 204 /// `false` returned. is_leader(&self) -> bool205 pub fn is_leader(&self) -> bool { 206 self.0 207 } 208 } 209