• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use super::pariter::{Consumer, ParallelIterator};
15 use crate::error::ScheduleError;
16 use crate::executor::{global_default_async, AsyncHandle};
17 use crate::task::{JoinHandle, TaskBuilder};
18 cfg_not_ffrt! {
19     use crate::executor::{async_pool::AsyncPoolSpawner};
20 }
21 
22 cfg_ffrt! {
23     use crate::executor::PlaceholderScheduler;
24     use crate::ffrt::spawner::ffrt_submit;
25     use crate::task::{Task, VirtualTableType};
26     use crate::util::num_cpus::get_cpu_num;
27     use std::sync::Weak;
28 }
29 
core<P, C>(par_iter: P, consumer: C) -> Result<C::Output, ScheduleError> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,30 pub(crate) async fn core<P, C>(par_iter: P, consumer: C) -> Result<C::Output, ScheduleError>
31 where
32     P: ParallelIterator + Send,
33     C: Consumer<P> + Send + Sync,
34 {
35     let task_builder = TaskBuilder::new();
36     let runtime = global_default_async();
37 
38     match runtime.async_spawner {
39         #[cfg(feature = "current_thread_runtime")]
40         AsyncHandle::CurrentThread(_) => Ok(consumer.consume(par_iter)),
41         #[cfg(not(feature = "ffrt"))]
42         AsyncHandle::MultiThread(ref runtime) => {
43             const MIN_SPLIT_LEN: usize = 1;
44             let split_time = runtime.exe_mng_info.num_workers;
45             recur(
46                 runtime,
47                 &task_builder,
48                 par_iter,
49                 &consumer,
50                 MIN_SPLIT_LEN,
51                 split_time,
52             )
53             .await
54         }
55         #[cfg(feature = "ffrt")]
56         AsyncHandle::FfrtMultiThread => {
57             const MIN_SPLIT_LEN: usize = 1;
58             let split_time = get_cpu_num() as usize;
59             recur_ffrt(
60                 &task_builder,
61                 par_iter,
62                 &consumer,
63                 MIN_SPLIT_LEN,
64                 split_time,
65             )
66             .await
67         }
68     }
69 }
70 
71 #[cfg(not(feature = "ffrt"))]
recur<P, C>( runtime: &AsyncPoolSpawner, task_builder: &TaskBuilder, par_iter: P, consumer: &C, min_split_len: usize, mut split_time: usize, ) -> Result<C::Output, ScheduleError> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,72 async fn recur<P, C>(
73     runtime: &AsyncPoolSpawner,
74     task_builder: &TaskBuilder,
75     par_iter: P,
76     consumer: &C,
77     min_split_len: usize,
78     mut split_time: usize,
79 ) -> Result<C::Output, ScheduleError>
80 where
81     P: ParallelIterator + Send,
82     C: Consumer<P> + Send + Sync,
83 {
84     if (par_iter.len() >> 1) <= min_split_len || split_time == 0 {
85         return Ok(consumer.consume(par_iter));
86     }
87     let (left, right) = par_iter.split();
88     let right = match right {
89         Some(a) => a,
90         None => {
91             return Ok(consumer.consume(left));
92         }
93     };
94     split_time >>= 1;
95     unsafe {
96         let left = spawn_task(
97             runtime,
98             task_builder,
99             left,
100             consumer,
101             min_split_len,
102             split_time,
103         );
104         let right = spawn_task(
105             runtime,
106             task_builder,
107             right,
108             consumer,
109             min_split_len,
110             split_time,
111         );
112         let left = left.await??;
113         let right = right.await??;
114         Ok(C::combine(left, right))
115     }
116 }
117 
118 // Safety
119 // No restriction on lifetime to static, so it must be ensured that the data
120 // pointed to is always valid until the execution is completed, in other word
121 // .await the join handle after it is created.
122 #[cfg(not(feature = "ffrt"))]
123 #[inline]
spawn_task<P, C>( runtime: &AsyncPoolSpawner, task_builder: &TaskBuilder, par_iter: P, consumer: &C, min_split_len: usize, split_time: usize, ) -> JoinHandle<Result<C::Output, ScheduleError>> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,124 unsafe fn spawn_task<P, C>(
125     runtime: &AsyncPoolSpawner,
126     task_builder: &TaskBuilder,
127     par_iter: P,
128     consumer: &C,
129     min_split_len: usize,
130     split_time: usize,
131 ) -> JoinHandle<Result<C::Output, ScheduleError>>
132 where
133     P: ParallelIterator + Send,
134     C: Consumer<P> + Send + Sync,
135 {
136     runtime.spawn_with_ref(
137         task_builder,
138         recur(
139             runtime,
140             task_builder,
141             par_iter,
142             consumer,
143             min_split_len,
144             split_time,
145         ),
146     )
147 }
148 
149 #[cfg(feature = "ffrt")]
recur_ffrt<P, C>( task_builder: &TaskBuilder, par_iter: P, consumer: &C, min_split_len: usize, mut split_time: usize, ) -> Result<C::Output, ScheduleError> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,150 async fn recur_ffrt<P, C>(
151     task_builder: &TaskBuilder,
152     par_iter: P,
153     consumer: &C,
154     min_split_len: usize,
155     mut split_time: usize,
156 ) -> Result<C::Output, ScheduleError>
157 where
158     P: ParallelIterator + Send,
159     C: Consumer<P> + Send + Sync,
160 {
161     if (par_iter.len() >> 1) <= min_split_len || split_time == 0 {
162         return Ok(consumer.consume(par_iter));
163     }
164     let (left, right) = par_iter.split();
165     let right = match right {
166         Some(a) => a,
167         None => {
168             return Ok(consumer.consume(left));
169         }
170     };
171     split_time >>= 1;
172     unsafe {
173         let left = spawn_task_ffrt(task_builder, left, consumer, min_split_len, split_time);
174         let right = spawn_task_ffrt(task_builder, right, consumer, min_split_len, split_time);
175         let left = left.await??;
176         let right = right.await??;
177         Ok(C::combine(left, right))
178     }
179 }
180 
181 // Safety
182 // No restriction on lifetime to static, so it must be ensured that the data
183 // pointed to is always valid until the execution is completed, in other word
184 // .await the join handle after it is created.
185 #[cfg(feature = "ffrt")]
186 #[inline]
spawn_task_ffrt<P, C>( task_builder: &TaskBuilder, par_iter: P, consumer: &C, min_split_len: usize, split_time: usize, ) -> JoinHandle<Result<C::Output, ScheduleError>> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,187 unsafe fn spawn_task_ffrt<P, C>(
188     task_builder: &TaskBuilder,
189     par_iter: P,
190     consumer: &C,
191     min_split_len: usize,
192     split_time: usize,
193 ) -> JoinHandle<Result<C::Output, ScheduleError>>
194 where
195     P: ParallelIterator + Send,
196     C: Consumer<P> + Send + Sync,
197 {
198     let scheduler: Weak<PlaceholderScheduler> = Weak::new();
199     let raw_task = Task::create_raw_task(
200         task_builder,
201         scheduler,
202         recur_ffrt(task_builder, par_iter, consumer, min_split_len, split_time),
203         VirtualTableType::Ffrt,
204     );
205     let join_handle = JoinHandle::new(raw_task);
206     let task = Task(raw_task);
207     ffrt_submit(task, task_builder);
208     join_handle
209 }
210