1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/c_api/model_c.h"
17 #include "type_c_private.h"
18 #include "context_c.h"
19 #include <vector>
20 #include <cstdint>
21 #include "include/api/context.h"
22 #include "include/api/serialization.h"
23 #include "include/api/types.h"
24 #include "src/litert/cxx_api/tensor/tensor_impl.h"
25 #include "src/litert/cxx_api/model/model_impl.h"
26 #ifdef ENABLE_HI_APP_EVENT
27 #include "src/common/hi_app_event/hi_app_event.h"
28 #endif
29
30 namespace mindspore {
31 class ModelC {
32 public:
ModelC()33 ModelC() : model_(nullptr) {}
~ModelC()34 ~ModelC() {
35 for (auto in : inputs_) {
36 if (in != nullptr) {
37 delete in;
38 }
39 }
40 for (auto out : outputs_) {
41 if (out != nullptr) {
42 delete out;
43 }
44 }
45 for (auto out : outputs_train_) {
46 if (out != nullptr) {
47 delete out;
48 }
49 }
50
51 // In zero copy scene where user will call set or get allocator function, but when model is destroyed, the allocator
52 // table will not be freed, and its size continues to grow causing memory leak, so when ModelC is destroyed, clean
53 // allocator table.
54 CleanAllocatorTable();
55 }
56
57 MSTensor **GetInputs(size_t *input_num);
58 MSTensor **GetOutputs(size_t *output_num);
59 mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &ms_callback);
60 std::shared_ptr<Model> model_;
61 std::shared_ptr<Context> context_;
62
63 private:
64 MSTensor **GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors);
65 std::vector<MSTensor *> inputs_;
66 std::vector<MSTensor *> outputs_;
67 std::vector<MSTensor *> outputs_train_;
68 };
69
GetInputs(size_t * input_num)70 MSTensor **ModelC::GetInputs(size_t *input_num) {
71 if (model_ == nullptr) {
72 MS_LOG(ERROR) << "model_ is nullptr.";
73 return nullptr;
74 }
75 if (!inputs_.empty()) {
76 *input_num = inputs_.size();
77 return inputs_.data();
78 }
79 auto inputs = model_->GetInputs();
80 *input_num = inputs.size();
81 inputs_.resize(inputs.size(), nullptr);
82 for (size_t i = 0; i < inputs.size(); i++) {
83 inputs_[i] = new (std::nothrow) MSTensor(inputs[i].impl());
84 if (inputs_[i] == nullptr) {
85 inputs_.clear();
86 return nullptr;
87 }
88 }
89 return inputs_.data();
90 }
91
GetOutputs(size_t * output_num)92 MSTensor **ModelC::GetOutputs(size_t *output_num) {
93 if (model_->GetTrainMode() == true) {
94 return GetOutputsTensor(output_num, &outputs_train_);
95 } else {
96 return GetOutputsTensor(output_num, &outputs_);
97 }
98 }
99
GetOutputsTensor(size_t * output_num,std::vector<MSTensor * > * vec_tensors)100 MSTensor **ModelC::GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors) {
101 if (model_ == nullptr) {
102 MS_LOG(ERROR) << "model_ is nullptr.";
103 return nullptr;
104 }
105 if (!vec_tensors->empty()) {
106 *output_num = vec_tensors->size();
107 return vec_tensors->data();
108 }
109
110 auto outputs = model_->GetOutputs();
111 *output_num = outputs.size();
112 vec_tensors->resize(outputs.size(), nullptr);
113 for (size_t i = 0; i < outputs.size(); i++) {
114 (*vec_tensors)[i] = new (std::nothrow) MSTensor(outputs[i].impl());
115 if ((*vec_tensors)[i] == nullptr) {
116 vec_tensors->clear();
117 return nullptr;
118 }
119 }
120 return vec_tensors->data();
121 }
122
TransCallBack(const OH_AI_KernelCallBack & ms_callback)123 mindspore::MSKernelCallBack ModelC::TransCallBack(const OH_AI_KernelCallBack &ms_callback) {
124 mindspore::MSKernelCallBack call_back = nullptr;
125 if (ms_callback != nullptr) {
126 call_back = [&](const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
127 const mindspore::MSCallBackParam &opInfo) {
128 std::vector<OH_AI_TensorHandle> vec_inputs;
129 std::vector<OH_AI_TensorHandle> vec_outputs;
130 OH_AI_CallBackParam call_back = {const_cast<char *>(opInfo.node_name.c_str()),
131 const_cast<char *>(opInfo.node_type.c_str())};
132 size_t inputs_handle_num = inputs.size();
133 for (size_t i = 0; i < inputs_handle_num; i++) {
134 vec_inputs.push_back(static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(inputs)[i])));
135 }
136 size_t outputs_handle_num = outputs.size();
137 for (size_t i = 0; i < outputs_handle_num; i++) {
138 vec_outputs.push_back(
139 static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(outputs)[i])));
140 }
141 OH_AI_TensorHandleArray handle_inputs = {inputs_handle_num, vec_inputs.data()};
142 OH_AI_TensorHandleArray handle_outputs = {outputs_handle_num, vec_outputs.data()};
143 return ms_callback(handle_inputs, handle_outputs, call_back);
144 };
145 }
146 return call_back;
147 }
148 } // namespace mindspore
149
OH_AI_ModelCreate()150 OH_AI_ModelHandle OH_AI_ModelCreate() {
151 MS_LOG(INFO) << "Start to create ms model";
152 auto impl = new (std::nothrow) mindspore::ModelC();
153 if (impl == nullptr) {
154 MS_LOG(ERROR) << "Model implement is nullptr.";
155 return nullptr;
156 }
157 impl->model_ = std::make_shared<mindspore::Model>();
158 if (impl->model_ == nullptr) {
159 MS_LOG(ERROR) << "inner model object is nullptr.";
160 delete impl;
161 return nullptr;
162 }
163 MS_LOG(INFO) << "Created ms model successfully";
164 return static_cast<OH_AI_ModelHandle>(impl);
165 }
166
OH_AI_ModelDestroy(OH_AI_ModelHandle * model)167 void OH_AI_ModelDestroy(OH_AI_ModelHandle *model) {
168 MS_LOG(INFO) << "Start to destroy ms model";
169 if (model == nullptr || *model == nullptr) {
170 MS_LOG(ERROR) << "model is nullptr.";
171 return;
172 }
173 auto impl = static_cast<mindspore::ModelC *>(*model);
174 delete impl;
175 *model = nullptr;
176 MS_LOG(INFO) << "Destroyed ms model successfully";
177 }
178
OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model,void * workspace,size_t workspace_size)179 void OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model, void *workspace, size_t workspace_size) {
180 MS_LOG(ERROR) << "Unsupported Feature.";
181 return;
182 }
183
OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model)184 size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {
185 MS_LOG(ERROR) << "Unsupported Feature.";
186 return 0;
187 }
188
OH_AI_ModelBuild(OH_AI_ModelHandle model,const void * model_data,size_t data_size,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context)189 OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, OH_AI_ModelType model_type,
190 const OH_AI_ContextHandle model_context) {
191 MS_LOG(INFO) << "Start to build ms model";
192 if (model == nullptr || model_data == nullptr || model_context == nullptr) {
193 MS_LOG(ERROR) << "model or model_data or model_context is nullptr.";
194 return OH_AI_STATUS_LITE_NULLPTR;
195 }
196 if (model_type == OH_AI_MODELTYPE_INVALID) {
197 MS_LOG(ERROR) << "model_type is invalid.";
198 return OH_AI_STATUS_LITE_PARAM_INVALID;
199 }
200 mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
201 auto impl = static_cast<mindspore::ModelC *>(model);
202 if (impl->context_.get() != context->context_ && context->owned_by_model_) {
203 MS_LOG(ERROR) << "context is owned by other model.";
204 return OH_AI_STATUS_LITE_PARAM_INVALID;
205 }
206 if (impl->context_.get() != context->context_) {
207 impl->context_.reset(context->context_);
208 context->owned_by_model_ = true;
209 }
210 auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_);
211 if (ret.IsOk()) {
212 MS_LOG(INFO) << "Built ms model successfully";
213 } else {
214 MS_LOG(ERROR) << "Built ms model failed, ret: " << ret;
215 }
216 return static_cast<OH_AI_Status>(ret.StatusCode());
217 }
218
OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model,const char * model_path,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context)219 OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
220 const OH_AI_ContextHandle model_context) {
221 MS_LOG(INFO) << "Start to build ms model from file";
222 if (model == nullptr || model_path == nullptr || model_context == nullptr) {
223 MS_LOG(ERROR) << "model or model_path or model_context is nullptr.";
224 return OH_AI_STATUS_LITE_NULLPTR;
225 }
226 if (model_type == OH_AI_MODELTYPE_INVALID) {
227 MS_LOG(ERROR) << "model_type is invalid.";
228 return OH_AI_STATUS_LITE_PARAM_INVALID;
229 }
230 mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
231 auto impl = static_cast<mindspore::ModelC *>(model);
232 if (impl->context_.get() != context->context_ && context->owned_by_model_) {
233 MS_LOG(ERROR) << "context is owned by other model.";
234 return OH_AI_STATUS_LITE_PARAM_INVALID;
235 }
236 if (impl->context_.get() != context->context_) {
237 impl->context_.reset(context->context_);
238 context->owned_by_model_ = true;
239 }
240 auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_);
241 if (ret.IsOk()) {
242 MS_LOG(INFO) << "Built ms model from file successfully";
243 } else {
244 MS_LOG(ERROR) << "Built ms model from file failed, ret: " << ret;
245 }
246 return static_cast<OH_AI_Status>(ret.StatusCode());
247 }
248
OH_AI_ModelResize(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray inputs,OH_AI_ShapeInfo * shape_infos,size_t shape_info_num)249 OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos,
250 size_t shape_info_num) {
251 MS_LOG(INFO) << "Start to resize ms model";
252 if (model == nullptr || shape_infos == nullptr) {
253 MS_LOG(ERROR) << "model or shape_infos is nullptr.";
254 return OH_AI_STATUS_LITE_NULLPTR;
255 }
256 std::vector<mindspore::MSTensor> vec_inputs;
257 for (size_t i = 0; i < inputs.handle_num; ++i) {
258 vec_inputs.push_back(*static_cast<mindspore::MSTensor *>(inputs.handle_list[i]));
259 }
260
261 std::vector<std::vector<int64_t>> vec_dims;
262 for (size_t i = 0; i < shape_info_num; i++) {
263 std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num);
264 if (std::any_of(shape.begin(), shape.end(), [](int64_t val) { return val < 0 || val > INT32_MAX; })) {
265 MS_LOG(ERROR) << "Invalid shape: " << shape << ", each dimension must be in [0, INT32_MAX]";
266 return OH_AI_STATUS_LITE_PARAM_INVALID;
267 }
268 vec_dims.push_back(shape);
269 }
270 auto impl = static_cast<mindspore::ModelC *>(model);
271 auto ret = impl->model_->Resize(vec_inputs, vec_dims);
272 if (ret.IsOk()) {
273 MS_LOG(INFO) << "Resized ms model successfully";
274 } else {
275 MS_LOG(ERROR) << "Resized ms model failed, ret: " << ret;
276 }
277 return static_cast<OH_AI_Status>(ret.StatusCode());
278 }
279
OH_AI_ModelPredict(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray inputs,OH_AI_TensorHandleArray * outputs,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)280 OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_TensorHandleArray *outputs,
281 const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
282 MS_LOG(INFO) << "Start to predict ms model";
283 if (model == nullptr) {
284 MS_LOG(ERROR) << "model is nullptr.";
285 return OH_AI_STATUS_LITE_NULLPTR;
286 }
287 auto impl = static_cast<mindspore::ModelC *>(model);
288 size_t input_num;
289 (void)impl->GetInputs(&input_num);
290 if (input_num != inputs.handle_num) {
291 MS_LOG(ERROR) << "Wrong input size.";
292 return OH_AI_STATUS_LITE_ERROR;
293 }
294
295 std::vector<mindspore::MSTensor> ms_tensor_inputs;
296 for (size_t i = 0; i < inputs.handle_num; i++) {
297 if (inputs.handle_list[i] != nullptr) {
298 auto user_input = static_cast<mindspore::MSTensor *>(inputs.handle_list[i]);
299 ms_tensor_inputs.push_back(*user_input);
300 } else {
301 MS_LOG(ERROR) << "input handle is nullptr.";
302 return OH_AI_STATUS_LITE_NULLPTR;
303 }
304 }
305
306 mindspore::MSKernelCallBack before_call_back = impl->TransCallBack(before);
307 mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after);
308 std::vector<mindspore::MSTensor> ms_tensor_outputs;
309
310 size_t output_num;
311 (void)impl->GetOutputs(&output_num);
312 auto handle_num = outputs->handle_num;
313 if (handle_num == output_num) {
314 MS_LOG(INFO) << "use user provided output";
315 for (size_t i = 0; i < output_num; i++) {
316 if (outputs->handle_list[i] == nullptr) {
317 MS_LOG(ERROR) << "user provided output array handle_list[" << i << "] is nullptr";
318 return OH_AI_STATUS_LITE_NULLPTR;
319 }
320 ms_tensor_outputs.push_back(*static_cast<mindspore::MSTensor *>(outputs->handle_list[i]));
321 }
322 }
323
324 auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
325 if (!ret.IsOk()) {
326 MS_LOG(ERROR) << "Predict fail, ret :" << ret;
327 return static_cast<OH_AI_Status>(ret.StatusCode());
328 }
329
330 if (handle_num == output_num) {
331 return OH_AI_STATUS_SUCCESS;
332 }
333
334 outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&(outputs->handle_num)));
335 MS_LOG(INFO) << "Predicted ms model successfully";
336 return static_cast<OH_AI_Status>(ret.StatusCode());
337 }
338
OH_AI_ModelRunStep(OH_AI_ModelHandle model,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)339 OH_AI_Status OH_AI_ModelRunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
340 MS_LOG(ERROR) << "Unsupported Feature.";
341 return OH_AI_STATUS_LITE_NOT_SUPPORT;
342 }
343
OH_AI_ModelExportWeight(const OH_AI_ModelHandle model,const char * export_path)344 OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *export_path) {
345 MS_LOG(ERROR) << "Unsupported Feature.";
346 return OH_AI_STATUS_LITE_NOT_SUPPORT;
347 }
348
OH_AI_ModelGetInputs(const OH_AI_ModelHandle model)349 OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) {
350 MS_LOG(INFO) << "Start to get ms model inputs";
351 if (model == nullptr) {
352 MS_LOG(ERROR) << "model is nullptr.";
353 return {0, nullptr};
354 }
355 auto impl = static_cast<mindspore::ModelC *>(model);
356 size_t input_num = 0;
357 auto handles = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetInputs(&input_num));
358 MS_LOG(INFO) << "Got ms model " << input_num << " inputs successfully";
359 return {input_num, handles};
360 }
361
OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model)362 OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) {
363 MS_LOG(INFO) << "Start to get ms model outputs";
364 if (model == nullptr) {
365 MS_LOG(ERROR) << "model is nullptr.";
366 return {0, nullptr};
367 }
368 auto impl = static_cast<mindspore::ModelC *>(model);
369 size_t output_num;
370 auto handles = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&output_num));
371 MS_LOG(INFO) << "Got ms model " << output_num << " outputs successfully";
372 return {output_num, handles};
373 }
374
OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model,const char * tensor_name)375 OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
376 MS_LOG(INFO) << "Start to get ms model input by name";
377 if (model == nullptr || tensor_name == nullptr) {
378 MS_LOG(ERROR) << "model or tensor_name is nullptr.";
379 return nullptr;
380 }
381 auto impl = static_cast<mindspore::ModelC *>(model);
382 size_t input_num;
383 auto inputs = impl->GetInputs(&input_num);
384 for (size_t i = 0; i < input_num; i++) {
385 if (inputs[i]->Name() == tensor_name) {
386 MS_LOG(INFO) << "Got ms model input by name successfully";
387 return static_cast<OH_AI_TensorHandle>(inputs[i]);
388 }
389 }
390 MS_LOG(ERROR) << "Input tensor is not exist";
391 return nullptr;
392 }
393
OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model,const char * tensor_name)394 OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
395 MS_LOG(INFO) << "Start to get ms model output by name";
396 if (model == nullptr || tensor_name == nullptr) {
397 MS_LOG(ERROR) << "model or tensor_name is nullptr.";
398 return nullptr;
399 }
400 auto impl = static_cast<mindspore::ModelC *>(model);
401 size_t output_num;
402 auto outputs = impl->GetOutputs(&output_num);
403 for (size_t i = 0; i < output_num; i++) {
404 if (outputs[i]->Name() == tensor_name) {
405 MS_LOG(INFO) << "Got ms model output by name successfully";
406 return static_cast<OH_AI_TensorHandle>(outputs[i]);
407 }
408 }
409 MS_LOG(ERROR) << "Output tensor is not exist";
410 return nullptr;
411 }
412
OH_AI_TrainCfgCreate()413 OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate() {
414 auto impl = new (std::nothrow) mindspore::TrainCfg();
415 if (impl == nullptr) {
416 MS_LOG(ERROR) << "TrainCfg implement is nullptr.";
417 return nullptr;
418 }
419 return static_cast<OH_AI_TrainCfgHandle>(impl);
420 }
421
OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle * train_cfg)422 void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg) {
423 if (train_cfg != nullptr && *train_cfg != nullptr) {
424 auto impl = static_cast<mindspore::TrainCfg *>(*train_cfg);
425 delete impl;
426 *train_cfg = nullptr;
427 }
428 }
429
OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg,size_t * num)430 char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) {
431 if (train_cfg == nullptr || num == nullptr) {
432 MS_LOG(ERROR) << "train_cfg or num is nullptr.";
433 return nullptr;
434 }
435 auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
436 auto loss_name = impl->GetLossName();
437 *num = loss_name.size();
438 char **name = static_cast<char **>(malloc(loss_name.size() * sizeof(char *)));
439 if (name == nullptr) {
440 MS_LOG(ERROR) << "Failed to malloc loss_name.";
441 return nullptr;
442 }
443 for (size_t i = 0; i < loss_name.size(); i++) {
444 name[i] = static_cast<char *>(malloc(loss_name[i].size() + 1));
445 if (name[i] == nullptr) {
446 for(size_t j = 0; j < i; j++){
447 free(name[j]);
448 }
449 MS_LOG(ERROR) << "Failed to malloc name.";
450 return nullptr;
451 }
452 memcpy(name[i], loss_name[i].c_str(), loss_name[i].size() + 1);
453 }
454 return name;
455 }
456
OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg,const char ** loss_name,size_t num)457 void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num) {
458 if (train_cfg == nullptr || loss_name == nullptr || *loss_name == nullptr) {
459 MS_LOG(ERROR) << "train_cfg or loss_name is nullptr.";
460 return;
461 }
462 auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
463 std::vector<std::string> vec_name;
464 for (size_t i = 0; i < num; i++) {
465 vec_name.push_back(loss_name[i]);
466 }
467 impl->SetLossName(vec_name);
468 }
469
OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg)470 OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg) {
471 if (train_cfg == nullptr) {
472 MS_LOG(ERROR) << "train_cfg is nullptr, return OH_AI_KO0";
473 return OH_AI_KO0;
474 }
475 auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
476 return static_cast<OH_AI_OptimizationLevel>(impl->optimization_level_);
477 }
478
OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg,OH_AI_OptimizationLevel level)479 void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level) {
480 if (train_cfg == nullptr) {
481 MS_LOG(ERROR) << "train_cfg is nullptr.";
482 return;
483 }
484 auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
485 impl->optimization_level_ = static_cast<mindspore::OptimizationLevel>(level);
486 }
487
OH_AI_TrainModelBuild(OH_AI_ModelHandle model,const void * model_data,size_t data_size,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context,const OH_AI_TrainCfgHandle train_cfg)488 OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, OH_AI_ModelType model_type,
489 const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg) {
490 if (model == nullptr || model_data == nullptr || model_context == nullptr) {
491 MS_LOG(ERROR) << "model or model_data or model_context is nullptr.";
492 return OH_AI_STATUS_LITE_NULLPTR;
493 }
494 if (model_type == OH_AI_MODELTYPE_INVALID) {
495 MS_LOG(ERROR) << "model_type is invalid.";
496 return OH_AI_STATUS_LITE_PARAM_INVALID;
497 }
498 auto impl = static_cast<mindspore::ModelC *>(model);
499
500 mindspore::Graph graph;
501 auto status =
502 mindspore::Serialization::Load(model_data, data_size, static_cast<mindspore::ModelType>(model_type), &graph);
503 if (status != mindspore::kSuccess) {
504 MS_LOG(ERROR) << "load ms file failed.";
505 return OH_AI_STATUS_LITE_ERROR;
506 }
507 auto context = static_cast<mindspore::ContextC *>(model_context);
508 auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
509 if (impl->context_.get() != context->context_ && context->owned_by_model_) {
510 MS_LOG(ERROR) << "context is owned by other model.";
511 return OH_AI_STATUS_LITE_PARAM_INVALID;
512 }
513 if (impl->context_.get() != context->context_) {
514 impl->context_.reset(context->context_);
515 context->owned_by_model_ = true;
516 }
517 auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
518 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
519 if (ret != mindspore::kSuccess) {
520 MS_LOG(ERROR) << "Load and compile failed";
521 }
522 return static_cast<OH_AI_Status>(ret.StatusCode());
523 }
524
OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model,const char * model_path,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context,const OH_AI_TrainCfgHandle train_cfg)525 OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
526 const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg) {
527 if (model == nullptr || model_path == nullptr || model_context == nullptr) {
528 MS_LOG(ERROR) << "model or model_path or model_context is nullptr.";
529 return OH_AI_STATUS_LITE_NULLPTR;
530 }
531 if (model_type == OH_AI_MODELTYPE_INVALID) {
532 MS_LOG(ERROR) << "model_type is invalid.";
533 return OH_AI_STATUS_LITE_PARAM_INVALID;
534 }
535 auto impl = static_cast<mindspore::ModelC *>(model);
536
537 mindspore::Graph graph;
538 auto status = mindspore::Serialization::Load(model_path, static_cast<mindspore::ModelType>(model_type), &graph);
539 if (status != mindspore::kSuccess) {
540 MS_LOG(ERROR) << "load ms file failed. " << model_path;
541 return OH_AI_STATUS_LITE_ERROR;
542 }
543 auto context = static_cast<mindspore::ContextC *>(model_context);
544 auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
545 if (impl->context_.get() != context->context_ && context->owned_by_model_) {
546 MS_LOG(ERROR) << "context is owned by other model.";
547 return OH_AI_STATUS_LITE_PARAM_INVALID;
548 }
549 if (impl->context_.get() != context->context_) {
550 impl->context_.reset(context->context_);
551 context->owned_by_model_ = true;
552 }
553 auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
554 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
555 if (ret != mindspore::kSuccess) {
556 MS_LOG(ERROR) << "Load and compile failed";
557 }
558 return static_cast<OH_AI_Status>(ret.StatusCode());
559 }
560
OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model,float learning_rate)561 OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate) {
562 if (model == nullptr) {
563 MS_LOG(ERROR) << "model is nullptr.";
564 return OH_AI_STATUS_LITE_PARAM_INVALID;
565 }
566 auto impl = static_cast<mindspore::ModelC *>(model);
567 auto ret = impl->model_->SetLearningRate(learning_rate);
568 return static_cast<OH_AI_Status>(ret.StatusCode());
569 }
570
OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model)571 float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model) {
572 if (model == nullptr) {
573 MS_LOG(ERROR) << "model is nullptr.";
574 return OH_AI_STATUS_LITE_PARAM_INVALID;
575 }
576 auto impl = static_cast<mindspore::ModelC *>(model);
577 return impl->model_->GetLearningRate();
578 }
579
OH_AI_RunStep(OH_AI_ModelHandle model,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)580 OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
581 if (model == nullptr) {
582 MS_LOG(ERROR) << "model is nullptr.";
583 return OH_AI_STATUS_LITE_PARAM_INVALID;
584 }
585 auto impl = static_cast<mindspore::ModelC *>(model);
586 auto ret = impl->model_->RunStep(impl->TransCallBack(before), impl->TransCallBack(after));
587 return static_cast<OH_AI_Status>(ret.StatusCode());
588 }
589
OH_AI_ModelGetWeights(OH_AI_ModelHandle model)590 OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model) {
591 if (model == nullptr) {
592 MS_LOG(ERROR) << "model is nullptr.";
593 return {0, nullptr};
594 }
595 auto impl = static_cast<mindspore::ModelC *>(model);
596 auto features = impl->model_->GetFeatureMaps();
597 size_t handle_num = features.size();
598
599 mindspore::MSTensor **handle_list =
600 static_cast<mindspore::MSTensor **>(malloc(handle_num * sizeof(mindspore::MSTensor *)));
601 if (handle_list == nullptr) {
602 MS_LOG(ERROR) << "Failed to malloc handle_list.";
603 return {0, nullptr};
604 }
605 for (size_t i = 0; i < handle_num; i++) {
606 handle_list[i] = new (std::nothrow) mindspore::MSTensor(features[i].impl());
607 }
608 return {handle_num, reinterpret_cast<OH_AI_TensorHandle *>(handle_list)};
609 }
610
OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray new_weights)611 OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights) {
612 if (model == nullptr) {
613 MS_LOG(ERROR) << "model is nullptr.";
614 return OH_AI_STATUS_LITE_PARAM_INVALID;
615 }
616 auto impl = static_cast<mindspore::ModelC *>(model);
617 std::vector<mindspore::MSTensor> weights;
618 for (size_t i = 0; i < new_weights.handle_num; i++) {
619 weights.push_back(*static_cast<mindspore::MSTensor *>(new_weights.handle_list[i]));
620 }
621 auto ret = impl->model_->UpdateWeights(weights);
622 return static_cast<OH_AI_Status>(ret.StatusCode());
623 }
624
OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model)625 bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model) {
626 if (model == nullptr) {
627 MS_LOG(ERROR) << "model is nullptr.";
628 return false;
629 }
630 auto impl = static_cast<mindspore::ModelC *>(model);
631 return impl->model_->GetTrainMode();
632 }
633
OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model,bool train)634 OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train) {
635 if (model == nullptr) {
636 MS_LOG(ERROR) << "model is nullptr.";
637 return OH_AI_STATUS_LITE_PARAM_INVALID;
638 }
639 auto impl = static_cast<mindspore::ModelC *>(model);
640 auto ret = impl->model_->SetTrainMode(train);
641 return static_cast<OH_AI_Status>(ret.StatusCode());
642 }
643
OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model,int virtual_batch_multiplier,float lr,float momentum)644 OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, float momentum) {
645 if (model == nullptr) {
646 MS_LOG(ERROR) << "model is nullptr.";
647 return OH_AI_STATUS_LITE_PARAM_INVALID;
648 }
649 auto impl = static_cast<mindspore::ModelC *>(model);
650 auto ret = impl->model_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
651 return static_cast<OH_AI_Status>(ret.StatusCode());
652 }
653
OH_AI_ExportModel(OH_AI_ModelHandle model,OH_AI_ModelType model_type,const char * model_file,OH_AI_QuantizationType quantization_type,bool export_inference_only,char ** output_tensor_name,size_t num)654 OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
655 OH_AI_QuantizationType quantization_type, bool export_inference_only, char **output_tensor_name,
656 size_t num) {
657 if (model == nullptr) {
658 MS_LOG(ERROR) << "model is nullptr.";
659 return OH_AI_STATUS_LITE_PARAM_INVALID;
660 }
661 auto impl = static_cast<mindspore::ModelC *>(model);
662 std::vector<std::string> tensor_name;
663 for (size_t i = 0; i < num; i++) {
664 tensor_name.push_back(output_tensor_name[i]);
665 }
666 auto ret = mindspore::Serialization::ExportModel(
667 *(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), model_file,
668 static_cast<mindspore::QuantizationType>(quantization_type), export_inference_only, tensor_name);
669 if (!ret.IsOk()) {
670 MS_LOG(ERROR) << "export model fail, ret :" << ret;
671 }
672 return static_cast<OH_AI_Status>(ret.StatusCode());
673 }
674
OH_AI_ExportModelBuffer(OH_AI_ModelHandle model,OH_AI_ModelType model_type,char ** model_data,size_t * data_size,OH_AI_QuantizationType quantization_type,bool export_inference_only,char ** output_tensor_name,size_t num)675 OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data, size_t *data_size,
676 OH_AI_QuantizationType quantization_type, bool export_inference_only,
677 char **output_tensor_name, size_t num) {
678 if (model == nullptr) {
679 MS_LOG(ERROR) << "model is nullptr.";
680 return OH_AI_STATUS_LITE_PARAM_INVALID;
681 }
682 auto impl = static_cast<mindspore::ModelC *>(model);
683 std::vector<std::string> tensor_name;
684 for (size_t i = 0; i < num; i++) {
685 tensor_name.push_back(output_tensor_name[i]);
686 }
687 mindspore::Buffer buffer;
688 auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type),
689 &buffer, static_cast<mindspore::QuantizationType>(quantization_type),
690 export_inference_only, tensor_name);
691 auto data = reinterpret_cast<char *>(buffer.MutableData());
692 *model_data = reinterpret_cast<char *>(malloc(buffer.DataSize()));
693 if (*model_data == nullptr) {
694 MS_LOG(ERROR) << "malloc model_data failed.";
695 return OH_AI_STATUS_LITE_NULLPTR;
696 }
697 *data_size = buffer.DataSize();
698 memcpy(*model_data, data, buffer.DataSize());
699 if (!ret.IsOk()) {
700 MS_LOG(ERROR) << "export model fail, ret :" << ret;
701 }
702 return static_cast<OH_AI_Status>(ret.StatusCode());
703 }
704
OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model,OH_AI_ModelType model_type,const char * weight_file,bool is_inference,bool enable_fp16,char ** changeable_weights_name,size_t num)705 OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *weight_file,
706 bool is_inference, bool enable_fp16, char **changeable_weights_name,
707 size_t num) {
708 if (model == nullptr) {
709 MS_LOG(ERROR) << "model is nullptr.";
710 return OH_AI_STATUS_LITE_PARAM_INVALID;
711 }
712 auto impl = static_cast<mindspore::ModelC *>(model);
713 std::vector<std::string> weights_name;
714 for (size_t i = 0; i < num; i++) {
715 weights_name.push_back(changeable_weights_name[i]);
716 }
717 auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(
718 *(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), weight_file, is_inference, enable_fp16,
719 weights_name);
720 if (!ret.IsOk()) {
721 MS_LOG(ERROR) << "export model fail, ret :" << ret;
722 }
723 return static_cast<OH_AI_Status>(ret.StatusCode());
724 }
725
OH_AI_ModelLoadConfig(OH_AI_ModelHandle model,const char * config_file_path)726 OH_AI_Status OH_AI_ModelLoadConfig(OH_AI_ModelHandle model, const char *config_file_path) {
727 MS_LOG(INFO) << "Start to load config file for ms model";
728 if (model == nullptr || config_file_path == nullptr) {
729 MS_LOG(ERROR) << "model or config_file_path is nullptr.";
730 return OH_AI_STATUS_LITE_NULLPTR;
731 }
732 MS_LOG(INFO) << "config_file_path: " << config_file_path;
733
734 auto impl = static_cast<mindspore::ModelC *>(model);
735 auto ret = impl->model_->LoadConfig(config_file_path);
736
737 if (ret.IsOk()) {
738 MS_LOG(INFO) << "Loaded ms model config file successfully";
739 } else {
740 MS_LOG(ERROR) << "Loaded ms model config file failed, ret: " << ret;
741 }
742 return static_cast<OH_AI_Status>(ret.StatusCode());
743 }
744