1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "threading/task_queue_factory.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <condition_variable>
21 #include <cstddef>
22 #include <deque>
23 #include <memory>
24 #include <thread>
25
26 #include <base/containers/array_view.h>
27 #include <base/containers/iterator.h>
28 #include <base/containers/type_traits.h>
29 #include <base/containers/unique_ptr.h>
30 #include <base/math/mathf.h>
31 #include <base/util/uid.h>
32 #include <core/log.h>
33 #include <core/perf/cpu_perf_scope.h>
34 #include <core/threading/intf_thread_pool.h>
35
36 #include "threading/dispatcher_impl.h"
37 #include "threading/parallel_impl.h"
38 #include "threading/sequential_impl.h"
39
40 #ifdef PLATFORM_HAS_JAVA
41 #include <os/java/java_internal.h>
42 #endif
43
44 CORE_BEGIN_NAMESPACE()
45 using BASE_NS::array_view;
46 using BASE_NS::make_unique;
47 using BASE_NS::move;
48 using BASE_NS::unique_ptr;
49 using BASE_NS::Math::max;
50
51 namespace {
52 #ifdef PLATFORM_HAS_JAVA
53 /** RAII class for handling thread setup/release. */
54 class JavaThreadContext final {
55 public:
JavaThreadContext()56 JavaThreadContext()
57 {
58 JNIEnv* env = nullptr;
59 javaVm_ = java_internal::GetJavaVM();
60
61 #ifndef NDEBUG
62 // Check that the thread was not already attached.
63 // It's not really a problem as another attach is a no-op, but we will be detaching the
64 // thread later and it may be unexpected for the user.
65 jint result = javaVm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
66 CORE_ASSERT_MSG((result != JNI_OK), "Thread already attached");
67 #endif
68
69 javaVm_->AttachCurrentThread(&env, nullptr);
70 }
71
~JavaThreadContext()72 ~JavaThreadContext()
73 {
74 javaVm_->DetachCurrentThread();
75 }
76 JavaVM* javaVm_ { nullptr };
77 };
78 #endif // PLATFORM_HAS_JAVA
79
80 // -- TaskResult, returned by ThreadPool::Push and can be waited on.
81 class TaskResult final : public IThreadPool::IResult {
82 public:
83 // Task state which can be waited and marked as done.
84 class State {
85 public:
Done()86 void Done()
87 {
88 {
89 auto lock = std::lock_guard(mutex_);
90 done_ = true;
91 }
92 cv_.notify_all();
93 }
94
Wait()95 void Wait()
96 {
97 auto lock = std::unique_lock(mutex_);
98 cv_.wait(lock, [this]() { return done_; });
99 }
100
IsDone() const101 bool IsDone() const
102 {
103 auto lock = std::lock_guard(mutex_);
104 return done_;
105 }
106
107 private:
108 mutable std::mutex mutex_;
109 std::condition_variable cv_;
110 bool done_ { false };
111 };
112
TaskResult(std::shared_ptr<State> && future)113 explicit TaskResult(std::shared_ptr<State>&& future) : future_(BASE_NS::move(future)) {}
114
Wait()115 void Wait() final
116 {
117 if (future_) {
118 future_->Wait();
119 }
120 }
IsDone() const121 bool IsDone() const final
122 {
123 if (future_) {
124 return future_->IsDone();
125 }
126 return true;
127 }
128
129 protected:
Destroy()130 void Destroy() final
131 {
132 delete this;
133 }
134
135 private:
136 std::shared_ptr<State> future_;
137 };
138
139 // -- ThreadPool
140 class ThreadPool final : public IThreadPool {
141 public:
ThreadPool(size_t threadCount)142 explicit ThreadPool(size_t threadCount)
143 : threadCount_(max(size_t(1), threadCount)), threads_(make_unique<ThreadContext[]>(threadCount_))
144 {
145 CORE_ASSERT(threads_);
146
147 if (threadCount == 0U) {
148 CORE_LOG_W("Threadpool minimum thread count is 1");
149 }
150 // Create thread containers.
151 auto threads = array_view<ThreadContext>(threads_.get(), threadCount_);
152 for (ThreadContext& context : threads) {
153 // Set-up thread function.
154 context.thread = std::thread(&ThreadPool::ThreadProc, this, std::ref(context));
155 }
156 }
157
158 ThreadPool(const ThreadPool&) = delete;
159 ThreadPool(ThreadPool&&) = delete;
160 ThreadPool& operator=(const ThreadPool&) = delete;
161 ThreadPool& operator=(ThreadPool&&) = delete;
162
Push(ITask::Ptr task)163 IResult::Ptr Push(ITask::Ptr task) override
164 {
165 auto taskState = std::make_shared<TaskResult::State>();
166 if (taskState) {
167 if (task) {
168 std::lock_guard lock(mutex_);
169 q_.push_back(std::make_shared<Task>(BASE_NS::move(task), taskState));
170 cv_.notify_one();
171 } else {
172 // mark as done if the there was no function.
173 taskState->Done();
174 }
175 }
176 return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
177 }
178
Push(ITask::Ptr task,BASE_NS::array_view<const ITask * const> dependencies)179 IResult::Ptr Push(ITask::Ptr task, BASE_NS::array_view<const ITask* const> dependencies) override
180 {
181 if (dependencies.empty()) {
182 return Push(BASE_NS::move(task));
183 }
184 auto taskState = std::make_shared<TaskResult::State>();
185 if (taskState) {
186 if (task) {
187 BASE_NS::vector<std::weak_ptr<Task>> deps;
188 deps.reserve(dependencies.size());
189 {
190 std::lock_guard lock(mutex_);
191 for (const ITask* dep : dependencies) {
192 if (auto pos = std::find_if(q_.cbegin(), q_.cend(),
193 [dep](const std::shared_ptr<Task>& task) { return task && (*task == dep); });
194 pos != q_.cend()) {
195 deps.push_back(*pos);
196 }
197 }
198 q_.push_back(std::make_shared<Task>(BASE_NS::move(task), taskState, BASE_NS::move(deps)));
199 cv_.notify_one();
200 }
201 } else {
202 // mark as done if the there was no function.
203 taskState->Done();
204 }
205 }
206 return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
207 }
208
PushNoWait(ITask::Ptr task)209 void PushNoWait(ITask::Ptr task) override
210 {
211 if (task) {
212 std::lock_guard lock(mutex_);
213 q_.push_back(std::make_shared<Task>(BASE_NS::move(task)));
214 cv_.notify_one();
215 }
216 }
217
PushNoWait(ITask::Ptr task,BASE_NS::array_view<const ITask * const> dependencies)218 void PushNoWait(ITask::Ptr task, BASE_NS::array_view<const ITask* const> dependencies) override
219 {
220 if (dependencies.empty()) {
221 return PushNoWait(BASE_NS::move(task));
222 }
223
224 if (task) {
225 BASE_NS::vector<std::weak_ptr<Task>> deps;
226 deps.reserve(dependencies.size());
227 {
228 std::lock_guard lock(mutex_);
229 for (const ITask* dep : dependencies) {
230 if (auto pos = std::find_if(q_.cbegin(), q_.cend(),
231 [dep](const std::shared_ptr<Task>& task) { return task && (*task == dep); });
232 pos != q_.cend()) {
233 deps.push_back(*pos);
234 }
235 }
236 q_.push_back(std::make_shared<Task>(BASE_NS::move(task), BASE_NS::move(deps)));
237 cv_.notify_one();
238 }
239 }
240 }
241
GetNumberOfThreads() const242 uint32_t GetNumberOfThreads() const override
243 {
244 return static_cast<uint32_t>(threadCount_);
245 }
246
247 // IInterface
GetInterface(const BASE_NS::Uid & uid) const248 const IInterface* GetInterface(const BASE_NS::Uid& uid) const override
249 {
250 if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
251 return this;
252 }
253 return nullptr;
254 }
255
GetInterface(const BASE_NS::Uid & uid)256 IInterface* GetInterface(const BASE_NS::Uid& uid) override
257 {
258 if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
259 return this;
260 }
261 return nullptr;
262 }
263
Ref()264 void Ref() override
265 {
266 refcnt_.fetch_add(1, std::memory_order_relaxed);
267 }
268
Unref()269 void Unref() override
270 {
271 if (std::atomic_fetch_sub_explicit(&refcnt_, 1, std::memory_order_release) == 1) {
272 std::atomic_thread_fence(std::memory_order_acquire);
273 delete this;
274 }
275 }
276
277 protected:
~ThreadPool()278 ~ThreadPool() final
279 {
280 Stop();
281 }
282
283 private:
284 struct ThreadContext {
285 std::thread thread;
286 };
287
288 // Helper which holds a pointer to a queued task function and the result state.
289 struct Task {
290 ITask::Ptr function_;
291 std::shared_ptr<TaskResult::State> state_;
292 BASE_NS::vector<std::weak_ptr<Task>> dependencies_;
293 bool running_ { false };
294
295 ~Task() = default;
296 Task() = default;
297
Task__anon0bea27f40111::ThreadPool::Task298 Task(ITask::Ptr&& function, std::shared_ptr<TaskResult::State> state,
299 BASE_NS::vector<std::weak_ptr<Task>>&& dependencies)
300 : function_(BASE_NS::move(function)),
301 state_(BASE_NS::move(state)), dependencies_ { BASE_NS::move(dependencies) }
302 {
303 CORE_ASSERT(this->function_ && this->state_);
304 }
305
Task__anon0bea27f40111::ThreadPool::Task306 Task(ITask::Ptr&& function, std::shared_ptr<TaskResult::State> state)
307 : function_(BASE_NS::move(function)), state_(BASE_NS::move(state))
308 {
309 CORE_ASSERT(this->function_ && this->state_);
310 }
311
Task__anon0bea27f40111::ThreadPool::Task312 explicit Task(ITask::Ptr&& function) : function_(BASE_NS::move(function))
313 {
314 CORE_ASSERT(this->function_);
315 }
316
Task__anon0bea27f40111::ThreadPool::Task317 Task(ITask::Ptr&& function, BASE_NS::vector<std::weak_ptr<Task>>&& dependencies)
318 : function_(BASE_NS::move(function)), dependencies_ { BASE_NS::move(dependencies) }
319 {
320 CORE_ASSERT(this->function_);
321 }
322
323 Task(Task&&) = default;
324 Task& operator=(Task&&) = default;
325 Task(const Task&) = delete;
326 Task& operator=(const Task&) = delete;
327
operator ()__anon0bea27f40111::ThreadPool::Task328 inline void operator()() const
329 {
330 (*function_)();
331 if (state_) {
332 state_->Done();
333 }
334 }
335
operator ==__anon0bea27f40111::ThreadPool::Task336 inline bool operator==(const ITask* task) const
337 {
338 return function_.get() == task;
339 }
340
341 // Task can run if it's not already running and there are no dependencies, or all the dependencies are ready.
CanRun__anon0bea27f40111::ThreadPool::Task342 inline bool CanRun() const
343 {
344 return !running_ && (dependencies_.empty() ||
345 std::all_of(std::begin(dependencies_), std::end(dependencies_),
346 [](const std::weak_ptr<Task>& dependency) { return dependency.expired(); }));
347 }
348 };
349
350 // Looks for a task that can be executed.
FindRunnable()351 std::shared_ptr<Task> FindRunnable()
352 {
353 if (q_.empty()) {
354 return {};
355 }
356 for (auto& task : q_) {
357 if (task && task->CanRun()) {
358 task->running_ = true;
359 return task;
360 }
361 }
362 return {};
363 }
364
Clear()365 void Clear()
366 {
367 std::lock_guard lock(mutex_);
368 q_.clear();
369 }
370
371 // At the moment Stop is called only from the destructor with waitForAllTasksToComplete=true.
372 // If this doesn't change the class can be simplified a bit.
Stop()373 void Stop()
374 {
375 // Wait all tasks to complete before returning.
376 if (isDone_) {
377 return;
378 }
379 {
380 std::lock_guard lock(mutex_);
381 isDone_ = true;
382 }
383
384 // Trigger all waiting threads.
385 cv_.notify_all();
386
387 // Wait for all threads to finish.
388 auto threads = array_view(threads_.get(), threadCount_);
389 for (auto& context : threads) {
390 if (context.thread.joinable()) {
391 context.thread.join();
392 }
393 }
394
395 Clear();
396 }
397
ThreadProc(ThreadContext & context)398 void ThreadProc(ThreadContext& context)
399 {
400 #ifdef PLATFORM_HAS_JAVA
401 // RAII class for handling thread setup/release.
402 JavaThreadContext javaContext;
403 #endif
404
405 while (true) {
406 // Function to process.
407 std::shared_ptr<Task> task;
408 {
409 std::unique_lock lock(mutex_);
410
411 // Try to wait for next task to process.
412 cv_.wait(lock, [this, &task]() {
413 task = FindRunnable();
414 return task || isDone_;
415 });
416 }
417 // If there was no task it means we are stopping and thread can exit.
418 if (!task) {
419 return;
420 }
421
422 while (task) {
423 // Run task function.
424 {
425 CORE_CPU_PERF_SCOPE("CORE", "ThreadPoolTask", "", CORE_PROFILER_DEFAULT_COLOR);
426 (*task)();
427 }
428
429 std::lock_guard lock(mutex_);
430 // After running the task remove it from the queue. Any dependent tasks will see their weak_ptr expire
431 // idicating that the dependency has been completed.
432 if (auto pos = std::find_if(q_.cbegin(), q_.cend(),
433 [&task](const std::shared_ptr<Task>& queuedTask) { return queuedTask == task; });
434 pos != q_.cend()) {
435 q_.erase(pos);
436 }
437 task.reset();
438
439 // Get next function.
440 if (auto pos = std::find_if(std::begin(q_), std::end(q_),
441 [](const std::shared_ptr<Task>& task) { return (task) && (task->CanRun()); });
442 pos != std::end(q_)) {
443 task = *pos;
444 task->running_ = true;
445 // Check if there are more runnable tasks and notify workers as needed.
446 auto runnable = std::min(static_cast<ptrdiff_t>(threadCount_),
447 std::count_if(pos + 1, std::end(q_),
448 [](const std::shared_ptr<Task>& task) { return (task) && (task->CanRun()); }));
449 while (runnable--) {
450 cv_.notify_one();
451 }
452 }
453 }
454 }
455 }
456
457 size_t threadCount_ { 0 };
458 unique_ptr<ThreadContext[]> threads_;
459
460 std::deque<std::shared_ptr<Task>> q_;
461
462 bool isDone_ { false };
463
464 std::mutex mutex_;
465 std::condition_variable cv_;
466 std::atomic<int32_t> refcnt_ { 0 };
467 };
468 } // namespace
469
GetNumberOfCores() const470 uint32_t TaskQueueFactory::GetNumberOfCores() const
471 {
472 uint32_t result = std::thread::hardware_concurrency();
473 if (result == 0) {
474 // If not detectable, default to 4.
475 result = 4;
476 }
477
478 return result;
479 }
480
CreateThreadPool(const uint32_t threadCount) const481 IThreadPool::Ptr TaskQueueFactory::CreateThreadPool(const uint32_t threadCount) const
482 {
483 return IThreadPool::Ptr { new ThreadPool(threadCount) };
484 }
485
CreateDispatcherTaskQueue(const IThreadPool::Ptr & threadPool) const486 IDispatcherTaskQueue::Ptr TaskQueueFactory::CreateDispatcherTaskQueue(const IThreadPool::Ptr& threadPool) const
487 {
488 return IDispatcherTaskQueue::Ptr { make_unique<DispatcherImpl>(threadPool).release() };
489 }
490
CreateParallelTaskQueue(const IThreadPool::Ptr & threadPool) const491 IParallelTaskQueue::Ptr TaskQueueFactory::CreateParallelTaskQueue(const IThreadPool::Ptr& threadPool) const
492 {
493 return IParallelTaskQueue::Ptr { make_unique<ParallelImpl>(threadPool).release() };
494 }
495
CreateSequentialTaskQueue(const IThreadPool::Ptr & threadPool) const496 ISequentialTaskQueue::Ptr TaskQueueFactory::CreateSequentialTaskQueue(const IThreadPool::Ptr& threadPool) const
497 {
498 return ISequentialTaskQueue::Ptr { make_unique<SequentialImpl>(threadPool).release() };
499 }
500
501 // IInterface
GetInterface(const BASE_NS::Uid & uid) const502 const IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid) const
503 {
504 if (uid == ITaskQueueFactory::UID) {
505 return static_cast<const ITaskQueueFactory*>(this);
506 }
507 return nullptr;
508 }
509
GetInterface(const BASE_NS::Uid & uid)510 IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid)
511 {
512 if (uid == ITaskQueueFactory::UID) {
513 return static_cast<ITaskQueueFactory*>(this);
514 }
515 return nullptr;
516 }
517
Ref()518 void TaskQueueFactory::Ref() {}
519
Unref()520 void TaskQueueFactory::Unref() {}
521 CORE_END_NAMESPACE()
522