• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::future::Future;
6 use std::ptr;
7 use std::sync::atomic::AtomicI32;
8 use std::sync::atomic::Ordering;
9 use std::sync::Arc;
10 use std::task::Context;
11 use std::task::Poll;
12 
13 use futures::pin_mut;
14 use futures::task::waker_ref;
15 use futures::task::ArcWake;
16 
17 // Randomly generated values to indicate the state of the current thread.
18 const WAITING: i32 = 0x25de_74d1;
19 const WOKEN: i32 = 0x72d3_2c9f;
20 
21 const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG;
22 const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG;
23 
24 thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING))));
25 
26 #[repr(transparent)]
27 struct Waker(AtomicI32);
28 
29 impl ArcWake for Waker {
wake_by_ref(arc_self: &Arc<Self>)30     fn wake_by_ref(arc_self: &Arc<Self>) {
31         let state = arc_self.0.swap(WOKEN, Ordering::Release);
32         if state == WAITING {
33             // SAFETY:
34             // The thread hasn't already been woken up so wake it up now. Safe because this doesn't
35             // modify any memory and we check the return value.
36             let res = unsafe {
37                 libc::syscall(
38                     libc::SYS_futex,
39                     &arc_self.0,
40                     FUTEX_WAKE_PRIVATE,
41                     libc::INT_MAX,                        // val
42                     ptr::null::<*const libc::timespec>(), // timeout
43                     ptr::null::<*const libc::c_int>(),    // uaddr2
44                     0_i32,                                // val3
45                 )
46             };
47             if res < 0 {
48                 panic!(
49                     "unexpected error from FUTEX_WAKE_PRIVATE: {}",
50                     std::io::Error::last_os_error()
51                 );
52             }
53         }
54     }
55 }
56 
57 /// Run a future to completion on the current thread.
58 ///
59 /// This method will block the current thread until `f` completes. Useful when you need to call an
60 /// async fn from a non-async context.
block_on<F: Future>(f: F) -> F::Output61 pub fn block_on<F: Future>(f: F) -> F::Output {
62     pin_mut!(f);
63 
64     PER_THREAD_WAKER.with(|thread_waker| {
65         let waker = waker_ref(thread_waker);
66         let mut cx = Context::from_waker(&waker);
67 
68         loop {
69             if let Poll::Ready(t) = f.as_mut().poll(&mut cx) {
70                 return t;
71             }
72 
73             let state = thread_waker.0.swap(WAITING, Ordering::Acquire);
74             if state == WAITING {
75                 // SAFETY:
76                 // If we weren't already woken up then wait until we are. Safe because this doesn't
77                 // modify any memory and we check the return value.
78                 let res = unsafe {
79                     libc::syscall(
80                         libc::SYS_futex,
81                         &thread_waker.0,
82                         FUTEX_WAIT_PRIVATE,
83                         state,
84                         ptr::null::<*const libc::timespec>(), // timeout
85                         ptr::null::<*const libc::c_int>(),    // uaddr2
86                         0_i32,                                // val3
87                     )
88                 };
89 
90                 if res < 0 {
91                     let e = std::io::Error::last_os_error();
92                     match e.raw_os_error() {
93                         Some(libc::EAGAIN) | Some(libc::EINTR) => {}
94                         _ => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e),
95                     }
96                 }
97 
98                 // Clear the state to prevent unnecessary extra loop iterations and also to allow
99                 // nested usage of `block_on`.
100                 thread_waker.0.store(WAITING, Ordering::Release);
101             }
102         }
103     })
104 }
105 
106 #[cfg(test)]
107 mod test {
108     use std::future::Future;
109     use std::pin::Pin;
110     use std::sync::mpsc::channel;
111     use std::sync::mpsc::Sender;
112     use std::sync::Arc;
113     use std::task::Context;
114     use std::task::Poll;
115     use std::task::Waker;
116     use std::thread;
117     use std::time::Duration;
118 
119     use super::*;
120     use crate::sync::SpinLock;
121 
122     struct TimerState {
123         fired: bool,
124         waker: Option<Waker>,
125     }
126     struct Timer {
127         state: Arc<SpinLock<TimerState>>,
128     }
129 
130     impl Future for Timer {
131         type Output = ();
132 
poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>133         fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
134             let mut state = self.state.lock();
135             if state.fired {
136                 return Poll::Ready(());
137             }
138 
139             state.waker = Some(cx.waker().clone());
140             Poll::Pending
141         }
142     }
143 
start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer144     fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer {
145         let state = Arc::new(SpinLock::new(TimerState {
146             fired: false,
147             waker: None,
148         }));
149 
150         let thread_state = Arc::clone(&state);
151         thread::spawn(move || {
152             thread::sleep(dur);
153             let mut ts = thread_state.lock();
154             ts.fired = true;
155             if let Some(waker) = ts.waker.take() {
156                 waker.wake();
157             }
158             drop(ts);
159 
160             if let Some(tx) = notify {
161                 tx.send(()).expect("Failed to send completion notification");
162             }
163         });
164 
165         Timer { state }
166     }
167 
168     #[test]
it_works()169     fn it_works() {
170         block_on(start_timer(Duration::from_millis(100), None));
171     }
172 
173     #[test]
nested()174     fn nested() {
175         async fn inner() {
176             block_on(start_timer(Duration::from_millis(100), None));
177         }
178 
179         block_on(inner());
180     }
181 
182     #[test]
ready_before_poll()183     fn ready_before_poll() {
184         let (tx, rx) = channel();
185 
186         let timer = start_timer(Duration::from_millis(50), Some(tx));
187 
188         rx.recv()
189             .expect("Failed to receive completion notification");
190 
191         // We know the timer has already fired so the poll should complete immediately.
192         block_on(timer);
193     }
194 }
195