1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "threading/task_queue_factory.h"
17
18 #include <algorithm>
19 #include <condition_variable>
20 #include <cstddef>
21 #include <deque>
22 #include <thread>
23
24 #include <base/containers/array_view.h>
25 #include <base/containers/atomics.h>
26 #include <base/containers/iterator.h>
27 #include <base/containers/shared_ptr.h>
28 #include <base/containers/type_traits.h>
29 #include <base/containers/unique_ptr.h>
30 #include <base/math/mathf.h>
31 #include <base/util/uid.h>
32 #include <core/log.h>
33 #include <core/perf/cpu_perf_scope.h>
34 #include <core/threading/intf_thread_pool.h>
35
36 #include "threading/dispatcher_impl.h"
37 #include "threading/parallel_impl.h"
38 #include "threading/sequential_impl.h"
39
40 #ifdef PLATFORM_HAS_JAVA
41 #include <os/java/java_internal.h>
42 #endif
43
44 CORE_BEGIN_NAMESPACE()
45 using BASE_NS::array_view;
46 using BASE_NS::make_unique;
47 using BASE_NS::move;
48 using BASE_NS::unique_ptr;
49 using BASE_NS::Math::max;
50
51 namespace {
52 #ifdef PLATFORM_HAS_JAVA
53 /** RAII class for handling thread setup/release. */
54 class JavaThreadContext final {
55 public:
JavaThreadContext()56 JavaThreadContext()
57 {
58 JNIEnv* env = nullptr;
59 javaVm_ = java_internal::GetJavaVM();
60
61 #ifndef NDEBUG
62 // Check that the thread was not already attached.
63 // It's not really a problem as another attach is a no-op, but we will be detaching the
64 // thread later and it may be unexpected for the user.
65 jint result = javaVm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
66 CORE_ASSERT_MSG((result != JNI_OK), "Thread already attached");
67 #endif
68
69 javaVm_->AttachCurrentThread(&env, nullptr);
70 }
71
~JavaThreadContext()72 ~JavaThreadContext()
73 {
74 javaVm_->DetachCurrentThread();
75 }
76 JavaVM* javaVm_ { nullptr };
77 };
78 #endif // PLATFORM_HAS_JAVA
79
80 // -- TaskResult, returned by ThreadPool::Push and can be waited on.
81 class TaskResult final : public IThreadPool::IResult {
82 public:
83 // Task state which can be waited and marked as done.
84 class State {
85 public:
Done()86 void Done()
87 {
88 {
89 auto lock = std::lock_guard(mutex_);
90 done_ = true;
91 }
92 cv_.notify_all();
93 }
94
Wait()95 void Wait()
96 {
97 auto lock = std::unique_lock(mutex_);
98 cv_.wait(lock, [this]() { return done_; });
99 }
100
IsDone() const101 bool IsDone() const
102 {
103 auto lock = std::lock_guard(mutex_);
104 return done_;
105 }
106
107 private:
108 mutable std::mutex mutex_;
109 std::condition_variable cv_;
110 bool done_ { false };
111 };
112
TaskResult(BASE_NS::shared_ptr<State> && future)113 explicit TaskResult(BASE_NS::shared_ptr<State>&& future) : future_(BASE_NS::move(future)) {}
114
Wait()115 void Wait() final
116 {
117 if (future_) {
118 future_->Wait();
119 }
120 }
IsDone() const121 bool IsDone() const final
122 {
123 if (future_) {
124 return future_->IsDone();
125 }
126 return true;
127 }
128
129 protected:
Destroy()130 void Destroy() final
131 {
132 delete this;
133 }
134
135 private:
136 BASE_NS::shared_ptr<State> future_;
137 };
138
139 // -- ThreadPool
140 class ThreadPool final : public IThreadPool {
141 public:
ThreadPool(size_t threadCount)142 explicit ThreadPool(size_t threadCount)
143 : threadCount_(max(size_t(1), threadCount)), threads_(make_unique<ThreadContext[]>(threadCount_))
144 {
145 CORE_ASSERT(threads_);
146
147 if (threadCount == 0U) {
148 CORE_LOG_W("Threadpool minimum thread count is 1");
149 }
150 // Create thread containers.
151 auto threads = array_view<ThreadContext>(threads_.get(), threadCount_);
152 for (ThreadContext& context : threads) {
153 // Set-up thread function.
154 context.thread = std::thread(&ThreadPool::ThreadProc, this, std::ref(context));
155 }
156 }
157
158 ThreadPool(const ThreadPool&) = delete;
159 ThreadPool(ThreadPool&&) = delete;
160 ThreadPool& operator=(const ThreadPool&) = delete;
161 ThreadPool& operator=(ThreadPool&&) = delete;
162
Push(ITask::Ptr task)163 IResult::Ptr Push(ITask::Ptr task) override
164 {
165 auto taskState = BASE_NS::make_shared<TaskResult::State>();
166 if (taskState) {
167 if (task) {
168 std::lock_guard lock(mutex_);
169 q_.push_back(BASE_NS::make_shared<Task>(BASE_NS::move(task), taskState));
170 cv_.notify_one();
171 } else {
172 // mark as done if the there was no function.
173 taskState->Done();
174 }
175 }
176 return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
177 }
178
Push(ITask::Ptr task,BASE_NS::array_view<const ITask * const> dependencies)179 IResult::Ptr Push(ITask::Ptr task, BASE_NS::array_view<const ITask* const> dependencies) override
180 {
181 if (dependencies.empty()) {
182 return Push(BASE_NS::move(task));
183 }
184 auto taskState = BASE_NS::make_shared<TaskResult::State>();
185 if (taskState) {
186 if (task) {
187 BASE_NS::vector<BASE_NS::weak_ptr<Task>> deps;
188 deps.reserve(dependencies.size());
189 {
190 std::lock_guard lock(mutex_);
191 for (const ITask* dep : dependencies) {
192 if (auto pos = std::find_if(
193 q_.cbegin(), q_.cend(),
194 [dep](const BASE_NS::shared_ptr<Task> &task) { return task && (*task == dep); });
195 pos != q_.cend()) {
196 deps.push_back(*pos);
197 }
198 }
199 q_.push_back(BASE_NS::make_shared<Task>(BASE_NS::move(task), taskState, BASE_NS::move(deps)));
200 cv_.notify_one();
201 }
202 } else {
203 // mark as done if the there was no function.
204 taskState->Done();
205 }
206 }
207 return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
208 }
209
PushNoWait(ITask::Ptr task)210 void PushNoWait(ITask::Ptr task) override
211 {
212 if (task) {
213 std::lock_guard lock(mutex_);
214 q_.push_back(BASE_NS::make_shared<Task>(BASE_NS::move(task)));
215 cv_.notify_one();
216 }
217 }
218
PushNoWait(ITask::Ptr task,BASE_NS::array_view<const ITask * const> dependencies)219 void PushNoWait(ITask::Ptr task, BASE_NS::array_view<const ITask* const> dependencies) override
220 {
221 if (dependencies.empty()) {
222 return PushNoWait(BASE_NS::move(task));
223 }
224
225 if (task) {
226 BASE_NS::vector<BASE_NS::weak_ptr<Task>> deps;
227 deps.reserve(dependencies.size());
228 {
229 std::lock_guard lock(mutex_);
230 for (const ITask* dep : dependencies) {
231 if (auto pos = std::find_if(
232 q_.cbegin(), q_.cend(),
233 [dep](const BASE_NS::shared_ptr<Task> &task) { return task && (*task == dep); });
234 pos != q_.cend()) {
235 deps.push_back(*pos);
236 }
237 }
238 q_.push_back(BASE_NS::make_shared<Task>(BASE_NS::move(task), BASE_NS::move(deps)));
239 cv_.notify_one();
240 }
241 }
242 }
243
GetNumberOfThreads() const244 uint32_t GetNumberOfThreads() const override
245 {
246 return static_cast<uint32_t>(threadCount_);
247 }
248
249 // IInterface
GetInterface(const BASE_NS::Uid & uid) const250 const IInterface* GetInterface(const BASE_NS::Uid& uid) const override
251 {
252 if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
253 return this;
254 }
255 return nullptr;
256 }
257
GetInterface(const BASE_NS::Uid & uid)258 IInterface* GetInterface(const BASE_NS::Uid& uid) override
259 {
260 if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
261 return this;
262 }
263 return nullptr;
264 }
265
Ref()266 void Ref() override
267 {
268 BASE_NS::AtomicIncrementRelaxed(&refcnt_);
269 }
270
Unref()271 void Unref() override
272 {
273 if (BASE_NS::AtomicDecrementRelease(&refcnt_) == 0) {
274 BASE_NS::AtomicFenceAcquire();
275 delete this;
276 }
277 }
278
279 protected:
~ThreadPool()280 ~ThreadPool() final
281 {
282 Stop();
283 }
284
285 private:
286 struct ThreadContext {
287 std::thread thread;
288 };
289
290 // Helper which holds a pointer to a queued task function and the result state.
291 struct Task {
292 ITask::Ptr function_;
293 BASE_NS::shared_ptr<TaskResult::State> state_;
294 BASE_NS::vector<BASE_NS::weak_ptr<Task>> dependencies_;
295 bool running_ { false };
296
297 ~Task() = default;
298 Task() = default;
299
Task__anonc0e55d560111::ThreadPool::Task300 Task(ITask::Ptr&& function, BASE_NS::shared_ptr<TaskResult::State> state,
301 BASE_NS::vector<BASE_NS::weak_ptr<Task>>&& dependencies)
302 : function_(BASE_NS::move(function)),
303 state_(BASE_NS::move(state)), dependencies_ { BASE_NS::move(dependencies) }
304 {
305 CORE_ASSERT(this->function_ && this->state_);
306 }
307
Task__anonc0e55d560111::ThreadPool::Task308 Task(ITask::Ptr&& function, BASE_NS::shared_ptr<TaskResult::State> state)
309 : function_(BASE_NS::move(function)), state_(BASE_NS::move(state))
310 {
311 CORE_ASSERT(this->function_ && this->state_);
312 }
313
Task__anonc0e55d560111::ThreadPool::Task314 explicit Task(ITask::Ptr&& function) : function_(BASE_NS::move(function))
315 {
316 CORE_ASSERT(this->function_);
317 }
318
Task__anonc0e55d560111::ThreadPool::Task319 Task(ITask::Ptr&& function, BASE_NS::vector<BASE_NS::weak_ptr<Task>>&& dependencies)
320 : function_(BASE_NS::move(function)), dependencies_ { BASE_NS::move(dependencies) }
321 {
322 CORE_ASSERT(this->function_);
323 }
324
325 Task(Task&&) = default;
326 Task& operator=(Task&&) = default;
327 Task(const Task&) = delete;
328 Task& operator=(const Task&) = delete;
329
operator ()__anonc0e55d560111::ThreadPool::Task330 inline void operator()() const
331 {
332 (*function_)();
333 if (state_) {
334 state_->Done();
335 }
336 }
337
operator ==__anonc0e55d560111::ThreadPool::Task338 inline bool operator==(const ITask* task) const
339 {
340 return function_.get() == task;
341 }
342
343 // Task can run if it's not already running and there are no dependencies, or all the dependencies are ready.
CanRun__anonc0e55d560111::ThreadPool::Task344 inline bool CanRun() const
345 {
346 return !running_ &&
347 (dependencies_.empty() ||
348 std::all_of(std::begin(dependencies_), std::end(dependencies_),
349 [](const BASE_NS::weak_ptr<Task>& dependency) { return dependency.expired(); }));
350 }
351 };
352
353 // Looks for a task that can be executed.
FindRunnable()354 BASE_NS::shared_ptr<Task> FindRunnable()
355 {
356 if (q_.empty()) {
357 return {};
358 }
359 for (auto& task : q_) {
360 if (task && task->CanRun()) {
361 task->running_ = true;
362 return task;
363 }
364 }
365 return {};
366 }
367
Clear()368 void Clear()
369 {
370 std::lock_guard lock(mutex_);
371 q_.clear();
372 }
373
374 // At the moment Stop is called only from the destructor with waitForAllTasksToComplete=true.
375 // If this doesn't change the class can be simplified a bit.
Stop()376 void Stop()
377 {
378 // Wait all tasks to complete before returning.
379 if (isDone_) {
380 return;
381 }
382 {
383 std::lock_guard lock(mutex_);
384 isDone_ = true;
385 }
386
387 // Trigger all waiting threads.
388 cv_.notify_all();
389
390 // Wait for all threads to finish.
391 auto threads = array_view(threads_.get(), threadCount_);
392 for (auto& context : threads) {
393 if (context.thread.joinable()) {
394 context.thread.join();
395 }
396 }
397
398 Clear();
399 }
400
ThreadProc(ThreadContext & context)401 void ThreadProc(ThreadContext& context)
402 {
403 #ifdef PLATFORM_HAS_JAVA
404 // RAII class for handling thread setup/release.
405 JavaThreadContext javaContext;
406 #endif
407
408 while (true) {
409 // Function to process.
410 BASE_NS::shared_ptr<Task> task;
411 {
412 std::unique_lock lock(mutex_);
413
414 // Try to wait for next task to process.
415 cv_.wait(lock, [this, &task]() {
416 task = FindRunnable();
417 return task || isDone_;
418 });
419 }
420 // If there was no task it means we are stopping and thread can exit.
421 if (!task) {
422 return;
423 }
424
425 while (task) {
426 // Run task function.
427 {
428 CORE_CPU_PERF_SCOPE("CORE", "ThreadPoolTask", "", CORE_PROFILER_DEFAULT_COLOR);
429 (*task)();
430 }
431
432 std::lock_guard lock(mutex_);
433 // After running the task remove it from the queue. Any dependent tasks will see their weak_ptr expire
434 // idicating that the dependency has been completed.
435 if (auto pos = std::find_if(
436 q_.cbegin(), q_.cend(),
437 [&task](const BASE_NS::shared_ptr<Task> &queuedTask) { return queuedTask == task; });
438 pos != q_.cend()) {
439 q_.erase(pos);
440 }
441 task.reset();
442
443 // Get next function.
444 if (auto pos =
445 std::find_if(std::begin(q_), std::end(q_),
446 [](const BASE_NS::shared_ptr<Task> &task) { return (task) && (task->CanRun()); });
447 pos != std::end(q_)) {
448 task = *pos;
449 task->running_ = true;
450 // Check if there are more runnable tasks and notify workers as needed.
451 auto runnable = std::min(static_cast<ptrdiff_t>(threadCount_),
452 std::count_if(pos + 1, std::end(q_),
453 [](const BASE_NS::shared_ptr<Task>& task) { return (task) && (task->CanRun()); }));
454 while (runnable--) {
455 cv_.notify_one();
456 }
457 }
458 }
459 }
460 }
461
462 size_t threadCount_ { 0 };
463 unique_ptr<ThreadContext[]> threads_;
464
465 std::deque<BASE_NS::shared_ptr<Task>> q_;
466
467 bool isDone_ { false };
468
469 std::mutex mutex_;
470 std::condition_variable cv_;
471 int32_t refcnt_ { 0 };
472 };
473 } // namespace
474
GetNumberOfCores() const475 uint32_t TaskQueueFactory::GetNumberOfCores() const
476 {
477 uint32_t result = std::thread::hardware_concurrency();
478 if (result == 0) {
479 // If not detectable, default to 4.
480 result = 4;
481 }
482
483 return result;
484 }
485
CreateThreadPool(const uint32_t threadCount) const486 IThreadPool::Ptr TaskQueueFactory::CreateThreadPool(const uint32_t threadCount) const
487 {
488 return IThreadPool::Ptr { new ThreadPool(threadCount) };
489 }
490
CreateDispatcherTaskQueue(const IThreadPool::Ptr & threadPool) const491 IDispatcherTaskQueue::Ptr TaskQueueFactory::CreateDispatcherTaskQueue(const IThreadPool::Ptr& threadPool) const
492 {
493 return IDispatcherTaskQueue::Ptr { make_unique<DispatcherImpl>(threadPool).release() };
494 }
495
CreateParallelTaskQueue(const IThreadPool::Ptr & threadPool) const496 IParallelTaskQueue::Ptr TaskQueueFactory::CreateParallelTaskQueue(const IThreadPool::Ptr& threadPool) const
497 {
498 return IParallelTaskQueue::Ptr { make_unique<ParallelImpl>(threadPool).release() };
499 }
500
CreateSequentialTaskQueue(const IThreadPool::Ptr & threadPool) const501 ISequentialTaskQueue::Ptr TaskQueueFactory::CreateSequentialTaskQueue(const IThreadPool::Ptr& threadPool) const
502 {
503 return ISequentialTaskQueue::Ptr { make_unique<SequentialImpl>(threadPool).release() };
504 }
505
506 // IInterface
GetInterface(const BASE_NS::Uid & uid) const507 const IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid) const
508 {
509 if (uid == ITaskQueueFactory::UID) {
510 return static_cast<const ITaskQueueFactory*>(this);
511 }
512 return nullptr;
513 }
514
GetInterface(const BASE_NS::Uid & uid)515 IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid)
516 {
517 if (uid == ITaskQueueFactory::UID) {
518 return static_cast<ITaskQueueFactory*>(this);
519 }
520 return nullptr;
521 }
522
Ref()523 void TaskQueueFactory::Ref() {}
524
Unref()525 void TaskQueueFactory::Unref() {}
526 CORE_END_NAMESPACE()
527