• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/c_api/model_c.h"
17 #include "type_c_private.h"
18 #include "context_c.h"
19 #include <vector>
20 #include <cstdint>
21 #include "include/api/context.h"
22 #include "include/api/serialization.h"
23 #include "include/api/types.h"
24 #include "src/litert/cxx_api/tensor/tensor_impl.h"
25 #include "src/litert/cxx_api/model/model_impl.h"
26 #ifdef ENABLE_HI_APP_EVENT
27 #include "src/common/hi_app_event/hi_app_event.h"
28 #endif
29 
30 namespace mindspore {
31 class ModelC {
32  public:
ModelC()33   ModelC() : model_(nullptr) {}
~ModelC()34   ~ModelC() {
35     for (auto in : inputs_) {
36       if (in != nullptr) {
37         delete in;
38       }
39     }
40     for (auto out : outputs_) {
41       if (out != nullptr) {
42         delete out;
43       }
44     }
45     for (auto out : outputs_train_) {
46       if (out != nullptr) {
47         delete out;
48       }
49     }
50 
51     // In zero copy scene where user will call set or get allocator function, but when model is destroyed, the allocator
52     // table will not be freed, and its size continues to grow causing memory leak, so when ModelC is destroyed, clean
53     // allocator table.
54     CleanAllocatorTable();
55   }
56 
57   MSTensor **GetInputs(size_t *input_num);
58   MSTensor **GetOutputs(size_t *output_num);
59   mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &ms_callback);
60   std::shared_ptr<Model> model_;
61   std::shared_ptr<Context> context_;
62 
63  private:
64   MSTensor **GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors);
65   std::vector<MSTensor *> inputs_;
66   std::vector<MSTensor *> outputs_;
67   std::vector<MSTensor *> outputs_train_;
68 };
69 
GetInputs(size_t * input_num)70 MSTensor **ModelC::GetInputs(size_t *input_num) {
71   if (model_ == nullptr) {
72     MS_LOG(ERROR) << "model_ is nullptr.";
73     return nullptr;
74   }
75   if (!inputs_.empty()) {
76     *input_num = inputs_.size();
77     return inputs_.data();
78   }
79   auto inputs = model_->GetInputs();
80   *input_num = inputs.size();
81   inputs_.resize(inputs.size(), nullptr);
82   for (size_t i = 0; i < inputs.size(); i++) {
83     inputs_[i] = new (std::nothrow) MSTensor(inputs[i].impl());
84     if (inputs_[i] == nullptr) {
85       inputs_.clear();
86       return nullptr;
87     }
88   }
89   return inputs_.data();
90 }
91 
GetOutputs(size_t * output_num)92 MSTensor **ModelC::GetOutputs(size_t *output_num) {
93   if (model_->GetTrainMode() == true) {
94     return GetOutputsTensor(output_num, &outputs_train_);
95   } else {
96     return GetOutputsTensor(output_num, &outputs_);
97   }
98 }
99 
GetOutputsTensor(size_t * output_num,std::vector<MSTensor * > * vec_tensors)100 MSTensor **ModelC::GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors) {
101   if (model_ == nullptr) {
102     MS_LOG(ERROR) << "model_ is nullptr.";
103     return nullptr;
104   }
105   if (!vec_tensors->empty()) {
106     *output_num = vec_tensors->size();
107     return vec_tensors->data();
108   }
109 
110   auto outputs = model_->GetOutputs();
111   *output_num = outputs.size();
112   vec_tensors->resize(outputs.size(), nullptr);
113   for (size_t i = 0; i < outputs.size(); i++) {
114     (*vec_tensors)[i] = new (std::nothrow) MSTensor(outputs[i].impl());
115     if ((*vec_tensors)[i] == nullptr) {
116       vec_tensors->clear();
117       return nullptr;
118     }
119   }
120   return vec_tensors->data();
121 }
122 
TransCallBack(const OH_AI_KernelCallBack & ms_callback)123 mindspore::MSKernelCallBack ModelC::TransCallBack(const OH_AI_KernelCallBack &ms_callback) {
124   mindspore::MSKernelCallBack call_back = nullptr;
125   if (ms_callback != nullptr) {
126     call_back = [&](const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
127                     const mindspore::MSCallBackParam &opInfo) {
128       std::vector<OH_AI_TensorHandle> vec_inputs;
129       std::vector<OH_AI_TensorHandle> vec_outputs;
130       OH_AI_CallBackParam call_back = {const_cast<char *>(opInfo.node_name.c_str()),
131                                     const_cast<char *>(opInfo.node_type.c_str())};
132       size_t inputs_handle_num = inputs.size();
133       for (size_t i = 0; i < inputs_handle_num; i++) {
134         vec_inputs.push_back(static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(inputs)[i])));
135       }
136       size_t outputs_handle_num = outputs.size();
137       for (size_t i = 0; i < outputs_handle_num; i++) {
138         vec_outputs.push_back(
139           static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(outputs)[i])));
140       }
141       OH_AI_TensorHandleArray handle_inputs = {inputs_handle_num, vec_inputs.data()};
142       OH_AI_TensorHandleArray handle_outputs = {outputs_handle_num, vec_outputs.data()};
143       return ms_callback(handle_inputs, handle_outputs, call_back);
144     };
145   }
146   return call_back;
147 }
148 }  // namespace mindspore
149 
OH_AI_ModelCreate()150 OH_AI_ModelHandle OH_AI_ModelCreate() {
151   MS_LOG(INFO) << "Start to create ms model";
152   auto impl = new (std::nothrow) mindspore::ModelC();
153   if (impl == nullptr) {
154     MS_LOG(ERROR) << "Model implement is nullptr.";
155     return nullptr;
156   }
157   impl->model_ = std::make_shared<mindspore::Model>();
158   if (impl->model_ == nullptr) {
159     MS_LOG(ERROR) << "inner model object is nullptr.";
160     delete impl;
161     return nullptr;
162   }
163   MS_LOG(INFO) << "Created ms model successfully";
164   return static_cast<OH_AI_ModelHandle>(impl);
165 }
166 
OH_AI_ModelDestroy(OH_AI_ModelHandle * model)167 void OH_AI_ModelDestroy(OH_AI_ModelHandle *model) {
168   MS_LOG(INFO) << "Start to destroy ms model";
169   if (model == nullptr || *model == nullptr) {
170     MS_LOG(ERROR) << "model is nullptr.";
171     return;
172   }
173   auto impl = static_cast<mindspore::ModelC *>(*model);
174   delete impl;
175   *model = nullptr;
176   MS_LOG(INFO) << "Destroyed ms model successfully";
177 }
178 
OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model,void * workspace,size_t workspace_size)179 void OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model, void *workspace, size_t workspace_size) {
180   MS_LOG(ERROR) << "Unsupported Feature.";
181   return;
182 }
183 
OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model)184 size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {
185   MS_LOG(ERROR) << "Unsupported Feature.";
186   return 0;
187 }
188 
OH_AI_ModelBuild(OH_AI_ModelHandle model,const void * model_data,size_t data_size,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context)189 OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, OH_AI_ModelType model_type,
190                       const OH_AI_ContextHandle model_context) {
191   MS_LOG(INFO) << "Start to build ms model";
192   if (model == nullptr || model_data == nullptr || model_context == nullptr) {
193     MS_LOG(ERROR) << "model or model_data or model_context is nullptr.";
194     return OH_AI_STATUS_LITE_NULLPTR;
195   }
196   if (model_type == OH_AI_MODELTYPE_INVALID) {
197     MS_LOG(ERROR) << "model_type is invalid.";
198     return OH_AI_STATUS_LITE_PARAM_INVALID;
199   }
200   mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
201   auto impl = static_cast<mindspore::ModelC *>(model);
202   if (impl->context_.get() != context->context_ && context->owned_by_model_) {
203     MS_LOG(ERROR) << "context is owned by other model.";
204     return OH_AI_STATUS_LITE_PARAM_INVALID;
205   }
206   if (impl->context_.get() != context->context_) {
207     impl->context_.reset(context->context_);
208     context->owned_by_model_ = true;
209   }
210   auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_);
211   if (ret.IsOk()) {
212     MS_LOG(INFO) << "Built ms model successfully";
213   } else {
214     MS_LOG(ERROR) << "Built ms model failed, ret: " << ret;
215   }
216   return static_cast<OH_AI_Status>(ret.StatusCode());
217 }
218 
OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model,const char * model_path,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context)219 OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
220                               const OH_AI_ContextHandle model_context) {
221   MS_LOG(INFO) << "Start to build ms model from file";
222   if (model == nullptr || model_path == nullptr || model_context == nullptr) {
223     MS_LOG(ERROR) << "model or model_path or model_context is nullptr.";
224     return OH_AI_STATUS_LITE_NULLPTR;
225   }
226   if (model_type == OH_AI_MODELTYPE_INVALID) {
227     MS_LOG(ERROR) << "model_type is invalid.";
228     return OH_AI_STATUS_LITE_PARAM_INVALID;
229   }
230   mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
231   auto impl = static_cast<mindspore::ModelC *>(model);
232   if (impl->context_.get() != context->context_ && context->owned_by_model_) {
233     MS_LOG(ERROR) << "context is owned by other model.";
234     return OH_AI_STATUS_LITE_PARAM_INVALID;
235   }
236   if (impl->context_.get() != context->context_) {
237     impl->context_.reset(context->context_);
238     context->owned_by_model_ = true;
239   }
240   auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_);
241   if (ret.IsOk()) {
242     MS_LOG(INFO) << "Built ms model from file successfully";
243   } else {
244     MS_LOG(ERROR) << "Built ms model from file failed, ret: " << ret;
245   }
246   return static_cast<OH_AI_Status>(ret.StatusCode());
247 }
248 
OH_AI_ModelResize(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray inputs,OH_AI_ShapeInfo * shape_infos,size_t shape_info_num)249 OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos,
250                        size_t shape_info_num) {
251   MS_LOG(INFO) << "Start to resize ms model";
252   if (model == nullptr || shape_infos == nullptr) {
253     MS_LOG(ERROR) << "model or shape_infos is nullptr.";
254     return OH_AI_STATUS_LITE_NULLPTR;
255   }
256   std::vector<mindspore::MSTensor> vec_inputs;
257   for (size_t i = 0; i < inputs.handle_num; ++i) {
258     vec_inputs.push_back(*static_cast<mindspore::MSTensor *>(inputs.handle_list[i]));
259   }
260 
261   std::vector<std::vector<int64_t>> vec_dims;
262   for (size_t i = 0; i < shape_info_num; i++) {
263     std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num);
264     if (std::any_of(shape.begin(), shape.end(), [](int64_t val) { return val < 0 || val > INT32_MAX; })) {
265       MS_LOG(ERROR) << "Invalid shape: " << shape << ", each dimension must be in [0, INT32_MAX]";
266       return OH_AI_STATUS_LITE_PARAM_INVALID;
267     }
268     vec_dims.push_back(shape);
269   }
270   auto impl = static_cast<mindspore::ModelC *>(model);
271   auto ret = impl->model_->Resize(vec_inputs, vec_dims);
272   if (ret.IsOk()) {
273     MS_LOG(INFO) << "Resized ms model successfully";
274   } else {
275     MS_LOG(ERROR) << "Resized ms model failed, ret: " << ret;
276   }
277   return static_cast<OH_AI_Status>(ret.StatusCode());
278 }
279 
OH_AI_ModelPredict(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray inputs,OH_AI_TensorHandleArray * outputs,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)280 OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_TensorHandleArray *outputs,
281                         const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
282   MS_LOG(INFO) << "Start to predict ms model";
283   if (model == nullptr) {
284     MS_LOG(ERROR) << "model is nullptr.";
285     return OH_AI_STATUS_LITE_NULLPTR;
286   }
287   auto impl = static_cast<mindspore::ModelC *>(model);
288   size_t input_num;
289   (void)impl->GetInputs(&input_num);
290   if (input_num != inputs.handle_num) {
291     MS_LOG(ERROR) << "Wrong input size.";
292     return OH_AI_STATUS_LITE_ERROR;
293   }
294 
295   std::vector<mindspore::MSTensor> ms_tensor_inputs;
296   for (size_t i = 0; i < inputs.handle_num; i++) {
297     if (inputs.handle_list[i] != nullptr) {
298       auto user_input = static_cast<mindspore::MSTensor *>(inputs.handle_list[i]);
299       ms_tensor_inputs.push_back(*user_input);
300     } else {
301       MS_LOG(ERROR) << "input handle is nullptr.";
302       return OH_AI_STATUS_LITE_NULLPTR;
303     }
304   }
305 
306   mindspore::MSKernelCallBack before_call_back = impl->TransCallBack(before);
307   mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after);
308   std::vector<mindspore::MSTensor> ms_tensor_outputs;
309 
310   size_t output_num;
311   (void)impl->GetOutputs(&output_num);
312   auto handle_num = outputs->handle_num;
313   if (handle_num == output_num) {
314     MS_LOG(INFO) << "use user provided output";
315     for (size_t i = 0; i < output_num; i++) {
316       if (outputs->handle_list[i] == nullptr) {
317         MS_LOG(ERROR) << "user provided output array handle_list[" << i << "] is nullptr";
318         return OH_AI_STATUS_LITE_NULLPTR;
319       }
320       ms_tensor_outputs.push_back(*static_cast<mindspore::MSTensor *>(outputs->handle_list[i]));
321     }
322   }
323 
324   auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
325   if (!ret.IsOk()) {
326     MS_LOG(ERROR) << "Predict fail, ret :" << ret;
327     return static_cast<OH_AI_Status>(ret.StatusCode());
328   }
329 
330   if (handle_num == output_num) {
331     return OH_AI_STATUS_SUCCESS;
332   }
333 
334   outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&(outputs->handle_num)));
335   MS_LOG(INFO) << "Predicted ms model successfully";
336   return static_cast<OH_AI_Status>(ret.StatusCode());
337 }
338 
OH_AI_ModelRunStep(OH_AI_ModelHandle model,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)339 OH_AI_Status OH_AI_ModelRunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
340   MS_LOG(ERROR) << "Unsupported Feature.";
341   return OH_AI_STATUS_LITE_NOT_SUPPORT;
342 }
343 
OH_AI_ModelExportWeight(const OH_AI_ModelHandle model,const char * export_path)344 OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *export_path) {
345   MS_LOG(ERROR) << "Unsupported Feature.";
346   return OH_AI_STATUS_LITE_NOT_SUPPORT;
347 }
348 
OH_AI_ModelGetInputs(const OH_AI_ModelHandle model)349 OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) {
350   MS_LOG(INFO) << "Start to get ms model inputs";
351   if (model == nullptr) {
352     MS_LOG(ERROR) << "model is nullptr.";
353     return {0, nullptr};
354   }
355   auto impl = static_cast<mindspore::ModelC *>(model);
356   size_t input_num = 0;
357   auto handles = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetInputs(&input_num));
358   MS_LOG(INFO) << "Got ms model " << input_num << " inputs successfully";
359   return {input_num, handles};
360 }
361 
OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model)362 OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) {
363   MS_LOG(INFO) << "Start to get ms model outputs";
364   if (model == nullptr) {
365     MS_LOG(ERROR) << "model is nullptr.";
366     return {0, nullptr};
367   }
368   auto impl = static_cast<mindspore::ModelC *>(model);
369   size_t output_num;
370   auto handles = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&output_num));
371   MS_LOG(INFO) << "Got ms model " << output_num << " outputs successfully";
372   return {output_num, handles};
373 }
374 
OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model,const char * tensor_name)375 OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
376   MS_LOG(INFO) << "Start to get ms model input by name";
377   if (model == nullptr || tensor_name == nullptr) {
378     MS_LOG(ERROR) << "model or tensor_name is nullptr.";
379     return nullptr;
380   }
381   auto impl = static_cast<mindspore::ModelC *>(model);
382   size_t input_num;
383   auto inputs = impl->GetInputs(&input_num);
384   for (size_t i = 0; i < input_num; i++) {
385     if (inputs[i]->Name() == tensor_name) {
386       MS_LOG(INFO) << "Got ms model input by name successfully";
387       return static_cast<OH_AI_TensorHandle>(inputs[i]);
388     }
389   }
390   MS_LOG(ERROR) << "Input tensor is not exist";
391   return nullptr;
392 }
393 
OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model,const char * tensor_name)394 OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
395   MS_LOG(INFO) << "Start to get ms model output by name";
396   if (model == nullptr || tensor_name == nullptr) {
397     MS_LOG(ERROR) << "model or tensor_name is nullptr.";
398     return nullptr;
399   }
400   auto impl = static_cast<mindspore::ModelC *>(model);
401   size_t output_num;
402   auto outputs = impl->GetOutputs(&output_num);
403   for (size_t i = 0; i < output_num; i++) {
404     if (outputs[i]->Name() == tensor_name) {
405       MS_LOG(INFO) << "Got ms model output by name successfully";
406       return static_cast<OH_AI_TensorHandle>(outputs[i]);
407     }
408   }
409   MS_LOG(ERROR) << "Output tensor is not exist";
410   return nullptr;
411 }
412 
OH_AI_TrainCfgCreate()413 OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate() {
414   auto impl = new (std::nothrow) mindspore::TrainCfg();
415   if (impl == nullptr) {
416     MS_LOG(ERROR) << "TrainCfg implement is nullptr.";
417     return nullptr;
418   }
419   return static_cast<OH_AI_TrainCfgHandle>(impl);
420 }
421 
OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle * train_cfg)422 void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg) {
423   if (train_cfg != nullptr && *train_cfg != nullptr) {
424     auto impl = static_cast<mindspore::TrainCfg *>(*train_cfg);
425     delete impl;
426     *train_cfg = nullptr;
427   }
428 }
429 
OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg,size_t * num)430 char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) {
431   if (train_cfg == nullptr || num == nullptr) {
432     MS_LOG(ERROR) << "train_cfg or num is nullptr.";
433     return nullptr;
434   }
435   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
436   auto loss_name = impl->GetLossName();
437   *num = loss_name.size();
438   char **name = static_cast<char **>(malloc(loss_name.size() * sizeof(char *)));
439   if (name == nullptr) {
440     MS_LOG(ERROR) << "Failed to malloc loss_name.";
441     return nullptr;
442   }
443   for (size_t i = 0; i < loss_name.size(); i++) {
444     name[i] = static_cast<char *>(malloc(loss_name[i].size() + 1));
445     if (name[i] == nullptr) {
446       for(size_t j = 0; j < i; j++){
447         free(name[j]);
448       }
449       MS_LOG(ERROR) << "Failed to malloc name.";
450       return nullptr;
451     }
452     memcpy(name[i], loss_name[i].c_str(), loss_name[i].size() + 1);
453   }
454   return name;
455 }
456 
OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg,const char ** loss_name,size_t num)457 void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num) {
458   if (train_cfg == nullptr || loss_name == nullptr || *loss_name == nullptr) {
459     MS_LOG(ERROR) << "train_cfg or loss_name is nullptr.";
460     return;
461   }
462   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
463   std::vector<std::string> vec_name;
464   for (size_t i = 0; i < num; i++) {
465     vec_name.push_back(loss_name[i]);
466   }
467   impl->SetLossName(vec_name);
468 }
469 
OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg)470 OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg) {
471   if (train_cfg == nullptr) {
472     MS_LOG(ERROR) << "train_cfg is nullptr, return OH_AI_KO0";
473     return OH_AI_KO0;
474   }
475   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
476   return static_cast<OH_AI_OptimizationLevel>(impl->optimization_level_);
477 }
478 
OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg,OH_AI_OptimizationLevel level)479 void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level) {
480   if (train_cfg == nullptr) {
481     MS_LOG(ERROR) << "train_cfg is nullptr.";
482     return;
483   }
484   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
485   impl->optimization_level_ = static_cast<mindspore::OptimizationLevel>(level);
486 }
487 
OH_AI_TrainModelBuild(OH_AI_ModelHandle model,const void * model_data,size_t data_size,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context,const OH_AI_TrainCfgHandle train_cfg)488 OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, OH_AI_ModelType model_type,
489                            const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg) {
490   if (model == nullptr || model_data == nullptr || model_context == nullptr) {
491     MS_LOG(ERROR) << "model or model_data or model_context is nullptr.";
492     return OH_AI_STATUS_LITE_NULLPTR;
493   }
494   if (model_type == OH_AI_MODELTYPE_INVALID) {
495     MS_LOG(ERROR) << "model_type is invalid.";
496     return OH_AI_STATUS_LITE_PARAM_INVALID;
497   }
498   auto impl = static_cast<mindspore::ModelC *>(model);
499 
500   mindspore::Graph graph;
501   auto status =
502     mindspore::Serialization::Load(model_data, data_size, static_cast<mindspore::ModelType>(model_type), &graph);
503   if (status != mindspore::kSuccess) {
504     MS_LOG(ERROR) << "load ms file failed.";
505     return OH_AI_STATUS_LITE_ERROR;
506   }
507   auto context = static_cast<mindspore::ContextC *>(model_context);
508   auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
509   if (impl->context_.get() != context->context_ && context->owned_by_model_) {
510     MS_LOG(ERROR) << "context is owned by other model.";
511     return OH_AI_STATUS_LITE_PARAM_INVALID;
512   }
513   if (impl->context_.get() != context->context_) {
514     impl->context_.reset(context->context_);
515     context->owned_by_model_ = true;
516   }
517   auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
518                                  std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
519   if (ret != mindspore::kSuccess) {
520     MS_LOG(ERROR) << "Load and compile failed";
521   }
522   return static_cast<OH_AI_Status>(ret.StatusCode());
523 }
524 
OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model,const char * model_path,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context,const OH_AI_TrainCfgHandle train_cfg)525 OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
526                                    const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg) {
527   if (model == nullptr || model_path == nullptr || model_context == nullptr) {
528     MS_LOG(ERROR) << "model or model_path or model_context is nullptr.";
529     return OH_AI_STATUS_LITE_NULLPTR;
530   }
531   if (model_type == OH_AI_MODELTYPE_INVALID) {
532     MS_LOG(ERROR) << "model_type is invalid.";
533     return OH_AI_STATUS_LITE_PARAM_INVALID;
534   }
535   auto impl = static_cast<mindspore::ModelC *>(model);
536 
537   mindspore::Graph graph;
538   auto status = mindspore::Serialization::Load(model_path, static_cast<mindspore::ModelType>(model_type), &graph);
539   if (status != mindspore::kSuccess) {
540     MS_LOG(ERROR) << "load ms file failed. " << model_path;
541     return OH_AI_STATUS_LITE_ERROR;
542   }
543   auto context = static_cast<mindspore::ContextC *>(model_context);
544   auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
545   if (impl->context_.get() != context->context_ && context->owned_by_model_) {
546     MS_LOG(ERROR) << "context is owned by other model.";
547     return OH_AI_STATUS_LITE_PARAM_INVALID;
548   }
549   if (impl->context_.get() != context->context_) {
550     impl->context_.reset(context->context_);
551     context->owned_by_model_ = true;
552   }
553   auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
554                                  std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
555   if (ret != mindspore::kSuccess) {
556     MS_LOG(ERROR) << "Load and compile failed";
557   }
558   return static_cast<OH_AI_Status>(ret.StatusCode());
559 }
560 
OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model,float learning_rate)561 OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate) {
562   if (model == nullptr) {
563     MS_LOG(ERROR) << "model is nullptr.";
564     return OH_AI_STATUS_LITE_PARAM_INVALID;
565   }
566   auto impl = static_cast<mindspore::ModelC *>(model);
567   auto ret = impl->model_->SetLearningRate(learning_rate);
568   return static_cast<OH_AI_Status>(ret.StatusCode());
569 }
570 
OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model)571 float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model) {
572   if (model == nullptr) {
573     MS_LOG(ERROR) << "model is nullptr.";
574     return OH_AI_STATUS_LITE_PARAM_INVALID;
575   }
576   auto impl = static_cast<mindspore::ModelC *>(model);
577   return impl->model_->GetLearningRate();
578 }
579 
OH_AI_RunStep(OH_AI_ModelHandle model,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)580 OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
581   if (model == nullptr) {
582     MS_LOG(ERROR) << "model is nullptr.";
583     return OH_AI_STATUS_LITE_PARAM_INVALID;
584   }
585   auto impl = static_cast<mindspore::ModelC *>(model);
586   auto ret = impl->model_->RunStep(impl->TransCallBack(before), impl->TransCallBack(after));
587   return static_cast<OH_AI_Status>(ret.StatusCode());
588 }
589 
OH_AI_ModelGetWeights(OH_AI_ModelHandle model)590 OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model) {
591   if (model == nullptr) {
592     MS_LOG(ERROR) << "model is nullptr.";
593     return {0, nullptr};
594   }
595   auto impl = static_cast<mindspore::ModelC *>(model);
596   auto features = impl->model_->GetFeatureMaps();
597   size_t handle_num = features.size();
598 
599   mindspore::MSTensor **handle_list =
600     static_cast<mindspore::MSTensor **>(malloc(handle_num * sizeof(mindspore::MSTensor *)));
601   if (handle_list == nullptr) {
602     MS_LOG(ERROR) << "Failed to malloc handle_list.";
603     return {0, nullptr};
604   }
605   for (size_t i = 0; i < handle_num; i++) {
606     handle_list[i] = new (std::nothrow) mindspore::MSTensor(features[i].impl());
607   }
608   return {handle_num, reinterpret_cast<OH_AI_TensorHandle *>(handle_list)};
609 }
610 
OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray new_weights)611 OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights) {
612   if (model == nullptr) {
613     MS_LOG(ERROR) << "model is nullptr.";
614     return OH_AI_STATUS_LITE_PARAM_INVALID;
615   }
616   auto impl = static_cast<mindspore::ModelC *>(model);
617   std::vector<mindspore::MSTensor> weights;
618   for (size_t i = 0; i < new_weights.handle_num; i++) {
619     weights.push_back(*static_cast<mindspore::MSTensor *>(new_weights.handle_list[i]));
620   }
621   auto ret = impl->model_->UpdateWeights(weights);
622   return static_cast<OH_AI_Status>(ret.StatusCode());
623 }
624 
OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model)625 bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model) {
626   if (model == nullptr) {
627     MS_LOG(ERROR) << "model is nullptr.";
628     return false;
629   }
630   auto impl = static_cast<mindspore::ModelC *>(model);
631   return impl->model_->GetTrainMode();
632 }
633 
OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model,bool train)634 OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train) {
635   if (model == nullptr) {
636     MS_LOG(ERROR) << "model is nullptr.";
637     return OH_AI_STATUS_LITE_PARAM_INVALID;
638   }
639   auto impl = static_cast<mindspore::ModelC *>(model);
640   auto ret = impl->model_->SetTrainMode(train);
641   return static_cast<OH_AI_Status>(ret.StatusCode());
642 }
643 
OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model,int virtual_batch_multiplier,float lr,float momentum)644 OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, float momentum) {
645   if (model == nullptr) {
646     MS_LOG(ERROR) << "model is nullptr.";
647     return OH_AI_STATUS_LITE_PARAM_INVALID;
648   }
649   auto impl = static_cast<mindspore::ModelC *>(model);
650   auto ret = impl->model_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
651   return static_cast<OH_AI_Status>(ret.StatusCode());
652 }
653 
OH_AI_ExportModel(OH_AI_ModelHandle model,OH_AI_ModelType model_type,const char * model_file,OH_AI_QuantizationType quantization_type,bool export_inference_only,char ** output_tensor_name,size_t num)654 OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
655                        OH_AI_QuantizationType quantization_type, bool export_inference_only, char **output_tensor_name,
656                        size_t num) {
657   if (model == nullptr) {
658     MS_LOG(ERROR) << "model is nullptr.";
659     return OH_AI_STATUS_LITE_PARAM_INVALID;
660   }
661   auto impl = static_cast<mindspore::ModelC *>(model);
662   std::vector<std::string> tensor_name;
663   for (size_t i = 0; i < num; i++) {
664     tensor_name.push_back(output_tensor_name[i]);
665   }
666   auto ret = mindspore::Serialization::ExportModel(
667     *(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), model_file,
668     static_cast<mindspore::QuantizationType>(quantization_type), export_inference_only, tensor_name);
669   if (!ret.IsOk()) {
670     MS_LOG(ERROR) << "export model fail, ret :" << ret;
671   }
672   return static_cast<OH_AI_Status>(ret.StatusCode());
673 }
674 
OH_AI_ExportModelBuffer(OH_AI_ModelHandle model,OH_AI_ModelType model_type,char ** model_data,size_t * data_size,OH_AI_QuantizationType quantization_type,bool export_inference_only,char ** output_tensor_name,size_t num)675 OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data, size_t *data_size,
676                              OH_AI_QuantizationType quantization_type, bool export_inference_only,
677                              char **output_tensor_name, size_t num) {
678   if (model == nullptr) {
679     MS_LOG(ERROR) << "model is nullptr.";
680     return OH_AI_STATUS_LITE_PARAM_INVALID;
681   }
682   auto impl = static_cast<mindspore::ModelC *>(model);
683   std::vector<std::string> tensor_name;
684   for (size_t i = 0; i < num; i++) {
685     tensor_name.push_back(output_tensor_name[i]);
686   }
687   mindspore::Buffer buffer;
688   auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type),
689                                                    &buffer, static_cast<mindspore::QuantizationType>(quantization_type),
690                                                    export_inference_only, tensor_name);
691   auto data = reinterpret_cast<char *>(buffer.MutableData());
692   *model_data = reinterpret_cast<char *>(malloc(buffer.DataSize()));
693   if (*model_data == nullptr) {
694     MS_LOG(ERROR) << "malloc model_data failed.";
695     return OH_AI_STATUS_LITE_NULLPTR;
696   }
697   *data_size = buffer.DataSize();
698   memcpy(*model_data, data, buffer.DataSize());
699   if (!ret.IsOk()) {
700     MS_LOG(ERROR) << "export model fail, ret :" << ret;
701   }
702   return static_cast<OH_AI_Status>(ret.StatusCode());
703 }
704 
OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model,OH_AI_ModelType model_type,const char * weight_file,bool is_inference,bool enable_fp16,char ** changeable_weights_name,size_t num)705 OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *weight_file,
706                                              bool is_inference, bool enable_fp16, char **changeable_weights_name,
707                                              size_t num) {
708   if (model == nullptr) {
709     MS_LOG(ERROR) << "model is nullptr.";
710     return OH_AI_STATUS_LITE_PARAM_INVALID;
711   }
712   auto impl = static_cast<mindspore::ModelC *>(model);
713   std::vector<std::string> weights_name;
714   for (size_t i = 0; i < num; i++) {
715     weights_name.push_back(changeable_weights_name[i]);
716   }
717   auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(
718     *(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), weight_file, is_inference, enable_fp16,
719     weights_name);
720   if (!ret.IsOk()) {
721     MS_LOG(ERROR) << "export model fail, ret :" << ret;
722   }
723   return static_cast<OH_AI_Status>(ret.StatusCode());
724 }
725 
OH_AI_ModelLoadConfig(OH_AI_ModelHandle model,const char * config_file_path)726 OH_AI_Status OH_AI_ModelLoadConfig(OH_AI_ModelHandle model, const char *config_file_path) {
727   MS_LOG(INFO) << "Start to load config file for ms model";
728   if (model == nullptr || config_file_path == nullptr) {
729     MS_LOG(ERROR) << "model or config_file_path is nullptr.";
730     return OH_AI_STATUS_LITE_NULLPTR;
731   }
732   MS_LOG(INFO) << "config_file_path: " << config_file_path;
733 
734   auto impl = static_cast<mindspore::ModelC *>(model);
735   auto ret = impl->model_->LoadConfig(config_file_path);
736 
737   if (ret.IsOk()) {
738     MS_LOG(INFO) << "Loaded ms model config file successfully";
739   } else {
740     MS_LOG(ERROR) << "Loaded ms model config file failed, ret: " << ret;
741   }
742   return static_cast<OH_AI_Status>(ret.StatusCode());
743 }
744