• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 mod keeper;
15 mod running_task;
16 use std::collections::{HashMap, HashSet};
17 use std::sync::atomic::{AtomicBool, Ordering};
18 use std::sync::Arc;
19 
20 use keeper::SAKeeper;
21 
22 cfg_oh! {
23     use crate::ability::SYSTEM_CONFIG_MANAGER;
24 }
25 use ylong_runtime::task::JoinHandle;
26 
27 use crate::config::Mode;
28 use crate::error::ErrorCode;
29 use crate::manage::database::RequestDb;
30 use crate::manage::events::{TaskEvent, TaskManagerEvent};
31 use crate::manage::scheduler::qos::{QosChanges, QosDirection};
32 use crate::manage::scheduler::queue::running_task::RunningTask;
33 use crate::manage::task_manager::TaskManagerTx;
34 use crate::service::active_counter::ActiveCounter;
35 use crate::service::client::ClientManagerEntry;
36 use crate::service::run_count::RunCountManagerEntry;
37 use crate::task::config::Action;
38 use crate::task::info::State;
39 use crate::task::reason::Reason;
40 use crate::task::request_task::RequestTask;
41 use crate::utils::runtime_spawn;
42 
43 pub(crate) struct RunningQueue {
44     download_queue: HashMap<(u64, u32), Arc<RequestTask>>,
45     upload_queue: HashMap<(u64, u32), Arc<RequestTask>>,
46     running_tasks: HashMap<(u64, u32), Option<AbortHandle>>,
47     keeper: SAKeeper,
48     tx: TaskManagerTx,
49     run_count_manager: RunCountManagerEntry,
50     client_manager: ClientManagerEntry,
51     // paused and then resume upload task need to upload from the breakpoint
52     pub(crate) upload_resume: HashSet<u32>,
53 }
54 
55 impl RunningQueue {
new( tx: TaskManagerTx, run_count_manager: RunCountManagerEntry, client_manager: ClientManagerEntry, active_counter: ActiveCounter, ) -> Self56     pub(crate) fn new(
57         tx: TaskManagerTx,
58         run_count_manager: RunCountManagerEntry,
59         client_manager: ClientManagerEntry,
60         active_counter: ActiveCounter,
61     ) -> Self {
62         Self {
63             download_queue: HashMap::new(),
64             upload_queue: HashMap::new(),
65             keeper: SAKeeper::new(tx.clone(), active_counter),
66             tx,
67             running_tasks: HashMap::new(),
68             run_count_manager,
69             client_manager,
70             upload_resume: HashSet::new(),
71         }
72     }
73 
get_task(&self, uid: u64, task_id: u32) -> Option<&Arc<RequestTask>>74     pub(crate) fn get_task(&self, uid: u64, task_id: u32) -> Option<&Arc<RequestTask>> {
75         self.download_queue
76             .get(&(uid, task_id))
77             .or_else(|| self.upload_queue.get(&(uid, task_id)))
78     }
79 
get_task_clone(&self, uid: u64, task_id: u32) -> Option<Arc<RequestTask>>80     pub(crate) fn get_task_clone(&self, uid: u64, task_id: u32) -> Option<Arc<RequestTask>> {
81         self.download_queue
82             .get(&(uid, task_id))
83             .cloned()
84             .or_else(|| self.upload_queue.get(&(uid, task_id)).cloned())
85     }
86 
task_finish(&mut self, uid: u64, task_id: u32)87     pub(crate) fn task_finish(&mut self, uid: u64, task_id: u32) {
88         self.running_tasks.remove(&(uid, task_id));
89     }
90 
try_restart(&mut self, uid: u64, task_id: u32) -> bool91     pub(crate) fn try_restart(&mut self, uid: u64, task_id: u32) -> bool {
92         if let Some(task) = self
93             .download_queue
94             .get(&(uid, task_id))
95             .or(self.upload_queue.get(&(uid, task_id)))
96         {
97             if self.running_tasks.contains_key(&(uid, task_id)) {
98                 return false;
99             }
100             info!("{} restart running", task_id);
101             let running_task = RunningTask::new(task.clone(), self.tx.clone(), self.keeper.clone());
102             let abort_flag = Arc::new(AtomicBool::new(false));
103             let abort_flag_clone = abort_flag.clone();
104             let join_handle = runtime_spawn(async move {
105                 running_task.run(abort_flag_clone.clone()).await;
106             });
107             let uid = task.uid();
108             let task_id = task.task_id();
109             self.running_tasks.insert(
110                 (uid, task_id),
111                 Some(AbortHandle::new(abort_flag, join_handle)),
112             );
113             true
114         } else {
115             false
116         }
117     }
118 
tasks(&self) -> impl Iterator<Item = &Arc<RequestTask>>119     pub(crate) fn tasks(&self) -> impl Iterator<Item = &Arc<RequestTask>> {
120         self.download_queue
121             .values()
122             .chain(self.upload_queue.values())
123     }
124 
running_tasks(&self) -> usize125     pub(crate) fn running_tasks(&self) -> usize {
126         self.download_queue.len() + self.upload_queue.len()
127     }
128 
reschedule(&mut self, qos: QosChanges, qos_remove_queue: &mut Vec<(u64, u32)>)129     pub(crate) fn reschedule(&mut self, qos: QosChanges, qos_remove_queue: &mut Vec<(u64, u32)>) {
130         if let Some(vec) = qos.download {
131             self.reschedule_inner(Action::Download, vec, qos_remove_queue)
132         }
133         if let Some(vec) = qos.upload {
134             self.reschedule_inner(Action::Upload, vec, qos_remove_queue)
135         }
136     }
137 
reschedule_inner( &mut self, action: Action, qos_vec: Vec<QosDirection>, qos_remove_queue: &mut Vec<(u64, u32)>, )138     pub(crate) fn reschedule_inner(
139         &mut self,
140         action: Action,
141         qos_vec: Vec<QosDirection>,
142         qos_remove_queue: &mut Vec<(u64, u32)>,
143     ) {
144         let mut new_queue = HashMap::new();
145 
146         let queue = if action == Action::Download {
147             &mut self.download_queue
148         } else {
149             &mut self.upload_queue
150         };
151 
152         // We need to decide which tasks need to continue running based on `QosChanges`.
153         for qos_direction in qos_vec.iter() {
154             let uid = qos_direction.uid();
155             let task_id = qos_direction.task_id();
156 
157             if let Some(task) = queue.remove(&(uid, task_id)) {
158                 // If we can find that the task is running in `running_tasks`,
159                 // we just need to adjust its rate.
160                 task.speed_limit(qos_direction.direction() as u64);
161                 // Then we put it into `satisfied_tasks`.
162                 new_queue.insert((uid, task_id), task);
163                 continue;
164             }
165 
166             // If the task is not in the current running queue, retrieve
167             // the corresponding task from the database and start it.
168 
169             #[cfg(feature = "oh")]
170             let system_config = unsafe { SYSTEM_CONFIG_MANAGER.assume_init_ref().system_config() };
171             let upload_resume = self.upload_resume.remove(&task_id);
172 
173             let task = match RequestDb::get_instance().get_task(
174                 task_id,
175                 #[cfg(feature = "oh")]
176                 system_config,
177                 &self.client_manager,
178                 upload_resume,
179             ) {
180                 Ok(task) => task,
181                 Err(ErrorCode::TaskNotFound) => continue, // If we cannot find the task, skip it.
182                 Err(ErrorCode::TaskStateErr) => continue, // If we cannot find the task, skip it.
183                 Err(e) => {
184                     info!("get task {} error:{:?}", task_id, e);
185                     if let Some(info) = RequestDb::get_instance().get_task_qos_info(task_id) {
186                         self.tx.send_event(TaskManagerEvent::Task(TaskEvent::Failed(
187                             task_id,
188                             uid,
189                             Reason::OthersError,
190                             Mode::from(info.mode),
191                         )));
192                     }
193                     qos_remove_queue.push((uid, task_id));
194                     continue;
195                 }
196             };
197             task.speed_limit(qos_direction.direction() as u64);
198 
199             new_queue.insert((uid, task_id), task.clone());
200 
201             if self.running_tasks.contains_key(&(uid, task_id)) {
202                 info!("task {} not finished", task_id);
203                 continue;
204             }
205 
206             info!("{} create running", task_id);
207             let running_task = RunningTask::new(task.clone(), self.tx.clone(), self.keeper.clone());
208             RequestDb::get_instance().update_task_state(
209                 running_task.task_id(),
210                 State::Running,
211                 Reason::Default,
212             );
213             let abort_flag = Arc::new(AtomicBool::new(false));
214             let abort_flag_clone = abort_flag.clone();
215             let join_handle = runtime_spawn(async move {
216                 running_task.run(abort_flag_clone).await;
217             });
218 
219             let uid = task.uid();
220             let task_id = task.task_id();
221             self.running_tasks.insert(
222                 (uid, task_id),
223                 Some(AbortHandle::new(abort_flag, join_handle)),
224             );
225         }
226         // every satisfied tasks in running has been moved, set left tasks to Waiting
227 
228         for task in queue.values() {
229             if let Some(join_handle) = self.running_tasks.get_mut(&(task.uid(), task.task_id())) {
230                 if let Some(join_handle) = join_handle.take() {
231                     join_handle.cancel();
232                 };
233             }
234         }
235         *queue = new_queue;
236 
237         #[cfg(feature = "oh")]
238         self.run_count_manager
239             .notify_run_count(self.download_queue.len() + self.upload_queue.len());
240     }
241 
retry_all_tasks(&mut self)242     pub(crate) fn retry_all_tasks(&mut self) {
243         for task in self.running_tasks.iter_mut() {
244             if let Some(handle) = task.1.take() {
245                 handle.cancel();
246             }
247         }
248     }
249 
cancel_task(&mut self, task_id: u32, uid: u64) -> bool250     pub(crate) fn cancel_task(&mut self, task_id: u32, uid: u64) -> bool {
251         let handle = match self
252             .running_tasks
253             .get_mut(&(uid, task_id))
254             .and_then(|task| task.take())
255         {
256             Some(h) => h,
257             None => return false,
258         };
259         let task = match self
260             .upload_queue
261             .get(&(uid, task_id))
262             .or_else(|| self.download_queue.get(&(uid, task_id)))
263         {
264             Some(t) => t,
265             None => {
266                 return false;
267             }
268         };
269 
270         let progress_lock = task.progress.lock().unwrap();
271         handle.cancel();
272         drop(progress_lock);
273 
274         task.update_progress_in_database();
275         true
276     }
277 }
278 
279 struct AbortHandle {
280     abort_flag: Arc<AtomicBool>,
281     join_handle: JoinHandle<()>,
282 }
283 
284 impl AbortHandle {
new(abort_flag: Arc<AtomicBool>, join_handle: JoinHandle<()>) -> Self285     fn new(abort_flag: Arc<AtomicBool>, join_handle: JoinHandle<()>) -> Self {
286         Self {
287             abort_flag,
288             join_handle,
289         }
290     }
cancel(self)291     fn cancel(self) {
292         self.abort_flag.store(true, Ordering::Release);
293         self.join_handle.cancel();
294     }
295 }
296