1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "threading/task_queue_factory.h"
17
18 #include <atomic>
19 #include <condition_variable>
20 #include <memory>
21 #include <mutex>
22 #include <queue>
23 #include <thread>
24
25 #include <base/containers/array_view.h>
26 #include <base/containers/type_traits.h>
27 #include <base/containers/unique_ptr.h>
28 #include <core/log.h>
29
30 #include "os/platform.h"
31 #include "threading/dispatcher_impl.h"
32 #include "threading/parallel_impl.h"
33 #include "threading/sequential_impl.h"
34
35 #ifdef PLATFORM_HAS_JAVA
36 #include <os/java/java_internal.h>
37 #endif
38
39 CORE_BEGIN_NAMESPACE()
40 using BASE_NS::array_view;
41 using BASE_NS::make_unique;
42 using BASE_NS::move;
43 using BASE_NS::unique_ptr;
44
45 namespace {
46 #ifdef PLATFORM_HAS_JAVA
47 /** RAII class for handling thread setup/release. */
48 class JavaThreadContext final {
49 public:
JavaThreadContext()50 JavaThreadContext()
51 {
52 JNIEnv* env = nullptr;
53 javaVm_ = java_internal::GetJavaVM();
54
55 #ifndef NDEBUG
56 // Check that the thread was not already attached.
57 // It's not really a problem as another attach is a no-op, but we will be detaching the
58 // thread later and it may be unexpected for the user.
59 jint result = javaVm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
60 CORE_ASSERT_MSG((result != JNI_OK), "Thread already attached");
61 #endif
62
63 javaVm_->AttachCurrentThread(&env, nullptr);
64 }
65
~JavaThreadContext()66 ~JavaThreadContext()
67 {
68 javaVm_->DetachCurrentThread();
69 }
70 JavaVM* javaVm_ { nullptr };
71 };
72 #endif // PLATFORM_HAS_JAVA
73
74 // -- TaskResult, returned by ThreadPool::Push and can be waited on.
75 class TaskResult final : public IThreadPool::IResult {
76 public:
77 // Task state which can be waited and marked as done.
78 class State {
79 public:
Done()80 void Done()
81 {
82 {
83 auto lock = std::lock_guard(mutex_);
84 done_ = true;
85 }
86 cv_.notify_all();
87 }
88
Wait()89 void Wait()
90 {
91 auto lock = std::unique_lock(mutex_);
92 cv_.wait(lock, [this]() { return done_; });
93 }
94
95 private:
96 std::mutex mutex_;
97 std::condition_variable cv_;
98 bool done_ { false };
99 };
100
TaskResult(std::shared_ptr<State> && future)101 explicit TaskResult(std::shared_ptr<State>&& future) : future_(BASE_NS::move(future)) {}
102 ~TaskResult() = default;
103
Wait()104 void Wait() override
105 {
106 if (future_) {
107 future_->Wait();
108 }
109 }
110
111 protected:
Destroy()112 void Destroy() override
113 {
114 delete this;
115 }
116
117 private:
118 std::shared_ptr<State> future_;
119 };
120
121 // -- ThreadPool
122 class ThreadPool final : public IThreadPool {
123 public:
ThreadPool(size_t threadCount)124 explicit ThreadPool(size_t threadCount)
125 : threadCount_(threadCount), threads_(make_unique<ThreadContext[]>(threadCount))
126 {
127 CORE_ASSERT(threads_);
128
129 // Create thread containers.
130 auto threads = array_view<ThreadContext>(threads_.get(), threadCount_);
131 for (ThreadContext& context : threads) {
132 // Set-up thread function.
133 context.thread = std::thread(&ThreadPool::ThreadProc, this, std::ref(context));
134 }
135 }
136
Push(ITask::Ptr function)137 IResult::Ptr Push(ITask::Ptr function) override
138 {
139 auto taskState = std::make_shared<TaskResult::State>();
140 if (taskState) {
141 if (function) {
142 {
143 std::lock_guard lock(mutex_);
144 q_.Push(Task(move(function), taskState));
145 }
146 cv_.notify_one();
147 } else {
148 // mark as done if the there was no function.
149 taskState->Done();
150 }
151 }
152 return IResult::Ptr { new TaskResult(BASE_NS::move(taskState)) };
153 }
154
PushNoWait(ITask::Ptr function)155 void PushNoWait(ITask::Ptr function) override
156 {
157 if (function) {
158 {
159 std::lock_guard lock(mutex_);
160 q_.Push(Task(move(function)));
161 }
162 cv_.notify_one();
163 }
164 }
165
GetNumberOfThreads() const166 uint32_t GetNumberOfThreads() const override
167 {
168 return static_cast<uint32_t>(threadCount_);
169 }
170
171 // IInterface
GetInterface(const BASE_NS::Uid & uid) const172 const IInterface* GetInterface(const BASE_NS::Uid& uid) const override
173 {
174 if (uid == IThreadPool::UID) {
175 return static_cast<const IThreadPool*>(this);
176 }
177 return nullptr;
178 }
179
GetInterface(const BASE_NS::Uid & uid)180 IInterface* GetInterface(const BASE_NS::Uid& uid) override
181 {
182 if (uid == IThreadPool::UID) {
183 return static_cast<IThreadPool*>(this);
184 }
185 return nullptr;
186 }
187
Ref()188 void Ref() override
189 {
190 refcnt_.fetch_add(1, std::memory_order_relaxed);
191 }
192
Unref()193 void Unref() override
194 {
195 if (std::atomic_fetch_sub_explicit(&refcnt_, 1, std::memory_order_release) == 1) {
196 std::atomic_thread_fence(std::memory_order_acquire);
197 delete this;
198 }
199 }
200
201 protected:
~ThreadPool()202 virtual ~ThreadPool()
203 {
204 Stop(true);
205 }
206
207 private:
208 ThreadPool(const ThreadPool&) = delete;
209 ThreadPool(ThreadPool&&) = delete;
210 ThreadPool& operator=(const ThreadPool&) = delete;
211 ThreadPool& operator=(ThreadPool&&) = delete;
212
213 // Helper which holds a pointer to a queued task function and the result state.
214 struct Task {
215 ITask::Ptr function_;
216 std::shared_ptr<TaskResult::State> state_;
217
218 ~Task() = default;
219 Task() = default;
Task__anond6001d750111::ThreadPool::Task220 explicit Task(ITask::Ptr&& function, std::shared_ptr<TaskResult::State> state)
221 : function_(move(function)), state_(CORE_NS::move(state))
222 {
223 CORE_ASSERT(this->function_ && this->state_);
224 }
Task__anond6001d750111::ThreadPool::Task225 explicit Task(ITask::Ptr&& function) : function_(move(function))
226 {
227 CORE_ASSERT(this->function_);
228 }
229 Task(Task&&) = default;
230 Task& operator=(Task&&) = default;
231 Task(const Task&) = delete;
232 Task& operator=(const Task&) = delete;
233
operator ()__anond6001d750111::ThreadPool::Task234 void operator()()
235 {
236 (*function_)();
237 if (state_) {
238 state_->Done();
239 }
240 }
241 };
242
243 template<typename T>
244 class Queue {
245 public:
Push(T && value)246 bool Push(T&& value)
247 {
248 q_.push(move(value));
249 return true;
250 }
251
Pop(T & v)252 bool Pop(T& v)
253 {
254 if (q_.empty()) {
255 v = {};
256 return false;
257 }
258 v = CORE_NS::move(q_.front());
259 q_.pop();
260 return true;
261 }
262
263 private:
264 std::queue<T> q_;
265 };
266
267 struct ThreadContext {
268 std::thread thread;
269 bool exit { false };
270 };
271
Clear()272 void Clear()
273 {
274 Task f;
275 std::lock_guard lock(mutex_);
276 while (q_.Pop(f)) {
277 // Intentionally empty.
278 }
279 }
280
281 // At the moment Stop is called only from the destructor with waitForAllTasksToComplete=true.
282 // If this doesn't change the class can be simplified a bit.
Stop(bool waitForAllTasksToComplete)283 void Stop(bool waitForAllTasksToComplete)
284 {
285 if (isStop_) {
286 return;
287 }
288 if (waitForAllTasksToComplete) {
289 // Wait all tasks to complete before returning.
290 if (isDone_) {
291 return;
292 }
293 std::lock_guard lock(mutex_);
294 isDone_ = true;
295 } else {
296 isStop_ = true;
297
298 // Ask all the threads to stop and not process any more tasks.
299 auto threads = array_view(threads_.get(), threadCount_);
300 {
301 auto lock = std::lock_guard(mutex_);
302 for (auto& context : threads) {
303 context.exit = true;
304 }
305 }
306 Clear();
307 }
308
309 // Trigger all waiting threads.
310 cv_.notify_all();
311
312 // Wait for all threads to finish.
313 auto threads = array_view(threads_.get(), threadCount_);
314 for (auto& context : threads) {
315 if (context.thread.joinable()) {
316 context.thread.join();
317 }
318 }
319
320 Clear();
321 }
322
ThreadProc(ThreadContext & context)323 void ThreadProc(ThreadContext& context)
324 {
325 #ifdef PLATFORM_HAS_HAVA
326 // RAII class for handling thread setup/release.
327 JavaThreadContext javaContext;
328 #endif
329
330 // Get function to process.
331 Task func;
332 bool isPop = [this, &func]() {
333 std::lock_guard lock(mutex_);
334 return q_.Pop(func);
335 }();
336
337 while (true) {
338 while (isPop) {
339 // Run task function.
340 func();
341
342 // If the thread is wanted to stop, return even if the queue is not empty yet.
343 std::lock_guard lock(mutex_);
344 if (context.exit) {
345 return;
346 }
347
348 // Get next function.
349 isPop = q_.Pop(func);
350 }
351
352 // The queue is empty here, wait for the next task.
353 std::unique_lock lock(mutex_);
354
355 // Try to wait for next task to process.
356 cv_.wait(lock, [this, &func, &isPop, &context]() {
357 isPop = q_.Pop(func);
358 return isPop || isDone_ || context.exit;
359 });
360
361 if (!isPop) {
362 return;
363 }
364 }
365 }
366
367 size_t threadCount_ { 0 };
368 unique_ptr<ThreadContext[]> threads_;
369
370 Queue<Task> q_;
371 bool isDone_ { false };
372 bool isStop_ { false };
373
374 std::mutex mutex_;
375 std::condition_variable cv_;
376 std::atomic<int32_t> refcnt_ { 0 };
377 };
378 } // namespace
379
GetNumberOfCores() const380 uint32_t TaskQueueFactory::GetNumberOfCores() const
381 {
382 uint32_t result = std::thread::hardware_concurrency();
383 if (result == 0) {
384 // If not detectable, default to 4.
385 result = 4;
386 }
387
388 return result;
389 }
390
CreateThreadPool(const uint32_t threadCount) const391 IThreadPool::Ptr TaskQueueFactory::CreateThreadPool(const uint32_t threadCount) const
392 {
393 return IThreadPool::Ptr { new ThreadPool(threadCount) };
394 }
395
CreateDispatcherTaskQueue(const IThreadPool::Ptr & threadPool) const396 IDispatcherTaskQueue::Ptr TaskQueueFactory::CreateDispatcherTaskQueue(const IThreadPool::Ptr& threadPool) const
397 {
398 return IDispatcherTaskQueue::Ptr { make_unique<DispatcherImpl>(threadPool).release() };
399 }
400
CreateParallelTaskQueue(const IThreadPool::Ptr & threadPool) const401 IParallelTaskQueue::Ptr TaskQueueFactory::CreateParallelTaskQueue(const IThreadPool::Ptr& threadPool) const
402 {
403 return IParallelTaskQueue::Ptr { make_unique<ParallelImpl>(threadPool).release() };
404 }
405
CreateSequentialTaskQueue(const IThreadPool::Ptr & threadPool) const406 ISequentialTaskQueue::Ptr TaskQueueFactory::CreateSequentialTaskQueue(const IThreadPool::Ptr& threadPool) const
407 {
408 return ISequentialTaskQueue::Ptr { make_unique<SequentialImpl>(threadPool).release() };
409 }
410
411 // IInterface
GetInterface(const BASE_NS::Uid & uid) const412 const IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid) const
413 {
414 if (uid == ITaskQueueFactory::UID) {
415 return static_cast<const ITaskQueueFactory*>(this);
416 }
417 return nullptr;
418 }
419
GetInterface(const BASE_NS::Uid & uid)420 IInterface* TaskQueueFactory::GetInterface(const BASE_NS::Uid& uid)
421 {
422 if (uid == ITaskQueueFactory::UID) {
423 return static_cast<ITaskQueueFactory*>(this);
424 }
425 return nullptr;
426 }
427
Ref()428 void TaskQueueFactory::Ref() {}
429
Unref()430 void TaskQueueFactory::Unref() {}
431 CORE_END_NAMESPACE()
432