• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::{
6     future::Future,
7     ptr,
8     sync::{
9         atomic::{AtomicI32, Ordering},
10         Arc,
11     },
12     task::{Context, Poll},
13 };
14 
15 use futures::{
16     pin_mut,
17     task::{waker_ref, ArcWake},
18 };
19 
20 // Randomly generated values to indicate the state of the current thread.
21 const WAITING: i32 = 0x25de_74d1;
22 const WOKEN: i32 = 0x72d3_2c9f;
23 
24 const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG;
25 const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG;
26 
27 thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING))));
28 
29 #[repr(transparent)]
30 struct Waker(AtomicI32);
31 
32 extern "C" {
33     #[cfg_attr(target_os = "android", link_name = "__errno")]
34     #[cfg_attr(target_os = "linux", link_name = "__errno_location")]
errno_location() -> *mut libc::c_int35     fn errno_location() -> *mut libc::c_int;
36 }
37 
38 impl ArcWake for Waker {
wake_by_ref(arc_self: &Arc<Self>)39     fn wake_by_ref(arc_self: &Arc<Self>) {
40         let state = arc_self.0.swap(WOKEN, Ordering::Release);
41         if state == WAITING {
42             // The thread hasn't already been woken up so wake it up now. Safe because this doesn't
43             // modify any memory and we check the return value.
44             let res = unsafe {
45                 libc::syscall(
46                     libc::SYS_futex,
47                     &arc_self.0,
48                     FUTEX_WAKE_PRIVATE,
49                     libc::INT_MAX,                        // val
50                     ptr::null() as *const libc::timespec, // timeout
51                     ptr::null() as *const libc::c_int,    // uaddr2
52                     0_i32,                                // val3
53                 )
54             };
55             if res < 0 {
56                 panic!("unexpected error from FUTEX_WAKE_PRIVATE: {}", unsafe {
57                     *errno_location()
58                 });
59             }
60         }
61     }
62 }
63 
64 /// Run a future to completion on the current thread.
65 ///
66 /// This method will block the current thread until `f` completes. Useful when you need to call an
67 /// async fn from a non-async context.
block_on<F: Future>(f: F) -> F::Output68 pub fn block_on<F: Future>(f: F) -> F::Output {
69     pin_mut!(f);
70 
71     PER_THREAD_WAKER.with(|thread_waker| {
72         let waker = waker_ref(thread_waker);
73         let mut cx = Context::from_waker(&waker);
74 
75         loop {
76             if let Poll::Ready(t) = f.as_mut().poll(&mut cx) {
77                 return t;
78             }
79 
80             let state = thread_waker.0.swap(WAITING, Ordering::Acquire);
81             if state == WAITING {
82                 // If we weren't already woken up then wait until we are. Safe because this doesn't
83                 // modify any memory and we check the return value.
84                 let res = unsafe {
85                     libc::syscall(
86                         libc::SYS_futex,
87                         &thread_waker.0,
88                         FUTEX_WAIT_PRIVATE,
89                         state,
90                         ptr::null() as *const libc::timespec, // timeout
91                         ptr::null() as *const libc::c_int,    // uaddr2
92                         0_i32,                                // val3
93                     )
94                 };
95 
96                 if res < 0 {
97                     // Safe because libc guarantees that this is a valid pointer.
98                     match unsafe { *errno_location() } {
99                         libc::EAGAIN | libc::EINTR => {}
100                         e => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e),
101                     }
102                 }
103 
104                 // Clear the state to prevent unnecessary extra loop iterations and also to allow
105                 // nested usage of `block_on`.
106                 thread_waker.0.store(WAITING, Ordering::Release);
107             }
108         }
109     })
110 }
111 
112 #[cfg(test)]
113 mod test {
114     use super::*;
115 
116     use std::{
117         future::Future,
118         pin::Pin,
119         sync::{
120             mpsc::{channel, Sender},
121             Arc,
122         },
123         task::{Context, Poll, Waker},
124         thread,
125         time::Duration,
126     };
127 
128     use super::super::super::sync::SpinLock;
129 
130     struct TimerState {
131         fired: bool,
132         waker: Option<Waker>,
133     }
134     struct Timer {
135         state: Arc<SpinLock<TimerState>>,
136     }
137 
138     impl Future for Timer {
139         type Output = ();
140 
poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>141         fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
142             let mut state = self.state.lock();
143             if state.fired {
144                 return Poll::Ready(());
145             }
146 
147             state.waker = Some(cx.waker().clone());
148             Poll::Pending
149         }
150     }
151 
start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer152     fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer {
153         let state = Arc::new(SpinLock::new(TimerState {
154             fired: false,
155             waker: None,
156         }));
157 
158         let thread_state = Arc::clone(&state);
159         thread::spawn(move || {
160             thread::sleep(dur);
161             let mut ts = thread_state.lock();
162             ts.fired = true;
163             if let Some(waker) = ts.waker.take() {
164                 waker.wake();
165             }
166             drop(ts);
167 
168             if let Some(tx) = notify {
169                 tx.send(()).expect("Failed to send completion notification");
170             }
171         });
172 
173         Timer { state }
174     }
175 
176     #[test]
it_works()177     fn it_works() {
178         block_on(start_timer(Duration::from_millis(100), None));
179     }
180 
181     #[test]
nested()182     fn nested() {
183         async fn inner() {
184             block_on(start_timer(Duration::from_millis(100), None));
185         }
186 
187         block_on(inner());
188     }
189 
190     #[test]
ready_before_poll()191     fn ready_before_poll() {
192         let (tx, rx) = channel();
193 
194         let timer = start_timer(Duration::from_millis(50), Some(tx));
195 
196         rx.recv()
197             .expect("Failed to receive completion notification");
198 
199         // We know the timer has already fired so the poll should complete immediately.
200         block_on(timer);
201     }
202 }
203