1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/thread_pool_reuse_manager.h"
18 #include <mutex>
19 #include "src/common/log_adapter.h"
20
21 namespace mindspore {
22 namespace lite {
23 namespace {
24 std::mutex l;
25 } // namespace
26
~ThreadPoolReuseManager()27 ThreadPoolReuseManager::~ThreadPoolReuseManager() {
28 std::lock_guard<std::mutex> lock(l);
29 for (auto &pair : thread_pool_container_) {
30 for (auto &thread_pool : pair.second) {
31 if (thread_pool) {
32 delete thread_pool;
33 }
34 thread_pool = nullptr;
35 }
36 }
37 thread_pool_container_.clear();
38 }
39
GetThreadPool(size_t actor_num,size_t inter_op_parallel_num,size_t thread_num,BindMode bind_mode,const std::vector<int> & core_list,std::string runner_id)40 ThreadPool *ThreadPoolReuseManager::GetThreadPool(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num,
41 BindMode bind_mode, const std::vector<int> &core_list,
42 std::string runner_id) {
43 #ifdef SERVER_INFERENCE
44 auto hash_key = ComputeHash(actor_num, inter_op_parallel_num, thread_num, bind_mode, core_list);
45 std::lock_guard<std::mutex> lock(l);
46 if (thread_pool_container_.find(hash_key) == thread_pool_container_.end()) {
47 return nullptr;
48 }
49 if (thread_pool_container_[hash_key].empty()) {
50 return nullptr;
51 }
52 auto thread_pool = thread_pool_container_[hash_key].back();
53 if (inter_op_parallel_num > 1 && !thread_pool->SetRunnerID(runner_id)) {
54 MS_LOG(WARNING) << "can not reuse thread pool.";
55 return nullptr;
56 }
57 thread_pool_container_[hash_key].pop_back();
58 return thread_pool;
59 #else
60 return nullptr;
61 #endif
62 }
63
RetrieveThreadPool(size_t actor_num,size_t inter_op_parallel_num,size_t thread_num,BindMode bind_mode,const std::vector<int> & core_list,ThreadPool * thread_pool)64 void ThreadPoolReuseManager::RetrieveThreadPool(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num,
65 BindMode bind_mode, const std::vector<int> &core_list,
66 ThreadPool *thread_pool) {
67 if (thread_pool == nullptr) {
68 return;
69 }
70 #ifdef SERVER_INFERENCE
71 auto hash_key = ComputeHash(actor_num, inter_op_parallel_num, thread_num, bind_mode, core_list);
72 std::lock_guard<std::mutex> lock(l);
73 thread_pool_container_[hash_key].push_back(thread_pool);
74 #else
75 delete thread_pool;
76 #endif
77 }
78
ComputeHash(size_t actor_num,size_t inter_op_parallel_num,size_t thread_num,BindMode bind_mode,const std::vector<int> & core_list)79 std::string ThreadPoolReuseManager::ComputeHash(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num,
80 BindMode bind_mode, const std::vector<int> &core_list) {
81 std::string hash_key = std::to_string(actor_num) + std::to_string(inter_op_parallel_num) +
82 std::to_string(thread_num) + std::to_string(bind_mode);
83 for (auto val : core_list) {
84 hash_key += std::to_string(val);
85 }
86 return hash_key;
87 }
88 } // namespace lite
89 } // namespace mindspore
90