• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Scheduler, Trailer};
2 use crate::runtime::task::state::Snapshot;
3 use crate::runtime::task::waker::waker_ref;
4 use crate::runtime::task::{JoinError, Notified, Schedule, Task};
5 
6 use std::future::Future;
7 use std::mem;
8 use std::panic;
9 use std::ptr::NonNull;
10 use std::task::{Context, Poll, Waker};
11 
12 /// Typed raw task handle
13 pub(super) struct Harness<T: Future, S: 'static> {
14     cell: NonNull<Cell<T, S>>,
15 }
16 
17 impl<T, S> Harness<T, S>
18 where
19     T: Future,
20     S: 'static,
21 {
from_raw(ptr: NonNull<Header>) -> Harness<T, S>22     pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> Harness<T, S> {
23         Harness {
24             cell: ptr.cast::<Cell<T, S>>(),
25         }
26     }
27 
header(&self) -> &Header28     fn header(&self) -> &Header {
29         unsafe { &self.cell.as_ref().header }
30     }
31 
trailer(&self) -> &Trailer32     fn trailer(&self) -> &Trailer {
33         unsafe { &self.cell.as_ref().trailer }
34     }
35 
core(&self) -> &Core<T, S>36     fn core(&self) -> &Core<T, S> {
37         unsafe { &self.cell.as_ref().core }
38     }
39 
scheduler_view(&self) -> SchedulerView<'_, S>40     fn scheduler_view(&self) -> SchedulerView<'_, S> {
41         SchedulerView {
42             header: self.header(),
43             scheduler: &self.core().scheduler,
44         }
45     }
46 }
47 
48 impl<T, S> Harness<T, S>
49 where
50     T: Future,
51     S: Schedule,
52 {
53     /// Polls the inner future.
54     ///
55     /// All necessary state checks and transitions are performed.
56     ///
57     /// Panics raised while polling the future are handled.
poll(self)58     pub(super) fn poll(self) {
59         match self.poll_inner() {
60             PollFuture::Notified => {
61                 // Signal yield
62                 self.core().scheduler.yield_now(Notified(self.to_task()));
63                 // The ref-count was incremented as part of
64                 // `transition_to_idle`.
65                 self.drop_reference();
66             }
67             PollFuture::DropReference => {
68                 self.drop_reference();
69             }
70             PollFuture::Complete(out, is_join_interested) => {
71                 self.complete(out, is_join_interested);
72             }
73             PollFuture::None => (),
74         }
75     }
76 
poll_inner(&self) -> PollFuture<T::Output>77     fn poll_inner(&self) -> PollFuture<T::Output> {
78         let snapshot = match self.scheduler_view().transition_to_running() {
79             TransitionToRunning::Ok(snapshot) => snapshot,
80             TransitionToRunning::DropReference => return PollFuture::DropReference,
81         };
82 
83         // The transition to `Running` done above ensures that a lock on the
84         // future has been obtained. This also ensures the `*mut T` pointer
85         // contains the future (as opposed to the output) and is initialized.
86 
87         let waker_ref = waker_ref::<T, S>(self.header());
88         let cx = Context::from_waker(&*waker_ref);
89         poll_future(self.header(), &self.core().stage, snapshot, cx)
90     }
91 
dealloc(self)92     pub(super) fn dealloc(self) {
93         // Release the join waker, if there is one.
94         self.trailer().waker.with_mut(drop);
95 
96         // Check causality
97         self.core().stage.with_mut(drop);
98         self.core().scheduler.with_mut(drop);
99 
100         unsafe {
101             drop(Box::from_raw(self.cell.as_ptr()));
102         }
103     }
104 
105     // ===== join handle =====
106 
107     /// Read the task output into `dst`.
try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker)108     pub(super) fn try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker) {
109         if can_read_output(self.header(), self.trailer(), waker) {
110             *dst = Poll::Ready(self.core().stage.take_output());
111         }
112     }
113 
drop_join_handle_slow(self)114     pub(super) fn drop_join_handle_slow(self) {
115         // Try to unset `JOIN_INTEREST`. This must be done as a first step in
116         // case the task concurrently completed.
117         if self.header().state.unset_join_interested().is_err() {
118             // It is our responsibility to drop the output. This is critical as
119             // the task output may not be `Send` and as such must remain with
120             // the scheduler or `JoinHandle`. i.e. if the output remains in the
121             // task structure until the task is deallocated, it may be dropped
122             // by a Waker on any arbitrary thread.
123             self.core().stage.drop_future_or_output();
124         }
125 
126         // Drop the `JoinHandle` reference, possibly deallocating the task
127         self.drop_reference();
128     }
129 
130     // ===== waker behavior =====
131 
wake_by_val(self)132     pub(super) fn wake_by_val(self) {
133         self.wake_by_ref();
134         self.drop_reference();
135     }
136 
wake_by_ref(&self)137     pub(super) fn wake_by_ref(&self) {
138         if self.header().state.transition_to_notified() {
139             self.core().scheduler.schedule(Notified(self.to_task()));
140         }
141     }
142 
drop_reference(self)143     pub(super) fn drop_reference(self) {
144         if self.header().state.ref_dec() {
145             self.dealloc();
146         }
147     }
148 
149     /// Forcibly shutdown the task
150     ///
151     /// Attempt to transition to `Running` in order to forcibly shutdown the
152     /// task. If the task is currently running or in a state of completion, then
153     /// there is nothing further to do. When the task completes running, it will
154     /// notice the `CANCELLED` bit and finalize the task.
shutdown(self)155     pub(super) fn shutdown(self) {
156         if !self.header().state.transition_to_shutdown() {
157             // The task is concurrently running. No further work needed.
158             return;
159         }
160 
161         // By transitioning the lifcycle to `Running`, we have permission to
162         // drop the future.
163         let err = cancel_task(&self.core().stage);
164         self.complete(Err(err), true)
165     }
166 
167     // ====== internal ======
168 
complete(self, output: super::Result<T::Output>, is_join_interested: bool)169     fn complete(self, output: super::Result<T::Output>, is_join_interested: bool) {
170         if is_join_interested {
171             // Store the output. The future has already been dropped
172             //
173             // Safety: Mutual exclusion is obtained by having transitioned the task
174             // state -> Running
175             let stage = &self.core().stage;
176             stage.store_output(output);
177 
178             // Transition to `Complete`, notifying the `JoinHandle` if necessary.
179             transition_to_complete(self.header(), stage, &self.trailer());
180         }
181 
182         // The task has completed execution and will no longer be scheduled.
183         //
184         // Attempts to batch a ref-dec with the state transition below.
185 
186         if self
187             .scheduler_view()
188             .transition_to_terminal(is_join_interested)
189         {
190             self.dealloc()
191         }
192     }
193 
to_task(&self) -> Task<S>194     fn to_task(&self) -> Task<S> {
195         self.scheduler_view().to_task()
196     }
197 }
198 
199 enum TransitionToRunning {
200     Ok(Snapshot),
201     DropReference,
202 }
203 
204 struct SchedulerView<'a, S> {
205     header: &'a Header,
206     scheduler: &'a Scheduler<S>,
207 }
208 
209 impl<'a, S> SchedulerView<'a, S>
210 where
211     S: Schedule,
212 {
to_task(&self) -> Task<S>213     fn to_task(&self) -> Task<S> {
214         // SAFETY The header is from the same struct containing the scheduler `S` so  the cast is safe
215         unsafe { Task::from_raw(self.header.into()) }
216     }
217 
218     /// Returns true if the task should be deallocated.
transition_to_terminal(&self, is_join_interested: bool) -> bool219     fn transition_to_terminal(&self, is_join_interested: bool) -> bool {
220         let ref_dec = if self.scheduler.is_bound() {
221             if let Some(task) = self.scheduler.release(self.to_task()) {
222                 mem::forget(task);
223                 true
224             } else {
225                 false
226             }
227         } else {
228             false
229         };
230 
231         // This might deallocate
232         let snapshot = self
233             .header
234             .state
235             .transition_to_terminal(!is_join_interested, ref_dec);
236 
237         snapshot.ref_count() == 0
238     }
239 
transition_to_running(&self) -> TransitionToRunning240     fn transition_to_running(&self) -> TransitionToRunning {
241         // If this is the first time the task is polled, the task will be bound
242         // to the scheduler, in which case the task ref count must be
243         // incremented.
244         let is_not_bound = !self.scheduler.is_bound();
245 
246         // Transition the task to the running state.
247         //
248         // A failure to transition here indicates the task has been cancelled
249         // while in the run queue pending execution.
250         let snapshot = match self.header.state.transition_to_running(is_not_bound) {
251             Ok(snapshot) => snapshot,
252             Err(_) => {
253                 // The task was shutdown while in the run queue. At this point,
254                 // we just hold a ref counted reference. Since we do not have access to it here
255                 // return `DropReference` so the caller drops it.
256                 return TransitionToRunning::DropReference;
257             }
258         };
259 
260         if is_not_bound {
261             // Ensure the task is bound to a scheduler instance. Since this is
262             // the first time polling the task, a scheduler instance is pulled
263             // from the local context and assigned to the task.
264             //
265             // The scheduler maintains ownership of the task and responds to
266             // `wake` calls.
267             //
268             // The task reference count has been incremented.
269             //
270             // Safety: Since we have unique access to the task so that we can
271             // safely call `bind_scheduler`.
272             self.scheduler.bind_scheduler(self.to_task());
273         }
274         TransitionToRunning::Ok(snapshot)
275     }
276 }
277 
278 /// Transitions the task's lifecycle to `Complete`. Notifies the
279 /// `JoinHandle` if it still has interest in the completion.
transition_to_complete<T>(header: &Header, stage: &CoreStage<T>, trailer: &Trailer) where T: Future,280 fn transition_to_complete<T>(header: &Header, stage: &CoreStage<T>, trailer: &Trailer)
281 where
282     T: Future,
283 {
284     // Transition the task's lifecycle to `Complete` and get a snapshot of
285     // the task's sate.
286     let snapshot = header.state.transition_to_complete();
287 
288     if !snapshot.is_join_interested() {
289         // The `JoinHandle` is not interested in the output of this task. It
290         // is our responsibility to drop the output.
291         stage.drop_future_or_output();
292     } else if snapshot.has_join_waker() {
293         // Notify the join handle. The previous transition obtains the
294         // lock on the waker cell.
295         trailer.wake_join();
296     }
297 }
298 
can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool299 fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool {
300     // Load a snapshot of the current task state
301     let snapshot = header.state.load();
302 
303     debug_assert!(snapshot.is_join_interested());
304 
305     if !snapshot.is_complete() {
306         // The waker must be stored in the task struct.
307         let res = if snapshot.has_join_waker() {
308             // There already is a waker stored in the struct. If it matches
309             // the provided waker, then there is no further work to do.
310             // Otherwise, the waker must be swapped.
311             let will_wake = unsafe {
312                 // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE`
313                 // may mutate the `waker` field.
314                 trailer.will_wake(waker)
315             };
316 
317             if will_wake {
318                 // The task is not complete **and** the waker is up to date,
319                 // there is nothing further that needs to be done.
320                 return false;
321             }
322 
323             // Unset the `JOIN_WAKER` to gain mutable access to the `waker`
324             // field then update the field with the new join worker.
325             //
326             // This requires two atomic operations, unsetting the bit and
327             // then resetting it. If the task transitions to complete
328             // concurrently to either one of those operations, then setting
329             // the join waker fails and we proceed to reading the task
330             // output.
331             header
332                 .state
333                 .unset_waker()
334                 .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot))
335         } else {
336             set_join_waker(header, trailer, waker.clone(), snapshot)
337         };
338 
339         match res {
340             Ok(_) => return false,
341             Err(snapshot) => {
342                 assert!(snapshot.is_complete());
343             }
344         }
345     }
346     true
347 }
348 
set_join_waker( header: &Header, trailer: &Trailer, waker: Waker, snapshot: Snapshot, ) -> Result<Snapshot, Snapshot>349 fn set_join_waker(
350     header: &Header,
351     trailer: &Trailer,
352     waker: Waker,
353     snapshot: Snapshot,
354 ) -> Result<Snapshot, Snapshot> {
355     assert!(snapshot.is_join_interested());
356     assert!(!snapshot.has_join_waker());
357 
358     // Safety: Only the `JoinHandle` may set the `waker` field. When
359     // `JOIN_INTEREST` is **not** set, nothing else will touch the field.
360     unsafe {
361         trailer.set_waker(Some(waker));
362     }
363 
364     // Update the `JoinWaker` state accordingly
365     let res = header.state.set_join_waker();
366 
367     // If the state could not be updated, then clear the join waker
368     if res.is_err() {
369         unsafe {
370             trailer.set_waker(None);
371         }
372     }
373 
374     res
375 }
376 
377 enum PollFuture<T> {
378     Complete(Result<T, JoinError>, bool),
379     DropReference,
380     Notified,
381     None,
382 }
383 
cancel_task<T: Future>(stage: &CoreStage<T>) -> JoinError384 fn cancel_task<T: Future>(stage: &CoreStage<T>) -> JoinError {
385     // Drop the future from a panic guard.
386     let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
387         stage.drop_future_or_output();
388     }));
389 
390     if let Err(err) = res {
391         // Dropping the future panicked, complete the join
392         // handle with the panic to avoid dropping the panic
393         // on the ground.
394         JoinError::panic(err)
395     } else {
396         JoinError::cancelled()
397     }
398 }
399 
poll_future<T: Future>( header: &Header, core: &CoreStage<T>, snapshot: Snapshot, cx: Context<'_>, ) -> PollFuture<T::Output>400 fn poll_future<T: Future>(
401     header: &Header,
402     core: &CoreStage<T>,
403     snapshot: Snapshot,
404     cx: Context<'_>,
405 ) -> PollFuture<T::Output> {
406     if snapshot.is_cancelled() {
407         PollFuture::Complete(Err(JoinError::cancelled()), snapshot.is_join_interested())
408     } else {
409         let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
410             struct Guard<'a, T: Future> {
411                 core: &'a CoreStage<T>,
412             }
413 
414             impl<T: Future> Drop for Guard<'_, T> {
415                 fn drop(&mut self) {
416                     self.core.drop_future_or_output();
417                 }
418             }
419 
420             let guard = Guard { core };
421 
422             let res = guard.core.poll(cx);
423 
424             // prevent the guard from dropping the future
425             mem::forget(guard);
426 
427             res
428         }));
429         match res {
430             Ok(Poll::Pending) => match header.state.transition_to_idle() {
431                 Ok(snapshot) => {
432                     if snapshot.is_notified() {
433                         PollFuture::Notified
434                     } else {
435                         PollFuture::None
436                     }
437                 }
438                 Err(_) => PollFuture::Complete(Err(cancel_task(core)), true),
439             },
440             Ok(Poll::Ready(ok)) => PollFuture::Complete(Ok(ok), snapshot.is_join_interested()),
441             Err(err) => {
442                 PollFuture::Complete(Err(JoinError::panic(err)), snapshot.is_join_interested())
443             }
444         }
445     }
446 }
447