• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Core task module.
2 //!
3 //! # Safety
4 //!
5 //! The functions in this module are private to the `task` module. All of them
6 //! should be considered `unsafe` to use, but are not marked as such since it
7 //! would be too noisy.
8 //!
9 //! Make sure to consult the relevant safety section of each function before
10 //! use.
11 
12 use crate::future::Future;
13 use crate::loom::cell::UnsafeCell;
14 use crate::runtime::context;
15 use crate::runtime::task::raw::{self, Vtable};
16 use crate::runtime::task::state::State;
17 use crate::runtime::task::{Id, Schedule};
18 use crate::util::linked_list;
19 
20 use std::pin::Pin;
21 use std::ptr::NonNull;
22 use std::task::{Context, Poll, Waker};
23 
24 /// The task cell. Contains the components of the task.
25 ///
26 /// It is critical for `Header` to be the first field as the task structure will
27 /// be referenced by both *mut Cell and *mut Header.
28 ///
29 /// Any changes to the layout of this struct _must_ also be reflected in the
30 /// const fns in raw.rs.
31 #[repr(C)]
32 pub(super) struct Cell<T: Future, S> {
33     /// Hot task state data
34     pub(super) header: Header,
35 
36     /// Either the future or output, depending on the execution stage.
37     pub(super) core: Core<T, S>,
38 
39     /// Cold data
40     pub(super) trailer: Trailer,
41 }
42 
43 pub(super) struct CoreStage<T: Future> {
44     stage: UnsafeCell<Stage<T>>,
45 }
46 
47 /// The core of the task.
48 ///
49 /// Holds the future or output, depending on the stage of execution.
50 ///
51 /// Any changes to the layout of this struct _must_ also be reflected in the
52 /// const fns in raw.rs.
53 #[repr(C)]
54 pub(super) struct Core<T: Future, S> {
55     /// Scheduler used to drive this future.
56     pub(super) scheduler: S,
57 
58     /// The task's ID, used for populating `JoinError`s.
59     pub(super) task_id: Id,
60 
61     /// Either the future or the output.
62     pub(super) stage: CoreStage<T>,
63 }
64 
65 /// Crate public as this is also needed by the pool.
66 #[repr(C)]
67 pub(crate) struct Header {
68     /// Task state.
69     pub(super) state: State,
70 
71     /// Pointer to next task, used with the injection queue.
72     pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>,
73 
74     /// Table of function pointers for executing actions on the task.
75     pub(super) vtable: &'static Vtable,
76 
77     /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that
78     /// this task is stored in. If the task is not in any list, should be the
79     /// id of the list that it was previously in, or zero if it has never been
80     /// in any list.
81     ///
82     /// Once a task has been bound to a list, it can never be bound to another
83     /// list, even if removed from the first list.
84     ///
85     /// The id is not unset when removed from a list because we want to be able
86     /// to read the id without synchronization, even if it is concurrently being
87     /// removed from the list.
88     pub(super) owner_id: UnsafeCell<u64>,
89 
90     /// The tracing ID for this instrumented task.
91     #[cfg(all(tokio_unstable, feature = "tracing"))]
92     pub(super) tracing_id: Option<tracing::Id>,
93 }
94 
95 unsafe impl Send for Header {}
96 unsafe impl Sync for Header {}
97 
98 /// Cold data is stored after the future. Data is considered cold if it is only
99 /// used during creation or shutdown of the task.
100 pub(super) struct Trailer {
101     /// Pointers for the linked list in the `OwnedTasks` that owns this task.
102     pub(super) owned: linked_list::Pointers<Header>,
103     /// Consumer task waiting on completion of this task.
104     pub(super) waker: UnsafeCell<Option<Waker>>,
105 }
106 
107 generate_addr_of_methods! {
108     impl<> Trailer {
109         pub(super) unsafe fn addr_of_owned(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Header>> {
110             &self.owned
111         }
112     }
113 }
114 
115 /// Either the future or the output.
116 pub(super) enum Stage<T: Future> {
117     Running(T),
118     Finished(super::Result<T::Output>),
119     Consumed,
120 }
121 
122 impl<T: Future, S: Schedule> Cell<T, S> {
123     /// Allocates a new task cell, containing the header, trailer, and core
124     /// structures.
new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>>125     pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
126         #[cfg(all(tokio_unstable, feature = "tracing"))]
127         let tracing_id = future.id();
128         let result = Box::new(Cell {
129             header: Header {
130                 state,
131                 queue_next: UnsafeCell::new(None),
132                 vtable: raw::vtable::<T, S>(),
133                 owner_id: UnsafeCell::new(0),
134                 #[cfg(all(tokio_unstable, feature = "tracing"))]
135                 tracing_id,
136             },
137             core: Core {
138                 scheduler,
139                 stage: CoreStage {
140                     stage: UnsafeCell::new(Stage::Running(future)),
141                 },
142                 task_id,
143             },
144             trailer: Trailer {
145                 waker: UnsafeCell::new(None),
146                 owned: linked_list::Pointers::new(),
147             },
148         });
149 
150         #[cfg(debug_assertions)]
151         {
152             let trailer_addr = (&result.trailer) as *const Trailer as usize;
153             let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(&result.header)) };
154             assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);
155 
156             let scheduler_addr = (&result.core.scheduler) as *const S as usize;
157             let scheduler_ptr =
158                 unsafe { Header::get_scheduler::<S>(NonNull::from(&result.header)) };
159             assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);
160 
161             let id_addr = (&result.core.task_id) as *const Id as usize;
162             let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(&result.header)) };
163             assert_eq!(id_addr, id_ptr.as_ptr() as usize);
164         }
165 
166         result
167     }
168 }
169 
170 impl<T: Future> CoreStage<T> {
with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R171     pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R {
172         self.stage.with_mut(f)
173     }
174 }
175 
176 /// Set and clear the task id in the context when the future is executed or
177 /// dropped, or when the output produced by the future is dropped.
178 pub(crate) struct TaskIdGuard {
179     parent_task_id: Option<Id>,
180 }
181 
182 impl TaskIdGuard {
enter(id: Id) -> Self183     fn enter(id: Id) -> Self {
184         TaskIdGuard {
185             parent_task_id: context::set_current_task_id(Some(id)),
186         }
187     }
188 }
189 
190 impl Drop for TaskIdGuard {
drop(&mut self)191     fn drop(&mut self) {
192         context::set_current_task_id(self.parent_task_id);
193     }
194 }
195 
196 impl<T: Future, S: Schedule> Core<T, S> {
197     /// Polls the future.
198     ///
199     /// # Safety
200     ///
201     /// The caller must ensure it is safe to mutate the `state` field. This
202     /// requires ensuring mutual exclusion between any concurrent thread that
203     /// might modify the future or output field.
204     ///
205     /// The mutual exclusion is implemented by `Harness` and the `Lifecycle`
206     /// component of the task state.
207     ///
208     /// `self` must also be pinned. This is handled by storing the task on the
209     /// heap.
poll(&self, mut cx: Context<'_>) -> Poll<T::Output>210     pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> {
211         let res = {
212             self.stage.stage.with_mut(|ptr| {
213                 // Safety: The caller ensures mutual exclusion to the field.
214                 let future = match unsafe { &mut *ptr } {
215                     Stage::Running(future) => future,
216                     _ => unreachable!("unexpected stage"),
217                 };
218 
219                 // Safety: The caller ensures the future is pinned.
220                 let future = unsafe { Pin::new_unchecked(future) };
221 
222                 let _guard = TaskIdGuard::enter(self.task_id);
223                 future.poll(&mut cx)
224             })
225         };
226 
227         if res.is_ready() {
228             self.drop_future_or_output();
229         }
230 
231         res
232     }
233 
234     /// Drops the future.
235     ///
236     /// # Safety
237     ///
238     /// The caller must ensure it is safe to mutate the `stage` field.
drop_future_or_output(&self)239     pub(super) fn drop_future_or_output(&self) {
240         // Safety: the caller ensures mutual exclusion to the field.
241         unsafe {
242             self.set_stage(Stage::Consumed);
243         }
244     }
245 
246     /// Stores the task output.
247     ///
248     /// # Safety
249     ///
250     /// The caller must ensure it is safe to mutate the `stage` field.
store_output(&self, output: super::Result<T::Output>)251     pub(super) fn store_output(&self, output: super::Result<T::Output>) {
252         // Safety: the caller ensures mutual exclusion to the field.
253         unsafe {
254             self.set_stage(Stage::Finished(output));
255         }
256     }
257 
258     /// Takes the task output.
259     ///
260     /// # Safety
261     ///
262     /// The caller must ensure it is safe to mutate the `stage` field.
take_output(&self) -> super::Result<T::Output>263     pub(super) fn take_output(&self) -> super::Result<T::Output> {
264         use std::mem;
265 
266         self.stage.stage.with_mut(|ptr| {
267             // Safety:: the caller ensures mutual exclusion to the field.
268             match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) {
269                 Stage::Finished(output) => output,
270                 _ => panic!("JoinHandle polled after completion"),
271             }
272         })
273     }
274 
set_stage(&self, stage: Stage<T>)275     unsafe fn set_stage(&self, stage: Stage<T>) {
276         let _guard = TaskIdGuard::enter(self.task_id);
277         self.stage.stage.with_mut(|ptr| *ptr = stage)
278     }
279 }
280 
281 cfg_rt_multi_thread! {
282     impl Header {
283         pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
284             self.queue_next.with_mut(|ptr| *ptr = next);
285         }
286     }
287 }
288 
289 impl Header {
290     // safety: The caller must guarantee exclusive access to this field, and
291     // must ensure that the id is either 0 or the id of the OwnedTasks
292     // containing this task.
set_owner_id(&self, owner: u64)293     pub(super) unsafe fn set_owner_id(&self, owner: u64) {
294         self.owner_id.with_mut(|ptr| *ptr = owner);
295     }
296 
get_owner_id(&self) -> u64297     pub(super) fn get_owner_id(&self) -> u64 {
298         // safety: If there are concurrent writes, then that write has violated
299         // the safety requirements on `set_owner_id`.
300         unsafe { self.owner_id.with(|ptr| *ptr) }
301     }
302 
303     /// Gets a pointer to the `Trailer` of the task containing this `Header`.
304     ///
305     /// # Safety
306     ///
307     /// The provided raw pointer must point at the header of a task.
get_trailer(me: NonNull<Header>) -> NonNull<Trailer>308     pub(super) unsafe fn get_trailer(me: NonNull<Header>) -> NonNull<Trailer> {
309         let offset = me.as_ref().vtable.trailer_offset;
310         let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
311         NonNull::new_unchecked(trailer)
312     }
313 
314     /// Gets a pointer to the scheduler of the task containing this `Header`.
315     ///
316     /// # Safety
317     ///
318     /// The provided raw pointer must point at the header of a task.
319     ///
320     /// The generic type S must be set to the correct scheduler type for this
321     /// task.
get_scheduler<S>(me: NonNull<Header>) -> NonNull<S>322     pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
323         let offset = me.as_ref().vtable.scheduler_offset;
324         let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
325         NonNull::new_unchecked(scheduler)
326     }
327 
328     /// Gets a pointer to the id of the task containing this `Header`.
329     ///
330     /// # Safety
331     ///
332     /// The provided raw pointer must point at the header of a task.
get_id_ptr(me: NonNull<Header>) -> NonNull<Id>333     pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
334         let offset = me.as_ref().vtable.id_offset;
335         let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
336         NonNull::new_unchecked(id)
337     }
338 
339     /// Gets the id of the task containing this `Header`.
340     ///
341     /// # Safety
342     ///
343     /// The provided raw pointer must point at the header of a task.
get_id(me: NonNull<Header>) -> Id344     pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
345         let ptr = Header::get_id_ptr(me).as_ptr();
346         *ptr
347     }
348 
349     /// Gets the tracing id of the task containing this `Header`.
350     ///
351     /// # Safety
352     ///
353     /// The provided raw pointer must point at the header of a task.
354     #[cfg(all(tokio_unstable, feature = "tracing"))]
get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id>355     pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
356         me.as_ref().tracing_id.as_ref()
357     }
358 }
359 
360 impl Trailer {
set_waker(&self, waker: Option<Waker>)361     pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) {
362         self.waker.with_mut(|ptr| {
363             *ptr = waker;
364         });
365     }
366 
will_wake(&self, waker: &Waker) -> bool367     pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool {
368         self.waker
369             .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
370     }
371 
wake_join(&self)372     pub(super) fn wake_join(&self) {
373         self.waker.with(|ptr| match unsafe { &*ptr } {
374             Some(waker) => waker.wake_by_ref(),
375             None => panic!("waker missing"),
376         });
377     }
378 }
379 
380 #[test]
381 #[cfg(not(loom))]
header_lte_cache_line()382 fn header_lte_cache_line() {
383     use std::mem::size_of;
384 
385     assert!(size_of::<Header>() <= 8 * size_of::<*const ()>());
386 }
387