1 use crate::future::Future;
2 use crate::runtime::task::{Header, RawTask, Schedule};
3
4 use std::marker::PhantomData;
5 use std::mem::ManuallyDrop;
6 use std::ops;
7 use std::ptr::NonNull;
8 use std::task::{RawWaker, RawWakerVTable, Waker};
9
10 pub(super) struct WakerRef<'a, S: 'static> {
11 waker: ManuallyDrop<Waker>,
12 _p: PhantomData<(&'a Header, S)>,
13 }
14
15 /// Returns a `WakerRef` which avoids having to preemptively increase the
16 /// refcount if there is no need to do so.
waker_ref<T, S>(header: &NonNull<Header>) -> WakerRef<'_, S> where T: Future, S: Schedule,17 pub(super) fn waker_ref<T, S>(header: &NonNull<Header>) -> WakerRef<'_, S>
18 where
19 T: Future,
20 S: Schedule,
21 {
22 // `Waker::will_wake` uses the VTABLE pointer as part of the check. This
23 // means that `will_wake` will always return false when using the current
24 // task's waker. (discussion at rust-lang/rust#66281).
25 //
26 // To fix this, we use a single vtable. Since we pass in a reference at this
27 // point and not an *owned* waker, we must ensure that `drop` is never
28 // called on this waker instance. This is done by wrapping it with
29 // `ManuallyDrop` and then never calling drop.
30 let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker(*header))) };
31
32 WakerRef {
33 waker,
34 _p: PhantomData,
35 }
36 }
37
38 impl<S> ops::Deref for WakerRef<'_, S> {
39 type Target = Waker;
40
deref(&self) -> &Waker41 fn deref(&self) -> &Waker {
42 &self.waker
43 }
44 }
45
46 cfg_trace! {
47 macro_rules! trace {
48 ($header:expr, $op:expr) => {
49 if let Some(id) = Header::get_tracing_id(&$header) {
50 tracing::trace!(
51 target: "tokio::task::waker",
52 op = $op,
53 task.id = id.into_u64(),
54 );
55 }
56 }
57 }
58 }
59
60 cfg_not_trace! {
61 macro_rules! trace {
62 ($header:expr, $op:expr) => {
63 // noop
64 let _ = &$header;
65 }
66 }
67 }
68
clone_waker(ptr: *const ()) -> RawWaker69 unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
70 let header = NonNull::new_unchecked(ptr as *mut Header);
71 trace!(header, "waker.clone");
72 header.as_ref().state.ref_inc();
73 raw_waker(header)
74 }
75
drop_waker(ptr: *const ())76 unsafe fn drop_waker(ptr: *const ()) {
77 let ptr = NonNull::new_unchecked(ptr as *mut Header);
78 trace!(ptr, "waker.drop");
79 let raw = RawTask::from_raw(ptr);
80 raw.drop_reference();
81 }
82
wake_by_val(ptr: *const ())83 unsafe fn wake_by_val(ptr: *const ()) {
84 let ptr = NonNull::new_unchecked(ptr as *mut Header);
85 trace!(ptr, "waker.wake");
86 let raw = RawTask::from_raw(ptr);
87 raw.wake_by_val();
88 }
89
90 // Wake without consuming the waker
wake_by_ref(ptr: *const ())91 unsafe fn wake_by_ref(ptr: *const ()) {
92 let ptr = NonNull::new_unchecked(ptr as *mut Header);
93 trace!(ptr, "waker.wake_by_ref");
94 let raw = RawTask::from_raw(ptr);
95 raw.wake_by_ref();
96 }
97
98 static WAKER_VTABLE: RawWakerVTable =
99 RawWakerVTable::new(clone_waker, wake_by_val, wake_by_ref, drop_waker);
100
raw_waker(header: NonNull<Header>) -> RawWaker101 fn raw_waker(header: NonNull<Header>) -> RawWaker {
102 let ptr = header.as_ptr() as *const ();
103 RawWaker::new(ptr, &WAKER_VTABLE)
104 }
105