1 /**
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <gtest/gtest.h>
17
18 #include "runtime/include/runtime.h"
19 #include "runtime/thread_pool.h"
20
21 namespace ark::test {
22
23 constexpr int TEN = 10;
24 constexpr int PERCENT = 100;
25
26 class MockThreadPoolTest : public testing::Test {
27 public:
28 static const size_t TASK_NUMBER = 32;
MockThreadPoolTest()29 MockThreadPoolTest()
30 {
31 RuntimeOptions options;
32 options.SetShouldLoadBootPandaFiles(false);
33 options.SetShouldInitializeIntrinsics(false);
34 Runtime::Create(options);
35 thread_ = ark::MTManagedThread::GetCurrent();
36 thread_->ManagedCodeBegin();
37 }
38
~MockThreadPoolTest()39 ~MockThreadPoolTest() override
40 {
41 thread_->ManagedCodeEnd();
42 Runtime::Destroy();
43 }
44
45 NO_COPY_SEMANTIC(MockThreadPoolTest);
46 NO_MOVE_SEMANTIC(MockThreadPoolTest);
47
48 private:
49 ark::MTManagedThread *thread_;
50 };
51
52 class MockTask : public TaskInterface {
53 public:
MockTask(size_t identifier=0)54 explicit MockTask(size_t identifier = 0) : identifier_(identifier) {}
55
56 enum TaskStatus {
57 NOT_STARTED,
58 IN_QUEUE,
59 PROCESSING,
60 COMPLETED,
61 };
62
IsEmpty() const63 bool IsEmpty() const
64 {
65 return identifier_ == 0;
66 }
67
GetId() const68 size_t GetId() const
69 {
70 return identifier_;
71 }
72
GetStatus() const73 TaskStatus GetStatus() const
74 {
75 return status_;
76 }
77
SetStatus(TaskStatus status)78 void SetStatus(TaskStatus status)
79 {
80 status_ = status;
81 }
82
83 private:
84 size_t identifier_;
85 TaskStatus status_ = NOT_STARTED;
86 };
87
88 class MockQueue : public TaskQueueInterface<MockTask> {
89 public:
MockQueue(mem::InternalAllocatorPtr allocator)90 explicit MockQueue(mem::InternalAllocatorPtr allocator) : queue_(allocator->Adapter()) {}
MockQueue(mem::InternalAllocatorPtr allocator,size_t queueSize)91 MockQueue(mem::InternalAllocatorPtr allocator, size_t queueSize)
92 : TaskQueueInterface<MockTask>(queueSize), queue_(allocator->Adapter())
93 {
94 }
95
GetTask()96 MockTask GetTask() override
97 {
98 if (queue_.empty()) {
99 LOG(DEBUG, RUNTIME) << "Cannot get an element, queue is empty";
100 return MockTask();
101 }
102 auto task = queue_.front();
103 queue_.pop_front();
104 LOG(DEBUG, RUNTIME) << "Extract task " << task.GetId();
105 return task;
106 }
107
108 // NOLINTNEXTLINE(google-default-arguments)
AddTask(MockTask && task,size_t priority=0)109 void AddTask(MockTask &&task, [[maybe_unused]] size_t priority = 0) override
110 {
111 task.SetStatus(MockTask::IN_QUEUE);
112 queue_.push_front(task);
113 }
114
Finalize()115 void Finalize() override
116 {
117 queue_.clear();
118 }
119
120 protected:
GetQueueSize()121 size_t GetQueueSize() override
122 {
123 return queue_.size();
124 }
125
126 private:
127 PandaList<MockTask> queue_;
128 };
129
130 class MockTaskController {
131 public:
132 explicit MockTaskController() = default;
133
SolveTask(MockTask task)134 void SolveTask(MockTask task)
135 {
136 task.SetStatus(MockTask::PROCESSING);
137 // This is required to distribute tasks between different workers rather than solve it instantly
138 // on only one worker.
139 std::this_thread::sleep_for(std::chrono::milliseconds(TEN));
140 task.SetStatus(MockTask::COMPLETED);
141 LOG(DEBUG, RUNTIME) << "Task " << task.GetId() << " has been solved";
142 solvedTasks_++;
143 }
144
GetSolvedTasks()145 size_t GetSolvedTasks()
146 {
147 return solvedTasks_;
148 }
149
150 private:
151 std::atomic_size_t solvedTasks_ = 0;
152 };
153
154 class MockProcessor : public ProcessorInterface<MockTask, MockTaskController *> {
155 public:
MockProcessor(MockTaskController * controller)156 explicit MockProcessor(MockTaskController *controller) : controller_(controller) {}
157
Process(MockTask && task)158 bool Process(MockTask &&task) override
159 {
160 if (task.GetStatus() == MockTask::IN_QUEUE) {
161 controller_->SolveTask(task);
162 return true;
163 }
164 return false;
165 }
166
167 private:
168 MockTaskController *controller_;
169 };
170
CreateTasks(ThreadPool<MockTask,MockProcessor,MockTaskController * > * threadPool,size_t numberOfElements)171 void CreateTasks(ThreadPool<MockTask, MockProcessor, MockTaskController *> *threadPool, size_t numberOfElements)
172 {
173 for (size_t i = 0; i < numberOfElements; i++) {
174 MockTask task(i + 1);
175 LOG(DEBUG, RUNTIME) << "Queue task " << task.GetId();
176 // NOLINTNEXTLINE(performance-move-const-arg)
177 threadPool->PutTask(std::move(task));
178 }
179 }
180
TestThreadPool(size_t initialNumberOfThreads,size_t scaledNumberOfThreads,float scaleThreshold)181 void TestThreadPool(size_t initialNumberOfThreads, size_t scaledNumberOfThreads, float scaleThreshold)
182 {
183 auto allocator = Runtime::GetCurrent()->GetInternalAllocator();
184 auto queue = allocator->New<MockQueue>(allocator);
185 auto controller = allocator->New<MockTaskController>();
186 auto threadPool = allocator->New<ThreadPool<MockTask, MockProcessor, MockTaskController *>>(
187 allocator, queue, controller, initialNumberOfThreads, "Test thread");
188
189 CreateTasks(threadPool, MockThreadPoolTest::TASK_NUMBER);
190
191 if (scaleThreshold < 1.0) {
192 while (controller->GetSolvedTasks() < scaleThreshold * MockThreadPoolTest::TASK_NUMBER) {
193 }
194 threadPool->Scale(scaledNumberOfThreads);
195 }
196
197 for (;;) {
198 auto solvedTasks = controller->GetSolvedTasks();
199 // NOLINTNEXTLINE(readability-magic-numbers)
200 auto rate = static_cast<size_t>((static_cast<float>(solvedTasks) / MockThreadPoolTest::TASK_NUMBER) * 100);
201 LOG(DEBUG, RUNTIME) << "Number of solved tasks is " << solvedTasks << " (" << rate << "%)";
202 if (scaleThreshold == 1.0) {
203 // NOLINTNEXTLINE(readability-magic-numbers)
204 size_t dynamicScaling = rate / 10 + 1;
205 threadPool->Scale(dynamicScaling);
206 }
207
208 if (solvedTasks == MockThreadPoolTest::TASK_NUMBER) {
209 break;
210 }
211 }
212
213 allocator->Delete(threadPool);
214 allocator->Delete(controller);
215 allocator->Delete(queue);
216 }
217
TEST_F(MockThreadPoolTest,SeveralThreads)218 TEST_F(MockThreadPoolTest, SeveralThreads)
219 {
220 constexpr size_t NUMBER_OF_THREADS_INITIAL = 8;
221 constexpr size_t NUMBER_OF_THREADS_SCALED = 8;
222 constexpr float SCALE_THRESHOLD = 0.0;
223 TestThreadPool(NUMBER_OF_THREADS_INITIAL, NUMBER_OF_THREADS_SCALED, SCALE_THRESHOLD);
224 }
225
TEST_F(MockThreadPoolTest,ReduceThreads)226 TEST_F(MockThreadPoolTest, ReduceThreads)
227 {
228 constexpr size_t NUMBER_OF_THREADS_INITIAL = 8;
229 constexpr size_t NUMBER_OF_THREADS_SCALED = 4;
230 constexpr float SCALE_THRESHOLD = 0.25;
231 TestThreadPool(NUMBER_OF_THREADS_INITIAL, NUMBER_OF_THREADS_SCALED, SCALE_THRESHOLD);
232 }
233
TEST_F(MockThreadPoolTest,IncreaseThreads)234 TEST_F(MockThreadPoolTest, IncreaseThreads)
235 {
236 constexpr size_t NUMBER_OF_THREADS_INITIAL = 4;
237 constexpr size_t NUMBER_OF_THREADS_SCALED = 8;
238 constexpr float SCALE_THRESHOLD = 0.25;
239 TestThreadPool(NUMBER_OF_THREADS_INITIAL, NUMBER_OF_THREADS_SCALED, SCALE_THRESHOLD);
240 }
241
TEST_F(MockThreadPoolTest,DifferentNumberOfThreads)242 TEST_F(MockThreadPoolTest, DifferentNumberOfThreads)
243 {
244 constexpr size_t NUMBER_OF_THREADS_INITIAL = 8;
245 constexpr size_t NUMBER_OF_THREADS_SCALED = 8;
246 constexpr float SCALE_THRESHOLD = 1.0;
247 TestThreadPool(NUMBER_OF_THREADS_INITIAL, NUMBER_OF_THREADS_SCALED, SCALE_THRESHOLD);
248 }
249
ControllerThreadPutTask(ThreadPool<MockTask,MockProcessor,MockTaskController * > * threadPool,size_t numberOfTasks)250 void ControllerThreadPutTask(ThreadPool<MockTask, MockProcessor, MockTaskController *> *threadPool,
251 size_t numberOfTasks)
252 {
253 CreateTasks(threadPool, numberOfTasks);
254 }
255
ControllerThreadTryPutTask(ThreadPool<MockTask,MockProcessor,MockTaskController * > * threadPool,size_t numberOfTasks)256 void ControllerThreadTryPutTask(ThreadPool<MockTask, MockProcessor, MockTaskController *> *threadPool,
257 size_t numberOfTasks)
258 {
259 for (size_t i = 0; i < numberOfTasks; i++) {
260 for (;;) {
261 if (threadPool->TryPutTask(MockTask {i + 1}) || !threadPool->IsActive()) {
262 break;
263 }
264 }
265 }
266 }
267
ControllerThreadScale(ThreadPool<MockTask,MockProcessor,MockTaskController * > * threadPool,size_t numberOfThreads)268 void ControllerThreadScale(ThreadPool<MockTask, MockProcessor, MockTaskController *> *threadPool,
269 size_t numberOfThreads)
270 {
271 threadPool->Scale(numberOfThreads);
272 }
273
ControllerThreadShutdown(ThreadPool<MockTask,MockProcessor,MockTaskController * > * threadPool,bool isShutdown,bool isForceShutdown)274 void ControllerThreadShutdown(ThreadPool<MockTask, MockProcessor, MockTaskController *> *threadPool, bool isShutdown,
275 bool isForceShutdown)
276 {
277 if (isShutdown) {
278 threadPool->Shutdown(isForceShutdown);
279 }
280 }
281
TestThreadPoolWithControllers(size_t numberOfThreadsInitial,size_t numberOfThreadsScaled,bool isShutdown,bool isForceShutdown)282 void TestThreadPoolWithControllers(size_t numberOfThreadsInitial, size_t numberOfThreadsScaled, bool isShutdown,
283 bool isForceShutdown)
284 {
285 constexpr size_t NUMBER_OF_TASKS = MockThreadPoolTest::TASK_NUMBER / 4;
286 constexpr size_t QUEUE_SIZE = 16;
287
288 auto allocator = Runtime::GetCurrent()->GetInternalAllocator();
289 auto queue = allocator->New<MockQueue>(allocator, QUEUE_SIZE);
290 auto controller = allocator->New<MockTaskController>();
291 auto threadPool = allocator->New<ThreadPool<MockTask, MockProcessor, MockTaskController *>>(
292 allocator, queue, controller, numberOfThreadsInitial, "Test thread");
293
294 std::thread controllerThreadPutTask1(ControllerThreadPutTask, threadPool, NUMBER_OF_TASKS);
295 std::thread controllerThreadPutTask2(ControllerThreadPutTask, threadPool, NUMBER_OF_TASKS);
296 std::thread controllerThreadTryPutTask1(ControllerThreadTryPutTask, threadPool, NUMBER_OF_TASKS);
297 std::thread controllerThreadTryPutTask2(ControllerThreadTryPutTask, threadPool, NUMBER_OF_TASKS);
298 std::thread controllerThreadScale1(ControllerThreadScale, threadPool, numberOfThreadsScaled);
299 std::thread controllerThreadScale2(ControllerThreadScale, threadPool,
300 numberOfThreadsScaled + numberOfThreadsInitial);
301 std::thread controllerThreadShutdown1(ControllerThreadShutdown, threadPool, isShutdown, isForceShutdown);
302 std::thread controllerThreadShutdown2(ControllerThreadShutdown, threadPool, isShutdown, isForceShutdown);
303
304 // Wait for tasks completion.
305 for (;;) {
306 auto solvedTasks = controller->GetSolvedTasks();
307 auto rate = static_cast<size_t>((static_cast<float>(solvedTasks) / MockThreadPoolTest::TASK_NUMBER) * PERCENT);
308 (void)rate;
309 LOG(DEBUG, RUNTIME) << "Number of solved tasks is " << solvedTasks << " (" << rate << "%)";
310 if (solvedTasks == MockThreadPoolTest::TASK_NUMBER || !threadPool->IsActive()) {
311 break;
312 }
313 std::this_thread::sleep_for(std::chrono::milliseconds(TEN));
314 }
315 controllerThreadPutTask1.join();
316 controllerThreadPutTask2.join();
317 controllerThreadTryPutTask1.join();
318 controllerThreadTryPutTask2.join();
319 controllerThreadScale1.join();
320 controllerThreadScale2.join();
321 controllerThreadShutdown1.join();
322 controllerThreadShutdown2.join();
323
324 allocator->Delete(threadPool);
325 allocator->Delete(controller);
326 allocator->Delete(queue);
327 }
328
TEST_F(MockThreadPoolTest,Controllers)329 TEST_F(MockThreadPoolTest, Controllers)
330 {
331 constexpr size_t NUMBER_OF_THREADS_INITIAL = 8;
332 constexpr size_t NUMBER_OF_THREADS_SCALED = 4;
333 constexpr bool IS_SHUTDOWN = false;
334 constexpr bool IS_FORCE_SHUTDOWN = false;
335 TestThreadPoolWithControllers(NUMBER_OF_THREADS_INITIAL, NUMBER_OF_THREADS_SCALED, IS_SHUTDOWN, IS_FORCE_SHUTDOWN);
336 }
337
TEST_F(MockThreadPoolTest,ControllersShutdown)338 TEST_F(MockThreadPoolTest, ControllersShutdown)
339 {
340 constexpr size_t NUMBER_OF_THREADS_INITIAL = 8;
341 constexpr size_t NUMBER_OF_THREADS_SCALED = 4;
342 constexpr bool IS_SHUTDOWN = true;
343 constexpr bool IS_FORCE_SHUTDOWN = false;
344 TestThreadPoolWithControllers(NUMBER_OF_THREADS_INITIAL, NUMBER_OF_THREADS_SCALED, IS_SHUTDOWN, IS_FORCE_SHUTDOWN);
345 }
346
TEST_F(MockThreadPoolTest,ControllersForceShutdown)347 TEST_F(MockThreadPoolTest, ControllersForceShutdown)
348 {
349 constexpr size_t NUMBER_OF_THREADS_INITIAL = 8;
350 constexpr size_t NUMBER_OF_THREADS_SCALED = 4;
351 constexpr bool IS_SHUTDOWN = true;
352 constexpr bool IS_FORCE_SHUTDOWN = true;
353 TestThreadPoolWithControllers(NUMBER_OF_THREADS_INITIAL, NUMBER_OF_THREADS_SCALED, IS_SHUTDOWN, IS_FORCE_SHUTDOWN);
354 }
355
356 } // namespace ark::test
357