• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "threading/parallel_task_queue.h"
17 
18 #include <condition_variable>
19 #include <mutex>
20 #include <queue>
21 #include <thread>
22 
23 #include <base/containers/unordered_map.h>
24 #include <base/containers/vector.h>
25 #include <core/log.h>
26 #include <core/namespace.h>
27 
28 #include "os/platform.h"
29 
30 CORE_BEGIN_NAMESPACE()
31 using BASE_NS::unordered_map;
32 using BASE_NS::vector;
33 
34 struct ParallelTaskQueue::TaskState {
35     unordered_map<uint64_t, bool> finished;
36     std::condition_variable cv;
37     std::mutex mutex;
38 };
39 
40 class ParallelTaskQueue::Task final : public IThreadPool::ITask {
41 public:
42     explicit Task(TaskState& state, IThreadPool::ITask& task, uint64_t id);
43     ~Task() = default;
44 
45     void operator()() override;
46 
47 protected:
48     void Destroy() override;
49 
50 private:
51     TaskState& state_;
52     IThreadPool::ITask& task_;
53     uint64_t id_;
54 };
55 
Task(TaskState & state,IThreadPool::ITask & task,uint64_t id)56 ParallelTaskQueue::Task::Task(TaskState& state, IThreadPool::ITask& task, uint64_t id)
57     : state_(state), task_(task), id_(id)
58 {}
59 
operator ()()60 void ParallelTaskQueue::Task::operator()()
61 {
62     // Run task.
63     task_();
64 
65     // Mark task as completed.
66     std::unique_lock lock(state_.mutex);
67     state_.finished[id_] = true;
68 
69     // Notify that there is completed task.
70     state_.cv.notify_one();
71 }
72 
Destroy()73 void ParallelTaskQueue::Task::Destroy()
74 {
75     delete this;
76 }
77 
78 // -- Parallel task queue.
ParallelTaskQueue(const IThreadPool::Ptr & threadPool)79 ParallelTaskQueue::ParallelTaskQueue(const IThreadPool::Ptr& threadPool) : TaskQueue(threadPool) {}
80 
~ParallelTaskQueue()81 ParallelTaskQueue::~ParallelTaskQueue()
82 {
83     Wait();
84 }
85 
Submit(uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)86 void ParallelTaskQueue::Submit(uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
87 {
88     CORE_ASSERT(std::find(tasks_.begin(), tasks_.end(), taskIdentifier) == tasks_.end());
89 
90     tasks_.emplace_back(taskIdentifier, std::move(task));
91 }
92 
SubmitAfter(uint64_t afterIdentifier,uint64_t taskIdentifier,IThreadPool::ITask::Ptr && task)93 void ParallelTaskQueue::SubmitAfter(uint64_t afterIdentifier, uint64_t taskIdentifier, IThreadPool::ITask::Ptr&& task)
94 {
95     CORE_ASSERT(std::find(tasks_.begin(), tasks_.end(), taskIdentifier) == tasks_.end());
96 
97     auto it = std::find(tasks_.begin(), tasks_.end(), afterIdentifier);
98     if (it != tasks_.end()) {
99         Entry entry(taskIdentifier, std::move(task));
100         entry.dependencies.push_back(afterIdentifier);
101 
102         tasks_.push_back(std::move(entry));
103     } else {
104         tasks_.emplace_back(taskIdentifier, std::move(task));
105     }
106 }
107 
Remove(uint64_t taskIdentifier)108 void ParallelTaskQueue::Remove(uint64_t taskIdentifier)
109 {
110     auto it = std::find(tasks_.begin(), tasks_.end(), taskIdentifier);
111     if (it != tasks_.end()) {
112         tasks_.erase(it);
113     }
114 }
115 
Clear()116 void ParallelTaskQueue::Clear()
117 {
118     Wait();
119     tasks_.clear();
120 }
121 
QueueTasks(vector<size_t> & waiting,TaskState & state)122 void ParallelTaskQueue::QueueTasks(vector<size_t>& waiting, TaskState& state)
123 {
124     if (waiting.empty()) {
125         // No more tasks to proecss.
126         return;
127     }
128 
129     for (vector<size_t>::iterator it = waiting.begin(); it != waiting.end();) {
130         // Entry to handle.
131         Entry& entry = tasks_[*it];
132 
133         // Can run this task?
134         bool canRun = true;
135         for (const auto& dep : entry.dependencies) {
136             if (!state.finished.contains(dep)) {
137                 // Task that is marked as dependency is not executed yet.
138                 canRun = false;
139                 break;
140             }
141         }
142 
143         if (canRun) {
144             // This task can be executed.
145             // Remove task from waiting list.
146             it = waiting.erase(it);
147 
148             // Push to execution queue.
149             threadPool_->PushNoWait(IThreadPool::ITask::Ptr { new Task(state, *entry.task, entry.identifier) });
150         } else {
151             ++it;
152         }
153     }
154 }
155 
Execute()156 void ParallelTaskQueue::Execute()
157 {
158 #if (CORE_VALIDATION_ENABLED == 1)
159     // NOTE: Check the integrity of the task queue (no circular deps etc.)
160 #endif
161     vector<size_t> waiting;
162     waiting.resize(tasks_.size());
163     for (size_t i = 0; i < tasks_.size(); ++i) {
164         waiting[i] = i;
165     }
166 
167     TaskState state;
168     state.finished.reserve(tasks_.size());
169 
170     {
171         // Keep on pushing tasks to queue until all done.
172         std::unique_lock lock(state.mutex);
173         state.cv.wait(lock, [this, &waiting, &state]() {
174             // Push new tasks to queue.
175             QueueTasks(waiting, state);
176             return state.finished.size() == tasks_.size();
177         });
178     }
179 }
180 CORE_END_NAMESPACE()
181