1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::any::Any;
6 use std::panic;
7 use std::panic::UnwindSafe;
8 use std::sync::mpsc::channel;
9 use std::sync::mpsc::Receiver;
10 use std::thread;
11 use std::thread::JoinHandle;
12 use std::time::Duration;
13
14 /// Spawns a thread that can be joined with a timeout.
spawn_with_timeout<F, T>(f: F) -> JoinHandleWithTimeout<T> where F: FnOnce() -> T, F: Send + UnwindSafe + 'static, T: Send + 'static,15 pub fn spawn_with_timeout<F, T>(f: F) -> JoinHandleWithTimeout<T>
16 where
17 F: FnOnce() -> T,
18 F: Send + UnwindSafe + 'static,
19 T: Send + 'static,
20 {
21 // Use a channel to signal completion to the join handle
22 let (tx, rx) = channel();
23 let handle = thread::spawn(move || {
24 let val = panic::catch_unwind(f);
25 tx.send(()).unwrap();
26 val
27 });
28 JoinHandleWithTimeout { handle, rx }
29 }
30
31 pub struct JoinHandleWithTimeout<T> {
32 handle: JoinHandle<thread::Result<T>>,
33 rx: Receiver<()>,
34 }
35
36 #[derive(Debug)]
37 pub enum JoinError {
38 Panic(Box<dyn Any>),
39 Timeout,
40 }
41
42 impl<T> JoinHandleWithTimeout<T> {
43 /// Tries to join the thread. Returns an error if the join takes more than `timeout_ms`.
try_join(self, timeout: Duration) -> Result<T, JoinError>44 pub fn try_join(self, timeout: Duration) -> Result<T, JoinError> {
45 if self.rx.recv_timeout(timeout).is_ok() {
46 self.handle.join().unwrap().map_err(|e| JoinError::Panic(e))
47 } else {
48 Err(JoinError::Timeout)
49 }
50 }
51 }
52