1 //! Thread pool for blocking operations
2
3 use crate::loom::sync::{Arc, Condvar, Mutex};
4 use crate::loom::thread;
5 use crate::runtime::blocking::schedule::NoopSchedule;
6 use crate::runtime::blocking::shutdown;
7 use crate::runtime::builder::ThreadNameFn;
8 use crate::runtime::context;
9 use crate::runtime::task::{self, JoinHandle};
10 use crate::runtime::{Builder, Callback, Handle};
11
12 use std::collections::{HashMap, VecDeque};
13 use std::fmt;
14 use std::time::Duration;
15
16 pub(crate) struct BlockingPool {
17 spawner: Spawner,
18 shutdown_rx: shutdown::Receiver,
19 }
20
21 #[derive(Clone)]
22 pub(crate) struct Spawner {
23 inner: Arc<Inner>,
24 }
25
26 struct Inner {
27 /// State shared between worker threads.
28 shared: Mutex<Shared>,
29
30 /// Pool threads wait on this.
31 condvar: Condvar,
32
33 /// Spawned threads use this name.
34 thread_name: ThreadNameFn,
35
36 /// Spawned thread stack size.
37 stack_size: Option<usize>,
38
39 /// Call after a thread starts.
40 after_start: Option<Callback>,
41
42 /// Call before a thread stops.
43 before_stop: Option<Callback>,
44
45 // Maximum number of threads.
46 thread_cap: usize,
47
48 // Customizable wait timeout.
49 keep_alive: Duration,
50 }
51
52 struct Shared {
53 queue: VecDeque<Task>,
54 num_th: usize,
55 num_idle: u32,
56 num_notify: u32,
57 shutdown: bool,
58 shutdown_tx: Option<shutdown::Sender>,
59 /// Prior to shutdown, we clean up JoinHandles by having each timed-out
60 /// thread join on the previous timed-out thread. This is not strictly
61 /// necessary but helps avoid Valgrind false positives, see
62 /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666>
63 /// for more information.
64 last_exiting_thread: Option<thread::JoinHandle<()>>,
65 /// This holds the JoinHandles for all running threads; on shutdown, the thread
66 /// calling shutdown handles joining on these.
67 worker_threads: HashMap<usize, thread::JoinHandle<()>>,
68 /// This is a counter used to iterate worker_threads in a consistent order (for loom's
69 /// benefit).
70 worker_thread_index: usize,
71 }
72
73 type Task = task::UnownedTask<NoopSchedule>;
74
75 const KEEP_ALIVE: Duration = Duration::from_secs(10);
76
77 /// Runs the provided function on an executor dedicated to blocking operations.
spawn_blocking<F, R>(func: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,78 pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
79 where
80 F: FnOnce() -> R + Send + 'static,
81 R: Send + 'static,
82 {
83 let rt = context::current();
84 rt.spawn_blocking(func)
85 }
86
87 // ===== impl BlockingPool =====
88
89 impl BlockingPool {
new(builder: &Builder, thread_cap: usize) -> BlockingPool90 pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
91 let (shutdown_tx, shutdown_rx) = shutdown::channel();
92 let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
93
94 BlockingPool {
95 spawner: Spawner {
96 inner: Arc::new(Inner {
97 shared: Mutex::new(Shared {
98 queue: VecDeque::new(),
99 num_th: 0,
100 num_idle: 0,
101 num_notify: 0,
102 shutdown: false,
103 shutdown_tx: Some(shutdown_tx),
104 last_exiting_thread: None,
105 worker_threads: HashMap::new(),
106 worker_thread_index: 0,
107 }),
108 condvar: Condvar::new(),
109 thread_name: builder.thread_name.clone(),
110 stack_size: builder.thread_stack_size,
111 after_start: builder.after_start.clone(),
112 before_stop: builder.before_stop.clone(),
113 thread_cap,
114 keep_alive,
115 }),
116 },
117 shutdown_rx,
118 }
119 }
120
spawner(&self) -> &Spawner121 pub(crate) fn spawner(&self) -> &Spawner {
122 &self.spawner
123 }
124
shutdown(&mut self, timeout: Option<Duration>)125 pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
126 let mut shared = self.spawner.inner.shared.lock();
127
128 // The function can be called multiple times. First, by explicitly
129 // calling `shutdown` then by the drop handler calling `shutdown`. This
130 // prevents shutting down twice.
131 if shared.shutdown {
132 return;
133 }
134
135 shared.shutdown = true;
136 shared.shutdown_tx = None;
137 self.spawner.inner.condvar.notify_all();
138
139 let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
140 let workers = std::mem::take(&mut shared.worker_threads);
141
142 drop(shared);
143
144 if self.shutdown_rx.wait(timeout) {
145 let _ = last_exited_thread.map(|th| th.join());
146
147 // Loom requires that execution be deterministic, so sort by thread ID before joining.
148 // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic)
149 let mut workers: Vec<(usize, thread::JoinHandle<()>)> = workers.into_iter().collect();
150 workers.sort_by_key(|(id, _)| *id);
151
152 for (_id, handle) in workers.into_iter() {
153 let _ = handle.join();
154 }
155 }
156 }
157 }
158
159 impl Drop for BlockingPool {
drop(&mut self)160 fn drop(&mut self) {
161 self.shutdown(None);
162 }
163 }
164
165 impl fmt::Debug for BlockingPool {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result166 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
167 fmt.debug_struct("BlockingPool").finish()
168 }
169 }
170
171 // ===== impl Spawner =====
172
173 impl Spawner {
spawn(&self, task: Task, rt: &Handle) -> Result<(), ()>174 pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> {
175 let shutdown_tx = {
176 let mut shared = self.inner.shared.lock();
177
178 if shared.shutdown {
179 // Shutdown the task
180 task.shutdown();
181
182 // no need to even push this task; it would never get picked up
183 return Err(());
184 }
185
186 shared.queue.push_back(task);
187
188 if shared.num_idle == 0 {
189 // No threads are able to process the task.
190
191 if shared.num_th == self.inner.thread_cap {
192 // At max number of threads
193 None
194 } else {
195 shared.num_th += 1;
196 assert!(shared.shutdown_tx.is_some());
197 shared.shutdown_tx.clone()
198 }
199 } else {
200 // Notify an idle worker thread. The notification counter
201 // is used to count the needed amount of notifications
202 // exactly. Thread libraries may generate spurious
203 // wakeups, this counter is used to keep us in a
204 // consistent state.
205 shared.num_idle -= 1;
206 shared.num_notify += 1;
207 self.inner.condvar.notify_one();
208 None
209 }
210 };
211
212 if let Some(shutdown_tx) = shutdown_tx {
213 let mut shared = self.inner.shared.lock();
214
215 let id = shared.worker_thread_index;
216 shared.worker_thread_index += 1;
217
218 let handle = self.spawn_thread(shutdown_tx, rt, id);
219
220 shared.worker_threads.insert(id, handle);
221 }
222
223 Ok(())
224 }
225
spawn_thread( &self, shutdown_tx: shutdown::Sender, rt: &Handle, id: usize, ) -> thread::JoinHandle<()>226 fn spawn_thread(
227 &self,
228 shutdown_tx: shutdown::Sender,
229 rt: &Handle,
230 id: usize,
231 ) -> thread::JoinHandle<()> {
232 let mut builder = thread::Builder::new().name((self.inner.thread_name)());
233
234 if let Some(stack_size) = self.inner.stack_size {
235 builder = builder.stack_size(stack_size);
236 }
237
238 let rt = rt.clone();
239
240 builder
241 .spawn(move || {
242 // Only the reference should be moved into the closure
243 let _enter = crate::runtime::context::enter(rt.clone());
244 rt.blocking_spawner.inner.run(id);
245 drop(shutdown_tx);
246 })
247 .unwrap()
248 }
249 }
250
251 impl Inner {
run(&self, worker_thread_id: usize)252 fn run(&self, worker_thread_id: usize) {
253 if let Some(f) = &self.after_start {
254 f()
255 }
256
257 let mut shared = self.shared.lock();
258 let mut join_on_thread = None;
259
260 'main: loop {
261 // BUSY
262 while let Some(task) = shared.queue.pop_front() {
263 drop(shared);
264 task.run();
265
266 shared = self.shared.lock();
267 }
268
269 // IDLE
270 shared.num_idle += 1;
271
272 while !shared.shutdown {
273 let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
274
275 shared = lock_result.0;
276 let timeout_result = lock_result.1;
277
278 if shared.num_notify != 0 {
279 // We have received a legitimate wakeup,
280 // acknowledge it by decrementing the counter
281 // and transition to the BUSY state.
282 shared.num_notify -= 1;
283 break;
284 }
285
286 // Even if the condvar "timed out", if the pool is entering the
287 // shutdown phase, we want to perform the cleanup logic.
288 if !shared.shutdown && timeout_result.timed_out() {
289 // We'll join the prior timed-out thread's JoinHandle after dropping the lock.
290 // This isn't done when shutting down, because the thread calling shutdown will
291 // handle joining everything.
292 let my_handle = shared.worker_threads.remove(&worker_thread_id);
293 join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
294
295 break 'main;
296 }
297
298 // Spurious wakeup detected, go back to sleep.
299 }
300
301 if shared.shutdown {
302 // Drain the queue
303 while let Some(task) = shared.queue.pop_front() {
304 drop(shared);
305 task.shutdown();
306
307 shared = self.shared.lock();
308 }
309
310 // Work was produced, and we "took" it (by decrementing num_notify).
311 // This means that num_idle was decremented once for our wakeup.
312 // But, since we are exiting, we need to "undo" that, as we'll stay idle.
313 shared.num_idle += 1;
314 // NOTE: Technically we should also do num_notify++ and notify again,
315 // but since we're shutting down anyway, that won't be necessary.
316 break;
317 }
318 }
319
320 // Thread exit
321 shared.num_th -= 1;
322
323 // num_idle should now be tracked exactly, panic
324 // with a descriptive message if it is not the
325 // case.
326 shared.num_idle = shared
327 .num_idle
328 .checked_sub(1)
329 .expect("num_idle underflowed on thread exit");
330
331 if shared.shutdown && shared.num_th == 0 {
332 self.condvar.notify_one();
333 }
334
335 drop(shared);
336
337 if let Some(f) = &self.before_stop {
338 f()
339 }
340
341 if let Some(handle) = join_on_thread {
342 let _ = handle.join();
343 }
344 }
345 }
346
347 impl fmt::Debug for Spawner {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result348 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
349 fmt.debug_struct("blocking::Spawner").finish()
350 }
351 }
352