1 //! Synchronization primitive allowing multiple threads to synchronize the 2 //! beginning of some computation. 3 //! 4 //! Implementation adapted from the 'Barrier' type of the standard library. See: 5 //! <https://doc.rust-lang.org/std/sync/struct.Barrier.html> 6 //! 7 //! Copyright 2014 The Rust Project Developers. See the COPYRIGHT 8 //! file at the top-level directory of this distribution and at 9 //! <http://rust-lang.org/COPYRIGHT>. 10 //! 11 //! Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 12 //! <http://www.apache.org/licenses/LICENSE-2.0>> or the MIT license 13 //! <LICENSE-MIT or <http://opensource.org/licenses/MIT>>, at your 14 //! option. This file may not be copied, modified, or distributed 15 //! except according to those terms. 16 17 use crate::{mutex::Mutex, RelaxStrategy, Spin}; 18 19 /// A primitive that synchronizes the execution of multiple threads. 20 /// 21 /// # Example 22 /// 23 /// ``` 24 /// use spin; 25 /// use std::sync::Arc; 26 /// use std::thread; 27 /// 28 /// let mut handles = Vec::with_capacity(10); 29 /// let barrier = Arc::new(spin::Barrier::new(10)); 30 /// for _ in 0..10 { 31 /// let c = barrier.clone(); 32 /// // The same messages will be printed together. 33 /// // You will NOT see any interleaving. 34 /// handles.push(thread::spawn(move|| { 35 /// println!("before wait"); 36 /// c.wait(); 37 /// println!("after wait"); 38 /// })); 39 /// } 40 /// // Wait for other threads to finish. 41 /// for handle in handles { 42 /// handle.join().unwrap(); 43 /// } 44 /// ``` 45 pub struct Barrier<R = Spin> { 46 lock: Mutex<BarrierState, R>, 47 num_threads: usize, 48 } 49 50 // The inner state of a double barrier 51 struct BarrierState { 52 count: usize, 53 generation_id: usize, 54 } 55 56 /// A `BarrierWaitResult` is returned by [`wait`] when all threads in the [`Barrier`] 57 /// have rendezvoused. 58 /// 59 /// [`wait`]: struct.Barrier.html#method.wait 60 /// [`Barrier`]: struct.Barrier.html 61 /// 62 /// # Examples 63 /// 64 /// ``` 65 /// use spin; 66 /// 67 /// let barrier = spin::Barrier::new(1); 68 /// let barrier_wait_result = barrier.wait(); 69 /// ``` 70 pub struct BarrierWaitResult(bool); 71 72 impl<R: RelaxStrategy> Barrier<R> { 73 /// Blocks the current thread until all threads have rendezvoused here. 74 /// 75 /// Barriers are re-usable after all threads have rendezvoused once, and can 76 /// be used continuously. 77 /// 78 /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that 79 /// returns `true` from [`is_leader`] when returning from this function, and 80 /// all other threads will receive a result that will return `false` from 81 /// [`is_leader`]. 82 /// 83 /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html 84 /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader 85 /// 86 /// # Examples 87 /// 88 /// ``` 89 /// use spin; 90 /// use std::sync::Arc; 91 /// use std::thread; 92 /// 93 /// let mut handles = Vec::with_capacity(10); 94 /// let barrier = Arc::new(spin::Barrier::new(10)); 95 /// for _ in 0..10 { 96 /// let c = barrier.clone(); 97 /// // The same messages will be printed together. 98 /// // You will NOT see any interleaving. 99 /// handles.push(thread::spawn(move|| { 100 /// println!("before wait"); 101 /// c.wait(); 102 /// println!("after wait"); 103 /// })); 104 /// } 105 /// // Wait for other threads to finish. 106 /// for handle in handles { 107 /// handle.join().unwrap(); 108 /// } 109 /// ``` wait(&self) -> BarrierWaitResult110 pub fn wait(&self) -> BarrierWaitResult { 111 let mut lock = self.lock.lock(); 112 lock.count += 1; 113 114 if lock.count < self.num_threads { 115 // not the leader 116 let local_gen = lock.generation_id; 117 118 while local_gen == lock.generation_id && 119 lock.count < self.num_threads { 120 drop(lock); 121 R::relax(); 122 lock = self.lock.lock(); 123 } 124 BarrierWaitResult(false) 125 } else { 126 // this thread is the leader, 127 // and is responsible for incrementing the generation 128 lock.count = 0; 129 lock.generation_id = lock.generation_id.wrapping_add(1); 130 BarrierWaitResult(true) 131 } 132 } 133 } 134 135 impl<R> Barrier<R> { 136 /// Creates a new barrier that can block a given number of threads. 137 /// 138 /// A barrier will block `n`-1 threads which call [`wait`] and then wake up 139 /// all threads at once when the `n`th thread calls [`wait`]. A Barrier created 140 /// with n = 0 will behave identically to one created with n = 1. 141 /// 142 /// [`wait`]: #method.wait 143 /// 144 /// # Examples 145 /// 146 /// ``` 147 /// use spin; 148 /// 149 /// let barrier = spin::Barrier::new(10); 150 /// ``` new(n: usize) -> Self151 pub const fn new(n: usize) -> Self { 152 Self { 153 lock: Mutex::new(BarrierState { 154 count: 0, 155 generation_id: 0, 156 }), 157 num_threads: n, 158 } 159 } 160 } 161 162 impl BarrierWaitResult { 163 /// Returns whether this thread from [`wait`] is the "leader thread". 164 /// 165 /// Only one thread will have `true` returned from their result, all other 166 /// threads will have `false` returned. 167 /// 168 /// [`wait`]: struct.Barrier.html#method.wait 169 /// 170 /// # Examples 171 /// 172 /// ``` 173 /// use spin; 174 /// 175 /// let barrier = spin::Barrier::new(1); 176 /// let barrier_wait_result = barrier.wait(); 177 /// println!("{:?}", barrier_wait_result.is_leader()); 178 /// ``` is_leader(&self) -> bool179 pub fn is_leader(&self) -> bool { self.0 } 180 } 181 182 #[cfg(test)] 183 mod tests { 184 use std::prelude::v1::*; 185 186 use std::sync::mpsc::{channel, TryRecvError}; 187 use std::sync::Arc; 188 use std::thread; 189 190 type Barrier = super::Barrier; 191 use_barrier(n: usize, barrier: Arc<Barrier>)192 fn use_barrier(n: usize, barrier: Arc<Barrier>) { 193 let (tx, rx) = channel(); 194 195 for _ in 0..n - 1 { 196 let c = barrier.clone(); 197 let tx = tx.clone(); 198 thread::spawn(move|| { 199 tx.send(c.wait().is_leader()).unwrap(); 200 }); 201 } 202 203 // At this point, all spawned threads should be blocked, 204 // so we shouldn't get anything from the port 205 assert!(match rx.try_recv() { 206 Err(TryRecvError::Empty) => true, 207 _ => false, 208 }); 209 210 let mut leader_found = barrier.wait().is_leader(); 211 212 // Now, the barrier is cleared and we should get data. 213 for _ in 0..n - 1 { 214 if rx.recv().unwrap() { 215 assert!(!leader_found); 216 leader_found = true; 217 } 218 } 219 assert!(leader_found); 220 } 221 222 #[test] test_barrier()223 fn test_barrier() { 224 const N: usize = 10; 225 226 let barrier = Arc::new(Barrier::new(N)); 227 228 use_barrier(N, barrier.clone()); 229 230 // use barrier twice to ensure it is reusable 231 use_barrier(N, barrier.clone()); 232 } 233 } 234