1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/api/model_parallel_runner.h"
17 #include "src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.h"
18 #include "src/extendrt/cxx_api/model_pool/runner_config.h"
19 #include "src/common/log_adapter.h"
20 #ifdef CAPTURE_SIGNALS
21 #include "src/extendrt/signal_handler.h"
22 #endif
23 namespace mindspore {
24 namespace {
25 constexpr size_t kMaxSectionNum = 100;
26 constexpr size_t kMaxConfigNumPerSection = 1000;
27 } // namespace
28 #ifdef USE_GLOG
29 extern "C" {
30 extern void mindspore_log_init();
31 }
32 #endif
33
34 std::mutex g_model_parallel_runner_mutex;
35
RunnerConfig()36 RunnerConfig::RunnerConfig() : data_(std::make_shared<Data>()) {}
37
~RunnerConfig()38 RunnerConfig::~RunnerConfig() {}
39
SetWorkersNum(int32_t workers_num)40 void RunnerConfig::SetWorkersNum(int32_t workers_num) {
41 if (data_ == nullptr) {
42 MS_LOG(ERROR) << "Runner config data is nullptr.";
43 return;
44 }
45 data_->workers_num = workers_num;
46 }
47
SetContext(const std::shared_ptr<Context> & context)48 void RunnerConfig::SetContext(const std::shared_ptr<Context> &context) {
49 if (data_ == nullptr) {
50 MS_LOG(ERROR) << "Runner config data is nullptr.";
51 return;
52 }
53 MS_CHECK_TRUE_RET_VOID(context != nullptr);
54 data_->context = context;
55 }
56
GetWorkersNum() const57 int32_t RunnerConfig::GetWorkersNum() const {
58 if (data_ == nullptr) {
59 MS_LOG(ERROR) << "Runner config data is nullptr.";
60 return -1;
61 }
62 return data_->workers_num;
63 }
64
GetContext() const65 std::shared_ptr<Context> RunnerConfig::GetContext() const {
66 if (data_ == nullptr) {
67 MS_LOG(ERROR) << "Runner config data is nullptr.";
68 return nullptr;
69 }
70 return data_->context;
71 }
72
SetConfigInfo(const std::vector<char> & section,const std::map<std::vector<char>,std::vector<char>> & config)73 void RunnerConfig::SetConfigInfo(const std::vector<char> §ion,
74 const std::map<std::vector<char>, std::vector<char>> &config) {
75 if (data_ == nullptr) {
76 MS_LOG(ERROR) << "Runner config data is nullptr.";
77 return;
78 }
79 if (data_->config_info.size() > kMaxSectionNum) {
80 MS_LOG(ERROR) << "The number of added sessions exceeds the maximum[" << kMaxSectionNum << "] limit.";
81 return;
82 }
83 if (config.size() > kMaxConfigNumPerSection) {
84 MS_LOG(ERROR) << "The number of added config exceeds the maximum[" << kMaxConfigNumPerSection << "] limit.";
85 return;
86 }
87 data_->config_info[CharToString(section)] = MapVectorCharToString(config);
88 return;
89 }
90
SetConfigPath(const std::vector<char> & config_path)91 void RunnerConfig::SetConfigPath(const std::vector<char> &config_path) {
92 if (data_ == nullptr) {
93 MS_LOG(ERROR) << "Runner config data is nullptr.";
94 return;
95 }
96 data_->config_path = CharToString(config_path);
97 return;
98 }
99
GetConfigPathChar() const100 std::vector<char> RunnerConfig::GetConfigPathChar() const {
101 if (data_ == nullptr) {
102 MS_LOG(ERROR) << "Runner config data is nullptr.";
103 std::vector<char> empty;
104 return empty;
105 }
106 return StringToChar(data_->config_path);
107 }
108
GetConfigInfoChar() const109 std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> RunnerConfig::GetConfigInfoChar() const {
110 if (data_ == nullptr) {
111 MS_LOG(ERROR) << "Runner config data is nullptr.";
112 std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> empty;
113 return empty;
114 }
115 return MapMapStringToChar(data_->config_info);
116 }
117
SetDeviceIds(const std::vector<uint32_t> & device_ids)118 void RunnerConfig::SetDeviceIds(const std::vector<uint32_t> &device_ids) {
119 if (data_ == nullptr) {
120 MS_LOG(ERROR) << "Runner config data is nullptr.";
121 return;
122 }
123 data_->device_ids = device_ids;
124 }
125
GetDeviceIds() const126 std::vector<uint32_t> RunnerConfig::GetDeviceIds() const {
127 if (data_ == nullptr) {
128 MS_LOG(ERROR) << "Runner config data is nullptr.";
129 return {};
130 }
131 return data_->device_ids;
132 }
133
ModelParallelRunner()134 ModelParallelRunner::ModelParallelRunner() {}
135
~ModelParallelRunner()136 ModelParallelRunner::~ModelParallelRunner() {}
137
Init(const std::vector<char> & model_path,const std::shared_ptr<RunnerConfig> & runner_config)138 Status ModelParallelRunner::Init(const std::vector<char> &model_path,
139 const std::shared_ptr<RunnerConfig> &runner_config) {
140 {
141 std::lock_guard<std::mutex> l(g_model_parallel_runner_mutex);
142 #ifdef USE_GLOG
143 mindspore::mindspore_log_init();
144 #endif
145 if (model_parallel_runner_impl_ == nullptr) {
146 model_parallel_runner_impl_ = std::make_shared<ModelParallelRunnerImpl>();
147 if (model_parallel_runner_impl_ == nullptr) {
148 MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
149 return kLiteNullptr;
150 }
151 }
152 }
153 return model_parallel_runner_impl_->Init(CharToString(model_path), runner_config);
154 }
155
Init(const void * model_data,size_t data_size,const std::shared_ptr<RunnerConfig> & runner_config)156 Status ModelParallelRunner::Init(const void *model_data, size_t data_size,
157 const std::shared_ptr<RunnerConfig> &runner_config) {
158 {
159 std::lock_guard<std::mutex> l(g_model_parallel_runner_mutex);
160 #ifdef USE_GLOG
161 mindspore::mindspore_log_init();
162 #endif
163 if (model_parallel_runner_impl_ == nullptr) {
164 model_parallel_runner_impl_ = std::make_shared<ModelParallelRunnerImpl>();
165 if (model_parallel_runner_impl_ == nullptr) {
166 MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
167 return kLiteNullptr;
168 }
169 }
170 }
171 return model_parallel_runner_impl_->Init(model_data, data_size, runner_config);
172 }
173
GetInputs()174 std::vector<MSTensor> ModelParallelRunner::GetInputs() {
175 if (model_parallel_runner_impl_ == nullptr) {
176 std::vector<MSTensor> empty;
177 MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetInput API.";
178 return empty;
179 }
180 return model_parallel_runner_impl_->GetInputs();
181 }
182
GetOutputs()183 std::vector<MSTensor> ModelParallelRunner::GetOutputs() {
184 if (model_parallel_runner_impl_ == nullptr) {
185 std::vector<MSTensor> empty;
186 MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetOutputs API.";
187 return empty;
188 }
189 return model_parallel_runner_impl_->GetOutputs();
190 }
191
Predict(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)192 Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
193 const MSKernelCallBack &before, const MSKernelCallBack &after) {
194 if (model_parallel_runner_impl_ == nullptr) {
195 MS_LOG(ERROR) << "ModelParallelRunner Not Initialize.";
196 return kLiteNullptr;
197 }
198 return model_parallel_runner_impl_->Predict(inputs, outputs, before, after);
199 }
200 } // namespace mindspore
201