• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::ffi::c_void;
6 use std::future::Future;
7 use std::marker::PhantomData;
8 use std::marker::PhantomPinned;
9 use std::pin::Pin;
10 use std::ptr::null_mut;
11 use std::sync::MutexGuard;
12 use std::task::Context;
13 use std::task::Poll;
14 use std::task::Waker;
15 
16 use base::error;
17 use base::warn;
18 use base::AsRawDescriptor;
19 use base::Descriptor;
20 use sync::Mutex;
21 use winapi::shared::ntdef::FALSE;
22 use winapi::um::handleapi::INVALID_HANDLE_VALUE;
23 use winapi::um::threadpoollegacyapiset::UnregisterWaitEx;
24 use winapi::um::winbase::RegisterWaitForSingleObject;
25 use winapi::um::winbase::INFINITE;
26 use winapi::um::winnt::BOOLEAN;
27 use winapi::um::winnt::PVOID;
28 use winapi::um::winnt::WT_EXECUTEONLYONCE;
29 
30 use crate::sys::windows::handle_source::Error;
31 use crate::sys::windows::handle_source::Result;
32 
33 /// Inner state shared between the future struct & the kernel invoked waiter callback.
34 struct WaitForHandleInner {
35     wait_state: WaitState,
36     wait_object: Descriptor,
37     waker: Option<Waker>,
38 }
39 impl WaitForHandleInner {
new() -> WaitForHandleInner40     fn new() -> WaitForHandleInner {
41         WaitForHandleInner {
42             wait_state: WaitState::New,
43             wait_object: Descriptor(null_mut::<c_void>()),
44             waker: None,
45         }
46     }
47 }
48 
49 /// Future's state.
50 #[derive(Clone, Copy, PartialEq, Eq)]
51 enum WaitState {
52     New,
53     Sleeping,
54     Woken,
55     Aborted,
56     Finished,
57     Failed,
58 }
59 
60 /// Waits for an object with a handle to be readable.
61 pub struct WaitForHandle<'a, T: AsRawDescriptor> {
62     handle: Descriptor,
63     inner: Mutex<WaitForHandleInner>,
64     _marker: PhantomData<&'a T>,
65     _pinned_marker: PhantomPinned,
66 }
67 
68 impl<'a, T> WaitForHandle<'a, T>
69 where
70     T: AsRawDescriptor,
71 {
new(source: &'a T) -> WaitForHandle<'a, T>72     pub fn new(source: &'a T) -> WaitForHandle<'a, T> {
73         WaitForHandle {
74             handle: Descriptor(source.as_raw_descriptor()),
75             inner: Mutex::new(WaitForHandleInner::new()),
76             _marker: PhantomData,
77             _pinned_marker: PhantomPinned,
78         }
79     }
80 }
81 
82 impl<'a, T> Future for WaitForHandle<'a, T>
83 where
84     T: AsRawDescriptor,
85 {
86     type Output = Result<()>;
87 
poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>88     fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
89         let inner_for_callback = &self.inner as *const _ as *mut Mutex<WaitForHandleInner>;
90         let mut inner = self.inner.lock();
91         match inner.wait_state {
92             WaitState::New => {
93                 // SAFETY:
94                 // Safe because:
95                 //      a) the callback only runs when WaitForHandle is alive (we cancel it on
96                 //         drop).
97                 //      b) inner & its children are owned by WaitForHandle.
98                 let err = unsafe {
99                     RegisterWaitForSingleObject(
100                         &mut inner.wait_object as *mut _ as *mut *mut c_void,
101                         self.handle.0,
102                         Some(wait_for_handle_waker),
103                         inner_for_callback as *mut c_void,
104                         INFINITE,
105                         WT_EXECUTEONLYONCE,
106                     )
107                 };
108                 if err == 0 {
109                     return Poll::Ready(Err(Error::HandleWaitFailed(base::Error::last())));
110                 }
111 
112                 inner.wait_state = WaitState::Sleeping;
113                 inner.waker = Some(cx.waker().clone());
114                 Poll::Pending
115             }
116             WaitState::Sleeping => {
117                 // In case we are polled with a different waker which won't be woken by the existing
118                 // waker, we'll have to update to the new waker.
119                 if inner
120                     .waker
121                     .as_ref()
122                     .map(|w| !w.will_wake(cx.waker()))
123                     .unwrap_or(true)
124                 {
125                     inner.waker = Some(cx.waker().clone());
126                 }
127                 Poll::Pending
128             }
129             WaitState::Woken => {
130                 inner.wait_state = WaitState::Finished;
131 
132                 // SAFETY:
133                 // Safe because:
134                 // a) we know a wait was registered and hasn't been unregistered yet.
135                 // b) the callback is not queued because we set WT_EXECUTEONLYONCE, and we know
136                 //    it has already completed.
137                 unsafe { unregister_wait(inner.wait_object) }
138 
139                 Poll::Ready(Ok(()))
140             }
141             WaitState::Aborted => Poll::Ready(Err(Error::OperationAborted)),
142             WaitState::Finished => panic!("polled an already completed WaitForHandle future."),
143             WaitState::Failed => {
144                 panic!("WaitForHandle future's waiter callback hit unexpected behavior.")
145             }
146         }
147     }
148 }
149 
150 impl<'a, T> Drop for WaitForHandle<'a, T>
151 where
152     T: AsRawDescriptor,
153 {
drop(&mut self)154     fn drop(&mut self) {
155         // We cannot hold the lock over the call to unregister_wait, otherwise we could deadlock
156         // with the callback trying to access the same data. It is sufficient to just verify
157         // (without mutual exclusion beyond the data access itself) that we have exited the New
158         // state before attempting to unregister. This works because once we have exited New, we
159         // cannot ever re-enter that state, and we know for sure that inner.wait_object is a valid
160         // wait object.
161         let (current_state, wait_object) = {
162             let inner = self.inner.lock();
163             (inner.wait_state, inner.wait_object)
164         };
165 
166         if current_state != WaitState::New && current_state != WaitState::Finished {
167             // SAFETY:
168             // Safe because self.descriptor is valid in any state except New or Finished.
169             //
170             // Note: this method call is critical for supplying the safety guarantee relied upon by
171             // wait_for_handle_waker. Upon return, it ensures that wait_for_handle_waker is not
172             // running and won't be scheduled again, which makes it safe to drop
173             // self.inner_for_callback (wait_for_handle_waker has a non owning pointer
174             // to self.inner_for_callback).
175             unsafe { unregister_wait(wait_object) }
176         }
177     }
178 }
179 
180 /// Safe portion of the RegisterWaitForSingleObject callback.
process_wait_state_change( mut state: MutexGuard<WaitForHandleInner>, wait_fired: bool, ) -> Option<Waker>181 fn process_wait_state_change(
182     mut state: MutexGuard<WaitForHandleInner>,
183     wait_fired: bool,
184 ) -> Option<Waker> {
185     let mut waker = None;
186     state.wait_state = match state.wait_state {
187         WaitState::Sleeping => {
188             let new_state = if wait_fired {
189                 WaitState::Woken
190             } else {
191                 // This should never happen.
192                 error!("wait_for_handle_waker did not wake due to wait firing.");
193                 WaitState::Aborted
194             };
195 
196             match state.waker.take() {
197                 Some(w) => {
198                     waker = Some(w);
199                     new_state
200                 }
201                 None => {
202                     error!("wait_for_handler_waker called, but no waker available.");
203                     WaitState::Failed
204                 }
205             }
206         }
207         _ => {
208             error!("wait_for_handle_waker called with state != sleeping.");
209             WaitState::Failed
210         }
211     };
212     waker
213 }
214 
215 /// # Safety
216 /// a) inner_ptr is valid whenever this function can be called. This is guaranteed by WaitForHandle,
217 ///    which cannot be dropped until this function has finished running & is no longer queued for
218 ///    execution because the Drop impl calls UnregisterWaitEx, which blocks on that condition.
wait_for_handle_waker(inner_ptr: PVOID, timer_or_wait_fired: BOOLEAN)219 unsafe extern "system" fn wait_for_handle_waker(inner_ptr: PVOID, timer_or_wait_fired: BOOLEAN) {
220     let inner = inner_ptr as *const Mutex<WaitForHandleInner>;
221     let inner_locked = (*inner).lock();
222     let waker = process_wait_state_change(
223         inner_locked,
224         /* wait_fired= */ timer_or_wait_fired == FALSE,
225     );
226 
227     // We wake *after* releasing the lock to avoid waking up a thread that then will go back to
228     // sleep because the lock it needs is currently held.
229     if let Some(w) = waker {
230         w.wake()
231     }
232 }
233 
234 /// # Safety
235 /// a) desc must be a valid wait handle from RegisterWaitForSingleObject.
unregister_wait(desc: Descriptor)236 unsafe fn unregister_wait(desc: Descriptor) {
237     if UnregisterWaitEx(desc.0, INVALID_HANDLE_VALUE) == 0 {
238         warn!(
239             "WaitForHandle: failed to clean up RegisterWaitForSingleObject wait handle: {}",
240             base::Error::last()
241         )
242     }
243 }
244 
245 #[cfg(test)]
246 mod tests {
247     use std::sync::Arc;
248     use std::sync::Weak;
249     use std::time::Duration;
250 
251     use base::thread::spawn_with_timeout;
252     use base::Event;
253     use futures::pin_mut;
254 
255     use super::*;
256     use crate::waker::new_waker;
257     use crate::waker::WeakWake;
258     use crate::EventAsync;
259     use crate::Executor;
260 
261     struct FakeWaker {}
262     impl WeakWake for FakeWaker {
wake_by_ref(_weak_self: &Weak<Self>)263         fn wake_by_ref(_weak_self: &Weak<Self>) {
264             // Do nothing.
265         }
266     }
267 
268     #[test]
test_unsignaled_event()269     fn test_unsignaled_event() {
270         async fn wait_on_unsignaled_event(evt: EventAsync) {
271             evt.next_val().await.unwrap();
272             panic!("await should never terminate");
273         }
274 
275         let fake_waker = Arc::new(FakeWaker {});
276         let waker = new_waker(Arc::downgrade(&fake_waker));
277         let mut cx = Context::from_waker(&waker);
278 
279         let ex = Executor::new().unwrap();
280         let evt = Event::new().unwrap();
281         let async_evt = EventAsync::new(evt, &ex).unwrap();
282 
283         let fut = wait_on_unsignaled_event(async_evt);
284         pin_mut!(fut);
285 
286         // Assert we make it to the pending state. This means we've registered a wait.
287         assert_eq!(fut.poll(&mut cx), Poll::Pending);
288 
289         // If this test doesn't crash trying to drop the future, it is considered successful.
290     }
291 
292     #[test]
test_signaled_event()293     fn test_signaled_event() {
294         let join_handle = spawn_with_timeout(|| {
295             async fn wait_on_signaled_event(evt: EventAsync) {
296                 evt.next_val().await.unwrap();
297             }
298 
299             let ex = Executor::new().unwrap();
300             let evt = Event::new().unwrap();
301             evt.signal().unwrap();
302             let async_evt = EventAsync::new(evt, &ex).unwrap();
303 
304             let fut = wait_on_signaled_event(async_evt);
305             pin_mut!(fut);
306 
307             ex.run_until(fut).unwrap();
308         });
309         join_handle
310             .try_join(Duration::from_secs(5))
311             .expect("async wait never returned from signaled event.");
312     }
313 }
314