• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Thread pool for blocking operations
2 
3 use crate::loom::sync::{Arc, Condvar, Mutex};
4 use crate::loom::thread;
5 use crate::runtime::blocking::schedule::BlockingSchedule;
6 use crate::runtime::blocking::{shutdown, BlockingTask};
7 use crate::runtime::builder::ThreadNameFn;
8 use crate::runtime::task::{self, JoinHandle};
9 use crate::runtime::{Builder, Callback, Handle};
10 
11 use std::collections::{HashMap, VecDeque};
12 use std::fmt;
13 use std::io;
14 use std::sync::atomic::{AtomicUsize, Ordering};
15 use std::time::Duration;
16 
17 pub(crate) struct BlockingPool {
18     spawner: Spawner,
19     shutdown_rx: shutdown::Receiver,
20 }
21 
22 #[derive(Clone)]
23 pub(crate) struct Spawner {
24     inner: Arc<Inner>,
25 }
26 
27 #[derive(Default)]
28 pub(crate) struct SpawnerMetrics {
29     num_threads: AtomicUsize,
30     num_idle_threads: AtomicUsize,
31     queue_depth: AtomicUsize,
32 }
33 
34 impl SpawnerMetrics {
num_threads(&self) -> usize35     fn num_threads(&self) -> usize {
36         self.num_threads.load(Ordering::Relaxed)
37     }
38 
num_idle_threads(&self) -> usize39     fn num_idle_threads(&self) -> usize {
40         self.num_idle_threads.load(Ordering::Relaxed)
41     }
42 
43     cfg_metrics! {
44         fn queue_depth(&self) -> usize {
45             self.queue_depth.load(Ordering::Relaxed)
46         }
47     }
48 
inc_num_threads(&self)49     fn inc_num_threads(&self) {
50         self.num_threads.fetch_add(1, Ordering::Relaxed);
51     }
52 
dec_num_threads(&self)53     fn dec_num_threads(&self) {
54         self.num_threads.fetch_sub(1, Ordering::Relaxed);
55     }
56 
inc_num_idle_threads(&self)57     fn inc_num_idle_threads(&self) {
58         self.num_idle_threads.fetch_add(1, Ordering::Relaxed);
59     }
60 
dec_num_idle_threads(&self) -> usize61     fn dec_num_idle_threads(&self) -> usize {
62         self.num_idle_threads.fetch_sub(1, Ordering::Relaxed)
63     }
64 
inc_queue_depth(&self)65     fn inc_queue_depth(&self) {
66         self.queue_depth.fetch_add(1, Ordering::Relaxed);
67     }
68 
dec_queue_depth(&self)69     fn dec_queue_depth(&self) {
70         self.queue_depth.fetch_sub(1, Ordering::Relaxed);
71     }
72 }
73 
74 struct Inner {
75     /// State shared between worker threads.
76     shared: Mutex<Shared>,
77 
78     /// Pool threads wait on this.
79     condvar: Condvar,
80 
81     /// Spawned threads use this name.
82     thread_name: ThreadNameFn,
83 
84     /// Spawned thread stack size.
85     stack_size: Option<usize>,
86 
87     /// Call after a thread starts.
88     after_start: Option<Callback>,
89 
90     /// Call before a thread stops.
91     before_stop: Option<Callback>,
92 
93     // Maximum number of threads.
94     thread_cap: usize,
95 
96     // Customizable wait timeout.
97     keep_alive: Duration,
98 
99     // Metrics about the pool.
100     metrics: SpawnerMetrics,
101 }
102 
103 struct Shared {
104     queue: VecDeque<Task>,
105     num_notify: u32,
106     shutdown: bool,
107     shutdown_tx: Option<shutdown::Sender>,
108     /// Prior to shutdown, we clean up JoinHandles by having each timed-out
109     /// thread join on the previous timed-out thread. This is not strictly
110     /// necessary but helps avoid Valgrind false positives, see
111     /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666>
112     /// for more information.
113     last_exiting_thread: Option<thread::JoinHandle<()>>,
114     /// This holds the JoinHandles for all running threads; on shutdown, the thread
115     /// calling shutdown handles joining on these.
116     worker_threads: HashMap<usize, thread::JoinHandle<()>>,
117     /// This is a counter used to iterate worker_threads in a consistent order (for loom's
118     /// benefit).
119     worker_thread_index: usize,
120 }
121 
122 pub(crate) struct Task {
123     task: task::UnownedTask<BlockingSchedule>,
124     mandatory: Mandatory,
125 }
126 
127 #[derive(PartialEq, Eq)]
128 pub(crate) enum Mandatory {
129     #[cfg_attr(not(fs), allow(dead_code))]
130     Mandatory,
131     NonMandatory,
132 }
133 
134 pub(crate) enum SpawnError {
135     /// Pool is shutting down and the task was not scheduled
136     ShuttingDown,
137     /// There are no worker threads available to take the task
138     /// and the OS failed to spawn a new one
139     NoThreads(io::Error),
140 }
141 
142 impl From<SpawnError> for io::Error {
from(e: SpawnError) -> Self143     fn from(e: SpawnError) -> Self {
144         match e {
145             SpawnError::ShuttingDown => {
146                 io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
147             }
148             SpawnError::NoThreads(e) => e,
149         }
150     }
151 }
152 
153 impl Task {
new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task154     pub(crate) fn new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task {
155         Task { task, mandatory }
156     }
157 
run(self)158     fn run(self) {
159         self.task.run();
160     }
161 
shutdown_or_run_if_mandatory(self)162     fn shutdown_or_run_if_mandatory(self) {
163         match self.mandatory {
164             Mandatory::NonMandatory => self.task.shutdown(),
165             Mandatory::Mandatory => self.task.run(),
166         }
167     }
168 }
169 
170 const KEEP_ALIVE: Duration = Duration::from_secs(10);
171 
172 /// Runs the provided function on an executor dedicated to blocking operations.
173 /// Tasks will be scheduled as non-mandatory, meaning they may not get executed
174 /// in case of runtime shutdown.
175 #[track_caller]
176 #[cfg_attr(target_os = "wasi", allow(dead_code))]
spawn_blocking<F, R>(func: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,177 pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
178 where
179     F: FnOnce() -> R + Send + 'static,
180     R: Send + 'static,
181 {
182     let rt = Handle::current();
183     rt.spawn_blocking(func)
184 }
185 
186 cfg_fs! {
187     #[cfg_attr(any(
188         all(loom, not(test)), // the function is covered by loom tests
189         test
190     ), allow(dead_code))]
191     /// Runs the provided function on an executor dedicated to blocking
192     /// operations. Tasks will be scheduled as mandatory, meaning they are
193     /// guaranteed to run unless a shutdown is already taking place. In case a
194     /// shutdown is already taking place, `None` will be returned.
195     pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
196     where
197         F: FnOnce() -> R + Send + 'static,
198         R: Send + 'static,
199     {
200         let rt = Handle::current();
201         rt.inner.blocking_spawner().spawn_mandatory_blocking(&rt, func)
202     }
203 }
204 
205 // ===== impl BlockingPool =====
206 
207 impl BlockingPool {
new(builder: &Builder, thread_cap: usize) -> BlockingPool208     pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
209         let (shutdown_tx, shutdown_rx) = shutdown::channel();
210         let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
211 
212         BlockingPool {
213             spawner: Spawner {
214                 inner: Arc::new(Inner {
215                     shared: Mutex::new(Shared {
216                         queue: VecDeque::new(),
217                         num_notify: 0,
218                         shutdown: false,
219                         shutdown_tx: Some(shutdown_tx),
220                         last_exiting_thread: None,
221                         worker_threads: HashMap::new(),
222                         worker_thread_index: 0,
223                     }),
224                     condvar: Condvar::new(),
225                     thread_name: builder.thread_name.clone(),
226                     stack_size: builder.thread_stack_size,
227                     after_start: builder.after_start.clone(),
228                     before_stop: builder.before_stop.clone(),
229                     thread_cap,
230                     keep_alive,
231                     metrics: Default::default(),
232                 }),
233             },
234             shutdown_rx,
235         }
236     }
237 
spawner(&self) -> &Spawner238     pub(crate) fn spawner(&self) -> &Spawner {
239         &self.spawner
240     }
241 
shutdown(&mut self, timeout: Option<Duration>)242     pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
243         let mut shared = self.spawner.inner.shared.lock();
244 
245         // The function can be called multiple times. First, by explicitly
246         // calling `shutdown` then by the drop handler calling `shutdown`. This
247         // prevents shutting down twice.
248         if shared.shutdown {
249             return;
250         }
251 
252         shared.shutdown = true;
253         shared.shutdown_tx = None;
254         self.spawner.inner.condvar.notify_all();
255 
256         let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
257         let workers = std::mem::take(&mut shared.worker_threads);
258 
259         drop(shared);
260 
261         if self.shutdown_rx.wait(timeout) {
262             let _ = last_exited_thread.map(|th| th.join());
263 
264             // Loom requires that execution be deterministic, so sort by thread ID before joining.
265             // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic)
266             let mut workers: Vec<(usize, thread::JoinHandle<()>)> = workers.into_iter().collect();
267             workers.sort_by_key(|(id, _)| *id);
268 
269             for (_id, handle) in workers.into_iter() {
270                 let _ = handle.join();
271             }
272         }
273     }
274 }
275 
276 impl Drop for BlockingPool {
drop(&mut self)277     fn drop(&mut self) {
278         self.shutdown(None);
279     }
280 }
281 
282 impl fmt::Debug for BlockingPool {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result283     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
284         fmt.debug_struct("BlockingPool").finish()
285     }
286 }
287 
288 // ===== impl Spawner =====
289 
290 impl Spawner {
291     #[track_caller]
spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,292     pub(crate) fn spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R>
293     where
294         F: FnOnce() -> R + Send + 'static,
295         R: Send + 'static,
296     {
297         let (join_handle, spawn_result) =
298             if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
299                 self.spawn_blocking_inner(Box::new(func), Mandatory::NonMandatory, None, rt)
300             } else {
301                 self.spawn_blocking_inner(func, Mandatory::NonMandatory, None, rt)
302             };
303 
304         match spawn_result {
305             Ok(()) => join_handle,
306             // Compat: do not panic here, return the join_handle even though it will never resolve
307             Err(SpawnError::ShuttingDown) => join_handle,
308             Err(SpawnError::NoThreads(e)) => {
309                 panic!("OS can't spawn worker thread: {}", e)
310             }
311         }
312     }
313 
314     cfg_fs! {
315         #[track_caller]
316         #[cfg_attr(any(
317             all(loom, not(test)), // the function is covered by loom tests
318             test
319         ), allow(dead_code))]
320         pub(crate) fn spawn_mandatory_blocking<F, R>(&self, rt: &Handle, func: F) -> Option<JoinHandle<R>>
321         where
322             F: FnOnce() -> R + Send + 'static,
323             R: Send + 'static,
324         {
325             let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
326                 self.spawn_blocking_inner(
327                     Box::new(func),
328                     Mandatory::Mandatory,
329                     None,
330                     rt,
331                 )
332             } else {
333                 self.spawn_blocking_inner(
334                     func,
335                     Mandatory::Mandatory,
336                     None,
337                     rt,
338                 )
339             };
340 
341             if spawn_result.is_ok() {
342                 Some(join_handle)
343             } else {
344                 None
345             }
346         }
347     }
348 
349     #[track_caller]
spawn_blocking_inner<F, R>( &self, func: F, is_mandatory: Mandatory, name: Option<&str>, rt: &Handle, ) -> (JoinHandle<R>, Result<(), SpawnError>) where F: FnOnce() -> R + Send + 'static, R: Send + 'static,350     pub(crate) fn spawn_blocking_inner<F, R>(
351         &self,
352         func: F,
353         is_mandatory: Mandatory,
354         name: Option<&str>,
355         rt: &Handle,
356     ) -> (JoinHandle<R>, Result<(), SpawnError>)
357     where
358         F: FnOnce() -> R + Send + 'static,
359         R: Send + 'static,
360     {
361         let fut = BlockingTask::new(func);
362         let id = task::Id::next();
363         #[cfg(all(tokio_unstable, feature = "tracing"))]
364         let fut = {
365             use tracing::Instrument;
366             let location = std::panic::Location::caller();
367             let span = tracing::trace_span!(
368                 target: "tokio::task::blocking",
369                 "runtime.spawn",
370                 kind = %"blocking",
371                 task.name = %name.unwrap_or_default(),
372                 task.id = id.as_u64(),
373                 "fn" = %std::any::type_name::<F>(),
374                 loc.file = location.file(),
375                 loc.line = location.line(),
376                 loc.col = location.column(),
377             );
378             fut.instrument(span)
379         };
380 
381         #[cfg(not(all(tokio_unstable, feature = "tracing")))]
382         let _ = name;
383 
384         let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id);
385 
386         let spawned = self.spawn_task(Task::new(task, is_mandatory), rt);
387         (handle, spawned)
388     }
389 
spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError>390     fn spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError> {
391         let mut shared = self.inner.shared.lock();
392 
393         if shared.shutdown {
394             // Shutdown the task: it's fine to shutdown this task (even if
395             // mandatory) because it was scheduled after the shutdown of the
396             // runtime began.
397             task.task.shutdown();
398 
399             // no need to even push this task; it would never get picked up
400             return Err(SpawnError::ShuttingDown);
401         }
402 
403         shared.queue.push_back(task);
404         self.inner.metrics.inc_queue_depth();
405 
406         if self.inner.metrics.num_idle_threads() == 0 {
407             // No threads are able to process the task.
408 
409             if self.inner.metrics.num_threads() == self.inner.thread_cap {
410                 // At max number of threads
411             } else {
412                 assert!(shared.shutdown_tx.is_some());
413                 let shutdown_tx = shared.shutdown_tx.clone();
414 
415                 if let Some(shutdown_tx) = shutdown_tx {
416                     let id = shared.worker_thread_index;
417 
418                     match self.spawn_thread(shutdown_tx, rt, id) {
419                         Ok(handle) => {
420                             self.inner.metrics.inc_num_threads();
421                             shared.worker_thread_index += 1;
422                             shared.worker_threads.insert(id, handle);
423                         }
424                         Err(ref e)
425                             if is_temporary_os_thread_error(e)
426                                 && self.inner.metrics.num_threads() > 0 =>
427                         {
428                             // OS temporarily failed to spawn a new thread.
429                             // The task will be picked up eventually by a currently
430                             // busy thread.
431                         }
432                         Err(e) => {
433                             // The OS refused to spawn the thread and there is no thread
434                             // to pick up the task that has just been pushed to the queue.
435                             return Err(SpawnError::NoThreads(e));
436                         }
437                     }
438                 }
439             }
440         } else {
441             // Notify an idle worker thread. The notification counter
442             // is used to count the needed amount of notifications
443             // exactly. Thread libraries may generate spurious
444             // wakeups, this counter is used to keep us in a
445             // consistent state.
446             self.inner.metrics.dec_num_idle_threads();
447             shared.num_notify += 1;
448             self.inner.condvar.notify_one();
449         }
450 
451         Ok(())
452     }
453 
spawn_thread( &self, shutdown_tx: shutdown::Sender, rt: &Handle, id: usize, ) -> std::io::Result<thread::JoinHandle<()>>454     fn spawn_thread(
455         &self,
456         shutdown_tx: shutdown::Sender,
457         rt: &Handle,
458         id: usize,
459     ) -> std::io::Result<thread::JoinHandle<()>> {
460         let mut builder = thread::Builder::new().name((self.inner.thread_name)());
461 
462         if let Some(stack_size) = self.inner.stack_size {
463             builder = builder.stack_size(stack_size);
464         }
465 
466         let rt = rt.clone();
467 
468         builder.spawn(move || {
469             // Only the reference should be moved into the closure
470             let _enter = rt.enter();
471             rt.inner.blocking_spawner().inner.run(id);
472             drop(shutdown_tx);
473         })
474     }
475 }
476 
477 cfg_metrics! {
478     impl Spawner {
479         pub(crate) fn num_threads(&self) -> usize {
480             self.inner.metrics.num_threads()
481         }
482 
483         pub(crate) fn num_idle_threads(&self) -> usize {
484             self.inner.metrics.num_idle_threads()
485         }
486 
487         pub(crate) fn queue_depth(&self) -> usize {
488             self.inner.metrics.queue_depth()
489         }
490     }
491 }
492 
493 // Tells whether the error when spawning a thread is temporary.
494 #[inline]
is_temporary_os_thread_error(error: &std::io::Error) -> bool495 fn is_temporary_os_thread_error(error: &std::io::Error) -> bool {
496     matches!(error.kind(), std::io::ErrorKind::WouldBlock)
497 }
498 
499 impl Inner {
run(&self, worker_thread_id: usize)500     fn run(&self, worker_thread_id: usize) {
501         if let Some(f) = &self.after_start {
502             f()
503         }
504 
505         let mut shared = self.shared.lock();
506         let mut join_on_thread = None;
507 
508         'main: loop {
509             // BUSY
510             while let Some(task) = shared.queue.pop_front() {
511                 self.metrics.dec_queue_depth();
512                 drop(shared);
513                 task.run();
514 
515                 shared = self.shared.lock();
516             }
517 
518             // IDLE
519             self.metrics.inc_num_idle_threads();
520 
521             while !shared.shutdown {
522                 let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
523 
524                 shared = lock_result.0;
525                 let timeout_result = lock_result.1;
526 
527                 if shared.num_notify != 0 {
528                     // We have received a legitimate wakeup,
529                     // acknowledge it by decrementing the counter
530                     // and transition to the BUSY state.
531                     shared.num_notify -= 1;
532                     break;
533                 }
534 
535                 // Even if the condvar "timed out", if the pool is entering the
536                 // shutdown phase, we want to perform the cleanup logic.
537                 if !shared.shutdown && timeout_result.timed_out() {
538                     // We'll join the prior timed-out thread's JoinHandle after dropping the lock.
539                     // This isn't done when shutting down, because the thread calling shutdown will
540                     // handle joining everything.
541                     let my_handle = shared.worker_threads.remove(&worker_thread_id);
542                     join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
543 
544                     break 'main;
545                 }
546 
547                 // Spurious wakeup detected, go back to sleep.
548             }
549 
550             if shared.shutdown {
551                 // Drain the queue
552                 while let Some(task) = shared.queue.pop_front() {
553                     self.metrics.dec_queue_depth();
554                     drop(shared);
555 
556                     task.shutdown_or_run_if_mandatory();
557 
558                     shared = self.shared.lock();
559                 }
560 
561                 // Work was produced, and we "took" it (by decrementing num_notify).
562                 // This means that num_idle was decremented once for our wakeup.
563                 // But, since we are exiting, we need to "undo" that, as we'll stay idle.
564                 self.metrics.inc_num_idle_threads();
565                 // NOTE: Technically we should also do num_notify++ and notify again,
566                 // but since we're shutting down anyway, that won't be necessary.
567                 break;
568             }
569         }
570 
571         // Thread exit
572         self.metrics.dec_num_threads();
573 
574         // num_idle should now be tracked exactly, panic
575         // with a descriptive message if it is not the
576         // case.
577         let prev_idle = self.metrics.dec_num_idle_threads();
578         if prev_idle < self.metrics.num_idle_threads() {
579             panic!("num_idle_threads underflowed on thread exit")
580         }
581 
582         if shared.shutdown && self.metrics.num_threads() == 0 {
583             self.condvar.notify_one();
584         }
585 
586         drop(shared);
587 
588         if let Some(f) = &self.before_stop {
589             f()
590         }
591 
592         if let Some(handle) = join_on_thread {
593             let _ = handle.join();
594         }
595     }
596 }
597 
598 impl fmt::Debug for Spawner {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result599     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
600         fmt.debug_struct("blocking::Spawner").finish()
601     }
602 }
603