1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "thread/parallel_thread_pool_manager.h"
18 #include <map>
19 #include <string>
20 #include "thread/parallel_threadpool.h"
21
22 namespace mindspore {
23 namespace {
24 const char *kInnerModelParallelRunner = "inner_model_parallel_runner";
25 const char *kInnerRunnerID = "inner_runner_id";
26 const char *kInnerModelID = "inner_model_id";
27 } // namespace
GetInstance()28 ParallelThreadPoolManager *ParallelThreadPoolManager::GetInstance() {
29 static ParallelThreadPoolManager instance;
30 return &instance;
31 }
32
Init(bool enable_shared_thread_pool,const std::string & runner_id,int worker_num,int remaining_thread_num,int thread_num_limit)33 void ParallelThreadPoolManager::Init(bool enable_shared_thread_pool, const std::string &runner_id, int worker_num,
34 int remaining_thread_num, int thread_num_limit) {
35 std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
36 if (enable_shared_thread_pool_.find(runner_id) != enable_shared_thread_pool_.end()) {
37 THREAD_ERROR("Not need to repeat init.");
38 return;
39 }
40 enable_shared_thread_pool_[runner_id] = enable_shared_thread_pool;
41 if (!enable_shared_thread_pool) {
42 THREAD_INFO("not enable shared parallel thread pool.");
43 return;
44 }
45 std::vector<ParallelThreadPool *> runner_pools(worker_num, nullptr);
46 runner_id_pools_[runner_id] = runner_pools;
47 remaining_thread_num_[runner_id] = remaining_thread_num;
48 thread_num_limit_[runner_id] = thread_num_limit;
49 idle_pool_num_[runner_id] = worker_num;
50 runner_worker_num_[runner_id] = worker_num;
51 worker_init_num_[runner_id] = 0;
52 }
53
SetHasIdlePool(std::string runner_id,bool is_idle)54 void ParallelThreadPoolManager::SetHasIdlePool(std::string runner_id, bool is_idle) {
55 std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
56 has_idle_pool_[runner_id] = is_idle;
57 }
58
GetTaskNum(const std::map<std::string,std::map<std::string,std::string>> * config_info)59 int ParallelThreadPoolManager::GetTaskNum(
60 const std::map<std::string, std::map<std::string, std::string>> *config_info) {
61 if (config_info == nullptr) {
62 THREAD_ERROR("config_info is nullptr.");
63 return -1;
64 }
65 std::string runner_id;
66 auto it_id = config_info->find(kInnerModelParallelRunner);
67 if (it_id != config_info->end()) {
68 auto item_runner = it_id->second.find(kInnerRunnerID);
69 if (item_runner != it_id->second.end()) {
70 runner_id = it_id->second.at(kInnerRunnerID);
71 }
72 }
73 std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
74 if (runner_id.empty() || enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end() ||
75 !enable_shared_thread_pool_[runner_id]) {
76 THREAD_INFO("not enable shared parallel thread pool.");
77 return -1;
78 }
79 return thread_num_limit_[runner_id];
80 }
81
GetThreadPoolSize(ThreadPool * pool)82 int ParallelThreadPoolManager::GetThreadPoolSize(ThreadPool *pool) {
83 std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
84 ParallelThreadPool *thread_pool = static_cast<ParallelThreadPool *>(pool);
85 if (thread_pool == nullptr) {
86 return -1;
87 }
88 if (pool_workers_.find(thread_pool) != pool_workers_.end()) {
89 return pool_workers_[thread_pool].size();
90 } else {
91 return -1;
92 }
93 return -1;
94 }
95
BindPoolToRunner(ThreadPool * pool,const std::map<std::string,std::map<std::string,std::string>> * config_info)96 void ParallelThreadPoolManager::BindPoolToRunner(
97 ThreadPool *pool, const std::map<std::string, std::map<std::string, std::string>> *config_info) {
98 std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
99 if (config_info == nullptr) {
100 THREAD_ERROR("config_info is nullptr.");
101 return;
102 }
103 std::string runner_id;
104 auto it_id = config_info->find(kInnerModelParallelRunner);
105 if (it_id != config_info->end()) {
106 auto item_runner = it_id->second.find(kInnerRunnerID);
107 if (item_runner != it_id->second.end()) {
108 runner_id = it_id->second.at(kInnerRunnerID);
109 }
110 }
111 if (enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end() ||
112 !enable_shared_thread_pool_[runner_id]) {
113 THREAD_ERROR("not use parallel thread pool shared.");
114 return;
115 }
116 auto parallel_pool = static_cast<ParallelThreadPool *>(pool);
117 if (parallel_pool == nullptr) {
118 THREAD_ERROR("parallel pool is nullptr.");
119 }
120 int model_id = 0;
121 auto item_runner = it_id->second.find(kInnerModelID);
122 if (item_runner != it_id->second.end()) {
123 model_id = std::atoi(it_id->second.at(kInnerModelID).c_str());
124 }
125 auto runner_id_pools_iter = runner_id_pools_.find(runner_id);
126 if (runner_id_pools_iter == runner_id_pools_.end()) {
127 return;
128 }
129 auto &runner_id_pools = runner_id_pools_iter->second;
130 if (static_cast<size_t>(model_id) >= runner_id_pools.size()) {
131 return;
132 }
133 runner_id_pools_[runner_id].at(model_id) = parallel_pool;
134 auto all_workers = parallel_pool->GetParallelPoolWorkers();
135 for (size_t i = 0; i < all_workers.size(); i++) {
136 auto worker = static_cast<ParallelWorker *>(all_workers[i]);
137 pool_workers_[parallel_pool].push_back(worker);
138 }
139 worker_init_num_[runner_id]++;
140 }
141
GetEnableSharedThreadPool(std::string runner_id)142 bool ParallelThreadPoolManager::GetEnableSharedThreadPool(std::string runner_id) {
143 std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
144 if (enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end()) {
145 return false;
146 }
147 return enable_shared_thread_pool_[runner_id];
148 }
149
ActivatePool(const std::string & runner_id,int model_id)150 void ParallelThreadPoolManager::ActivatePool(const std::string &runner_id, int model_id) {
151 std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
152 if (enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end() ||
153 !enable_shared_thread_pool_[runner_id]) {
154 return;
155 }
156 if (idle_pool_num_.find(runner_id) == idle_pool_num_.end()) {
157 return;
158 }
159 auto runner_id_pools_iter = runner_id_pools_.find(runner_id);
160 if (runner_id_pools_iter == runner_id_pools_.end()) {
161 return;
162 }
163 idle_pool_num_[runner_id]--;
164 auto &runner_id_pools = runner_id_pools_iter->second;
165 if (static_cast<size_t>(model_id) < runner_id_pools.size()) {
166 auto &pool = runner_id_pools_iter->second[model_id];
167 pool->UseThreadPool(1);
168
169 auto pool_workers_iter = pool_workers_.find(pool);
170 if (pool_workers_iter != pool_workers_.end()) {
171 auto &workers = pool_workers_iter->second;
172 for (auto &worker : workers) {
173 worker->ActivateByOtherPoolTask();
174 }
175 }
176 }
177 }
178
SetFreePool(const std::string & runner_id,int model_id)179 void ParallelThreadPoolManager::SetFreePool(const std::string &runner_id, int model_id) {
180 std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
181 if (enable_shared_thread_pool_.find(runner_id) == enable_shared_thread_pool_.end() ||
182 !enable_shared_thread_pool_[runner_id]) {
183 return;
184 }
185 auto runner_id_pools_iter = runner_id_pools_.find(runner_id);
186 if (runner_id_pools_iter == runner_id_pools_.end()) {
187 return;
188 }
189 if (idle_pool_num_.find(runner_id) == idle_pool_num_.end()) {
190 return;
191 }
192 auto &runner_id_pools = runner_id_pools_iter->second;
193 if (static_cast<size_t>(model_id) < runner_id_pools.size()) {
194 auto &pool = runner_id_pools_iter->second[model_id];
195 pool->UseThreadPool(-1);
196 idle_pool_num_[runner_id]++;
197 }
198 }
199
GetIdleThreadPool(const std::string & runner_id,ParallelTask * task)200 ParallelThreadPool *ParallelThreadPoolManager::GetIdleThreadPool(const std::string &runner_id, ParallelTask *task) {
201 std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
202 auto runner_worker_num_iter = runner_worker_num_.find(runner_id);
203 auto worker_init_num_iter = worker_init_num_.find(runner_id);
204 if (runner_worker_num_iter == runner_worker_num_.end() || worker_init_num_iter == worker_init_num_.end() ||
205 runner_worker_num_iter->second != worker_init_num_iter->second) {
206 return nullptr;
207 }
208 auto idle_pool_num_iter = idle_pool_num_.find(runner_id);
209 if (idle_pool_num_iter == idle_pool_num_.end() || idle_pool_num_iter->second <= 0) {
210 return nullptr;
211 }
212
213 auto runner_id_pools_iter = runner_id_pools_.find(runner_id);
214 if (runner_id_pools_iter == runner_id_pools_.end()) {
215 return nullptr;
216 }
217 auto &all_pools = runner_id_pools_iter->second;
218 for (int pool_index = all_pools.size() - 1; pool_index >= 0; pool_index--) {
219 auto &pool = all_pools[pool_index];
220 if (pool->IsIdlePool()) {
221 auto pool_workers_iter = pool_workers_.find(pool);
222 if (pool_workers_iter == pool_workers_.end()) {
223 pool->UseThreadPool(-1);
224 continue;
225 }
226 auto &workers = pool_workers_iter->second;
227 auto remaining_thread_num_iter = remaining_thread_num_.find(runner_id);
228 if (remaining_thread_num_iter == remaining_thread_num_.end()) {
229 pool->UseThreadPool(-1);
230 continue;
231 }
232 for (size_t i = 0; i < workers.size() - remaining_thread_num_iter->second; i++) {
233 workers[i]->ActivateByOtherPoolTask(task);
234 }
235 return pool;
236 }
237 }
238 return nullptr;
239 }
240
ResetParallelThreadPoolManager(const std::string & runner_id)241 void ParallelThreadPoolManager::ResetParallelThreadPoolManager(const std::string &runner_id) {
242 std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
243 if (runner_id_pools_.find(runner_id) == runner_id_pools_.end()) {
244 return;
245 }
246 auto pools = runner_id_pools_[runner_id];
247 for (auto &pool : pools) {
248 pool_workers_.erase(pool);
249 }
250 runner_id_pools_.erase(runner_id);
251 has_idle_pool_.erase(runner_id);
252 enable_shared_thread_pool_.erase(runner_id);
253 remaining_thread_num_.erase(runner_id);
254 thread_num_limit_.erase(runner_id);
255 runner_worker_num_.erase(runner_id);
256 worker_init_num_.erase(runner_id);
257 idle_pool_num_.erase(runner_id);
258 }
259
~ParallelThreadPoolManager()260 ParallelThreadPoolManager::~ParallelThreadPoolManager() {
261 THREAD_INFO("~ParallelThreadPoolManager start.");
262 std::unique_lock<std::shared_mutex> l(pool_manager_mutex_);
263 pool_workers_.clear();
264 runner_id_pools_.clear();
265 has_idle_pool_.clear();
266 enable_shared_thread_pool_.clear();
267 remaining_thread_num_.clear();
268 thread_num_limit_.clear();
269 runner_worker_num_.clear();
270 worker_init_num_.clear();
271 idle_pool_num_.clear();
272 THREAD_INFO("~ParallelThreadPoolManager end.");
273 }
274 } // namespace mindspore
275