• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::fs::File;
6 use std::io::BufRead;
7 use std::io::BufReader;
8 use std::mem::zeroed;
9 use std::ptr::null;
10 use std::ptr::null_mut;
11 use std::sync::atomic::AtomicI32;
12 use std::sync::atomic::AtomicUsize;
13 use std::sync::atomic::Ordering;
14 use std::sync::Arc;
15 use std::sync::Mutex;
16 use std::sync::MutexGuard;
17 use std::sync::Once;
18 use std::thread::sleep;
19 use std::thread::spawn;
20 use std::time::Duration;
21 use std::time::Instant;
22 
23 use base::platform::gettid;
24 use base::platform::kill;
25 use base::platform::scoped_signal_handler::Error;
26 use base::platform::scoped_signal_handler::Result;
27 use base::platform::Error as ErrnoError;
28 use base::platform::Pid;
29 use base::sys::clear_signal_handler;
30 use base::sys::has_default_signal_handler;
31 use base::sys::wait_for_interrupt;
32 use base::sys::ScopedSignalHandler;
33 use base::sys::Signal;
34 use base::sys::SignalHandler;
35 use libc::sigaction;
36 
37 const TEST_SIGNAL: Signal = Signal::User1;
38 const TEST_SIGNALS: &[Signal] = &[Signal::User1, Signal::User2];
39 
40 static TEST_SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
41 
42 /// Only allows one test case to execute at a time.
get_mutex() -> MutexGuard<'static, ()>43 fn get_mutex() -> MutexGuard<'static, ()> {
44     static INIT: Once = Once::new();
45     static mut VAL: Option<Arc<Mutex<()>>> = None;
46 
47     INIT.call_once(|| {
48         let val = Some(Arc::new(Mutex::new(())));
49         // Safe because the mutation is protected by the Once.
50         unsafe { VAL = val }
51     });
52 
53     // Safe mutation only happens in the Once.
54     unsafe { VAL.as_ref() }.unwrap().lock().unwrap()
55 }
56 
reset_counter()57 fn reset_counter() {
58     TEST_SIGNAL_COUNTER.swap(0, Ordering::SeqCst);
59 }
60 
get_sigaction(signal: Signal) -> Result<sigaction>61 fn get_sigaction(signal: Signal) -> Result<sigaction> {
62     // Safe because sigaction is owned and expected to be initialized ot zeros.
63     let mut sigact: sigaction = unsafe { zeroed() };
64 
65     if unsafe { sigaction(signal.into(), null(), &mut sigact) } < 0 {
66         Err(Error::Sigaction(signal, ErrnoError::last()))
67     } else {
68         Ok(sigact)
69     }
70 }
71 
72 /// Safety:
73 /// This is only safe if the signal handler set in sigaction is safe.
restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction>74 unsafe fn restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction> {
75     if sigaction(signal.into(), &sigact, null_mut()) < 0 {
76         Err(Error::Sigaction(signal, ErrnoError::last()))
77     } else {
78         Ok(sigact)
79     }
80 }
81 
82 /// Safety:
83 /// Safe if the signal handler for Signal::User1 is safe.
send_test_signal()84 unsafe fn send_test_signal() {
85     kill(gettid(), Signal::User1.into()).unwrap()
86 }
87 
88 macro_rules! assert_counter_eq {
89     ($compare_to:expr) => {{
90         let expected: usize = $compare_to;
91         let got: usize = TEST_SIGNAL_COUNTER.load(Ordering::SeqCst);
92         if got != expected {
93             panic!(
94                 "wrong signal counter value: got {}; expected {}",
95                 got, expected
96             );
97         }
98     }};
99 }
100 
101 struct TestHandler;
102 
103 /// # Safety
104 /// Safe because handle_signal is async-signal safe.
105 unsafe impl SignalHandler for TestHandler {
handle_signal(signal: Signal)106     fn handle_signal(signal: Signal) {
107         if TEST_SIGNAL == signal {
108             TEST_SIGNAL_COUNTER.fetch_add(1, Ordering::SeqCst);
109         }
110     }
111 }
112 
113 #[test]
scopedsignalhandler_success()114 fn scopedsignalhandler_success() {
115     // Prevent other test cases from running concurrently since the signal
116     // handlers are shared for the process.
117     let _guard = get_mutex();
118 
119     reset_counter();
120     assert_counter_eq!(0);
121 
122     assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
123     let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
124     assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
125 
126     // Safe because test_handler is safe.
127     unsafe { send_test_signal() };
128 
129     // Give the handler time to run in case it is on a different thread.
130     for _ in 1..40 {
131         if TEST_SIGNAL_COUNTER.load(Ordering::SeqCst) > 0 {
132             break;
133         }
134         sleep(Duration::from_millis(250));
135     }
136 
137     assert_counter_eq!(1);
138 
139     drop(handler);
140     assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
141 }
142 
143 #[test]
scopedsignalhandler_handleralreadyset()144 fn scopedsignalhandler_handleralreadyset() {
145     // Prevent other test cases from running concurrently since the signal
146     // handlers are shared for the process.
147     let _guard = get_mutex();
148 
149     reset_counter();
150     assert_counter_eq!(0);
151 
152     assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
153     // Safe because TestHandler is async-signal safe.
154     let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
155     assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
156 
157     // Safe because TestHandler is async-signal safe.
158     assert!(matches!(
159         ScopedSignalHandler::new::<TestHandler>(TEST_SIGNALS),
160         Err(Error::HandlerAlreadySet(Signal::User1))
161     ));
162 
163     assert_counter_eq!(0);
164     drop(handler);
165     assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
166 }
167 
168 /// Stores the thread used by WaitForInterruptHandler.
169 static WAIT_FOR_INTERRUPT_THREAD_ID: AtomicI32 = AtomicI32::new(0);
170 /// Forwards SIGINT to the appropriate thread.
171 struct WaitForInterruptHandler;
172 
173 /// # Safety
174 /// Safe because handle_signal is async-signal safe.
175 unsafe impl SignalHandler for WaitForInterruptHandler {
handle_signal(_: Signal)176     fn handle_signal(_: Signal) {
177         let tid = WAIT_FOR_INTERRUPT_THREAD_ID.load(Ordering::SeqCst);
178         // If the thread ID is set and executed on the wrong thread, forward the signal.
179         if tid != 0 && gettid() != tid {
180             // Safe because the handler is safe and the target thread id is expecting the signal.
181             unsafe { kill(tid, Signal::Interrupt.into()) }.unwrap();
182         }
183     }
184 }
185 
186 /// Query /proc/${tid}/status for its State and check if it is either S (sleeping) or in
187 /// D (disk sleep).
thread_is_sleeping(tid: Pid) -> std::result::Result<bool, ErrnoError>188 fn thread_is_sleeping(tid: Pid) -> std::result::Result<bool, ErrnoError> {
189     const PREFIX: &str = "State:";
190     let mut status_reader = BufReader::new(File::open(format!("/proc/{}/status", tid))?);
191     let mut line = String::new();
192     loop {
193         let count = status_reader.read_line(&mut line)?;
194         if count == 0 {
195             return Err(ErrnoError::new(libc::EIO));
196         }
197         if let Some(stripped) = line.strip_prefix(PREFIX) {
198             return Ok(matches!(
199                 stripped.trim_start().chars().next(),
200                 Some('S') | Some('D')
201             ));
202         }
203         line.clear();
204     }
205 }
206 
207 /// Wait for a process to block either in a sleeping or disk sleep state.
wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> std::result::Result<(), ErrnoError>208 fn wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> std::result::Result<(), ErrnoError> {
209     let start = Instant::now();
210     loop {
211         if thread_is_sleeping(tid)? {
212             return Ok(());
213         }
214         if start.elapsed() > timeout {
215             return Err(ErrnoError::new(libc::EAGAIN));
216         }
217         sleep(Duration::from_millis(50));
218     }
219 }
220 
221 #[test]
waitforinterrupt_success()222 fn waitforinterrupt_success() {
223     // Prevent other test cases from running concurrently since the signal
224     // handlers are shared for the process.
225     let _guard = get_mutex();
226 
227     let to_restore = get_sigaction(Signal::Interrupt).unwrap();
228     clear_signal_handler(Signal::Interrupt.into()).unwrap();
229     // Safe because TestHandler is async-signal safe.
230     let handler =
231         ScopedSignalHandler::new::<WaitForInterruptHandler>(&[Signal::Interrupt]).unwrap();
232 
233     let tid = gettid();
234     WAIT_FOR_INTERRUPT_THREAD_ID.store(tid, Ordering::SeqCst);
235 
236     let join_handle = spawn(move || -> std::result::Result<(), ErrnoError> {
237         // Wait unitl the thread is ready to receive the signal.
238         wait_for_thread_to_sleep(tid, Duration::from_secs(10)).unwrap();
239 
240         // Safe because the SIGINT handler is safe.
241         unsafe { kill(tid, Signal::Interrupt.into()) }
242     });
243     let wait_ret = wait_for_interrupt();
244     let join_ret = join_handle.join();
245 
246     drop(handler);
247     // Safe because we are restoring the previous SIGINT handler.
248     unsafe { restore_sigaction(Signal::Interrupt, to_restore) }.unwrap();
249 
250     wait_ret.unwrap();
251     join_ret.unwrap().unwrap();
252 }
253