• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/api/model.h"
17 #include "include/api/model_group.h"
18 #include "include/api/model_parallel_runner.h"
19 #include "src/common/log_adapter.h"
20 #include "mindspore/lite/python/src/common_pybind.h"
21 #include "pybind11/pybind11.h"
22 #include "pybind11/stl.h"
23 #include "pybind11/functional.h"
24 
25 namespace mindspore::lite {
26 namespace py = pybind11;
27 
PyModelPredict(Model * model,const std::vector<MSTensorPtr> & inputs_ptr,const std::vector<MSTensorPtr> & outputs_ptr)28 std::vector<MSTensorPtr> PyModelPredict(Model *model, const std::vector<MSTensorPtr> &inputs_ptr,
29                                         const std::vector<MSTensorPtr> &outputs_ptr) {
30   if (model == nullptr) {
31     MS_LOG(ERROR) << "Model object cannot be nullptr";
32     return {};
33   }
34   std::vector<MSTensor> inputs = MSTensorPtrToMSTensor(inputs_ptr);
35   std::vector<MSTensor> outputs;
36   if (!outputs_ptr.empty()) {
37     outputs = MSTensorPtrToMSTensor(outputs_ptr);
38   }
39   if (!model->Predict(inputs, &outputs).IsOk()) {
40     return {};
41   }
42   if (!outputs_ptr.empty()) {
43     for (size_t i = 0; i < outputs.size(); i++) {
44       outputs_ptr[i]->SetShape(outputs[i].Shape());
45       outputs_ptr[i]->SetDataType(outputs[i].DataType());
46     }
47     return outputs_ptr;
48   }
49   return MSTensorToMSTensorPtr(outputs);
50 }
51 
PyModelResize(Model * model,const std::vector<MSTensorPtr> & inputs_ptr,const std::vector<std::vector<int64_t>> & new_shapes)52 Status PyModelResize(Model *model, const std::vector<MSTensorPtr> &inputs_ptr,
53                      const std::vector<std::vector<int64_t>> &new_shapes) {
54   if (model == nullptr) {
55     MS_LOG(ERROR) << "Model object cannot be nullptr";
56     return kLiteError;
57   }
58   auto inputs = MSTensorPtrToMSTensor(inputs_ptr);
59   return model->Resize(inputs, new_shapes);
60 }
61 
PyModelUpdateConfig(Model * model,const std::string & key,const std::map<std::string,std::string> & value)62 Status PyModelUpdateConfig(Model *model, const std::string &key, const std::map<std::string, std::string> &value) {
63   if (model == nullptr) {
64     MS_LOG(ERROR) << "Model object cannot be nullptr";
65     return kLiteError;
66   }
67   for (auto &item : value) {
68     if (model->UpdateConfig(key, item).IsError()) {
69       MS_LOG(ERROR) << "Update config failed, section: " << key << ", config name: " << item.first
70                     << ", config value: " << item.second;
71       return kLiteError;
72     }
73   }
74   return kSuccess;
75 }
76 
PyModelGetInputs(Model * model)77 std::vector<MSTensorPtr> PyModelGetInputs(Model *model) {
78   if (model == nullptr) {
79     MS_LOG(ERROR) << "Model object cannot be nullptr";
80     return {};
81   }
82   return MSTensorToMSTensorPtr(model->GetInputs());
83 }
84 
PyModelGetOutputs(Model * model)85 std::vector<MSTensorPtr> PyModelGetOutputs(Model *model) {
86   if (model == nullptr) {
87     MS_LOG(ERROR) << "Model object cannot be nullptr";
88     return {};
89   }
90   return MSTensorToMSTensorPtr(model->GetOutputs());
91 }
92 
PyModelGetModelInfo(Model * model,const std::string & key)93 std::string PyModelGetModelInfo(Model *model, const std::string &key) {
94   std::string empty;
95   if (model == nullptr) {
96     MS_LOG(ERROR) << "Model object cannot be nullptr";
97     return empty;
98   }
99   return model->GetModelInfo(key);
100 }
101 
PyModelUpdateWeights(Model * model,const std::vector<std::vector<MSTensorPtr>> & weights)102 Status PyModelUpdateWeights(Model *model, const std::vector<std::vector<MSTensorPtr>> &weights) {
103   if (model == nullptr) {
104     MS_LOG(ERROR) << "Model object cannot be nullptr";
105     return {};
106   }
107   std::vector<std::vector<MSTensor>> new_weights;
108   for (auto &weight : weights) {
109     std::vector<MSTensor> new_weight = MSTensorPtrToMSTensor(weight);
110     new_weights.push_back(new_weight);
111   }
112   if (!model->UpdateWeights(new_weights).IsOk()) {
113     return kLiteError;
114   }
115   return kSuccess;
116 }
117 
ModelPyBind(const py::module & m)118 void ModelPyBind(const py::module &m) {
119   (void)py::enum_<ModelType>(m, "ModelType")
120     .value("kMindIR", ModelType::kMindIR)
121     .value("kMindIR_Lite", ModelType::kMindIR_Lite);
122 
123   (void)py::enum_<StatusCode>(m, "StatusCode")
124     .value("kSuccess", StatusCode::kSuccess)
125     .value("kLiteError", StatusCode::kLiteError)
126     .value("kLiteNullptr", StatusCode::kLiteNullptr)
127     .value("kLiteParamInvalid", StatusCode::kLiteParamInvalid)
128     .value("kLiteNoChange", StatusCode::kLiteNoChange)
129     .value("kLiteSuccessExit", StatusCode::kLiteSuccessExit)
130     .value("kLiteMemoryFailed", StatusCode::kLiteMemoryFailed)
131     .value("kLiteNotSupport", StatusCode::kLiteNotSupport)
132     .value("kLiteThreadPoolError", StatusCode::kLiteThreadPoolError)
133     .value("kLiteUninitializedObj", StatusCode::kLiteUninitializedObj)
134     .value("kLiteFileError", StatusCode::kLiteFileError)
135     .value("kLiteServiceDeny", StatusCode::kLiteServiceDeny)
136     .value("kLiteOutOfTensorRange", StatusCode::kLiteOutOfTensorRange)
137     .value("kLiteInputTensorError", StatusCode::kLiteInputTensorError)
138     .value("kLiteReentrantError", StatusCode::kLiteReentrantError)
139     .value("kLiteGraphFileError", StatusCode::kLiteGraphFileError)
140     .value("kLiteNotFindOp", StatusCode::kLiteNotFindOp)
141     .value("kLiteInvalidOpName", StatusCode::kLiteInvalidOpName)
142     .value("kLiteInvalidOpAttr", StatusCode::kLiteInvalidOpAttr)
143     .value("kLiteOpExecuteFailure", StatusCode::kLiteOpExecuteFailure)
144     .value("kLiteFormatError", StatusCode::kLiteFormatError)
145     .value("kLiteInferError", StatusCode::kLiteInferError)
146     .value("kLiteInferInvalid", StatusCode::kLiteInferInvalid)
147     .value("kLiteInputParamInvalid", StatusCode::kLiteInputParamInvalid)
148     .value("kLiteLLMKVCacheNotExist", StatusCode::kLiteLLMKVCacheNotExist)
149     .value("kLiteLLMWaitProcessTimeOut", StatusCode::kLiteLLMWaitProcessTimeOut)
150     .value("kLiteLLMRepeatRequest", StatusCode::kLiteLLMRepeatRequest)
151     .value("kLiteLLMRequestAlreadyCompleted", StatusCode::kLiteLLMRequestAlreadyCompleted)
152     .value("kLiteLLMEngineFinalized", StatusCode::kLiteLLMEngineFinalized)
153     .value("kLiteLLMNotYetLink", StatusCode::kLiteLLMNotYetLink)
154     .value("kLiteLLMAlreadyLink", StatusCode::kLiteLLMAlreadyLink)
155     .value("kLiteLLMLinkFailed", StatusCode::kLiteLLMLinkFailed)
156     .value("kLiteLLMUnlinkFailed", StatusCode::kLiteLLMUnlinkFailed)
157     .value("kLiteLLMNofiryPromptUnlinkFailed", StatusCode::kLiteLLMNofiryPromptUnlinkFailed)
158     .value("kLiteLLMClusterNumExceedLimit", StatusCode::kLiteLLMClusterNumExceedLimit)
159     .value("kLiteLLMProcessingLink", StatusCode::kLiteLLMProcessingLink)
160     .value("kLiteLLMOutOfMemory", StatusCode::kLiteLLMOutOfMemory)
161     .value("kLiteLLMPrefixAlreadyExist", StatusCode::kLiteLLMPrefixAlreadyExist)
162     .value("kLiteLLMPrefixNotExist", StatusCode::kLiteLLMPrefixNotExist)
163     .value("kLiteLLMSeqLenOverLimit", StatusCode::kLiteLLMSeqLenOverLimit)
164     .value("kLiteLLMNoFreeBlock", StatusCode::kLiteLLMNoFreeBlock)
165     .value("kLiteLLMBlockOutOfMemory", StatusCode::kLiteLLMBlockOutOfMemory);
166 
167   (void)py::class_<Status, std::shared_ptr<Status>>(m, "Status")
168     .def(py::init<>())
169     .def(py::init<StatusCode>())
170     .def("ToString", &Status::ToString)
171     .def("IsOk", &Status::IsOk)
172     .def("IsError", &Status::IsError)
173     .def("StatusCode", &Status::StatusCode);
174 
175   (void)py::class_<Model, std::shared_ptr<Model>>(m, "ModelBind")
176     .def(py::init<>())
177     .def("build_from_buff",
178          py::overload_cast<const void *, size_t, ModelType, const std::shared_ptr<Context> &>(&Model::Build),
179          py::call_guard<py::gil_scoped_release>())
180     .def("build_from_file",
181          py::overload_cast<const std::string &, ModelType, const std::shared_ptr<Context> &>(&Model::Build),
182          py::call_guard<py::gil_scoped_release>())
183     .def("build_from_buff_with_decrypt",
184          py::overload_cast<const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
185                            const std::string &, const std::string &>(&Model::Build))
186     .def("build_from_file_with_decrypt",
187          py::overload_cast<const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &,
188                            const std::string &, const std::string &>(&Model::Build))
189     .def("load_config", py::overload_cast<const std::string &>(&Model::LoadConfig))
190     .def("update_config", &PyModelUpdateConfig)
191     .def("resize", &PyModelResize)
192     .def("predict", &PyModelPredict, py::call_guard<py::gil_scoped_release>())
193     .def("update_weights", &PyModelUpdateWeights, py::call_guard<py::gil_scoped_release>())
194     .def("get_inputs", &PyModelGetInputs)
195     .def("get_outputs", &PyModelGetOutputs)
196     .def("get_model_info", &PyModelGetModelInfo)
197     .def("get_input_by_tensor_name",
198          [](Model &model, const std::string &tensor_name) { return model.GetInputByTensorName(tensor_name); })
199     .def("get_output_by_tensor_name",
200          [](Model &model, const std::string &tensor_name) { return model.GetOutputByTensorName(tensor_name); });
201 }
202 
203 #ifdef PARALLEL_INFERENCE
PyModelParallelRunnerPredict(ModelParallelRunner * runner,const std::vector<MSTensorPtr> & inputs_ptr,const std::vector<MSTensorPtr> & outputs_ptr,const MSKernelCallBack & before=nullptr,const MSKernelCallBack & after=nullptr)204 std::vector<MSTensorPtr> PyModelParallelRunnerPredict(ModelParallelRunner *runner,
205                                                       const std::vector<MSTensorPtr> &inputs_ptr,
206                                                       const std::vector<MSTensorPtr> &outputs_ptr,
207                                                       const MSKernelCallBack &before = nullptr,
208                                                       const MSKernelCallBack &after = nullptr) {
209   if (runner == nullptr) {
210     MS_LOG(ERROR) << "ModelParallelRunner object cannot be nullptr";
211     return {};
212   }
213   std::vector<MSTensor> inputs = MSTensorPtrToMSTensor(inputs_ptr);
214   std::vector<MSTensor> outputs;
215   if (!outputs_ptr.empty()) {
216     outputs = MSTensorPtrToMSTensor(outputs_ptr);
217   }
218   if (!runner->Predict(inputs, &outputs, before, after).IsOk()) {
219     return {};
220   }
221   return MSTensorToMSTensorPtr(outputs);
222 }
223 
PyModelParallelRunnerGetInputs(ModelParallelRunner * runner)224 std::vector<MSTensorPtr> PyModelParallelRunnerGetInputs(ModelParallelRunner *runner) {
225   if (runner == nullptr) {
226     MS_LOG(ERROR) << "ModelParallelRunner object cannot be nullptr";
227     return {};
228   }
229   return MSTensorToMSTensorPtr(runner->GetInputs());
230 }
231 
PyModelParallelRunnerGetOutputs(ModelParallelRunner * runner)232 std::vector<MSTensorPtr> PyModelParallelRunnerGetOutputs(ModelParallelRunner *runner) {
233   if (runner == nullptr) {
234     MS_LOG(ERROR) << "ModelParallelRunner object cannot be nullptr";
235     return {};
236   }
237   return MSTensorToMSTensorPtr(runner->GetOutputs());
238 }
239 #endif
240 
ModelParallelRunnerPyBind(const py::module & m)241 void ModelParallelRunnerPyBind(const py::module &m) {
242 #ifdef PARALLEL_INFERENCE
243   (void)py::class_<RunnerConfig, std::shared_ptr<RunnerConfig>>(m, "RunnerConfigBind")
244     .def(py::init<>())
245     .def("set_config_info", py::overload_cast<const std::string &, const std::map<std::string, std::string> &>(
246                               &RunnerConfig::SetConfigInfo))
247     .def("get_config_info", &RunnerConfig::GetConfigInfo)
248     .def("set_config_path", py::overload_cast<const std::string &>(&RunnerConfig::SetConfigPath))
249     .def("get_config_path", &RunnerConfig::GetConfigPath)
250     .def("set_workers_num", &RunnerConfig::SetWorkersNum)
251     .def("get_workers_num", &RunnerConfig::GetWorkersNum)
252     .def("set_context", &RunnerConfig::SetContext)
253     .def("get_context", &RunnerConfig::GetContext)
254     .def("set_device_ids", &RunnerConfig::SetDeviceIds)
255     .def("get_device_ids", &RunnerConfig::GetDeviceIds)
256     .def("get_context_info",
257          [](RunnerConfig &runner_config) {
258            const auto &context = runner_config.GetContext();
259            std::string result = "thread num: " + std::to_string(context->GetThreadNum()) +
260                                 ", bind mode: " + std::to_string(context->GetThreadAffinityMode());
261            return result;
262          })
263     .def("get_config_info_string", [](RunnerConfig &runner_config) {
264       std::string result = "";
265       const auto &config_info = runner_config.GetConfigInfo();
266       for (auto &section : config_info) {
267         result += section.first + ": ";
268         for (auto &config : section.second) {
269           auto temp = config.first + " " + config.second + "\n";
270           result += temp;
271         }
272       }
273       return result;
274     });
275 
276   (void)py::class_<ModelParallelRunner, std::shared_ptr<ModelParallelRunner>>(m, "ModelParallelRunnerBind")
277     .def(py::init<>())
278     .def("init",
279          py::overload_cast<const std::string &, const std::shared_ptr<RunnerConfig> &>(&ModelParallelRunner::Init),
280          py::call_guard<py::gil_scoped_release>())
281     .def("get_inputs", &PyModelParallelRunnerGetInputs)
282     .def("get_outputs", &PyModelParallelRunnerGetOutputs)
283     .def("predict", &PyModelParallelRunnerPredict, py::call_guard<py::gil_scoped_release>());
284 #endif
285 }
286 
PyModelGroupAddModelByObject(ModelGroup * model_group,const std::vector<Model * > & models_ptr)287 Status PyModelGroupAddModelByObject(ModelGroup *model_group, const std::vector<Model *> &models_ptr) {
288   if (model_group == nullptr) {
289     MS_LOG(ERROR) << "Model group object cannot be nullptr";
290     return {};
291   }
292   std::vector<Model> models;
293   for (auto model_ptr : models_ptr) {
294     if (model_ptr == nullptr) {
295       MS_LOG(ERROR) << "Model object cannot be nullptr";
296       return {};
297     }
298     models.push_back(*model_ptr);
299   }
300   return model_group->AddModel(models);
301 }
302 
ModelGroupPyBind(const py::module & m)303 void ModelGroupPyBind(const py::module &m) {
304   (void)py::enum_<ModelGroupFlag>(m, "ModelGroupFlag")
305     .value("kShareWeight", ModelGroupFlag::kShareWeight)
306     .value("kShareWorkspace", ModelGroupFlag::kShareWorkspace);
307 
308   (void)py::class_<ModelGroup, std::shared_ptr<ModelGroup>>(m, "ModelGroupBind")
309     .def(py::init<ModelGroupFlag>())
310     .def("add_model", py::overload_cast<const std::vector<std::string> &>(&ModelGroup::AddModel))
311     .def("add_model_by_object", &PyModelGroupAddModelByObject)
312     .def("cal_max_size_of_workspace",
313          py::overload_cast<ModelType, const std::shared_ptr<Context> &>(&ModelGroup::CalMaxSizeOfWorkspace));
314 }
315 }  // namespace mindspore::lite
316