1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::cell::UnsafeCell;
15 use std::future::Future;
16 use std::mem;
17 use std::pin::Pin;
18 use std::ptr::{addr_of_mut, NonNull};
19 use std::sync::Weak;
20 use std::task::{Context, Poll, Waker};
21
22 use crate::error::ScheduleError;
23 use crate::executor::Schedule;
24 use crate::task::state::TaskState;
25 use crate::task::task_handle::TaskHandle;
26 use crate::task::{TaskBuilder, VirtualTableType};
27 use crate::util::linked_list::{Link, Node};
28
29 cfg_ffrt! {
30 use crate::ffrt::ffrt_task::FfrtTaskCtx;
31 }
32
33 pub(crate) struct TaskVirtualTable {
34 /// Task running method
35 pub(crate) run: unsafe fn(NonNull<Header>) -> bool,
36 /// Task scheduling method
37 pub(crate) schedule: unsafe fn(NonNull<Header>, bool),
38 /// Task result-getting method
39 pub(crate) get_result: unsafe fn(NonNull<Header>, *mut ()),
40 /// JoinHandle drop method
41 pub(crate) drop_join_handle: unsafe fn(NonNull<Header>),
42 /// Task reference drop method
43 pub(crate) drop_ref: unsafe fn(NonNull<Header>),
44 /// Task waker setting method
45 pub(crate) set_waker: unsafe fn(NonNull<Header>, cur_state: usize, waker: *const ()) -> bool,
46 /// Task release method
47 #[cfg(not(feature = "ffrt"))]
48 pub(crate) release: unsafe fn(NonNull<Header>),
49 /// Task cancel method
50 pub(crate) cancel: unsafe fn(NonNull<Header>),
51 }
52
53 #[repr(C)]
54 pub(crate) struct Header {
55 pub(crate) state: TaskState,
56 pub(crate) vtable: &'static TaskVirtualTable,
57 // Node inside the global queue
58 node: Node<Header>,
59 }
60
61 cfg_not_ffrt! {
62 fn get_default_vtable() -> &'static TaskVirtualTable {
63 unsafe fn default_run(_task: NonNull<Header>) -> bool {
64 false
65 }
66 unsafe fn default_schedule(_task: NonNull<Header>, _fifo: bool) {}
67 unsafe fn default_get_result(_task: NonNull<Header>, _result: *mut ()) {}
68 unsafe fn default_drop_handle(_task: NonNull<Header>) {}
69 unsafe fn default_set_waker(
70 _task: NonNull<Header>,
71 _cur_state: usize,
72 _waker: *const (),
73 ) -> bool {
74 false
75 }
76 unsafe fn default_drop_ref(_task: NonNull<Header>) {}
77 unsafe fn default_release(_task: NonNull<Header>) {}
78 unsafe fn default_cancel(_task: NonNull<Header>) {}
79
80 &TaskVirtualTable {
81 run: default_run,
82 schedule: default_schedule,
83 get_result: default_get_result,
84 drop_join_handle: default_drop_handle,
85 drop_ref: default_drop_ref,
86 set_waker: default_set_waker,
87 release: default_release,
88 cancel: default_cancel,
89 }
90 }
91
92 impl Default for Header {
93 fn default() -> Self {
94 Self {
95 state: TaskState::new(),
96 vtable: get_default_vtable(),
97 node: Default::default(),
98 }
99 }
100 }
101 }
102
103 unsafe impl Link for Header {
node(mut ptr: NonNull<Self>) -> NonNull<Node<Self>> where Self: Sized,104 unsafe fn node(mut ptr: NonNull<Self>) -> NonNull<Node<Self>>
105 where
106 Self: Sized,
107 {
108 let node_ptr = addr_of_mut!(ptr.as_mut().node);
109 NonNull::new_unchecked(node_ptr)
110 }
111 }
112
113 #[derive(PartialEq, Eq, Hash, Clone, Copy)]
114 pub(crate) struct RawTask {
115 pub(crate) ptr: NonNull<Header>,
116 }
117
118 impl RawTask {
header(&self) -> &Header119 pub(crate) fn header(&self) -> &Header {
120 unsafe { self.ptr.as_ref() }
121 }
122
run(self) -> bool123 pub(crate) fn run(self) -> bool {
124 let vir_table = self.header().vtable;
125 unsafe { (vir_table.run)(self.ptr) }
126 }
127
get_result(self, res: *mut ())128 pub(crate) unsafe fn get_result(self, res: *mut ()) {
129 let vir_table = self.header().vtable;
130 (vir_table.get_result)(self.ptr, res);
131 }
132
cancel(self)133 pub(crate) unsafe fn cancel(self) {
134 let vir_table = self.header().vtable;
135 (vir_table.cancel)(self.ptr)
136 }
137
set_waker(self, cur_state: usize, waker: *const ()) -> bool138 pub(crate) unsafe fn set_waker(self, cur_state: usize, waker: *const ()) -> bool {
139 let vir_table = self.header().vtable;
140 (vir_table.set_waker)(self.ptr, cur_state, waker)
141 }
142
drop_ref(self)143 pub(crate) fn drop_ref(self) {
144 let vir_table = self.header().vtable;
145 unsafe {
146 (vir_table.drop_ref)(self.ptr);
147 }
148 }
149
drop_join_handle(self)150 pub(crate) fn drop_join_handle(self) {
151 let vir_table = self.header().vtable;
152 unsafe {
153 (vir_table.drop_join_handle)(self.ptr);
154 }
155 }
156 }
157
158 #[cfg(not(feature = "ffrt"))]
159 impl RawTask {
160 #[inline]
form_raw(ptr: NonNull<Header>) -> RawTask161 pub(crate) fn form_raw(ptr: NonNull<Header>) -> RawTask {
162 RawTask { ptr }
163 }
164
shutdown(self)165 pub(super) fn shutdown(self) {
166 let vir_table = self.header().vtable;
167 unsafe {
168 (vir_table.release)(self.ptr);
169 }
170 }
171 }
172
173 pub(crate) enum Stage<T: Future> {
174 Executing(T),
175 Executed,
176 StoreData(Result<T::Output, ScheduleError>),
177 UsedData,
178 }
179
180 #[repr(C)]
181 pub(crate) struct Inner<T: Future, S: Schedule> {
182 /// The execution stage of the future
183 pub(crate) stage: UnsafeCell<Stage<T>>,
184 /// The scheduler of the task queue
185 pub(crate) scheduler: Weak<S>,
186 /// Waker of the task waiting on this task
187 pub(crate) waker: UnsafeCell<Option<Waker>>,
188 /// Task in adaptive runtime
189 #[cfg(feature = "ffrt")]
190 pub(crate) task: UnsafeCell<Option<FfrtTaskCtx>>,
191 }
192
193 impl<T, S> Inner<T, S>
194 where
195 T: Future,
196 S: Schedule,
197 {
new(task: T, scheduler: Weak<S>) -> Self198 fn new(task: T, scheduler: Weak<S>) -> Self {
199 Inner {
200 stage: UnsafeCell::new(Stage::Executing(task)),
201 scheduler,
202 waker: UnsafeCell::new(None),
203 #[cfg(feature = "ffrt")]
204 task: UnsafeCell::new(None),
205 }
206 }
207
208 #[cfg(feature = "ffrt")]
get_task_ctx(&self)209 pub(crate) fn get_task_ctx(&self) {
210 unsafe {
211 if (*self.task.get()).is_none() {
212 (*self.task.get()).replace(FfrtTaskCtx::get_current());
213 }
214 }
215 }
216
turning_to_executed(&self)217 fn turning_to_executed(&self) {
218 let stage = self.stage.get();
219 unsafe {
220 *stage = Stage::Executed;
221 }
222 }
223
turning_to_used_data(&self)224 pub(crate) fn turning_to_used_data(&self) {
225 let stage = self.stage.get();
226 unsafe {
227 *stage = Stage::UsedData;
228 }
229 }
230
turning_to_store_data(&self, output: std::result::Result<T::Output, ScheduleError>)231 fn turning_to_store_data(&self, output: std::result::Result<T::Output, ScheduleError>) {
232 let stage = self.stage.get();
233 unsafe {
234 *stage = Stage::StoreData(output);
235 }
236 }
237
turning_to_get_data(&self) -> Result<T::Output, ScheduleError>238 pub(crate) fn turning_to_get_data(&self) -> Result<T::Output, ScheduleError> {
239 let stage = self.stage.get();
240 let data = mem::replace(unsafe { &mut *stage }, Stage::UsedData);
241 match data {
242 Stage::StoreData(output) => output,
243 _ => panic!("invalid task stage: the output is not stored inside the task"),
244 }
245 }
246
poll(&self, context: &mut Context) -> Poll<T::Output>247 pub(crate) fn poll(&self, context: &mut Context) -> Poll<T::Output> {
248 let stage = self.stage.get();
249 let future = match unsafe { &mut *stage } {
250 Stage::Executing(future) => future,
251 _ => panic!("invalid task stage: task polled while not being executed"),
252 };
253
254 let future = unsafe { Pin::new_unchecked(future) };
255 let res = future.poll(context);
256
257 // if result is received, turn the task to finished
258 if res.is_ready() {
259 self.turning_to_executed();
260 }
261 res
262 }
263
send_result(&self, output: Result<T::Output, ScheduleError>)264 pub(crate) fn send_result(&self, output: Result<T::Output, ScheduleError>) {
265 self.turning_to_store_data(output);
266 }
267
wake_join(&self)268 pub(crate) fn wake_join(&self) {
269 let waker = self.waker.get();
270 match unsafe { &*waker } {
271 Some(waker) => {
272 waker.wake_by_ref();
273 }
274 None => panic!("task waker has not been set"),
275 }
276 }
277 }
278
279 /// Manages task infos.
280 /// `repr(C)` is necessary because we cast a pointer of [`TaskMngInfo`] into a
281 /// pointer of [`Header`].
282 #[repr(C)]
283 pub(crate) struct TaskMngInfo<T: Future, S: Schedule> {
284 /// a pointer to the heap-allocated task
285 header: Header,
286 inner: Inner<T, S>,
287 }
288
get_result<T, S>(ptr: NonNull<Header>, res: *mut ()) where T: Future, S: Schedule,289 unsafe fn get_result<T, S>(ptr: NonNull<Header>, res: *mut ())
290 where
291 T: Future,
292 S: Schedule,
293 {
294 let out = &mut *(res.cast::<Poll<Result<T::Output, ScheduleError>>>());
295 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
296 task_handle.get_result(out);
297 }
298
drop_ref<T, S>(ptr: NonNull<Header>) where T: Future, S: Schedule,299 unsafe fn drop_ref<T, S>(ptr: NonNull<Header>)
300 where
301 T: Future,
302 S: Schedule,
303 {
304 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
305 task_handle.drop_ref();
306 }
307
set_waker<T, S>(ptr: NonNull<Header>, cur_state: usize, waker: *const ()) -> bool where T: Future, S: Schedule,308 unsafe fn set_waker<T, S>(ptr: NonNull<Header>, cur_state: usize, waker: *const ()) -> bool
309 where
310 T: Future,
311 S: Schedule,
312 {
313 let waker = &*(waker.cast::<Waker>());
314 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
315 task_handle.set_waker(cur_state, waker)
316 }
317
drop_join_handle<T, S>(ptr: NonNull<Header>) where T: Future, S: Schedule,318 unsafe fn drop_join_handle<T, S>(ptr: NonNull<Header>)
319 where
320 T: Future,
321 S: Schedule,
322 {
323 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
324 task_handle.drop_join_handle();
325 }
326
327 cfg_not_ffrt! {
328 unsafe fn run<T, S>(ptr: NonNull<Header>) -> bool
329 where
330 T: Future,
331 S: Schedule,
332 {
333 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
334 task_handle.run();
335 true
336 }
337
338 unsafe fn schedule<T, S>(ptr: NonNull<Header>, flag: bool)
339 where
340 T: Future,
341 S: Schedule,
342 {
343 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
344 if flag {
345 task_handle.wake();
346 } else {
347 task_handle.wake_by_ref();
348 }
349 }
350
351 unsafe fn release<T, S>(ptr: NonNull<Header>)
352 where
353 T: Future,
354 S: Schedule,
355 {
356 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
357 task_handle.shutdown();
358 }
359
360 unsafe fn cancel<T, S>(ptr: NonNull<Header>)
361 where
362 T: Future,
363 S: Schedule,
364 {
365 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
366 task_handle.set_canceled();
367 }
368
369 fn create_vtable<T, S>() -> &'static TaskVirtualTable
370 where
371 T: Future,
372 S: Schedule,
373 {
374 &TaskVirtualTable {
375 run: run::<T, S>,
376 schedule: schedule::<T, S>,
377 get_result: get_result::<T, S>,
378 drop_join_handle: drop_join_handle::<T, S>,
379 drop_ref: drop_ref::<T, S>,
380 set_waker: set_waker::<T, S>,
381 #[cfg(not(feature = "ffrt"))]
382 release: release::<T, S>,
383 cancel: cancel::<T, S>,
384 }
385 }
386 }
387
388 cfg_ffrt! {
389 unsafe fn ffrt_run<T, S>(ptr: NonNull<Header>) -> bool
390 where
391 T: Future,
392 S: Schedule,
393 {
394 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
395 task_handle.ffrt_run()
396 }
397
398 unsafe fn ffrt_schedule<T, S>(ptr: NonNull<Header>, flag: bool)
399 where
400 T: Future,
401 S: Schedule,
402 {
403 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
404 if flag {
405 task_handle.ffrt_wake();
406 } else {
407 task_handle.ffrt_wake_by_ref();
408 }
409 }
410
411 unsafe fn ffrt_cancel<T, S>(ptr: NonNull<Header>)
412 where
413 T: Future,
414 S: Schedule,
415 {
416 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
417 task_handle.ffrt_set_canceled();
418 }
419
420 fn create_ffrt_vtable<T, S>() -> &'static TaskVirtualTable
421 where
422 T: Future,
423 S: Schedule,
424 {
425 &TaskVirtualTable {
426 run: ffrt_run::<T, S>,
427 schedule: ffrt_schedule::<T, S>,
428 get_result: get_result::<T, S>,
429 drop_join_handle: drop_join_handle::<T, S>,
430 drop_ref: drop_ref::<T, S>,
431 set_waker: set_waker::<T, S>,
432 cancel: ffrt_cancel::<T, S>,
433 }
434 }
435 }
436
437 impl<T, S> TaskMngInfo<T, S>
438 where
439 T: Future,
440 S: Schedule,
441 {
442 /// Creates non-stackful task info.
443 // TODO: builder information currently is not used yet. Might use in the future
444 // (e.g. qos), so keep it now.
new( _builder: &TaskBuilder, scheduler: Weak<S>, task: T, virtual_table_type: VirtualTableType, ) -> Box<Self>445 pub(crate) fn new(
446 _builder: &TaskBuilder,
447 scheduler: Weak<S>,
448 task: T,
449 virtual_table_type: VirtualTableType,
450 ) -> Box<Self> {
451 let vtable = match virtual_table_type {
452 #[cfg(not(feature = "ffrt"))]
453 VirtualTableType::Ylong => create_vtable::<T, S>(),
454 #[cfg(feature = "ffrt")]
455 VirtualTableType::Ffrt => create_ffrt_vtable::<T, S>(),
456 };
457 // Create the common header
458 let header = Header {
459 state: TaskState::new(),
460 vtable,
461 node: Node::new(),
462 };
463 // Create task private info
464 let inner = Inner::<T, S>::new(task, scheduler);
465 // Allocate it onto the heap
466 Box::new(TaskMngInfo { header, inner })
467 }
468
header(&self) -> &Header469 pub(crate) fn header(&self) -> &Header {
470 &self.header
471 }
472
inner(&self) -> &Inner<T, S>473 pub(crate) fn inner(&self) -> &Inner<T, S> {
474 &self.inner
475 }
476 }
477