• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::job::{ArcJob, StackJob};
2 use crate::latch::{CountLatch, LatchRef};
3 use crate::registry::{Registry, WorkerThread};
4 use std::fmt;
5 use std::marker::PhantomData;
6 use std::sync::Arc;
7 
8 mod test;
9 
10 /// Executes `op` within every thread in the current threadpool. If this is
11 /// called from a non-Rayon thread, it will execute in the global threadpool.
12 /// Any attempts to use `join`, `scope`, or parallel iterators will then operate
13 /// within that threadpool. When the call has completed on each thread, returns
14 /// a vector containing all of their return values.
15 ///
16 /// For more information, see the [`ThreadPool::broadcast()`][m] method.
17 ///
18 /// [m]: struct.ThreadPool.html#method.broadcast
broadcast<OP, R>(op: OP) -> Vec<R> where OP: Fn(BroadcastContext<'_>) -> R + Sync, R: Send,19 pub fn broadcast<OP, R>(op: OP) -> Vec<R>
20 where
21     OP: Fn(BroadcastContext<'_>) -> R + Sync,
22     R: Send,
23 {
24     // We assert that current registry has not terminated.
25     unsafe { broadcast_in(op, &Registry::current()) }
26 }
27 
28 /// Spawns an asynchronous task on every thread in this thread-pool. This task
29 /// will run in the implicit, global scope, which means that it may outlast the
30 /// current stack frame -- therefore, it cannot capture any references onto the
31 /// stack (you will likely need a `move` closure).
32 ///
33 /// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
34 ///
35 /// [m]: struct.ThreadPool.html#method.spawn_broadcast
spawn_broadcast<OP>(op: OP) where OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,36 pub fn spawn_broadcast<OP>(op: OP)
37 where
38     OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
39 {
40     // We assert that current registry has not terminated.
41     unsafe { spawn_broadcast_in(op, &Registry::current()) }
42 }
43 
44 /// Provides context to a closure called by `broadcast`.
45 pub struct BroadcastContext<'a> {
46     worker: &'a WorkerThread,
47 
48     /// Make sure to prevent auto-traits like `Send` and `Sync`.
49     _marker: PhantomData<&'a mut dyn Fn()>,
50 }
51 
52 impl<'a> BroadcastContext<'a> {
with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R53     pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
54         let worker_thread = WorkerThread::current();
55         assert!(!worker_thread.is_null());
56         f(BroadcastContext {
57             worker: unsafe { &*worker_thread },
58             _marker: PhantomData,
59         })
60     }
61 
62     /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
63     #[inline]
index(&self) -> usize64     pub fn index(&self) -> usize {
65         self.worker.index()
66     }
67 
68     /// The number of threads receiving the broadcast in the thread pool.
69     ///
70     /// # Future compatibility note
71     ///
72     /// Future versions of Rayon might vary the number of threads over time, but
73     /// this method will always return the number of threads which are actually
74     /// receiving your particular `broadcast` call.
75     #[inline]
num_threads(&self) -> usize76     pub fn num_threads(&self) -> usize {
77         self.worker.registry().num_threads()
78     }
79 }
80 
81 impl<'a> fmt::Debug for BroadcastContext<'a> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result82     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
83         fmt.debug_struct("BroadcastContext")
84             .field("index", &self.index())
85             .field("num_threads", &self.num_threads())
86             .field("pool_id", &self.worker.registry().id())
87             .finish()
88     }
89 }
90 
91 /// Execute `op` on every thread in the pool. It will be executed on each
92 /// thread when they have nothing else to do locally, before they try to
93 /// steal work from other threads. This function will not return until all
94 /// threads have completed the `op`.
95 ///
96 /// Unsafe because `registry` must not yet have terminated.
broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R> where OP: Fn(BroadcastContext<'_>) -> R + Sync, R: Send,97 pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
98 where
99     OP: Fn(BroadcastContext<'_>) -> R + Sync,
100     R: Send,
101 {
102     let f = move |injected: bool| {
103         debug_assert!(injected);
104         BroadcastContext::with(&op)
105     };
106 
107     let n_threads = registry.num_threads();
108     let current_thread = WorkerThread::current().as_ref();
109     let latch = CountLatch::with_count(n_threads, current_thread);
110     let jobs: Vec<_> = (0..n_threads)
111         .map(|_| StackJob::new(&f, LatchRef::new(&latch)))
112         .collect();
113     let job_refs = jobs.iter().map(|job| job.as_job_ref());
114 
115     registry.inject_broadcast(job_refs);
116 
117     // Wait for all jobs to complete, then collect the results, maybe propagating a panic.
118     latch.wait(current_thread);
119     jobs.into_iter().map(|job| job.into_result()).collect()
120 }
121 
122 /// Execute `op` on every thread in the pool. It will be executed on each
123 /// thread when they have nothing else to do locally, before they try to
124 /// steal work from other threads. This function returns immediately after
125 /// injecting the jobs.
126 ///
127 /// Unsafe because `registry` must not yet have terminated.
spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>) where OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,128 pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
129 where
130     OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
131 {
132     let job = ArcJob::new({
133         let registry = Arc::clone(registry);
134         move || {
135             registry.catch_unwind(|| BroadcastContext::with(&op));
136             registry.terminate(); // (*) permit registry to terminate now
137         }
138     });
139 
140     let n_threads = registry.num_threads();
141     let job_refs = (0..n_threads).map(|_| {
142         // Ensure that registry cannot terminate until this job has executed
143         // on each thread. This ref is decremented at the (*) above.
144         registry.increment_terminate_count();
145 
146         ArcJob::as_static_job_ref(&job)
147     });
148 
149     registry.inject_broadcast(job_refs);
150 }
151