• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "threading/task_queue_factory.h"
17 
18 #include <algorithm>
19 #include <condition_variable>
20 #include <cstddef>
21 #include <deque>
22 #include <thread>
23 
24 #include <base/containers/array_view.h>
25 #include <base/containers/atomics.h>
26 #include <base/containers/iterator.h>
27 #include <base/containers/shared_ptr.h>
28 #include <base/containers/type_traits.h>
29 #include <base/containers/unique_ptr.h>
30 #include <base/math/mathf.h>
31 #include <base/util/uid.h>
32 #include <core/log.h>
33 #include <core/perf/cpu_perf_scope.h>
34 #include <core/threading/intf_thread_pool.h>
35 
36 #include "threading/dispatcher_impl.h"
37 #include "threading/parallel_impl.h"
38 #include "threading/sequential_impl.h"
39 
40 #ifdef PLATFORM_HAS_JAVA
41 #include <os/java/java_internal.h>
42 #endif
43 
44 CORE_BEGIN_NAMESPACE()
45 using BASE_NS::array_view;
46 using BASE_NS::make_unique;
47 using BASE_NS::move;
48 using BASE_NS::unique_ptr;
49 using BASE_NS::Math::max;
50 
51 namespace {
52 #ifdef PLATFORM_HAS_JAVA
53 /** RAII class for handling thread setup/release. */
54 class JavaThreadContext final {
55 public:
JavaThreadContext()56     JavaThreadContext()
57     {
58         JNIEnv* env = nullptr;
59         javaVm_ = java_internal::GetJavaVM();
60 
61 #ifndef NDEBUG
62         // Check that the thread was not already attached.
63         // It's not really a problem as another attach is a no-op, but we will be detaching the
64         // thread later and it may be unexpected for the user.
65         jint result = javaVm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
66         CORE_ASSERT_MSG((result != JNI_OK), "Thread already attached");
67 #endif
68 
69         javaVm_->AttachCurrentThread(&env, nullptr);
70     }
71 
~JavaThreadContext()72     ~JavaThreadContext()
73     {
74         javaVm_->DetachCurrentThread();
75     }
76     JavaVM* javaVm_ { nullptr };
77 };
78 #endif // PLATFORM_HAS_JAVA
79 
80 // -- TaskResult, returned by ThreadPool::Push and can be waited on.
81 class TaskResult final : public IThreadPool::IResult {
82 public:
83     // Task state which can be waited and marked as done.
84     class State {
85     public:
Done()86         void Done()
87         {
88             {
89                 auto lock = std::lock_guard(mutex_);
90                 done_ = true;
91             }
92             cv_.notify_all();
93         }
94 
Wait()95         void Wait()
96         {
97             auto lock = std::unique_lock(mutex_);
98             cv_.wait(lock, [this]() { return done_; });
99         }
100 
IsDone() const101         bool IsDone() const
102         {
103             auto lock = std::lock_guard(mutex_);
104             return done_;
105         }
106 
107     private:
108         mutable std::mutex mutex_;
109         std::condition_variable cv_;
110         bool done_ { false };
111     };
112 
TaskResult(BASE_NS::shared_ptr<State> && future)113     explicit TaskResult(BASE_NS::shared_ptr<State>&& future) : future_(BASE_NS::move(future)) {}
114 
Wait()115     void Wait() final
116     {
117         if (future_) {
118             future_->Wait();
119         }
120     }
IsDone() const121     bool IsDone() const final
122     {
123         if (future_) {
124             return future_->IsDone();
125         }
126         return true;
127     }
128 
129 protected:
Destroy()130     void Destroy() final
131     {
132         delete this;
133     }
134 
135 private:
136     BASE_NS::shared_ptr<State> future_;
137 };
138 
139 // -- ThreadPool
140 class ThreadPool final : public IThreadPool {
141 public:
ThreadPool(size_t threadCount)142     explicit ThreadPool(size_t threadCount)
143         : threadCount_(max(size_t(1), threadCount)), threads_(make_unique<ThreadContext[]>(threadCount_))
144     {
145         CORE_ASSERT(threads_);
146 
147         if (threadCount == 0U) {
148             CORE_LOG_W("Threadpool minimum thread count is 1");
149         }
150         // Create thread containers.
151         auto threads = array_view<ThreadContext>(threads_.get(), threadCount_);
152         for (ThreadContext& context : threads) {
153             // Set-up thread function.
154             context.thread = std::thread(&ThreadPool::ThreadProc, this, std::ref(context));
155         }
156     }
157 
158     ThreadPool(const ThreadPool&) = delete;
159     ThreadPool(ThreadPool&&) = delete;
160     ThreadPool& operator=(const ThreadPool&) = delete;
161     ThreadPool& operator=(ThreadPool&&) = delete;
162 
Push(ITask::Ptr task)163     IResult::Ptr Push(ITask::Ptr task) override
164     {
165         auto taskState = BASE_NS::make_shared<TaskResult::State>();
166         if (taskState) {
167             if (task) {
168                 std::lock_guard lock(mutex_);
169                 q_.push_back(BASE_NS::make_shared<Task>(BASE_NS::move(task), taskState));
170                 cv_.notify_one();
171             } else {
172                 // mark as done if the there was no function.
173                 taskState->Done();
174             }
175         }
176         return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
177     }
178 
Push(ITask::Ptr task,BASE_NS::array_view<const ITask * const> dependencies)179     IResult::Ptr Push(ITask::Ptr task, BASE_NS::array_view<const ITask* const> dependencies) override
180     {
181         if (dependencies.empty()) {
182             return Push(BASE_NS::move(task));
183         }
184         auto taskState = BASE_NS::make_shared<TaskResult::State>();
185         if (taskState) {
186             if (task) {
187                 BASE_NS::vector<BASE_NS::weak_ptr<Task>> deps;
188                 deps.reserve(dependencies.size());
189                 {
190                     std::lock_guard lock(mutex_);
191                     for (const ITask* dep : dependencies) {
192                         if (auto pos = std::find_if(
193                                 q_.cbegin(), q_.cend(),
194                                 [dep](const BASE_NS::shared_ptr<Task> &task) { return task && (*task == dep); });
195                             pos != q_.cend()) {
196                             deps.push_back(*pos);
197                         }
198                     }
199                     q_.push_back(BASE_NS::make_shared<Task>(BASE_NS::move(task), taskState, BASE_NS::move(deps)));
200                     cv_.notify_one();
201                 }
202             } else {
203                 // mark as done if the there was no function.
204                 taskState->Done();
205             }
206         }
207         return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
208     }
209 
PushNoWait(ITask::Ptr task)210     void PushNoWait(ITask::Ptr task) override
211     {
212         if (task) {
213             std::lock_guard lock(mutex_);
214             q_.push_back(BASE_NS::make_shared<Task>(BASE_NS::move(task)));
215             cv_.notify_one();
216         }
217     }
218 
PushNoWait(ITask::Ptr task,BASE_NS::array_view<const ITask * const> dependencies)219     void PushNoWait(ITask::Ptr task, BASE_NS::array_view<const ITask* const> dependencies) override
220     {
221         if (dependencies.empty()) {
222             return PushNoWait(BASE_NS::move(task));
223         }
224 
225         if (task) {
226             BASE_NS::vector<BASE_NS::weak_ptr<Task>> deps;
227             deps.reserve(dependencies.size());
228             {
229                 std::lock_guard lock(mutex_);
230                 for (const ITask* dep : dependencies) {
231                     if (auto pos = std::find_if(
232                             q_.cbegin(), q_.cend(),
233                             [dep](const BASE_NS::shared_ptr<Task> &task) { return task && (*task == dep); });
234                         pos != q_.cend()) {
235                         deps.push_back(*pos);
236                     }
237                 }
238                 q_.push_back(BASE_NS::make_shared<Task>(BASE_NS::move(task), BASE_NS::move(deps)));
239                 cv_.notify_one();
240             }
241         }
242     }
243 
GetNumberOfThreads() const244     uint32_t GetNumberOfThreads() const override
245     {
246         return static_cast<uint32_t>(threadCount_);
247     }
248 
249     // IInterface
GetInterface(const BASE_NS::Uid & uid) const250     const IInterface* GetInterface(const BASE_NS::Uid& uid) const override
251     {
252         if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
253             return this;
254         }
255         return nullptr;
256     }
257 
GetInterface(const BASE_NS::Uid & uid)258     IInterface* GetInterface(const BASE_NS::Uid& uid) override
259     {
260         if ((uid == IThreadPool::UID) || (uid == IInterface::UID)) {
261             return this;
262         }
263         return nullptr;
264     }
265 
Ref()266     void Ref() override
267     {
268         BASE_NS::AtomicIncrementRelaxed(&refcnt_);
269     }
270 
Unref()271     void Unref() override
272     {
273         if (BASE_NS::AtomicDecrementRelease(&refcnt_) == 0) {
274             BASE_NS::AtomicFenceAcquire();
275             delete this;
276         }
277     }
278 
279 protected:
~ThreadPool()280     ~ThreadPool() final
281     {
282         Stop();
283     }
284 
285 private:
286     struct ThreadContext {
287         std::thread thread;
288     };
289 
290     // Helper which holds a pointer to a queued task function and the result state.
291     struct Task {
292         ITask::Ptr function_;
293         BASE_NS::shared_ptr<TaskResult::State> state_;
294         BASE_NS::vector<BASE_NS::weak_ptr<Task>> dependencies_;
295         bool running_ { false };
296 
297         ~Task() = default;
298         Task() = default;
299 
Task__anonc0e55d560111::ThreadPool::Task300         Task(ITask::Ptr&& function, BASE_NS::shared_ptr<TaskResult::State> state,
301             BASE_NS::vector<BASE_NS::weak_ptr<Task>>&& dependencies)
302             : function_(BASE_NS::move(function)),
303               state_(BASE_NS::move(state)), dependencies_ { BASE_NS::move(dependencies) }
304         {
305             CORE_ASSERT(this->function_ && this->state_);
306         }
307 
Task__anonc0e55d560111::ThreadPool::Task308         Task(ITask::Ptr&& function, BASE_NS::shared_ptr<TaskResult::State> state)
309             : function_(BASE_NS::move(function)), state_(BASE_NS::move(state))
310         {
311             CORE_ASSERT(this->function_ && this->state_);
312         }
313 
Task__anonc0e55d560111::ThreadPool::Task314         explicit Task(ITask::Ptr&& function) : function_(BASE_NS::move(function))
315         {
316             CORE_ASSERT(this->function_);
317         }
318 
Task__anonc0e55d560111::ThreadPool::Task319         Task(ITask::Ptr&& function, BASE_NS::vector<BASE_NS::weak_ptr<Task>>&& dependencies)
320             : function_(BASE_NS::move(function)), dependencies_ { BASE_NS::move(dependencies) }
321         {
322             CORE_ASSERT(this->function_);
323         }
324 
325         Task(Task&&) = default;
326         Task& operator=(Task&&) = default;
327         Task(const Task&) = delete;
328         Task& operator=(const Task&) = delete;
329 
operator ()__anonc0e55d560111::ThreadPool::Task330         inline void operator()() const
331         {
332             (*function_)();
333             if (state_) {
334                 state_->Done();
335             }
336         }
337 
operator ==__anonc0e55d560111::ThreadPool::Task338         inline bool operator==(const ITask* task) const
339         {
340             return function_.get() == task;
341         }
342 
343         // Task can run if it's not already running and there are no dependencies, or all the dependencies are ready.
CanRun__anonc0e55d560111::ThreadPool::Task344         inline bool CanRun() const
345         {
346             return !running_ &&
347                    (dependencies_.empty() ||
348                        std::all_of(std::begin(dependencies_), std::end(dependencies_),
349                            [](const BASE_NS::weak_ptr<Task>& dependency) { return dependency.expired(); }));
350         }
351     };
352 
353     // Looks for a task that can be executed.
FindRunnable()354     BASE_NS::shared_ptr<Task> FindRunnable()
355     {
356         if (q_.empty()) {
357             return {};
358         }
359         for (auto& task : q_) {
360             if (task && task->CanRun()) {
361                 task->running_ = true;
362                 return task;
363             }
364         }
365         return {};
366     }
367 
Clear()368     void Clear()
369     {
370         std::lock_guard lock(mutex_);
371         q_.clear();
372     }
373 
374     // At the moment Stop is called only from the destructor with waitForAllTasksToComplete=true.
375     // If this doesn't change the class can be simplified a bit.
Stop()376     void Stop()
377     {
378         // Wait all tasks to complete before returning.
379         if (isDone_) {
380             return;
381         }
382         {
383             std::lock_guard lock(mutex_);
384             isDone_ = true;
385         }
386 
387         // Trigger all waiting threads.
388         cv_.notify_all();
389 
390         // Wait for all threads to finish.
391         auto threads = array_view(threads_.get(), threadCount_);
392         for (auto& context : threads) {
393             if (context.thread.joinable()) {
394                 context.thread.join();
395             }
396         }
397 
398         Clear();
399     }
400 
ThreadProc(ThreadContext & context)401     void ThreadProc(ThreadContext& context)
402     {
403 #ifdef PLATFORM_HAS_JAVA
404         // RAII class for handling thread setup/release.
405         JavaThreadContext javaContext;
406 #endif
407 
408         while (true) {
409             // Function to process.
410             BASE_NS::shared_ptr<Task> task;
411             {
412                 std::unique_lock lock(mutex_);
413 
414                 // Try to wait for next task to process.
415                 cv_.wait(lock, [this, &task]() {
416                     task = FindRunnable();
417                     return task || isDone_;
418                 });
419             }
420             // If there was no task it means we are stopping and thread can exit.
421             if (!task) {
422                 return;
423             }
424 
425             while (task) {
426                 // Run task function.
427                 {
428                     CORE_CPU_PERF_SCOPE("CORE", "ThreadPoolTask", "", CORE_PROFILER_DEFAULT_COLOR);
429                     (*task)();
430                 }
431 
432                 std::lock_guard lock(mutex_);
433                 // After running the task remove it from the queue. Any dependent tasks will see their weak_ptr expire
434                 // idicating that the dependency has been completed.
435                 if (auto pos = std::find_if(
436                         q_.cbegin(), q_.cend(),
437                         [&task](const BASE_NS::shared_ptr<Task> &queuedTask) { return queuedTask == task; });
438                     pos != q_.cend()) {
439                     q_.erase(pos);
440                 }
441                 task.reset();
442 
443                 // Get next function.
444                 if (auto pos =
445                         std::find_if(std::begin(q_), std::end(q_),
446                                      [](const BASE_NS::shared_ptr<Task> &task) { return (task) && (task->CanRun()); });
447                     pos != std::end(q_)) {
448                     task = *pos;
449                     task->running_ = true;
450                     // Check if there are more runnable tasks and notify workers as needed.
451                     auto runnable = std::min(static_cast<ptrdiff_t>(threadCount_),
452                         std::count_if(pos + 1, std::end(q_),
453                             [](const BASE_NS::shared_ptr<Task>& task) { return (task) && (task->CanRun()); }));
454                     while (runnable--) {
455                         cv_.notify_one();
456                     }
457                 }
458             }
459         }
460     }
461 
462     size_t threadCount_ { 0 };
463     unique_ptr<ThreadContext[]> threads_;
464 
465     std::deque<BASE_NS::shared_ptr<Task>> q_;
466 
467     bool isDone_ { false };
468 
469     std::mutex mutex_;
470     std::condition_variable cv_;
471     int32_t refcnt_ { 0 };
472 };
473 } // namespace
474 
GetNumberOfCores() const475 uint32_t TaskQueueFactory::GetNumberOfCores() const
476 {
477     uint32_t result = std::thread::hardware_concurrency();
478     if (result == 0) {
479         // If not detectable, default to 4.
480         result = 4;
481     }
482 
483     return result;
484 }
485 
CreateThreadPool(const uint32_t threadCount) const486 IThreadPool::Ptr TaskQueueFactory::CreateThreadPool(const uint32_t threadCount) const
487 {
488     return IThreadPool::Ptr { new ThreadPool(threadCount) };
489 }
490 
CreateDispatcherTaskQueue(const IThreadPool::Ptr & threadPool) const491 IDispatcherTaskQueue::Ptr TaskQueueFactory::CreateDispatcherTaskQueue(const IThreadPool::Ptr& threadPool) const
492 {
493     return IDispatcherTaskQueue::Ptr { make_unique<DispatcherImpl>(threadPool).release() };
494 }
495 
CreateParallelTaskQueue(const IThreadPool::Ptr & threadPool) const496 IParallelTaskQueue::Ptr TaskQueueFactory::CreateParallelTaskQueue(const IThreadPool::Ptr& threadPool) const
497 {
498     return IParallelTaskQueue::Ptr { make_unique<ParallelImpl>(threadPool).release() };
499 }
500 
CreateSequentialTaskQueue(const IThreadPool::Ptr & threadPool) const501 ISequentialTaskQueue::Ptr TaskQueueFactory::CreateSequentialTaskQueue(const IThreadPool::Ptr& threadPool) const
502 {
503     return ISequentialTaskQueue::Ptr { make_unique<SequentialImpl>(threadPool).release() };
504 }
505 
506 // IInterface
GetInterface(const BASE_NS::Uid & uid) const507 const IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid) const
508 {
509     if (uid == ITaskQueueFactory::UID) {
510         return static_cast<const ITaskQueueFactory*>(this);
511     }
512     return nullptr;
513 }
514 
GetInterface(const BASE_NS::Uid & uid)515 IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid)
516 {
517     if (uid == ITaskQueueFactory::UID) {
518         return static_cast<ITaskQueueFactory*>(this);
519     }
520     return nullptr;
521 }
522 
Ref()523 void TaskQueueFactory::Ref() {}
524 
Unref()525 void TaskQueueFactory::Unref() {}
526 CORE_END_NAMESPACE()
527