• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef _MSC_VER
17 #include <sched.h>
18 #include <unistd.h>
19 #endif
20 #include "thread/parallel_threadpool.h"
21 #include "thread/core_affinity.h"
22 #include "thread/parallel_thread_pool_manager.h"
23 
24 namespace mindspore {
25 constexpr int kActorParallelThreshold = 5;
26 
~ParallelWorker()27 ParallelWorker::~ParallelWorker() {
28   {
29     std::lock_guard<std::mutex> _l(mutex_);
30     alive_ = false;
31   }
32   if (enable_shared_thread_pool_) {
33     ActivateByOtherPoolTask(nullptr);
34   } else {
35     cond_var_->notify_one();
36   }
37   if (thread_->joinable()) {
38     thread_->join();
39   }
40   pool_ = nullptr;
41   parallel_pool_ = nullptr;
42 }
43 
CreateThread()44 void ParallelWorker::CreateThread() { thread_ = std::make_unique<std::thread>(&ParallelWorker::ParallelRun, this); }
45 
ParallelRun()46 void ParallelWorker::ParallelRun() {
47   if (!core_list_.empty()) {
48     SetAffinity();
49   }
50 #if !defined(__APPLE__) && !defined(_MSC_VER)
51   (void)pthread_setname_np(pthread_self(), ("OS_Parallel_" + std::to_string(worker_id_)).c_str());
52 #endif
53 #ifdef PLATFORM_86
54   // Some CPU kernels need set the flush zero mode to improve performance.
55   _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
56   _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
57 #endif
58   std::string bind_runner_id = parallel_pool_->GetPoolBindRunnerID();
59   enable_shared_thread_pool_ = ParallelThreadPoolManager::GetInstance()->GetEnableSharedThreadPool(bind_runner_id);
60   while (alive_) {
61     // only run either local KernelTask or PoolQueue ActorTask
62     if (RunLocalKernelTask() || RunQueueActorTask()) {
63       spin_count_ = 0;
64     } else {
65       (void)DirectRunOtherPoolTask();
66       if (++spin_count_ > max_spin_count_) {
67         if (!enable_shared_thread_pool_) {
68           WaitUntilActive();
69           spin_count_ = 0;
70         } else {
71           WaitOtherPoolTask();
72         }
73       } else {
74         std::this_thread::yield();
75       }
76     }
77   }
78 }
79 
WaitUntilActive()80 void ParallelWorker::WaitUntilActive() {
81   std::unique_lock<std::mutex> _l(mutex_);
82   cond_var_->wait(_l, [&] { return active_num_ > 0 || !alive_; });
83   if (active_num_ > 0) {
84     active_num_--;
85   }
86 }
87 
RunOtherPoolTask(ParallelTask * p_task)88 void ParallelWorker::RunOtherPoolTask(ParallelTask *p_task) {
89   bool find = false;
90   int finish = 0;
91   Distributor expected_index = p_task->distributor;
92   bool is_busy = false;
93   while (p_task->valid && expected_index.started < expected_index.task_num) {
94     if (parallel_pool_->GetPoolRef() != 1) {
95       is_busy = true;
96       break;
97     }
98     if (p_task->distributor.compare_exchange_strong(expected_index,
99                                                     {expected_index.started + 1, expected_index.task_num})) {
100       p_task->status |= p_task->func(other_task_->content, expected_index.started, 0, 0);
101       find = true;
102       expected_index = p_task->distributor;
103       finish++;
104     }
105   }
106   if (find && !is_busy) {
107     p_task->valid = false;
108   }
109   p_task->finished += finish;
110   return;
111 }
112 
ActivateByOtherPoolTask(ParallelTask * task)113 void ParallelWorker::ActivateByOtherPoolTask(ParallelTask *task) {
114   std::unique_lock<std::mutex> l(other_task_mutex_);
115   wait_do_other_task_ = false;
116   other_task_ = task;
117   cv_other_task_.notify_one();
118 }
119 
WaitOtherPoolTask()120 void ParallelWorker::WaitOtherPoolTask() {
121   std::unique_lock<std::mutex> l(other_task_mutex_);
122   while (alive_ && wait_do_other_task_) {
123     cv_other_task_.wait(l);
124   }
125   wait_do_other_task_ = true;
126   if (other_task_ == nullptr) {
127     return;
128   }
129   RunOtherPoolTask(other_task_);
130   other_task_ = nullptr;
131   return;
132 }
133 
DirectRunOtherPoolTask()134 void ParallelWorker::DirectRunOtherPoolTask() {
135   std::unique_lock<std::mutex> l(other_task_mutex_);
136   if (other_task_ == nullptr) {
137     return;
138   }
139   RunOtherPoolTask(other_task_);
140   other_task_ = nullptr;
141   return;
142 }
143 
RunLocalKernelTask()144 bool ParallelWorker::RunLocalKernelTask() { return parallel_pool_->RunParallel(); }
145 
RunQueueActorTask()146 bool ParallelWorker::RunQueueActorTask() {
147   if (worker_id_ < parallel_pool_->tasks_size()) {
148     auto actor = parallel_pool_->PopActorFromQueue();
149     if (actor == nullptr) {
150       return false;
151     }
152     actor->Run();
153     return true;
154   }
155   return false;
156 }
157 
UseThreadPool(int num)158 void ParallelThreadPool::UseThreadPool(int num) {
159   std::lock_guard<std::mutex> l(mutex_pool_ref_count_);
160   pool_ref_count_ += num;
161 }
162 
SetRunnerID(const std::string & runner_id)163 bool ParallelThreadPool::SetRunnerID(const std::string &runner_id) {
164   if (!bind_runner_id_.empty() &&
165       ParallelThreadPoolManager::GetInstance()->GetEnableSharedThreadPool(runner_id) != enable_shared_) {
166     THREAD_ERROR("can not set runner id.");
167     return false;
168   }
169   bind_runner_id_ = runner_id;
170   return true;
171 }
172 
GetParallelPoolWorkers()173 std::vector<ParallelWorker *> ParallelThreadPool::GetParallelPoolWorkers() {
174   std::vector<ParallelWorker *> workers;
175   for (auto woker : workers_) {
176     workers.push_back(static_cast<ParallelWorker *>(woker));
177   }
178   return workers;
179 }
180 
GetPoolRef()181 int ParallelThreadPool::GetPoolRef() {
182   std::lock_guard<std::mutex> l(mutex_pool_ref_count_);
183   return pool_ref_count_;
184 }
185 
RunTaskOnce(int start,int end)186 inline bool ParallelThreadPool::RunTaskOnce(int start, int end) {
187   bool find = false;
188   ParallelTask *p_task;
189   for (int i = start; i < end; i++) {
190     if (tasks_[i].valid) {
191       if (i != start) {
192         tasks_start_ = i;
193       }
194       int finish = 0;
195       p_task = &tasks_[i];
196       Distributor expected_index = p_task->distributor;
197       while (expected_index.started < expected_index.task_num) {
198         if (p_task->distributor.compare_exchange_strong(expected_index,
199                                                         {expected_index.started + 1, expected_index.task_num})) {
200           p_task->status |= p_task->func(p_task->content, expected_index.started, 0, 0);
201           find = true;
202           expected_index = p_task->distributor;
203           finish++;
204         }
205       }
206       if (find) {
207         p_task->valid = false;
208         p_task->finished += finish;
209         break;
210       }
211     }
212   }
213   return find;
214 }
215 
RunParallel()216 bool ParallelThreadPool::RunParallel() {
217   bool ret = false;
218   bool find;
219   int max_num = static_cast<int>(tasks_size_);
220   if (max_num < kActorParallelThreshold) {
221     ParallelTask *p_task;
222     do {
223       find = false;
224       for (int i = 0; i < max_num; i++) {
225         if (tasks_[i].valid) {
226           int finish = 0;
227           p_task = &tasks_[i];
228           Distributor expected_index = p_task->distributor;
229           while (expected_index.started < expected_index.task_num) {
230             if (p_task->distributor.compare_exchange_strong(expected_index,
231                                                             {expected_index.started + 1, expected_index.task_num})) {
232               p_task->status |= p_task->func(p_task->content, expected_index.started, 0, 0);
233               find = true;
234               expected_index = p_task->distributor;
235               finish++;
236             }
237           }
238           if (find) {
239             p_task->valid = false;
240             p_task->finished += finish;
241             break;
242           }
243         }
244       }
245       ret = ret || find;
246     } while (find);
247   } else {
248     do {
249       int start = tasks_start_;
250       int end = tasks_end_;
251       find = false;
252       if (start < end) {
253         find = RunTaskOnce(start, end);
254       } else if (start != end) {
255         find = RunTaskOnce(start, max_num);
256         if (find == false) {
257           find = RunTaskOnce(0, end);
258         }
259       }
260       ret = ret || find;
261     } while (find);
262   }
263   return ret;
264 }
265 
ParallelLaunch(const Func & func,Content content,int task_num)266 int ParallelThreadPool::ParallelLaunch(const Func &func, Content content, int task_num) {
267   // if single thread, run master thread
268   if (task_num <= 1) {
269     return SyncRunFunc(func, content, 0, task_num);
270   }
271   UseThreadPool(1);
272   // distribute task to the KernelThread and the idle ActorThread,
273   // if the task num is greater than the KernelThread num
274   size_t task_index;
275   bool expected;
276   size_t max_task_num = tasks_size_;
277 
278   for (task_index = static_cast<size_t>(tasks_end_); task_index < max_task_num; task_index++) {
279     expected = false;
280     if (tasks_[task_index].occupied.compare_exchange_strong(expected, true)) {
281       tasks_end_ = static_cast<int>(task_index + 1);
282       break;
283     }
284   }
285   if (task_index >= max_task_num) {
286     for (task_index = 0; task_index < max_task_num; task_index++) {
287       expected = false;
288       if (tasks_[task_index].occupied.compare_exchange_strong(expected, true)) {
289         tasks_end_ = static_cast<int>(task_index + 1);
290         break;
291       }
292     }
293     if (task_index >= max_task_num) {
294       return SyncRunFunc(func, content, 0, task_num);
295     }
296   }
297 
298   ParallelTask *p_task = &tasks_[task_index];
299   p_task->valid.store(false);
300   p_task->func = func;
301   p_task->content = content;
302   p_task->finished = 1;
303   p_task->distributor = {1, task_num};
304   p_task->valid.store(true);
305 
306   ParallelThreadPool *idle_pool = nullptr;
307   if (!enable_shared_) {
308     ActiveWorkers();
309   } else {
310     for (auto &worker : workers_) {
311       static_cast<ParallelWorker *>(worker)->ActivateByOtherPoolTask();
312     }
313     if (thread_num_ < task_num) {
314       idle_pool = ParallelThreadPoolManager::GetInstance()->GetIdleThreadPool(bind_runner_id_, p_task);
315     }
316   }
317 
318   p_task->status |= p_task->func(p_task->content, 0, 0, 0);
319 
320   Distributor expected_index = p_task->distributor;
321   while (expected_index.started < task_num) {
322     if (p_task->distributor.compare_exchange_strong(expected_index, {expected_index.started + 1, task_num})) {
323       p_task->status |= p_task->func(p_task->content, expected_index.started, 0, 0);
324       (void)++p_task->finished;
325       expected_index = p_task->distributor;
326     }
327   }
328   p_task->valid = false;
329   while (p_task->finished < task_num) {
330     std::this_thread::yield();
331   }
332   p_task->occupied = false;
333   // check the return value of task
334   if (p_task->status != THREAD_OK) {
335     return THREAD_ERROR;
336   }
337   if (idle_pool != nullptr) {
338     idle_pool->UseThreadPool(-1);
339   }
340   UseThreadPool(-1);
341   return THREAD_OK;
342 }
343 
IsIdlePool()344 bool ParallelThreadPool::IsIdlePool() {
345   auto export_ref_count = 0;
346   if (this->pool_ref_count_.compare_exchange_strong(export_ref_count, 1)) {
347     return true;
348   }
349   return false;
350 }
351 
CreateParallelThreads(size_t actor_thread_num,size_t all_thread_num,const std::vector<int> & core_list)352 int ParallelThreadPool::CreateParallelThreads(size_t actor_thread_num, size_t all_thread_num,
353                                               const std::vector<int> &core_list) {
354   if (actor_thread_num == 0) {
355     THREAD_ERROR("thread num is invalid");
356     return THREAD_ERROR;
357   }
358   if (ActorQueueInit() != THREAD_OK) {
359     return THREAD_ERROR;
360   }
361   if (affinity_ != nullptr) {
362     affinity_->SetCoreId(core_list);
363   }
364   size_t core_num = std::thread::hardware_concurrency();
365   all_thread_num = all_thread_num < core_num ? all_thread_num : core_num;
366   actor_thread_num_ = actor_thread_num < all_thread_num ? actor_thread_num : all_thread_num;
367   size_t tasks_size = actor_thread_num;
368 
369   tasks_ = new (std::nothrow) ParallelTask[tasks_size]();
370   THREAD_ERROR_IF_NULL(tasks_);
371   tasks_size_ = tasks_size;
372   if (TaskQueuesInit(all_thread_num) != THREAD_OK) {
373     return THREAD_ERROR;
374   }
375   enable_shared_ = ParallelThreadPoolManager::GetInstance()->GetEnableSharedThreadPool(bind_runner_id_);
376   auto ret = ThreadPool::CreateThreads<ParallelWorker>(all_thread_num, core_list);
377   if (ret != THREAD_OK) {
378     return THREAD_ERROR;
379   }
380   thread_num_ = static_cast<int>(thread_num());
381   return THREAD_OK;
382 }
383 
CreateThreadPool(size_t actor_thread_num,size_t all_thread_num,const std::vector<int> & core_list,BindMode bind_mode,std::string runner_id)384 ParallelThreadPool *ParallelThreadPool::CreateThreadPool(size_t actor_thread_num, size_t all_thread_num,
385                                                          const std::vector<int> &core_list, BindMode bind_mode,
386                                                          std::string runner_id) {
387   std::lock_guard<std::mutex> lock(create_thread_pool_muntex_);
388   ParallelThreadPool *pool = new (std::nothrow) ParallelThreadPool();
389   if (pool == nullptr) {
390     return nullptr;
391   }
392   if (!pool->SetRunnerID(runner_id)) {
393     delete pool;
394     return nullptr;
395   }
396   int ret = pool->InitAffinityInfo();
397   if (ret != THREAD_OK) {
398     delete pool;
399     return nullptr;
400   }
401   if (core_list.empty()) {
402     ret = pool->CreateParallelThreads(actor_thread_num, all_thread_num,
403                                       pool->affinity_->GetCoreId(all_thread_num, bind_mode));
404   } else {
405     ret = pool->CreateParallelThreads(actor_thread_num, all_thread_num, core_list);
406   }
407   if (ret != THREAD_OK) {
408     delete pool;
409     return nullptr;
410   }
411   return pool;
412 }
413 }  // namespace mindspore
414