1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/api/model.h"
17 #include "include/api/model_group.h"
18 #include "include/api/model_parallel_runner.h"
19 #include "src/common/log_adapter.h"
20 #include "mindspore/lite/python/src/common_pybind.h"
21 #include "pybind11/pybind11.h"
22 #include "pybind11/stl.h"
23 #include "pybind11/functional.h"
24
25 namespace mindspore::lite {
26 namespace py = pybind11;
27
PyModelPredict(Model * model,const std::vector<MSTensorPtr> & inputs_ptr,const std::vector<MSTensorPtr> & outputs_ptr)28 std::vector<MSTensorPtr> PyModelPredict(Model *model, const std::vector<MSTensorPtr> &inputs_ptr,
29 const std::vector<MSTensorPtr> &outputs_ptr) {
30 if (model == nullptr) {
31 MS_LOG(ERROR) << "Model object cannot be nullptr";
32 return {};
33 }
34 std::vector<MSTensor> inputs = MSTensorPtrToMSTensor(inputs_ptr);
35 std::vector<MSTensor> outputs;
36 if (!outputs_ptr.empty()) {
37 outputs = MSTensorPtrToMSTensor(outputs_ptr);
38 }
39 if (!model->Predict(inputs, &outputs).IsOk()) {
40 return {};
41 }
42 if (!outputs_ptr.empty()) {
43 for (size_t i = 0; i < outputs.size(); i++) {
44 outputs_ptr[i]->SetShape(outputs[i].Shape());
45 outputs_ptr[i]->SetDataType(outputs[i].DataType());
46 }
47 return outputs_ptr;
48 }
49 return MSTensorToMSTensorPtr(outputs);
50 }
51
PyModelResize(Model * model,const std::vector<MSTensorPtr> & inputs_ptr,const std::vector<std::vector<int64_t>> & new_shapes)52 Status PyModelResize(Model *model, const std::vector<MSTensorPtr> &inputs_ptr,
53 const std::vector<std::vector<int64_t>> &new_shapes) {
54 if (model == nullptr) {
55 MS_LOG(ERROR) << "Model object cannot be nullptr";
56 return kLiteError;
57 }
58 auto inputs = MSTensorPtrToMSTensor(inputs_ptr);
59 return model->Resize(inputs, new_shapes);
60 }
61
PyModelUpdateConfig(Model * model,const std::string & key,const std::map<std::string,std::string> & value)62 Status PyModelUpdateConfig(Model *model, const std::string &key, const std::map<std::string, std::string> &value) {
63 if (model == nullptr) {
64 MS_LOG(ERROR) << "Model object cannot be nullptr";
65 return kLiteError;
66 }
67 for (auto &item : value) {
68 if (model->UpdateConfig(key, item).IsError()) {
69 MS_LOG(ERROR) << "Update config failed, section: " << key << ", config name: " << item.first
70 << ", config value: " << item.second;
71 return kLiteError;
72 }
73 }
74 return kSuccess;
75 }
76
PyModelGetInputs(Model * model)77 std::vector<MSTensorPtr> PyModelGetInputs(Model *model) {
78 if (model == nullptr) {
79 MS_LOG(ERROR) << "Model object cannot be nullptr";
80 return {};
81 }
82 return MSTensorToMSTensorPtr(model->GetInputs());
83 }
84
PyModelGetOutputs(Model * model)85 std::vector<MSTensorPtr> PyModelGetOutputs(Model *model) {
86 if (model == nullptr) {
87 MS_LOG(ERROR) << "Model object cannot be nullptr";
88 return {};
89 }
90 return MSTensorToMSTensorPtr(model->GetOutputs());
91 }
92
PyModelGetModelInfo(Model * model,const std::string & key)93 std::string PyModelGetModelInfo(Model *model, const std::string &key) {
94 std::string empty;
95 if (model == nullptr) {
96 MS_LOG(ERROR) << "Model object cannot be nullptr";
97 return empty;
98 }
99 return model->GetModelInfo(key);
100 }
101
PyModelUpdateWeights(Model * model,const std::vector<std::vector<MSTensorPtr>> & weights)102 Status PyModelUpdateWeights(Model *model, const std::vector<std::vector<MSTensorPtr>> &weights) {
103 if (model == nullptr) {
104 MS_LOG(ERROR) << "Model object cannot be nullptr";
105 return {};
106 }
107 std::vector<std::vector<MSTensor>> new_weights;
108 for (auto &weight : weights) {
109 std::vector<MSTensor> new_weight = MSTensorPtrToMSTensor(weight);
110 new_weights.push_back(new_weight);
111 }
112 if (!model->UpdateWeights(new_weights).IsOk()) {
113 return kLiteError;
114 }
115 return kSuccess;
116 }
117
ModelPyBind(const py::module & m)118 void ModelPyBind(const py::module &m) {
119 (void)py::enum_<ModelType>(m, "ModelType")
120 .value("kMindIR", ModelType::kMindIR)
121 .value("kMindIR_Lite", ModelType::kMindIR_Lite);
122
123 (void)py::enum_<StatusCode>(m, "StatusCode")
124 .value("kSuccess", StatusCode::kSuccess)
125 .value("kLiteError", StatusCode::kLiteError)
126 .value("kLiteNullptr", StatusCode::kLiteNullptr)
127 .value("kLiteParamInvalid", StatusCode::kLiteParamInvalid)
128 .value("kLiteNoChange", StatusCode::kLiteNoChange)
129 .value("kLiteSuccessExit", StatusCode::kLiteSuccessExit)
130 .value("kLiteMemoryFailed", StatusCode::kLiteMemoryFailed)
131 .value("kLiteNotSupport", StatusCode::kLiteNotSupport)
132 .value("kLiteThreadPoolError", StatusCode::kLiteThreadPoolError)
133 .value("kLiteUninitializedObj", StatusCode::kLiteUninitializedObj)
134 .value("kLiteFileError", StatusCode::kLiteFileError)
135 .value("kLiteServiceDeny", StatusCode::kLiteServiceDeny)
136 .value("kLiteOutOfTensorRange", StatusCode::kLiteOutOfTensorRange)
137 .value("kLiteInputTensorError", StatusCode::kLiteInputTensorError)
138 .value("kLiteReentrantError", StatusCode::kLiteReentrantError)
139 .value("kLiteGraphFileError", StatusCode::kLiteGraphFileError)
140 .value("kLiteNotFindOp", StatusCode::kLiteNotFindOp)
141 .value("kLiteInvalidOpName", StatusCode::kLiteInvalidOpName)
142 .value("kLiteInvalidOpAttr", StatusCode::kLiteInvalidOpAttr)
143 .value("kLiteOpExecuteFailure", StatusCode::kLiteOpExecuteFailure)
144 .value("kLiteFormatError", StatusCode::kLiteFormatError)
145 .value("kLiteInferError", StatusCode::kLiteInferError)
146 .value("kLiteInferInvalid", StatusCode::kLiteInferInvalid)
147 .value("kLiteInputParamInvalid", StatusCode::kLiteInputParamInvalid)
148 .value("kLiteLLMKVCacheNotExist", StatusCode::kLiteLLMKVCacheNotExist)
149 .value("kLiteLLMWaitProcessTimeOut", StatusCode::kLiteLLMWaitProcessTimeOut)
150 .value("kLiteLLMRepeatRequest", StatusCode::kLiteLLMRepeatRequest)
151 .value("kLiteLLMRequestAlreadyCompleted", StatusCode::kLiteLLMRequestAlreadyCompleted)
152 .value("kLiteLLMEngineFinalized", StatusCode::kLiteLLMEngineFinalized)
153 .value("kLiteLLMNotYetLink", StatusCode::kLiteLLMNotYetLink)
154 .value("kLiteLLMAlreadyLink", StatusCode::kLiteLLMAlreadyLink)
155 .value("kLiteLLMLinkFailed", StatusCode::kLiteLLMLinkFailed)
156 .value("kLiteLLMUnlinkFailed", StatusCode::kLiteLLMUnlinkFailed)
157 .value("kLiteLLMNofiryPromptUnlinkFailed", StatusCode::kLiteLLMNofiryPromptUnlinkFailed)
158 .value("kLiteLLMClusterNumExceedLimit", StatusCode::kLiteLLMClusterNumExceedLimit)
159 .value("kLiteLLMProcessingLink", StatusCode::kLiteLLMProcessingLink)
160 .value("kLiteLLMOutOfMemory", StatusCode::kLiteLLMOutOfMemory)
161 .value("kLiteLLMPrefixAlreadyExist", StatusCode::kLiteLLMPrefixAlreadyExist)
162 .value("kLiteLLMPrefixNotExist", StatusCode::kLiteLLMPrefixNotExist)
163 .value("kLiteLLMSeqLenOverLimit", StatusCode::kLiteLLMSeqLenOverLimit)
164 .value("kLiteLLMNoFreeBlock", StatusCode::kLiteLLMNoFreeBlock)
165 .value("kLiteLLMBlockOutOfMemory", StatusCode::kLiteLLMBlockOutOfMemory);
166
167 (void)py::class_<Status, std::shared_ptr<Status>>(m, "Status")
168 .def(py::init<>())
169 .def(py::init<StatusCode>())
170 .def("ToString", &Status::ToString)
171 .def("IsOk", &Status::IsOk)
172 .def("IsError", &Status::IsError)
173 .def("StatusCode", &Status::StatusCode);
174
175 (void)py::class_<Model, std::shared_ptr<Model>>(m, "ModelBind")
176 .def(py::init<>())
177 .def("build_from_buff",
178 py::overload_cast<const void *, size_t, ModelType, const std::shared_ptr<Context> &>(&Model::Build),
179 py::call_guard<py::gil_scoped_release>())
180 .def("build_from_file",
181 py::overload_cast<const std::string &, ModelType, const std::shared_ptr<Context> &>(&Model::Build),
182 py::call_guard<py::gil_scoped_release>())
183 .def("build_from_buff_with_decrypt",
184 py::overload_cast<const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
185 const std::string &, const std::string &>(&Model::Build))
186 .def("build_from_file_with_decrypt",
187 py::overload_cast<const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &,
188 const std::string &, const std::string &>(&Model::Build))
189 .def("load_config", py::overload_cast<const std::string &>(&Model::LoadConfig))
190 .def("update_config", &PyModelUpdateConfig)
191 .def("resize", &PyModelResize)
192 .def("predict", &PyModelPredict, py::call_guard<py::gil_scoped_release>())
193 .def("update_weights", &PyModelUpdateWeights, py::call_guard<py::gil_scoped_release>())
194 .def("get_inputs", &PyModelGetInputs)
195 .def("get_outputs", &PyModelGetOutputs)
196 .def("get_model_info", &PyModelGetModelInfo)
197 .def("get_input_by_tensor_name",
198 [](Model &model, const std::string &tensor_name) { return model.GetInputByTensorName(tensor_name); })
199 .def("get_output_by_tensor_name",
200 [](Model &model, const std::string &tensor_name) { return model.GetOutputByTensorName(tensor_name); });
201 }
202
203 #ifdef PARALLEL_INFERENCE
PyModelParallelRunnerPredict(ModelParallelRunner * runner,const std::vector<MSTensorPtr> & inputs_ptr,const std::vector<MSTensorPtr> & outputs_ptr,const MSKernelCallBack & before=nullptr,const MSKernelCallBack & after=nullptr)204 std::vector<MSTensorPtr> PyModelParallelRunnerPredict(ModelParallelRunner *runner,
205 const std::vector<MSTensorPtr> &inputs_ptr,
206 const std::vector<MSTensorPtr> &outputs_ptr,
207 const MSKernelCallBack &before = nullptr,
208 const MSKernelCallBack &after = nullptr) {
209 if (runner == nullptr) {
210 MS_LOG(ERROR) << "ModelParallelRunner object cannot be nullptr";
211 return {};
212 }
213 std::vector<MSTensor> inputs = MSTensorPtrToMSTensor(inputs_ptr);
214 std::vector<MSTensor> outputs;
215 if (!outputs_ptr.empty()) {
216 outputs = MSTensorPtrToMSTensor(outputs_ptr);
217 }
218 if (!runner->Predict(inputs, &outputs, before, after).IsOk()) {
219 return {};
220 }
221 return MSTensorToMSTensorPtr(outputs);
222 }
223
PyModelParallelRunnerGetInputs(ModelParallelRunner * runner)224 std::vector<MSTensorPtr> PyModelParallelRunnerGetInputs(ModelParallelRunner *runner) {
225 if (runner == nullptr) {
226 MS_LOG(ERROR) << "ModelParallelRunner object cannot be nullptr";
227 return {};
228 }
229 return MSTensorToMSTensorPtr(runner->GetInputs());
230 }
231
PyModelParallelRunnerGetOutputs(ModelParallelRunner * runner)232 std::vector<MSTensorPtr> PyModelParallelRunnerGetOutputs(ModelParallelRunner *runner) {
233 if (runner == nullptr) {
234 MS_LOG(ERROR) << "ModelParallelRunner object cannot be nullptr";
235 return {};
236 }
237 return MSTensorToMSTensorPtr(runner->GetOutputs());
238 }
239 #endif
240
ModelParallelRunnerPyBind(const py::module & m)241 void ModelParallelRunnerPyBind(const py::module &m) {
242 #ifdef PARALLEL_INFERENCE
243 (void)py::class_<RunnerConfig, std::shared_ptr<RunnerConfig>>(m, "RunnerConfigBind")
244 .def(py::init<>())
245 .def("set_config_info", py::overload_cast<const std::string &, const std::map<std::string, std::string> &>(
246 &RunnerConfig::SetConfigInfo))
247 .def("get_config_info", &RunnerConfig::GetConfigInfo)
248 .def("set_config_path", py::overload_cast<const std::string &>(&RunnerConfig::SetConfigPath))
249 .def("get_config_path", &RunnerConfig::GetConfigPath)
250 .def("set_workers_num", &RunnerConfig::SetWorkersNum)
251 .def("get_workers_num", &RunnerConfig::GetWorkersNum)
252 .def("set_context", &RunnerConfig::SetContext)
253 .def("get_context", &RunnerConfig::GetContext)
254 .def("set_device_ids", &RunnerConfig::SetDeviceIds)
255 .def("get_device_ids", &RunnerConfig::GetDeviceIds)
256 .def("get_context_info",
257 [](RunnerConfig &runner_config) {
258 const auto &context = runner_config.GetContext();
259 std::string result = "thread num: " + std::to_string(context->GetThreadNum()) +
260 ", bind mode: " + std::to_string(context->GetThreadAffinityMode());
261 return result;
262 })
263 .def("get_config_info_string", [](RunnerConfig &runner_config) {
264 std::string result = "";
265 const auto &config_info = runner_config.GetConfigInfo();
266 for (auto §ion : config_info) {
267 result += section.first + ": ";
268 for (auto &config : section.second) {
269 auto temp = config.first + " " + config.second + "\n";
270 result += temp;
271 }
272 }
273 return result;
274 });
275
276 (void)py::class_<ModelParallelRunner, std::shared_ptr<ModelParallelRunner>>(m, "ModelParallelRunnerBind")
277 .def(py::init<>())
278 .def("init",
279 py::overload_cast<const std::string &, const std::shared_ptr<RunnerConfig> &>(&ModelParallelRunner::Init),
280 py::call_guard<py::gil_scoped_release>())
281 .def("get_inputs", &PyModelParallelRunnerGetInputs)
282 .def("get_outputs", &PyModelParallelRunnerGetOutputs)
283 .def("predict", &PyModelParallelRunnerPredict, py::call_guard<py::gil_scoped_release>());
284 #endif
285 }
286
PyModelGroupAddModelByObject(ModelGroup * model_group,const std::vector<Model * > & models_ptr)287 Status PyModelGroupAddModelByObject(ModelGroup *model_group, const std::vector<Model *> &models_ptr) {
288 if (model_group == nullptr) {
289 MS_LOG(ERROR) << "Model group object cannot be nullptr";
290 return {};
291 }
292 std::vector<Model> models;
293 for (auto model_ptr : models_ptr) {
294 if (model_ptr == nullptr) {
295 MS_LOG(ERROR) << "Model object cannot be nullptr";
296 return {};
297 }
298 models.push_back(*model_ptr);
299 }
300 return model_group->AddModel(models);
301 }
302
ModelGroupPyBind(const py::module & m)303 void ModelGroupPyBind(const py::module &m) {
304 (void)py::enum_<ModelGroupFlag>(m, "ModelGroupFlag")
305 .value("kShareWeight", ModelGroupFlag::kShareWeight)
306 .value("kShareWorkspace", ModelGroupFlag::kShareWorkspace);
307
308 (void)py::class_<ModelGroup, std::shared_ptr<ModelGroup>>(m, "ModelGroupBind")
309 .def(py::init<ModelGroupFlag>())
310 .def("add_model", py::overload_cast<const std::vector<std::string> &>(&ModelGroup::AddModel))
311 .def("add_model_by_object", &PyModelGroupAddModelByObject)
312 .def("cal_max_size_of_workspace",
313 py::overload_cast<ModelType, const std::shared_ptr<Context> &>(&ModelGroup::CalMaxSizeOfWorkspace));
314 }
315 } // namespace mindspore::lite
316