1 use crate::future::Future;
2 use crate::runtime::task::core::{Core, Trailer};
3 use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State};
4
5 use std::ptr::NonNull;
6 use std::task::{Poll, Waker};
7
8 /// Raw task handle
9 pub(crate) struct RawTask {
10 ptr: NonNull<Header>,
11 }
12
13 pub(super) struct Vtable {
14 /// Polls the future.
15 pub(super) poll: unsafe fn(NonNull<Header>),
16
17 /// Schedules the task for execution on the runtime.
18 pub(super) schedule: unsafe fn(NonNull<Header>),
19
20 /// Deallocates the memory.
21 pub(super) dealloc: unsafe fn(NonNull<Header>),
22
23 /// Reads the task output, if complete.
24 pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),
25
26 /// The join handle has been dropped.
27 pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),
28
29 /// An abort handle has been dropped.
30 pub(super) drop_abort_handle: unsafe fn(NonNull<Header>),
31
32 /// Scheduler is being shutdown.
33 pub(super) shutdown: unsafe fn(NonNull<Header>),
34
35 /// The number of bytes that the `trailer` field is offset from the header.
36 pub(super) trailer_offset: usize,
37
38 /// The number of bytes that the `scheduler` field is offset from the header.
39 pub(super) scheduler_offset: usize,
40
41 /// The number of bytes that the `id` field is offset from the header.
42 pub(super) id_offset: usize,
43 }
44
45 /// Get the vtable for the requested `T` and `S` generics.
vtable<T: Future, S: Schedule>() -> &'static Vtable46 pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
47 &Vtable {
48 poll: poll::<T, S>,
49 schedule: schedule::<S>,
50 dealloc: dealloc::<T, S>,
51 try_read_output: try_read_output::<T, S>,
52 drop_join_handle_slow: drop_join_handle_slow::<T, S>,
53 drop_abort_handle: drop_abort_handle::<T, S>,
54 shutdown: shutdown::<T, S>,
55 trailer_offset: OffsetHelper::<T, S>::TRAILER_OFFSET,
56 scheduler_offset: OffsetHelper::<T, S>::SCHEDULER_OFFSET,
57 id_offset: OffsetHelper::<T, S>::ID_OFFSET,
58 }
59 }
60
61 /// Calling `get_trailer_offset` directly in vtable doesn't work because it
62 /// prevents the vtable from being promoted to a static reference.
63 ///
64 /// See this thread for more info:
65 /// <https://users.rust-lang.org/t/custom-vtables-with-integers/78508>
66 struct OffsetHelper<T, S>(T, S);
67 impl<T: Future, S: Schedule> OffsetHelper<T, S> {
68 // Pass `size_of`/`align_of` as arguments rather than calling them directly
69 // inside `get_trailer_offset` because trait bounds on generic parameters
70 // of const fn are unstable on our MSRV.
71 const TRAILER_OFFSET: usize = get_trailer_offset(
72 std::mem::size_of::<Header>(),
73 std::mem::size_of::<Core<T, S>>(),
74 std::mem::align_of::<Core<T, S>>(),
75 std::mem::align_of::<Trailer>(),
76 );
77
78 // The `scheduler` is the first field of `Core`, so it has the same
79 // offset as `Core`.
80 const SCHEDULER_OFFSET: usize = get_core_offset(
81 std::mem::size_of::<Header>(),
82 std::mem::align_of::<Core<T, S>>(),
83 );
84
85 const ID_OFFSET: usize = get_id_offset(
86 std::mem::size_of::<Header>(),
87 std::mem::align_of::<Core<T, S>>(),
88 std::mem::size_of::<S>(),
89 std::mem::align_of::<Id>(),
90 );
91 }
92
93 /// Compute the offset of the `Trailer` field in `Cell<T, S>` using the
94 /// `#[repr(C)]` algorithm.
95 ///
96 /// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
97 /// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
get_trailer_offset( header_size: usize, core_size: usize, core_align: usize, trailer_align: usize, ) -> usize98 const fn get_trailer_offset(
99 header_size: usize,
100 core_size: usize,
101 core_align: usize,
102 trailer_align: usize,
103 ) -> usize {
104 let mut offset = header_size;
105
106 let core_misalign = offset % core_align;
107 if core_misalign > 0 {
108 offset += core_align - core_misalign;
109 }
110 offset += core_size;
111
112 let trailer_misalign = offset % trailer_align;
113 if trailer_misalign > 0 {
114 offset += trailer_align - trailer_misalign;
115 }
116
117 offset
118 }
119
120 /// Compute the offset of the `Core<T, S>` field in `Cell<T, S>` using the
121 /// `#[repr(C)]` algorithm.
122 ///
123 /// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
124 /// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
get_core_offset(header_size: usize, core_align: usize) -> usize125 const fn get_core_offset(header_size: usize, core_align: usize) -> usize {
126 let mut offset = header_size;
127
128 let core_misalign = offset % core_align;
129 if core_misalign > 0 {
130 offset += core_align - core_misalign;
131 }
132
133 offset
134 }
135
136 /// Compute the offset of the `Id` field in `Cell<T, S>` using the
137 /// `#[repr(C)]` algorithm.
138 ///
139 /// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
140 /// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
get_id_offset( header_size: usize, core_align: usize, scheduler_size: usize, id_align: usize, ) -> usize141 const fn get_id_offset(
142 header_size: usize,
143 core_align: usize,
144 scheduler_size: usize,
145 id_align: usize,
146 ) -> usize {
147 let mut offset = get_core_offset(header_size, core_align);
148 offset += scheduler_size;
149
150 let id_misalign = offset % id_align;
151 if id_misalign > 0 {
152 offset += id_align - id_misalign;
153 }
154
155 offset
156 }
157
158 impl RawTask {
new<T, S>(task: T, scheduler: S, id: Id) -> RawTask where T: Future, S: Schedule,159 pub(super) fn new<T, S>(task: T, scheduler: S, id: Id) -> RawTask
160 where
161 T: Future,
162 S: Schedule,
163 {
164 let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id));
165 let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) };
166
167 RawTask { ptr }
168 }
169
from_raw(ptr: NonNull<Header>) -> RawTask170 pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> RawTask {
171 RawTask { ptr }
172 }
173
header_ptr(&self) -> NonNull<Header>174 pub(super) fn header_ptr(&self) -> NonNull<Header> {
175 self.ptr
176 }
177
trailer_ptr(&self) -> NonNull<Trailer>178 pub(super) fn trailer_ptr(&self) -> NonNull<Trailer> {
179 unsafe { Header::get_trailer(self.ptr) }
180 }
181
182 /// Returns a reference to the task's header.
header(&self) -> &Header183 pub(super) fn header(&self) -> &Header {
184 unsafe { self.ptr.as_ref() }
185 }
186
187 /// Returns a reference to the task's trailer.
trailer(&self) -> &Trailer188 pub(super) fn trailer(&self) -> &Trailer {
189 unsafe { &*self.trailer_ptr().as_ptr() }
190 }
191
192 /// Returns a reference to the task's state.
state(&self) -> &State193 pub(super) fn state(&self) -> &State {
194 &self.header().state
195 }
196
197 /// Safety: mutual exclusion is required to call this function.
poll(self)198 pub(crate) fn poll(self) {
199 let vtable = self.header().vtable;
200 unsafe { (vtable.poll)(self.ptr) }
201 }
202
schedule(self)203 pub(super) fn schedule(self) {
204 let vtable = self.header().vtable;
205 unsafe { (vtable.schedule)(self.ptr) }
206 }
207
dealloc(self)208 pub(super) fn dealloc(self) {
209 let vtable = self.header().vtable;
210 unsafe {
211 (vtable.dealloc)(self.ptr);
212 }
213 }
214
215 /// Safety: `dst` must be a `*mut Poll<super::Result<T::Output>>` where `T`
216 /// is the future stored by the task.
try_read_output(self, dst: *mut (), waker: &Waker)217 pub(super) unsafe fn try_read_output(self, dst: *mut (), waker: &Waker) {
218 let vtable = self.header().vtable;
219 (vtable.try_read_output)(self.ptr, dst, waker);
220 }
221
drop_join_handle_slow(self)222 pub(super) fn drop_join_handle_slow(self) {
223 let vtable = self.header().vtable;
224 unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
225 }
226
drop_abort_handle(self)227 pub(super) fn drop_abort_handle(self) {
228 let vtable = self.header().vtable;
229 unsafe { (vtable.drop_abort_handle)(self.ptr) }
230 }
231
shutdown(self)232 pub(super) fn shutdown(self) {
233 let vtable = self.header().vtable;
234 unsafe { (vtable.shutdown)(self.ptr) }
235 }
236
237 /// Increment the task's reference count.
238 ///
239 /// Currently, this is used only when creating an `AbortHandle`.
ref_inc(self)240 pub(super) fn ref_inc(self) {
241 self.header().state.ref_inc();
242 }
243
244 /// Get the queue-next pointer
245 ///
246 /// This is for usage by the injection queue
247 ///
248 /// Safety: make sure only one queue uses this and access is synchronized.
get_queue_next(self) -> Option<RawTask>249 pub(crate) unsafe fn get_queue_next(self) -> Option<RawTask> {
250 self.header()
251 .queue_next
252 .with(|ptr| *ptr)
253 .map(|p| RawTask::from_raw(p))
254 }
255
256 /// Sets the queue-next pointer
257 ///
258 /// This is for usage by the injection queue
259 ///
260 /// Safety: make sure only one queue uses this and access is synchronized.
set_queue_next(self, val: Option<RawTask>)261 pub(crate) unsafe fn set_queue_next(self, val: Option<RawTask>) {
262 self.header().set_next(val.map(|task| task.ptr));
263 }
264 }
265
266 impl Clone for RawTask {
clone(&self) -> Self267 fn clone(&self) -> Self {
268 RawTask { ptr: self.ptr }
269 }
270 }
271
272 impl Copy for RawTask {}
273
poll<T: Future, S: Schedule>(ptr: NonNull<Header>)274 unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) {
275 let harness = Harness::<T, S>::from_raw(ptr);
276 harness.poll();
277 }
278
schedule<S: Schedule>(ptr: NonNull<Header>)279 unsafe fn schedule<S: Schedule>(ptr: NonNull<Header>) {
280 use crate::runtime::task::{Notified, Task};
281
282 let scheduler = Header::get_scheduler::<S>(ptr);
283 scheduler
284 .as_ref()
285 .schedule(Notified(Task::from_raw(ptr.cast())));
286 }
287
dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>)288 unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) {
289 let harness = Harness::<T, S>::from_raw(ptr);
290 harness.dealloc();
291 }
292
try_read_output<T: Future, S: Schedule>( ptr: NonNull<Header>, dst: *mut (), waker: &Waker, )293 unsafe fn try_read_output<T: Future, S: Schedule>(
294 ptr: NonNull<Header>,
295 dst: *mut (),
296 waker: &Waker,
297 ) {
298 let out = &mut *(dst as *mut Poll<super::Result<T::Output>>);
299
300 let harness = Harness::<T, S>::from_raw(ptr);
301 harness.try_read_output(out, waker);
302 }
303
drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>)304 unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
305 let harness = Harness::<T, S>::from_raw(ptr);
306 harness.drop_join_handle_slow()
307 }
308
drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>)309 unsafe fn drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>) {
310 let harness = Harness::<T, S>::from_raw(ptr);
311 harness.drop_reference();
312 }
313
shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>)314 unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) {
315 let harness = Harness::<T, S>::from_raw(ptr);
316 harness.shutdown()
317 }
318