• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::job::{JobFifo, JobRef, StackJob};
2 use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LatchRef, LockLatch, SpinLatch};
3 use crate::log::Event::*;
4 use crate::log::Logger;
5 use crate::sleep::Sleep;
6 use crate::unwind;
7 use crate::{
8     ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder,
9     Yield,
10 };
11 use crossbeam_deque::{Injector, Steal, Stealer, Worker};
12 use std::cell::Cell;
13 use std::collections::hash_map::DefaultHasher;
14 use std::fmt;
15 use std::hash::Hasher;
16 use std::io;
17 use std::mem;
18 use std::ptr;
19 use std::sync::atomic::{AtomicUsize, Ordering};
20 use std::sync::{Arc, Mutex, Once};
21 use std::thread;
22 use std::usize;
23 
24 /// Thread builder used for customization via
25 /// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler).
26 pub struct ThreadBuilder {
27     name: Option<String>,
28     stack_size: Option<usize>,
29     worker: Worker<JobRef>,
30     stealer: Stealer<JobRef>,
31     registry: Arc<Registry>,
32     index: usize,
33 }
34 
35 impl ThreadBuilder {
36     /// Gets the index of this thread in the pool, within `0..num_threads`.
index(&self) -> usize37     pub fn index(&self) -> usize {
38         self.index
39     }
40 
41     /// Gets the string that was specified by `ThreadPoolBuilder::name()`.
name(&self) -> Option<&str>42     pub fn name(&self) -> Option<&str> {
43         self.name.as_deref()
44     }
45 
46     /// Gets the value that was specified by `ThreadPoolBuilder::stack_size()`.
stack_size(&self) -> Option<usize>47     pub fn stack_size(&self) -> Option<usize> {
48         self.stack_size
49     }
50 
51     /// Executes the main loop for this thread. This will not return until the
52     /// thread pool is dropped.
run(self)53     pub fn run(self) {
54         unsafe { main_loop(self) }
55     }
56 }
57 
58 impl fmt::Debug for ThreadBuilder {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result59     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60         f.debug_struct("ThreadBuilder")
61             .field("pool", &self.registry.id())
62             .field("index", &self.index)
63             .field("name", &self.name)
64             .field("stack_size", &self.stack_size)
65             .finish()
66     }
67 }
68 
69 /// Generalized trait for spawning a thread in the `Registry`.
70 ///
71 /// This trait is pub-in-private -- E0445 forces us to make it public,
72 /// but we don't actually want to expose these details in the API.
73 pub trait ThreadSpawn {
74     private_decl! {}
75 
76     /// Spawn a thread with the `ThreadBuilder` parameters, and then
77     /// call `ThreadBuilder::run()`.
spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>78     fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
79 }
80 
81 /// Spawns a thread in the "normal" way with `std::thread::Builder`.
82 ///
83 /// This type is pub-in-private -- E0445 forces us to make it public,
84 /// but we don't actually want to expose these details in the API.
85 #[derive(Debug, Default)]
86 pub struct DefaultSpawn;
87 
88 impl ThreadSpawn for DefaultSpawn {
89     private_impl! {}
90 
spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>91     fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
92         let mut b = thread::Builder::new();
93         if let Some(name) = thread.name() {
94             b = b.name(name.to_owned());
95         }
96         if let Some(stack_size) = thread.stack_size() {
97             b = b.stack_size(stack_size);
98         }
99         b.spawn(|| thread.run())?;
100         Ok(())
101     }
102 }
103 
104 /// Spawns a thread with a user's custom callback.
105 ///
106 /// This type is pub-in-private -- E0445 forces us to make it public,
107 /// but we don't actually want to expose these details in the API.
108 #[derive(Debug)]
109 pub struct CustomSpawn<F>(F);
110 
111 impl<F> CustomSpawn<F>
112 where
113     F: FnMut(ThreadBuilder) -> io::Result<()>,
114 {
new(spawn: F) -> Self115     pub(super) fn new(spawn: F) -> Self {
116         CustomSpawn(spawn)
117     }
118 }
119 
120 impl<F> ThreadSpawn for CustomSpawn<F>
121 where
122     F: FnMut(ThreadBuilder) -> io::Result<()>,
123 {
124     private_impl! {}
125 
126     #[inline]
spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>127     fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
128         (self.0)(thread)
129     }
130 }
131 
132 pub(super) struct Registry {
133     logger: Logger,
134     thread_infos: Vec<ThreadInfo>,
135     sleep: Sleep,
136     injected_jobs: Injector<JobRef>,
137     broadcasts: Mutex<Vec<Worker<JobRef>>>,
138     panic_handler: Option<Box<PanicHandler>>,
139     start_handler: Option<Box<StartHandler>>,
140     exit_handler: Option<Box<ExitHandler>>,
141 
142     // When this latch reaches 0, it means that all work on this
143     // registry must be complete. This is ensured in the following ways:
144     //
145     // - if this is the global registry, there is a ref-count that never
146     //   gets released.
147     // - if this is a user-created thread-pool, then so long as the thread-pool
148     //   exists, it holds a reference.
149     // - when we inject a "blocking job" into the registry with `ThreadPool::install()`,
150     //   no adjustment is needed; the `ThreadPool` holds the reference, and since we won't
151     //   return until the blocking job is complete, that ref will continue to be held.
152     // - when `join()` or `scope()` is invoked, similarly, no adjustments are needed.
153     //   These are always owned by some other job (e.g., one injected by `ThreadPool::install()`)
154     //   and that job will keep the pool alive.
155     terminate_count: AtomicUsize,
156 }
157 
158 /// ////////////////////////////////////////////////////////////////////////
159 /// Initialization
160 
161 static mut THE_REGISTRY: Option<Arc<Registry>> = None;
162 static THE_REGISTRY_SET: Once = Once::new();
163 
164 /// Starts the worker threads (if that has not already happened). If
165 /// initialization has not already occurred, use the default
166 /// configuration.
global_registry() -> &'static Arc<Registry>167 pub(super) fn global_registry() -> &'static Arc<Registry> {
168     set_global_registry(default_global_registry)
169         .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) })
170         .expect("The global thread pool has not been initialized.")
171 }
172 
173 /// Starts the worker threads (if that has not already happened) with
174 /// the given builder.
init_global_registry<S>( builder: ThreadPoolBuilder<S>, ) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> where S: ThreadSpawn,175 pub(super) fn init_global_registry<S>(
176     builder: ThreadPoolBuilder<S>,
177 ) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
178 where
179     S: ThreadSpawn,
180 {
181     set_global_registry(|| Registry::new(builder))
182 }
183 
184 /// Starts the worker threads (if that has not already happened)
185 /// by creating a registry with the given callback.
set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> where F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,186 fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
187 where
188     F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
189 {
190     let mut result = Err(ThreadPoolBuildError::new(
191         ErrorKind::GlobalPoolAlreadyInitialized,
192     ));
193 
194     THE_REGISTRY_SET.call_once(|| {
195         result = registry()
196             .map(|registry: Arc<Registry>| unsafe { &*THE_REGISTRY.get_or_insert(registry) })
197     });
198 
199     result
200 }
201 
default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError>202 fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
203     let result = Registry::new(ThreadPoolBuilder::new());
204 
205     // If we're running in an environment that doesn't support threads at all, we can fall back to
206     // using the current thread alone. This is crude, and probably won't work for non-blocking
207     // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine.
208     //
209     // Notably, this allows current WebAssembly targets to work even though their threading support
210     // is stubbed out, and we won't have to change anything if they do add real threading.
211     let unsupported = matches!(&result, Err(e) if e.is_unsupported());
212     if unsupported && WorkerThread::current().is_null() {
213         let builder = ThreadPoolBuilder::new()
214             .num_threads(1)
215             .spawn_handler(|thread| {
216                 // Rather than starting a new thread, we're just taking over the current thread
217                 // *without* running the main loop, so we can still return from here.
218                 // The WorkerThread is leaked, but we never shutdown the global pool anyway.
219                 let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));
220                 let registry = &*worker_thread.registry;
221                 let index = worker_thread.index;
222 
223                 unsafe {
224                     WorkerThread::set_current(worker_thread);
225 
226                     // let registry know we are ready to do work
227                     Latch::set(&registry.thread_infos[index].primed);
228                 }
229 
230                 Ok(())
231             });
232 
233         let fallback_result = Registry::new(builder);
234         if fallback_result.is_ok() {
235             return fallback_result;
236         }
237     }
238 
239     result
240 }
241 
242 struct Terminator<'a>(&'a Arc<Registry>);
243 
244 impl<'a> Drop for Terminator<'a> {
drop(&mut self)245     fn drop(&mut self) {
246         self.0.terminate()
247     }
248 }
249 
250 impl Registry {
new<S>( mut builder: ThreadPoolBuilder<S>, ) -> Result<Arc<Self>, ThreadPoolBuildError> where S: ThreadSpawn,251     pub(super) fn new<S>(
252         mut builder: ThreadPoolBuilder<S>,
253     ) -> Result<Arc<Self>, ThreadPoolBuildError>
254     where
255         S: ThreadSpawn,
256     {
257         // Soft-limit the number of threads that we can actually support.
258         let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads());
259 
260         let breadth_first = builder.get_breadth_first();
261 
262         let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
263             .map(|_| {
264                 let worker = if breadth_first {
265                     Worker::new_fifo()
266                 } else {
267                     Worker::new_lifo()
268                 };
269 
270                 let stealer = worker.stealer();
271                 (worker, stealer)
272             })
273             .unzip();
274 
275         let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads)
276             .map(|_| {
277                 let worker = Worker::new_fifo();
278                 let stealer = worker.stealer();
279                 (worker, stealer)
280             })
281             .unzip();
282 
283         let logger = Logger::new(n_threads);
284         let registry = Arc::new(Registry {
285             logger: logger.clone(),
286             thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
287             sleep: Sleep::new(logger, n_threads),
288             injected_jobs: Injector::new(),
289             broadcasts: Mutex::new(broadcasts),
290             terminate_count: AtomicUsize::new(1),
291             panic_handler: builder.take_panic_handler(),
292             start_handler: builder.take_start_handler(),
293             exit_handler: builder.take_exit_handler(),
294         });
295 
296         // If we return early or panic, make sure to terminate existing threads.
297         let t1000 = Terminator(&registry);
298 
299         for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() {
300             let thread = ThreadBuilder {
301                 name: builder.get_thread_name(index),
302                 stack_size: builder.get_stack_size(),
303                 registry: Arc::clone(&registry),
304                 worker,
305                 stealer,
306                 index,
307             };
308             if let Err(e) = builder.get_spawn_handler().spawn(thread) {
309                 return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
310             }
311         }
312 
313         // Returning normally now, without termination.
314         mem::forget(t1000);
315 
316         Ok(registry)
317     }
318 
current() -> Arc<Registry>319     pub(super) fn current() -> Arc<Registry> {
320         unsafe {
321             let worker_thread = WorkerThread::current();
322             let registry = if worker_thread.is_null() {
323                 global_registry()
324             } else {
325                 &(*worker_thread).registry
326             };
327             Arc::clone(registry)
328         }
329     }
330 
331     /// Returns the number of threads in the current registry.  This
332     /// is better than `Registry::current().num_threads()` because it
333     /// avoids incrementing the `Arc`.
current_num_threads() -> usize334     pub(super) fn current_num_threads() -> usize {
335         unsafe {
336             let worker_thread = WorkerThread::current();
337             if worker_thread.is_null() {
338                 global_registry().num_threads()
339             } else {
340                 (*worker_thread).registry.num_threads()
341             }
342         }
343     }
344 
345     /// Returns the current `WorkerThread` if it's part of this `Registry`.
current_thread(&self) -> Option<&WorkerThread>346     pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
347         unsafe {
348             let worker = WorkerThread::current().as_ref()?;
349             if worker.registry().id() == self.id() {
350                 Some(worker)
351             } else {
352                 None
353             }
354         }
355     }
356 
357     /// Returns an opaque identifier for this registry.
id(&self) -> RegistryId358     pub(super) fn id(&self) -> RegistryId {
359         // We can rely on `self` not to change since we only ever create
360         // registries that are boxed up in an `Arc` (see `new()` above).
361         RegistryId {
362             addr: self as *const Self as usize,
363         }
364     }
365 
366     #[inline]
log(&self, event: impl FnOnce() -> crate::log::Event)367     pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) {
368         self.logger.log(event)
369     }
370 
num_threads(&self) -> usize371     pub(super) fn num_threads(&self) -> usize {
372         self.thread_infos.len()
373     }
374 
catch_unwind(&self, f: impl FnOnce())375     pub(super) fn catch_unwind(&self, f: impl FnOnce()) {
376         if let Err(err) = unwind::halt_unwinding(f) {
377             // If there is no handler, or if that handler itself panics, then we abort.
378             let abort_guard = unwind::AbortIfPanic;
379             if let Some(ref handler) = self.panic_handler {
380                 handler(err);
381                 mem::forget(abort_guard);
382             }
383         }
384     }
385 
386     /// Waits for the worker threads to get up and running.  This is
387     /// meant to be used for benchmarking purposes, primarily, so that
388     /// you can get more consistent numbers by having everything
389     /// "ready to go".
wait_until_primed(&self)390     pub(super) fn wait_until_primed(&self) {
391         for info in &self.thread_infos {
392             info.primed.wait();
393         }
394     }
395 
396     /// Waits for the worker threads to stop. This is used for testing
397     /// -- so we can check that termination actually works.
398     #[cfg(test)]
wait_until_stopped(&self)399     pub(super) fn wait_until_stopped(&self) {
400         for info in &self.thread_infos {
401             info.stopped.wait();
402         }
403     }
404 
405     /// ////////////////////////////////////////////////////////////////////////
406     /// MAIN LOOP
407     ///
408     /// So long as all of the worker threads are hanging out in their
409     /// top-level loop, there is no work to be done.
410 
411     /// Push a job into the given `registry`. If we are running on a
412     /// worker thread for the registry, this will push onto the
413     /// deque. Else, it will inject from the outside (which is slower).
inject_or_push(&self, job_ref: JobRef)414     pub(super) fn inject_or_push(&self, job_ref: JobRef) {
415         let worker_thread = WorkerThread::current();
416         unsafe {
417             if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
418                 (*worker_thread).push(job_ref);
419             } else {
420                 self.inject(job_ref);
421             }
422         }
423     }
424 
425     /// Push a job into the "external jobs" queue; it will be taken by
426     /// whatever worker has nothing to do. Use this if you know that
427     /// you are not on a worker of this registry.
inject(&self, injected_job: JobRef)428     pub(super) fn inject(&self, injected_job: JobRef) {
429         self.log(|| JobsInjected { count: 1 });
430 
431         // It should not be possible for `state.terminate` to be true
432         // here. It is only set to true when the user creates (and
433         // drops) a `ThreadPool`; and, in that case, they cannot be
434         // calling `inject()` later, since they dropped their
435         // `ThreadPool`.
436         debug_assert_ne!(
437             self.terminate_count.load(Ordering::Acquire),
438             0,
439             "inject() sees state.terminate as true"
440         );
441 
442         let queue_was_empty = self.injected_jobs.is_empty();
443 
444         self.injected_jobs.push(injected_job);
445         self.sleep.new_injected_jobs(usize::MAX, 1, queue_was_empty);
446     }
447 
has_injected_job(&self) -> bool448     fn has_injected_job(&self) -> bool {
449         !self.injected_jobs.is_empty()
450     }
451 
pop_injected_job(&self, worker_index: usize) -> Option<JobRef>452     fn pop_injected_job(&self, worker_index: usize) -> Option<JobRef> {
453         loop {
454             match self.injected_jobs.steal() {
455                 Steal::Success(job) => {
456                     self.log(|| JobUninjected {
457                         worker: worker_index,
458                     });
459                     return Some(job);
460                 }
461                 Steal::Empty => return None,
462                 Steal::Retry => {}
463             }
464         }
465     }
466 
467     /// Push a job into each thread's own "external jobs" queue; it will be
468     /// executed only on that thread, when it has nothing else to do locally,
469     /// before it tries to steal other work.
470     ///
471     /// **Panics** if not given exactly as many jobs as there are threads.
inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>)472     pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) {
473         assert_eq!(self.num_threads(), injected_jobs.len());
474         self.log(|| JobBroadcast {
475             count: self.num_threads(),
476         });
477         {
478             let broadcasts = self.broadcasts.lock().unwrap();
479 
480             // It should not be possible for `state.terminate` to be true
481             // here. It is only set to true when the user creates (and
482             // drops) a `ThreadPool`; and, in that case, they cannot be
483             // calling `inject_broadcast()` later, since they dropped their
484             // `ThreadPool`.
485             debug_assert_ne!(
486                 self.terminate_count.load(Ordering::Acquire),
487                 0,
488                 "inject_broadcast() sees state.terminate as true"
489             );
490 
491             assert_eq!(broadcasts.len(), injected_jobs.len());
492             for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) {
493                 worker.push(job_ref);
494             }
495         }
496         for i in 0..self.num_threads() {
497             self.sleep.notify_worker_latch_is_set(i);
498         }
499     }
500 
501     /// If already in a worker-thread of this registry, just execute `op`.
502     /// Otherwise, inject `op` in this thread-pool. Either way, block until `op`
503     /// completes and return its return value. If `op` panics, that panic will
504     /// be propagated as well.  The second argument indicates `true` if injection
505     /// was performed, `false` if executed directly.
in_worker<OP, R>(&self, op: OP) -> R where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send,506     pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
507     where
508         OP: FnOnce(&WorkerThread, bool) -> R + Send,
509         R: Send,
510     {
511         unsafe {
512             let worker_thread = WorkerThread::current();
513             if worker_thread.is_null() {
514                 self.in_worker_cold(op)
515             } else if (*worker_thread).registry().id() != self.id() {
516                 self.in_worker_cross(&*worker_thread, op)
517             } else {
518                 // Perfectly valid to give them a `&T`: this is the
519                 // current thread, so we know the data structure won't be
520                 // invalidated until we return.
521                 op(&*worker_thread, false)
522             }
523         }
524     }
525 
526     #[cold]
in_worker_cold<OP, R>(&self, op: OP) -> R where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send,527     unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
528     where
529         OP: FnOnce(&WorkerThread, bool) -> R + Send,
530         R: Send,
531     {
532         thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());
533 
534         LOCK_LATCH.with(|l| {
535             // This thread isn't a member of *any* thread pool, so just block.
536             debug_assert!(WorkerThread::current().is_null());
537             let job = StackJob::new(
538                 |injected| {
539                     let worker_thread = WorkerThread::current();
540                     assert!(injected && !worker_thread.is_null());
541                     op(&*worker_thread, true)
542                 },
543                 LatchRef::new(l),
544             );
545             self.inject(job.as_job_ref());
546             job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.
547 
548             // flush accumulated logs as we exit the thread
549             self.logger.log(|| Flush);
550 
551             job.into_result()
552         })
553     }
554 
555     #[cold]
in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send,556     unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
557     where
558         OP: FnOnce(&WorkerThread, bool) -> R + Send,
559         R: Send,
560     {
561         // This thread is a member of a different pool, so let it process
562         // other work while waiting for this `op` to complete.
563         debug_assert!(current_thread.registry().id() != self.id());
564         let latch = SpinLatch::cross(current_thread);
565         let job = StackJob::new(
566             |injected| {
567                 let worker_thread = WorkerThread::current();
568                 assert!(injected && !worker_thread.is_null());
569                 op(&*worker_thread, true)
570             },
571             latch,
572         );
573         self.inject(job.as_job_ref());
574         current_thread.wait_until(&job.latch);
575         job.into_result()
576     }
577 
578     /// Increments the terminate counter. This increment should be
579     /// balanced by a call to `terminate`, which will decrement. This
580     /// is used when spawning asynchronous work, which needs to
581     /// prevent the registry from terminating so long as it is active.
582     ///
583     /// Note that blocking functions such as `join` and `scope` do not
584     /// need to concern themselves with this fn; their context is
585     /// responsible for ensuring the current thread-pool will not
586     /// terminate until they return.
587     ///
588     /// The global thread-pool always has an outstanding reference
589     /// (the initial one). Custom thread-pools have one outstanding
590     /// reference that is dropped when the `ThreadPool` is dropped:
591     /// since installing the thread-pool blocks until any joins/scopes
592     /// complete, this ensures that joins/scopes are covered.
593     ///
594     /// The exception is `::spawn()`, which can create a job outside
595     /// of any blocking scope. In that case, the job itself holds a
596     /// terminate count and is responsible for invoking `terminate()`
597     /// when finished.
increment_terminate_count(&self)598     pub(super) fn increment_terminate_count(&self) {
599         let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
600         debug_assert!(previous != 0, "registry ref count incremented from zero");
601         assert!(
602             previous != std::usize::MAX,
603             "overflow in registry ref count"
604         );
605     }
606 
607     /// Signals that the thread-pool which owns this registry has been
608     /// dropped. The worker threads will gradually terminate, once any
609     /// extant work is completed.
terminate(&self)610     pub(super) fn terminate(&self) {
611         if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
612             for (i, thread_info) in self.thread_infos.iter().enumerate() {
613                 unsafe { CountLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
614             }
615         }
616     }
617 
618     /// Notify the worker that the latch they are sleeping on has been "set".
notify_worker_latch_is_set(&self, target_worker_index: usize)619     pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
620         self.sleep.notify_worker_latch_is_set(target_worker_index);
621     }
622 }
623 
624 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
625 pub(super) struct RegistryId {
626     addr: usize,
627 }
628 
629 struct ThreadInfo {
630     /// Latch set once thread has started and we are entering into the
631     /// main loop. Used to wait for worker threads to become primed,
632     /// primarily of interest for benchmarking.
633     primed: LockLatch,
634 
635     /// Latch is set once worker thread has completed. Used to wait
636     /// until workers have stopped; only used for tests.
637     stopped: LockLatch,
638 
639     /// The latch used to signal that terminated has been requested.
640     /// This latch is *set* by the `terminate` method on the
641     /// `Registry`, once the registry's main "terminate" counter
642     /// reaches zero.
643     ///
644     /// NB. We use a `CountLatch` here because it has no lifetimes and is
645     /// meant for async use, but the count never gets higher than one.
646     terminate: CountLatch,
647 
648     /// the "stealer" half of the worker's deque
649     stealer: Stealer<JobRef>,
650 }
651 
652 impl ThreadInfo {
new(stealer: Stealer<JobRef>) -> ThreadInfo653     fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
654         ThreadInfo {
655             primed: LockLatch::new(),
656             stopped: LockLatch::new(),
657             terminate: CountLatch::new(),
658             stealer,
659         }
660     }
661 }
662 
663 /// ////////////////////////////////////////////////////////////////////////
664 /// WorkerThread identifiers
665 
666 pub(super) struct WorkerThread {
667     /// the "worker" half of our local deque
668     worker: Worker<JobRef>,
669 
670     /// the "stealer" half of the worker's broadcast deque
671     stealer: Stealer<JobRef>,
672 
673     /// local queue used for `spawn_fifo` indirection
674     fifo: JobFifo,
675 
676     index: usize,
677 
678     /// A weak random number generator.
679     rng: XorShift64Star,
680 
681     registry: Arc<Registry>,
682 }
683 
684 // This is a bit sketchy, but basically: the WorkerThread is
685 // allocated on the stack of the worker on entry and stored into this
686 // thread local variable. So it will remain valid at least until the
687 // worker is fully unwound. Using an unsafe pointer avoids the need
688 // for a RefCell<T> etc.
689 thread_local! {
690     static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) };
691 }
692 
693 impl From<ThreadBuilder> for WorkerThread {
from(thread: ThreadBuilder) -> Self694     fn from(thread: ThreadBuilder) -> Self {
695         Self {
696             worker: thread.worker,
697             stealer: thread.stealer,
698             fifo: JobFifo::new(),
699             index: thread.index,
700             rng: XorShift64Star::new(),
701             registry: thread.registry,
702         }
703     }
704 }
705 
706 impl Drop for WorkerThread {
drop(&mut self)707     fn drop(&mut self) {
708         // Undo `set_current`
709         WORKER_THREAD_STATE.with(|t| {
710             assert!(t.get().eq(&(self as *const _)));
711             t.set(ptr::null());
712         });
713     }
714 }
715 
716 impl WorkerThread {
717     /// Gets the `WorkerThread` index for the current thread; returns
718     /// NULL if this is not a worker thread. This pointer is valid
719     /// anywhere on the current thread.
720     #[inline]
current() -> *const WorkerThread721     pub(super) fn current() -> *const WorkerThread {
722         WORKER_THREAD_STATE.with(Cell::get)
723     }
724 
725     /// Sets `self` as the worker thread index for the current thread.
726     /// This is done during worker thread startup.
set_current(thread: *const WorkerThread)727     unsafe fn set_current(thread: *const WorkerThread) {
728         WORKER_THREAD_STATE.with(|t| {
729             assert!(t.get().is_null());
730             t.set(thread);
731         });
732     }
733 
734     /// Returns the registry that owns this worker thread.
735     #[inline]
registry(&self) -> &Arc<Registry>736     pub(super) fn registry(&self) -> &Arc<Registry> {
737         &self.registry
738     }
739 
740     #[inline]
log(&self, event: impl FnOnce() -> crate::log::Event)741     pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) {
742         self.registry.logger.log(event)
743     }
744 
745     /// Our index amongst the worker threads (ranges from `0..self.num_threads()`).
746     #[inline]
index(&self) -> usize747     pub(super) fn index(&self) -> usize {
748         self.index
749     }
750 
751     #[inline]
push(&self, job: JobRef)752     pub(super) unsafe fn push(&self, job: JobRef) {
753         self.log(|| JobPushed { worker: self.index });
754         let queue_was_empty = self.worker.is_empty();
755         self.worker.push(job);
756         self.registry
757             .sleep
758             .new_internal_jobs(self.index, 1, queue_was_empty);
759     }
760 
761     #[inline]
push_fifo(&self, job: JobRef)762     pub(super) unsafe fn push_fifo(&self, job: JobRef) {
763         self.push(self.fifo.push(job));
764     }
765 
766     #[inline]
local_deque_is_empty(&self) -> bool767     pub(super) fn local_deque_is_empty(&self) -> bool {
768         self.worker.is_empty()
769     }
770 
771     /// Attempts to obtain a "local" job -- typically this means
772     /// popping from the top of the stack, though if we are configured
773     /// for breadth-first execution, it would mean dequeuing from the
774     /// bottom.
775     #[inline]
take_local_job(&self) -> Option<JobRef>776     pub(super) fn take_local_job(&self) -> Option<JobRef> {
777         let popped_job = self.worker.pop();
778 
779         if popped_job.is_some() {
780             self.log(|| JobPopped { worker: self.index });
781             return popped_job;
782         }
783 
784         loop {
785             match self.stealer.steal() {
786                 Steal::Success(job) => return Some(job),
787                 Steal::Empty => return None,
788                 Steal::Retry => {}
789             }
790         }
791     }
792 
has_injected_job(&self) -> bool793     fn has_injected_job(&self) -> bool {
794         !self.stealer.is_empty() || self.registry.has_injected_job()
795     }
796 
797     /// Wait until the latch is set. Try to keep busy by popping and
798     /// stealing tasks as necessary.
799     #[inline]
wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L)800     pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
801         let latch = latch.as_core_latch();
802         if !latch.probe() {
803             self.wait_until_cold(latch);
804         }
805     }
806 
807     #[cold]
wait_until_cold(&self, latch: &CoreLatch)808     unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
809         // the code below should swallow all panics and hence never
810         // unwind; but if something does wrong, we want to abort,
811         // because otherwise other code in rayon may assume that the
812         // latch has been signaled, and that can lead to random memory
813         // accesses, which would be *very bad*
814         let abort_guard = unwind::AbortIfPanic;
815 
816         let mut idle_state = self.registry.sleep.start_looking(self.index, latch);
817         while !latch.probe() {
818             if let Some(job) = self.find_work() {
819                 self.registry.sleep.work_found(idle_state);
820                 self.execute(job);
821                 idle_state = self.registry.sleep.start_looking(self.index, latch);
822             } else {
823                 self.registry
824                     .sleep
825                     .no_work_found(&mut idle_state, latch, || self.has_injected_job())
826             }
827         }
828 
829         // If we were sleepy, we are not anymore. We "found work" --
830         // whatever the surrounding thread was doing before it had to
831         // wait.
832         self.registry.sleep.work_found(idle_state);
833 
834         self.log(|| ThreadSawLatchSet {
835             worker: self.index,
836             latch_addr: latch.addr(),
837         });
838         mem::forget(abort_guard); // successful execution, do not abort
839     }
840 
find_work(&self) -> Option<JobRef>841     fn find_work(&self) -> Option<JobRef> {
842         // Try to find some work to do. We give preference first
843         // to things in our local deque, then in other workers
844         // deques, and finally to injected jobs from the
845         // outside. The idea is to finish what we started before
846         // we take on something new.
847         self.take_local_job()
848             .or_else(|| self.steal())
849             .or_else(|| self.registry.pop_injected_job(self.index))
850     }
851 
yield_now(&self) -> Yield852     pub(super) fn yield_now(&self) -> Yield {
853         match self.find_work() {
854             Some(job) => unsafe {
855                 self.execute(job);
856                 Yield::Executed
857             },
858             None => Yield::Idle,
859         }
860     }
861 
yield_local(&self) -> Yield862     pub(super) fn yield_local(&self) -> Yield {
863         match self.take_local_job() {
864             Some(job) => unsafe {
865                 self.execute(job);
866                 Yield::Executed
867             },
868             None => Yield::Idle,
869         }
870     }
871 
872     #[inline]
execute(&self, job: JobRef)873     pub(super) unsafe fn execute(&self, job: JobRef) {
874         job.execute();
875     }
876 
877     /// Try to steal a single job and return it.
878     ///
879     /// This should only be done as a last resort, when there is no
880     /// local work to do.
steal(&self) -> Option<JobRef>881     fn steal(&self) -> Option<JobRef> {
882         // we only steal when we don't have any work to do locally
883         debug_assert!(self.local_deque_is_empty());
884 
885         // otherwise, try to steal
886         let thread_infos = &self.registry.thread_infos.as_slice();
887         let num_threads = thread_infos.len();
888         if num_threads <= 1 {
889             return None;
890         }
891 
892         loop {
893             let mut retry = false;
894             let start = self.rng.next_usize(num_threads);
895             let job = (start..num_threads)
896                 .chain(0..start)
897                 .filter(move |&i| i != self.index)
898                 .find_map(|victim_index| {
899                     let victim = &thread_infos[victim_index];
900                     match victim.stealer.steal() {
901                         Steal::Success(job) => {
902                             self.log(|| JobStolen {
903                                 worker: self.index,
904                                 victim: victim_index,
905                             });
906                             Some(job)
907                         }
908                         Steal::Empty => None,
909                         Steal::Retry => {
910                             retry = true;
911                             None
912                         }
913                     }
914                 });
915             if job.is_some() || !retry {
916                 return job;
917             }
918         }
919     }
920 }
921 
922 /// ////////////////////////////////////////////////////////////////////////
923 
main_loop(thread: ThreadBuilder)924 unsafe fn main_loop(thread: ThreadBuilder) {
925     let worker_thread = &WorkerThread::from(thread);
926     WorkerThread::set_current(worker_thread);
927     let registry = &*worker_thread.registry;
928     let index = worker_thread.index;
929 
930     // let registry know we are ready to do work
931     Latch::set(&registry.thread_infos[index].primed);
932 
933     // Worker threads should not panic. If they do, just abort, as the
934     // internal state of the threadpool is corrupted. Note that if
935     // **user code** panics, we should catch that and redirect.
936     let abort_guard = unwind::AbortIfPanic;
937 
938     // Inform a user callback that we started a thread.
939     if let Some(ref handler) = registry.start_handler {
940         registry.catch_unwind(|| handler(index));
941     }
942 
943     let my_terminate_latch = &registry.thread_infos[index].terminate;
944     worker_thread.log(|| ThreadStart {
945         worker: index,
946         terminate_addr: my_terminate_latch.as_core_latch().addr(),
947     });
948     worker_thread.wait_until(my_terminate_latch);
949 
950     // Should not be any work left in our queue.
951     debug_assert!(worker_thread.take_local_job().is_none());
952 
953     // let registry know we are done
954     Latch::set(&registry.thread_infos[index].stopped);
955 
956     // Normal termination, do not abort.
957     mem::forget(abort_guard);
958 
959     worker_thread.log(|| ThreadTerminate { worker: index });
960 
961     // Inform a user callback that we exited a thread.
962     if let Some(ref handler) = registry.exit_handler {
963         registry.catch_unwind(|| handler(index));
964         // We're already exiting the thread, there's nothing else to do.
965     }
966 }
967 
968 /// If already in a worker-thread, just execute `op`.  Otherwise,
969 /// execute `op` in the default thread-pool. Either way, block until
970 /// `op` completes and return its return value. If `op` panics, that
971 /// panic will be propagated as well.  The second argument indicates
972 /// `true` if injection was performed, `false` if executed directly.
in_worker<OP, R>(op: OP) -> R where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send,973 pub(super) fn in_worker<OP, R>(op: OP) -> R
974 where
975     OP: FnOnce(&WorkerThread, bool) -> R + Send,
976     R: Send,
977 {
978     unsafe {
979         let owner_thread = WorkerThread::current();
980         if !owner_thread.is_null() {
981             // Perfectly valid to give them a `&T`: this is the
982             // current thread, so we know the data structure won't be
983             // invalidated until we return.
984             op(&*owner_thread, false)
985         } else {
986             global_registry().in_worker(op)
987         }
988     }
989 }
990 
991 /// [xorshift*] is a fast pseudorandom number generator which will
992 /// even tolerate weak seeding, as long as it's not zero.
993 ///
994 /// [xorshift*]: https://en.wikipedia.org/wiki/Xorshift#xorshift*
995 struct XorShift64Star {
996     state: Cell<u64>,
997 }
998 
999 impl XorShift64Star {
new() -> Self1000     fn new() -> Self {
1001         // Any non-zero seed will do -- this uses the hash of a global counter.
1002         let mut seed = 0;
1003         while seed == 0 {
1004             let mut hasher = DefaultHasher::new();
1005             static COUNTER: AtomicUsize = AtomicUsize::new(0);
1006             hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
1007             seed = hasher.finish();
1008         }
1009 
1010         XorShift64Star {
1011             state: Cell::new(seed),
1012         }
1013     }
1014 
next(&self) -> u641015     fn next(&self) -> u64 {
1016         let mut x = self.state.get();
1017         debug_assert_ne!(x, 0);
1018         x ^= x >> 12;
1019         x ^= x << 25;
1020         x ^= x >> 27;
1021         self.state.set(x);
1022         x.wrapping_mul(0x2545_f491_4f6c_dd1d)
1023     }
1024 
1025     /// Return a value from `0..n`.
next_usize(&self, n: usize) -> usize1026     fn next_usize(&self, n: usize) -> usize {
1027         (self.next() % n as u64) as usize
1028     }
1029 }
1030