• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "threading/task_queue_factory.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <condition_variable>
21 #include <cstddef>
22 #include <deque>
23 #include <memory>
24 #include <thread>
25 
26 #include <base/containers/array_view.h>
27 #include <base/containers/iterator.h>
28 #include <base/containers/type_traits.h>
29 #include <base/containers/unique_ptr.h>
30 #include <base/math/mathf.h>
31 #include <base/util/uid.h>
32 #include <core/log.h>
33 #include <core/perf/cpu_perf_scope.h>
34 #include <core/threading/intf_thread_pool.h>
35 
36 #include "threading/dispatcher_impl.h"
37 #include "threading/parallel_impl.h"
38 #include "threading/sequential_impl.h"
39 
40 #ifdef PLATFORM_HAS_JAVA
41 #include <os/java/java_internal.h>
42 #endif
43 
44 CORE_BEGIN_NAMESPACE()
45 using BASE_NS::array_view;
46 using BASE_NS::make_unique;
47 using BASE_NS::move;
48 using BASE_NS::unique_ptr;
49 using BASE_NS::Math::max;
50 
51 namespace {
52 #ifdef PLATFORM_HAS_JAVA
53 /** RAII class for handling thread setup/release. */
54 class JavaThreadContext final {
55 public:
JavaThreadContext()56     JavaThreadContext()
57     {
58         JNIEnv* env = nullptr;
59         javaVm_ = java_internal::GetJavaVM();
60 
61 #ifndef NDEBUG
62         // Check that the thread was not already attached.
63         // It's not really a problem as another attach is a no-op, but we will be detaching the
64         // thread later and it may be unexpected for the user.
65         jint result = javaVm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
66         CORE_ASSERT_MSG((result != JNI_OK), "Thread already attached");
67 #endif
68 
69         javaVm_->AttachCurrentThread(&env, nullptr);
70     }
71 
~JavaThreadContext()72     ~JavaThreadContext()
73     {
74         javaVm_->DetachCurrentThread();
75     }
76     JavaVM* javaVm_ { nullptr };
77 };
78 #endif // PLATFORM_HAS_JAVA
79 
80 // -- TaskResult, returned by ThreadPool::Push and can be waited on.
81 class TaskResult final : public IThreadPool::IResult {
82 public:
83     // Task state which can be waited and marked as done.
84     class State {
85     public:
Done()86         void Done()
87         {
88             {
89                 auto lock = std::lock_guard(mutex_);
90                 done_ = true;
91             }
92             cv_.notify_all();
93         }
94 
Wait()95         void Wait()
96         {
97             auto lock = std::unique_lock(mutex_);
98             cv_.wait(lock, [this]() { return done_; });
99         }
100 
IsDone() const101         bool IsDone() const
102         {
103             auto lock = std::lock_guard(mutex_);
104             return done_;
105         }
106 
107     private:
108         mutable std::mutex mutex_;
109         std::condition_variable cv_;
110         bool done_ { false };
111     };
112 
TaskResult(std::shared_ptr<State> && future)113     explicit TaskResult(std::shared_ptr<State>&& future) : future_(BASE_NS::move(future)) {}
114 
Wait()115     void Wait() final
116     {
117         if (future_) {
118             future_->Wait();
119         }
120     }
IsDone() const121     bool IsDone() const final
122     {
123         if (future_) {
124             return future_->IsDone();
125         }
126         return true;
127     }
128 
129 protected:
Destroy()130     void Destroy() final
131     {
132         delete this;
133     }
134 
135 private:
136     std::shared_ptr<State> future_;
137 };
138 
139 // -- ThreadPool
140 class ThreadPool final : public IThreadPool {
141 public:
ThreadPool(size_t threadCount)142     explicit ThreadPool(size_t threadCount)
143         : threadCount_(max(size_t(1), threadCount)), threads_(make_unique<ThreadContext[]>(threadCount_))
144     {
145         CORE_ASSERT(threads_);
146 
147         if (threadCount == 0U) {
148             CORE_LOG_W("Threadpool minimum thread count is 1");
149         }
150         // Create thread containers.
151         auto threads = array_view<ThreadContext>(threads_.get(), threadCount_);
152         for (ThreadContext& context : threads) {
153             // Set-up thread function.
154             context.thread = std::thread(&ThreadPool::ThreadProc, this, std::ref(context));
155         }
156     }
157 
158     ThreadPool(const ThreadPool&) = delete;
159     ThreadPool(ThreadPool&&) = delete;
160     ThreadPool& operator=(const ThreadPool&) = delete;
161     ThreadPool& operator=(ThreadPool&&) = delete;
162 
Push(ITask::Ptr task)163     IResult::Ptr Push(ITask::Ptr task) override
164     {
165         auto taskState = std::make_shared<TaskResult::State>();
166         if (taskState) {
167             if (task) {
168                 std::lock_guard lock(mutex_);
169                 q_.push_back(std::make_shared<Task>(BASE_NS::move(task), taskState));
170                 cv_.notify_one();
171             } else {
172                 // mark as done if the there was no function.
173                 taskState->Done();
174             }
175         }
176         return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
177     }
178 
Push(ITask::Ptr task,BASE_NS::array_view<const ITask * const> dependencies)179     IResult::Ptr Push(ITask::Ptr task, BASE_NS::array_view<const ITask* const> dependencies) override
180     {
181         if (dependencies.empty()) {
182             return Push(BASE_NS::move(task));
183         }
184         auto taskState = std::make_shared<TaskResult::State>();
185         if (taskState) {
186             if (task) {
187                 BASE_NS::vector<std::weak_ptr<Task>> deps;
188                 deps.reserve(dependencies.size());
189                 {
190                     std::lock_guard lock(mutex_);
191                     for (const ITask* dep : dependencies) {
192                         if (auto pos = std::find_if(q_.cbegin(), q_.cend(),
193                             [dep](const std::shared_ptr<Task>& task) { return task && (*task == dep); });
194                             pos != q_.cend()) {
195                             deps.push_back(*pos);
196                         }
197                     }
198                     q_.push_back(std::make_shared<Task>(BASE_NS::move(task), taskState, BASE_NS::move(deps)));
199                     cv_.notify_one();
200                 }
201             } else {
202                 // mark as done if the there was no function.
203                 taskState->Done();
204             }
205         }
206         return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
207     }
208 
PushNoWait(ITask::Ptr task)209     void PushNoWait(ITask::Ptr task) override
210     {
211         if (task) {
212             std::lock_guard lock(mutex_);
213             q_.push_back(std::make_shared<Task>(BASE_NS::move(task)));
214             cv_.notify_one();
215         }
216     }
217 
PushNoWait(ITask::Ptr task,BASE_NS::array_view<const ITask * const> dependencies)218     void PushNoWait(ITask::Ptr task, BASE_NS::array_view<const ITask* const> dependencies) override
219     {
220         if (dependencies.empty()) {
221             return PushNoWait(BASE_NS::move(task));
222         }
223 
224         if (task) {
225             BASE_NS::vector<std::weak_ptr<Task>> deps;
226             deps.reserve(dependencies.size());
227             {
228                 std::lock_guard lock(mutex_);
229                 for (const ITask* dep : dependencies) {
230                     if (auto pos = std::find_if(q_.cbegin(), q_.cend(),
231                         [dep](const std::shared_ptr<Task>& task) { return task && (*task == dep); });
232                         pos != q_.cend()) {
233                         deps.push_back(*pos);
234                     }
235                 }
236                 q_.push_back(std::make_shared<Task>(BASE_NS::move(task), BASE_NS::move(deps)));
237                 cv_.notify_one();
238             }
239         }
240     }
241 
GetNumberOfThreads() const242     uint32_t GetNumberOfThreads() const override
243     {
244         return static_cast<uint32_t>(threadCount_);
245     }
246 
247     // IInterface
GetInterface(const BASE_NS::Uid & uid) const248     const IInterface* GetInterface(const BASE_NS::Uid& uid) const override
249     {
250         if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
251             return this;
252         }
253         return nullptr;
254     }
255 
GetInterface(const BASE_NS::Uid & uid)256     IInterface* GetInterface(const BASE_NS::Uid& uid) override
257     {
258         if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
259             return this;
260         }
261         return nullptr;
262     }
263 
Ref()264     void Ref() override
265     {
266         refcnt_.fetch_add(1, std::memory_order_relaxed);
267     }
268 
Unref()269     void Unref() override
270     {
271         if (std::atomic_fetch_sub_explicit(&refcnt_, 1, std::memory_order_release) == 1) {
272             std::atomic_thread_fence(std::memory_order_acquire);
273             delete this;
274         }
275     }
276 
277 protected:
~ThreadPool()278     ~ThreadPool() final
279     {
280         Stop();
281     }
282 
283 private:
284     struct ThreadContext {
285         std::thread thread;
286     };
287 
288     // Helper which holds a pointer to a queued task function and the result state.
289     struct Task {
290         ITask::Ptr function_;
291         std::shared_ptr<TaskResult::State> state_;
292         BASE_NS::vector<std::weak_ptr<Task>> dependencies_;
293         bool running_ { false };
294 
295         ~Task() = default;
296         Task() = default;
297 
Task__anon0bea27f40111::ThreadPool::Task298         Task(ITask::Ptr&& function, std::shared_ptr<TaskResult::State> state,
299             BASE_NS::vector<std::weak_ptr<Task>>&& dependencies)
300             : function_(BASE_NS::move(function)),
301               state_(BASE_NS::move(state)), dependencies_ { BASE_NS::move(dependencies) }
302         {
303             CORE_ASSERT(this->function_ && this->state_);
304         }
305 
Task__anon0bea27f40111::ThreadPool::Task306         Task(ITask::Ptr&& function, std::shared_ptr<TaskResult::State> state)
307             : function_(BASE_NS::move(function)), state_(BASE_NS::move(state))
308         {
309             CORE_ASSERT(this->function_ && this->state_);
310         }
311 
Task__anon0bea27f40111::ThreadPool::Task312         explicit Task(ITask::Ptr&& function) : function_(BASE_NS::move(function))
313         {
314             CORE_ASSERT(this->function_);
315         }
316 
Task__anon0bea27f40111::ThreadPool::Task317         Task(ITask::Ptr&& function, BASE_NS::vector<std::weak_ptr<Task>>&& dependencies)
318             : function_(BASE_NS::move(function)), dependencies_ { BASE_NS::move(dependencies) }
319         {
320             CORE_ASSERT(this->function_);
321         }
322 
323         Task(Task&&) = default;
324         Task& operator=(Task&&) = default;
325         Task(const Task&) = delete;
326         Task& operator=(const Task&) = delete;
327 
operator ()__anon0bea27f40111::ThreadPool::Task328         inline void operator()() const
329         {
330             (*function_)();
331             if (state_) {
332                 state_->Done();
333             }
334         }
335 
operator ==__anon0bea27f40111::ThreadPool::Task336         inline bool operator==(const ITask* task) const
337         {
338             return function_.get() == task;
339         }
340 
341         // Task can run if it's not already running and there are no dependencies, or all the dependencies are ready.
CanRun__anon0bea27f40111::ThreadPool::Task342         inline bool CanRun() const
343         {
344             return !running_ && (dependencies_.empty() ||
345                                     std::all_of(std::begin(dependencies_), std::end(dependencies_),
346                                         [](const std::weak_ptr<Task>& dependency) { return dependency.expired(); }));
347         }
348     };
349 
350     // Looks for a task that can be executed.
FindRunnable()351     std::shared_ptr<Task> FindRunnable()
352     {
353         if (q_.empty()) {
354             return {};
355         }
356         for (auto& task : q_) {
357             if (task && task->CanRun()) {
358                 task->running_ = true;
359                 return task;
360             }
361         }
362         return {};
363     }
364 
Clear()365     void Clear()
366     {
367         std::lock_guard lock(mutex_);
368         q_.clear();
369     }
370 
371     // At the moment Stop is called only from the destructor with waitForAllTasksToComplete=true.
372     // If this doesn't change the class can be simplified a bit.
Stop()373     void Stop()
374     {
375         // Wait all tasks to complete before returning.
376         if (isDone_) {
377             return;
378         }
379         {
380             std::lock_guard lock(mutex_);
381             isDone_ = true;
382         }
383 
384         // Trigger all waiting threads.
385         cv_.notify_all();
386 
387         // Wait for all threads to finish.
388         auto threads = array_view(threads_.get(), threadCount_);
389         for (auto& context : threads) {
390             if (context.thread.joinable()) {
391                 context.thread.join();
392             }
393         }
394 
395         Clear();
396     }
397 
ThreadProc(ThreadContext & context)398     void ThreadProc(ThreadContext& context)
399     {
400 #ifdef PLATFORM_HAS_JAVA
401         // RAII class for handling thread setup/release.
402         JavaThreadContext javaContext;
403 #endif
404 
405         while (true) {
406             // Function to process.
407             std::shared_ptr<Task> task;
408             {
409                 std::unique_lock lock(mutex_);
410 
411                 // Try to wait for next task to process.
412                 cv_.wait(lock, [this, &task]() {
413                     task = FindRunnable();
414                     return task || isDone_;
415                 });
416             }
417             // If there was no task it means we are stopping and thread can exit.
418             if (!task) {
419                 return;
420             }
421 
422             while (task) {
423                 // Run task function.
424                 {
425                     CORE_CPU_PERF_SCOPE("CORE", "ThreadPoolTask", "", CORE_PROFILER_DEFAULT_COLOR);
426                     (*task)();
427                 }
428 
429                 std::lock_guard lock(mutex_);
430                 // After running the task remove it from the queue. Any dependent tasks will see their weak_ptr expire
431                 // idicating that the dependency has been completed.
432                 if (auto pos = std::find_if(q_.cbegin(), q_.cend(),
433                     [&task](const std::shared_ptr<Task>& queuedTask) { return queuedTask == task; });
434                     pos != q_.cend()) {
435                     q_.erase(pos);
436                 }
437                 task.reset();
438 
439                 // Get next function.
440                 if (auto pos = std::find_if(std::begin(q_), std::end(q_),
441                     [](const std::shared_ptr<Task>& task) { return (task) && (task->CanRun()); });
442                     pos != std::end(q_)) {
443                     task = *pos;
444                     task->running_ = true;
445                     // Check if there are more runnable tasks and notify workers as needed.
446                     auto runnable = std::min(static_cast<ptrdiff_t>(threadCount_),
447                         std::count_if(pos + 1, std::end(q_),
448                             [](const std::shared_ptr<Task>& task) { return (task) && (task->CanRun()); }));
449                     while (runnable--) {
450                         cv_.notify_one();
451                     }
452                 }
453             }
454         }
455     }
456 
457     size_t threadCount_ { 0 };
458     unique_ptr<ThreadContext[]> threads_;
459 
460     std::deque<std::shared_ptr<Task>> q_;
461 
462     bool isDone_ { false };
463 
464     std::mutex mutex_;
465     std::condition_variable cv_;
466     std::atomic<int32_t> refcnt_ { 0 };
467 };
468 } // namespace
469 
GetNumberOfCores() const470 uint32_t TaskQueueFactory::GetNumberOfCores() const
471 {
472     uint32_t result = std::thread::hardware_concurrency();
473     if (result == 0) {
474         // If not detectable, default to 4.
475         result = 4;
476     }
477 
478     return result;
479 }
480 
CreateThreadPool(const uint32_t threadCount) const481 IThreadPool::Ptr TaskQueueFactory::CreateThreadPool(const uint32_t threadCount) const
482 {
483     return IThreadPool::Ptr { new ThreadPool(threadCount) };
484 }
485 
CreateDispatcherTaskQueue(const IThreadPool::Ptr & threadPool) const486 IDispatcherTaskQueue::Ptr TaskQueueFactory::CreateDispatcherTaskQueue(const IThreadPool::Ptr& threadPool) const
487 {
488     return IDispatcherTaskQueue::Ptr { make_unique<DispatcherImpl>(threadPool).release() };
489 }
490 
CreateParallelTaskQueue(const IThreadPool::Ptr & threadPool) const491 IParallelTaskQueue::Ptr TaskQueueFactory::CreateParallelTaskQueue(const IThreadPool::Ptr& threadPool) const
492 {
493     return IParallelTaskQueue::Ptr { make_unique<ParallelImpl>(threadPool).release() };
494 }
495 
CreateSequentialTaskQueue(const IThreadPool::Ptr & threadPool) const496 ISequentialTaskQueue::Ptr TaskQueueFactory::CreateSequentialTaskQueue(const IThreadPool::Ptr& threadPool) const
497 {
498     return ISequentialTaskQueue::Ptr { make_unique<SequentialImpl>(threadPool).release() };
499 }
500 
501 // IInterface
GetInterface(const BASE_NS::Uid & uid) const502 const IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid) const
503 {
504     if (uid == ITaskQueueFactory::UID) {
505         return static_cast<const ITaskQueueFactory*>(this);
506     }
507     return nullptr;
508 }
509 
GetInterface(const BASE_NS::Uid & uid)510 IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid)
511 {
512     if (uid == ITaskQueueFactory::UID) {
513         return static_cast<ITaskQueueFactory*>(this);
514     }
515     return nullptr;
516 }
517 
Ref()518 void TaskQueueFactory::Ref() {}
519 
Unref()520 void TaskQueueFactory::Unref() {}
521 CORE_END_NAMESPACE()
522