• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "threading/task_queue_factory.h"
17 
18 #include <atomic>
19 #include <condition_variable>
20 #include <memory>
21 #include <mutex>
22 #include <queue>
23 #include <thread>
24 
25 #include <base/containers/array_view.h>
26 #include <base/containers/type_traits.h>
27 #include <base/containers/unique_ptr.h>
28 #include <core/log.h>
29 
30 #include "os/platform.h"
31 #include "threading/dispatcher_impl.h"
32 #include "threading/parallel_impl.h"
33 #include "threading/sequential_impl.h"
34 
35 #ifdef PLATFORM_HAS_JAVA
36 #include <os/java/java_internal.h>
37 #endif
38 
39 CORE_BEGIN_NAMESPACE()
40 using BASE_NS::array_view;
41 using BASE_NS::make_unique;
42 using BASE_NS::move;
43 using BASE_NS::unique_ptr;
44 
45 namespace {
46 #ifdef PLATFORM_HAS_JAVA
47 /** RAII class for handling thread setup/release. */
48 class JavaThreadContext final {
49 public:
JavaThreadContext()50     JavaThreadContext()
51     {
52         JNIEnv* env = nullptr;
53         javaVm_ = java_internal::GetJavaVM();
54 
55 #ifndef NDEBUG
56         // Check that the thread was not already attached.
57         // It's not really a problem as another attach is a no-op, but we will be detaching the
58         // thread later and it may be unexpected for the user.
59         jint result = javaVm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
60         CORE_ASSERT_MSG((result != JNI_OK), "Thread already attached");
61 #endif
62 
63         javaVm_->AttachCurrentThread(&env, nullptr);
64     }
65 
~JavaThreadContext()66     ~JavaThreadContext()
67     {
68         javaVm_->DetachCurrentThread();
69     }
70     JavaVM* javaVm_ { nullptr };
71 };
72 #endif // PLATFORM_HAS_JAVA
73 
74 // -- TaskResult, returned by ThreadPool::Push and can be waited on.
75 class TaskResult final : public IThreadPool::IResult {
76 public:
77     // Task state which can be waited and marked as done.
78     class State {
79     public:
Done()80         void Done()
81         {
82             {
83                 auto lock = std::lock_guard(mutex_);
84                 done_ = true;
85             }
86             cv_.notify_all();
87         }
88 
Wait()89         void Wait()
90         {
91             auto lock = std::unique_lock(mutex_);
92             cv_.wait(lock, [this]() { return done_; });
93         }
94 
95     private:
96         std::mutex mutex_;
97         std::condition_variable cv_;
98         bool done_ { false };
99     };
100 
TaskResult(std::shared_ptr<State> && future)101     explicit TaskResult(std::shared_ptr<State>&& future) : future_(BASE_NS::move(future)) {}
102     ~TaskResult() = default;
103 
Wait()104     void Wait() override
105     {
106         if (future_) {
107             future_->Wait();
108         }
109     }
110 
111 protected:
Destroy()112     void Destroy() override
113     {
114         delete this;
115     }
116 
117 private:
118     std::shared_ptr<State> future_;
119 };
120 
121 // -- ThreadPool
122 class ThreadPool final : public IThreadPool {
123 public:
ThreadPool(size_t threadCount)124     explicit ThreadPool(size_t threadCount)
125         : threadCount_(threadCount), threads_(make_unique<ThreadContext[]>(threadCount))
126     {
127         CORE_ASSERT(threads_);
128 
129         // Create thread containers.
130         auto threads = array_view<ThreadContext>(threads_.get(), threadCount_);
131         for (ThreadContext& context : threads) {
132             // Set-up thread function.
133             context.thread = std::thread(&ThreadPool::ThreadProc, this, std::ref(context));
134         }
135     }
136 
Push(ITask::Ptr function)137     IResult::Ptr Push(ITask::Ptr function) override
138     {
139         auto taskState = std::make_shared<TaskResult::State>();
140         if (taskState) {
141             if (function) {
142                 {
143                     std::lock_guard lock(mutex_);
144                     q_.Push(Task(move(function), taskState));
145                 }
146                 cv_.notify_one();
147             } else {
148                 // mark as done if the there was no function.
149                 taskState->Done();
150             }
151         }
152         return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
153     }
154 
PushNoWait(ITask::Ptr function)155     void PushNoWait(ITask::Ptr function) override
156     {
157         if (function) {
158             {
159                 std::lock_guard lock(mutex_);
160                 q_.Push(Task(move(function)));
161             }
162             cv_.notify_one();
163         }
164     }
165 
GetNumberOfThreads() const166     uint32_t GetNumberOfThreads() const override
167     {
168         return static_cast<uint32_t>(threadCount_);
169     }
170 
171     // IInterface
GetInterface(const BASE_NS::Uid & uid) const172     const IInterface* GetInterface(const BASE_NS::Uid& uid) const override
173     {
174         if (uid == IThreadPool::UID) {
175             return static_cast<const IThreadPool*>(this);
176         }
177         return nullptr;
178     }
179 
GetInterface(const BASE_NS::Uid & uid)180     IInterface* GetInterface(const BASE_NS::Uid& uid) override
181     {
182         if (uid == IThreadPool::UID) {
183             return static_cast<IThreadPool*>(this);
184         }
185         return nullptr;
186     }
187 
Ref()188     void Ref() override
189     {
190         refcnt_.fetch_add(1, std::memory_order_relaxed);
191     }
192 
Unref()193     void Unref() override
194     {
195         if (std::atomic_fetch_sub_explicit(&refcnt_, 1, std::memory_order_release) == 1) {
196             std::atomic_thread_fence(std::memory_order_acquire);
197             delete this;
198         }
199     }
200 
201 protected:
~ThreadPool()202     virtual ~ThreadPool()
203     {
204         Stop(true);
205     }
206 
207 private:
208     ThreadPool(const ThreadPool&) = delete;
209     ThreadPool(ThreadPool&&) = delete;
210     ThreadPool& operator=(const ThreadPool&) = delete;
211     ThreadPool& operator=(ThreadPool&&) = delete;
212 
213     // Helper which holds a pointer to a queued task function and the result state.
214     struct Task {
215         ITask::Ptr function_;
216         std::shared_ptr<TaskResult::State> state_;
217 
218         ~Task() = default;
219         Task() = default;
Task__anond6001d750111::ThreadPool::Task220         explicit Task(ITask::Ptr&& function, std::shared_ptr<TaskResult::State> state)
221             : function_(move(function)), state_(CORE_NS::move(state))
222         {
223             CORE_ASSERT(this->function_ && this->state_);
224         }
Task__anond6001d750111::ThreadPool::Task225         explicit Task(ITask::Ptr&& function) : function_(move(function))
226         {
227             CORE_ASSERT(this->function_);
228         }
229         Task(Task&&) = default;
230         Task& operator=(Task&&) = default;
231         Task(const Task&) = delete;
232         Task& operator=(const Task&) = delete;
233 
operator ()__anond6001d750111::ThreadPool::Task234         void operator()()
235         {
236             (*function_)();
237             if (state_) {
238                 state_->Done();
239             }
240         }
241     };
242 
243     template<typename T>
244     class Queue {
245     public:
Push(T && value)246         bool Push(T&& value)
247         {
248             q_.push(move(value));
249             return true;
250         }
251 
Pop(T & v)252         bool Pop(T& v)
253         {
254             if (q_.empty()) {
255                 v = {};
256                 return false;
257             }
258             v = CORE_NS::move(q_.front());
259             q_.pop();
260             return true;
261         }
262 
263     private:
264         std::queue<T> q_;
265     };
266 
267     struct ThreadContext {
268         std::thread thread;
269         bool exit { false };
270     };
271 
Clear()272     void Clear()
273     {
274         Task f;
275         std::lock_guard lock(mutex_);
276         while (q_.Pop(f)) {
277             // Intentionally empty.
278         }
279     }
280 
281     // At the moment Stop is called only from the destructor with waitForAllTasksToComplete=true.
282     // If this doesn't change the class can be simplified a bit.
Stop(bool waitForAllTasksToComplete)283     void Stop(bool waitForAllTasksToComplete)
284     {
285         if (isStop_) {
286             return;
287         }
288         if (waitForAllTasksToComplete) {
289             // Wait all tasks to complete before returning.
290             if (isDone_) {
291                 return;
292             }
293             std::lock_guard lock(mutex_);
294             isDone_ = true;
295         } else {
296             isStop_ = true;
297 
298             // Ask all the threads to stop and not process any more tasks.
299             auto threads = array_view(threads_.get(), threadCount_);
300             {
301                 auto lock = std::lock_guard(mutex_);
302                 for (auto& context : threads) {
303                     context.exit = true;
304                 }
305             }
306             Clear();
307         }
308 
309         // Trigger all waiting threads.
310         cv_.notify_all();
311 
312         // Wait for all threads to finish.
313         auto threads = array_view(threads_.get(), threadCount_);
314         for (auto& context : threads) {
315             if (context.thread.joinable()) {
316                 context.thread.join();
317             }
318         }
319 
320         Clear();
321     }
322 
ThreadProc(ThreadContext & context)323     void ThreadProc(ThreadContext& context)
324     {
325 #ifdef PLATFORM_HAS_HAVA
326         // RAII class for handling thread setup/release.
327         JavaThreadContext javaContext;
328 #endif
329 
330         // Get function to process.
331         Task func;
332         bool isPop = [this, &func]() {
333             std::lock_guard lock(mutex_);
334             return q_.Pop(func);
335         }();
336 
337         while (true) {
338             while (isPop) {
339                 // Run task function.
340                 func();
341 
342                 // If the thread is wanted to stop, return even if the queue is not empty yet.
343                 std::lock_guard lock(mutex_);
344                 if (context.exit) {
345                     return;
346                 }
347 
348                 // Get next function.
349                 isPop = q_.Pop(func);
350             }
351 
352             // The queue is empty here, wait for the next task.
353             std::unique_lock lock(mutex_);
354 
355             // Try to wait for next task to process.
356             cv_.wait(lock, [this, &func, &isPop, &context]() {
357                 isPop = q_.Pop(func);
358                 return isPop || isDone_ || context.exit;
359             });
360 
361             if (!isPop) {
362                 return;
363             }
364         }
365     }
366 
367     size_t threadCount_ { 0 };
368     unique_ptr<ThreadContext[]> threads_;
369 
370     Queue<Task> q_;
371     bool isDone_ { false };
372     bool isStop_ { false };
373 
374     std::mutex mutex_;
375     std::condition_variable cv_;
376     std::atomic<int32_t> refcnt_ { 0 };
377 };
378 } // namespace
379 
GetNumberOfCores() const380 uint32_t TaskQueueFactory::GetNumberOfCores() const
381 {
382     uint32_t result = std::thread::hardware_concurrency();
383     if (result == 0) {
384         // If not detectable, default to 4.
385         result = 4;
386     }
387 
388     return result;
389 }
390 
CreateThreadPool(const uint32_t threadCount) const391 IThreadPool::Ptr TaskQueueFactory::CreateThreadPool(const uint32_t threadCount) const
392 {
393     return IThreadPool::Ptr { new ThreadPool(threadCount) };
394 }
395 
CreateDispatcherTaskQueue(const IThreadPool::Ptr & threadPool) const396 IDispatcherTaskQueue::Ptr TaskQueueFactory::CreateDispatcherTaskQueue(const IThreadPool::Ptr& threadPool) const
397 {
398     return IDispatcherTaskQueue::Ptr { make_unique<DispatcherImpl>(threadPool).release() };
399 }
400 
CreateParallelTaskQueue(const IThreadPool::Ptr & threadPool) const401 IParallelTaskQueue::Ptr TaskQueueFactory::CreateParallelTaskQueue(const IThreadPool::Ptr& threadPool) const
402 {
403     return IParallelTaskQueue::Ptr { make_unique<ParallelImpl>(threadPool).release() };
404 }
405 
CreateSequentialTaskQueue(const IThreadPool::Ptr & threadPool) const406 ISequentialTaskQueue::Ptr TaskQueueFactory::CreateSequentialTaskQueue(const IThreadPool::Ptr& threadPool) const
407 {
408     return ISequentialTaskQueue::Ptr { make_unique<SequentialImpl>(threadPool).release() };
409 }
410 
411 // IInterface
GetInterface(const BASE_NS::Uid & uid) const412 const IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid) const
413 {
414     if (uid == ITaskQueueFactory::UID) {
415         return static_cast<const ITaskQueueFactory*>(this);
416     }
417     return nullptr;
418 }
419 
GetInterface(const BASE_NS::Uid & uid)420 IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid)
421 {
422     if (uid == ITaskQueueFactory::UID) {
423         return static_cast<ITaskQueueFactory*>(this);
424     }
425     return nullptr;
426 }
427 
Ref()428 void TaskQueueFactory::Ref() {}
429 
Unref()430 void TaskQueueFactory::Unref() {}
431 CORE_END_NAMESPACE()
432