1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Provides a struct for registering signal handlers that get cleared on drop.
6
7 use std::{
8 convert::TryFrom,
9 fmt,
10 io::{Cursor, Write},
11 panic::catch_unwind,
12 result,
13 };
14
15 use libc::{c_int, c_void, STDERR_FILENO};
16 use remain::sorted;
17 use thiserror::Error;
18
19 use super::{
20 signal::{
21 clear_signal_handler, has_default_signal_handler, register_signal_handler, wait_for_signal,
22 Signal,
23 },
24 Error as ErrnoError,
25 };
26
27 #[sorted]
28 #[derive(Error, Debug)]
29 pub enum Error {
30 /// Already waiting for interrupt.
31 #[error("already waiting for interrupt.")]
32 AlreadyWaiting,
33 /// Signal already has a handler.
34 #[error("signal handler already set for {0:?}")]
35 HandlerAlreadySet(Signal),
36 /// Failed to check if signal has the default signal handler.
37 #[error("failed to check the signal handler for {0:?}: {1}")]
38 HasDefaultSignalHandler(Signal, ErrnoError),
39 /// Failed to register a signal handler.
40 #[error("failed to register a signal handler for {0:?}: {1}")]
41 RegisterSignalHandler(Signal, ErrnoError),
42 /// Sigaction failed.
43 #[error("sigaction failed for {0:?}: {1}")]
44 Sigaction(Signal, ErrnoError),
45 /// Failed to wait for signal.
46 #[error("wait_for_signal failed: {0}")]
47 WaitForSignal(ErrnoError),
48 }
49
50 pub type Result<T> = result::Result<T, Error>;
51
52 /// The interface used by Scoped Signal handler.
53 ///
54 /// # Safety
55 /// The implementation of handle_signal needs to be async signal-safe.
56 ///
57 /// NOTE: panics are caught when possible because a panic inside ffi is undefined behavior.
58 pub unsafe trait SignalHandler {
59 /// A function that is called to handle the passed signal.
handle_signal(signal: Signal)60 fn handle_signal(signal: Signal);
61 }
62
63 /// Wrap the handler with an extern "C" function.
call_handler<H: SignalHandler>(signum: c_int)64 extern "C" fn call_handler<H: SignalHandler>(signum: c_int) {
65 // Make an effort to surface an error.
66 if catch_unwind(|| H::handle_signal(Signal::try_from(signum).unwrap())).is_err() {
67 // Note the following cannot be used:
68 // eprintln! - uses std::io which has locks that may be held.
69 // format! - uses the allocator which enforces mutual exclusion.
70
71 // Get the debug representation of signum.
72 let signal: Signal;
73 let signal_debug: &dyn fmt::Debug = match Signal::try_from(signum) {
74 Ok(s) => {
75 signal = s;
76 &signal as &dyn fmt::Debug
77 }
78 Err(_) => &signum as &dyn fmt::Debug,
79 };
80
81 // Buffer the output, so a single call to write can be used.
82 // The message accounts for 29 chars, that leaves 35 for the string representation of the
83 // signal which is more than enough.
84 let mut buffer = [0u8; 64];
85 let mut cursor = Cursor::new(buffer.as_mut());
86 if writeln!(cursor, "signal handler got error for: {:?}", signal_debug).is_ok() {
87 let len = cursor.position() as usize;
88 // Safe in the sense that buffer is owned and the length is checked. This may print in
89 // the middle of an existing write, but that is considered better than dropping the
90 // error.
91 unsafe {
92 libc::write(
93 STDERR_FILENO,
94 cursor.get_ref().as_ptr() as *const c_void,
95 len,
96 )
97 };
98 } else {
99 // This should never happen, but write an error message just in case.
100 const ERROR_DROPPED: &str = "Error dropped by signal handler.";
101 let bytes = ERROR_DROPPED.as_bytes();
102 unsafe { libc::write(STDERR_FILENO, bytes.as_ptr() as *const c_void, bytes.len()) };
103 }
104 }
105 }
106
107 /// Represents a signal handler that is registered with a set of signals that unregistered when the
108 /// struct goes out of scope. Prefer a signalfd based solution before using this.
109 pub struct ScopedSignalHandler {
110 signals: Vec<Signal>,
111 }
112
113 impl ScopedSignalHandler {
114 /// Attempts to register `handler` with the provided `signals`. It will fail if there is already
115 /// an existing handler on any of `signals`.
116 ///
117 /// # Safety
118 /// This is safe if H::handle_signal is async-signal safe.
new<H: SignalHandler>(signals: &[Signal]) -> Result<Self>119 pub fn new<H: SignalHandler>(signals: &[Signal]) -> Result<Self> {
120 let mut scoped_handler = ScopedSignalHandler {
121 signals: Vec::with_capacity(signals.len()),
122 };
123 for &signal in signals {
124 if !has_default_signal_handler((signal).into())
125 .map_err(|err| Error::HasDefaultSignalHandler(signal, err))?
126 {
127 return Err(Error::HandlerAlreadySet(signal));
128 }
129 // Requires an async-safe callback.
130 unsafe {
131 register_signal_handler((signal).into(), call_handler::<H>)
132 .map_err(|err| Error::RegisterSignalHandler(signal, err))?
133 };
134 scoped_handler.signals.push(signal);
135 }
136 Ok(scoped_handler)
137 }
138 }
139
140 /// Clears the signal handler for any of the associated signals.
141 impl Drop for ScopedSignalHandler {
drop(&mut self)142 fn drop(&mut self) {
143 for signal in &self.signals {
144 if let Err(err) = clear_signal_handler((*signal).into()) {
145 eprintln!("Error: failed to clear signal handler: {:?}", err);
146 }
147 }
148 }
149 }
150
151 /// A signal handler that does nothing.
152 ///
153 /// This is useful in cases where wait_for_signal is used since it will never trigger if the signal
154 /// is blocked and the default handler may have undesired effects like terminating the process.
155 pub struct EmptySignalHandler;
156 /// # Safety
157 /// Safe because handle_signal is async-signal safe.
158 unsafe impl SignalHandler for EmptySignalHandler {
handle_signal(_: Signal)159 fn handle_signal(_: Signal) {}
160 }
161
162 /// Blocks until SIGINT is received, which often happens because Ctrl-C was pressed in an
163 /// interactive terminal.
164 ///
165 /// Note: if you are using a multi-threaded application you need to block SIGINT on all other
166 /// threads or they may receive the signal instead of the desired thread.
wait_for_interrupt() -> Result<()>167 pub fn wait_for_interrupt() -> Result<()> {
168 // Register a signal handler if there is not one already so the thread is not killed.
169 let ret = ScopedSignalHandler::new::<EmptySignalHandler>(&[Signal::Interrupt]);
170 if !matches!(&ret, Ok(_) | Err(Error::HandlerAlreadySet(_))) {
171 ret?;
172 }
173
174 match wait_for_signal(&[Signal::Interrupt.into()], None) {
175 Ok(_) => Ok(()),
176 Err(err) => Err(Error::WaitForSignal(err)),
177 }
178 }
179
180 #[cfg(test)]
181 mod tests {
182 use super::*;
183
184 use std::{
185 fs::File,
186 io::{BufRead, BufReader},
187 mem::zeroed,
188 ptr::{null, null_mut},
189 sync::{
190 atomic::{AtomicI32, AtomicUsize, Ordering},
191 Arc, Mutex, MutexGuard, Once,
192 },
193 thread::{sleep, spawn},
194 time::{Duration, Instant},
195 };
196
197 use libc::sigaction;
198
199 use super::super::{gettid, kill, Pid};
200
201 const TEST_SIGNAL: Signal = Signal::User1;
202 const TEST_SIGNALS: &[Signal] = &[Signal::User1, Signal::User2];
203
204 static TEST_SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
205
206 /// Only allows one test case to execute at a time.
get_mutex() -> MutexGuard<'static, ()>207 fn get_mutex() -> MutexGuard<'static, ()> {
208 static INIT: Once = Once::new();
209 static mut VAL: Option<Arc<Mutex<()>>> = None;
210
211 INIT.call_once(|| {
212 let val = Some(Arc::new(Mutex::new(())));
213 // Safe because the mutation is protected by the Once.
214 unsafe { VAL = val }
215 });
216
217 // Safe mutation only happens in the Once.
218 unsafe { VAL.as_ref() }.unwrap().lock().unwrap()
219 }
220
reset_counter()221 fn reset_counter() {
222 TEST_SIGNAL_COUNTER.swap(0, Ordering::SeqCst);
223 }
224
get_sigaction(signal: Signal) -> Result<sigaction>225 fn get_sigaction(signal: Signal) -> Result<sigaction> {
226 // Safe because sigaction is owned and expected to be initialized ot zeros.
227 let mut sigact: sigaction = unsafe { zeroed() };
228
229 if unsafe { sigaction(signal.into(), null(), &mut sigact) } < 0 {
230 Err(Error::Sigaction(signal, ErrnoError::last()))
231 } else {
232 Ok(sigact)
233 }
234 }
235
236 /// Safety:
237 /// This is only safe if the signal handler set in sigaction is safe.
restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction>238 unsafe fn restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction> {
239 if sigaction(signal.into(), &sigact, null_mut()) < 0 {
240 Err(Error::Sigaction(signal, ErrnoError::last()))
241 } else {
242 Ok(sigact)
243 }
244 }
245
246 /// Safety:
247 /// Safe if the signal handler for Signal::User1 is safe.
send_test_signal()248 unsafe fn send_test_signal() {
249 kill(gettid(), Signal::User1.into()).unwrap()
250 }
251
252 macro_rules! assert_counter_eq {
253 ($compare_to:expr) => {{
254 let expected: usize = $compare_to;
255 let got: usize = TEST_SIGNAL_COUNTER.load(Ordering::SeqCst);
256 if got != expected {
257 panic!(
258 "wrong signal counter value: got {}; expected {}",
259 got, expected
260 );
261 }
262 }};
263 }
264
265 struct TestHandler;
266
267 /// # Safety
268 /// Safe because handle_signal is async-signal safe.
269 unsafe impl SignalHandler for TestHandler {
handle_signal(signal: Signal)270 fn handle_signal(signal: Signal) {
271 if TEST_SIGNAL == signal {
272 TEST_SIGNAL_COUNTER.fetch_add(1, Ordering::SeqCst);
273 }
274 }
275 }
276
277 #[test]
scopedsignalhandler_success()278 fn scopedsignalhandler_success() {
279 // Prevent other test cases from running concurrently since the signal
280 // handlers are shared for the process.
281 let _guard = get_mutex();
282
283 reset_counter();
284 assert_counter_eq!(0);
285
286 assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
287 let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
288 assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
289
290 // Safe because test_handler is safe.
291 unsafe { send_test_signal() };
292
293 // Give the handler time to run in case it is on a different thread.
294 for _ in 1..40 {
295 if TEST_SIGNAL_COUNTER.load(Ordering::SeqCst) > 0 {
296 break;
297 }
298 sleep(Duration::from_millis(250));
299 }
300
301 assert_counter_eq!(1);
302
303 drop(handler);
304 assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
305 }
306
307 #[test]
scopedsignalhandler_handleralreadyset()308 fn scopedsignalhandler_handleralreadyset() {
309 // Prevent other test cases from running concurrently since the signal
310 // handlers are shared for the process.
311 let _guard = get_mutex();
312
313 reset_counter();
314 assert_counter_eq!(0);
315
316 assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
317 // Safe because TestHandler is async-signal safe.
318 let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
319 assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
320
321 // Safe because TestHandler is async-signal safe.
322 assert!(matches!(
323 ScopedSignalHandler::new::<TestHandler>(TEST_SIGNALS),
324 Err(Error::HandlerAlreadySet(Signal::User1))
325 ));
326
327 assert_counter_eq!(0);
328 drop(handler);
329 assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
330 }
331
332 /// Stores the thread used by WaitForInterruptHandler.
333 static WAIT_FOR_INTERRUPT_THREAD_ID: AtomicI32 = AtomicI32::new(0);
334 /// Forwards SIGINT to the appropriate thread.
335 struct WaitForInterruptHandler;
336
337 /// # Safety
338 /// Safe because handle_signal is async-signal safe.
339 unsafe impl SignalHandler for WaitForInterruptHandler {
handle_signal(_: Signal)340 fn handle_signal(_: Signal) {
341 let tid = WAIT_FOR_INTERRUPT_THREAD_ID.load(Ordering::SeqCst);
342 // If the thread ID is set and executed on the wrong thread, forward the signal.
343 if tid != 0 && gettid() != tid {
344 // Safe because the handler is safe and the target thread id is expecting the signal.
345 unsafe { kill(tid, Signal::Interrupt.into()) }.unwrap();
346 }
347 }
348 }
349
350 /// Query /proc/${tid}/status for its State and check if it is either S (sleeping) or in
351 /// D (disk sleep).
thread_is_sleeping(tid: Pid) -> result::Result<bool, ErrnoError>352 fn thread_is_sleeping(tid: Pid) -> result::Result<bool, ErrnoError> {
353 const PREFIX: &str = "State:";
354 let mut status_reader = BufReader::new(File::open(format!("/proc/{}/status", tid))?);
355 let mut line = String::new();
356 loop {
357 let count = status_reader.read_line(&mut line)?;
358 if count == 0 {
359 return Err(ErrnoError::new(libc::EIO));
360 }
361 if let Some(stripped) = line.strip_prefix(PREFIX) {
362 return Ok(matches!(
363 stripped.trim_start().chars().next(),
364 Some('S') | Some('D')
365 ));
366 }
367 line.clear();
368 }
369 }
370
371 /// Wait for a process to block either in a sleeping or disk sleep state.
wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> result::Result<(), ErrnoError>372 fn wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> result::Result<(), ErrnoError> {
373 let start = Instant::now();
374 loop {
375 if thread_is_sleeping(tid)? {
376 return Ok(());
377 }
378 if start.elapsed() > timeout {
379 return Err(ErrnoError::new(libc::EAGAIN));
380 }
381 sleep(Duration::from_millis(50));
382 }
383 }
384
385 #[test]
waitforinterrupt_success()386 fn waitforinterrupt_success() {
387 // Prevent other test cases from running concurrently since the signal
388 // handlers are shared for the process.
389 let _guard = get_mutex();
390
391 let to_restore = get_sigaction(Signal::Interrupt).unwrap();
392 clear_signal_handler(Signal::Interrupt.into()).unwrap();
393 // Safe because TestHandler is async-signal safe.
394 let handler =
395 ScopedSignalHandler::new::<WaitForInterruptHandler>(&[Signal::Interrupt]).unwrap();
396
397 let tid = gettid();
398 WAIT_FOR_INTERRUPT_THREAD_ID.store(tid, Ordering::SeqCst);
399
400 let join_handle = spawn(move || -> result::Result<(), ErrnoError> {
401 // Wait unitl the thread is ready to receive the signal.
402 wait_for_thread_to_sleep(tid, Duration::from_secs(10)).unwrap();
403
404 // Safe because the SIGINT handler is safe.
405 unsafe { kill(tid, Signal::Interrupt.into()) }
406 });
407 let wait_ret = wait_for_interrupt();
408 let join_ret = join_handle.join();
409
410 drop(handler);
411 // Safe because we are restoring the previous SIGINT handler.
412 unsafe { restore_sigaction(Signal::Interrupt, to_restore) }.unwrap();
413
414 wait_ret.unwrap();
415 join_ret.unwrap().unwrap();
416 }
417 }
418