• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/thread_pool_reuse_manager.h"
18 #include <mutex>
19 #include "src/common/log_adapter.h"
20 
21 namespace mindspore {
22 namespace lite {
23 namespace {
24 std::mutex l;
25 }  // namespace
26 
~ThreadPoolReuseManager()27 ThreadPoolReuseManager::~ThreadPoolReuseManager() {
28   std::lock_guard<std::mutex> lock(l);
29   for (auto &pair : thread_pool_container_) {
30     for (auto &thread_pool : pair.second) {
31       if (thread_pool) {
32         delete thread_pool;
33       }
34       thread_pool = nullptr;
35     }
36   }
37   thread_pool_container_.clear();
38 }
39 
GetThreadPool(size_t actor_num,size_t inter_op_parallel_num,size_t thread_num,BindMode bind_mode,const std::vector<int> & core_list,std::string runner_id)40 ThreadPool *ThreadPoolReuseManager::GetThreadPool(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num,
41                                                   BindMode bind_mode, const std::vector<int> &core_list,
42                                                   std::string runner_id) {
43 #ifdef SERVER_INFERENCE
44   auto hash_key = ComputeHash(actor_num, inter_op_parallel_num, thread_num, bind_mode, core_list);
45   std::lock_guard<std::mutex> lock(l);
46   if (thread_pool_container_.find(hash_key) == thread_pool_container_.end()) {
47     return nullptr;
48   }
49   if (thread_pool_container_[hash_key].empty()) {
50     return nullptr;
51   }
52   auto thread_pool = thread_pool_container_[hash_key].back();
53   if (inter_op_parallel_num > 1 && !thread_pool->SetRunnerID(runner_id)) {
54     MS_LOG(WARNING) << "can not reuse thread pool.";
55     return nullptr;
56   }
57   thread_pool_container_[hash_key].pop_back();
58   return thread_pool;
59 #else
60   return nullptr;
61 #endif
62 }
63 
RetrieveThreadPool(size_t actor_num,size_t inter_op_parallel_num,size_t thread_num,BindMode bind_mode,const std::vector<int> & core_list,ThreadPool * thread_pool)64 void ThreadPoolReuseManager::RetrieveThreadPool(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num,
65                                                 BindMode bind_mode, const std::vector<int> &core_list,
66                                                 ThreadPool *thread_pool) {
67   if (thread_pool == nullptr) {
68     return;
69   }
70 #ifdef SERVER_INFERENCE
71   auto hash_key = ComputeHash(actor_num, inter_op_parallel_num, thread_num, bind_mode, core_list);
72   std::lock_guard<std::mutex> lock(l);
73   thread_pool_container_[hash_key].push_back(thread_pool);
74 #else
75   delete thread_pool;
76 #endif
77 }
78 
ComputeHash(size_t actor_num,size_t inter_op_parallel_num,size_t thread_num,BindMode bind_mode,const std::vector<int> & core_list)79 std::string ThreadPoolReuseManager::ComputeHash(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num,
80                                                 BindMode bind_mode, const std::vector<int> &core_list) {
81   std::string hash_key = std::to_string(actor_num) + std::to_string(inter_op_parallel_num) +
82                          std::to_string(thread_num) + std::to_string(bind_mode);
83   for (auto val : core_list) {
84     hash_key += std::to_string(val);
85   }
86   return hash_key;
87 }
88 }  // namespace lite
89 }  // namespace mindspore
90