• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "thread/parallel_thread_pool_manager.h"
18 #include <map>
19 #include <string>
20 #include "thread/parallel_threadpool.h"
21 
22 namespace mindspore {
23 namespace {
24 const char *kInnerModelParallelRunner = "inner_model_parallel_runner";
25 const char *kInnerRunnerID = "inner_runner_id";
26 const char *kInnerModelID = "inner_model_id";
27 }  // namespace
GetInstance()28 ParallelThreadPoolManager *ParallelThreadPoolManager::GetInstance() {
29   static ParallelThreadPoolManager instance;
30   return &instance;
31 }
32 
Init(bool enable_shared_thread_pool,const std::string & runner_id,int worker_num,int remaining_thread_num,int thread_num_limit)33 void ParallelThreadPoolManager::Init(bool enable_shared_thread_pool, const std::string &runner_id, int worker_num,
34                                      int remaining_thread_num, int thread_num_limit) {
35   std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
36   if (enable_shared_thread_pool_.find(runner_id) != enable_shared_thread_pool_.end()) {
37     THREAD_ERROR("Not need to repeat init.");
38     return;
39   }
40   enable_shared_thread_pool_[runner_id] = enable_shared_thread_pool;
41   if (!enable_shared_thread_pool) {
42     THREAD_INFO("not enable shared parallel thread pool.");
43     return;
44   }
45   std::vector<ParallelThreadPool *> runner_pools(worker_num, nullptr);
46   runner_id_pools_[runner_id] = runner_pools;
47   remaining_thread_num_[runner_id] = remaining_thread_num;
48   thread_num_limit_[runner_id] = thread_num_limit;
49   idle_pool_num_[runner_id] = worker_num;
50   runner_worker_num_[runner_id] = worker_num;
51   worker_init_num_[runner_id] = 0;
52 }
53 
SetHasIdlePool(std::string runner_id,bool is_idle)54 void ParallelThreadPoolManager::SetHasIdlePool(std::string runner_id, bool is_idle) {
55   std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
56   has_idle_pool_[runner_id] = is_idle;
57 }
58 
GetTaskNum(const std::map<std::string,std::map<std::string,std::string>> * config_info)59 int ParallelThreadPoolManager::GetTaskNum(
60   const std::map<std::string, std::map<std::string, std::string>> *config_info) {
61   if (config_info == nullptr) {
62     THREAD_ERROR("config_info is nullptr.");
63     return -1;
64   }
65   std::string runner_id;
66   auto it_id = config_info->find(kInnerModelParallelRunner);
67   if (it_id != config_info->end()) {
68     auto item_runner = it_id->second.find(kInnerRunnerID);
69     if (item_runner != it_id->second.end()) {
70       runner_id = it_id->second.at(kInnerRunnerID);
71     }
72   }
73   std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
74   if (runner_id.empty() || enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end() ||
75       !enable_shared_thread_pool_[runner_id]) {
76     THREAD_INFO("not enable shared parallel thread pool.");
77     return -1;
78   }
79   return thread_num_limit_[runner_id];
80 }
81 
GetThreadPoolSize(ThreadPool * pool)82 int ParallelThreadPoolManager::GetThreadPoolSize(ThreadPool *pool) {
83   std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
84   ParallelThreadPool *thread_pool = static_cast<ParallelThreadPool *>(pool);
85   if (thread_pool == nullptr) {
86     return -1;
87   }
88   if (pool_workers_.find(thread_pool) != pool_workers_.end()) {
89     return pool_workers_[thread_pool].size();
90   } else {
91     return -1;
92   }
93   return -1;
94 }
95 
BindPoolToRunner(ThreadPool * pool,const std::map<std::string,std::map<std::string,std::string>> * config_info)96 void ParallelThreadPoolManager::BindPoolToRunner(
97   ThreadPool *pool, const std::map<std::string, std::map<std::string, std::string>> *config_info) {
98   std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
99   if (config_info == nullptr) {
100     THREAD_ERROR("config_info is nullptr.");
101     return;
102   }
103   std::string runner_id;
104   auto it_id = config_info->find(kInnerModelParallelRunner);
105   if (it_id != config_info->end()) {
106     auto item_runner = it_id->second.find(kInnerRunnerID);
107     if (item_runner != it_id->second.end()) {
108       runner_id = it_id->second.at(kInnerRunnerID);
109     }
110   }
111   if (enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end() ||
112       !enable_shared_thread_pool_[runner_id]) {
113     THREAD_ERROR("not use parallel thread pool shared.");
114     return;
115   }
116   auto parallel_pool = static_cast<ParallelThreadPool *>(pool);
117   if (parallel_pool == nullptr) {
118     THREAD_ERROR("parallel pool is nullptr.");
119   }
120   int model_id = 0;
121   auto item_runner = it_id->second.find(kInnerModelID);
122   if (item_runner != it_id->second.end()) {
123     model_id = std::atoi(it_id->second.at(kInnerModelID).c_str());
124   }
125   auto runner_id_pools_iter = runner_id_pools_.find(runner_id);
126   if (runner_id_pools_iter == runner_id_pools_.end()) {
127     return;
128   }
129   auto &runner_id_pools = runner_id_pools_iter->second;
130   if (static_cast<size_t>(model_id) >= runner_id_pools.size()) {
131     return;
132   }
133   runner_id_pools_[runner_id].at(model_id) = parallel_pool;
134   auto all_workers = parallel_pool->GetParallelPoolWorkers();
135   for (size_t i = 0; i < all_workers.size(); i++) {
136     auto worker = static_cast<ParallelWorker *>(all_workers[i]);
137     pool_workers_[parallel_pool].push_back(worker);
138   }
139   worker_init_num_[runner_id]++;
140 }
141 
GetEnableSharedThreadPool(std::string runner_id)142 bool ParallelThreadPoolManager::GetEnableSharedThreadPool(std::string runner_id) {
143   std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
144   if (enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end()) {
145     return false;
146   }
147   return enable_shared_thread_pool_[runner_id];
148 }
149 
ActivatePool(const std::string & runner_id,int model_id)150 void ParallelThreadPoolManager::ActivatePool(const std::string &runner_id, int model_id) {
151   std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
152   if (enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end() ||
153       !enable_shared_thread_pool_[runner_id]) {
154     return;
155   }
156   if (idle_pool_num_.find(runner_id) == idle_pool_num_.end()) {
157     return;
158   }
159   auto runner_id_pools_iter = runner_id_pools_.find(runner_id);
160   if (runner_id_pools_iter == runner_id_pools_.end()) {
161     return;
162   }
163   idle_pool_num_[runner_id]--;
164   auto &runner_id_pools = runner_id_pools_iter->second;
165   if (static_cast<size_t>(model_id) < runner_id_pools.size()) {
166     auto &pool = runner_id_pools_iter->second[model_id];
167     pool->UseThreadPool(1);
168 
169     auto pool_workers_iter = pool_workers_.find(pool);
170     if (pool_workers_iter != pool_workers_.end()) {
171       auto &workers = pool_workers_iter->second;
172       for (auto &worker : workers) {
173         worker->ActivateByOtherPoolTask();
174       }
175     }
176   }
177 }
178 
SetFreePool(const std::string & runner_id,int model_id)179 void ParallelThreadPoolManager::SetFreePool(const std::string &runner_id, int model_id) {
180   std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
181   if (enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end() ||
182       !enable_shared_thread_pool_[runner_id]) {
183     return;
184   }
185   auto runner_id_pools_iter = runner_id_pools_.find(runner_id);
186   if (runner_id_pools_iter == runner_id_pools_.end()) {
187     return;
188   }
189   if (idle_pool_num_.find(runner_id) == idle_pool_num_.end()) {
190     return;
191   }
192   auto &runner_id_pools = runner_id_pools_iter->second;
193   if (static_cast<size_t>(model_id) < runner_id_pools.size()) {
194     auto &pool = runner_id_pools_iter->second[model_id];
195     pool->UseThreadPool(-1);
196     idle_pool_num_[runner_id]++;
197   }
198 }
199 
GetIdleThreadPool(const std::string & runner_id,ParallelTask * task)200 ParallelThreadPool *ParallelThreadPoolManager::GetIdleThreadPool(const std::string &runner_id, ParallelTask *task) {
201   std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
202   auto runner_worker_num_iter = runner_worker_num_.find(runner_id);
203   auto worker_init_num_iter = worker_init_num_.find(runner_id);
204   if (runner_worker_num_iter == runner_worker_num_.end() || worker_init_num_iter == worker_init_num_.end() ||
205       runner_worker_num_iter->second != worker_init_num_iter->second) {
206     return nullptr;
207   }
208   auto idle_pool_num_iter = idle_pool_num_.find(runner_id);
209   if (idle_pool_num_iter == idle_pool_num_.end() || idle_pool_num_iter->second <= 0) {
210     return nullptr;
211   }
212 
213   auto runner_id_pools_iter = runner_id_pools_.find(runner_id);
214   if (runner_id_pools_iter == runner_id_pools_.end()) {
215     return nullptr;
216   }
217   auto &all_pools = runner_id_pools_iter->second;
218   for (int pool_index = all_pools.size() - 1; pool_index >= 0; pool_index--) {
219     auto &pool = all_pools[pool_index];
220     if (pool->IsIdlePool()) {
221       auto pool_workers_iter = pool_workers_.find(pool);
222       if (pool_workers_iter == pool_workers_.end()) {
223         pool->UseThreadPool(-1);
224         continue;
225       }
226       auto &workers = pool_workers_iter->second;
227       auto remaining_thread_num_iter = remaining_thread_num_.find(runner_id);
228       if (remaining_thread_num_iter == remaining_thread_num_.end()) {
229         pool->UseThreadPool(-1);
230         continue;
231       }
232       for (size_t i = 0; i < workers.size() - remaining_thread_num_iter->second; i++) {
233         workers[i]->ActivateByOtherPoolTask(task);
234       }
235       return pool;
236     }
237   }
238   return nullptr;
239 }
240 
ResetParallelThreadPoolManager(const std::string & runner_id)241 void ParallelThreadPoolManager::ResetParallelThreadPoolManager(const std::string &runner_id) {
242   std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
243   if (runner_id_pools_.find(runner_id) == runner_id_pools_.end()) {
244     return;
245   }
246   auto pools = runner_id_pools_[runner_id];
247   for (auto &pool : pools) {
248     pool_workers_.erase(pool);
249   }
250   runner_id_pools_.erase(runner_id);
251   has_idle_pool_.erase(runner_id);
252   enable_shared_thread_pool_.erase(runner_id);
253   remaining_thread_num_.erase(runner_id);
254   thread_num_limit_.erase(runner_id);
255   runner_worker_num_.erase(runner_id);
256   worker_init_num_.erase(runner_id);
257   idle_pool_num_.erase(runner_id);
258 }
259 
~ParallelThreadPoolManager()260 ParallelThreadPoolManager::~ParallelThreadPoolManager() {
261   THREAD_INFO("~ParallelThreadPoolManager start.");
262   std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
263   pool_workers_.clear();
264   runner_id_pools_.clear();
265   has_idle_pool_.clear();
266   enable_shared_thread_pool_.clear();
267   remaining_thread_num_.clear();
268   thread_num_limit_.clear();
269   runner_worker_num_.clear();
270   worker_init_num_.clear();
271   idle_pool_num_.clear();
272   THREAD_INFO("~ParallelThreadPoolManager end.");
273 }
274 }  // namespace mindspore
275