1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::fs::File;
6 use std::io::BufRead;
7 use std::io::BufReader;
8 use std::mem::zeroed;
9 use std::ptr::null;
10 use std::ptr::null_mut;
11 use std::sync::atomic::AtomicI32;
12 use std::sync::atomic::AtomicUsize;
13 use std::sync::atomic::Ordering;
14 use std::sync::Arc;
15 use std::sync::Mutex;
16 use std::sync::MutexGuard;
17 use std::sync::Once;
18 use std::thread::sleep;
19 use std::thread::spawn;
20 use std::time::Duration;
21 use std::time::Instant;
22
23 use base::platform::gettid;
24 use base::platform::kill;
25 use base::platform::scoped_signal_handler::Error;
26 use base::platform::scoped_signal_handler::Result;
27 use base::platform::Error as ErrnoError;
28 use base::platform::Pid;
29 use base::sys::clear_signal_handler;
30 use base::sys::has_default_signal_handler;
31 use base::sys::wait_for_interrupt;
32 use base::sys::ScopedSignalHandler;
33 use base::sys::Signal;
34 use base::sys::SignalHandler;
35 use libc::sigaction;
36
37 const TEST_SIGNAL: Signal = Signal::User1;
38 const TEST_SIGNALS: &[Signal] = &[Signal::User1, Signal::User2];
39
40 static TEST_SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
41
42 /// Only allows one test case to execute at a time.
get_mutex() -> MutexGuard<'static, ()>43 fn get_mutex() -> MutexGuard<'static, ()> {
44 static INIT: Once = Once::new();
45 static mut VAL: Option<Arc<Mutex<()>>> = None;
46
47 INIT.call_once(|| {
48 let val = Some(Arc::new(Mutex::new(())));
49 // Safe because the mutation is protected by the Once.
50 unsafe { VAL = val }
51 });
52
53 // Safe mutation only happens in the Once.
54 unsafe { VAL.as_ref() }.unwrap().lock().unwrap()
55 }
56
reset_counter()57 fn reset_counter() {
58 TEST_SIGNAL_COUNTER.swap(0, Ordering::SeqCst);
59 }
60
get_sigaction(signal: Signal) -> Result<sigaction>61 fn get_sigaction(signal: Signal) -> Result<sigaction> {
62 // Safe because sigaction is owned and expected to be initialized ot zeros.
63 let mut sigact: sigaction = unsafe { zeroed() };
64
65 if unsafe { sigaction(signal.into(), null(), &mut sigact) } < 0 {
66 Err(Error::Sigaction(signal, ErrnoError::last()))
67 } else {
68 Ok(sigact)
69 }
70 }
71
72 /// Safety:
73 /// This is only safe if the signal handler set in sigaction is safe.
restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction>74 unsafe fn restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction> {
75 if sigaction(signal.into(), &sigact, null_mut()) < 0 {
76 Err(Error::Sigaction(signal, ErrnoError::last()))
77 } else {
78 Ok(sigact)
79 }
80 }
81
82 /// Safety:
83 /// Safe if the signal handler for Signal::User1 is safe.
send_test_signal()84 unsafe fn send_test_signal() {
85 kill(gettid(), Signal::User1.into()).unwrap()
86 }
87
88 macro_rules! assert_counter_eq {
89 ($compare_to:expr) => {{
90 let expected: usize = $compare_to;
91 let got: usize = TEST_SIGNAL_COUNTER.load(Ordering::SeqCst);
92 if got != expected {
93 panic!(
94 "wrong signal counter value: got {}; expected {}",
95 got, expected
96 );
97 }
98 }};
99 }
100
101 struct TestHandler;
102
103 /// # Safety
104 /// Safe because handle_signal is async-signal safe.
105 unsafe impl SignalHandler for TestHandler {
handle_signal(signal: Signal)106 fn handle_signal(signal: Signal) {
107 if TEST_SIGNAL == signal {
108 TEST_SIGNAL_COUNTER.fetch_add(1, Ordering::SeqCst);
109 }
110 }
111 }
112
113 #[test]
scopedsignalhandler_success()114 fn scopedsignalhandler_success() {
115 // Prevent other test cases from running concurrently since the signal
116 // handlers are shared for the process.
117 let _guard = get_mutex();
118
119 reset_counter();
120 assert_counter_eq!(0);
121
122 assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
123 let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
124 assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
125
126 // Safe because test_handler is safe.
127 unsafe { send_test_signal() };
128
129 // Give the handler time to run in case it is on a different thread.
130 for _ in 1..40 {
131 if TEST_SIGNAL_COUNTER.load(Ordering::SeqCst) > 0 {
132 break;
133 }
134 sleep(Duration::from_millis(250));
135 }
136
137 assert_counter_eq!(1);
138
139 drop(handler);
140 assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
141 }
142
143 #[test]
scopedsignalhandler_handleralreadyset()144 fn scopedsignalhandler_handleralreadyset() {
145 // Prevent other test cases from running concurrently since the signal
146 // handlers are shared for the process.
147 let _guard = get_mutex();
148
149 reset_counter();
150 assert_counter_eq!(0);
151
152 assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
153 // Safe because TestHandler is async-signal safe.
154 let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
155 assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
156
157 // Safe because TestHandler is async-signal safe.
158 assert!(matches!(
159 ScopedSignalHandler::new::<TestHandler>(TEST_SIGNALS),
160 Err(Error::HandlerAlreadySet(Signal::User1))
161 ));
162
163 assert_counter_eq!(0);
164 drop(handler);
165 assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
166 }
167
168 /// Stores the thread used by WaitForInterruptHandler.
169 static WAIT_FOR_INTERRUPT_THREAD_ID: AtomicI32 = AtomicI32::new(0);
170 /// Forwards SIGINT to the appropriate thread.
171 struct WaitForInterruptHandler;
172
173 /// # Safety
174 /// Safe because handle_signal is async-signal safe.
175 unsafe impl SignalHandler for WaitForInterruptHandler {
handle_signal(_: Signal)176 fn handle_signal(_: Signal) {
177 let tid = WAIT_FOR_INTERRUPT_THREAD_ID.load(Ordering::SeqCst);
178 // If the thread ID is set and executed on the wrong thread, forward the signal.
179 if tid != 0 && gettid() != tid {
180 // Safe because the handler is safe and the target thread id is expecting the signal.
181 unsafe { kill(tid, Signal::Interrupt.into()) }.unwrap();
182 }
183 }
184 }
185
186 /// Query /proc/${tid}/status for its State and check if it is either S (sleeping) or in
187 /// D (disk sleep).
thread_is_sleeping(tid: Pid) -> std::result::Result<bool, ErrnoError>188 fn thread_is_sleeping(tid: Pid) -> std::result::Result<bool, ErrnoError> {
189 const PREFIX: &str = "State:";
190 let mut status_reader = BufReader::new(File::open(format!("/proc/{}/status", tid))?);
191 let mut line = String::new();
192 loop {
193 let count = status_reader.read_line(&mut line)?;
194 if count == 0 {
195 return Err(ErrnoError::new(libc::EIO));
196 }
197 if let Some(stripped) = line.strip_prefix(PREFIX) {
198 return Ok(matches!(
199 stripped.trim_start().chars().next(),
200 Some('S') | Some('D')
201 ));
202 }
203 line.clear();
204 }
205 }
206
207 /// Wait for a process to block either in a sleeping or disk sleep state.
wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> std::result::Result<(), ErrnoError>208 fn wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> std::result::Result<(), ErrnoError> {
209 let start = Instant::now();
210 loop {
211 if thread_is_sleeping(tid)? {
212 return Ok(());
213 }
214 if start.elapsed() > timeout {
215 return Err(ErrnoError::new(libc::EAGAIN));
216 }
217 sleep(Duration::from_millis(50));
218 }
219 }
220
221 #[test]
waitforinterrupt_success()222 fn waitforinterrupt_success() {
223 // Prevent other test cases from running concurrently since the signal
224 // handlers are shared for the process.
225 let _guard = get_mutex();
226
227 let to_restore = get_sigaction(Signal::Interrupt).unwrap();
228 clear_signal_handler(Signal::Interrupt.into()).unwrap();
229 // Safe because TestHandler is async-signal safe.
230 let handler =
231 ScopedSignalHandler::new::<WaitForInterruptHandler>(&[Signal::Interrupt]).unwrap();
232
233 let tid = gettid();
234 WAIT_FOR_INTERRUPT_THREAD_ID.store(tid, Ordering::SeqCst);
235
236 let join_handle = spawn(move || -> std::result::Result<(), ErrnoError> {
237 // Wait unitl the thread is ready to receive the signal.
238 wait_for_thread_to_sleep(tid, Duration::from_secs(10)).unwrap();
239
240 // Safe because the SIGINT handler is safe.
241 unsafe { kill(tid, Signal::Interrupt.into()) }
242 });
243 let wait_ret = wait_for_interrupt();
244 let join_ret = join_handle.join();
245
246 drop(handler);
247 // Safe because we are restoring the previous SIGINT handler.
248 unsafe { restore_sigaction(Signal::Interrupt, to_restore) }.unwrap();
249
250 wait_ret.unwrap();
251 join_ret.unwrap().unwrap();
252 }
253