1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Provides an async blocking pool whose tasks can be cancelled.
6
7 use std::collections::HashMap;
8 use std::future::Future;
9 use std::sync::Arc;
10 use std::time::Duration;
11 use std::time::Instant;
12
13 use once_cell::sync::Lazy;
14 use sync::Condvar;
15 use sync::Mutex;
16 use thiserror::Error as ThisError;
17
18 use crate::BlockingPool;
19
20 /// Global executor.
21 ///
22 /// This is convenient, though not preferred. Pros/cons:
23 /// + It avoids passing executor all the way to each call sites.
24 /// + The call site can assume that executor will never shutdown.
25 /// + Provides similar functionality as async_task with a few improvements
26 /// around ability to cancel.
27 /// - Globals are harder to reason about.
28 static EXECUTOR: Lazy<CancellableBlockingPool> =
29 Lazy::new(|| CancellableBlockingPool::new(256, Duration::from_secs(10)));
30
31 const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
32
33 #[derive(PartialEq, Eq, PartialOrd, Default)]
34 enum WindDownStates {
35 #[default]
36 Armed,
37 Disarmed,
38 ShuttingDown,
39 ShutDown,
40 }
41
42 #[derive(Default)]
43 struct State {
44 wind_down: WindDownStates,
45
46 /// Helps to generate unique id to associate `cancel` with task.
47 current_cancellable_id: u64,
48
49 /// A map of all the `cancel` routines of queued/in-flight tasks.
50 cancellables: HashMap<u64, Box<dyn Fn() + Send + 'static>>,
51 }
52
53 #[derive(Debug, Clone, Copy)]
54 pub enum TimeoutAction {
55 /// Do nothing on timeout.
56 None,
57 /// Panic the thread on timeout.
58 Panic,
59 }
60
61 #[derive(ThisError, Debug, PartialEq, Eq)]
62 pub enum Error {
63 #[error("Timeout occurred while trying to join threads")]
64 Timedout,
65 #[error("Shutdown is in progress")]
66 ShutdownInProgress,
67 #[error("Already shut down")]
68 AlreadyShutdown,
69 }
70
71 struct Inner {
72 blocking_pool: BlockingPool,
73 state: Mutex<State>,
74
75 /// This condvar gets notified when `cancellables` is empty after removing an
76 /// entry.
77 cancellables_cv: Condvar,
78 }
79
80 impl Inner {
spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,81 pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
82 where
83 F: FnOnce() -> R + Send + 'static,
84 R: Send + 'static,
85 {
86 self.blocking_pool.spawn(f)
87 }
88
89 /// Adds cancel to a cancellables and returns an `id` with which `cancel` can be
90 /// accessed/removed.
add_cancellable(&self, cancel: Box<dyn Fn() + Send + 'static>) -> u6491 fn add_cancellable(&self, cancel: Box<dyn Fn() + Send + 'static>) -> u64 {
92 let mut state = self.state.lock();
93 let id = state.current_cancellable_id;
94 state.current_cancellable_id += 1;
95 state.cancellables.insert(id, cancel);
96 id
97 }
98 }
99
100 /// A thread pool for running work that may block.
101 ///
102 /// This is a wrapper around `BlockingPool` with an ability to cancel queued tasks.
103 /// See [BlockingPool] for more info.
104 ///
105 /// # Examples
106 ///
107 /// Spawn a task to run in the `CancellableBlockingPool` and await on its result.
108 ///
109 /// ```edition2018
110 /// use cros_async::CancellableBlockingPool;
111 ///
112 /// # async fn do_it() {
113 /// let pool = CancellableBlockingPool::default();
114 /// let CANCELLED = 0;
115 ///
116 /// let res = pool.spawn(move || {
117 /// // Do some CPU-intensive or blocking work here.
118 ///
119 /// 42
120 /// }, move || CANCELLED).await;
121 ///
122 /// assert_eq!(res, 42);
123 /// # }
124 /// # futures::executor::block_on(do_it());
125 /// ```
126 #[derive(Clone)]
127 pub struct CancellableBlockingPool {
128 inner: Arc<Inner>,
129 }
130
131 impl CancellableBlockingPool {
132 const RETRY_COUNT: usize = 10;
133 const SLEEP_DURATION: Duration = Duration::from_millis(100);
134
135 /// Create a new `CancellableBlockingPool`.
136 ///
137 /// When we try to shutdown or drop `CancellableBlockingPool`, it may happen that a hung thread
138 /// might prevent `CancellableBlockingPool` pool from getting dropped. On failure to shutdown in
139 /// `watchdog_opts.timeout` duration, `CancellableBlockingPool` can take an action specified by
140 /// `watchdog_opts.action`.
141 ///
142 /// See also: [BlockingPool::new()](BlockingPool::new)
new(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool143 pub fn new(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
144 CancellableBlockingPool {
145 inner: Arc::new(Inner {
146 blocking_pool: BlockingPool::new(max_threads, keepalive),
147 state: Default::default(),
148 cancellables_cv: Condvar::new(),
149 }),
150 }
151 }
152
153 /// Like [Self::new] but with pre-allocating capacity for up to `max_threads`.
with_capacity(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool154 pub fn with_capacity(max_threads: usize, keepalive: Duration) -> CancellableBlockingPool {
155 CancellableBlockingPool {
156 inner: Arc::new(Inner {
157 blocking_pool: BlockingPool::with_capacity(max_threads, keepalive),
158 state: Mutex::new(State::default()),
159 cancellables_cv: Condvar::new(),
160 }),
161 }
162 }
163
164 /// Spawn a task to run in the `CancellableBlockingPool`.
165 ///
166 /// Callers may `await` the returned `Task` to be notified when the work is completed.
167 /// Dropping the future will not cancel the task.
168 ///
169 /// `cancel` helps to cancel a queued or in-flight operation `f`.
170 /// `cancel` may be called more than once if `f` doesn't respond to `cancel`.
171 /// `cancel` is not called if `f` completes successfully. For example,
172 /// # Examples
173 ///
174 /// ```edition2018
175 /// use {cros_async::CancellableBlockingPool, std::sync::{Arc, Mutex, Condvar}};
176 ///
177 /// # async fn cancel_it() {
178 /// let pool = CancellableBlockingPool::default();
179 /// let cancelled: i32 = 1;
180 /// let success: i32 = 2;
181 ///
182 /// let shared = Arc::new((Mutex::new(0), Condvar::new()));
183 /// let shared2 = shared.clone();
184 /// let shared3 = shared.clone();
185 ///
186 /// let res = pool
187 /// .spawn(
188 /// move || {
189 /// let guard = shared.0.lock().unwrap();
190 /// let mut guard = shared.1.wait_while(guard, |state| *state == 0).unwrap();
191 /// if *guard != cancelled {
192 /// *guard = success;
193 /// }
194 /// },
195 /// move || {
196 /// *shared2.0.lock().unwrap() = cancelled;
197 /// shared2.1.notify_all();
198 /// },
199 /// )
200 /// .await;
201 /// pool.shutdown();
202 ///
203 /// assert_eq!(*shared3.0.lock().unwrap(), cancelled);
204 /// # }
205 /// ```
spawn<F, R, G>(&self, f: F, cancel: G) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, G: Fn() -> R + Send + 'static,206 pub fn spawn<F, R, G>(&self, f: F, cancel: G) -> impl Future<Output = R>
207 where
208 F: FnOnce() -> R + Send + 'static,
209 R: Send + 'static,
210 G: Fn() -> R + Send + 'static,
211 {
212 let inner = self.inner.clone();
213 let cancelled = Arc::new(Mutex::new(None));
214 let cancelled_spawn = cancelled.clone();
215 let id = inner.add_cancellable(Box::new(move || {
216 let mut c = cancelled.lock();
217 *c = Some(cancel());
218 }));
219
220 self.inner.spawn(move || {
221 if let Some(res) = cancelled_spawn.lock().take() {
222 return res;
223 }
224 let ret = f();
225 let mut state = inner.state.lock();
226 state.cancellables.remove(&id);
227 if state.cancellables.is_empty() {
228 inner.cancellables_cv.notify_one();
229 }
230 ret
231 })
232 }
233
234 /// Iterates over all the queued tasks and marks them as cancelled.
drain_cancellables(&self)235 fn drain_cancellables(&self) {
236 let mut state = self.inner.state.lock();
237 // Iterate a few times to try cancelling all the tasks.
238 for _ in 0..Self::RETRY_COUNT {
239 // Nothing left to do.
240 if state.cancellables.is_empty() {
241 return;
242 }
243
244 // We only cancel the task and do not remove it from the cancellables. It is runner's
245 // job to remove from state.cancellables.
246 for cancel in state.cancellables.values() {
247 cancel();
248 }
249 // Hold the state lock in a block before sleeping so that woken up threads can get to
250 // hold the lock.
251 // Wait for a while so that the threads get a chance complete task in flight.
252 let (state1, _cv_timeout) = self
253 .inner
254 .cancellables_cv
255 .wait_timeout(state, Self::SLEEP_DURATION);
256 state = state1;
257 }
258 }
259
260 /// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
261 /// will be cancelled.
262 /// Does not wait for all the tasks to get cancelled.
disarm(&self)263 pub fn disarm(&self) {
264 {
265 let mut state = self.inner.state.lock();
266
267 if state.wind_down >= WindDownStates::Disarmed {
268 return;
269 }
270
271 // At this point any new incoming request will be cancelled when run.
272 state.wind_down = WindDownStates::Disarmed;
273 }
274 self.drain_cancellables();
275 }
276
277 /// Shut down the `CancellableBlockingPool`.
278 ///
279 /// This will block until all work that has been started by the worker threads is finished. Any
280 /// work that was added to the `CancellableBlockingPool` but not yet picked up by a worker
281 /// thread will not complete and `await`ing on the `Task` for that work will panic.
282 ///
shutdown(&self) -> Result<(), Error>283 pub fn shutdown(&self) -> Result<(), Error> {
284 self.shutdown_with_timeout(DEFAULT_SHUTDOWN_TIMEOUT)
285 }
286
shutdown_with_timeout(&self, timeout: Duration) -> Result<(), Error>287 fn shutdown_with_timeout(&self, timeout: Duration) -> Result<(), Error> {
288 self.disarm();
289 {
290 let mut state = self.inner.state.lock();
291 if state.wind_down == WindDownStates::ShuttingDown {
292 return Err(Error::ShutdownInProgress);
293 }
294 if state.wind_down == WindDownStates::ShutDown {
295 return Err(Error::AlreadyShutdown);
296 }
297 state.wind_down = WindDownStates::ShuttingDown;
298 }
299
300 let res = self
301 .inner
302 .blocking_pool
303 .shutdown(/* deadline: */ Some(Instant::now() + timeout));
304
305 self.inner.state.lock().wind_down = WindDownStates::ShutDown;
306 match res {
307 Ok(_) => Ok(()),
308 Err(_) => Err(Error::Timedout),
309 }
310 }
311 }
312
313 impl Default for CancellableBlockingPool {
default() -> CancellableBlockingPool314 fn default() -> CancellableBlockingPool {
315 CancellableBlockingPool::new(256, Duration::from_secs(10))
316 }
317 }
318
319 impl Drop for CancellableBlockingPool {
drop(&mut self)320 fn drop(&mut self) {
321 let _ = self.shutdown();
322 }
323 }
324
325 /// Spawn a task to run in the `CancellableBlockingPool` static executor.
326 ///
327 /// `cancel` in-flight operation. cancel is called on operation during `disarm` or during
328 /// `shutdown`. Cancel may be called multiple times if running task doesn't get cancelled on first
329 /// attempt.
330 ///
331 /// Callers may `await` the returned `Task` to be notified when the work is completed.
332 ///
333 /// See also: `spawn`.
unblock<F, R, G>(f: F, cancel: G) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, G: Fn() -> R + Send + 'static,334 pub fn unblock<F, R, G>(f: F, cancel: G) -> impl Future<Output = R>
335 where
336 F: FnOnce() -> R + Send + 'static,
337 R: Send + 'static,
338 G: Fn() -> R + Send + 'static,
339 {
340 EXECUTOR.spawn(f, cancel)
341 }
342
343 /// Marks all the queued and in-flight tasks as cancelled. Any tasks queued after `disarm`ing
344 /// will be cancelled.
345 /// Doesn't not wait for all the tasks to get cancelled.
unblock_disarm()346 pub fn unblock_disarm() {
347 EXECUTOR.disarm()
348 }
349
350 #[cfg(test)]
351 mod test {
352 use std::sync::Arc;
353 use std::sync::Barrier;
354 use std::thread;
355 use std::time::Duration;
356
357 use futures::executor::block_on;
358 use sync::Condvar;
359 use sync::Mutex;
360
361 use crate::blocking::Error;
362 use crate::CancellableBlockingPool;
363
364 const TEST_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(100);
365
366 #[test]
disarm_with_pending_work()367 fn disarm_with_pending_work() {
368 // Create a pool with only one thread.
369 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
370
371 let mu = Arc::new(Mutex::new(false));
372 let cv = Arc::new(Condvar::new());
373 let blocker_is_running = Arc::new(Barrier::new(2));
374
375 // First spawn a thread that blocks the pool.
376 let task_mu = mu.clone();
377 let task_cv = cv.clone();
378 let task_blocker_is_running = blocker_is_running.clone();
379 let _ = pool.spawn(
380 move || {
381 task_blocker_is_running.wait();
382 let mut ready = task_mu.lock();
383 while !*ready {
384 ready = task_cv.wait(ready);
385 }
386 },
387 move || {},
388 );
389
390 // Wait for the worker to start running the blocking thread.
391 blocker_is_running.wait();
392
393 // This task will never finish because we will disarm the pool first.
394 let unfinished = pool.spawn(|| 5, || 0);
395
396 // Disarming should cancel the task.
397 pool.disarm();
398
399 // Shutdown the blocking thread. This will allow a worker to pick up the task that has
400 // to be cancelled.
401 *mu.lock() = true;
402 cv.notify_all();
403
404 // We expect the cancelled value to be returned.
405 assert_eq!(block_on(unfinished), 0);
406
407 // Now the pool is empty and can be shutdown without blocking.
408 pool.shutdown_with_timeout(TEST_SHUTDOWN_TIMEOUT).unwrap();
409 }
410
411 #[test]
shutdown_with_blocked_work_should_panic()412 fn shutdown_with_blocked_work_should_panic() {
413 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
414
415 let running = Arc::new((Mutex::new(false), Condvar::new()));
416 let running1 = running.clone();
417 let _ = pool.spawn(
418 move || {
419 *running1.0.lock() = true;
420 running1.1.notify_one();
421 thread::sleep(Duration::from_secs(10000));
422 },
423 move || {},
424 );
425
426 let mut is_running = running.0.lock();
427 while !*is_running {
428 is_running = running.1.wait(is_running);
429 }
430
431 assert_eq!(
432 pool.shutdown_with_timeout(TEST_SHUTDOWN_TIMEOUT),
433 Err(Error::Timedout)
434 );
435 }
436
437 #[test]
multiple_shutdown_returns_error()438 fn multiple_shutdown_returns_error() {
439 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
440 let _ = pool.shutdown();
441 assert_eq!(
442 pool.shutdown_with_timeout(TEST_SHUTDOWN_TIMEOUT),
443 Err(Error::AlreadyShutdown)
444 );
445 }
446
447 #[test]
shutdown_in_progress()448 fn shutdown_in_progress() {
449 let pool = CancellableBlockingPool::new(1, Duration::from_secs(10));
450
451 let running = Arc::new((Mutex::new(false), Condvar::new()));
452 let running1 = running.clone();
453 let _ = pool.spawn(
454 move || {
455 *running1.0.lock() = true;
456 running1.1.notify_one();
457 thread::sleep(Duration::from_secs(10000));
458 },
459 move || {},
460 );
461
462 let mut is_running = running.0.lock();
463 while !*is_running {
464 is_running = running.1.wait(is_running);
465 }
466
467 let pool_clone = pool.clone();
468 thread::spawn(move || {
469 while !pool_clone.inner.blocking_pool.shutting_down() {}
470 assert_eq!(
471 pool_clone.shutdown_with_timeout(TEST_SHUTDOWN_TIMEOUT),
472 Err(Error::ShutdownInProgress)
473 );
474 });
475 assert_eq!(
476 pool.shutdown_with_timeout(TEST_SHUTDOWN_TIMEOUT),
477 Err(Error::Timedout)
478 );
479 }
480 }
481