• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cell::UnsafeCell;
15 use std::future::Future;
16 use std::mem;
17 use std::pin::Pin;
18 use std::ptr::{addr_of_mut, NonNull};
19 use std::sync::Weak;
20 use std::task::{Context, Poll, Waker};
21 
22 use crate::error::ScheduleError;
23 use crate::executor::Schedule;
24 use crate::task::state::TaskState;
25 use crate::task::task_handle::TaskHandle;
26 use crate::task::{TaskBuilder, VirtualTableType};
27 use crate::util::linked_list::{Link, Node};
28 
29 cfg_ffrt! {
30     use crate::ffrt::ffrt_task::FfrtTaskCtx;
31 }
32 
33 pub(crate) struct TaskVirtualTable {
34     /// Task running method
35     pub(crate) run: unsafe fn(NonNull<Header>) -> bool,
36     /// Task scheduling method
37     pub(crate) schedule: unsafe fn(NonNull<Header>, bool),
38     /// Task result-getting method
39     pub(crate) get_result: unsafe fn(NonNull<Header>, *mut ()),
40     /// JoinHandle drop method
41     pub(crate) drop_join_handle: unsafe fn(NonNull<Header>),
42     /// Task reference drop method
43     pub(crate) drop_ref: unsafe fn(NonNull<Header>),
44     /// Task waker setting method
45     pub(crate) set_waker: unsafe fn(NonNull<Header>, cur_state: usize, waker: *const ()) -> bool,
46     /// Task release method
47     #[cfg(not(feature = "ffrt"))]
48     pub(crate) release: unsafe fn(NonNull<Header>),
49     /// Task cancel method
50     pub(crate) cancel: unsafe fn(NonNull<Header>),
51 }
52 
53 #[repr(C)]
54 pub(crate) struct Header {
55     pub(crate) state: TaskState,
56     pub(crate) vtable: &'static TaskVirtualTable,
57     // Node inside the global queue
58     node: Node<Header>,
59 }
60 
61 cfg_not_ffrt! {
62     fn get_default_vtable() -> &'static TaskVirtualTable {
63         unsafe fn default_run(_task: NonNull<Header>) -> bool {
64             false
65         }
66         unsafe fn default_schedule(_task: NonNull<Header>, _fifo: bool) {}
67         unsafe fn default_get_result(_task: NonNull<Header>, _result: *mut ()) {}
68         unsafe fn default_drop_handle(_task: NonNull<Header>) {}
69         unsafe fn default_set_waker(
70             _task: NonNull<Header>,
71             _cur_state: usize,
72             _waker: *const (),
73         ) -> bool {
74             false
75         }
76         unsafe fn default_drop_ref(_task: NonNull<Header>) {}
77         unsafe fn default_release(_task: NonNull<Header>) {}
78         unsafe fn default_cancel(_task: NonNull<Header>) {}
79 
80         &TaskVirtualTable {
81             run: default_run,
82             schedule: default_schedule,
83             get_result: default_get_result,
84             drop_join_handle: default_drop_handle,
85             drop_ref: default_drop_ref,
86             set_waker: default_set_waker,
87             release: default_release,
88             cancel: default_cancel,
89         }
90     }
91 
92     impl Default for Header {
93         fn default() -> Self {
94             Self {
95                 state: TaskState::new(),
96                 vtable: get_default_vtable(),
97                 node: Default::default(),
98             }
99         }
100     }
101 }
102 
103 unsafe impl Link for Header {
node(mut ptr: NonNull<Self>) -> NonNull<Node<Self>> where Self: Sized,104     unsafe fn node(mut ptr: NonNull<Self>) -> NonNull<Node<Self>>
105     where
106         Self: Sized,
107     {
108         let node_ptr = addr_of_mut!(ptr.as_mut().node);
109         NonNull::new_unchecked(node_ptr)
110     }
111 }
112 
113 #[derive(PartialEq, Eq, Hash, Clone, Copy)]
114 pub(crate) struct RawTask {
115     pub(crate) ptr: NonNull<Header>,
116 }
117 
118 impl RawTask {
header(&self) -> &Header119     pub(crate) fn header(&self) -> &Header {
120         unsafe { self.ptr.as_ref() }
121     }
122 
run(self) -> bool123     pub(crate) fn run(self) -> bool {
124         let vir_table = self.header().vtable;
125         unsafe { (vir_table.run)(self.ptr) }
126     }
127 
get_result(self, res: *mut ())128     pub(crate) unsafe fn get_result(self, res: *mut ()) {
129         let vir_table = self.header().vtable;
130         (vir_table.get_result)(self.ptr, res);
131     }
132 
cancel(self)133     pub(crate) unsafe fn cancel(self) {
134         let vir_table = self.header().vtable;
135         (vir_table.cancel)(self.ptr)
136     }
137 
set_waker(self, cur_state: usize, waker: *const ()) -> bool138     pub(crate) unsafe fn set_waker(self, cur_state: usize, waker: *const ()) -> bool {
139         let vir_table = self.header().vtable;
140         (vir_table.set_waker)(self.ptr, cur_state, waker)
141     }
142 
drop_ref(self)143     pub(crate) fn drop_ref(self) {
144         let vir_table = self.header().vtable;
145         unsafe {
146             (vir_table.drop_ref)(self.ptr);
147         }
148     }
149 
drop_join_handle(self)150     pub(crate) fn drop_join_handle(self) {
151         let vir_table = self.header().vtable;
152         unsafe {
153             (vir_table.drop_join_handle)(self.ptr);
154         }
155     }
156 }
157 
158 #[cfg(not(feature = "ffrt"))]
159 impl RawTask {
160     #[inline]
form_raw(ptr: NonNull<Header>) -> RawTask161     pub(crate) fn form_raw(ptr: NonNull<Header>) -> RawTask {
162         RawTask { ptr }
163     }
164 
shutdown(self)165     pub(super) fn shutdown(self) {
166         let vir_table = self.header().vtable;
167         unsafe {
168             (vir_table.release)(self.ptr);
169         }
170     }
171 }
172 
173 pub(crate) enum Stage<T: Future> {
174     Executing(T),
175     Executed,
176     StoreData(Result<T::Output, ScheduleError>),
177     UsedData,
178 }
179 
180 #[repr(C)]
181 pub(crate) struct Inner<T: Future, S: Schedule> {
182     /// The execution stage of the future
183     pub(crate) stage: UnsafeCell<Stage<T>>,
184     /// The scheduler of the task queue
185     pub(crate) scheduler: Weak<S>,
186     /// Waker of the task waiting on this task
187     pub(crate) waker: UnsafeCell<Option<Waker>>,
188     /// Task in adaptive runtime
189     #[cfg(feature = "ffrt")]
190     pub(crate) task: UnsafeCell<Option<FfrtTaskCtx>>,
191 }
192 
193 impl<T, S> Inner<T, S>
194 where
195     T: Future,
196     S: Schedule,
197 {
new(task: T, scheduler: Weak<S>) -> Self198     fn new(task: T, scheduler: Weak<S>) -> Self {
199         Inner {
200             stage: UnsafeCell::new(Stage::Executing(task)),
201             scheduler,
202             waker: UnsafeCell::new(None),
203             #[cfg(feature = "ffrt")]
204             task: UnsafeCell::new(None),
205         }
206     }
207 
208     #[cfg(feature = "ffrt")]
get_task_ctx(&self)209     pub(crate) fn get_task_ctx(&self) {
210         unsafe {
211             if (*self.task.get()).is_none() {
212                 (*self.task.get()).replace(FfrtTaskCtx::get_current());
213             }
214         }
215     }
216 
turning_to_executed(&self)217     fn turning_to_executed(&self) {
218         let stage = self.stage.get();
219         unsafe {
220             *stage = Stage::Executed;
221         }
222     }
223 
turning_to_used_data(&self)224     pub(crate) fn turning_to_used_data(&self) {
225         let stage = self.stage.get();
226         unsafe {
227             *stage = Stage::UsedData;
228         }
229     }
230 
turning_to_store_data(&self, output: std::result::Result<T::Output, ScheduleError>)231     fn turning_to_store_data(&self, output: std::result::Result<T::Output, ScheduleError>) {
232         let stage = self.stage.get();
233         unsafe {
234             *stage = Stage::StoreData(output);
235         }
236     }
237 
turning_to_get_data(&self) -> Result<T::Output, ScheduleError>238     pub(crate) fn turning_to_get_data(&self) -> Result<T::Output, ScheduleError> {
239         let stage = self.stage.get();
240         let data = mem::replace(unsafe { &mut *stage }, Stage::UsedData);
241         match data {
242             Stage::StoreData(output) => output,
243             _ => panic!("invalid task stage: the output is not stored inside the task"),
244         }
245     }
246 
poll(&self, context: &mut Context) -> Poll<T::Output>247     pub(crate) fn poll(&self, context: &mut Context) -> Poll<T::Output> {
248         let stage = self.stage.get();
249         let future = match unsafe { &mut *stage } {
250             Stage::Executing(future) => future,
251             _ => panic!("invalid task stage: task polled while not being executed"),
252         };
253 
254         let future = unsafe { Pin::new_unchecked(future) };
255         let res = future.poll(context);
256 
257         // if result is received, turn the task to finished
258         if res.is_ready() {
259             self.turning_to_executed();
260         }
261         res
262     }
263 
send_result(&self, output: Result<T::Output, ScheduleError>)264     pub(crate) fn send_result(&self, output: Result<T::Output, ScheduleError>) {
265         self.turning_to_store_data(output);
266     }
267 
wake_join(&self)268     pub(crate) fn wake_join(&self) {
269         let waker = self.waker.get();
270         match unsafe { &*waker } {
271             Some(waker) => {
272                 waker.wake_by_ref();
273             }
274             None => panic!("task waker has not been set"),
275         }
276     }
277 }
278 
279 /// Manages task infos.
280 /// `repr(C)` is necessary because we cast a pointer of [`TaskMngInfo`] into a
281 /// pointer of [`Header`].
282 #[repr(C)]
283 pub(crate) struct TaskMngInfo<T: Future, S: Schedule> {
284     /// a pointer to the heap-allocated task
285     header: Header,
286     inner: Inner<T, S>,
287 }
288 
get_result<T, S>(ptr: NonNull<Header>, res: *mut ()) where T: Future, S: Schedule,289 unsafe fn get_result<T, S>(ptr: NonNull<Header>, res: *mut ())
290 where
291     T: Future,
292     S: Schedule,
293 {
294     let out = &mut *(res.cast::<Poll<Result<T::Output, ScheduleError>>>());
295     let task_handle = TaskHandle::<T, S>::from_raw(ptr);
296     task_handle.get_result(out);
297 }
298 
drop_ref<T, S>(ptr: NonNull<Header>) where T: Future, S: Schedule,299 unsafe fn drop_ref<T, S>(ptr: NonNull<Header>)
300 where
301     T: Future,
302     S: Schedule,
303 {
304     let task_handle = TaskHandle::<T, S>::from_raw(ptr);
305     task_handle.drop_ref();
306 }
307 
set_waker<T, S>(ptr: NonNull<Header>, cur_state: usize, waker: *const ()) -> bool where T: Future, S: Schedule,308 unsafe fn set_waker<T, S>(ptr: NonNull<Header>, cur_state: usize, waker: *const ()) -> bool
309 where
310     T: Future,
311     S: Schedule,
312 {
313     let waker = &*(waker.cast::<Waker>());
314     let task_handle = TaskHandle::<T, S>::from_raw(ptr);
315     task_handle.set_waker(cur_state, waker)
316 }
317 
drop_join_handle<T, S>(ptr: NonNull<Header>) where T: Future, S: Schedule,318 unsafe fn drop_join_handle<T, S>(ptr: NonNull<Header>)
319 where
320     T: Future,
321     S: Schedule,
322 {
323     let task_handle = TaskHandle::<T, S>::from_raw(ptr);
324     task_handle.drop_join_handle();
325 }
326 
327 cfg_not_ffrt! {
328     unsafe fn run<T, S>(ptr: NonNull<Header>) -> bool
329     where
330         T: Future,
331         S: Schedule,
332     {
333         let task_handle = TaskHandle::<T, S>::from_raw(ptr);
334         task_handle.run();
335         true
336     }
337 
338     unsafe fn schedule<T, S>(ptr: NonNull<Header>, flag: bool)
339     where
340         T: Future,
341         S: Schedule,
342     {
343         let task_handle = TaskHandle::<T, S>::from_raw(ptr);
344         if flag {
345             task_handle.wake();
346         } else {
347             task_handle.wake_by_ref();
348         }
349     }
350 
351     unsafe fn release<T, S>(ptr: NonNull<Header>)
352     where
353         T: Future,
354         S: Schedule,
355     {
356         let task_handle = TaskHandle::<T, S>::from_raw(ptr);
357         task_handle.shutdown();
358     }
359 
360     unsafe fn cancel<T, S>(ptr: NonNull<Header>)
361     where
362         T: Future,
363         S: Schedule,
364     {
365         let task_handle = TaskHandle::<T, S>::from_raw(ptr);
366         task_handle.set_canceled();
367     }
368 
369     fn create_vtable<T, S>() -> &'static TaskVirtualTable
370     where
371         T: Future,
372         S: Schedule,
373     {
374         &TaskVirtualTable {
375             run: run::<T, S>,
376             schedule: schedule::<T, S>,
377             get_result: get_result::<T, S>,
378             drop_join_handle: drop_join_handle::<T, S>,
379             drop_ref: drop_ref::<T, S>,
380             set_waker: set_waker::<T, S>,
381             #[cfg(not(feature = "ffrt"))]
382             release: release::<T, S>,
383             cancel: cancel::<T, S>,
384         }
385     }
386 }
387 
388 cfg_ffrt! {
389     unsafe fn ffrt_run<T, S>(ptr: NonNull<Header>) -> bool
390     where
391         T: Future,
392         S: Schedule,
393     {
394         let task_handle = TaskHandle::<T, S>::from_raw(ptr);
395         task_handle.ffrt_run()
396     }
397 
398     unsafe fn ffrt_schedule<T, S>(ptr: NonNull<Header>, flag: bool)
399     where
400         T: Future,
401         S: Schedule,
402     {
403         let task_handle = TaskHandle::<T, S>::from_raw(ptr);
404         if flag {
405             task_handle.ffrt_wake();
406         } else {
407             task_handle.ffrt_wake_by_ref();
408         }
409     }
410 
411     unsafe fn ffrt_cancel<T, S>(ptr: NonNull<Header>)
412     where
413         T: Future,
414         S: Schedule,
415     {
416         let task_handle = TaskHandle::<T, S>::from_raw(ptr);
417         task_handle.ffrt_set_canceled();
418     }
419 
420     fn create_ffrt_vtable<T, S>() -> &'static TaskVirtualTable
421     where
422         T: Future,
423         S: Schedule,
424     {
425         &TaskVirtualTable {
426             run: ffrt_run::<T, S>,
427             schedule: ffrt_schedule::<T, S>,
428             get_result: get_result::<T, S>,
429             drop_join_handle: drop_join_handle::<T, S>,
430             drop_ref: drop_ref::<T, S>,
431             set_waker: set_waker::<T, S>,
432             cancel: ffrt_cancel::<T, S>,
433         }
434     }
435 }
436 
437 impl<T, S> TaskMngInfo<T, S>
438 where
439     T: Future,
440     S: Schedule,
441 {
442     /// Creates non-stackful task info.
443     // TODO: builder information currently is not used yet. Might use in the future
444     // (e.g. qos),   so keep it now.
new( _builder: &TaskBuilder, scheduler: Weak<S>, task: T, virtual_table_type: VirtualTableType, ) -> Box<Self>445     pub(crate) fn new(
446         _builder: &TaskBuilder,
447         scheduler: Weak<S>,
448         task: T,
449         virtual_table_type: VirtualTableType,
450     ) -> Box<Self> {
451         let vtable = match virtual_table_type {
452             #[cfg(not(feature = "ffrt"))]
453             VirtualTableType::Ylong => create_vtable::<T, S>(),
454             #[cfg(feature = "ffrt")]
455             VirtualTableType::Ffrt => create_ffrt_vtable::<T, S>(),
456         };
457         // Create the common header
458         let header = Header {
459             state: TaskState::new(),
460             vtable,
461             node: Node::new(),
462         };
463         // Create task private info
464         let inner = Inner::<T, S>::new(task, scheduler);
465         // Allocate it onto the heap
466         Box::new(TaskMngInfo { header, inner })
467     }
468 
header(&self) -> &Header469     pub(crate) fn header(&self) -> &Header {
470         &self.header
471     }
472 
inner(&self) -> &Inner<T, S>473     pub(crate) fn inner(&self) -> &Inner<T, S> {
474         &self.inner
475     }
476 }
477