• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::loom::sync::Mutex;
2 use crate::sync::watch;
3 
4 /// A barrier enables multiple tasks to synchronize the beginning of some computation.
5 ///
6 /// ```
7 /// # #[tokio::main]
8 /// # async fn main() {
9 /// use tokio::sync::Barrier;
10 /// use std::sync::Arc;
11 ///
12 /// let mut handles = Vec::with_capacity(10);
13 /// let barrier = Arc::new(Barrier::new(10));
14 /// for _ in 0..10 {
15 ///     let c = barrier.clone();
16 ///     // The same messages will be printed together.
17 ///     // You will NOT see any interleaving.
18 ///     handles.push(tokio::spawn(async move {
19 ///         println!("before wait");
20 ///         let wait_result = c.wait().await;
21 ///         println!("after wait");
22 ///         wait_result
23 ///     }));
24 /// }
25 ///
26 /// // Will not resolve until all "after wait" messages have been printed
27 /// let mut num_leaders = 0;
28 /// for handle in handles {
29 ///     let wait_result = handle.await.unwrap();
30 ///     if wait_result.is_leader() {
31 ///         num_leaders += 1;
32 ///     }
33 /// }
34 ///
35 /// // Exactly one barrier will resolve as the "leader"
36 /// assert_eq!(num_leaders, 1);
37 /// # }
38 /// ```
39 #[derive(Debug)]
40 pub struct Barrier {
41     state: Mutex<BarrierState>,
42     wait: watch::Receiver<usize>,
43     n: usize,
44 }
45 
46 #[derive(Debug)]
47 struct BarrierState {
48     waker: watch::Sender<usize>,
49     arrived: usize,
50     generation: usize,
51 }
52 
53 impl Barrier {
54     /// Creates a new barrier that can block a given number of tasks.
55     ///
56     /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all
57     /// tasks at once when the `n`th task calls `wait`.
new(mut n: usize) -> Barrier58     pub fn new(mut n: usize) -> Barrier {
59         let (waker, wait) = crate::sync::watch::channel(0);
60 
61         if n == 0 {
62             // if n is 0, it's not clear what behavior the user wants.
63             // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
64             // .wait() immediately unblocks, so we adopt that here as well.
65             n = 1;
66         }
67 
68         Barrier {
69             state: Mutex::new(BarrierState {
70                 waker,
71                 arrived: 0,
72                 generation: 1,
73             }),
74             n,
75             wait,
76         }
77     }
78 
79     /// Does not resolve until all tasks have rendezvoused here.
80     ///
81     /// Barriers are re-usable after all tasks have rendezvoused once, and can
82     /// be used continuously.
83     ///
84     /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
85     /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks
86     /// will receive a result that will return `false` from `is_leader`.
wait(&self) -> BarrierWaitResult87     pub async fn wait(&self) -> BarrierWaitResult {
88         // NOTE: we are taking a _synchronous_ lock here.
89         // It is okay to do so because the critical section is fast and never yields, so it cannot
90         // deadlock even if another future is concurrently holding the lock.
91         // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than
92         // the asynchronous counter-parts, so we should use them where possible [citation needed].
93         // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
94         // a yield point, and thus marks the returned future as !Send.
95         let generation = {
96             let mut state = self.state.lock();
97             let generation = state.generation;
98             state.arrived += 1;
99             if state.arrived == self.n {
100                 // we are the leader for this generation
101                 // wake everyone, increment the generation, and return
102                 state
103                     .waker
104                     .send(state.generation)
105                     .expect("there is at least one receiver");
106                 state.arrived = 0;
107                 state.generation += 1;
108                 return BarrierWaitResult(true);
109             }
110 
111             generation
112         };
113 
114         // we're going to have to wait for the last of the generation to arrive
115         let mut wait = self.wait.clone();
116 
117         loop {
118             let _ = wait.changed().await;
119 
120             // note that the first time through the loop, this _will_ yield a generation
121             // immediately, since we cloned a receiver that has never seen any values.
122             if *wait.borrow() >= generation {
123                 break;
124             }
125         }
126 
127         BarrierWaitResult(false)
128     }
129 }
130 
131 /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused.
132 #[derive(Debug, Clone)]
133 pub struct BarrierWaitResult(bool);
134 
135 impl BarrierWaitResult {
136     /// Returns `true` if this task from wait is the "leader task".
137     ///
138     /// Only one task will have `true` returned from their result, all other tasks will have
139     /// `false` returned.
is_leader(&self) -> bool140     pub fn is_leader(&self) -> bool {
141         self.0
142     }
143 }
144