1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef API_CORE_UTIL_PARALLEL_SORT_H
17 #define API_CORE_UTIL_PARALLEL_SORT_H
18
19 #include <algorithm>
20 #include <iterator>
21
22 #include <core/threading/intf_thread_pool.h>
23
CORE_BEGIN_NAMESPACE()24 CORE_BEGIN_NAMESPACE()
25
26 // Helper class for running lambda as a ThreadPool task.
27 template<typename Fn>
28 class FunctionTask final : public IThreadPool::ITask {
29 public:
30 explicit FunctionTask(Fn&& func) : func_(BASE_NS::move(func)) {};
31
32 void operator()() override
33 {
34 func_();
35 }
36
37 protected:
38 void Destroy() override
39 {
40 delete this;
41 }
42
43 private:
44 Fn func_;
45 };
46
47 template<typename Fn>
CreateFunctionTask(Fn && func)48 inline IThreadPool::ITask::Ptr CreateFunctionTask(Fn&& func)
49 {
50 return IThreadPool::ITask::Ptr { new FunctionTask<Fn>(BASE_NS::move(func)) };
51 }
52
53 template<class RandomIt, class Compare>
ParallelSort(RandomIt first,RandomIt last,Compare comp,IThreadPool * threadPool)54 void ParallelSort(RandomIt first, RandomIt last, Compare comp, IThreadPool* threadPool)
55 {
56 const auto totalSize = std::distance(first, last);
57 if (totalSize <= 1) {
58 return;
59 }
60
61 const auto numThreads = std::max(threadPool->GetNumberOfThreads(), 1u);
62 const auto partSize = (totalSize + numThreads - 1) / numThreads;
63
64 BASE_NS::vector<std::pair<RandomIt, RandomIt>> partitions;
65
66 // partition create
67 RandomIt partStart = first;
68 for (size_t ii = 0; ii < numThreads && partStart != last; ii++) {
69 RandomIt partEnd = partStart;
70 if (std::distance(partStart, last) > partSize) {
71 std::advance(partEnd, partSize);
72 } else {
73 partEnd = last;
74 }
75
76 partitions.push_back({ partStart, partEnd });
77 partStart = partEnd;
78 }
79
80 // partition parallel sort
81 BASE_NS::vector<IThreadPool::IResult::Ptr> sortResults;
82
83 for (const auto& part : partitions) {
84 auto task = CreateFunctionTask([part, &comp]() { std::sort(part.first, part.second, comp); });
85 auto result = threadPool->Push(BASE_NS::move(task));
86 sortResults.push_back(BASE_NS::move(result));
87 }
88
89 // sort task completion wait
90 for (auto& result : sortResults) {
91 result->Wait();
92 }
93
94 // partition merge
95 while (partitions.size() > 1) {
96 BASE_NS::vector<IThreadPool::IResult::Ptr> mergeResults;
97 BASE_NS::vector<std::pair<RandomIt, RandomIt>> newPartitions;
98
99 for (size_t ii = 0; ii + 1 < partitions.size(); ii += 2) { // 2: step
100 const auto begin1 = partitions[ii].first;
101 const auto end1 = partitions[ii].second;
102 const auto begin2 = partitions[ii + 1].first;
103 const auto end2 = partitions[ii + 1].second;
104
105 auto task =
106 CreateFunctionTask([begin1, end1, end2, &comp]() { std::inplace_merge(begin1, end1, end2, comp); });
107 auto result = threadPool->Push(BASE_NS::move(task));
108 mergeResults.push_back(BASE_NS::move(result));
109
110 newPartitions.push_back({ begin1, end2 });
111 }
112
113 if (partitions.size() % 2 == 1) { // 2: step
114 newPartitions.push_back(partitions.back());
115 }
116
117 // merge task completion wait
118 for (auto& result : mergeResults) {
119 result->Wait();
120 }
121
122 partitions = BASE_NS::move(newPartitions);
123 }
124 }
125
126 template<class RandomIt>
ParallelSort(RandomIt first,RandomIt last,IThreadPool * threadPool)127 void ParallelSort(RandomIt first, RandomIt last, IThreadPool* threadPool)
128 {
129 using ValueType = typename std::iterator_traits<RandomIt>::value_type;
130 ParallelSort(first, last, std::less<ValueType>(), threadPool);
131 }
132
133 CORE_END_NAMESPACE()
134
135 #endif
136