• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::future::Future;
15 use std::panic;
16 use std::ptr::NonNull;
17 use std::task::{Context, Poll, Waker};
18 
19 use crate::error::{ErrorKind, ScheduleError};
20 use crate::executor::Schedule;
21 use crate::task::raw::{Header, Inner, TaskMngInfo};
22 use crate::task::state::StateAction;
23 use crate::task::waker::WakerRefHeader;
24 use crate::task::{state, Task};
25 
26 cfg_not_ffrt! {
27     use crate::task::raw::Stage;
28 }
29 
30 pub(crate) struct TaskHandle<T: Future, S: Schedule> {
31     task: NonNull<TaskMngInfo<T, S>>,
32 }
33 
34 impl<T, S> TaskHandle<T, S>
35 where
36     T: Future,
37     S: Schedule,
38 {
from_raw(ptr: NonNull<Header>) -> Self39     pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Self {
40         TaskHandle {
41             task: ptr.cast::<TaskMngInfo<T, S>>(),
42         }
43     }
44 
header(&self) -> &Header45     fn header(&self) -> &Header {
46         unsafe { self.task.as_ref().header() }
47     }
48 
inner(&self) -> &Inner<T, S>49     fn inner(&self) -> &Inner<T, S> {
50         unsafe { self.task.as_ref().inner() }
51     }
52 }
53 
54 impl<T, S> TaskHandle<T, S>
55 where
56     T: Future,
57     S: Schedule,
58 {
release(self)59     pub(crate) fn release(self) {
60         unsafe { drop(Box::from_raw(self.task.as_ptr())) };
61     }
62 
drop_ref(self)63     pub(crate) fn drop_ref(self) {
64         let prev = self.header().state.dec_ref();
65         if state::is_last_ref_count(prev) {
66             self.release();
67         }
68     }
69 
finish(self, state: usize, output: Result<T::Output, ScheduleError>)70     fn finish(self, state: usize, output: Result<T::Output, ScheduleError>) {
71         // send result if the JoinHandle is not dropped
72         if state::is_care_join_handle(state) {
73             self.inner().send_result(output);
74         } else {
75             self.inner().turning_to_used_data();
76         }
77 
78         let res = self.header().state.turning_to_finish();
79         let cur = match res {
80             Ok(cur) => cur,
81             Err(e) => panic!("{}", e.as_str()),
82         };
83 
84         if state::is_set_waker(cur) {
85             self.inner().wake_join();
86         }
87         self.drop_ref();
88     }
89 
90     // Runs the task
run(self)91     pub(crate) fn run(self) {
92         let action = self.header().state.turning_to_running();
93 
94         match action {
95             StateAction::Success => {}
96             StateAction::Canceled(cur) => {
97                 let output = self.get_canceled();
98                 return self.finish(cur, Err(output));
99             }
100             StateAction::Failed(state) => panic!("task state invalid: {}", state),
101             _ => unreachable!(),
102         };
103 
104         // turn the task header into a waker
105         let waker = WakerRefHeader::<'_>::new::<T>(self.header());
106         let mut context = Context::from_waker(&waker);
107 
108         let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
109             self.inner().poll(&mut context).map(Ok)
110         }));
111 
112         let cur = self.header().state.get_current_state();
113         match res {
114             Ok(Poll::Ready(output)) => {
115                 // send result if the JoinHandle is not dropped
116                 self.finish(cur, output);
117             }
118 
119             Ok(Poll::Pending) => match self.header().state.turning_to_idle() {
120                 StateAction::Enqueue => {
121                     self.get_scheduled(true);
122                 }
123                 StateAction::Failed(state) => panic!("task state invalid: {}", state),
124                 StateAction::Canceled(state) => {
125                     let output = self.get_canceled();
126                     self.finish(state, Err(output));
127                 }
128                 _ => {}
129             },
130 
131             Err(_) => {
132                 let output = Err(ScheduleError::new(ErrorKind::Panic, "panic happen"));
133                 self.finish(cur, output);
134             }
135         }
136     }
137 
get_result(self, out: &mut Poll<std::result::Result<T::Output, ScheduleError>>)138     pub(crate) fn get_result(self, out: &mut Poll<std::result::Result<T::Output, ScheduleError>>) {
139         *out = Poll::Ready(self.inner().turning_to_get_data());
140     }
141 
drop_join_handle(self)142     pub(crate) fn drop_join_handle(self) {
143         if self.header().state.try_turning_to_un_join_handle() {
144             return;
145         }
146 
147         match self.header().state.turn_to_un_join_handle() {
148             Ok(_) => {}
149             Err(_) => {
150                 self.inner().turning_to_used_data();
151             }
152         }
153         self.drop_ref();
154     }
155 
set_waker_inner(&self, des_waker: Waker, cur_state: usize) -> Result<usize, usize>156     fn set_waker_inner(&self, des_waker: Waker, cur_state: usize) -> Result<usize, usize> {
157         if !state::is_care_join_handle(cur_state) || state::is_set_waker(cur_state) {
158             panic!("set waker failed: the join handle either get dropped or the task already has a waker set");
159         }
160         unsafe {
161             let waker = self.inner().waker.get();
162             *waker = Some(des_waker);
163         }
164         let result = self.header().state.turn_to_set_waker();
165         if result.is_err() {
166             unsafe {
167                 let waker = self.inner().waker.get();
168                 *waker = None;
169             }
170         }
171         result
172     }
173 
set_waker(self, cur: usize, des_waker: &Waker) -> bool174     pub(crate) fn set_waker(self, cur: usize, des_waker: &Waker) -> bool {
175         let res = if state::is_set_waker(cur) {
176             let is_same_waker = unsafe {
177                 let waker = self.inner().waker.get();
178                 (*waker).as_ref().unwrap().will_wake(des_waker)
179             };
180             // we don't register the same waker
181             if is_same_waker {
182                 return false;
183             }
184             self.header()
185                 .state
186                 .turn_to_un_set_waker()
187                 .and_then(|cur| self.set_waker_inner(des_waker.clone(), cur))
188         } else {
189             self.set_waker_inner(des_waker.clone(), cur)
190         };
191 
192         if let Err(cur) = res {
193             if !state::is_finished(cur) {
194                 panic!("setting waker should only be failed due to the task's completion");
195             }
196             return true;
197         }
198 
199         false
200     }
201 
wake(self)202     pub(crate) fn wake(self) {
203         self.wake_by_ref();
204         self.drop_ref();
205     }
206 
wake_by_ref(&self)207     pub(crate) fn wake_by_ref(&self) {
208         let prev = self.header().state.turn_to_scheduling();
209         if state::need_enqueue(prev) {
210             self.get_scheduled(false);
211         }
212     }
213 
214     // Actually cancels the task during running
get_canceled(&self) -> ScheduleError215     fn get_canceled(&self) -> ScheduleError {
216         self.inner().turning_to_used_data();
217         ErrorKind::TaskCanceled.into()
218     }
219 
220     // Sets task state into canceled and scheduled
set_canceled(&self)221     pub(crate) fn set_canceled(&self) {
222         if self.header().state.turn_to_canceled_and_scheduled() {
223             self.get_scheduled(false);
224         }
225     }
226 
to_task(&self) -> Task227     fn to_task(&self) -> Task {
228         unsafe { Task::from_raw(self.header().into()) }
229     }
230 
get_scheduled(&self, lifo: bool)231     fn get_scheduled(&self, lifo: bool) {
232         self.inner()
233             .scheduler
234             .upgrade()
235             .unwrap()
236             .schedule(self.to_task(), lifo);
237     }
238 }
239 
240 #[cfg(not(feature = "ffrt"))]
241 impl<T, S> TaskHandle<T, S>
242 where
243     T: Future,
244     S: Schedule,
245 {
shutdown(self)246     pub(crate) unsafe fn shutdown(self) {
247         self.header().state.set_cancel();
248         // Check if the JoinHandle gets dropped already. If JoinHandle is still there,
249         // wakes the JoinHandle.
250         let cur = self.header().state.get_current_state();
251         if state::is_care_join_handle(cur) {
252             let stage = self.inner().stage.get();
253             *stage = Stage::StoreData(Err(ErrorKind::TaskCanceled.into()));
254             self.header().state.set_running();
255             let _ = self.header().state.turning_to_finish();
256             if state::is_set_waker(cur) {
257                 self.inner().wake_join();
258             }
259             self.drop_ref();
260         }
261     }
262 }
263 
264 #[cfg(feature = "ffrt")]
265 impl<T, S> TaskHandle<T, S>
266 where
267     T: Future,
268     S: Schedule,
269 {
ffrt_finish(self, state: usize, output: Result<T::Output, ScheduleError>)270     fn ffrt_finish(self, state: usize, output: Result<T::Output, ScheduleError>) {
271         if state::is_care_join_handle(state) {
272             self.inner().send_result(output);
273         } else {
274             self.inner().turning_to_used_data();
275         }
276 
277         let cur = match self.header().state.turning_to_finish() {
278             Ok(cur) => cur,
279             Err(e) => panic!("{}", e.as_str()),
280         };
281 
282         if state::is_set_waker(cur) {
283             self.inner().wake_join();
284         }
285     }
286 
ffrt_run(self) -> bool287     pub(crate) fn ffrt_run(self) -> bool {
288         self.inner().get_task_ctx();
289 
290         match self.header().state.turning_to_running() {
291             StateAction::Failed(state) => panic!("turning to running failed: {:b}", state),
292             StateAction::Canceled(cur) => {
293                 let output = self.ffrt_get_canceled();
294                 self.ffrt_finish(cur, Err(output));
295                 return true;
296             }
297             _ => {}
298         }
299 
300         let waker = WakerRefHeader::<'_>::new::<T>(self.header());
301         let mut context = Context::from_waker(&waker);
302 
303         let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
304             self.inner().poll(&mut context).map(Ok)
305         }));
306 
307         let cur = self.header().state.get_current_state();
308         match res {
309             Ok(Poll::Ready(output)) => {
310                 // send result if the JoinHandle is not dropped
311                 self.ffrt_finish(cur, output);
312                 true
313             }
314 
315             Ok(Poll::Pending) => match self.header().state.turning_to_idle() {
316                 StateAction::Enqueue => {
317                     let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
318                     ffrt_task.wake_task();
319                     false
320                 }
321                 StateAction::Failed(state) => panic!("task state invalid: {:b}", state),
322                 StateAction::Canceled(state) => {
323                     let output = self.ffrt_get_canceled();
324                     self.ffrt_finish(state, Err(output));
325                     true
326                 }
327                 _ => false,
328             },
329 
330             Err(_) => {
331                 let output = Err(ScheduleError::new(ErrorKind::Panic, "panic happen"));
332                 self.ffrt_finish(cur, output);
333                 true
334             }
335         }
336     }
337 
ffrt_wake(self)338     pub(crate) fn ffrt_wake(self) {
339         self.ffrt_wake_by_ref();
340         self.drop_ref();
341     }
342 
ffrt_wake_by_ref(&self)343     pub(crate) fn ffrt_wake_by_ref(&self) {
344         let prev = self.header().state.turn_to_scheduling();
345         if state::need_enqueue(prev) {
346             let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
347             ffrt_task.wake_task();
348         }
349     }
350 
351     // Actually cancels the task during running
ffrt_get_canceled(&self) -> ScheduleError352     fn ffrt_get_canceled(&self) -> ScheduleError {
353         self.inner().turning_to_used_data();
354         ErrorKind::TaskCanceled.into()
355     }
356 
357     // Sets task state into canceled and scheduled
ffrt_set_canceled(&self)358     pub(crate) fn ffrt_set_canceled(&self) {
359         if self.header().state.turn_to_canceled_and_scheduled() {
360             let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
361             ffrt_task.wake_task();
362         }
363     }
364 }
365