• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/c_api/model_c.h"
17 #include "type_c_private.h"
18 #include "context_c.h"
19 #include <vector>
20 #include <cstdint>
21 #include "include/api/context.h"
22 #include "include/api/serialization.h"
23 #include "include/api/types.h"
24 #include "src/litert/cxx_api/tensor/tensor_impl.h"
25 #include "src/litert/cxx_api/model/model_impl.h"
26 #ifdef ENABLE_HI_APP_EVENT
27 #include "src/common/hi_app_event/hi_app_event.h"
28 #endif
29 
30 namespace mindspore {
31 class ModelC {
32  public:
ModelC()33   ModelC() : model_(nullptr) {}
~ModelC()34   ~ModelC() {
35     for (auto in : inputs_) {
36       if (in != nullptr) {
37         delete in;
38       }
39     }
40     for (auto out : outputs_) {
41       if (out != nullptr) {
42         delete out;
43       }
44     }
45     for (auto out : outputs_train_) {
46       if (out != nullptr) {
47         delete out;
48       }
49     }
50 
51     // In zero copy scene where user will call set or get allocator function, but when model is destroyed, the allocator
52     // table will not be freed, and its size continues to grow causing memory leak, so when ModelC is destroyed, clean
53     // allocator table.
54     CleanAllocatorTable();
55   }
56 
57   MSTensor **GetInputs(size_t *input_num);
58   MSTensor **GetOutputs(size_t *output_num);
59   mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &ms_callback);
60   std::shared_ptr<Model> model_;
61   std::shared_ptr<Context> context_;
62 
63  private:
64   MSTensor **GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors);
65   std::vector<MSTensor *> inputs_;
66   std::vector<MSTensor *> outputs_;
67   std::vector<MSTensor *> outputs_train_;
68 };
69 
GetInputs(size_t * input_num)70 MSTensor **ModelC::GetInputs(size_t *input_num) {
71   if (model_ == nullptr) {
72     MS_LOG(ERROR) << "model_ is nullptr.";
73     return nullptr;
74   }
75   if (!inputs_.empty()) {
76     *input_num = inputs_.size();
77     return inputs_.data();
78   }
79   auto inputs = model_->GetInputs();
80   *input_num = inputs.size();
81   inputs_.resize(inputs.size(), nullptr);
82   for (size_t i = 0; i < inputs.size(); i++) {
83     inputs_[i] = new (std::nothrow) MSTensor(inputs[i].impl());
84     if (inputs_[i] == nullptr) {
85       inputs_.clear();
86       return nullptr;
87     }
88   }
89   return inputs_.data();
90 }
91 
GetOutputs(size_t * output_num)92 MSTensor **ModelC::GetOutputs(size_t *output_num) {
93   if (model_->GetTrainMode() == true) {
94     return GetOutputsTensor(output_num, &outputs_train_);
95   } else {
96     return GetOutputsTensor(output_num, &outputs_);
97   }
98 }
99 
GetOutputsTensor(size_t * output_num,std::vector<MSTensor * > * vec_tensors)100 MSTensor **ModelC::GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors) {
101   if (model_ == nullptr) {
102     MS_LOG(ERROR) << "model_ is nullptr.";
103     return nullptr;
104   }
105   if (!vec_tensors->empty()) {
106     *output_num = vec_tensors->size();
107     return vec_tensors->data();
108   }
109 
110   auto outputs = model_->GetOutputs();
111   *output_num = outputs.size();
112   vec_tensors->resize(outputs.size(), nullptr);
113   for (size_t i = 0; i < outputs.size(); i++) {
114     (*vec_tensors)[i] = new (std::nothrow) MSTensor(outputs[i].impl());
115     if ((*vec_tensors)[i] == nullptr) {
116       vec_tensors->clear();
117       return nullptr;
118     }
119   }
120   return vec_tensors->data();
121 }
122 
TransCallBack(const OH_AI_KernelCallBack & ms_callback)123 mindspore::MSKernelCallBack ModelC::TransCallBack(const OH_AI_KernelCallBack &ms_callback) {
124   mindspore::MSKernelCallBack call_back = nullptr;
125   if (ms_callback != nullptr) {
126     call_back = [&](const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
127                     const mindspore::MSCallBackParam &opInfo) {
128       std::vector<OH_AI_TensorHandle> vec_inputs;
129       std::vector<OH_AI_TensorHandle> vec_outputs;
130       OH_AI_CallBackParam call_back = {const_cast<char *>(opInfo.node_name.c_str()),
131                                     const_cast<char *>(opInfo.node_type.c_str())};
132       size_t inputs_handle_num = inputs.size();
133       for (size_t i = 0; i < inputs_handle_num; i++) {
134         vec_inputs.push_back(static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(inputs)[i])));
135       }
136       size_t outputs_handle_num = outputs.size();
137       for (size_t i = 0; i < outputs_handle_num; i++) {
138         vec_outputs.push_back(
139           static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(outputs)[i])));
140       }
141       OH_AI_TensorHandleArray handle_inputs = {inputs_handle_num, vec_inputs.data()};
142       OH_AI_TensorHandleArray handle_outputs = {outputs_handle_num, vec_outputs.data()};
143       return ms_callback(handle_inputs, handle_outputs, call_back);
144     };
145   }
146   return call_back;
147 }
148 }  // namespace mindspore
149 
OH_AI_ModelCreate()150 OH_AI_ModelHandle OH_AI_ModelCreate() {
151   auto impl = new (std::nothrow) mindspore::ModelC();
152   if (impl == nullptr) {
153     MS_LOG(ERROR) << "Model implement is nullptr.";
154     return nullptr;
155   }
156   impl->model_ = std::make_shared<mindspore::Model>();
157   if (impl->model_ == nullptr) {
158     MS_LOG(ERROR) << "model_ is nullptr.";
159     delete impl;
160     return nullptr;
161   }
162   return static_cast<OH_AI_ModelHandle>(impl);
163 }
164 
OH_AI_ModelDestroy(OH_AI_ModelHandle * model)165 void OH_AI_ModelDestroy(OH_AI_ModelHandle *model) {
166   if (model == nullptr || *model == nullptr) {
167     MS_LOG(ERROR) << "model is nullptr.";
168     return;
169   }
170   auto impl = static_cast<mindspore::ModelC *>(*model);
171   delete impl;
172   *model = nullptr;
173 }
174 
OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model,void * workspace,size_t workspace_size)175 void OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model, void *workspace, size_t workspace_size) {
176   MS_LOG(ERROR) << "Unsupported Feature.";
177   return;
178 }
179 
OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model)180 size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {
181   MS_LOG(ERROR) << "Unsupported Feature.";
182   return 0;
183 }
184 
OH_AI_ModelBuild(OH_AI_ModelHandle model,const void * model_data,size_t data_size,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context)185 OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, OH_AI_ModelType model_type,
186                       const OH_AI_ContextHandle model_context) {
187   if (model == nullptr || model_data == nullptr || model_context == nullptr) {
188     MS_LOG(ERROR) << "model/model_data/model_context is nullptr.";
189     return OH_AI_STATUS_LITE_NULLPTR;
190   }
191   if (model_type == OH_AI_MODELTYPE_INVALID) {
192     MS_LOG(ERROR) << "model_type is invalid.";
193     return OH_AI_STATUS_LITE_PARAM_INVALID;
194   }
195   mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
196   auto impl = static_cast<mindspore::ModelC *>(model);
197   if (impl->context_.get() != context->context_ && context->owned_by_model_) {
198     MS_LOG(ERROR) << "context is owned by other model.";
199     return OH_AI_STATUS_LITE_PARAM_INVALID;
200   }
201   if (impl->context_.get() != context->context_) {
202     impl->context_.reset(context->context_);
203     context->owned_by_model_ = true;
204   }
205   auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_);
206   return static_cast<OH_AI_Status>(ret.StatusCode());
207 }
208 
OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model,const char * model_path,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context)209 OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
210                               const OH_AI_ContextHandle model_context) {
211   if (model == nullptr || model_path == nullptr || model_context == nullptr) {
212     MS_LOG(ERROR) << "model/model_path/model_context is nullptr.";
213     return OH_AI_STATUS_LITE_NULLPTR;
214   }
215   if (model_type == OH_AI_MODELTYPE_INVALID) {
216     MS_LOG(ERROR) << "model_type is invalid.";
217     return OH_AI_STATUS_LITE_PARAM_INVALID;
218   }
219   mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
220   auto impl = static_cast<mindspore::ModelC *>(model);
221   if (impl->context_.get() != context->context_ && context->owned_by_model_) {
222     MS_LOG(ERROR) << "context is owned by other model.";
223     return OH_AI_STATUS_LITE_PARAM_INVALID;
224   }
225   if (impl->context_.get() != context->context_) {
226     impl->context_.reset(context->context_);
227     context->owned_by_model_ = true;
228   }
229   auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_);
230   return static_cast<OH_AI_Status>(ret.StatusCode());
231 }
232 
OH_AI_ModelResize(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray inputs,OH_AI_ShapeInfo * shape_infos,size_t shape_info_num)233 OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos,
234                        size_t shape_info_num) {
235   if (model == nullptr || shape_infos == nullptr) {
236     MS_LOG(ERROR) << "model/shape_infos is nullptr.";
237     return OH_AI_STATUS_LITE_NULLPTR;
238   }
239   std::vector<mindspore::MSTensor> vec_inputs;
240   for (size_t i = 0; i < inputs.handle_num; ++i) {
241     vec_inputs.push_back(*static_cast<mindspore::MSTensor *>(inputs.handle_list[i]));
242   }
243 
244   std::vector<std::vector<int64_t>> vec_dims;
245   for (size_t i = 0; i < shape_info_num; i++) {
246     std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num);
247     if (std::any_of(shape.begin(), shape.end(), [](int64_t val) { return val < 0 || val > INT32_MAX; })) {
248       MS_LOG(ERROR) << "Invalid shape: " << shape << ", each dimension must be in [0, INT32_MAX]";
249       return OH_AI_STATUS_LITE_PARAM_INVALID;
250     }
251     vec_dims.push_back(shape);
252   }
253   auto impl = static_cast<mindspore::ModelC *>(model);
254   auto ret = impl->model_->Resize(vec_inputs, vec_dims);
255   return static_cast<OH_AI_Status>(ret.StatusCode());
256 }
257 
OH_AI_ModelPredict(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray inputs,OH_AI_TensorHandleArray * outputs,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)258 OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_TensorHandleArray *outputs,
259                         const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
260   if (model == nullptr) {
261     MS_LOG(ERROR) << "model is nullptr.";
262     return OH_AI_STATUS_LITE_NULLPTR;
263   }
264   auto impl = static_cast<mindspore::ModelC *>(model);
265   size_t input_num;
266   (void)impl->GetInputs(&input_num);
267   if (input_num != inputs.handle_num) {
268     MS_LOG(ERROR) << "Wrong input size.";
269     return OH_AI_STATUS_LITE_ERROR;
270   }
271 
272   std::vector<mindspore::MSTensor> ms_tensor_inputs;
273   for (size_t i = 0; i < inputs.handle_num; i++) {
274     if (inputs.handle_list[i] != nullptr) {
275       auto user_input = static_cast<mindspore::MSTensor *>(inputs.handle_list[i]);
276       ms_tensor_inputs.push_back(*user_input);
277     } else {
278       MS_LOG(ERROR) << "input handle is nullptr.";
279       return OH_AI_STATUS_LITE_NULLPTR;
280     }
281   }
282 
283   mindspore::MSKernelCallBack before_call_back = impl->TransCallBack(before);
284   mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after);
285   std::vector<mindspore::MSTensor> ms_tensor_outputs;
286 
287   size_t output_num;
288   (void)impl->GetOutputs(&output_num);
289   auto handle_num = outputs->handle_num;
290   if (handle_num == output_num) {
291     MS_LOG(INFO) << "use user provided output";
292     for (size_t i = 0; i < output_num; i++) {
293       if (outputs->handle_list[i] == nullptr) {
294         MS_LOG(ERROR) << "user provided output array handle_list[" << i << "] is nullptr";
295         return OH_AI_STATUS_LITE_NULLPTR;
296       }
297       ms_tensor_outputs.push_back(*static_cast<mindspore::MSTensor *>(outputs->handle_list[i]));
298     }
299   }
300 
301   auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
302   if (!ret.IsOk()) {
303     MS_LOG(ERROR) << "Predict fail, ret :" << ret;
304     return static_cast<OH_AI_Status>(ret.StatusCode());
305   }
306 
307   if (handle_num == output_num) {
308     return OH_AI_STATUS_SUCCESS;
309   }
310 
311   outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&(outputs->handle_num)));
312   return static_cast<OH_AI_Status>(ret.StatusCode());
313 }
314 
OH_AI_ModelRunStep(OH_AI_ModelHandle model,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)315 OH_AI_Status OH_AI_ModelRunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
316   MS_LOG(ERROR) << "Unsupported Feature.";
317   return OH_AI_STATUS_LITE_NOT_SUPPORT;
318 }
319 
OH_AI_ModelExportWeight(const OH_AI_ModelHandle model,const char * export_path)320 OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *export_path) {
321   MS_LOG(ERROR) << "Unsupported Feature.";
322   return OH_AI_STATUS_LITE_NOT_SUPPORT;
323 }
324 
OH_AI_ModelGetInputs(const OH_AI_ModelHandle model)325 OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) {
326   if (model == nullptr) {
327     MS_LOG(ERROR) << "model is nullptr.";
328     return {0, nullptr};
329   }
330   auto impl = static_cast<mindspore::ModelC *>(model);
331   size_t input_num = 0;
332   auto handles = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetInputs(&input_num));
333   return {input_num, handles};
334 }
335 
OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model)336 OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) {
337   if (model == nullptr) {
338     MS_LOG(ERROR) << "model is nullptr.";
339     return {0, nullptr};
340   }
341   auto impl = static_cast<mindspore::ModelC *>(model);
342   size_t output_num;
343   auto handles = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&output_num));
344   return {output_num, handles};
345 }
346 
OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model,const char * tensor_name)347 OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
348   if (model == nullptr || tensor_name == nullptr) {
349     MS_LOG(ERROR) << "model/tensor_name is nullptr.";
350     return nullptr;
351   }
352   auto impl = static_cast<mindspore::ModelC *>(model);
353   size_t input_num;
354   auto inputs = impl->GetInputs(&input_num);
355   for (size_t i = 0; i < input_num; i++) {
356     if (inputs[i]->Name() == tensor_name) {
357       return static_cast<OH_AI_TensorHandle>(inputs[i]);
358     }
359   }
360   MS_LOG(ERROR) << "tensor is not exist.";
361   return nullptr;
362 }
363 
OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model,const char * tensor_name)364 OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
365   if (model == nullptr || tensor_name == nullptr) {
366     MS_LOG(ERROR) << "model/tensor_name is nullptr.";
367     return nullptr;
368   }
369   auto impl = static_cast<mindspore::ModelC *>(model);
370   size_t output_num;
371   auto outputs = impl->GetOutputs(&output_num);
372   for (size_t i = 0; i < output_num; i++) {
373     if (outputs[i]->Name() == tensor_name) {
374       return static_cast<OH_AI_TensorHandle>(outputs[i]);
375     }
376   }
377   MS_LOG(ERROR) << "tensor is not exist.";
378   return nullptr;
379 }
380 
OH_AI_TrainCfgCreate()381 OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate() {
382   auto impl = new (std::nothrow) mindspore::TrainCfg();
383   if (impl == nullptr) {
384     MS_LOG(ERROR) << "TrainCfg implement is nullptr.";
385     return nullptr;
386   }
387   return static_cast<OH_AI_TrainCfgHandle>(impl);
388 }
389 
OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle * train_cfg)390 void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg) {
391   if (train_cfg != nullptr && *train_cfg != nullptr) {
392     auto impl = static_cast<mindspore::TrainCfg *>(*train_cfg);
393     delete impl;
394     *train_cfg = nullptr;
395   }
396 }
397 
OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg,size_t * num)398 char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) {
399   if (train_cfg == nullptr || num == nullptr) {
400     MS_LOG(ERROR) << "train_cfg/num is nullptr.";
401     return nullptr;
402   }
403   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
404   auto loss_name = impl->GetLossName();
405   *num = loss_name.size();
406   char **name = static_cast<char **>(malloc(loss_name.size() * sizeof(char *)));
407   if (name == nullptr) {
408     MS_LOG(ERROR) << "Failed to malloc loss_name.";
409     return nullptr;
410   }
411   for (size_t i = 0; i < loss_name.size(); i++) {
412     name[i] = static_cast<char *>(malloc(loss_name[i].size() + 1));
413     if (name[i] == nullptr) {
414       for(size_t j = 0; j < i; j++){
415         free(name[j]);
416       }
417       MS_LOG(ERROR) << "Failed to malloc name.";
418       return nullptr;
419     }
420     memcpy(name[i], loss_name[i].c_str(), loss_name[i].size() + 1);
421   }
422   return name;
423 }
424 
OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg,const char ** loss_name,size_t num)425 void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num) {
426   if (train_cfg == nullptr || loss_name == nullptr || *loss_name == nullptr) {
427     MS_LOG(ERROR) << "train_cfg/loss_name is nullptr.";
428     return;
429   }
430   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
431   std::vector<std::string> vec_name;
432   for (size_t i = 0; i < num; i++) {
433     vec_name.push_back(loss_name[i]);
434   }
435   impl->SetLossName(vec_name);
436 }
437 
OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg)438 OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg) {
439   if (train_cfg == nullptr) {
440     MS_LOG(ERROR) << "train_cfg is nullptr, return OH_AI_KO0";
441     return OH_AI_KO0;
442   }
443   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
444   return static_cast<OH_AI_OptimizationLevel>(impl->optimization_level_);
445 }
446 
OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg,OH_AI_OptimizationLevel level)447 void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level) {
448   if (train_cfg == nullptr) {
449     MS_LOG(ERROR) << "train_cfg is nullptr.";
450     return;
451   }
452   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
453   impl->optimization_level_ = static_cast<mindspore::OptimizationLevel>(level);
454 }
455 
OH_AI_TrainModelBuild(OH_AI_ModelHandle model,const void * model_data,size_t data_size,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context,const OH_AI_TrainCfgHandle train_cfg)456 OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, OH_AI_ModelType model_type,
457                            const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg) {
458   if (model == nullptr || model_data == nullptr || model_context == nullptr) {
459     MS_LOG(ERROR) << "model/model_data/model_context is nullptr.";
460     return OH_AI_STATUS_LITE_NULLPTR;
461   }
462   if (model_type == OH_AI_MODELTYPE_INVALID) {
463     MS_LOG(ERROR) << "model_type is invalid.";
464     return OH_AI_STATUS_LITE_PARAM_INVALID;
465   }
466   auto impl = static_cast<mindspore::ModelC *>(model);
467 
468   mindspore::Graph graph;
469   auto status =
470     mindspore::Serialization::Load(model_data, data_size, static_cast<mindspore::ModelType>(model_type), &graph);
471   if (status != mindspore::kSuccess) {
472     MS_LOG(ERROR) << "load ms file failed.";
473     return OH_AI_STATUS_LITE_ERROR;
474   }
475   auto context = static_cast<mindspore::ContextC *>(model_context);
476   auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
477   if (impl->context_.get() != context->context_ && context->owned_by_model_) {
478     MS_LOG(ERROR) << "context is owned by other model.";
479     return OH_AI_STATUS_LITE_PARAM_INVALID;
480   }
481   if (impl->context_.get() != context->context_) {
482     impl->context_.reset(context->context_);
483     context->owned_by_model_ = true;
484   }
485   auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
486                                  std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
487   if (ret != mindspore::kSuccess) {
488     MS_LOG(ERROR) << "Load and compile failed";
489   }
490   return static_cast<OH_AI_Status>(ret.StatusCode());
491 }
492 
OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model,const char * model_path,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context,const OH_AI_TrainCfgHandle train_cfg)493 OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
494                                    const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg) {
495   if (model == nullptr || model_path == nullptr || model_context == nullptr) {
496     MS_LOG(ERROR) << "model/model_path/model_context is nullptr.";
497     return OH_AI_STATUS_LITE_NULLPTR;
498   }
499   if (model_type == OH_AI_MODELTYPE_INVALID) {
500     MS_LOG(ERROR) << "model_type is invalid.";
501     return OH_AI_STATUS_LITE_PARAM_INVALID;
502   }
503   auto impl = static_cast<mindspore::ModelC *>(model);
504 
505   mindspore::Graph graph;
506   auto status = mindspore::Serialization::Load(model_path, static_cast<mindspore::ModelType>(model_type), &graph);
507   if (status != mindspore::kSuccess) {
508     MS_LOG(ERROR) << "load ms file failed. " << model_path;
509     return OH_AI_STATUS_LITE_ERROR;
510   }
511   auto context = static_cast<mindspore::ContextC *>(model_context);
512   auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
513   if (impl->context_.get() != context->context_ && context->owned_by_model_) {
514     MS_LOG(ERROR) << "context is owned by other model.";
515     return OH_AI_STATUS_LITE_PARAM_INVALID;
516   }
517   if (impl->context_.get() != context->context_) {
518     impl->context_.reset(context->context_);
519     context->owned_by_model_ = true;
520   }
521   auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
522                                  std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
523   if (ret != mindspore::kSuccess) {
524     MS_LOG(ERROR) << "Load and compile failed";
525   }
526   return static_cast<OH_AI_Status>(ret.StatusCode());
527 }
528 
OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model,float learning_rate)529 OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate) {
530   if (model == nullptr) {
531     MS_LOG(ERROR) << "model is nullptr.";
532     return OH_AI_STATUS_LITE_PARAM_INVALID;
533   }
534   auto impl = static_cast<mindspore::ModelC *>(model);
535   auto ret = impl->model_->SetLearningRate(learning_rate);
536   return static_cast<OH_AI_Status>(ret.StatusCode());
537 }
538 
OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model)539 float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model) {
540   if (model == nullptr) {
541     MS_LOG(ERROR) << "model is nullptr.";
542     return OH_AI_STATUS_LITE_PARAM_INVALID;
543   }
544   auto impl = static_cast<mindspore::ModelC *>(model);
545   return impl->model_->GetLearningRate();
546 }
547 
OH_AI_RunStep(OH_AI_ModelHandle model,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)548 OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
549   if (model == nullptr) {
550     MS_LOG(ERROR) << "model is nullptr.";
551     return OH_AI_STATUS_LITE_PARAM_INVALID;
552   }
553   auto impl = static_cast<mindspore::ModelC *>(model);
554   auto ret = impl->model_->RunStep(impl->TransCallBack(before), impl->TransCallBack(after));
555   return static_cast<OH_AI_Status>(ret.StatusCode());
556 }
557 
OH_AI_ModelGetWeights(OH_AI_ModelHandle model)558 OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model) {
559   if (model == nullptr) {
560     MS_LOG(ERROR) << "model is nullptr.";
561     return {0, nullptr};
562   }
563   auto impl = static_cast<mindspore::ModelC *>(model);
564   auto features = impl->model_->GetFeatureMaps();
565   size_t handle_num = features.size();
566 
567   mindspore::MSTensor **handle_list =
568     static_cast<mindspore::MSTensor **>(malloc(handle_num * sizeof(mindspore::MSTensor *)));
569   if (handle_list == nullptr) {
570     MS_LOG(ERROR) << "Failed to malloc handle_list.";
571     return {0, nullptr};
572   }
573   for (size_t i = 0; i < handle_num; i++) {
574     handle_list[i] = new (std::nothrow) mindspore::MSTensor(features[i].impl());
575   }
576   return {handle_num, reinterpret_cast<OH_AI_TensorHandle *>(handle_list)};
577 }
578 
OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray new_weights)579 OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights) {
580   if (model == nullptr) {
581     MS_LOG(ERROR) << "model is nullptr.";
582     return OH_AI_STATUS_LITE_PARAM_INVALID;
583   }
584   auto impl = static_cast<mindspore::ModelC *>(model);
585   std::vector<mindspore::MSTensor> weights;
586   for (size_t i = 0; i < new_weights.handle_num; i++) {
587     weights.push_back(*static_cast<mindspore::MSTensor *>(new_weights.handle_list[i]));
588   }
589   auto ret = impl->model_->UpdateWeights(weights);
590   return static_cast<OH_AI_Status>(ret.StatusCode());
591 }
592 
OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model)593 bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model) {
594   if (model == nullptr) {
595     MS_LOG(ERROR) << "model is nullptr.";
596     return false;
597   }
598   auto impl = static_cast<mindspore::ModelC *>(model);
599   return impl->model_->GetTrainMode();
600 }
601 
OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model,bool train)602 OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train) {
603   if (model == nullptr) {
604     MS_LOG(ERROR) << "model is nullptr.";
605     return OH_AI_STATUS_LITE_PARAM_INVALID;
606   }
607   auto impl = static_cast<mindspore::ModelC *>(model);
608   auto ret = impl->model_->SetTrainMode(train);
609   return static_cast<OH_AI_Status>(ret.StatusCode());
610 }
611 
OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model,int virtual_batch_multiplier,float lr,float momentum)612 OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, float momentum) {
613   if (model == nullptr) {
614     MS_LOG(ERROR) << "model is nullptr.";
615     return OH_AI_STATUS_LITE_PARAM_INVALID;
616   }
617   auto impl = static_cast<mindspore::ModelC *>(model);
618   auto ret = impl->model_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
619   return static_cast<OH_AI_Status>(ret.StatusCode());
620 }
621 
OH_AI_ExportModel(OH_AI_ModelHandle model,OH_AI_ModelType model_type,const char * model_file,OH_AI_QuantizationType quantization_type,bool export_inference_only,char ** output_tensor_name,size_t num)622 OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
623                        OH_AI_QuantizationType quantization_type, bool export_inference_only, char **output_tensor_name,
624                        size_t num) {
625   if (model == nullptr) {
626     MS_LOG(ERROR) << "model is nullptr.";
627     return OH_AI_STATUS_LITE_PARAM_INVALID;
628   }
629   auto impl = static_cast<mindspore::ModelC *>(model);
630   std::vector<std::string> tensor_name;
631   for (size_t i = 0; i < num; i++) {
632     tensor_name.push_back(output_tensor_name[i]);
633   }
634   auto ret = mindspore::Serialization::ExportModel(
635     *(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), model_file,
636     static_cast<mindspore::QuantizationType>(quantization_type), export_inference_only, tensor_name);
637   if (!ret.IsOk()) {
638     MS_LOG(ERROR) << "export model fail, ret :" << ret;
639   }
640   return static_cast<OH_AI_Status>(ret.StatusCode());
641 }
642 
OH_AI_ExportModelBuffer(OH_AI_ModelHandle model,OH_AI_ModelType model_type,char ** model_data,size_t * data_size,OH_AI_QuantizationType quantization_type,bool export_inference_only,char ** output_tensor_name,size_t num)643 OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data, size_t *data_size,
644                              OH_AI_QuantizationType quantization_type, bool export_inference_only,
645                              char **output_tensor_name, size_t num) {
646   if (model == nullptr) {
647     MS_LOG(ERROR) << "model is nullptr.";
648     return OH_AI_STATUS_LITE_PARAM_INVALID;
649   }
650   auto impl = static_cast<mindspore::ModelC *>(model);
651   std::vector<std::string> tensor_name;
652   for (size_t i = 0; i < num; i++) {
653     tensor_name.push_back(output_tensor_name[i]);
654   }
655   mindspore::Buffer buffer;
656   auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type),
657                                                    &buffer, static_cast<mindspore::QuantizationType>(quantization_type),
658                                                    export_inference_only, tensor_name);
659   auto data = reinterpret_cast<char *>(buffer.MutableData());
660   *model_data = reinterpret_cast<char *>(malloc(buffer.DataSize()));
661   if (*model_data == nullptr) {
662     MS_LOG(ERROR) << "malloc model_data failed.";
663     return OH_AI_STATUS_LITE_NULLPTR;
664   }
665   *data_size = buffer.DataSize();
666   memcpy(*model_data, data, buffer.DataSize());
667   if (!ret.IsOk()) {
668     MS_LOG(ERROR) << "export model fail, ret :" << ret;
669   }
670   return static_cast<OH_AI_Status>(ret.StatusCode());
671 }
672 
OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model,OH_AI_ModelType model_type,const char * weight_file,bool is_inference,bool enable_fp16,char ** changeable_weights_name,size_t num)673 OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *weight_file,
674                                              bool is_inference, bool enable_fp16, char **changeable_weights_name,
675                                              size_t num) {
676   if (model == nullptr) {
677     MS_LOG(ERROR) << "model is nullptr.";
678     return OH_AI_STATUS_LITE_PARAM_INVALID;
679   }
680   auto impl = static_cast<mindspore::ModelC *>(model);
681   std::vector<std::string> weights_name;
682   for (size_t i = 0; i < num; i++) {
683     weights_name.push_back(changeable_weights_name[i]);
684   }
685   auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(
686     *(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), weight_file, is_inference, enable_fp16,
687     weights_name);
688   if (!ret.IsOk()) {
689     MS_LOG(ERROR) << "export model fail, ret :" << ret;
690   }
691   return static_cast<OH_AI_Status>(ret.StatusCode());
692 }
693