1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use std::future::Future; 15 use std::panic; 16 use std::ptr::NonNull; 17 use std::task::{Context, Poll, Waker}; 18 19 use crate::error::{ErrorKind, ScheduleError}; 20 use crate::executor::Schedule; 21 use crate::task::raw::{Header, Inner, TaskMngInfo}; 22 use crate::task::state::StateAction; 23 use crate::task::waker::WakerRefHeader; 24 use crate::task::{state, Task}; 25 26 cfg_not_ffrt! { 27 use crate::task::raw::Stage; 28 } 29 30 pub(crate) struct TaskHandle<T: Future, S: Schedule> { 31 task: NonNull<TaskMngInfo<T, S>>, 32 } 33 34 impl<T, S> TaskHandle<T, S> 35 where 36 T: Future, 37 S: Schedule, 38 { from_raw(ptr: NonNull<Header>) -> Self39 pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Self { 40 TaskHandle { 41 task: ptr.cast::<TaskMngInfo<T, S>>(), 42 } 43 } 44 header(&self) -> &Header45 fn header(&self) -> &Header { 46 unsafe { self.task.as_ref().header() } 47 } 48 inner(&self) -> &Inner<T, S>49 fn inner(&self) -> &Inner<T, S> { 50 unsafe { self.task.as_ref().inner() } 51 } 52 } 53 54 impl<T, S> TaskHandle<T, S> 55 where 56 T: Future, 57 S: Schedule, 58 { release(self)59 pub(crate) fn release(self) { 60 unsafe { drop(Box::from_raw(self.task.as_ptr())) }; 61 } 62 drop_ref(self)63 pub(crate) fn drop_ref(self) { 64 let prev = self.header().state.dec_ref(); 65 if state::is_last_ref_count(prev) { 66 self.release(); 67 } 68 } 69 finish(self, state: usize, output: Result<T::Output, ScheduleError>)70 fn finish(self, state: usize, output: Result<T::Output, ScheduleError>) { 71 // send result if the JoinHandle is not dropped 72 if state::is_care_join_handle(state) { 73 self.inner().send_result(output); 74 } else { 75 self.inner().turning_to_used_data(); 76 } 77 78 let res = self.header().state.turning_to_finish(); 79 let cur = match res { 80 Ok(cur) => cur, 81 Err(e) => panic!("{}", e.as_str()), 82 }; 83 84 if state::is_set_waker(cur) { 85 self.inner().wake_join(); 86 } 87 self.drop_ref(); 88 } 89 90 // Runs the task run(self)91 pub(crate) fn run(self) { 92 let action = self.header().state.turning_to_running(); 93 94 match action { 95 StateAction::Success => {} 96 StateAction::Canceled(cur) => { 97 let output = self.get_canceled(); 98 return self.finish(cur, Err(output)); 99 } 100 StateAction::Failed(state) => panic!("task state invalid: {}", state), 101 _ => unreachable!(), 102 }; 103 104 // turn the task header into a waker 105 let waker = WakerRefHeader::<'_>::new::<T>(self.header()); 106 let mut context = Context::from_waker(&waker); 107 108 let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { 109 self.inner().poll(&mut context).map(Ok) 110 })); 111 112 let cur = self.header().state.get_current_state(); 113 match res { 114 Ok(Poll::Ready(output)) => { 115 // send result if the JoinHandle is not dropped 116 self.finish(cur, output); 117 } 118 119 Ok(Poll::Pending) => match self.header().state.turning_to_idle() { 120 StateAction::Enqueue => { 121 self.get_scheduled(true); 122 } 123 StateAction::Failed(state) => panic!("task state invalid: {}", state), 124 StateAction::Canceled(state) => { 125 let output = self.get_canceled(); 126 self.finish(state, Err(output)); 127 } 128 _ => {} 129 }, 130 131 Err(_) => { 132 let output = Err(ScheduleError::new(ErrorKind::Panic, "panic happen")); 133 self.finish(cur, output); 134 } 135 } 136 } 137 get_result(self, out: &mut Poll<std::result::Result<T::Output, ScheduleError>>)138 pub(crate) fn get_result(self, out: &mut Poll<std::result::Result<T::Output, ScheduleError>>) { 139 *out = Poll::Ready(self.inner().turning_to_get_data()); 140 } 141 drop_join_handle(self)142 pub(crate) fn drop_join_handle(self) { 143 if self.header().state.try_turning_to_un_join_handle() { 144 return; 145 } 146 147 match self.header().state.turn_to_un_join_handle() { 148 Ok(_) => {} 149 Err(_) => { 150 self.inner().turning_to_used_data(); 151 } 152 } 153 self.drop_ref(); 154 } 155 set_waker_inner(&self, des_waker: Waker, cur_state: usize) -> Result<usize, usize>156 fn set_waker_inner(&self, des_waker: Waker, cur_state: usize) -> Result<usize, usize> { 157 if !state::is_care_join_handle(cur_state) || state::is_set_waker(cur_state) { 158 panic!("set waker failed: the join handle either get dropped or the task already has a waker set"); 159 } 160 unsafe { 161 let waker = self.inner().waker.get(); 162 *waker = Some(des_waker); 163 } 164 let result = self.header().state.turn_to_set_waker(); 165 if result.is_err() { 166 unsafe { 167 let waker = self.inner().waker.get(); 168 *waker = None; 169 } 170 } 171 result 172 } 173 set_waker(self, cur: usize, des_waker: &Waker) -> bool174 pub(crate) fn set_waker(self, cur: usize, des_waker: &Waker) -> bool { 175 let res = if state::is_set_waker(cur) { 176 let is_same_waker = unsafe { 177 let waker = self.inner().waker.get(); 178 (*waker).as_ref().unwrap().will_wake(des_waker) 179 }; 180 // we don't register the same waker 181 if is_same_waker { 182 return false; 183 } 184 self.header() 185 .state 186 .turn_to_un_set_waker() 187 .and_then(|cur| self.set_waker_inner(des_waker.clone(), cur)) 188 } else { 189 self.set_waker_inner(des_waker.clone(), cur) 190 }; 191 192 if let Err(cur) = res { 193 if !state::is_finished(cur) { 194 panic!("setting waker should only be failed due to the task's completion"); 195 } 196 return true; 197 } 198 199 false 200 } 201 wake(self)202 pub(crate) fn wake(self) { 203 self.wake_by_ref(); 204 self.drop_ref(); 205 } 206 wake_by_ref(&self)207 pub(crate) fn wake_by_ref(&self) { 208 let prev = self.header().state.turn_to_scheduling(); 209 if state::need_enqueue(prev) { 210 self.get_scheduled(false); 211 } 212 } 213 214 // Actually cancels the task during running get_canceled(&self) -> ScheduleError215 fn get_canceled(&self) -> ScheduleError { 216 self.inner().turning_to_used_data(); 217 ErrorKind::TaskCanceled.into() 218 } 219 220 // Sets task state into canceled and scheduled set_canceled(&self)221 pub(crate) fn set_canceled(&self) { 222 if self.header().state.turn_to_canceled_and_scheduled() { 223 self.get_scheduled(false); 224 } 225 } 226 to_task(&self) -> Task227 fn to_task(&self) -> Task { 228 unsafe { Task::from_raw(self.header().into()) } 229 } 230 get_scheduled(&self, lifo: bool)231 fn get_scheduled(&self, lifo: bool) { 232 self.inner() 233 .scheduler 234 .upgrade() 235 .unwrap() 236 .schedule(self.to_task(), lifo); 237 } 238 } 239 240 #[cfg(not(feature = "ffrt"))] 241 impl<T, S> TaskHandle<T, S> 242 where 243 T: Future, 244 S: Schedule, 245 { shutdown(self)246 pub(crate) unsafe fn shutdown(self) { 247 self.header().state.set_cancel(); 248 // Check if the JoinHandle gets dropped already. If JoinHandle is still there, 249 // wakes the JoinHandle. 250 let cur = self.header().state.get_current_state(); 251 if state::is_care_join_handle(cur) { 252 let stage = self.inner().stage.get(); 253 *stage = Stage::StoreData(Err(ErrorKind::TaskCanceled.into())); 254 self.header().state.set_running(); 255 let _ = self.header().state.turning_to_finish(); 256 if state::is_set_waker(cur) { 257 self.inner().wake_join(); 258 } 259 self.drop_ref(); 260 } 261 } 262 } 263 264 #[cfg(feature = "ffrt")] 265 impl<T, S> TaskHandle<T, S> 266 where 267 T: Future, 268 S: Schedule, 269 { ffrt_finish(self, state: usize, output: Result<T::Output, ScheduleError>)270 fn ffrt_finish(self, state: usize, output: Result<T::Output, ScheduleError>) { 271 if state::is_care_join_handle(state) { 272 self.inner().send_result(output); 273 } else { 274 self.inner().turning_to_used_data(); 275 } 276 277 let cur = match self.header().state.turning_to_finish() { 278 Ok(cur) => cur, 279 Err(e) => panic!("{}", e.as_str()), 280 }; 281 282 if state::is_set_waker(cur) { 283 self.inner().wake_join(); 284 } 285 } 286 ffrt_run(self) -> bool287 pub(crate) fn ffrt_run(self) -> bool { 288 self.inner().get_task_ctx(); 289 290 match self.header().state.turning_to_running() { 291 StateAction::Failed(state) => panic!("turning to running failed: {:b}", state), 292 StateAction::Canceled(cur) => { 293 let output = self.ffrt_get_canceled(); 294 self.ffrt_finish(cur, Err(output)); 295 return true; 296 } 297 _ => {} 298 } 299 300 let waker = WakerRefHeader::<'_>::new::<T>(self.header()); 301 let mut context = Context::from_waker(&waker); 302 303 let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { 304 self.inner().poll(&mut context).map(Ok) 305 })); 306 307 let cur = self.header().state.get_current_state(); 308 match res { 309 Ok(Poll::Ready(output)) => { 310 // send result if the JoinHandle is not dropped 311 self.ffrt_finish(cur, output); 312 true 313 } 314 315 Ok(Poll::Pending) => match self.header().state.turning_to_idle() { 316 StateAction::Enqueue => { 317 let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() }; 318 ffrt_task.wake_task(); 319 false 320 } 321 StateAction::Failed(state) => panic!("task state invalid: {:b}", state), 322 StateAction::Canceled(state) => { 323 let output = self.ffrt_get_canceled(); 324 self.ffrt_finish(state, Err(output)); 325 true 326 } 327 _ => false, 328 }, 329 330 Err(_) => { 331 let output = Err(ScheduleError::new(ErrorKind::Panic, "panic happen")); 332 self.ffrt_finish(cur, output); 333 true 334 } 335 } 336 } 337 ffrt_wake(self)338 pub(crate) fn ffrt_wake(self) { 339 self.ffrt_wake_by_ref(); 340 self.drop_ref(); 341 } 342 ffrt_wake_by_ref(&self)343 pub(crate) fn ffrt_wake_by_ref(&self) { 344 let prev = self.header().state.turn_to_scheduling(); 345 if state::need_enqueue(prev) { 346 let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() }; 347 ffrt_task.wake_task(); 348 } 349 } 350 351 // Actually cancels the task during running ffrt_get_canceled(&self) -> ScheduleError352 fn ffrt_get_canceled(&self) -> ScheduleError { 353 self.inner().turning_to_used_data(); 354 ErrorKind::TaskCanceled.into() 355 } 356 357 // Sets task state into canceled and scheduled ffrt_set_canceled(&self)358 pub(crate) fn ffrt_set_canceled(&self) { 359 if self.header().state.turn_to_canceled_and_scheduled() { 360 let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() }; 361 ffrt_task.wake_task(); 362 } 363 } 364 } 365