1 //! Core task module.
2 //!
3 //! # Safety
4 //!
5 //! The functions in this module are private to the `task` module. All of them
6 //! should be considered `unsafe` to use, but are not marked as such since it
7 //! would be too noisy.
8 //!
9 //! Make sure to consult the relevant safety section of each function before
10 //! use.
11
12 use crate::future::Future;
13 use crate::loom::cell::UnsafeCell;
14 use crate::runtime::context;
15 use crate::runtime::task::raw::{self, Vtable};
16 use crate::runtime::task::state::State;
17 use crate::runtime::task::{Id, Schedule};
18 use crate::util::linked_list;
19
20 use std::pin::Pin;
21 use std::ptr::NonNull;
22 use std::task::{Context, Poll, Waker};
23
24 /// The task cell. Contains the components of the task.
25 ///
26 /// It is critical for `Header` to be the first field as the task structure will
27 /// be referenced by both *mut Cell and *mut Header.
28 ///
29 /// Any changes to the layout of this struct _must_ also be reflected in the
30 /// const fns in raw.rs.
31 #[repr(C)]
32 pub(super) struct Cell<T: Future, S> {
33 /// Hot task state data
34 pub(super) header: Header,
35
36 /// Either the future or output, depending on the execution stage.
37 pub(super) core: Core<T, S>,
38
39 /// Cold data
40 pub(super) trailer: Trailer,
41 }
42
43 pub(super) struct CoreStage<T: Future> {
44 stage: UnsafeCell<Stage<T>>,
45 }
46
47 /// The core of the task.
48 ///
49 /// Holds the future or output, depending on the stage of execution.
50 ///
51 /// Any changes to the layout of this struct _must_ also be reflected in the
52 /// const fns in raw.rs.
53 #[repr(C)]
54 pub(super) struct Core<T: Future, S> {
55 /// Scheduler used to drive this future.
56 pub(super) scheduler: S,
57
58 /// The task's ID, used for populating `JoinError`s.
59 pub(super) task_id: Id,
60
61 /// Either the future or the output.
62 pub(super) stage: CoreStage<T>,
63 }
64
65 /// Crate public as this is also needed by the pool.
66 #[repr(C)]
67 pub(crate) struct Header {
68 /// Task state.
69 pub(super) state: State,
70
71 /// Pointer to next task, used with the injection queue.
72 pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>,
73
74 /// Table of function pointers for executing actions on the task.
75 pub(super) vtable: &'static Vtable,
76
77 /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that
78 /// this task is stored in. If the task is not in any list, should be the
79 /// id of the list that it was previously in, or zero if it has never been
80 /// in any list.
81 ///
82 /// Once a task has been bound to a list, it can never be bound to another
83 /// list, even if removed from the first list.
84 ///
85 /// The id is not unset when removed from a list because we want to be able
86 /// to read the id without synchronization, even if it is concurrently being
87 /// removed from the list.
88 pub(super) owner_id: UnsafeCell<u64>,
89
90 /// The tracing ID for this instrumented task.
91 #[cfg(all(tokio_unstable, feature = "tracing"))]
92 pub(super) tracing_id: Option<tracing::Id>,
93 }
94
95 unsafe impl Send for Header {}
96 unsafe impl Sync for Header {}
97
98 /// Cold data is stored after the future. Data is considered cold if it is only
99 /// used during creation or shutdown of the task.
100 pub(super) struct Trailer {
101 /// Pointers for the linked list in the `OwnedTasks` that owns this task.
102 pub(super) owned: linked_list::Pointers<Header>,
103 /// Consumer task waiting on completion of this task.
104 pub(super) waker: UnsafeCell<Option<Waker>>,
105 }
106
107 generate_addr_of_methods! {
108 impl<> Trailer {
109 pub(super) unsafe fn addr_of_owned(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Header>> {
110 &self.owned
111 }
112 }
113 }
114
115 /// Either the future or the output.
116 pub(super) enum Stage<T: Future> {
117 Running(T),
118 Finished(super::Result<T::Output>),
119 Consumed,
120 }
121
122 impl<T: Future, S: Schedule> Cell<T, S> {
123 /// Allocates a new task cell, containing the header, trailer, and core
124 /// structures.
new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>>125 pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
126 #[cfg(all(tokio_unstable, feature = "tracing"))]
127 let tracing_id = future.id();
128 let result = Box::new(Cell {
129 header: Header {
130 state,
131 queue_next: UnsafeCell::new(None),
132 vtable: raw::vtable::<T, S>(),
133 owner_id: UnsafeCell::new(0),
134 #[cfg(all(tokio_unstable, feature = "tracing"))]
135 tracing_id,
136 },
137 core: Core {
138 scheduler,
139 stage: CoreStage {
140 stage: UnsafeCell::new(Stage::Running(future)),
141 },
142 task_id,
143 },
144 trailer: Trailer {
145 waker: UnsafeCell::new(None),
146 owned: linked_list::Pointers::new(),
147 },
148 });
149
150 #[cfg(debug_assertions)]
151 {
152 let trailer_addr = (&result.trailer) as *const Trailer as usize;
153 let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(&result.header)) };
154 assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);
155
156 let scheduler_addr = (&result.core.scheduler) as *const S as usize;
157 let scheduler_ptr =
158 unsafe { Header::get_scheduler::<S>(NonNull::from(&result.header)) };
159 assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);
160
161 let id_addr = (&result.core.task_id) as *const Id as usize;
162 let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(&result.header)) };
163 assert_eq!(id_addr, id_ptr.as_ptr() as usize);
164 }
165
166 result
167 }
168 }
169
170 impl<T: Future> CoreStage<T> {
with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R171 pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R {
172 self.stage.with_mut(f)
173 }
174 }
175
176 /// Set and clear the task id in the context when the future is executed or
177 /// dropped, or when the output produced by the future is dropped.
178 pub(crate) struct TaskIdGuard {
179 parent_task_id: Option<Id>,
180 }
181
182 impl TaskIdGuard {
enter(id: Id) -> Self183 fn enter(id: Id) -> Self {
184 TaskIdGuard {
185 parent_task_id: context::set_current_task_id(Some(id)),
186 }
187 }
188 }
189
190 impl Drop for TaskIdGuard {
drop(&mut self)191 fn drop(&mut self) {
192 context::set_current_task_id(self.parent_task_id);
193 }
194 }
195
196 impl<T: Future, S: Schedule> Core<T, S> {
197 /// Polls the future.
198 ///
199 /// # Safety
200 ///
201 /// The caller must ensure it is safe to mutate the `state` field. This
202 /// requires ensuring mutual exclusion between any concurrent thread that
203 /// might modify the future or output field.
204 ///
205 /// The mutual exclusion is implemented by `Harness` and the `Lifecycle`
206 /// component of the task state.
207 ///
208 /// `self` must also be pinned. This is handled by storing the task on the
209 /// heap.
poll(&self, mut cx: Context<'_>) -> Poll<T::Output>210 pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> {
211 let res = {
212 self.stage.stage.with_mut(|ptr| {
213 // Safety: The caller ensures mutual exclusion to the field.
214 let future = match unsafe { &mut *ptr } {
215 Stage::Running(future) => future,
216 _ => unreachable!("unexpected stage"),
217 };
218
219 // Safety: The caller ensures the future is pinned.
220 let future = unsafe { Pin::new_unchecked(future) };
221
222 let _guard = TaskIdGuard::enter(self.task_id);
223 future.poll(&mut cx)
224 })
225 };
226
227 if res.is_ready() {
228 self.drop_future_or_output();
229 }
230
231 res
232 }
233
234 /// Drops the future.
235 ///
236 /// # Safety
237 ///
238 /// The caller must ensure it is safe to mutate the `stage` field.
drop_future_or_output(&self)239 pub(super) fn drop_future_or_output(&self) {
240 // Safety: the caller ensures mutual exclusion to the field.
241 unsafe {
242 self.set_stage(Stage::Consumed);
243 }
244 }
245
246 /// Stores the task output.
247 ///
248 /// # Safety
249 ///
250 /// The caller must ensure it is safe to mutate the `stage` field.
store_output(&self, output: super::Result<T::Output>)251 pub(super) fn store_output(&self, output: super::Result<T::Output>) {
252 // Safety: the caller ensures mutual exclusion to the field.
253 unsafe {
254 self.set_stage(Stage::Finished(output));
255 }
256 }
257
258 /// Takes the task output.
259 ///
260 /// # Safety
261 ///
262 /// The caller must ensure it is safe to mutate the `stage` field.
take_output(&self) -> super::Result<T::Output>263 pub(super) fn take_output(&self) -> super::Result<T::Output> {
264 use std::mem;
265
266 self.stage.stage.with_mut(|ptr| {
267 // Safety:: the caller ensures mutual exclusion to the field.
268 match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) {
269 Stage::Finished(output) => output,
270 _ => panic!("JoinHandle polled after completion"),
271 }
272 })
273 }
274
set_stage(&self, stage: Stage<T>)275 unsafe fn set_stage(&self, stage: Stage<T>) {
276 let _guard = TaskIdGuard::enter(self.task_id);
277 self.stage.stage.with_mut(|ptr| *ptr = stage)
278 }
279 }
280
281 cfg_rt_multi_thread! {
282 impl Header {
283 pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
284 self.queue_next.with_mut(|ptr| *ptr = next);
285 }
286 }
287 }
288
289 impl Header {
290 // safety: The caller must guarantee exclusive access to this field, and
291 // must ensure that the id is either 0 or the id of the OwnedTasks
292 // containing this task.
set_owner_id(&self, owner: u64)293 pub(super) unsafe fn set_owner_id(&self, owner: u64) {
294 self.owner_id.with_mut(|ptr| *ptr = owner);
295 }
296
get_owner_id(&self) -> u64297 pub(super) fn get_owner_id(&self) -> u64 {
298 // safety: If there are concurrent writes, then that write has violated
299 // the safety requirements on `set_owner_id`.
300 unsafe { self.owner_id.with(|ptr| *ptr) }
301 }
302
303 /// Gets a pointer to the `Trailer` of the task containing this `Header`.
304 ///
305 /// # Safety
306 ///
307 /// The provided raw pointer must point at the header of a task.
get_trailer(me: NonNull<Header>) -> NonNull<Trailer>308 pub(super) unsafe fn get_trailer(me: NonNull<Header>) -> NonNull<Trailer> {
309 let offset = me.as_ref().vtable.trailer_offset;
310 let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
311 NonNull::new_unchecked(trailer)
312 }
313
314 /// Gets a pointer to the scheduler of the task containing this `Header`.
315 ///
316 /// # Safety
317 ///
318 /// The provided raw pointer must point at the header of a task.
319 ///
320 /// The generic type S must be set to the correct scheduler type for this
321 /// task.
get_scheduler<S>(me: NonNull<Header>) -> NonNull<S>322 pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
323 let offset = me.as_ref().vtable.scheduler_offset;
324 let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
325 NonNull::new_unchecked(scheduler)
326 }
327
328 /// Gets a pointer to the id of the task containing this `Header`.
329 ///
330 /// # Safety
331 ///
332 /// The provided raw pointer must point at the header of a task.
get_id_ptr(me: NonNull<Header>) -> NonNull<Id>333 pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
334 let offset = me.as_ref().vtable.id_offset;
335 let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
336 NonNull::new_unchecked(id)
337 }
338
339 /// Gets the id of the task containing this `Header`.
340 ///
341 /// # Safety
342 ///
343 /// The provided raw pointer must point at the header of a task.
get_id(me: NonNull<Header>) -> Id344 pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
345 let ptr = Header::get_id_ptr(me).as_ptr();
346 *ptr
347 }
348
349 /// Gets the tracing id of the task containing this `Header`.
350 ///
351 /// # Safety
352 ///
353 /// The provided raw pointer must point at the header of a task.
354 #[cfg(all(tokio_unstable, feature = "tracing"))]
get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id>355 pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
356 me.as_ref().tracing_id.as_ref()
357 }
358 }
359
360 impl Trailer {
set_waker(&self, waker: Option<Waker>)361 pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) {
362 self.waker.with_mut(|ptr| {
363 *ptr = waker;
364 });
365 }
366
will_wake(&self, waker: &Waker) -> bool367 pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool {
368 self.waker
369 .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
370 }
371
wake_join(&self)372 pub(super) fn wake_join(&self) {
373 self.waker.with(|ptr| match unsafe { &*ptr } {
374 Some(waker) => waker.wake_by_ref(),
375 None => panic!("waker missing"),
376 });
377 }
378 }
379
380 #[test]
381 #[cfg(not(loom))]
header_lte_cache_line()382 fn header_lte_cache_line() {
383 use std::mem::size_of;
384
385 assert!(size_of::<Header>() <= 8 * size_of::<*const ()>());
386 }
387