1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::cell::UnsafeCell;
15 use std::future::Future;
16 use std::mem;
17 use std::pin::Pin;
18 use std::ptr::NonNull;
19 use std::sync::Weak;
20 use std::task::{Context, Poll, Waker};
21
22 use crate::error::ScheduleError;
23 use crate::executor::Schedule;
24 use crate::task::state::TaskState;
25 use crate::task::task_handle::TaskHandle;
26 use crate::task::{TaskBuilder, VirtualTableType};
27
28 cfg_ffrt! {
29 use crate::ffrt::ffrt_task::FfrtTaskCtx;
30 }
31
32 pub(crate) struct TaskVirtualTable {
33 /// Task running method
34 pub(crate) run: unsafe fn(NonNull<Header>) -> bool,
35 /// Task scheduling method
36 pub(crate) schedule: unsafe fn(NonNull<Header>, bool),
37 /// Task result-getting method
38 pub(crate) get_result: unsafe fn(NonNull<Header>, *mut ()),
39 /// JoinHandle drop method
40 pub(crate) drop_join_handle: unsafe fn(NonNull<Header>),
41 /// Task reference drop method
42 pub(crate) drop_ref: unsafe fn(NonNull<Header>),
43 /// Task waker setting method
44 pub(crate) set_waker: unsafe fn(NonNull<Header>, cur_state: usize, waker: *const ()) -> bool,
45 /// Task release method
46 #[cfg(not(feature = "ffrt"))]
47 pub(crate) release: unsafe fn(NonNull<Header>),
48 /// Task cancel method
49 pub(crate) cancel: unsafe fn(NonNull<Header>),
50 }
51
52 #[repr(C)]
53 pub(crate) struct Header {
54 pub(crate) state: TaskState,
55 pub(crate) vtable: &'static TaskVirtualTable,
56 }
57
58 #[derive(PartialEq, Eq, Hash, Clone, Copy)]
59 pub(crate) struct RawTask {
60 pub(crate) ptr: NonNull<Header>,
61 }
62
63 impl RawTask {
form_raw(ptr: NonNull<Header>) -> RawTask64 pub(crate) fn form_raw(ptr: NonNull<Header>) -> RawTask {
65 RawTask { ptr }
66 }
67
header(&self) -> &Header68 pub(crate) fn header(&self) -> &Header {
69 unsafe { self.ptr.as_ref() }
70 }
71
run(self) -> bool72 pub(crate) fn run(self) -> bool {
73 let vir_table = self.header().vtable;
74 unsafe { (vir_table.run)(self.ptr) }
75 }
76
get_result(self, res: *mut ())77 pub(crate) unsafe fn get_result(self, res: *mut ()) {
78 let vir_table = self.header().vtable;
79 (vir_table.get_result)(self.ptr, res);
80 }
81
cancel(self)82 pub(crate) unsafe fn cancel(self) {
83 let vir_table = self.header().vtable;
84 (vir_table.cancel)(self.ptr)
85 }
86
set_waker(self, cur_state: usize, waker: *const ()) -> bool87 pub(crate) unsafe fn set_waker(self, cur_state: usize, waker: *const ()) -> bool {
88 let vir_table = self.header().vtable;
89 (vir_table.set_waker)(self.ptr, cur_state, waker)
90 }
91
drop_ref(self)92 pub(crate) fn drop_ref(self) {
93 let vir_table = self.header().vtable;
94 unsafe {
95 (vir_table.drop_ref)(self.ptr);
96 }
97 }
98
drop_join_handle(self)99 pub(crate) fn drop_join_handle(self) {
100 let vir_table = self.header().vtable;
101 unsafe {
102 (vir_table.drop_join_handle)(self.ptr);
103 }
104 }
105 }
106
107 #[cfg(not(feature = "ffrt"))]
108 impl RawTask {
shutdown(self)109 pub(super) fn shutdown(self) {
110 let vir_table = self.header().vtable;
111 unsafe {
112 (vir_table.release)(self.ptr);
113 }
114 }
115 }
116
117 pub(crate) enum Stage<T: Future> {
118 Executing(T),
119 Executed,
120 StoreData(Result<T::Output, ScheduleError>),
121 UsedData,
122 }
123
124 #[repr(C)]
125 pub(crate) struct Inner<T: Future, S: Schedule> {
126 /// The execution stage of the future
127 pub(crate) stage: UnsafeCell<Stage<T>>,
128 /// The scheduler of the task queue
129 pub(crate) scheduler: Weak<S>,
130 /// Waker of the task waiting on this task
131 pub(crate) waker: UnsafeCell<Option<Waker>>,
132 /// Task in adaptive runtime
133 #[cfg(feature = "ffrt")]
134 pub(crate) task: UnsafeCell<Option<FfrtTaskCtx>>,
135 }
136
137 impl<T, S> Inner<T, S>
138 where
139 T: Future,
140 S: Schedule,
141 {
new(task: T, scheduler: Weak<S>) -> Self142 fn new(task: T, scheduler: Weak<S>) -> Self {
143 Inner {
144 stage: UnsafeCell::new(Stage::Executing(task)),
145 scheduler,
146 waker: UnsafeCell::new(None),
147 #[cfg(feature = "ffrt")]
148 task: UnsafeCell::new(None),
149 }
150 }
151
152 #[cfg(feature = "ffrt")]
get_task_ctx(&self)153 pub(crate) fn get_task_ctx(&self) {
154 unsafe {
155 if (*self.task.get()).is_none() {
156 (*self.task.get()).replace(FfrtTaskCtx::get_current());
157 }
158 }
159 }
160
turning_to_executed(&self)161 fn turning_to_executed(&self) {
162 let stage = self.stage.get();
163 unsafe {
164 *stage = Stage::Executed;
165 }
166 }
167
turning_to_used_data(&self)168 pub(crate) fn turning_to_used_data(&self) {
169 let stage = self.stage.get();
170 unsafe {
171 *stage = Stage::UsedData;
172 }
173 }
174
turning_to_store_data(&self, output: std::result::Result<T::Output, ScheduleError>)175 fn turning_to_store_data(&self, output: std::result::Result<T::Output, ScheduleError>) {
176 let stage = self.stage.get();
177 unsafe {
178 *stage = Stage::StoreData(output);
179 }
180 }
181
turning_to_get_data(&self) -> Result<T::Output, ScheduleError>182 pub(crate) fn turning_to_get_data(&self) -> Result<T::Output, ScheduleError> {
183 let stage = self.stage.get();
184 let data = mem::replace(unsafe { &mut *stage }, Stage::UsedData);
185 match data {
186 Stage::StoreData(output) => output,
187 _ => panic!("invalid task stage: the output is not stored inside the task"),
188 }
189 }
190
poll(&self, context: &mut Context) -> Poll<T::Output>191 pub(crate) fn poll(&self, context: &mut Context) -> Poll<T::Output> {
192 let stage = self.stage.get();
193 let future = match unsafe { &mut *stage } {
194 Stage::Executing(future) => future,
195 _ => panic!("invalid task stage: task polled while not being executed"),
196 };
197
198 let future = unsafe { Pin::new_unchecked(future) };
199 let res = future.poll(context);
200
201 // if result is received, turn the task to finished
202 if res.is_ready() {
203 self.turning_to_executed();
204 }
205 res
206 }
207
send_result(&self, output: Result<T::Output, ScheduleError>)208 pub(crate) fn send_result(&self, output: Result<T::Output, ScheduleError>) {
209 self.turning_to_store_data(output);
210 }
211
wake_join(&self)212 pub(crate) fn wake_join(&self) {
213 let waker = self.waker.get();
214 match unsafe { &*waker } {
215 Some(waker) => {
216 waker.wake_by_ref();
217 }
218 None => panic!("task waker has not been set"),
219 }
220 }
221 }
222
223 /// Manages task infos.
224 /// `repr(C)` is necessary because we cast a pointer of [`TaskMngInfo`] into a
225 /// pointer of [`Header`].
226 #[repr(C)]
227 pub(crate) struct TaskMngInfo<T: Future, S: Schedule> {
228 /// a pointer to the heap-allocated task
229 header: Header,
230 inner: Inner<T, S>,
231 }
232
run<T, S>(ptr: NonNull<Header>) -> bool where T: Future, S: Schedule,233 unsafe fn run<T, S>(ptr: NonNull<Header>) -> bool
234 where
235 T: Future,
236 S: Schedule,
237 {
238 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
239 task_handle.run();
240 true
241 }
242
schedule<T, S>(ptr: NonNull<Header>, flag: bool) where T: Future, S: Schedule,243 unsafe fn schedule<T, S>(ptr: NonNull<Header>, flag: bool)
244 where
245 T: Future,
246 S: Schedule,
247 {
248 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
249 if flag {
250 task_handle.wake();
251 } else {
252 task_handle.wake_by_ref();
253 }
254 }
255
get_result<T, S>(ptr: NonNull<Header>, res: *mut ()) where T: Future, S: Schedule,256 unsafe fn get_result<T, S>(ptr: NonNull<Header>, res: *mut ())
257 where
258 T: Future,
259 S: Schedule,
260 {
261 let out = &mut *(res as *mut Poll<Result<T::Output, ScheduleError>>);
262 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
263 task_handle.get_result(out);
264 }
265
drop_ref<T, S>(ptr: NonNull<Header>) where T: Future, S: Schedule,266 unsafe fn drop_ref<T, S>(ptr: NonNull<Header>)
267 where
268 T: Future,
269 S: Schedule,
270 {
271 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
272 task_handle.drop_ref();
273 }
274
set_waker<T, S>(ptr: NonNull<Header>, cur_state: usize, waker: *const ()) -> bool where T: Future, S: Schedule,275 unsafe fn set_waker<T, S>(ptr: NonNull<Header>, cur_state: usize, waker: *const ()) -> bool
276 where
277 T: Future,
278 S: Schedule,
279 {
280 let waker = &*(waker as *const Waker);
281 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
282 task_handle.set_waker(cur_state, waker)
283 }
284
drop_join_handle<T, S>(ptr: NonNull<Header>) where T: Future, S: Schedule,285 unsafe fn drop_join_handle<T, S>(ptr: NonNull<Header>)
286 where
287 T: Future,
288 S: Schedule,
289 {
290 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
291 task_handle.drop_join_handle();
292 }
293
294 #[cfg(not(feature = "ffrt"))]
release<T, S>(ptr: NonNull<Header>) where T: Future, S: Schedule,295 unsafe fn release<T, S>(ptr: NonNull<Header>)
296 where
297 T: Future,
298 S: Schedule,
299 {
300 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
301 task_handle.shutdown();
302 }
303
cancel<T, S>(ptr: NonNull<Header>) where T: Future, S: Schedule,304 unsafe fn cancel<T, S>(ptr: NonNull<Header>)
305 where
306 T: Future,
307 S: Schedule,
308 {
309 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
310 task_handle.set_canceled();
311 }
312
create_vtable<T, S>() -> &'static TaskVirtualTable where T: Future, S: Schedule,313 fn create_vtable<T, S>() -> &'static TaskVirtualTable
314 where
315 T: Future,
316 S: Schedule,
317 {
318 &TaskVirtualTable {
319 run: run::<T, S>,
320 schedule: schedule::<T, S>,
321 get_result: get_result::<T, S>,
322 drop_join_handle: drop_join_handle::<T, S>,
323 drop_ref: drop_ref::<T, S>,
324 set_waker: set_waker::<T, S>,
325 #[cfg(not(feature = "ffrt"))]
326 release: release::<T, S>,
327 cancel: cancel::<T, S>,
328 }
329 }
330
331 cfg_ffrt! {
332 unsafe fn ffrt_run<T, S>(ptr: NonNull<Header>) -> bool
333 where
334 T: Future,
335 S: Schedule,
336 {
337 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
338 task_handle.ffrt_run()
339 }
340
341 unsafe fn ffrt_schedule<T, S>(ptr: NonNull<Header>, flag: bool)
342 where
343 T: Future,
344 S: Schedule,
345 {
346 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
347 if flag {
348 task_handle.ffrt_wake();
349 } else {
350 task_handle.ffrt_wake_by_ref();
351 }
352 }
353
354 unsafe fn ffrt_cancel<T, S>(ptr: NonNull<Header>)
355 where
356 T: Future,
357 S: Schedule,
358 {
359 let task_handle = TaskHandle::<T, S>::from_raw(ptr);
360 task_handle.ffrt_set_canceled();
361 }
362
363 fn create_ffrt_vtable<T, S>() -> &'static TaskVirtualTable
364 where
365 T: Future,
366 S: Schedule,
367 {
368 &TaskVirtualTable {
369 run: ffrt_run::<T, S>,
370 schedule: ffrt_schedule::<T, S>,
371 get_result: get_result::<T, S>,
372 drop_join_handle: drop_join_handle::<T, S>,
373 drop_ref: drop_ref::<T, S>,
374 set_waker: set_waker::<T, S>,
375 cancel: ffrt_cancel::<T, S>,
376 }
377 }
378 }
379
380 impl<T, S> TaskMngInfo<T, S>
381 where
382 T: Future,
383 S: Schedule,
384 {
385 /// Creates non-stackful task info.
386 // TODO: builder information currently is not used yet. Might use in the future
387 // (e.g. qos), so keep it now.
new( _builder: &TaskBuilder, scheduler: Weak<S>, task: T, virtual_table_type: VirtualTableType, ) -> Box<Self>388 pub(crate) fn new(
389 _builder: &TaskBuilder,
390 scheduler: Weak<S>,
391 task: T,
392 virtual_table_type: VirtualTableType,
393 ) -> Box<Self> {
394 let vtable = match virtual_table_type {
395 VirtualTableType::Ylong => create_vtable::<T, S>(),
396 #[cfg(feature = "ffrt")]
397 VirtualTableType::Ffrt => create_ffrt_vtable::<T, S>(),
398 };
399 // Create the common header
400 let header = Header {
401 state: TaskState::new(),
402 vtable,
403 };
404 // Create task private info
405 let inner = Inner::<T, S>::new(task, scheduler);
406 // Allocate it onto the heap
407 Box::new(TaskMngInfo { header, inner })
408 }
409
header(&self) -> &Header410 pub(crate) fn header(&self) -> &Header {
411 &self.header
412 }
413
inner(&self) -> &Inner<T, S>414 pub(crate) fn inner(&self) -> &Inner<T, S> {
415 &self.inner
416 }
417 }
418