• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![cfg_attr(not(feature = "full"), allow(dead_code))]
2 
3 use crate::loom::sync::atomic::AtomicUsize;
4 use crate::loom::sync::{Arc, Condvar, Mutex};
5 
6 use std::sync::atomic::Ordering::SeqCst;
7 use std::time::Duration;
8 
9 #[derive(Debug)]
10 pub(crate) struct ParkThread {
11     inner: Arc<Inner>,
12 }
13 
14 /// Unblocks a thread that was blocked by `ParkThread`.
15 #[derive(Clone, Debug)]
16 pub(crate) struct UnparkThread {
17     inner: Arc<Inner>,
18 }
19 
20 #[derive(Debug)]
21 struct Inner {
22     state: AtomicUsize,
23     mutex: Mutex<()>,
24     condvar: Condvar,
25 }
26 
27 const EMPTY: usize = 0;
28 const PARKED: usize = 1;
29 const NOTIFIED: usize = 2;
30 
31 tokio_thread_local! {
32     static CURRENT_PARKER: ParkThread = ParkThread::new();
33 }
34 
35 // Bit of a hack, but it is only for loom
36 #[cfg(loom)]
37 tokio_thread_local! {
38     pub(crate) static CURRENT_THREAD_PARK_COUNT: AtomicUsize = AtomicUsize::new(0);
39 }
40 
41 // ==== impl ParkThread ====
42 
43 impl ParkThread {
new() -> Self44     pub(crate) fn new() -> Self {
45         Self {
46             inner: Arc::new(Inner {
47                 state: AtomicUsize::new(EMPTY),
48                 mutex: Mutex::new(()),
49                 condvar: Condvar::new(),
50             }),
51         }
52     }
53 
unpark(&self) -> UnparkThread54     pub(crate) fn unpark(&self) -> UnparkThread {
55         let inner = self.inner.clone();
56         UnparkThread { inner }
57     }
58 
park(&mut self)59     pub(crate) fn park(&mut self) {
60         #[cfg(loom)]
61         CURRENT_THREAD_PARK_COUNT.with(|count| count.fetch_add(1, SeqCst));
62         self.inner.park();
63     }
64 
park_timeout(&mut self, duration: Duration)65     pub(crate) fn park_timeout(&mut self, duration: Duration) {
66         #[cfg(loom)]
67         CURRENT_THREAD_PARK_COUNT.with(|count| count.fetch_add(1, SeqCst));
68 
69         // Wasm doesn't have threads, so just sleep.
70         #[cfg(not(target_family = "wasm"))]
71         self.inner.park_timeout(duration);
72         #[cfg(target_family = "wasm")]
73         std::thread::sleep(duration);
74     }
75 
shutdown(&mut self)76     pub(crate) fn shutdown(&mut self) {
77         self.inner.shutdown();
78     }
79 }
80 
81 // ==== impl Inner ====
82 
83 impl Inner {
park(&self)84     fn park(&self) {
85         // If we were previously notified then we consume this notification and
86         // return quickly.
87         if self
88             .state
89             .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
90             .is_ok()
91         {
92             return;
93         }
94 
95         // Otherwise we need to coordinate going to sleep
96         let mut m = self.mutex.lock();
97 
98         match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
99             Ok(_) => {}
100             Err(NOTIFIED) => {
101                 // We must read here, even though we know it will be `NOTIFIED`.
102                 // This is because `unpark` may have been called again since we read
103                 // `NOTIFIED` in the `compare_exchange` above. We must perform an
104                 // acquire operation that synchronizes with that `unpark` to observe
105                 // any writes it made before the call to unpark. To do that we must
106                 // read from the write it made to `state`.
107                 let old = self.state.swap(EMPTY, SeqCst);
108                 debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly");
109 
110                 return;
111             }
112             Err(actual) => panic!("inconsistent park state; actual = {actual}"),
113         }
114 
115         loop {
116             m = self.condvar.wait(m).unwrap();
117 
118             if self
119                 .state
120                 .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
121                 .is_ok()
122             {
123                 // got a notification
124                 return;
125             }
126 
127             // spurious wakeup, go back to sleep
128         }
129     }
130 
131     /// Parks the current thread for at most `dur`.
park_timeout(&self, dur: Duration)132     fn park_timeout(&self, dur: Duration) {
133         // Like `park` above we have a fast path for an already-notified thread,
134         // and afterwards we start coordinating for a sleep. Return quickly.
135         if self
136             .state
137             .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
138             .is_ok()
139         {
140             return;
141         }
142 
143         if dur == Duration::from_millis(0) {
144             return;
145         }
146 
147         let m = self.mutex.lock();
148 
149         match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
150             Ok(_) => {}
151             Err(NOTIFIED) => {
152                 // We must read again here, see `park`.
153                 let old = self.state.swap(EMPTY, SeqCst);
154                 debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly");
155 
156                 return;
157             }
158             Err(actual) => panic!("inconsistent park_timeout state; actual = {actual}"),
159         }
160 
161         // Wait with a timeout, and if we spuriously wake up or otherwise wake up
162         // from a notification, we just want to unconditionally set the state back to
163         // empty, either consuming a notification or un-flagging ourselves as
164         // parked.
165         let (_m, _result) = self.condvar.wait_timeout(m, dur).unwrap();
166 
167         match self.state.swap(EMPTY, SeqCst) {
168             NOTIFIED => {} // got a notification, hurray!
169             PARKED => {}   // no notification, alas
170             n => panic!("inconsistent park_timeout state: {n}"),
171         }
172     }
173 
unpark(&self)174     fn unpark(&self) {
175         // To ensure the unparked thread will observe any writes we made before
176         // this call, we must perform a release operation that `park` can
177         // synchronize with. To do that we must write `NOTIFIED` even if `state`
178         // is already `NOTIFIED`. That is why this must be a swap rather than a
179         // compare-and-swap that returns if it reads `NOTIFIED` on failure.
180         match self.state.swap(NOTIFIED, SeqCst) {
181             EMPTY => return,    // no one was waiting
182             NOTIFIED => return, // already unparked
183             PARKED => {}        // gotta go wake someone up
184             _ => panic!("inconsistent state in unpark"),
185         }
186 
187         // There is a period between when the parked thread sets `state` to
188         // `PARKED` (or last checked `state` in the case of a spurious wake
189         // up) and when it actually waits on `cvar`. If we were to notify
190         // during this period it would be ignored and then when the parked
191         // thread went to sleep it would never wake up. Fortunately, it has
192         // `lock` locked at this stage so we can acquire `lock` to wait until
193         // it is ready to receive the notification.
194         //
195         // Releasing `lock` before the call to `notify_one` means that when the
196         // parked thread wakes it doesn't get woken only to have to wait for us
197         // to release `lock`.
198         drop(self.mutex.lock());
199 
200         self.condvar.notify_one();
201     }
202 
shutdown(&self)203     fn shutdown(&self) {
204         self.condvar.notify_all();
205     }
206 }
207 
208 impl Default for ParkThread {
default() -> Self209     fn default() -> Self {
210         Self::new()
211     }
212 }
213 
214 // ===== impl UnparkThread =====
215 
216 impl UnparkThread {
unpark(&self)217     pub(crate) fn unpark(&self) {
218         self.inner.unpark();
219     }
220 }
221 
222 use crate::loom::thread::AccessError;
223 use std::future::Future;
224 use std::marker::PhantomData;
225 use std::rc::Rc;
226 use std::task::{RawWaker, RawWakerVTable, Waker};
227 
228 /// Blocks the current thread using a condition variable.
229 #[derive(Debug)]
230 pub(crate) struct CachedParkThread {
231     _anchor: PhantomData<Rc<()>>,
232 }
233 
234 impl CachedParkThread {
235     /// Creates a new `ParkThread` handle for the current thread.
236     ///
237     /// This type cannot be moved to other threads, so it should be created on
238     /// the thread that the caller intends to park.
new() -> CachedParkThread239     pub(crate) fn new() -> CachedParkThread {
240         CachedParkThread {
241             _anchor: PhantomData,
242         }
243     }
244 
waker(&self) -> Result<Waker, AccessError>245     pub(crate) fn waker(&self) -> Result<Waker, AccessError> {
246         self.unpark().map(UnparkThread::into_waker)
247     }
248 
unpark(&self) -> Result<UnparkThread, AccessError>249     fn unpark(&self) -> Result<UnparkThread, AccessError> {
250         self.with_current(ParkThread::unpark)
251     }
252 
park(&mut self)253     pub(crate) fn park(&mut self) {
254         self.with_current(|park_thread| park_thread.inner.park())
255             .unwrap();
256     }
257 
park_timeout(&mut self, duration: Duration)258     pub(crate) fn park_timeout(&mut self, duration: Duration) {
259         self.with_current(|park_thread| park_thread.inner.park_timeout(duration))
260             .unwrap();
261     }
262 
263     /// Gets a reference to the `ParkThread` handle for this thread.
with_current<F, R>(&self, f: F) -> Result<R, AccessError> where F: FnOnce(&ParkThread) -> R,264     fn with_current<F, R>(&self, f: F) -> Result<R, AccessError>
265     where
266         F: FnOnce(&ParkThread) -> R,
267     {
268         CURRENT_PARKER.try_with(|inner| f(inner))
269     }
270 
block_on<F: Future>(&mut self, f: F) -> Result<F::Output, AccessError>271     pub(crate) fn block_on<F: Future>(&mut self, f: F) -> Result<F::Output, AccessError> {
272         use std::task::Context;
273         use std::task::Poll::Ready;
274 
275         let waker = self.waker()?;
276         let mut cx = Context::from_waker(&waker);
277 
278         pin!(f);
279 
280         loop {
281             if let Ready(v) = crate::runtime::coop::budget(|| f.as_mut().poll(&mut cx)) {
282                 return Ok(v);
283             }
284 
285             self.park();
286         }
287     }
288 }
289 
290 impl UnparkThread {
into_waker(self) -> Waker291     pub(crate) fn into_waker(self) -> Waker {
292         unsafe {
293             let raw = unparker_to_raw_waker(self.inner);
294             Waker::from_raw(raw)
295         }
296     }
297 }
298 
299 impl Inner {
300     #[allow(clippy::wrong_self_convention)]
into_raw(this: Arc<Inner>) -> *const ()301     fn into_raw(this: Arc<Inner>) -> *const () {
302         Arc::into_raw(this) as *const ()
303     }
304 
from_raw(ptr: *const ()) -> Arc<Inner>305     unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
306         Arc::from_raw(ptr as *const Inner)
307     }
308 }
309 
unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker310 unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
311     RawWaker::new(
312         Inner::into_raw(unparker),
313         &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
314     )
315 }
316 
clone(raw: *const ()) -> RawWaker317 unsafe fn clone(raw: *const ()) -> RawWaker {
318     Arc::increment_strong_count(raw as *const Inner);
319     unparker_to_raw_waker(Inner::from_raw(raw))
320 }
321 
drop_waker(raw: *const ())322 unsafe fn drop_waker(raw: *const ()) {
323     drop(Inner::from_raw(raw));
324 }
325 
wake(raw: *const ())326 unsafe fn wake(raw: *const ()) {
327     let unparker = Inner::from_raw(raw);
328     unparker.unpark();
329 }
330 
wake_by_ref(raw: *const ())331 unsafe fn wake_by_ref(raw: *const ()) {
332     let raw = raw as *const Inner;
333     (*raw).unpark();
334 }
335 
336 #[cfg(loom)]
current_thread_park_count() -> usize337 pub(crate) fn current_thread_park_count() -> usize {
338     CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst))
339 }
340