• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Thread pool for blocking operations
2 
3 use crate::loom::sync::{Arc, Condvar, Mutex};
4 use crate::loom::thread;
5 use crate::runtime::blocking::schedule::NoopSchedule;
6 use crate::runtime::blocking::shutdown;
7 use crate::runtime::builder::ThreadNameFn;
8 use crate::runtime::context;
9 use crate::runtime::task::{self, JoinHandle};
10 use crate::runtime::{Builder, Callback, Handle};
11 
12 use std::collections::{HashMap, VecDeque};
13 use std::fmt;
14 use std::time::Duration;
15 
16 pub(crate) struct BlockingPool {
17     spawner: Spawner,
18     shutdown_rx: shutdown::Receiver,
19 }
20 
21 #[derive(Clone)]
22 pub(crate) struct Spawner {
23     inner: Arc<Inner>,
24 }
25 
26 struct Inner {
27     /// State shared between worker threads.
28     shared: Mutex<Shared>,
29 
30     /// Pool threads wait on this.
31     condvar: Condvar,
32 
33     /// Spawned threads use this name.
34     thread_name: ThreadNameFn,
35 
36     /// Spawned thread stack size.
37     stack_size: Option<usize>,
38 
39     /// Call after a thread starts.
40     after_start: Option<Callback>,
41 
42     /// Call before a thread stops.
43     before_stop: Option<Callback>,
44 
45     // Maximum number of threads.
46     thread_cap: usize,
47 
48     // Customizable wait timeout.
49     keep_alive: Duration,
50 }
51 
52 struct Shared {
53     queue: VecDeque<Task>,
54     num_th: usize,
55     num_idle: u32,
56     num_notify: u32,
57     shutdown: bool,
58     shutdown_tx: Option<shutdown::Sender>,
59     /// Prior to shutdown, we clean up JoinHandles by having each timed-out
60     /// thread join on the previous timed-out thread. This is not strictly
61     /// necessary but helps avoid Valgrind false positives, see
62     /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666>
63     /// for more information.
64     last_exiting_thread: Option<thread::JoinHandle<()>>,
65     /// This holds the JoinHandles for all running threads; on shutdown, the thread
66     /// calling shutdown handles joining on these.
67     worker_threads: HashMap<usize, thread::JoinHandle<()>>,
68     /// This is a counter used to iterate worker_threads in a consistent order (for loom's
69     /// benefit).
70     worker_thread_index: usize,
71 }
72 
73 type Task = task::UnownedTask<NoopSchedule>;
74 
75 const KEEP_ALIVE: Duration = Duration::from_secs(10);
76 
77 /// Runs the provided function on an executor dedicated to blocking operations.
spawn_blocking<F, R>(func: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,78 pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
79 where
80     F: FnOnce() -> R + Send + 'static,
81     R: Send + 'static,
82 {
83     let rt = context::current();
84     rt.spawn_blocking(func)
85 }
86 
87 // ===== impl BlockingPool =====
88 
89 impl BlockingPool {
new(builder: &Builder, thread_cap: usize) -> BlockingPool90     pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
91         let (shutdown_tx, shutdown_rx) = shutdown::channel();
92         let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
93 
94         BlockingPool {
95             spawner: Spawner {
96                 inner: Arc::new(Inner {
97                     shared: Mutex::new(Shared {
98                         queue: VecDeque::new(),
99                         num_th: 0,
100                         num_idle: 0,
101                         num_notify: 0,
102                         shutdown: false,
103                         shutdown_tx: Some(shutdown_tx),
104                         last_exiting_thread: None,
105                         worker_threads: HashMap::new(),
106                         worker_thread_index: 0,
107                     }),
108                     condvar: Condvar::new(),
109                     thread_name: builder.thread_name.clone(),
110                     stack_size: builder.thread_stack_size,
111                     after_start: builder.after_start.clone(),
112                     before_stop: builder.before_stop.clone(),
113                     thread_cap,
114                     keep_alive,
115                 }),
116             },
117             shutdown_rx,
118         }
119     }
120 
spawner(&self) -> &Spawner121     pub(crate) fn spawner(&self) -> &Spawner {
122         &self.spawner
123     }
124 
shutdown(&mut self, timeout: Option<Duration>)125     pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
126         let mut shared = self.spawner.inner.shared.lock();
127 
128         // The function can be called multiple times. First, by explicitly
129         // calling `shutdown` then by the drop handler calling `shutdown`. This
130         // prevents shutting down twice.
131         if shared.shutdown {
132             return;
133         }
134 
135         shared.shutdown = true;
136         shared.shutdown_tx = None;
137         self.spawner.inner.condvar.notify_all();
138 
139         let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
140         let workers = std::mem::take(&mut shared.worker_threads);
141 
142         drop(shared);
143 
144         if self.shutdown_rx.wait(timeout) {
145             let _ = last_exited_thread.map(|th| th.join());
146 
147             // Loom requires that execution be deterministic, so sort by thread ID before joining.
148             // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic)
149             let mut workers: Vec<(usize, thread::JoinHandle<()>)> = workers.into_iter().collect();
150             workers.sort_by_key(|(id, _)| *id);
151 
152             for (_id, handle) in workers.into_iter() {
153                 let _ = handle.join();
154             }
155         }
156     }
157 }
158 
159 impl Drop for BlockingPool {
drop(&mut self)160     fn drop(&mut self) {
161         self.shutdown(None);
162     }
163 }
164 
165 impl fmt::Debug for BlockingPool {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result166     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
167         fmt.debug_struct("BlockingPool").finish()
168     }
169 }
170 
171 // ===== impl Spawner =====
172 
173 impl Spawner {
spawn(&self, task: Task, rt: &Handle) -> Result<(), ()>174     pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> {
175         let shutdown_tx = {
176             let mut shared = self.inner.shared.lock();
177 
178             if shared.shutdown {
179                 // Shutdown the task
180                 task.shutdown();
181 
182                 // no need to even push this task; it would never get picked up
183                 return Err(());
184             }
185 
186             shared.queue.push_back(task);
187 
188             if shared.num_idle == 0 {
189                 // No threads are able to process the task.
190 
191                 if shared.num_th == self.inner.thread_cap {
192                     // At max number of threads
193                     None
194                 } else {
195                     shared.num_th += 1;
196                     assert!(shared.shutdown_tx.is_some());
197                     shared.shutdown_tx.clone()
198                 }
199             } else {
200                 // Notify an idle worker thread. The notification counter
201                 // is used to count the needed amount of notifications
202                 // exactly. Thread libraries may generate spurious
203                 // wakeups, this counter is used to keep us in a
204                 // consistent state.
205                 shared.num_idle -= 1;
206                 shared.num_notify += 1;
207                 self.inner.condvar.notify_one();
208                 None
209             }
210         };
211 
212         if let Some(shutdown_tx) = shutdown_tx {
213             let mut shared = self.inner.shared.lock();
214 
215             let id = shared.worker_thread_index;
216             shared.worker_thread_index += 1;
217 
218             let handle = self.spawn_thread(shutdown_tx, rt, id);
219 
220             shared.worker_threads.insert(id, handle);
221         }
222 
223         Ok(())
224     }
225 
spawn_thread( &self, shutdown_tx: shutdown::Sender, rt: &Handle, id: usize, ) -> thread::JoinHandle<()>226     fn spawn_thread(
227         &self,
228         shutdown_tx: shutdown::Sender,
229         rt: &Handle,
230         id: usize,
231     ) -> thread::JoinHandle<()> {
232         let mut builder = thread::Builder::new().name((self.inner.thread_name)());
233 
234         if let Some(stack_size) = self.inner.stack_size {
235             builder = builder.stack_size(stack_size);
236         }
237 
238         let rt = rt.clone();
239 
240         builder
241             .spawn(move || {
242                 // Only the reference should be moved into the closure
243                 let _enter = crate::runtime::context::enter(rt.clone());
244                 rt.blocking_spawner.inner.run(id);
245                 drop(shutdown_tx);
246             })
247             .unwrap()
248     }
249 }
250 
251 impl Inner {
run(&self, worker_thread_id: usize)252     fn run(&self, worker_thread_id: usize) {
253         if let Some(f) = &self.after_start {
254             f()
255         }
256 
257         let mut shared = self.shared.lock();
258         let mut join_on_thread = None;
259 
260         'main: loop {
261             // BUSY
262             while let Some(task) = shared.queue.pop_front() {
263                 drop(shared);
264                 task.run();
265 
266                 shared = self.shared.lock();
267             }
268 
269             // IDLE
270             shared.num_idle += 1;
271 
272             while !shared.shutdown {
273                 let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
274 
275                 shared = lock_result.0;
276                 let timeout_result = lock_result.1;
277 
278                 if shared.num_notify != 0 {
279                     // We have received a legitimate wakeup,
280                     // acknowledge it by decrementing the counter
281                     // and transition to the BUSY state.
282                     shared.num_notify -= 1;
283                     break;
284                 }
285 
286                 // Even if the condvar "timed out", if the pool is entering the
287                 // shutdown phase, we want to perform the cleanup logic.
288                 if !shared.shutdown && timeout_result.timed_out() {
289                     // We'll join the prior timed-out thread's JoinHandle after dropping the lock.
290                     // This isn't done when shutting down, because the thread calling shutdown will
291                     // handle joining everything.
292                     let my_handle = shared.worker_threads.remove(&worker_thread_id);
293                     join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
294 
295                     break 'main;
296                 }
297 
298                 // Spurious wakeup detected, go back to sleep.
299             }
300 
301             if shared.shutdown {
302                 // Drain the queue
303                 while let Some(task) = shared.queue.pop_front() {
304                     drop(shared);
305                     task.shutdown();
306 
307                     shared = self.shared.lock();
308                 }
309 
310                 // Work was produced, and we "took" it (by decrementing num_notify).
311                 // This means that num_idle was decremented once for our wakeup.
312                 // But, since we are exiting, we need to "undo" that, as we'll stay idle.
313                 shared.num_idle += 1;
314                 // NOTE: Technically we should also do num_notify++ and notify again,
315                 // but since we're shutting down anyway, that won't be necessary.
316                 break;
317             }
318         }
319 
320         // Thread exit
321         shared.num_th -= 1;
322 
323         // num_idle should now be tracked exactly, panic
324         // with a descriptive message if it is not the
325         // case.
326         shared.num_idle = shared
327             .num_idle
328             .checked_sub(1)
329             .expect("num_idle underflowed on thread exit");
330 
331         if shared.shutdown && shared.num_th == 0 {
332             self.condvar.notify_one();
333         }
334 
335         drop(shared);
336 
337         if let Some(f) = &self.before_stop {
338             f()
339         }
340 
341         if let Some(handle) = join_on_thread {
342             let _ = handle.join();
343         }
344     }
345 }
346 
347 impl fmt::Debug for Spawner {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result348     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
349         fmt.debug_struct("blocking::Spawner").finish()
350     }
351 }
352