• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use core::cell::UnsafeCell;
2 use core::fmt;
3 use core::sync::atomic::{AtomicUsize, Ordering};
4 use core::task::Waker;
5 
6 use crate::raw::TaskVTable;
7 use crate::state::*;
8 use crate::utils::abort_on_panic;
9 
10 /// The header of a task.
11 ///
12 /// This header is stored in memory at the beginning of the heap-allocated task.
13 pub(crate) struct Header {
14     /// Current state of the task.
15     ///
16     /// Contains flags representing the current state and the reference count.
17     pub(crate) state: AtomicUsize,
18 
19     /// The task that is blocked on the `Task` handle.
20     ///
21     /// This waker needs to be woken up once the task completes or is closed.
22     pub(crate) awaiter: UnsafeCell<Option<Waker>>,
23 
24     /// The virtual table.
25     ///
26     /// In addition to the actual waker virtual table, it also contains pointers to several other
27     /// methods necessary for bookkeeping the heap-allocated task.
28     pub(crate) vtable: &'static TaskVTable,
29 }
30 
31 impl Header {
32     /// Notifies the awaiter blocked on this task.
33     ///
34     /// If the awaiter is the same as the current waker, it will not be notified.
35     #[inline]
notify(&self, current: Option<&Waker>)36     pub(crate) fn notify(&self, current: Option<&Waker>) {
37         if let Some(w) = self.take(current) {
38             abort_on_panic(|| w.wake());
39         }
40     }
41 
42     /// Takes the awaiter blocked on this task.
43     ///
44     /// If there is no awaiter or if it is the same as the current waker, returns `None`.
45     #[inline]
take(&self, current: Option<&Waker>) -> Option<Waker>46     pub(crate) fn take(&self, current: Option<&Waker>) -> Option<Waker> {
47         // Set the bit indicating that the task is notifying its awaiter.
48         let state = self.state.fetch_or(NOTIFYING, Ordering::AcqRel);
49 
50         // If the task was not notifying or registering an awaiter...
51         if state & (NOTIFYING | REGISTERING) == 0 {
52             // Take the waker out.
53             let waker = unsafe { (*self.awaiter.get()).take() };
54 
55             // Unset the bit indicating that the task is notifying its awaiter.
56             self.state
57                 .fetch_and(!NOTIFYING & !AWAITER, Ordering::Release);
58 
59             // Finally, notify the waker if it's different from the current waker.
60             if let Some(w) = waker {
61                 match current {
62                     None => return Some(w),
63                     Some(c) if !w.will_wake(c) => return Some(w),
64                     Some(_) => abort_on_panic(|| drop(w)),
65                 }
66             }
67         }
68 
69         None
70     }
71 
72     /// Registers a new awaiter blocked on this task.
73     ///
74     /// This method is called when `Task` is polled and it has not yet completed.
75     #[inline]
register(&self, waker: &Waker)76     pub(crate) fn register(&self, waker: &Waker) {
77         // Load the state and synchronize with it.
78         let mut state = self.state.fetch_or(0, Ordering::Acquire);
79 
80         loop {
81             // There can't be two concurrent registrations because `Task` can only be polled
82             // by a unique pinned reference.
83             debug_assert!(state & REGISTERING == 0);
84 
85             // If we're in the notifying state at this moment, just wake and return without
86             // registering.
87             if state & NOTIFYING != 0 {
88                 abort_on_panic(|| waker.wake_by_ref());
89                 return;
90             }
91 
92             // Mark the state to let other threads know we're registering a new awaiter.
93             match self.state.compare_exchange_weak(
94                 state,
95                 state | REGISTERING,
96                 Ordering::AcqRel,
97                 Ordering::Acquire,
98             ) {
99                 Ok(_) => {
100                     state |= REGISTERING;
101                     break;
102                 }
103                 Err(s) => state = s,
104             }
105         }
106 
107         // Put the waker into the awaiter field.
108         unsafe {
109             abort_on_panic(|| (*self.awaiter.get()) = Some(waker.clone()));
110         }
111 
112         // This variable will contain the newly registered waker if a notification comes in before
113         // we complete registration.
114         let mut waker = None;
115 
116         loop {
117             // If there was a notification, take the waker out of the awaiter field.
118             if state & NOTIFYING != 0 {
119                 if let Some(w) = unsafe { (*self.awaiter.get()).take() } {
120                     abort_on_panic(|| waker = Some(w));
121                 }
122             }
123 
124             // The new state is not being notified nor registered, but there might or might not be
125             // an awaiter depending on whether there was a concurrent notification.
126             let new = if waker.is_none() {
127                 (state & !NOTIFYING & !REGISTERING) | AWAITER
128             } else {
129                 state & !NOTIFYING & !REGISTERING & !AWAITER
130             };
131 
132             match self
133                 .state
134                 .compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire)
135             {
136                 Ok(_) => break,
137                 Err(s) => state = s,
138             }
139         }
140 
141         // If there was a notification during registration, wake the awaiter now.
142         if let Some(w) = waker {
143             abort_on_panic(|| w.wake());
144         }
145     }
146 }
147 
148 impl fmt::Debug for Header {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result149     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150         let state = self.state.load(Ordering::SeqCst);
151 
152         f.debug_struct("Header")
153             .field("scheduled", &(state & SCHEDULED != 0))
154             .field("running", &(state & RUNNING != 0))
155             .field("completed", &(state & COMPLETED != 0))
156             .field("closed", &(state & CLOSED != 0))
157             .field("awaiter", &(state & AWAITER != 0))
158             .field("task", &(state & TASK != 0))
159             .field("ref_count", &(state / REFERENCE))
160             .finish()
161     }
162 }
163