1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "threading/parallel_task_queue.h"
17
18 #include <algorithm>
19
20 #include <base/containers/array_view.h>
21 #include <base/containers/iterator.h>
22 #include <base/containers/refcnt_ptr.h>
23 #include <base/containers/type_traits.h>
24 #include <base/containers/unique_ptr.h>
25 #include <base/containers/unordered_map.h>
26 #include <base/containers/vector.h>
27 #include <core/log.h>
28 #include <core/namespace.h>
29 #include <core/threading/intf_thread_pool.h>
30
31 CORE_BEGIN_NAMESPACE()
32 using BASE_NS::array_view;
33 using BASE_NS::unordered_map;
34 using BASE_NS::vector;
35
36 // -- Parallel task queue.
ParallelTaskQueue(const IThreadPool::Ptr & threadPool)37 ParallelTaskQueue::ParallelTaskQueue(const IThreadPool::Ptr& threadPool) : TaskQueue(threadPool) {}
38
~ParallelTaskQueue()39 ParallelTaskQueue::~ParallelTaskQueue()
40 {
41 Wait();
42 }
43
Submit(uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)44 void ParallelTaskQueue::Submit(uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
45 {
46 CORE_ASSERT(std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier) == tasks_.cend());
47
48 tasks_.emplace_back(taskIdentifier, std::move(task));
49 }
50
SubmitAfter(uint64_t afterIdentifier,uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)51 void ParallelTaskQueue::SubmitAfter(uint64_t afterIdentifier, uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
52 {
53 CORE_ASSERT(std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier) == tasks_.cend());
54
55 auto it = std::find(tasks_.begin(), tasks_.end(), afterIdentifier);
56 if (it != tasks_.end()) {
57 Entry entry(taskIdentifier, std::move(task));
58 entry.dependencies.push_back(afterIdentifier);
59
60 tasks_.push_back(std::move(entry));
61 } else {
62 tasks_.emplace_back(taskIdentifier, std::move(task));
63 }
64 }
65
SubmitAfter(array_view<const uint64_t> afterIdentifiers,uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)66 void ParallelTaskQueue::SubmitAfter(
67 array_view<const uint64_t> afterIdentifiers, uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
68 {
69 if (std::all_of(
70 afterIdentifiers.cbegin(), afterIdentifiers.cend(), [&tasks = tasks_](const uint64_t afterIdentifier) {
71 return std::any_of(tasks.cbegin(), tasks.cend(),
72 [afterIdentifier](const TaskQueue::Entry& entry) { return entry.identifier == afterIdentifier; });
73 })) {
74 Entry entry(taskIdentifier, std::move(task));
75 entry.dependencies.insert(entry.dependencies.cend(), afterIdentifiers.begin(), afterIdentifiers.end());
76
77 tasks_.push_back(std::move(entry));
78 } else {
79 tasks_.emplace_back(taskIdentifier, std::move(task));
80 }
81 }
82
Remove(uint64_t taskIdentifier)83 void ParallelTaskQueue::Remove(uint64_t taskIdentifier)
84 {
85 auto it = std::find(tasks_.cbegin(), tasks_.cend(), taskIdentifier);
86 if (it != tasks_.cend()) {
87 tasks_.erase(it);
88 }
89 }
90
Clear()91 void ParallelTaskQueue::Clear()
92 {
93 Wait();
94 tasks_.clear();
95 }
96
Execute()97 void ParallelTaskQueue::Execute()
98 {
99 #if (CORE_VALIDATION_ENABLED == 1)
100 // NOTE: Check the integrity of the task queue (no circular deps etc.)
101 #endif
102 // gather dependencies for each task
103 vector<vector<const CORE_NS::IThreadPool::ITask*>> dependencies;
104 dependencies.reserve(tasks_.size());
105 for (auto& task : tasks_) {
106 auto& deps = dependencies.emplace_back();
107 for (const auto& dependency : task.dependencies) {
108 if (auto pos = std::find_if(tasks_.cbegin(), tasks_.cend(),
109 [dependency](const Entry &entry) { return entry.identifier == dependency; });
110 pos != tasks_.cend()) {
111 deps.push_back(pos->task.get());
112 }
113 }
114 }
115
116 // submit each task with its dependency information. threadpool will run a task when the dependencies are ready. now
117 // we have an IResult for every task, but we could use PushNowWait for tasks that are leafs.
118 vector<CORE_NS::IThreadPool::IResult::Ptr> states;
119 states.reserve(tasks_.size());
120 std::transform(std::begin(tasks_), std::end(tasks_), std::begin(dependencies), std::back_inserter(states),
121 [this](TaskQueue::Entry& entry, const vector<const CORE_NS::IThreadPool::ITask*>& dependencies) {
122 if (dependencies.empty()) {
123 return threadPool_->Push(BASE_NS::move(entry.task));
124 }
125 return threadPool_->Push(BASE_NS::move(entry.task), dependencies);
126 });
127 tasks_.clear();
128
129 // wait for tasks to complete.
130 for (const auto& state : states) {
131 state->Wait();
132 }
133 }
134 CORE_END_NAMESPACE()
135