1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use super::pariter::{Consumer, ParallelIterator};
15 use crate::error::ScheduleError;
16 use crate::executor::{global_default_async, AsyncHandle};
17 use crate::task::{JoinHandle, TaskBuilder};
18 cfg_not_ffrt! {
19 use crate::executor::{async_pool::AsyncPoolSpawner};
20 }
21
22 cfg_ffrt! {
23 use crate::executor::PlaceholderScheduler;
24 use crate::ffrt::spawner::ffrt_submit;
25 use crate::task::{Task, VirtualTableType};
26 use crate::util::num_cpus::get_cpu_num;
27 use std::sync::Weak;
28 }
29
core<P, C>(par_iter: P, consumer: C) -> Result<C::Output, ScheduleError> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,30 pub(crate) async fn core<P, C>(par_iter: P, consumer: C) -> Result<C::Output, ScheduleError>
31 where
32 P: ParallelIterator + Send,
33 C: Consumer<P> + Send + Sync,
34 {
35 let task_builder = TaskBuilder::new();
36 let runtime = global_default_async();
37
38 match runtime.async_spawner {
39 #[cfg(feature = "current_thread_runtime")]
40 AsyncHandle::CurrentThread(_) => Ok(consumer.consume(par_iter)),
41 #[cfg(not(feature = "ffrt"))]
42 AsyncHandle::MultiThread(ref runtime) => {
43 const MIN_SPLIT_LEN: usize = 1;
44 let split_time = runtime.exe_mng_info.num_workers;
45 recur(
46 runtime,
47 &task_builder,
48 par_iter,
49 &consumer,
50 MIN_SPLIT_LEN,
51 split_time,
52 )
53 .await
54 }
55 #[cfg(feature = "ffrt")]
56 AsyncHandle::FfrtMultiThread => {
57 const MIN_SPLIT_LEN: usize = 1;
58 let split_time = get_cpu_num() as usize;
59 recur_ffrt(
60 &task_builder,
61 par_iter,
62 &consumer,
63 MIN_SPLIT_LEN,
64 split_time,
65 )
66 .await
67 }
68 }
69 }
70
71 #[cfg(not(feature = "ffrt"))]
recur<P, C>( runtime: &AsyncPoolSpawner, task_builder: &TaskBuilder, par_iter: P, consumer: &C, min_split_len: usize, mut split_time: usize, ) -> Result<C::Output, ScheduleError> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,72 async fn recur<P, C>(
73 runtime: &AsyncPoolSpawner,
74 task_builder: &TaskBuilder,
75 par_iter: P,
76 consumer: &C,
77 min_split_len: usize,
78 mut split_time: usize,
79 ) -> Result<C::Output, ScheduleError>
80 where
81 P: ParallelIterator + Send,
82 C: Consumer<P> + Send + Sync,
83 {
84 if (par_iter.len() >> 1) <= min_split_len || split_time == 0 {
85 return Ok(consumer.consume(par_iter));
86 }
87 let (left, right) = par_iter.split();
88 let right = match right {
89 Some(a) => a,
90 None => {
91 return Ok(consumer.consume(left));
92 }
93 };
94 split_time >>= 1;
95 unsafe {
96 let left = spawn_task(
97 runtime,
98 task_builder,
99 left,
100 consumer,
101 min_split_len,
102 split_time,
103 );
104 let right = spawn_task(
105 runtime,
106 task_builder,
107 right,
108 consumer,
109 min_split_len,
110 split_time,
111 );
112 let left = left.await??;
113 let right = right.await??;
114 Ok(C::combine(left, right))
115 }
116 }
117
118 // Safety
119 // No restriction on lifetime to static, so it must be ensured that the data
120 // pointed to is always valid until the execution is completed, in other word
121 // .await the join handle after it is created.
122 #[cfg(not(feature = "ffrt"))]
123 #[inline]
spawn_task<P, C>( runtime: &AsyncPoolSpawner, task_builder: &TaskBuilder, par_iter: P, consumer: &C, min_split_len: usize, split_time: usize, ) -> JoinHandle<Result<C::Output, ScheduleError>> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,124 unsafe fn spawn_task<P, C>(
125 runtime: &AsyncPoolSpawner,
126 task_builder: &TaskBuilder,
127 par_iter: P,
128 consumer: &C,
129 min_split_len: usize,
130 split_time: usize,
131 ) -> JoinHandle<Result<C::Output, ScheduleError>>
132 where
133 P: ParallelIterator + Send,
134 C: Consumer<P> + Send + Sync,
135 {
136 runtime.spawn_with_ref(
137 task_builder,
138 recur(
139 runtime,
140 task_builder,
141 par_iter,
142 consumer,
143 min_split_len,
144 split_time,
145 ),
146 )
147 }
148
149 #[cfg(feature = "ffrt")]
recur_ffrt<P, C>( task_builder: &TaskBuilder, par_iter: P, consumer: &C, min_split_len: usize, mut split_time: usize, ) -> Result<C::Output, ScheduleError> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,150 async fn recur_ffrt<P, C>(
151 task_builder: &TaskBuilder,
152 par_iter: P,
153 consumer: &C,
154 min_split_len: usize,
155 mut split_time: usize,
156 ) -> Result<C::Output, ScheduleError>
157 where
158 P: ParallelIterator + Send,
159 C: Consumer<P> + Send + Sync,
160 {
161 if (par_iter.len() >> 1) <= min_split_len || split_time == 0 {
162 return Ok(consumer.consume(par_iter));
163 }
164 let (left, right) = par_iter.split();
165 let right = match right {
166 Some(a) => a,
167 None => {
168 return Ok(consumer.consume(left));
169 }
170 };
171 split_time >>= 1;
172 unsafe {
173 let left = spawn_task_ffrt(task_builder, left, consumer, min_split_len, split_time);
174 let right = spawn_task_ffrt(task_builder, right, consumer, min_split_len, split_time);
175 let left = left.await??;
176 let right = right.await??;
177 Ok(C::combine(left, right))
178 }
179 }
180
181 // Safety
182 // No restriction on lifetime to static, so it must be ensured that the data
183 // pointed to is always valid until the execution is completed, in other word
184 // .await the join handle after it is created.
185 #[cfg(feature = "ffrt")]
186 #[inline]
spawn_task_ffrt<P, C>( task_builder: &TaskBuilder, par_iter: P, consumer: &C, min_split_len: usize, split_time: usize, ) -> JoinHandle<Result<C::Output, ScheduleError>> where P: ParallelIterator + Send, C: Consumer<P> + Send + Sync,187 unsafe fn spawn_task_ffrt<P, C>(
188 task_builder: &TaskBuilder,
189 par_iter: P,
190 consumer: &C,
191 min_split_len: usize,
192 split_time: usize,
193 ) -> JoinHandle<Result<C::Output, ScheduleError>>
194 where
195 P: ParallelIterator + Send,
196 C: Consumer<P> + Send + Sync,
197 {
198 let scheduler: Weak<PlaceholderScheduler> = Weak::new();
199 let raw_task = Task::create_raw_task(
200 task_builder,
201 scheduler,
202 recur_ffrt(task_builder, par_iter, consumer, min_split_len, split_time),
203 VirtualTableType::Ffrt,
204 );
205 let join_handle = JoinHandle::new(raw_task);
206 let task = Task(raw_task);
207 ffrt_submit(task, task_builder);
208 join_handle
209 }
210