• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::loom::sync::Mutex;
2 use crate::sync::watch;
3 #[cfg(all(tokio_unstable, feature = "tracing"))]
4 use crate::util::trace;
5 
6 /// A barrier enables multiple tasks to synchronize the beginning of some computation.
7 ///
8 /// ```
9 /// # #[tokio::main]
10 /// # async fn main() {
11 /// use tokio::sync::Barrier;
12 /// use std::sync::Arc;
13 ///
14 /// let mut handles = Vec::with_capacity(10);
15 /// let barrier = Arc::new(Barrier::new(10));
16 /// for _ in 0..10 {
17 ///     let c = barrier.clone();
18 ///     // The same messages will be printed together.
19 ///     // You will NOT see any interleaving.
20 ///     handles.push(tokio::spawn(async move {
21 ///         println!("before wait");
22 ///         let wait_result = c.wait().await;
23 ///         println!("after wait");
24 ///         wait_result
25 ///     }));
26 /// }
27 ///
28 /// // Will not resolve until all "after wait" messages have been printed
29 /// let mut num_leaders = 0;
30 /// for handle in handles {
31 ///     let wait_result = handle.await.unwrap();
32 ///     if wait_result.is_leader() {
33 ///         num_leaders += 1;
34 ///     }
35 /// }
36 ///
37 /// // Exactly one barrier will resolve as the "leader"
38 /// assert_eq!(num_leaders, 1);
39 /// # }
40 /// ```
41 #[derive(Debug)]
42 pub struct Barrier {
43     state: Mutex<BarrierState>,
44     wait: watch::Receiver<usize>,
45     n: usize,
46     #[cfg(all(tokio_unstable, feature = "tracing"))]
47     resource_span: tracing::Span,
48 }
49 
50 #[derive(Debug)]
51 struct BarrierState {
52     waker: watch::Sender<usize>,
53     arrived: usize,
54     generation: usize,
55 }
56 
57 impl Barrier {
58     /// Creates a new barrier that can block a given number of tasks.
59     ///
60     /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all
61     /// tasks at once when the `n`th task calls `wait`.
62     #[track_caller]
new(mut n: usize) -> Barrier63     pub fn new(mut n: usize) -> Barrier {
64         let (waker, wait) = crate::sync::watch::channel(0);
65 
66         if n == 0 {
67             // if n is 0, it's not clear what behavior the user wants.
68             // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
69             // .wait() immediately unblocks, so we adopt that here as well.
70             n = 1;
71         }
72 
73         #[cfg(all(tokio_unstable, feature = "tracing"))]
74         let resource_span = {
75             let location = std::panic::Location::caller();
76             let resource_span = tracing::trace_span!(
77                 "runtime.resource",
78                 concrete_type = "Barrier",
79                 kind = "Sync",
80                 loc.file = location.file(),
81                 loc.line = location.line(),
82                 loc.col = location.column(),
83             );
84 
85             resource_span.in_scope(|| {
86                 tracing::trace!(
87                     target: "runtime::resource::state_update",
88                     size = n,
89                 );
90 
91                 tracing::trace!(
92                     target: "runtime::resource::state_update",
93                     arrived = 0,
94                 )
95             });
96             resource_span
97         };
98 
99         Barrier {
100             state: Mutex::new(BarrierState {
101                 waker,
102                 arrived: 0,
103                 generation: 1,
104             }),
105             n,
106             wait,
107             #[cfg(all(tokio_unstable, feature = "tracing"))]
108             resource_span,
109         }
110     }
111 
112     /// Does not resolve until all tasks have rendezvoused here.
113     ///
114     /// Barriers are re-usable after all tasks have rendezvoused once, and can
115     /// be used continuously.
116     ///
117     /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
118     /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks
119     /// will receive a result that will return `false` from `is_leader`.
wait(&self) -> BarrierWaitResult120     pub async fn wait(&self) -> BarrierWaitResult {
121         #[cfg(all(tokio_unstable, feature = "tracing"))]
122         return trace::async_op(
123             || self.wait_internal(),
124             self.resource_span.clone(),
125             "Barrier::wait",
126             "poll",
127             false,
128         )
129         .await;
130 
131         #[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
132         return self.wait_internal().await;
133     }
wait_internal(&self) -> BarrierWaitResult134     async fn wait_internal(&self) -> BarrierWaitResult {
135         crate::trace::async_trace_leaf().await;
136 
137         // NOTE: we are taking a _synchronous_ lock here.
138         // It is okay to do so because the critical section is fast and never yields, so it cannot
139         // deadlock even if another future is concurrently holding the lock.
140         // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than
141         // the asynchronous counter-parts, so we should use them where possible [citation needed].
142         // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
143         // a yield point, and thus marks the returned future as !Send.
144         let generation = {
145             let mut state = self.state.lock();
146             let generation = state.generation;
147             state.arrived += 1;
148             #[cfg(all(tokio_unstable, feature = "tracing"))]
149             tracing::trace!(
150                 target: "runtime::resource::state_update",
151                 arrived = 1,
152                 arrived.op = "add",
153             );
154             #[cfg(all(tokio_unstable, feature = "tracing"))]
155             tracing::trace!(
156                 target: "runtime::resource::async_op::state_update",
157                 arrived = true,
158             );
159             if state.arrived == self.n {
160                 #[cfg(all(tokio_unstable, feature = "tracing"))]
161                 tracing::trace!(
162                     target: "runtime::resource::async_op::state_update",
163                     is_leader = true,
164                 );
165                 // we are the leader for this generation
166                 // wake everyone, increment the generation, and return
167                 state
168                     .waker
169                     .send(state.generation)
170                     .expect("there is at least one receiver");
171                 state.arrived = 0;
172                 state.generation += 1;
173                 return BarrierWaitResult(true);
174             }
175 
176             generation
177         };
178 
179         // we're going to have to wait for the last of the generation to arrive
180         let mut wait = self.wait.clone();
181 
182         loop {
183             let _ = wait.changed().await;
184 
185             // note that the first time through the loop, this _will_ yield a generation
186             // immediately, since we cloned a receiver that has never seen any values.
187             if *wait.borrow() >= generation {
188                 break;
189             }
190         }
191 
192         BarrierWaitResult(false)
193     }
194 }
195 
196 /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused.
197 #[derive(Debug, Clone)]
198 pub struct BarrierWaitResult(bool);
199 
200 impl BarrierWaitResult {
201     /// Returns `true` if this task from wait is the "leader task".
202     ///
203     /// Only one task will have `true` returned from their result, all other tasks will have
204     /// `false` returned.
is_leader(&self) -> bool205     pub fn is_leader(&self) -> bool {
206         self.0
207     }
208 }
209