• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::future::Future;
6 use std::ptr;
7 use std::sync::atomic::{AtomicI32, Ordering};
8 use std::sync::Arc;
9 use std::task::{Context, Poll};
10 
11 use futures::pin_mut;
12 use futures::task::{waker_ref, ArcWake};
13 
14 // Randomly generated values to indicate the state of the current thread.
15 const WAITING: i32 = 0x25de_74d1;
16 const WOKEN: i32 = 0x72d3_2c9f;
17 
18 const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG;
19 const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG;
20 
21 thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING))));
22 
23 #[repr(transparent)]
24 struct Waker(AtomicI32);
25 
26 extern {
27     #[cfg_attr(target_os = "android", link_name = "__errno")]
28     #[cfg_attr(target_os = "linux", link_name = "__errno_location")]
errno_location() -> *mut libc::c_int29     fn errno_location() -> *mut libc::c_int;
30 }
31 
32 impl ArcWake for Waker {
wake_by_ref(arc_self: &Arc<Self>)33     fn wake_by_ref(arc_self: &Arc<Self>) {
34         let state = arc_self.0.swap(WOKEN, Ordering::Release);
35         if state == WAITING {
36             // The thread hasn't already been woken up so wake it up now. Safe because this doesn't
37             // modify any memory and we check the return value.
38             let res = unsafe {
39                 libc::syscall(
40                     libc::SYS_futex,
41                     &arc_self.0,
42                     FUTEX_WAKE_PRIVATE,
43                     libc::INT_MAX,                        // val
44                     ptr::null() as *const libc::timespec, // timeout
45                     ptr::null() as *const libc::c_int,    // uaddr2
46                     0 as libc::c_int,                     // val3
47                 )
48             };
49             if res < 0 {
50                 panic!("unexpected error from FUTEX_WAKE_PRIVATE: {}", unsafe {
51                     *errno_location()
52                 });
53             }
54         }
55     }
56 }
57 
58 /// Run a future to completion on the current thread.
59 ///
60 /// This method will block the current thread until `f` completes. Useful when you need to call an
61 /// async fn from a non-async context.
block_on<F: Future>(f: F) -> F::Output62 pub fn block_on<F: Future>(f: F) -> F::Output {
63     pin_mut!(f);
64 
65     PER_THREAD_WAKER.with(|thread_waker| {
66         let waker = waker_ref(thread_waker);
67         let mut cx = Context::from_waker(&waker);
68 
69         loop {
70             if let Poll::Ready(t) = f.as_mut().poll(&mut cx) {
71                 return t;
72             }
73 
74             let state = thread_waker.0.swap(WAITING, Ordering::Acquire);
75             if state == WAITING {
76                 // If we weren't already woken up then wait until we are. Safe because this doesn't
77                 // modify any memory and we check the return value.
78                 let res = unsafe {
79                     libc::syscall(
80                         libc::SYS_futex,
81                         &thread_waker.0,
82                         FUTEX_WAIT_PRIVATE,
83                         state,
84                         ptr::null() as *const libc::timespec, // timeout
85                         ptr::null() as *const libc::c_int,    // uaddr2
86                         0 as libc::c_int,                     // val3
87                     )
88                 };
89 
90                 if res < 0 {
91                     // Safe because libc guarantees that this is a valid pointer.
92                     match unsafe { *errno_location() } {
93                         libc::EAGAIN | libc::EINTR => {}
94                         e => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e),
95                     }
96                 }
97 
98                 // Clear the state to prevent unnecessary extra loop iterations and also to allow
99                 // nested usage of `block_on`.
100                 thread_waker.0.store(WAITING, Ordering::Release);
101             }
102         }
103     })
104 }
105 
106 #[cfg(test)]
107 mod test {
108     use super::*;
109 
110     use std::future::Future;
111     use std::pin::Pin;
112     use std::sync::mpsc::{channel, Sender};
113     use std::sync::Arc;
114     use std::task::{Context, Poll, Waker};
115     use std::thread;
116     use std::time::Duration;
117 
118     use crate::sync::SpinLock;
119 
120     struct TimerState {
121         fired: bool,
122         waker: Option<Waker>,
123     }
124     struct Timer {
125         state: Arc<SpinLock<TimerState>>,
126     }
127 
128     impl Future for Timer {
129         type Output = ();
130 
poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>131         fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
132             let mut state = self.state.lock();
133             if state.fired {
134                 return Poll::Ready(());
135             }
136 
137             state.waker = Some(cx.waker().clone());
138             Poll::Pending
139         }
140     }
141 
start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer142     fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer {
143         let state = Arc::new(SpinLock::new(TimerState {
144             fired: false,
145             waker: None,
146         }));
147 
148         let thread_state = Arc::clone(&state);
149         thread::spawn(move || {
150             thread::sleep(dur);
151             let mut ts = thread_state.lock();
152             ts.fired = true;
153             if let Some(waker) = ts.waker.take() {
154                 waker.wake();
155             }
156             drop(ts);
157 
158             if let Some(tx) = notify {
159                 tx.send(()).expect("Failed to send completion notification");
160             }
161         });
162 
163         Timer { state }
164     }
165 
166     #[test]
it_works()167     fn it_works() {
168         block_on(start_timer(Duration::from_millis(100), None));
169     }
170 
171     #[test]
nested()172     fn nested() {
173         async fn inner() {
174             block_on(start_timer(Duration::from_millis(100), None));
175         }
176 
177         block_on(inner());
178     }
179 
180     #[test]
ready_before_poll()181     fn ready_before_poll() {
182         let (tx, rx) = channel();
183 
184         let timer = start_timer(Duration::from_millis(50), Some(tx));
185 
186         rx.recv()
187             .expect("Failed to receive completion notification");
188 
189         // We know the timer has already fired so the poll should complete immediately.
190         block_on(timer);
191     }
192 }
193