• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Provides a struct for registering signal handlers that get cleared on drop.
6 
7 use std::{
8     convert::TryFrom,
9     fmt,
10     io::{Cursor, Write},
11     panic::catch_unwind,
12     result,
13 };
14 
15 use libc::{c_int, c_void, STDERR_FILENO};
16 use remain::sorted;
17 use thiserror::Error;
18 
19 use super::{
20     signal::{
21         clear_signal_handler, has_default_signal_handler, register_signal_handler, wait_for_signal,
22         Signal,
23     },
24     Error as ErrnoError,
25 };
26 
27 #[sorted]
28 #[derive(Error, Debug)]
29 pub enum Error {
30     /// Already waiting for interrupt.
31     #[error("already waiting for interrupt.")]
32     AlreadyWaiting,
33     /// Signal already has a handler.
34     #[error("signal handler already set for {0:?}")]
35     HandlerAlreadySet(Signal),
36     /// Failed to check if signal has the default signal handler.
37     #[error("failed to check the signal handler for {0:?}: {1}")]
38     HasDefaultSignalHandler(Signal, ErrnoError),
39     /// Failed to register a signal handler.
40     #[error("failed to register a signal handler for {0:?}: {1}")]
41     RegisterSignalHandler(Signal, ErrnoError),
42     /// Sigaction failed.
43     #[error("sigaction failed for {0:?}: {1}")]
44     Sigaction(Signal, ErrnoError),
45     /// Failed to wait for signal.
46     #[error("wait_for_signal failed: {0}")]
47     WaitForSignal(ErrnoError),
48 }
49 
50 pub type Result<T> = result::Result<T, Error>;
51 
52 /// The interface used by Scoped Signal handler.
53 ///
54 /// # Safety
55 /// The implementation of handle_signal needs to be async signal-safe.
56 ///
57 /// NOTE: panics are caught when possible because a panic inside ffi is undefined behavior.
58 pub unsafe trait SignalHandler {
59     /// A function that is called to handle the passed signal.
handle_signal(signal: Signal)60     fn handle_signal(signal: Signal);
61 }
62 
63 /// Wrap the handler with an extern "C" function.
call_handler<H: SignalHandler>(signum: c_int)64 extern "C" fn call_handler<H: SignalHandler>(signum: c_int) {
65     // Make an effort to surface an error.
66     if catch_unwind(|| H::handle_signal(Signal::try_from(signum).unwrap())).is_err() {
67         // Note the following cannot be used:
68         // eprintln! - uses std::io which has locks that may be held.
69         // format! - uses the allocator which enforces mutual exclusion.
70 
71         // Get the debug representation of signum.
72         let signal: Signal;
73         let signal_debug: &dyn fmt::Debug = match Signal::try_from(signum) {
74             Ok(s) => {
75                 signal = s;
76                 &signal as &dyn fmt::Debug
77             }
78             Err(_) => &signum as &dyn fmt::Debug,
79         };
80 
81         // Buffer the output, so a single call to write can be used.
82         // The message accounts for 29 chars, that leaves 35 for the string representation of the
83         // signal which is more than enough.
84         let mut buffer = [0u8; 64];
85         let mut cursor = Cursor::new(buffer.as_mut());
86         if writeln!(cursor, "signal handler got error for: {:?}", signal_debug).is_ok() {
87             let len = cursor.position() as usize;
88             // Safe in the sense that buffer is owned and the length is checked. This may print in
89             // the middle of an existing write, but that is considered better than dropping the
90             // error.
91             unsafe {
92                 libc::write(
93                     STDERR_FILENO,
94                     cursor.get_ref().as_ptr() as *const c_void,
95                     len,
96                 )
97             };
98         } else {
99             // This should never happen, but write an error message just in case.
100             const ERROR_DROPPED: &str = "Error dropped by signal handler.";
101             let bytes = ERROR_DROPPED.as_bytes();
102             unsafe { libc::write(STDERR_FILENO, bytes.as_ptr() as *const c_void, bytes.len()) };
103         }
104     }
105 }
106 
107 /// Represents a signal handler that is registered with a set of signals that unregistered when the
108 /// struct goes out of scope. Prefer a signalfd based solution before using this.
109 pub struct ScopedSignalHandler {
110     signals: Vec<Signal>,
111 }
112 
113 impl ScopedSignalHandler {
114     /// Attempts to register `handler` with the provided `signals`. It will fail if there is already
115     /// an existing handler on any of `signals`.
116     ///
117     /// # Safety
118     /// This is safe if H::handle_signal is async-signal safe.
new<H: SignalHandler>(signals: &[Signal]) -> Result<Self>119     pub fn new<H: SignalHandler>(signals: &[Signal]) -> Result<Self> {
120         let mut scoped_handler = ScopedSignalHandler {
121             signals: Vec::with_capacity(signals.len()),
122         };
123         for &signal in signals {
124             if !has_default_signal_handler((signal).into())
125                 .map_err(|err| Error::HasDefaultSignalHandler(signal, err))?
126             {
127                 return Err(Error::HandlerAlreadySet(signal));
128             }
129             // Requires an async-safe callback.
130             unsafe {
131                 register_signal_handler((signal).into(), call_handler::<H>)
132                     .map_err(|err| Error::RegisterSignalHandler(signal, err))?
133             };
134             scoped_handler.signals.push(signal);
135         }
136         Ok(scoped_handler)
137     }
138 }
139 
140 /// Clears the signal handler for any of the associated signals.
141 impl Drop for ScopedSignalHandler {
drop(&mut self)142     fn drop(&mut self) {
143         for signal in &self.signals {
144             if let Err(err) = clear_signal_handler((*signal).into()) {
145                 eprintln!("Error: failed to clear signal handler: {:?}", err);
146             }
147         }
148     }
149 }
150 
151 /// A signal handler that does nothing.
152 ///
153 /// This is useful in cases where wait_for_signal is used since it will never trigger if the signal
154 /// is blocked and the default handler may have undesired effects like terminating the process.
155 pub struct EmptySignalHandler;
156 /// # Safety
157 /// Safe because handle_signal is async-signal safe.
158 unsafe impl SignalHandler for EmptySignalHandler {
handle_signal(_: Signal)159     fn handle_signal(_: Signal) {}
160 }
161 
162 /// Blocks until SIGINT is received, which often happens because Ctrl-C was pressed in an
163 /// interactive terminal.
164 ///
165 /// Note: if you are using a multi-threaded application you need to block SIGINT on all other
166 /// threads or they may receive the signal instead of the desired thread.
wait_for_interrupt() -> Result<()>167 pub fn wait_for_interrupt() -> Result<()> {
168     // Register a signal handler if there is not one already so the thread is not killed.
169     let ret = ScopedSignalHandler::new::<EmptySignalHandler>(&[Signal::Interrupt]);
170     if !matches!(&ret, Ok(_) | Err(Error::HandlerAlreadySet(_))) {
171         ret?;
172     }
173 
174     match wait_for_signal(&[Signal::Interrupt.into()], None) {
175         Ok(_) => Ok(()),
176         Err(err) => Err(Error::WaitForSignal(err)),
177     }
178 }
179 
180 #[cfg(test)]
181 mod tests {
182     use super::*;
183 
184     use std::{
185         fs::File,
186         io::{BufRead, BufReader},
187         mem::zeroed,
188         ptr::{null, null_mut},
189         sync::{
190             atomic::{AtomicI32, AtomicUsize, Ordering},
191             Arc, Mutex, MutexGuard, Once,
192         },
193         thread::{sleep, spawn},
194         time::{Duration, Instant},
195     };
196 
197     use libc::sigaction;
198 
199     use super::super::{gettid, kill, Pid};
200 
201     const TEST_SIGNAL: Signal = Signal::User1;
202     const TEST_SIGNALS: &[Signal] = &[Signal::User1, Signal::User2];
203 
204     static TEST_SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
205 
206     /// Only allows one test case to execute at a time.
get_mutex() -> MutexGuard<'static, ()>207     fn get_mutex() -> MutexGuard<'static, ()> {
208         static INIT: Once = Once::new();
209         static mut VAL: Option<Arc<Mutex<()>>> = None;
210 
211         INIT.call_once(|| {
212             let val = Some(Arc::new(Mutex::new(())));
213             // Safe because the mutation is protected by the Once.
214             unsafe { VAL = val }
215         });
216 
217         // Safe mutation only happens in the Once.
218         unsafe { VAL.as_ref() }.unwrap().lock().unwrap()
219     }
220 
reset_counter()221     fn reset_counter() {
222         TEST_SIGNAL_COUNTER.swap(0, Ordering::SeqCst);
223     }
224 
get_sigaction(signal: Signal) -> Result<sigaction>225     fn get_sigaction(signal: Signal) -> Result<sigaction> {
226         // Safe because sigaction is owned and expected to be initialized ot zeros.
227         let mut sigact: sigaction = unsafe { zeroed() };
228 
229         if unsafe { sigaction(signal.into(), null(), &mut sigact) } < 0 {
230             Err(Error::Sigaction(signal, ErrnoError::last()))
231         } else {
232             Ok(sigact)
233         }
234     }
235 
236     /// Safety:
237     /// This is only safe if the signal handler set in sigaction is safe.
restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction>238     unsafe fn restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction> {
239         if sigaction(signal.into(), &sigact, null_mut()) < 0 {
240             Err(Error::Sigaction(signal, ErrnoError::last()))
241         } else {
242             Ok(sigact)
243         }
244     }
245 
246     /// Safety:
247     /// Safe if the signal handler for Signal::User1 is safe.
send_test_signal()248     unsafe fn send_test_signal() {
249         kill(gettid(), Signal::User1.into()).unwrap()
250     }
251 
252     macro_rules! assert_counter_eq {
253         ($compare_to:expr) => {{
254             let expected: usize = $compare_to;
255             let got: usize = TEST_SIGNAL_COUNTER.load(Ordering::SeqCst);
256             if got != expected {
257                 panic!(
258                     "wrong signal counter value: got {}; expected {}",
259                     got, expected
260                 );
261             }
262         }};
263     }
264 
265     struct TestHandler;
266 
267     /// # Safety
268     /// Safe because handle_signal is async-signal safe.
269     unsafe impl SignalHandler for TestHandler {
handle_signal(signal: Signal)270         fn handle_signal(signal: Signal) {
271             if TEST_SIGNAL == signal {
272                 TEST_SIGNAL_COUNTER.fetch_add(1, Ordering::SeqCst);
273             }
274         }
275     }
276 
277     #[test]
scopedsignalhandler_success()278     fn scopedsignalhandler_success() {
279         // Prevent other test cases from running concurrently since the signal
280         // handlers are shared for the process.
281         let _guard = get_mutex();
282 
283         reset_counter();
284         assert_counter_eq!(0);
285 
286         assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
287         let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
288         assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
289 
290         // Safe because test_handler is safe.
291         unsafe { send_test_signal() };
292 
293         // Give the handler time to run in case it is on a different thread.
294         for _ in 1..40 {
295             if TEST_SIGNAL_COUNTER.load(Ordering::SeqCst) > 0 {
296                 break;
297             }
298             sleep(Duration::from_millis(250));
299         }
300 
301         assert_counter_eq!(1);
302 
303         drop(handler);
304         assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
305     }
306 
307     #[test]
scopedsignalhandler_handleralreadyset()308     fn scopedsignalhandler_handleralreadyset() {
309         // Prevent other test cases from running concurrently since the signal
310         // handlers are shared for the process.
311         let _guard = get_mutex();
312 
313         reset_counter();
314         assert_counter_eq!(0);
315 
316         assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
317         // Safe because TestHandler is async-signal safe.
318         let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
319         assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
320 
321         // Safe because TestHandler is async-signal safe.
322         assert!(matches!(
323             ScopedSignalHandler::new::<TestHandler>(TEST_SIGNALS),
324             Err(Error::HandlerAlreadySet(Signal::User1))
325         ));
326 
327         assert_counter_eq!(0);
328         drop(handler);
329         assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
330     }
331 
332     /// Stores the thread used by WaitForInterruptHandler.
333     static WAIT_FOR_INTERRUPT_THREAD_ID: AtomicI32 = AtomicI32::new(0);
334     /// Forwards SIGINT to the appropriate thread.
335     struct WaitForInterruptHandler;
336 
337     /// # Safety
338     /// Safe because handle_signal is async-signal safe.
339     unsafe impl SignalHandler for WaitForInterruptHandler {
handle_signal(_: Signal)340         fn handle_signal(_: Signal) {
341             let tid = WAIT_FOR_INTERRUPT_THREAD_ID.load(Ordering::SeqCst);
342             // If the thread ID is set and executed on the wrong thread, forward the signal.
343             if tid != 0 && gettid() != tid {
344                 // Safe because the handler is safe and the target thread id is expecting the signal.
345                 unsafe { kill(tid, Signal::Interrupt.into()) }.unwrap();
346             }
347         }
348     }
349 
350     /// Query /proc/${tid}/status for its State and check if it is either S (sleeping) or in
351     /// D (disk sleep).
thread_is_sleeping(tid: Pid) -> result::Result<bool, ErrnoError>352     fn thread_is_sleeping(tid: Pid) -> result::Result<bool, ErrnoError> {
353         const PREFIX: &str = "State:";
354         let mut status_reader = BufReader::new(File::open(format!("/proc/{}/status", tid))?);
355         let mut line = String::new();
356         loop {
357             let count = status_reader.read_line(&mut line)?;
358             if count == 0 {
359                 return Err(ErrnoError::new(libc::EIO));
360             }
361             if let Some(stripped) = line.strip_prefix(PREFIX) {
362                 return Ok(matches!(
363                     stripped.trim_start().chars().next(),
364                     Some('S') | Some('D')
365                 ));
366             }
367             line.clear();
368         }
369     }
370 
371     /// Wait for a process to block either in a sleeping or disk sleep state.
wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> result::Result<(), ErrnoError>372     fn wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> result::Result<(), ErrnoError> {
373         let start = Instant::now();
374         loop {
375             if thread_is_sleeping(tid)? {
376                 return Ok(());
377             }
378             if start.elapsed() > timeout {
379                 return Err(ErrnoError::new(libc::EAGAIN));
380             }
381             sleep(Duration::from_millis(50));
382         }
383     }
384 
385     #[test]
waitforinterrupt_success()386     fn waitforinterrupt_success() {
387         // Prevent other test cases from running concurrently since the signal
388         // handlers are shared for the process.
389         let _guard = get_mutex();
390 
391         let to_restore = get_sigaction(Signal::Interrupt).unwrap();
392         clear_signal_handler(Signal::Interrupt.into()).unwrap();
393         // Safe because TestHandler is async-signal safe.
394         let handler =
395             ScopedSignalHandler::new::<WaitForInterruptHandler>(&[Signal::Interrupt]).unwrap();
396 
397         let tid = gettid();
398         WAIT_FOR_INTERRUPT_THREAD_ID.store(tid, Ordering::SeqCst);
399 
400         let join_handle = spawn(move || -> result::Result<(), ErrnoError> {
401             // Wait unitl the thread is ready to receive the signal.
402             wait_for_thread_to_sleep(tid, Duration::from_secs(10)).unwrap();
403 
404             // Safe because the SIGINT handler is safe.
405             unsafe { kill(tid, Signal::Interrupt.into()) }
406         });
407         let wait_ret = wait_for_interrupt();
408         let join_ret = join_handle.join();
409 
410         drop(handler);
411         // Safe because we are restoring the previous SIGINT handler.
412         unsafe { restore_sigaction(Signal::Interrupt, to_restore) }.unwrap();
413 
414         wait_ret.unwrap();
415         join_ret.unwrap().unwrap();
416     }
417 }
418