1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/c_api/model_c.h"
17 #include "type_c_private.h"
18 #include "context_c.h"
19 #include <vector>
20 #include <cstdint>
21 #include "include/api/context.h"
22 #include "include/api/serialization.h"
23 #include "include/api/types.h"
24 #include "src/litert/cxx_api/tensor/tensor_impl.h"
25 #include "src/litert/cxx_api/model/model_impl.h"
26 #ifdef ENABLE_HI_APP_EVENT
27 #include "src/common/hi_app_event/hi_app_event.h"
28 #endif
29
30 namespace mindspore {
31 class ModelC {
32 public:
ModelC()33 ModelC() : model_(nullptr) {}
~ModelC()34 ~ModelC() {
35 for (auto in : inputs_) {
36 if (in != nullptr) {
37 delete in;
38 }
39 }
40 for (auto out : outputs_) {
41 if (out != nullptr) {
42 delete out;
43 }
44 }
45 for (auto out : outputs_train_) {
46 if (out != nullptr) {
47 delete out;
48 }
49 }
50
51 // In zero copy scene where user will call set or get allocator function, but when model is destroyed, the allocator
52 // table will not be freed, and its size continues to grow causing memory leak, so when ModelC is destroyed, clean
53 // allocator table.
54 CleanAllocatorTable();
55 }
56
57 MSTensor **GetInputs(size_t *input_num);
58 MSTensor **GetOutputs(size_t *output_num);
59 mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &ms_callback);
60 std::shared_ptr<Model> model_;
61 std::shared_ptr<Context> context_;
62
63 private:
64 MSTensor **GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors);
65 std::vector<MSTensor *> inputs_;
66 std::vector<MSTensor *> outputs_;
67 std::vector<MSTensor *> outputs_train_;
68 };
69
GetInputs(size_t * input_num)70 MSTensor **ModelC::GetInputs(size_t *input_num) {
71 if (model_ == nullptr) {
72 MS_LOG(ERROR) << "model_ is nullptr.";
73 return nullptr;
74 }
75 if (!inputs_.empty()) {
76 *input_num = inputs_.size();
77 return inputs_.data();
78 }
79 auto inputs = model_->GetInputs();
80 *input_num = inputs.size();
81 inputs_.resize(inputs.size(), nullptr);
82 for (size_t i = 0; i < inputs.size(); i++) {
83 inputs_[i] = new (std::nothrow) MSTensor(inputs[i].impl());
84 if (inputs_[i] == nullptr) {
85 inputs_.clear();
86 return nullptr;
87 }
88 }
89 return inputs_.data();
90 }
91
GetOutputs(size_t * output_num)92 MSTensor **ModelC::GetOutputs(size_t *output_num) {
93 if (model_->GetTrainMode() == true) {
94 return GetOutputsTensor(output_num, &outputs_train_);
95 } else {
96 return GetOutputsTensor(output_num, &outputs_);
97 }
98 }
99
GetOutputsTensor(size_t * output_num,std::vector<MSTensor * > * vec_tensors)100 MSTensor **ModelC::GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors) {
101 if (model_ == nullptr) {
102 MS_LOG(ERROR) << "model_ is nullptr.";
103 return nullptr;
104 }
105 if (!vec_tensors->empty()) {
106 *output_num = vec_tensors->size();
107 return vec_tensors->data();
108 }
109
110 auto outputs = model_->GetOutputs();
111 *output_num = outputs.size();
112 vec_tensors->resize(outputs.size(), nullptr);
113 for (size_t i = 0; i < outputs.size(); i++) {
114 (*vec_tensors)[i] = new (std::nothrow) MSTensor(outputs[i].impl());
115 if ((*vec_tensors)[i] == nullptr) {
116 vec_tensors->clear();
117 return nullptr;
118 }
119 }
120 return vec_tensors->data();
121 }
122
TransCallBack(const OH_AI_KernelCallBack & ms_callback)123 mindspore::MSKernelCallBack ModelC::TransCallBack(const OH_AI_KernelCallBack &ms_callback) {
124 mindspore::MSKernelCallBack call_back = nullptr;
125 if (ms_callback != nullptr) {
126 call_back = [&](const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
127 const mindspore::MSCallBackParam &opInfo) {
128 std::vector<OH_AI_TensorHandle> vec_inputs;
129 std::vector<OH_AI_TensorHandle> vec_outputs;
130 OH_AI_CallBackParam call_back = {const_cast<char *>(opInfo.node_name.c_str()),
131 const_cast<char *>(opInfo.node_type.c_str())};
132 size_t inputs_handle_num = inputs.size();
133 for (size_t i = 0; i < inputs_handle_num; i++) {
134 vec_inputs.push_back(static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(inputs)[i])));
135 }
136 size_t outputs_handle_num = outputs.size();
137 for (size_t i = 0; i < outputs_handle_num; i++) {
138 vec_outputs.push_back(
139 static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(outputs)[i])));
140 }
141 OH_AI_TensorHandleArray handle_inputs = {inputs_handle_num, vec_inputs.data()};
142 OH_AI_TensorHandleArray handle_outputs = {outputs_handle_num, vec_outputs.data()};
143 return ms_callback(handle_inputs, handle_outputs, call_back);
144 };
145 }
146 return call_back;
147 }
148 } // namespace mindspore
149
OH_AI_ModelCreate()150 OH_AI_ModelHandle OH_AI_ModelCreate() {
151 auto impl = new (std::nothrow) mindspore::ModelC();
152 if (impl == nullptr) {
153 MS_LOG(ERROR) << "Model implement is nullptr.";
154 return nullptr;
155 }
156 impl->model_ = std::make_shared<mindspore::Model>();
157 if (impl->model_ == nullptr) {
158 MS_LOG(ERROR) << "model_ is nullptr.";
159 delete impl;
160 return nullptr;
161 }
162 return static_cast<OH_AI_ModelHandle>(impl);
163 }
164
OH_AI_ModelDestroy(OH_AI_ModelHandle * model)165 void OH_AI_ModelDestroy(OH_AI_ModelHandle *model) {
166 if (model == nullptr || *model == nullptr) {
167 MS_LOG(ERROR) << "model is nullptr.";
168 return;
169 }
170 auto impl = static_cast<mindspore::ModelC *>(*model);
171 delete impl;
172 *model = nullptr;
173 }
174
OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model,void * workspace,size_t workspace_size)175 void OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model, void *workspace, size_t workspace_size) {
176 MS_LOG(ERROR) << "Unsupported Feature.";
177 return;
178 }
179
OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model)180 size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {
181 MS_LOG(ERROR) << "Unsupported Feature.";
182 return 0;
183 }
184
OH_AI_ModelBuild(OH_AI_ModelHandle model,const void * model_data,size_t data_size,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context)185 OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, OH_AI_ModelType model_type,
186 const OH_AI_ContextHandle model_context) {
187 if (model == nullptr || model_data == nullptr || model_context == nullptr) {
188 MS_LOG(ERROR) << "model/model_data/model_context is nullptr.";
189 return OH_AI_STATUS_LITE_NULLPTR;
190 }
191 if (model_type == OH_AI_MODELTYPE_INVALID) {
192 MS_LOG(ERROR) << "model_type is invalid.";
193 return OH_AI_STATUS_LITE_PARAM_INVALID;
194 }
195 mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
196 auto impl = static_cast<mindspore::ModelC *>(model);
197 if (impl->context_.get() != context->context_ && context->owned_by_model_) {
198 MS_LOG(ERROR) << "context is owned by other model.";
199 return OH_AI_STATUS_LITE_PARAM_INVALID;
200 }
201 if (impl->context_.get() != context->context_) {
202 impl->context_.reset(context->context_);
203 context->owned_by_model_ = true;
204 }
205 auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_);
206 return static_cast<OH_AI_Status>(ret.StatusCode());
207 }
208
OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model,const char * model_path,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context)209 OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
210 const OH_AI_ContextHandle model_context) {
211 if (model == nullptr || model_path == nullptr || model_context == nullptr) {
212 MS_LOG(ERROR) << "model/model_path/model_context is nullptr.";
213 return OH_AI_STATUS_LITE_NULLPTR;
214 }
215 if (model_type == OH_AI_MODELTYPE_INVALID) {
216 MS_LOG(ERROR) << "model_type is invalid.";
217 return OH_AI_STATUS_LITE_PARAM_INVALID;
218 }
219 mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
220 auto impl = static_cast<mindspore::ModelC *>(model);
221 if (impl->context_.get() != context->context_ && context->owned_by_model_) {
222 MS_LOG(ERROR) << "context is owned by other model.";
223 return OH_AI_STATUS_LITE_PARAM_INVALID;
224 }
225 if (impl->context_.get() != context->context_) {
226 impl->context_.reset(context->context_);
227 context->owned_by_model_ = true;
228 }
229 auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_);
230 return static_cast<OH_AI_Status>(ret.StatusCode());
231 }
232
OH_AI_ModelResize(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray inputs,OH_AI_ShapeInfo * shape_infos,size_t shape_info_num)233 OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos,
234 size_t shape_info_num) {
235 if (model == nullptr || shape_infos == nullptr) {
236 MS_LOG(ERROR) << "model/shape_infos is nullptr.";
237 return OH_AI_STATUS_LITE_NULLPTR;
238 }
239 std::vector<mindspore::MSTensor> vec_inputs;
240 for (size_t i = 0; i < inputs.handle_num; ++i) {
241 vec_inputs.push_back(*static_cast<mindspore::MSTensor *>(inputs.handle_list[i]));
242 }
243
244 std::vector<std::vector<int64_t>> vec_dims;
245 for (size_t i = 0; i < shape_info_num; i++) {
246 std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num);
247 if (std::any_of(shape.begin(), shape.end(), [](int64_t val) { return val < 0 || val > INT32_MAX; })) {
248 MS_LOG(ERROR) << "Invalid shape: " << shape << ", each dimension must be in [0, INT32_MAX]";
249 return OH_AI_STATUS_LITE_PARAM_INVALID;
250 }
251 vec_dims.push_back(shape);
252 }
253 auto impl = static_cast<mindspore::ModelC *>(model);
254 auto ret = impl->model_->Resize(vec_inputs, vec_dims);
255 return static_cast<OH_AI_Status>(ret.StatusCode());
256 }
257
OH_AI_ModelPredict(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray inputs,OH_AI_TensorHandleArray * outputs,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)258 OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_TensorHandleArray *outputs,
259 const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
260 if (model == nullptr) {
261 MS_LOG(ERROR) << "model is nullptr.";
262 return OH_AI_STATUS_LITE_NULLPTR;
263 }
264 auto impl = static_cast<mindspore::ModelC *>(model);
265 size_t input_num;
266 (void)impl->GetInputs(&input_num);
267 if (input_num != inputs.handle_num) {
268 MS_LOG(ERROR) << "Wrong input size.";
269 return OH_AI_STATUS_LITE_ERROR;
270 }
271
272 std::vector<mindspore::MSTensor> ms_tensor_inputs;
273 for (size_t i = 0; i < inputs.handle_num; i++) {
274 if (inputs.handle_list[i] != nullptr) {
275 auto user_input = static_cast<mindspore::MSTensor *>(inputs.handle_list[i]);
276 ms_tensor_inputs.push_back(*user_input);
277 } else {
278 MS_LOG(ERROR) << "input handle is nullptr.";
279 return OH_AI_STATUS_LITE_NULLPTR;
280 }
281 }
282
283 mindspore::MSKernelCallBack before_call_back = impl->TransCallBack(before);
284 mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after);
285 std::vector<mindspore::MSTensor> ms_tensor_outputs;
286
287 size_t output_num;
288 (void)impl->GetOutputs(&output_num);
289 auto handle_num = outputs->handle_num;
290 if (handle_num == output_num) {
291 MS_LOG(INFO) << "use user provided output";
292 for (size_t i = 0; i < output_num; i++) {
293 if (outputs->handle_list[i] == nullptr) {
294 MS_LOG(ERROR) << "user provided output array handle_list[" << i << "] is nullptr";
295 return OH_AI_STATUS_LITE_NULLPTR;
296 }
297 ms_tensor_outputs.push_back(*static_cast<mindspore::MSTensor *>(outputs->handle_list[i]));
298 }
299 }
300
301 auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
302 if (!ret.IsOk()) {
303 MS_LOG(ERROR) << "Predict fail, ret :" << ret;
304 return static_cast<OH_AI_Status>(ret.StatusCode());
305 }
306
307 if (handle_num == output_num) {
308 return OH_AI_STATUS_SUCCESS;
309 }
310
311 outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&(outputs->handle_num)));
312 return static_cast<OH_AI_Status>(ret.StatusCode());
313 }
314
OH_AI_ModelRunStep(OH_AI_ModelHandle model,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)315 OH_AI_Status OH_AI_ModelRunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
316 MS_LOG(ERROR) << "Unsupported Feature.";
317 return OH_AI_STATUS_LITE_NOT_SUPPORT;
318 }
319
OH_AI_ModelExportWeight(const OH_AI_ModelHandle model,const char * export_path)320 OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *export_path) {
321 MS_LOG(ERROR) << "Unsupported Feature.";
322 return OH_AI_STATUS_LITE_NOT_SUPPORT;
323 }
324
OH_AI_ModelGetInputs(const OH_AI_ModelHandle model)325 OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) {
326 if (model == nullptr) {
327 MS_LOG(ERROR) << "model is nullptr.";
328 return {0, nullptr};
329 }
330 auto impl = static_cast<mindspore::ModelC *>(model);
331 size_t input_num = 0;
332 auto handles = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetInputs(&input_num));
333 return {input_num, handles};
334 }
335
OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model)336 OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) {
337 if (model == nullptr) {
338 MS_LOG(ERROR) << "model is nullptr.";
339 return {0, nullptr};
340 }
341 auto impl = static_cast<mindspore::ModelC *>(model);
342 size_t output_num;
343 auto handles = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&output_num));
344 return {output_num, handles};
345 }
346
OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model,const char * tensor_name)347 OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
348 if (model == nullptr || tensor_name == nullptr) {
349 MS_LOG(ERROR) << "model/tensor_name is nullptr.";
350 return nullptr;
351 }
352 auto impl = static_cast<mindspore::ModelC *>(model);
353 size_t input_num;
354 auto inputs = impl->GetInputs(&input_num);
355 for (size_t i = 0; i < input_num; i++) {
356 if (inputs[i]->Name() == tensor_name) {
357 return static_cast<OH_AI_TensorHandle>(inputs[i]);
358 }
359 }
360 MS_LOG(ERROR) << "tensor is not exist.";
361 return nullptr;
362 }
363
OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model,const char * tensor_name)364 OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
365 if (model == nullptr || tensor_name == nullptr) {
366 MS_LOG(ERROR) << "model/tensor_name is nullptr.";
367 return nullptr;
368 }
369 auto impl = static_cast<mindspore::ModelC *>(model);
370 size_t output_num;
371 auto outputs = impl->GetOutputs(&output_num);
372 for (size_t i = 0; i < output_num; i++) {
373 if (outputs[i]->Name() == tensor_name) {
374 return static_cast<OH_AI_TensorHandle>(outputs[i]);
375 }
376 }
377 MS_LOG(ERROR) << "tensor is not exist.";
378 return nullptr;
379 }
380
OH_AI_TrainCfgCreate()381 OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate() {
382 auto impl = new (std::nothrow) mindspore::TrainCfg();
383 if (impl == nullptr) {
384 MS_LOG(ERROR) << "TrainCfg implement is nullptr.";
385 return nullptr;
386 }
387 return static_cast<OH_AI_TrainCfgHandle>(impl);
388 }
389
OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle * train_cfg)390 void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg) {
391 if (train_cfg != nullptr && *train_cfg != nullptr) {
392 auto impl = static_cast<mindspore::TrainCfg *>(*train_cfg);
393 delete impl;
394 *train_cfg = nullptr;
395 }
396 }
397
OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg,size_t * num)398 char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) {
399 if (train_cfg == nullptr || num == nullptr) {
400 MS_LOG(ERROR) << "train_cfg/num is nullptr.";
401 return nullptr;
402 }
403 auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
404 auto loss_name = impl->GetLossName();
405 *num = loss_name.size();
406 char **name = static_cast<char **>(malloc(loss_name.size() * sizeof(char *)));
407 if (name == nullptr) {
408 MS_LOG(ERROR) << "Failed to malloc loss_name.";
409 return nullptr;
410 }
411 for (size_t i = 0; i < loss_name.size(); i++) {
412 name[i] = static_cast<char *>(malloc(loss_name[i].size() + 1));
413 if (name[i] == nullptr) {
414 for(size_t j = 0; j < i; j++){
415 free(name[j]);
416 }
417 MS_LOG(ERROR) << "Failed to malloc name.";
418 return nullptr;
419 }
420 memcpy(name[i], loss_name[i].c_str(), loss_name[i].size() + 1);
421 }
422 return name;
423 }
424
OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg,const char ** loss_name,size_t num)425 void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num) {
426 if (train_cfg == nullptr || loss_name == nullptr || *loss_name == nullptr) {
427 MS_LOG(ERROR) << "train_cfg/loss_name is nullptr.";
428 return;
429 }
430 auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
431 std::vector<std::string> vec_name;
432 for (size_t i = 0; i < num; i++) {
433 vec_name.push_back(loss_name[i]);
434 }
435 impl->SetLossName(vec_name);
436 }
437
OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg)438 OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg) {
439 if (train_cfg == nullptr) {
440 MS_LOG(ERROR) << "train_cfg is nullptr, return OH_AI_KO0";
441 return OH_AI_KO0;
442 }
443 auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
444 return static_cast<OH_AI_OptimizationLevel>(impl->optimization_level_);
445 }
446
OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg,OH_AI_OptimizationLevel level)447 void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level) {
448 if (train_cfg == nullptr) {
449 MS_LOG(ERROR) << "train_cfg is nullptr.";
450 return;
451 }
452 auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
453 impl->optimization_level_ = static_cast<mindspore::OptimizationLevel>(level);
454 }
455
OH_AI_TrainModelBuild(OH_AI_ModelHandle model,const void * model_data,size_t data_size,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context,const OH_AI_TrainCfgHandle train_cfg)456 OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, OH_AI_ModelType model_type,
457 const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg) {
458 if (model == nullptr || model_data == nullptr || model_context == nullptr) {
459 MS_LOG(ERROR) << "model/model_data/model_context is nullptr.";
460 return OH_AI_STATUS_LITE_NULLPTR;
461 }
462 if (model_type == OH_AI_MODELTYPE_INVALID) {
463 MS_LOG(ERROR) << "model_type is invalid.";
464 return OH_AI_STATUS_LITE_PARAM_INVALID;
465 }
466 auto impl = static_cast<mindspore::ModelC *>(model);
467
468 mindspore::Graph graph;
469 auto status =
470 mindspore::Serialization::Load(model_data, data_size, static_cast<mindspore::ModelType>(model_type), &graph);
471 if (status != mindspore::kSuccess) {
472 MS_LOG(ERROR) << "load ms file failed.";
473 return OH_AI_STATUS_LITE_ERROR;
474 }
475 auto context = static_cast<mindspore::ContextC *>(model_context);
476 auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
477 if (impl->context_.get() != context->context_ && context->owned_by_model_) {
478 MS_LOG(ERROR) << "context is owned by other model.";
479 return OH_AI_STATUS_LITE_PARAM_INVALID;
480 }
481 if (impl->context_.get() != context->context_) {
482 impl->context_.reset(context->context_);
483 context->owned_by_model_ = true;
484 }
485 auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
486 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
487 if (ret != mindspore::kSuccess) {
488 MS_LOG(ERROR) << "Load and compile failed";
489 }
490 return static_cast<OH_AI_Status>(ret.StatusCode());
491 }
492
OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model,const char * model_path,OH_AI_ModelType model_type,const OH_AI_ContextHandle model_context,const OH_AI_TrainCfgHandle train_cfg)493 OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
494 const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg) {
495 if (model == nullptr || model_path == nullptr || model_context == nullptr) {
496 MS_LOG(ERROR) << "model/model_path/model_context is nullptr.";
497 return OH_AI_STATUS_LITE_NULLPTR;
498 }
499 if (model_type == OH_AI_MODELTYPE_INVALID) {
500 MS_LOG(ERROR) << "model_type is invalid.";
501 return OH_AI_STATUS_LITE_PARAM_INVALID;
502 }
503 auto impl = static_cast<mindspore::ModelC *>(model);
504
505 mindspore::Graph graph;
506 auto status = mindspore::Serialization::Load(model_path, static_cast<mindspore::ModelType>(model_type), &graph);
507 if (status != mindspore::kSuccess) {
508 MS_LOG(ERROR) << "load ms file failed. " << model_path;
509 return OH_AI_STATUS_LITE_ERROR;
510 }
511 auto context = static_cast<mindspore::ContextC *>(model_context);
512 auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
513 if (impl->context_.get() != context->context_ && context->owned_by_model_) {
514 MS_LOG(ERROR) << "context is owned by other model.";
515 return OH_AI_STATUS_LITE_PARAM_INVALID;
516 }
517 if (impl->context_.get() != context->context_) {
518 impl->context_.reset(context->context_);
519 context->owned_by_model_ = true;
520 }
521 auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
522 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
523 if (ret != mindspore::kSuccess) {
524 MS_LOG(ERROR) << "Load and compile failed";
525 }
526 return static_cast<OH_AI_Status>(ret.StatusCode());
527 }
528
OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model,float learning_rate)529 OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate) {
530 if (model == nullptr) {
531 MS_LOG(ERROR) << "model is nullptr.";
532 return OH_AI_STATUS_LITE_PARAM_INVALID;
533 }
534 auto impl = static_cast<mindspore::ModelC *>(model);
535 auto ret = impl->model_->SetLearningRate(learning_rate);
536 return static_cast<OH_AI_Status>(ret.StatusCode());
537 }
538
OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model)539 float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model) {
540 if (model == nullptr) {
541 MS_LOG(ERROR) << "model is nullptr.";
542 return OH_AI_STATUS_LITE_PARAM_INVALID;
543 }
544 auto impl = static_cast<mindspore::ModelC *>(model);
545 return impl->model_->GetLearningRate();
546 }
547
OH_AI_RunStep(OH_AI_ModelHandle model,const OH_AI_KernelCallBack before,const OH_AI_KernelCallBack after)548 OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
549 if (model == nullptr) {
550 MS_LOG(ERROR) << "model is nullptr.";
551 return OH_AI_STATUS_LITE_PARAM_INVALID;
552 }
553 auto impl = static_cast<mindspore::ModelC *>(model);
554 auto ret = impl->model_->RunStep(impl->TransCallBack(before), impl->TransCallBack(after));
555 return static_cast<OH_AI_Status>(ret.StatusCode());
556 }
557
OH_AI_ModelGetWeights(OH_AI_ModelHandle model)558 OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model) {
559 if (model == nullptr) {
560 MS_LOG(ERROR) << "model is nullptr.";
561 return {0, nullptr};
562 }
563 auto impl = static_cast<mindspore::ModelC *>(model);
564 auto features = impl->model_->GetFeatureMaps();
565 size_t handle_num = features.size();
566
567 mindspore::MSTensor **handle_list =
568 static_cast<mindspore::MSTensor **>(malloc(handle_num * sizeof(mindspore::MSTensor *)));
569 if (handle_list == nullptr) {
570 MS_LOG(ERROR) << "Failed to malloc handle_list.";
571 return {0, nullptr};
572 }
573 for (size_t i = 0; i < handle_num; i++) {
574 handle_list[i] = new (std::nothrow) mindspore::MSTensor(features[i].impl());
575 }
576 return {handle_num, reinterpret_cast<OH_AI_TensorHandle *>(handle_list)};
577 }
578
OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model,const OH_AI_TensorHandleArray new_weights)579 OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights) {
580 if (model == nullptr) {
581 MS_LOG(ERROR) << "model is nullptr.";
582 return OH_AI_STATUS_LITE_PARAM_INVALID;
583 }
584 auto impl = static_cast<mindspore::ModelC *>(model);
585 std::vector<mindspore::MSTensor> weights;
586 for (size_t i = 0; i < new_weights.handle_num; i++) {
587 weights.push_back(*static_cast<mindspore::MSTensor *>(new_weights.handle_list[i]));
588 }
589 auto ret = impl->model_->UpdateWeights(weights);
590 return static_cast<OH_AI_Status>(ret.StatusCode());
591 }
592
OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model)593 bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model) {
594 if (model == nullptr) {
595 MS_LOG(ERROR) << "model is nullptr.";
596 return false;
597 }
598 auto impl = static_cast<mindspore::ModelC *>(model);
599 return impl->model_->GetTrainMode();
600 }
601
OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model,bool train)602 OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train) {
603 if (model == nullptr) {
604 MS_LOG(ERROR) << "model is nullptr.";
605 return OH_AI_STATUS_LITE_PARAM_INVALID;
606 }
607 auto impl = static_cast<mindspore::ModelC *>(model);
608 auto ret = impl->model_->SetTrainMode(train);
609 return static_cast<OH_AI_Status>(ret.StatusCode());
610 }
611
OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model,int virtual_batch_multiplier,float lr,float momentum)612 OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, float momentum) {
613 if (model == nullptr) {
614 MS_LOG(ERROR) << "model is nullptr.";
615 return OH_AI_STATUS_LITE_PARAM_INVALID;
616 }
617 auto impl = static_cast<mindspore::ModelC *>(model);
618 auto ret = impl->model_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
619 return static_cast<OH_AI_Status>(ret.StatusCode());
620 }
621
OH_AI_ExportModel(OH_AI_ModelHandle model,OH_AI_ModelType model_type,const char * model_file,OH_AI_QuantizationType quantization_type,bool export_inference_only,char ** output_tensor_name,size_t num)622 OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
623 OH_AI_QuantizationType quantization_type, bool export_inference_only, char **output_tensor_name,
624 size_t num) {
625 if (model == nullptr) {
626 MS_LOG(ERROR) << "model is nullptr.";
627 return OH_AI_STATUS_LITE_PARAM_INVALID;
628 }
629 auto impl = static_cast<mindspore::ModelC *>(model);
630 std::vector<std::string> tensor_name;
631 for (size_t i = 0; i < num; i++) {
632 tensor_name.push_back(output_tensor_name[i]);
633 }
634 auto ret = mindspore::Serialization::ExportModel(
635 *(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), model_file,
636 static_cast<mindspore::QuantizationType>(quantization_type), export_inference_only, tensor_name);
637 if (!ret.IsOk()) {
638 MS_LOG(ERROR) << "export model fail, ret :" << ret;
639 }
640 return static_cast<OH_AI_Status>(ret.StatusCode());
641 }
642
OH_AI_ExportModelBuffer(OH_AI_ModelHandle model,OH_AI_ModelType model_type,char ** model_data,size_t * data_size,OH_AI_QuantizationType quantization_type,bool export_inference_only,char ** output_tensor_name,size_t num)643 OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data, size_t *data_size,
644 OH_AI_QuantizationType quantization_type, bool export_inference_only,
645 char **output_tensor_name, size_t num) {
646 if (model == nullptr) {
647 MS_LOG(ERROR) << "model is nullptr.";
648 return OH_AI_STATUS_LITE_PARAM_INVALID;
649 }
650 auto impl = static_cast<mindspore::ModelC *>(model);
651 std::vector<std::string> tensor_name;
652 for (size_t i = 0; i < num; i++) {
653 tensor_name.push_back(output_tensor_name[i]);
654 }
655 mindspore::Buffer buffer;
656 auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type),
657 &buffer, static_cast<mindspore::QuantizationType>(quantization_type),
658 export_inference_only, tensor_name);
659 auto data = reinterpret_cast<char *>(buffer.MutableData());
660 *model_data = reinterpret_cast<char *>(malloc(buffer.DataSize()));
661 if (*model_data == nullptr) {
662 MS_LOG(ERROR) << "malloc model_data failed.";
663 return OH_AI_STATUS_LITE_NULLPTR;
664 }
665 *data_size = buffer.DataSize();
666 memcpy(*model_data, data, buffer.DataSize());
667 if (!ret.IsOk()) {
668 MS_LOG(ERROR) << "export model fail, ret :" << ret;
669 }
670 return static_cast<OH_AI_Status>(ret.StatusCode());
671 }
672
OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model,OH_AI_ModelType model_type,const char * weight_file,bool is_inference,bool enable_fp16,char ** changeable_weights_name,size_t num)673 OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *weight_file,
674 bool is_inference, bool enable_fp16, char **changeable_weights_name,
675 size_t num) {
676 if (model == nullptr) {
677 MS_LOG(ERROR) << "model is nullptr.";
678 return OH_AI_STATUS_LITE_PARAM_INVALID;
679 }
680 auto impl = static_cast<mindspore::ModelC *>(model);
681 std::vector<std::string> weights_name;
682 for (size_t i = 0; i < num; i++) {
683 weights_name.push_back(changeable_weights_name[i]);
684 }
685 auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(
686 *(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), weight_file, is_inference, enable_fp16,
687 weights_name);
688 if (!ret.IsOk()) {
689 MS_LOG(ERROR) << "export model fail, ret :" << ret;
690 }
691 return static_cast<OH_AI_Status>(ret.StatusCode());
692 }
693