1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/js_api/mslite_model_napi.h"
17 #include <climits>
18 #include <algorithm>
19 #include <random>
20 #include <cstring>
21 #include <memory>
22 #include <map>
23 #include <vector>
24 #include <unistd.h>
25 #include <fcntl.h>
26 #include <sys/mman.h>
27 #include <sys/stat.h>
28 #include "include/js_api/mstensor_napi.h"
29 #include "include/js_api/common_napi.h"
30 #include "include/js_api/ms_parameters_napi.h"
31 #include "include/js_api/ms_errors.h"
32 #include "include/js_api/mslite_model_callback_napi.h"
33 #include "src/common/log.h"
34 #include "mindspore/lite/src/common/log.h"
35 #include "include/c_api/model_c.h"
36 #include "include/c_api/context_c.h"
37 #include "include/c_api/types_c.h"
38 #include "include/js_api/nnrt_device_desc_napi.h"
39
40 namespace mindspore {
41 thread_local napi_ref MSLiteModelNapi::constructor_ = nullptr;
42 ModelInfo *MSLiteModelNapi::model_info_ = nullptr;
43 ContextInfo *MSLiteModelNapi::context_ = nullptr;
44 std::mutex MSLiteModelNapi::create_mutex_;
45 napi_ref MSLiteModelNapi::tensorFormat_ = nullptr;
46 napi_ref MSLiteModelNapi::tensorDataType_ = nullptr;
47 napi_ref MSLiteModelNapi::contextThreadAffinityMode_ = nullptr;
48 napi_ref MSLiteModelNapi::contextQuantizationType_ = nullptr;
49 napi_ref MSLiteModelNapi::contextOptimizationLevel_ = nullptr;
50 napi_ref MSLiteModelNapi::contextPerformanceMode_ = nullptr;
51 napi_ref MSLiteModelNapi::contextPriority_ = nullptr;
52 napi_ref MSLiteModelNapi::contextNnrtDeviceType_ = nullptr;
53
54 #define GET_PARAMS(env, info, num) \
55 size_t argc = num; \
56 napi_value argv[num] = {0}; \
57 napi_value thisVar = nullptr; \
58 void *data; \
59 napi_get_cb_info(env, info, &argc, argv, &thisVar, &data)
60
61 namespace {
62 const int ARGS_ONE = 1;
63 const int ARGS_TWO = 2;
64 const int ARGS_THREE = 3;
65 const int ARGS_FOUR = 4;
66
67 const int PARAM0 = 0;
68 const int PARAM1 = 1;
69 const int PARAM2 = 2;
70 const int PARAM3 = 3;
71 const int PARAM4 = 4;
72 const int UNSET_VALUE = -1;
73
74 const int SIZE = 100;
75
76 const std::string CLASS_NAME = "Model";
77
78 const std::unordered_map<std::string, DeviceType> kDeviceTypes{
79 {"cpu", kCPU},
80 {"nnrt", kNNRt},
81 {"gpu", kGPU},
82 };
83 } // namespace
84
MSLiteModelNapi()85 MSLiteModelNapi::MSLiteModelNapi() : native_model_(nullptr), env_(nullptr) {
86 MS_LOG(INFO) << "MSLiteModelNapi Instances create.";
87 }
88
~MSLiteModelNapi()89 MSLiteModelNapi::~MSLiteModelNapi() {
90 native_model_ = nullptr;
91 env_ = nullptr;
92 MS_LOG(INFO) << "MSLiteModelNapi Instances destroy.";
93 }
94
Finalize(napi_env env,void * nativeObject,void * finalize)95 void MSLiteModelNapi::Finalize(napi_env env, void *nativeObject, void *finalize) {
96 (void)env;
97 (void)finalize;
98 if (nativeObject != nullptr) {
99 // delete nativeObject
100 auto obj = static_cast<MSLiteModelNapi *>(nativeObject);
101 delete obj;
102 obj = nullptr;
103 }
104 MS_LOG(INFO) << "Finalize success";
105 }
106
Init(napi_env env,napi_value exports)107 napi_value MSLiteModelNapi::Init(napi_env env, napi_value exports) {
108 napi_property_descriptor properties[] = {
109 DECLARE_NAPI_FUNCTION("getInputs", GetInputs),
110 DECLARE_NAPI_FUNCTION("resize", Resize),
111 DECLARE_NAPI_FUNCTION("predict", PredictAsync),
112 DECLARE_NAPI_FUNCTION("runStep", RunStep),
113 DECLARE_NAPI_FUNCTION("getWeights", GetWeights),
114 DECLARE_NAPI_FUNCTION("updateWeights", UpdateWeights),
115 DECLARE_NAPI_FUNCTION("setupVirtualBatch", SetupVirtualBatch),
116 DECLARE_NAPI_FUNCTION("exportModel", ExportModel),
117 DECLARE_NAPI_FUNCTION("exportWeightsCollaborateWithMicro", ExportWeightsCollaborateWithMicro),
118 DECLARE_NAPI_GETTER_SETTER("trainMode", GetTrainMode, SetTrainMode),
119 DECLARE_NAPI_GETTER_SETTER("learningRate", GetLearningRate, SetLearningRate),
120 };
121
122 napi_property_descriptor staticProperty[] = {
123 DECLARE_NAPI_STATIC_FUNCTION("loadModelFromFile", LoadMSLiteModelFromFile),
124 DECLARE_NAPI_STATIC_FUNCTION("loadModelFromBuffer", LoadMSLiteModelFromBuffer),
125 DECLARE_NAPI_STATIC_FUNCTION("loadModelFromFd", LoadMSLiteModelFromFd),
126 DECLARE_NAPI_STATIC_FUNCTION("loadTrainModelFromFile", LoadMSLiteTrainModelFromFile),
127 DECLARE_NAPI_STATIC_FUNCTION("loadTrainModelFromBuffer", LoadMSLiteTrainModelFromBuffer),
128 DECLARE_NAPI_STATIC_FUNCTION("loadTrainModelFromFd", LoadMSLiteTrainModelFromFd),
129 DECLARE_NAPI_STATIC_FUNCTION("getAllNNRTDeviceDescriptions", GetAllNnrtDeviceDescs),
130 DECLARE_NAPI_PROPERTY("Format", CreateFormatObject(env)),
131 DECLARE_NAPI_PROPERTY("DataType", CreateDataTypeObject(env)),
132 DECLARE_NAPI_PROPERTY("ThreadAffinityMode", CreateThreadAffinityModeObject(env)),
133 DECLARE_NAPI_PROPERTY("QuantizationType", CreateQuantizationTypeObject(env)),
134 DECLARE_NAPI_PROPERTY("OptimizationLevel", CreateOptimizationLevelObject(env)),
135 DECLARE_NAPI_PROPERTY("PerformanceMode", CreatePerformanceModeObject(env)),
136 DECLARE_NAPI_PROPERTY("Priority", CreatePriorityObject(env)),
137 DECLARE_NAPI_PROPERTY("NNRTDeviceType", CreateNnrtDeviceTypeObject(env)),
138 };
139
140 napi_value constructor = nullptr;
141 napi_status status = napi_define_class(env, CLASS_NAME.c_str(), NAPI_AUTO_LENGTH, Constructor, nullptr,
142 sizeof(properties) / sizeof(properties[0]), properties, &constructor);
143 if (status != napi_ok) {
144 MS_LOG(ERROR) << "Failed to define MSLiteModel class";
145 return nullptr;
146 }
147
148 status = napi_create_reference(env, constructor, REFERENCE_CREATION_COUNT, &constructor_);
149 if (status != napi_ok) {
150 MS_LOG(ERROR) << "Failed to create reference of constructor";
151 return nullptr;
152 }
153
154 status = napi_set_named_property(env, exports, CLASS_NAME.c_str(), constructor);
155 if (status != napi_ok) {
156 MS_LOG(ERROR) << "Failed to set constructor";
157 return nullptr;
158 }
159
160 status = napi_define_properties(env, exports, sizeof(staticProperty) / sizeof(staticProperty[0]), staticProperty);
161 if (status != napi_ok) {
162 MS_LOG(ERROR) << "Failed to define static function";
163 return nullptr;
164 }
165
166 MS_LOG(INFO) << "init success";
167 return exports;
168 }
169
CreateFormatObject(napi_env env)170 napi_value MSLiteModelNapi::CreateFormatObject(napi_env env)
171 {
172 napi_value result = nullptr;
173 napi_status status;
174 std::string propName;
175 int32_t refCount = 1;
176
177 status = napi_create_object(env, &result);
178 if (status == napi_ok) {
179 for (auto &iter : tensorFormatMap) {
180 propName = iter.first;
181 status = AddNamedProperty(env, result, propName, iter.second);
182 if (status != napi_ok) {
183 MS_LOG(ERROR) << "Failed to add named prop in CreateFormatObject.";
184 break;
185 }
186 propName.clear();
187 }
188 if (status == napi_ok) {
189 status = napi_create_reference(env, result, refCount, &tensorFormat_);
190 if (status == napi_ok) {
191 return result;
192 }
193 }
194 }
195 MS_LOG(ERROR) << "CreateFormatObject is Failed!";
196 napi_get_undefined(env, &result);
197 return result;
198 }
199
CreateDataTypeObject(napi_env env)200 napi_value MSLiteModelNapi::CreateDataTypeObject(napi_env env) {
201 napi_value result = nullptr;
202 napi_status status;
203 std::string propName;
204 int32_t refCount = 1;
205
206 status = napi_create_object(env, &result);
207 if (status == napi_ok) {
208 for (auto &iter : tensorDataTypeMap) {
209 propName = iter.first;
210 status = AddNamedProperty(env, result, propName, iter.second);
211 if (status != napi_ok) {
212 MS_LOG(ERROR) << "Failed to add named prop in CreateDataTypeObject.";
213 break;
214 }
215 propName.clear();
216 }
217 if (status == napi_ok) {
218 status = napi_create_reference(env, result, refCount, &tensorDataType_);
219 if (status == napi_ok) {
220 return result;
221 }
222 }
223 }
224 MS_LOG(ERROR) << "CreateDataTypeObject is Failed!";
225 napi_get_undefined(env, &result);
226 return result;
227 }
228
CreateThreadAffinityModeObject(napi_env env)229 napi_value MSLiteModelNapi::CreateThreadAffinityModeObject(napi_env env) {
230 napi_value result = nullptr;
231 napi_status status;
232 std::string propName;
233 int32_t refCount = 1;
234
235 status = napi_create_object(env, &result);
236 if (status == napi_ok) {
237 for (auto &iter : contextThreadAffinityModeMap) {
238 propName = iter.first;
239 status = AddNamedProperty(env, result, propName, iter.second);
240 if (status != napi_ok) {
241 MS_LOG(ERROR) << "Failed to add named prop in CreateThreadAffinityModeObject.";
242 break;
243 }
244 propName.clear();
245 }
246 if (status == napi_ok) {
247 status = napi_create_reference(env, result, refCount, &contextThreadAffinityMode_);
248 if (status == napi_ok) {
249 return result;
250 }
251 }
252 }
253 MS_LOG(ERROR) << "CreateThreadAffinityModeObject is Failed!";
254 napi_get_undefined(env, &result);
255 return result;
256 }
257
CreateQuantizationTypeObject(napi_env env)258 napi_value MSLiteModelNapi::CreateQuantizationTypeObject(napi_env env) {
259 napi_value result = nullptr;
260 napi_status status;
261 std::string propName;
262 int32_t refCount = 1;
263
264 status = napi_create_object(env, &result);
265 if (status == napi_ok) {
266 for (auto &iter : contextQuantizationTypeMap) {
267 propName = iter.first;
268 status = AddNamedProperty(env, result, propName, iter.second);
269 if (status != napi_ok) {
270 MS_LOG(ERROR) << "Failed to add named prop in CreateQuantizationTypeObject.";
271 break;
272 }
273 propName.clear();
274 }
275 if (status == napi_ok) {
276 status = napi_create_reference(env, result, refCount, &contextQuantizationType_);
277 if (status == napi_ok) {
278 return result;
279 }
280 }
281 }
282 MS_LOG(ERROR) << "CreateQuantizationTypeObject is Failed!";
283 napi_get_undefined(env, &result);
284 return result;
285 }
286
CreateOptimizationLevelObject(napi_env env)287 napi_value MSLiteModelNapi::CreateOptimizationLevelObject(napi_env env) {
288 napi_value result = nullptr;
289 napi_status status;
290 std::string propName;
291 int32_t refCount = 1;
292
293 status = napi_create_object(env, &result);
294 if (status == napi_ok) {
295 for (auto &iter : contextOptimizationLevelTypeMap) {
296 propName = iter.first;
297 status = AddNamedProperty(env, result, propName, iter.second);
298 if (status != napi_ok) {
299 MS_LOG(ERROR) << "Failed to add named prop in CreateOptimizationLevelObject.";
300 break;
301 }
302 propName.clear();
303 }
304 if (status == napi_ok) {
305 status = napi_create_reference(env, result, refCount, &contextOptimizationLevel_);
306 if (status == napi_ok) {
307 return result;
308 }
309 }
310 }
311 MS_LOG(ERROR) << "CreateOptimizationLevelObject is Failed!";
312 napi_get_undefined(env, &result);
313 return result;
314 }
315
CreatePerformanceModeObject(napi_env env)316 napi_value MSLiteModelNapi::CreatePerformanceModeObject(napi_env env) {
317 napi_value result = nullptr;
318 napi_status status;
319 std::string propName;
320 int32_t refCount = 1;
321
322 status = napi_create_object(env, &result);
323 if (status == napi_ok) {
324 for (auto &iter : contextPerformanceModeTypeMap) {
325 propName = iter.first;
326 status = AddNamedProperty(env, result, propName, iter.second);
327 if (status != napi_ok) {
328 MS_LOG(ERROR) << "Failed to add named prop in CreatePerformanceModeObject.";
329 break;
330 }
331 propName.clear();
332 }
333 if (status == napi_ok) {
334 status = napi_create_reference(env, result, refCount, &contextPerformanceMode_);
335 if (status == napi_ok) {
336 return result;
337 }
338 }
339 }
340 MS_LOG(ERROR) << "CreatePerformanceModeObject is Failed!";
341 napi_get_undefined(env, &result);
342 return result;
343 }
344
CreatePriorityObject(napi_env env)345 napi_value MSLiteModelNapi::CreatePriorityObject(napi_env env) {
346 napi_value result = nullptr;
347 napi_status status;
348 std::string propName;
349 int32_t refCount = 1;
350
351 status = napi_create_object(env, &result);
352 if (status == napi_ok) {
353 for (auto &iter : contextPriorityTypeMap) {
354 propName = iter.first;
355 status = AddNamedProperty(env, result, propName, iter.second);
356 if (status != napi_ok) {
357 MS_LOG(ERROR) << "Failed to add named prop in CreatePriorityObject.";
358 break;
359 }
360 propName.clear();
361 }
362 if (status == napi_ok) {
363 status = napi_create_reference(env, result, refCount, &contextPriority_);
364 if (status == napi_ok) {
365 return result;
366 }
367 }
368 }
369 MS_LOG(ERROR) << "CreatePriorityObject is Failed!";
370 napi_get_undefined(env, &result);
371 return result;
372 }
373
CreateNnrtDeviceTypeObject(napi_env env)374 napi_value MSLiteModelNapi::CreateNnrtDeviceTypeObject(napi_env env) {
375 napi_value result = nullptr;
376 napi_status status;
377 std::string propName;
378 int32_t refCount = 1;
379
380 status = napi_create_object(env, &result);
381 if (status == napi_ok) {
382 for (auto &iter : contextNnrtDeviceTypeTypeMap) {
383 propName = iter.first;
384 status = AddNamedProperty(env, result, propName, iter.second);
385 if (status != napi_ok) {
386 MS_LOG(ERROR) << "Failed to add named prop in CreateNnrtDeviceTypeObject.";
387 break;
388 }
389 propName.clear();
390 }
391 if (status == napi_ok) {
392 status = napi_create_reference(env, result, refCount, &contextNnrtDeviceType_);
393 if (status == napi_ok) {
394 return result;
395 }
396 }
397 }
398 MS_LOG(ERROR) << "CreateNnrtDeviceTypeObject is Failed!";
399 napi_get_undefined(env, &result);
400 return result;
401 }
402
AddNamedProperty(napi_env env,napi_value object,const std::string name,int32_t enumValue)403 napi_status MSLiteModelNapi::AddNamedProperty(napi_env env, napi_value object, const std::string name,
404 int32_t enumValue) {
405 napi_status status;
406 napi_value enumNapiValue;
407
408 status = napi_create_int32(env, enumValue, &enumNapiValue);
409 if (status == napi_ok) {
410 status = napi_set_named_property(env, object, name.c_str(), enumNapiValue);
411 }
412 return status;
413 }
414
GetAllNnrtDeviceDescs(napi_env env,napi_callback_info info)415 napi_value MSLiteModelNapi::GetAllNnrtDeviceDescs(napi_env env, napi_callback_info info) {
416 size_t num;
417 napi_value jsResult = nullptr;
418 NNRTDeviceDesc *devices = OH_AI_GetAllNNRTDeviceDescs(&num);
419 if (devices == nullptr) {
420 MS_LOG(ERROR) << "Get all nnrt devices error, may nnrt is not supported.";
421 OH_AI_DestroyAllNNRTDeviceDescs(&devices);
422 return jsResult;
423 }
424
425 MS_LOG(INFO) << "all nnrt devices size: " << num;
426 napi_create_array_with_length(env, num, &jsResult);
427 for (size_t i = 0; i < num; i++) {
428 NnrtDeviceDesc nnrt_device;
429 NNRTDeviceDesc *nnrt_device_desc = OH_AI_GetElementOfNNRTDeviceDescs(devices, i);
430 nnrt_device.name.assign(OH_AI_GetNameFromNNRTDeviceDesc(nnrt_device_desc));
431 size_t id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(nnrt_device_desc);
432 nnrt_device.id = id;
433 nnrt_device.type = static_cast<ContextNnrtDeviceType>(OH_AI_GetTypeFromNNRTDeviceDesc(nnrt_device_desc));
434 auto status = napi_set_element(env, jsResult, i, NnrtDeviceDescNapi::NewInstance(env, nnrt_device));
435 if (status != napi_ok) {
436 MS_LOG(ERROR) << "napi_set_element failed! code: " << status;
437 OH_AI_DestroyAllNNRTDeviceDescs(&devices);
438 return jsResult;
439 }
440 }
441 MS_LOG(INFO) << "get All nnrt devices success!";
442 OH_AI_DestroyAllNNRTDeviceDescs(&devices);
443 return jsResult;
444 }
445
CreateModel(ModelInfo * model_info_ptr,ContextInfo * context_info_ptr)446 std::shared_ptr<mindspore::Model> MSLiteModelNapi::CreateModel(ModelInfo *model_info_ptr,
447 ContextInfo *context_info_ptr) {
448 if (context_info_ptr == nullptr) {
449 MS_LOG(ERROR) << "context_info_ptr is nullptr.";
450 return nullptr;
451 }
452 // create and init context
453 std::string s;
454 for (const auto &device_name : context_info_ptr->target) {
455 s += device_name + " ";
456 }
457 MS_LOG(DEBUG) << "target device: " << s.c_str();
458
459 auto context = std::make_shared<mindspore::Context>();
460 if (context == nullptr) {
461 MS_LOG(ERROR) << "Failed to new context.";
462 return nullptr;
463 }
464
465 auto &device_infos = context->MutableDeviceInfo();
466 if (context_info_ptr->target.empty()) {
467 MS_LOG(ERROR) << "context is empty.";
468 return nullptr;
469 }
470 if (GetDeviceInfoContext(context_info_ptr, device_infos) != SUCCESS) {
471 MS_LOG(ERROR) << "Create context failed.";
472 return nullptr;
473 }
474 context->SetThreadNum(context_info_ptr->cpu_device.thread_num);
475 MS_LOG(DEBUG) << "current thread num is : " << context->GetThreadNum();
476
477 switch (model_info_ptr->mode) {
478 case kBuffer: {
479 MS_LOG(DEBUG) << "input model buffer, model_buffer_total: " << model_info_ptr->model_buffer_total;
480 if (model_info_ptr->model_buffer_data == nullptr || model_info_ptr->model_buffer_total <= 0) {
481 MS_LOG(ERROR) << "Failed to build model.";
482 return nullptr;
483 }
484 std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
485 if (model_ptr == nullptr) {
486 MS_LOG(ERROR) << "Failed to new mindspore::model.";
487 return nullptr;
488 }
489 auto ret = model_ptr->Build(model_info_ptr->model_buffer_data, model_info_ptr->model_buffer_total,
490 mindspore::kMindIR, context);
491 if (ret == mindspore::kSuccess) {
492 MS_LOG(INFO) << "Build model from buffer success.";
493 return model_ptr;
494 }
495 break;
496 }
497 case kPath: {
498 MS_LOG(DEBUG) << "input model path, model_buffer_total: " << model_info_ptr->model_path.c_str();
499 std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
500 if (model_ptr == nullptr) {
501 MS_LOG(ERROR) << "Failed to new mindspore::model.";
502 return nullptr;
503 }
504 auto ret = model_ptr->Build(model_info_ptr->model_path, mindspore::kMindIR, context);
505 if (ret == mindspore::kSuccess) {
506 MS_LOG(INFO) << "Build model from path success.";
507 return model_ptr;
508 }
509 return nullptr;
510 }
511 case kFD: {
512 MS_LOG(DEBUG) << "input model fd:" << model_info_ptr->model_fd
513 << ", model_buffer_total: " << model_info_ptr->model_buffer_total;
514 std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
515 if (model_ptr == nullptr) {
516 MS_LOG(ERROR) << "Failed to new mindspore::model.";
517 return nullptr;
518 }
519 auto ret = model_ptr->Build(model_info_ptr->model_buffer_data, model_info_ptr->model_buffer_total,
520 mindspore::kMindIR, context);
521
522 (void)munmap(model_info_ptr->model_buffer_data, model_info_ptr->model_buffer_total);
523 if (ret == mindspore::kSuccess) {
524 MS_LOG(INFO) << "Build model from fd success.";
525 return model_ptr;
526 }
527
528 break;
529 }
530 default: {
531 MS_LOG(ERROR) << "Invalid model mode.";
532 }
533 }
534 MS_LOG(ERROR) << "Build model failed.";
535 return nullptr;
536 }
537
CreateTrainModel(ModelInfo * model_info_ptr,ContextInfo * context_info_ptr)538 std::shared_ptr<mindspore::Model> MSLiteModelNapi::CreateTrainModel(ModelInfo *model_info_ptr,
539 ContextInfo *context_info_ptr) {
540 // create and init context
541 std::string s;
542 for (const auto &device_name : context_info_ptr->target) {
543 s += device_name + " ";
544 }
545 MS_LOG(DEBUG) << "target device: " << s.c_str();
546
547 auto context = std::make_shared<mindspore::Context>();
548 if (context == nullptr) {
549 MS_LOG(ERROR) << "Failed to new context.";
550 return nullptr;
551 }
552
553 auto &device_infos = context->MutableDeviceInfo();
554 if (context_info_ptr->target.empty()) {
555 MS_LOG(ERROR) << "context is empty.";
556 return nullptr;
557 }
558 if (GetDeviceInfoContext(context_info_ptr, device_infos) != SUCCESS) {
559 MS_LOG(ERROR) << "Create context failed.";
560 return nullptr;
561 }
562
563 auto train_cfg = std::make_shared<TrainCfg>();
564 std::vector<std::string> loss_names;
565 for (const auto &name : train_cfg->GetLossName()) {
566 loss_names.push_back(name);
567 }
568 for (const auto &name : context_info_ptr->train_cfg.loss_names) {
569 loss_names.push_back(name);
570 }
571 train_cfg->SetLossName(loss_names);
572 train_cfg->optimization_level_ = static_cast<OptimizationLevel>(context_info_ptr->train_cfg.optimization_level);
573
574 switch (model_info_ptr->mode) {
575 case kBuffer: {
576 MS_LOG(DEBUG) << "input model buffer, model_buffer_total: " << model_info_ptr->model_buffer_total;
577 if (model_info_ptr->model_buffer_data == nullptr || model_info_ptr->model_buffer_total <= 0) {
578 MS_LOG(ERROR) << "Failed to build model.";
579 return nullptr;
580 }
581 std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
582 if (model_ptr == nullptr) {
583 MS_LOG(ERROR) << "Failed to new mindspore::model.";
584 return nullptr;
585 }
586 mindspore::Graph graph;
587 auto status = mindspore::Serialization::Load(model_info_ptr->model_buffer_data,
588 model_info_ptr->model_buffer_total, mindspore::kMindIR, &graph);
589 if (status != mindspore::kSuccess) {
590 MS_LOG(ERROR) << "load ms file failed.";
591 return nullptr;
592 }
593 auto ret = model_ptr->Build(static_cast<mindspore::GraphCell>(graph), context, train_cfg);
594 if (ret == mindspore::kSuccess) {
595 MS_LOG(INFO) << "Build model from buffer success.";
596 return model_ptr;
597 }
598 break;
599 }
600 case kPath: {
601 MS_LOG(DEBUG) << "input model path, model_buffer_total: " << model_info_ptr->model_path.c_str();
602 std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
603 if (model_ptr == nullptr) {
604 MS_LOG(ERROR) << "Failed to new mindspore::model.";
605 return nullptr;
606 }
607
608 mindspore::Graph graph;
609 auto status = mindspore::Serialization::Load(model_info_ptr->model_path, mindspore::kMindIR, &graph);
610 if (status != mindspore::kSuccess) {
611 MS_LOG(ERROR) << "load ms file failed.";
612 return nullptr;
613 }
614 auto ret = model_ptr->Build(static_cast<mindspore::GraphCell>(graph), context, train_cfg);
615 if (ret == mindspore::kSuccess) {
616 MS_LOG(INFO) << "Build model from path success.";
617 return model_ptr;
618 }
619 return nullptr;
620 }
621 case kFD: {
622 MS_LOG(DEBUG) << "input model fd:" << model_info_ptr->model_fd
623 << ", model_buffer_total: " << model_info_ptr->model_buffer_total;
624 std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
625 if (model_ptr == nullptr) {
626 MS_LOG(ERROR) << "Failed to new mindspore::model.";
627 return nullptr;
628 }
629
630 mindspore::Graph graph;
631 auto status = mindspore::Serialization::Load(model_info_ptr->model_buffer_data,
632 model_info_ptr->model_buffer_total, mindspore::kMindIR, &graph);
633 if (status != mindspore::kSuccess) {
634 MS_LOG(ERROR) << "load ms file failed.";
635 return nullptr;
636 }
637 auto ret = model_ptr->Build(static_cast<mindspore::GraphCell>(graph), context, train_cfg);
638 (void)munmap(model_info_ptr->model_buffer_data, model_info_ptr->model_buffer_total);
639 if (ret == mindspore::kSuccess) {
640 MS_LOG(INFO) << "Build model from fd success.";
641 return model_ptr;
642 }
643
644 break;
645 }
646 default: {
647 MS_LOG(ERROR) << "Invalid model mode.";
648 }
649 }
650 MS_LOG(ERROR) << "Build model failed.";
651 return nullptr;
652 }
653
GetDeviceInfoContext(ContextInfo * context_ptr,std::vector<std::shared_ptr<DeviceInfoContext>> & device_infos)654 int32_t MSLiteModelNapi::GetDeviceInfoContext(ContextInfo *context_ptr,
655 std::vector<std::shared_ptr<DeviceInfoContext>> &device_infos) {
656 for (auto device_name : context_ptr->target) {
657 if (kDeviceTypes.find(device_name) == kDeviceTypes.end()) {
658 MS_LOG(ERROR) << "Invalid device: " << device_name.c_str();
659 return ERR_INVALID_OPERATION;
660 }
661
662 auto device_type = kDeviceTypes.at(device_name);
663 switch (device_type) {
664 case kCPU: {
665 auto cpu_device = std::make_shared<mindspore::CPUDeviceInfo>();
666 if (cpu_device == nullptr) {
667 MS_LOG(ERROR) << "Failed to new CPU deviceInfo.";
668 return ERR_INVALID_OPERATION;
669 }
670 bool is_fp16 = (context_ptr->cpu_device.precision_mode.compare("preferred_fp16") == 0) ? true : false;
671 cpu_device->SetEnableFP16(is_fp16);
672 device_infos.push_back(cpu_device);
673 break;
674 }
675 case kNNRt: {
676 auto nnrt_device = std::make_shared<mindspore::NNRTDeviceInfo>();
677 if (nnrt_device == nullptr) {
678 MS_LOG(ERROR) << "Failed to new NNRT deviceInfo.";
679 return ERR_INVALID_OPERATION;
680 }
681 nnrt_device->SetDeviceID(context_ptr->nnrt_device.device_id);
682 if (context_ptr->nnrt_device.performance_mode != UNSET_VALUE) {
683 nnrt_device->SetPerformanceMode(context_ptr->nnrt_device.performance_mode);
684 }
685 if (context_ptr->nnrt_device.priority != UNSET_VALUE) {
686 nnrt_device->SetPriority(context_ptr->nnrt_device.priority);
687 }
688 // ignore extensions
689 device_infos.push_back(nnrt_device);
690 break;
691 }
692 default: {
693 MS_LOG(ERROR) << "invalid device.";
694 return ERR_INVALID_OPERATION;
695 }
696 }
697 }
698 return SUCCESS;
699 }
700
Constructor(napi_env env,napi_callback_info info)701 napi_value MSLiteModelNapi::Constructor(napi_env env, napi_callback_info info) {
702 napi_status status;
703 napi_value result = nullptr;
704 napi_get_undefined(env, &result);
705 GET_PARAMS(env, info, ARGS_TWO);
706
707 std::unique_ptr<MSLiteModelNapi> model_napi = std::make_unique<MSLiteModelNapi>();
708 if (model_napi == nullptr) {
709 MS_LOG(ERROR) << "No memory";
710 return result;
711 }
712
713 model_napi->env_ = env;
714 if (model_info_->train_model) {
715 model_napi->native_model_ = CreateTrainModel(model_info_, context_);
716 } else {
717 model_napi->native_model_ = CreateModel(model_info_, context_);
718 }
719 if (model_napi->native_model_ == nullptr) {
720 MS_LOG(ERROR) << "Failed to create model.";
721 return result;
722 }
723
724 status =
725 napi_wrap(env, thisVar, reinterpret_cast<void *>(model_napi.get()), MSLiteModelNapi::Finalize, nullptr, nullptr);
726 if (status == napi_ok) {
727 model_napi.release();
728 return thisVar;
729 }
730 return result;
731 }
732
ParseModelInfo(napi_env env,napi_value root,ModelInfo & model_info)733 int32_t MSLiteModelNapi::ParseModelInfo(napi_env env, napi_value root, ModelInfo &model_info) {
734 napi_valuetype valueType;
735 napi_status status = napi_typeof(env, root, &valueType);
736 if (status != napi_ok) {
737 MS_LOG(ERROR) << "napi_typeof error.";
738 return ERR_INVALID_PARAM;
739 }
740 if ((valueType != napi_object) && (valueType != napi_string) && (valueType != napi_number)) {
741 MS_LOG(ERROR) << "model is invaild.";
742 return ERR_INVALID_PARAM;
743 }
744
745 bool is_model_buffer = false;
746 napi_is_arraybuffer(env, root, &is_model_buffer);
747 if (is_model_buffer) {
748 // copy buffer
749 char *array_buffer_data;
750 size_t array_buffer_total;
751 status = napi_get_arraybuffer_info(env, root, reinterpret_cast<void **>(&array_buffer_data), &array_buffer_total);
752 if ((status != napi_ok) || (array_buffer_total <= 0)) {
753 MS_LOG(ERROR) << "Parse model buffer failed.";
754 return ERR_INVALID_PARAM;
755 }
756
757 // shallow copy
758 model_info.model_buffer_data = array_buffer_data;
759 model_info.model_buffer_total = array_buffer_total;
760 model_info.mode = kBuffer;
761 } else if (valueType == napi_number) {
762 int32_t fd;
763 status = napi_get_value_int32(env, root, &fd);
764 if ((status != napi_ok) || (fd <= 0)) {
765 MS_LOG(ERROR) << "Parse model FD failed.";
766 return ERR_INVALID_PARAM;
767 }
768
769 int size = lseek(fd, 0, SEEK_END);
770 (void)lseek(fd, 0, SEEK_SET);
771 auto mmap_buffers = mmap(NULL, size, PROT_READ, MAP_SHARED, fd, 0);
772 if (mmap_buffers == NULL) {
773 MS_LOG(ERROR) << "mmap_buffers is NULL.";
774 return ERR_INVALID_PARAM;
775 }
776 model_info.model_fd = fd;
777 model_info.model_buffer_data = static_cast<char *>(mmap_buffers);
778 model_info.model_buffer_total = size;
779 model_info.mode = kFD;
780 } else {
781 char char_buf[SIZE];
782 size_t buf_length = 0;
783 status = napi_get_value_string_utf8(env, root, char_buf, SIZE, &buf_length);
784 if ((status != napi_ok) || (buf_length <= 0)) {
785 MS_LOG(ERROR) << "Parse model file failed.";
786 return ERR_INVALID_PARAM;
787 }
788 model_info.model_path.assign(char_buf, char_buf + buf_length);
789 model_info.mode = kPath;
790 MS_LOG(DEBUG) << "model_path: " << model_info.model_path.c_str();
791 }
792 return SUCCESS;
793 }
794
ParseContextInfo(napi_env env,napi_value args,ContextInfo & context)795 int32_t MSLiteModelNapi::ParseContextInfo(napi_env env, napi_value args, ContextInfo &context) {
796 napi_valuetype valueType;
797 napi_status status = napi_typeof(env, args, &valueType);
798 if ((status != napi_ok) || (valueType != napi_object)) {
799 MS_LOG(ERROR) << "context is invaild.";
800 return ERR_NOT_EXISTED_PARAM;
801 }
802
803 std::vector<std::string> str_values;
804 auto ret = CommonNapi::GetPropertyStringArray(env, args, "target", str_values);
805 if (ret != SUCCESS) {
806 MS_LOG(ERROR) << "Get context target failed.";
807 return ret;
808 }
809 context.target.assign(str_values.begin(), str_values.end());
810
811 ret = GetCpuDeviceInfo(env, args, context);
812 if (ret != ERR_NOT_EXISTED_PARAM && ret != SUCCESS) {
813 MS_LOG(ERROR) << "Get context CpuDeviceInfo failed.";
814 return ret;
815 }
816
817 ret = GetNNRTDeviceInfo(env, args, context);
818 if (ret != ERR_NOT_EXISTED_PARAM && ret != SUCCESS) {
819 MS_LOG(ERROR) << "Get context NnrtDeviceInfo failed.";
820 return ret;
821 }
822 return SUCCESS;
823 }
824
ParseTrainCfgInfo(napi_env env,napi_value root,TrainConfig & cfg)825 int32_t MSLiteModelNapi::ParseTrainCfgInfo(napi_env env, napi_value root, TrainConfig &cfg) {
826 napi_valuetype valueType;
827 napi_status status = napi_typeof(env, root, &valueType);
828 if ((status != napi_ok) || (valueType != napi_object)) {
829 MS_LOG(ERROR) << "TrainCfg is invaild.";
830 return ERR_NOT_EXISTED_PARAM;
831 }
832 std::vector<std::string> str_values;
833 auto ret = CommonNapi::GetPropertyStringArray(env, root, "lossName", str_values);
834 if (ret != SUCCESS && ret != ERR_NOT_EXISTED_PARAM) {
835 MS_LOG(ERROR) << "Get lossName failed.";
836 return ret;
837 }
838 cfg.loss_names.assign(str_values.begin(), str_values.end());
839
840 int32_t int_value = 0;
841 ret = CommonNapi::GetPropertyInt32(env, root, "optimizationLevel", int_value);
842 if (ret != SUCCESS && ret != ERR_NOT_EXISTED_PARAM) {
843 MS_LOG(ERROR) << "Get optimization level failed";
844 return ret;
845 } else {
846 cfg.optimization_level = int_value;
847 }
848 return SUCCESS;
849 }
850
CreateMSLiteModelWrapper(napi_env env,MSLiteModelAsyncContext * async_context)851 napi_value MSLiteModelNapi::CreateMSLiteModelWrapper(napi_env env, MSLiteModelAsyncContext *async_context) {
852 std::lock_guard<std::mutex> lock(create_mutex_);
853 napi_status status;
854 napi_value result = nullptr;
855 napi_value constructor;
856 napi_get_undefined(env, &result);
857
858 status = napi_get_reference_value(env, constructor_, &constructor);
859 if (status != napi_ok) {
860 MS_LOG(ERROR) << "get reference failed.";
861 return result;
862 }
863 model_info_ = &(async_context->model_info);
864 context_ = &(async_context->context);
865 status = napi_new_instance(env, constructor, 0, nullptr, &result);
866 if (status == napi_ok) {
867 return result;
868 }
869
870 return result;
871 }
872
GetMSLiteModelAsyncCallbackComplete(napi_env env,napi_status status,void * data)873 void MSLiteModelNapi::GetMSLiteModelAsyncCallbackComplete(napi_env env, napi_status status, void *data) {
874 napi_value valueParam = nullptr;
875 auto async_context = static_cast<MSLiteModelAsyncContext *>(data);
876
877 if (async_context != nullptr) {
878 if (!async_context->status) {
879 valueParam = CreateMSLiteModelWrapper(env, async_context);
880 }
881 CommonCallbackRoutine(env, async_context, valueParam);
882 } else {
883 MS_LOG(ERROR) << "GetMSLiteModelAsyncCallbackComplete asyncContext is Null!";
884 }
885 }
886
CommonCallbackRoutine(napi_env env,MSLiteModelAsyncContext * & asyncContext,const napi_value & valueParam)887 void MSLiteModelNapi::CommonCallbackRoutine(napi_env env, MSLiteModelAsyncContext *&asyncContext,
888 const napi_value &valueParam) {
889 napi_value result[ARGS_ONE] = {0};
890 napi_value retVal;
891 napi_value error = nullptr;
892
893 if (!asyncContext->status) {
894 result[PARAM0] = valueParam;
895 } else {
896 napi_value message = nullptr;
897 std::string messageValue = CommonNapi::getMessageByCode(asyncContext->status);
898 napi_create_string_utf8(env, messageValue.c_str(), NAPI_AUTO_LENGTH, &message);
899
900 napi_value code = nullptr;
901 napi_create_string_utf8(env, (std::to_string(asyncContext->status)).c_str(), NAPI_AUTO_LENGTH, &code);
902
903 napi_create_error(env, code, message, &error);
904 napi_get_undefined(env, &result[PARAM0]);
905 }
906
907 if (asyncContext->deferred != nullptr) {
908 if (!asyncContext->status) {
909 napi_resolve_deferred(env, asyncContext->deferred, result[PARAM0]);
910 } else {
911 napi_reject_deferred(env, asyncContext->deferred, error);
912 }
913 } else {
914 napi_value callback = nullptr;
915 napi_get_reference_value(env, asyncContext->callbackRef, &callback);
916 napi_call_function(env, nullptr, callback, ARGS_ONE, result, &retVal);
917 napi_delete_reference(env, asyncContext->callbackRef);
918 }
919 napi_delete_async_work(env, asyncContext->work);
920
921 delete asyncContext;
922 asyncContext = nullptr;
923 }
924
LoadMSLiteModelFromFile(napi_env env,napi_callback_info info)925 napi_value MSLiteModelNapi::LoadMSLiteModelFromFile(napi_env env, napi_callback_info info) {
926 napi_status status;
927 napi_value result = nullptr;
928 const int32_t refCount = 1;
929 GET_PARAMS(env, info, ARGS_THREE);
930 napi_valuetype valueType = napi_undefined;
931
932 std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
933
934 int32_t ret;
935 for (size_t i = PARAM0; i < argc; i++) {
936 if (i == PARAM0) {
937 ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
938 if (ret != SUCCESS) {
939 MS_LOG(ERROR) << "Parsing model failed.";
940 return result;
941 }
942 } else if (i == PARAM1) {
943 napi_typeof(env, argv[i], &valueType);
944 if (valueType == napi_function) {
945 napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
946 } else {
947 ret = ParseContextInfo(env, argv[i], asyncContext->context);
948 if (ret != SUCCESS) {
949 MS_LOG(ERROR) << "Parsing context failed.";
950 return result;
951 }
952 }
953 } else if (i == PARAM2) {
954 napi_typeof(env, argv[i], &valueType);
955 if (valueType == napi_function) {
956 napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
957 }
958 break;
959 } else {
960 MS_LOG(ERROR) << "Invalid input params.";
961 return result;
962 }
963 }
964
965 if (asyncContext->callbackRef == nullptr) {
966 status = napi_create_promise(env, &asyncContext->deferred, &result);
967 if (status != napi_ok) {
968 MS_LOG(ERROR) << "create promise failed.";
969 return result;
970 }
971 } else {
972 status = napi_get_undefined(env, &result);
973 if (status != napi_ok) {
974 MS_LOG(ERROR) << "create callback failed.";
975 return result;
976 }
977 }
978
979 napi_value resource = nullptr;
980 napi_create_string_utf8(env, "LoadMSLiteModelFromFile", NAPI_AUTO_LENGTH, &resource);
981 status = napi_create_async_work(
982 env, nullptr, resource,
983 [](napi_env env, void *data) {
984 auto context = static_cast<MSLiteModelAsyncContext *>(data);
985 context->status = SUCCESS;
986 },
987 GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
988 if (status != napi_ok) {
989 result = nullptr;
990 } else {
991 status = napi_queue_async_work(env, asyncContext->work);
992 if (status == napi_ok) {
993 asyncContext.release();
994 } else {
995 result = nullptr;
996 }
997 }
998 return result;
999 }
1000
LoadMSLiteTrainModelFromFile(napi_env env,napi_callback_info info)1001 napi_value MSLiteModelNapi::LoadMSLiteTrainModelFromFile(napi_env env, napi_callback_info info) {
1002 napi_status status;
1003 napi_value result = nullptr;
1004 GET_PARAMS(env, info, ARGS_THREE);
1005
1006 std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1007
1008 asyncContext->model_info.train_model = true;
1009 int32_t ret;
1010 for (size_t i = PARAM0; i < argc; i++) {
1011 if (i == PARAM0) {
1012 ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1013 if (ret != SUCCESS) {
1014 MS_LOG(ERROR) << "Parsing model failed.";
1015 return result;
1016 }
1017 } else if (i == PARAM1) {
1018 ret = ParseTrainCfgInfo(env, argv[i], asyncContext->context.train_cfg);
1019 if (ret != SUCCESS) {
1020 MS_LOG(ERROR) << "Parsing TrainCfg failed.";
1021 return result;
1022 }
1023 } else if (i == PARAM2) {
1024 ret = ParseContextInfo(env, argv[i], asyncContext->context);
1025 if (ret != SUCCESS) {
1026 MS_LOG(ERROR) << "Parsing context failed.";
1027 return result;
1028 }
1029 } else {
1030 MS_LOG(ERROR) << "Invalid input params.";
1031 return result;
1032 }
1033 }
1034
1035 if (asyncContext->callbackRef == nullptr) {
1036 status = napi_create_promise(env, &asyncContext->deferred, &result);
1037 if (status != napi_ok) {
1038 MS_LOG(ERROR) << "create promise failed.";
1039 return result;
1040 }
1041 } else {
1042 status = napi_get_undefined(env, &result);
1043 if (status != napi_ok) {
1044 MS_LOG(ERROR) << "create callback failed.";
1045 return result;
1046 }
1047 }
1048
1049 napi_value resource = nullptr;
1050 napi_create_string_utf8(env, "LoadMSLiteTrainModelFromFile", NAPI_AUTO_LENGTH, &resource);
1051 status = napi_create_async_work(
1052 env, nullptr, resource,
1053 [](napi_env env, void *data) {
1054 auto context = static_cast<MSLiteModelAsyncContext *>(data);
1055 context->status = SUCCESS;
1056 },
1057 GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1058 if (status != napi_ok) {
1059 result = nullptr;
1060 } else {
1061 status = napi_queue_async_work(env, asyncContext->work);
1062 if (status == napi_ok) {
1063 asyncContext.release();
1064 } else {
1065 result = nullptr;
1066 }
1067 }
1068 return result;
1069 }
1070
LoadMSLiteTrainModelFromBuffer(napi_env env,napi_callback_info info)1071 napi_value MSLiteModelNapi::LoadMSLiteTrainModelFromBuffer(napi_env env, napi_callback_info info) {
1072 napi_status status;
1073 napi_value result = nullptr;
1074 GET_PARAMS(env, info, ARGS_THREE);
1075
1076 std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1077
1078 asyncContext->model_info.train_model = true;
1079 int32_t ret;
1080 for (size_t i = PARAM0; i < argc; i++) {
1081 if (i == PARAM0) {
1082 ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1083 if (ret != SUCCESS) {
1084 MS_LOG(ERROR) << "Parsing model failed.";
1085 return result;
1086 }
1087 } else if (i == PARAM1) {
1088 ret = ParseTrainCfgInfo(env, argv[i], asyncContext->context.train_cfg);
1089 if (ret != SUCCESS) {
1090 MS_LOG(ERROR) << "Parsing TrainCfg failed.";
1091 return result;
1092 }
1093 } else if (i == PARAM2) {
1094 ret = ParseContextInfo(env, argv[i], asyncContext->context);
1095 if (ret != SUCCESS) {
1096 MS_LOG(ERROR) << "Parsing context failed.";
1097 return result;
1098 }
1099 } else {
1100 MS_LOG(ERROR) << "Invalid input params.";
1101 return result;
1102 }
1103 }
1104
1105 if (asyncContext->callbackRef == nullptr) {
1106 status = napi_create_promise(env, &asyncContext->deferred, &result);
1107 if (status != napi_ok) {
1108 MS_LOG(ERROR) << "create promise failed.";
1109 return result;
1110 }
1111 } else {
1112 status = napi_get_undefined(env, &result);
1113 if (status != napi_ok) {
1114 MS_LOG(ERROR) << "create callback failed.";
1115 return result;
1116 }
1117 }
1118
1119 napi_value resource = nullptr;
1120 napi_create_string_utf8(env, "LoadMSLiteTrainModelFromBuffer", NAPI_AUTO_LENGTH, &resource);
1121 status = napi_create_async_work(
1122 env, nullptr, resource,
1123 [](napi_env env, void *data) {
1124 auto context = static_cast<MSLiteModelAsyncContext *>(data);
1125 context->status = SUCCESS;
1126 },
1127 GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1128 if (status != napi_ok) {
1129 result = nullptr;
1130 } else {
1131 status = napi_queue_async_work(env, asyncContext->work);
1132 if (status == napi_ok) {
1133 asyncContext.release();
1134 } else {
1135 result = nullptr;
1136 }
1137 }
1138 return result;
1139 }
1140
LoadMSLiteTrainModelFromFd(napi_env env,napi_callback_info info)1141 napi_value MSLiteModelNapi::LoadMSLiteTrainModelFromFd(napi_env env, napi_callback_info info) {
1142 napi_status status;
1143 napi_value result = nullptr;
1144 GET_PARAMS(env, info, ARGS_THREE);
1145
1146 std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1147
1148 asyncContext->model_info.train_model = true;
1149 int32_t ret;
1150 for (size_t i = PARAM0; i < argc; i++) {
1151 if (i == PARAM0) {
1152 ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1153 if (ret != SUCCESS) {
1154 MS_LOG(ERROR) << "Parsing model failed.";
1155 return result;
1156 }
1157 } else if (i == PARAM1) {
1158 ret = ParseTrainCfgInfo(env, argv[i], asyncContext->context.train_cfg);
1159 if (ret != SUCCESS) {
1160 MS_LOG(ERROR) << "Parsing TrainCfg failed.";
1161 return result;
1162 }
1163 } else if (i == PARAM2) {
1164 ret = ParseContextInfo(env, argv[i], asyncContext->context);
1165 if (ret != SUCCESS) {
1166 MS_LOG(ERROR) << "Parsing context failed.";
1167 return result;
1168 }
1169 } else {
1170 MS_LOG(ERROR) << "Invalid input params.";
1171 return result;
1172 }
1173 }
1174
1175 if (asyncContext->callbackRef == nullptr) {
1176 status = napi_create_promise(env, &asyncContext->deferred, &result);
1177 if (status != napi_ok) {
1178 MS_LOG(ERROR) << "create promise failed.";
1179 return result;
1180 }
1181 } else {
1182 status = napi_get_undefined(env, &result);
1183 if (status != napi_ok) {
1184 MS_LOG(ERROR) << "create callback failed.";
1185 return result;
1186 }
1187 }
1188
1189 napi_value resource = nullptr;
1190 napi_create_string_utf8(env, "LoadMSLiteTrainModelFromFd", NAPI_AUTO_LENGTH, &resource);
1191 status = napi_create_async_work(
1192 env, nullptr, resource,
1193 [](napi_env env, void *data) {
1194 auto context = static_cast<MSLiteModelAsyncContext *>(data);
1195 context->status = SUCCESS;
1196 },
1197 GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1198 if (status != napi_ok) {
1199 result = nullptr;
1200 } else {
1201 status = napi_queue_async_work(env, asyncContext->work);
1202 if (status == napi_ok) {
1203 asyncContext.release();
1204 } else {
1205 result = nullptr;
1206 }
1207 }
1208 return result;
1209 }
1210
LoadMSLiteModelFromBuffer(napi_env env,napi_callback_info info)1211 napi_value MSLiteModelNapi::LoadMSLiteModelFromBuffer(napi_env env, napi_callback_info info) {
1212 napi_status status;
1213 napi_value result = nullptr;
1214 const int32_t refCount = 1;
1215 GET_PARAMS(env, info, ARGS_THREE);
1216 napi_valuetype valueType = napi_undefined;
1217
1218 std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1219
1220 int32_t ret;
1221 for (size_t i = PARAM0; i < argc; i++) {
1222 if (i == PARAM0) {
1223 ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1224 if (ret != SUCCESS) {
1225 MS_LOG(ERROR) << "Parsing model failed.";
1226 return result;
1227 }
1228 } else if (i == PARAM1) {
1229 napi_typeof(env, argv[i], &valueType);
1230 if (valueType == napi_function) {
1231 napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1232 } else {
1233 ret = ParseContextInfo(env, argv[i], asyncContext->context);
1234 if (ret != SUCCESS) {
1235 MS_LOG(ERROR) << "Parsing context failed.";
1236 return result;
1237 }
1238 }
1239 } else if (i == PARAM2) {
1240 napi_typeof(env, argv[i], &valueType);
1241 if (valueType == napi_function) {
1242 napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1243 }
1244 break;
1245 } else {
1246 MS_LOG(ERROR) << "Invalid input params.";
1247 return result;
1248 }
1249 }
1250
1251 if (asyncContext->callbackRef == nullptr) {
1252 status = napi_create_promise(env, &asyncContext->deferred, &result);
1253 if (status != napi_ok) {
1254 MS_LOG(ERROR) << "create promise failed.";
1255 return result;
1256 }
1257 } else {
1258 status = napi_get_undefined(env, &result);
1259 if (status != napi_ok) {
1260 MS_LOG(ERROR) << "create callback failed.";
1261 return result;
1262 }
1263 }
1264
1265 napi_value resource = nullptr;
1266 napi_create_string_utf8(env, "LoadMSLiteModelFromBuffer", NAPI_AUTO_LENGTH, &resource);
1267 status = napi_create_async_work(
1268 env, nullptr, resource,
1269 [](napi_env env, void *data) {
1270 auto context = static_cast<MSLiteModelAsyncContext *>(data);
1271 context->status = SUCCESS;
1272 },
1273 GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1274 if (status != napi_ok) {
1275 result = nullptr;
1276 } else {
1277 status = napi_queue_async_work(env, asyncContext->work);
1278 if (status == napi_ok) {
1279 asyncContext.release();
1280 } else {
1281 result = nullptr;
1282 }
1283 }
1284 return result;
1285 }
1286
LoadMSLiteModelFromFd(napi_env env,napi_callback_info info)1287 napi_value MSLiteModelNapi::LoadMSLiteModelFromFd(napi_env env, napi_callback_info info) {
1288 napi_status status;
1289 napi_value result = nullptr;
1290 const int32_t refCount = 1;
1291 GET_PARAMS(env, info, ARGS_THREE);
1292 napi_valuetype valueType = napi_undefined;
1293
1294 std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1295
1296 int32_t ret;
1297 for (size_t i = PARAM0; i < argc; i++) {
1298 if (i == PARAM0) {
1299 ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1300 if (ret != SUCCESS) {
1301 MS_LOG(ERROR) << "Parsing model failed.";
1302 return result;
1303 }
1304 } else if (i == PARAM1) {
1305 napi_typeof(env, argv[i], &valueType);
1306 if (valueType == napi_function) {
1307 napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1308 } else {
1309 ret = ParseContextInfo(env, argv[i], asyncContext->context);
1310 if (ret != SUCCESS) {
1311 MS_LOG(ERROR) << "Parsing context failed.";
1312 return result;
1313 }
1314 }
1315 } else if (i == PARAM2) {
1316 napi_typeof(env, argv[i], &valueType);
1317 if (valueType == napi_function) {
1318 napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1319 }
1320 break;
1321 } else {
1322 MS_LOG(ERROR) << "Invalid input params.";
1323 return result;
1324 }
1325 }
1326
1327 if (asyncContext->callbackRef == nullptr) {
1328 status = napi_create_promise(env, &asyncContext->deferred, &result);
1329 if (status != napi_ok) {
1330 MS_LOG(ERROR) << "create promise failed.";
1331 return result;
1332 }
1333 } else {
1334 status = napi_get_undefined(env, &result);
1335 if (status != napi_ok) {
1336 MS_LOG(ERROR) << "create callback failed.";
1337 return result;
1338 }
1339 }
1340
1341 napi_value resource = nullptr;
1342 napi_create_string_utf8(env, "LoadMSLiteModelFromFd", NAPI_AUTO_LENGTH, &resource);
1343 status = napi_create_async_work(
1344 env, nullptr, resource,
1345 [](napi_env env, void *data) {
1346 auto context = static_cast<MSLiteModelAsyncContext *>(data);
1347 context->status = SUCCESS;
1348 },
1349 GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1350 if (status != napi_ok) {
1351 result = nullptr;
1352 } else {
1353 status = napi_queue_async_work(env, asyncContext->work);
1354 if (status == napi_ok) {
1355 asyncContext.release();
1356 } else {
1357 result = nullptr;
1358 }
1359 }
1360 return result;
1361 }
1362
GetCpuDeviceInfo(napi_env env,napi_value args,ContextInfo & context)1363 int32_t MSLiteModelNapi::GetCpuDeviceInfo(napi_env env, napi_value args, ContextInfo &context) {
1364 bool has_cpu_property = false;
1365 napi_status status = napi_has_named_property(env, args, "cpu", &has_cpu_property);
1366 if (status != napi_ok) {
1367 MS_LOG(ERROR) << "can not find cpu property";
1368 return ERR_INVALID_OPERATION;
1369 }
1370 if (!has_cpu_property) {
1371 return ERR_NOT_EXISTED_PARAM;
1372 }
1373
1374 napi_value config_item = nullptr;
1375 status = napi_get_named_property(env, args, "cpu", &config_item);
1376 if (status != napi_ok) {
1377 MS_LOG(ERROR) << "can not get cpu property";
1378 return ERR_INVALID_OPERATION;
1379 }
1380
1381 int32_t int_value = 0;
1382 std::string str_value = "";
1383 std::vector<int32_t> affinity_cores;
1384
1385 if (CommonNapi::GetPropertyInt32(env, config_item, "threadNum", int_value) == SUCCESS) {
1386 MS_LOG(DEBUG) << "threadNum: " << int_value;
1387 context.cpu_device.thread_num = int_value;
1388 } else {
1389 context.cpu_device.thread_num = PARAM2;
1390 }
1391
1392 if (CommonNapi::GetPropertyInt32(env, config_item, "threadAffinityMode", int_value) == SUCCESS) {
1393 MS_LOG(DEBUG) << "threadAffinityMode: " << int_value;
1394 if (int_value > PARAM2 || int_value < PARAM0) {
1395 MS_LOG(ERROR) << "threadAffinityMode value is set: " << int_value << ", is out of limition";
1396 return ERR_INVALID_OPERATION;
1397 }
1398 context.cpu_device.thread_affinity_mode = int_value;
1399 } else {
1400 context.cpu_device.thread_affinity_mode = PARAM0;
1401 }
1402
1403 if (CommonNapi::GetPropertyInt32Array(env, config_item, "threadAffinityCoreList", affinity_cores) == SUCCESS) {
1404 MS_LOG(DEBUG) << "affinityCores size: " << affinity_cores.size();
1405 context.cpu_device.thread_affinity_cores.assign(affinity_cores.begin(), affinity_cores.end());
1406 } else {
1407 context.cpu_device.thread_affinity_cores = {};
1408 }
1409
1410 if (CommonNapi::GetPropertyString(env, config_item, "precisionMode", str_value) == SUCCESS) {
1411 MS_LOG(DEBUG) << "precisionMode: " << str_value.c_str();
1412 context.cpu_device.precision_mode = str_value;
1413 } else {
1414 context.cpu_device.precision_mode = "enforce_fp32";
1415 }
1416 return SUCCESS;
1417 }
1418
GetNNRTDeviceInfo(napi_env env,napi_value args,ContextInfo & context)1419 int32_t MSLiteModelNapi::GetNNRTDeviceInfo(napi_env env, napi_value args, ContextInfo &context) {
1420 bool has_nnrt_property = false;
1421 napi_status status = napi_has_named_property(env, args, "nnrt", &has_nnrt_property);
1422 if (status != napi_ok) {
1423 MS_LOG(ERROR) << "can not find nnrt property";
1424 return ERR_ILLEGAL_STATE;
1425 }
1426 if (!has_nnrt_property) {
1427 return ERR_NOT_EXISTED_PARAM;
1428 }
1429
1430 napi_value config_item = nullptr;
1431 status = napi_get_named_property(env, args, "nnrt", &config_item);
1432 if (status != napi_ok) {
1433 MS_LOG(ERROR) << "can not get nnrt property";
1434 return ERR_INVALID_PARAM;
1435 }
1436
1437 int32_t int_value = 0;
1438 std::string str_value = "";
1439 std::vector<int32_t> affinity_cores;
1440
1441 uint64_t device_id;
1442 auto ret = CommonNapi::GetPropertyBigIntUint64(env, config_item, "deviceID", device_id);
1443 if (ret == SUCCESS) {
1444 MS_LOG(DEBUG) << "deviceID: " << device_id;
1445 context.nnrt_device.device_id = static_cast<size_t>(device_id);
1446 } else if (ret == ERR_NOT_EXISTED_PARAM) {
1447 size_t num = 0;
1448 auto *desc = OH_AI_GetAllNNRTDeviceDescs(&num);
1449 if (desc == nullptr || num == 0) {
1450 MS_LOG(WARNING) << "Failed to get nnrt device id, skip adding nnrt device info.";
1451 return ERR_NOT_EXISTED_PARAM;
1452 }
1453 auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc);
1454 OH_AI_DestroyAllNNRTDeviceDescs(&desc);
1455 MS_LOG(INFO) << "set nnrt device id to " << id;
1456 context.nnrt_device.device_id = id;
1457 } else {
1458 return ERR_INVALID_PARAM;
1459 }
1460
1461 ret = CommonNapi::GetPropertyInt32(env, config_item, "performanceMode", int_value);
1462 if (ret == SUCCESS) {
1463 MS_LOG(DEBUG) << "performanceMode: " << int_value;
1464 if (int_value > PARAM4 || int_value < PARAM0) {
1465 MS_LOG(ERROR) << "performanceMode value is set to: " << int_value << ", which is out of range";
1466 return ERR_INVALID_PARAM;
1467 }
1468 context.nnrt_device.performance_mode = int_value;
1469 } else if (ret == ERR_NOT_EXISTED_PARAM) {
1470 context.nnrt_device.performance_mode = UNSET_VALUE;
1471 } else {
1472 return ERR_INVALID_PARAM;
1473 }
1474
1475 ret = CommonNapi::GetPropertyInt32(env, config_item, "priority", int_value);
1476 if (ret == SUCCESS) {
1477 MS_LOG(DEBUG) << "priority: " << int_value;
1478 if (int_value > PARAM3 || int_value < PARAM0) {
1479 MS_LOG(ERROR) << "priority value is set to: " << int_value << ", which is out of range";
1480 return ERR_INVALID_PARAM;
1481 }
1482 context.nnrt_device.priority = int_value;
1483 } else if (ret == ERR_NOT_EXISTED_PARAM) {
1484 context.nnrt_device.priority = UNSET_VALUE;
1485 } else {
1486 return ERR_INVALID_PARAM;
1487 }
1488
1489 // ignore extensions for now
1490 return SUCCESS;
1491 }
1492
GetInputs(napi_env env,napi_callback_info info)1493 napi_value MSLiteModelNapi::GetInputs(napi_env env, napi_callback_info info) {
1494 napi_value undefinedResult = nullptr;
1495 napi_get_undefined(env, &undefinedResult);
1496
1497 size_t argCount = 0;
1498 napi_value jsThis = nullptr;
1499 napi_value jsResult = nullptr;
1500 MSLiteModelNapi *modelNapi = nullptr;
1501
1502 napi_status status = napi_get_cb_info(env, info, &argCount, nullptr, &jsThis, nullptr);
1503 if (status != napi_ok || jsThis == nullptr) {
1504 MS_LOG(ERROR) << "failed to retrieve details about the callback";
1505 return undefinedResult;
1506 }
1507
1508 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
1509 if (status != napi_ok || modelNapi == nullptr) {
1510 MS_LOG(ERROR) << "failed to get model";
1511 return undefinedResult;
1512 }
1513
1514 if (modelNapi->native_model_ == nullptr) {
1515 MS_LOG(ERROR) << "model is released(null), please create model again";
1516 return undefinedResult;
1517 }
1518 std::vector<MSTensor> inputs = modelNapi->native_model_->GetInputs();
1519 std::vector<MSTensor> tensor_inputs;
1520 for (size_t i = 0; i < inputs.size(); i++) {
1521 auto tensor = mindspore::MSTensor::CreateTensor(inputs.at(i).Name(), inputs.at(i).DataType(), {}, nullptr, 0);
1522 if (tensor == nullptr) {
1523 MS_LOG(ERROR) << "create tensor failed.";
1524 return undefinedResult;
1525 }
1526 tensor->SetShape(inputs.at(i).Shape());
1527 tensor->SetFormat(inputs.at(i).format());
1528 tensor->SetDataType(inputs.at(i).DataType());
1529 tensor_inputs.push_back(*tensor);
1530 delete tensor;
1531 }
1532
1533 size_t size = inputs.size();
1534 MS_LOG(INFO) << "inputs size: " << size;
1535 napi_create_array_with_length(env, size, &jsResult);
1536 for (size_t i = 0; i < size; i++) {
1537 status = napi_set_element(env, jsResult, i, MSTensorNapi::NewInstance(env, tensor_inputs[i]));
1538 if (status != napi_ok) {
1539 MS_LOG(ERROR) << "napi_set_element failed! code: " << status;
1540 }
1541 }
1542 MS_LOG(INFO) << "get model inputs success: " << inputs[0].Name().c_str();
1543 return jsResult;
1544 }
1545
Resize(napi_env env,napi_callback_info info)1546 napi_value MSLiteModelNapi::Resize(napi_env env, napi_callback_info info) {
1547 napi_value undefinedResult = nullptr;
1548 bool result = false;
1549 napi_status status = napi_get_boolean(env, result, &undefinedResult);
1550 if (status != napi_ok) {
1551 MS_LOG(ERROR) << "get bool error";
1552 return undefinedResult;
1553 }
1554
1555 napi_value jsThis = nullptr;
1556 napi_value jsResult = nullptr;
1557 MSLiteModelNapi *modelNapi = nullptr;
1558 napi_value argv[ARGS_TWO] = {0};
1559 size_t argCount = PARAM2;
1560 status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
1561 if (status != napi_ok || jsThis == nullptr) {
1562 MS_LOG(ERROR) << "failed to retrieve details about the callback";
1563 return undefinedResult;
1564 }
1565 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
1566 if (status != napi_ok || modelNapi == nullptr) {
1567 MS_LOG(ERROR) << "get model napi error";
1568 return undefinedResult;
1569 }
1570
1571 if (modelNapi->native_model_ == nullptr) {
1572 MS_LOG(ERROR) << "model is released(null), please create model again";
1573 return undefinedResult;
1574 }
1575 std::vector<MSTensor> inputs = modelNapi->native_model_->GetInputs();
1576 std::vector<MSTensor> tensor_inputs;
1577 std::vector<std::vector<int64_t>> dims;
1578
1579 // set inputs data
1580 uint32_t array_length = 0;
1581 status = napi_get_array_length(env, argv[PARAM0], &array_length);
1582 if (status != napi_ok || array_length <= 0) {
1583 MS_LOG(ERROR) << "get inputs tensor length failed.";
1584 return undefinedResult;
1585 }
1586 if (inputs.size() != array_length) {
1587 MS_LOG(ERROR) << "array length not equal to model inputs size.";
1588 return undefinedResult;
1589 }
1590 for (size_t i = 0; i < array_length; i++) {
1591 napi_value element = nullptr;
1592 status = napi_get_element(env, argv[PARAM0], i, &element);
1593 if (status != napi_ok) {
1594 MS_LOG(ERROR) << "can not get element";
1595 return undefinedResult;
1596 }
1597
1598 std::string property_name = "getData";
1599 bool exist = false;
1600 napi_value data_func = nullptr;
1601
1602 status = napi_has_named_property(env, element, property_name.c_str(), &exist);
1603 if (status != napi_ok || !exist) {
1604 MS_LOG(ERROR) << "can not find target property";
1605 return undefinedResult;
1606 }
1607
1608 if (status != napi_ok || !exist) {
1609 MS_LOG(ERROR) << "can not find " << property_name.c_str() << " property.";
1610 return undefinedResult;
1611 }
1612
1613 if (napi_get_named_property(env, element, property_name.c_str(), &data_func) != napi_ok) {
1614 MS_LOG(ERROR) << "get " << property_name.c_str() << " property fail.";
1615 return undefinedResult;
1616 }
1617 void *js_data = nullptr;
1618 size_t length = 0;
1619 napi_value return_val;
1620
1621 status = napi_call_function(env, element, data_func, 0, nullptr, &return_val);
1622 if (status != napi_ok || return_val == nullptr) {
1623 MS_LOG(ERROR) << "napi call function error.";
1624 return undefinedResult;
1625 }
1626
1627 status = napi_get_arraybuffer_info(env, return_val, &js_data, &length);
1628 if (status != napi_ok || js_data == nullptr) {
1629 MS_LOG(ERROR) << "get js data error.";
1630 return undefinedResult;
1631 }
1632 if (inputs[i].DataSize() != length) {
1633 MS_LOG(ERROR) << "tensor size is: " << static_cast<int>(inputs[i].DataSize()) << ", but data length got "
1634 << static_cast<int>(length);
1635 return undefinedResult;
1636 }
1637
1638 auto tensor_data = inputs[i].MutableData();
1639 if (tensor_data == nullptr) {
1640 MS_LOG(ERROR) << "malloc data for tensor failed.";
1641 return undefinedResult;
1642 }
1643 memcpy(tensor_data, js_data, length);
1644 }
1645
1646 napi_value dim_num = nullptr;
1647 int64_t dim_ele = 0;
1648 uint32_t dims_size = 0;
1649 uint32_t dim_size = 0;
1650
1651 status = napi_is_array(env, argv[PARAM1], &result);
1652 if (status != napi_ok || result == false) {
1653 MS_LOG(ERROR) << "new dim is not a array";
1654 return undefinedResult;
1655 }
1656
1657 status = napi_get_array_length(env, argv[PARAM1], &dims_size);
1658 if (status != napi_ok) {
1659 MS_LOG(ERROR) << "get new dims size error";
1660 return undefinedResult;
1661 }
1662 for (size_t i = 0; i < dims_size; i++) {
1663 napi_value dim_element = nullptr;
1664 status = napi_get_element(env, argv[PARAM1], i, &dim_element);
1665 if (status != napi_ok) {
1666 MS_LOG(ERROR) << "can not get element";
1667 return undefinedResult;
1668 }
1669
1670 status = napi_is_array(env, dim_element, &result);
1671 if (status != napi_ok || result == false) {
1672 MS_LOG(ERROR) << "new dim's element is not a array";
1673 return undefinedResult;
1674 }
1675
1676 status = napi_get_array_length(env, dim_element, &dim_size);
1677 if (status != napi_ok) {
1678 MS_LOG(ERROR) << "get new dim size error";
1679 return undefinedResult;
1680 }
1681 std::vector<int64_t> dim(dim_size);
1682 for (size_t j = 0; j < dim_size; j++) {
1683 status = napi_get_element(env, dim_element, j, &dim_num);
1684 if (status != napi_ok) {
1685 MS_LOG(ERROR) << "get dim num error";
1686 return undefinedResult;
1687 }
1688 status = napi_get_value_int64(env, dim_num, &dim_ele);
1689 if (status != napi_ok) {
1690 MS_LOG(ERROR) << "get dim element error";
1691 return undefinedResult;
1692 }
1693 dim[j] = dim_ele;
1694 }
1695 dims.push_back(dim);
1696 }
1697 if (modelNapi->native_model_->Resize(inputs, dims) != mindspore::kSuccess) {
1698 MS_LOG(ERROR) << "resize failed";
1699 return undefinedResult;
1700 }
1701 status = napi_get_boolean(env, result, &jsResult);
1702 if (status != napi_ok) {
1703 MS_LOG(ERROR) << "get bool error";
1704 return undefinedResult;
1705 }
1706 return jsResult;
1707 }
1708
1709 template <typename T, typename Distribution>
GenerateRandomData(int size,void * data,Distribution distribution)1710 void GenerateRandomData(int size, void *data, Distribution distribution) {
1711 std::mt19937 random_engine;
1712 int elements_num = size / sizeof(T);
1713 (void)std::generate_n(static_cast<T *>(data), elements_num,
1714 [&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
1715 }
1716
GenerateInputDataWithRandom(std::vector<mindspore::MSTensor> inputs)1717 int GenerateInputDataWithRandom(std::vector<mindspore::MSTensor> inputs) {
1718 for (auto tensor : inputs) {
1719 auto input_data = tensor.MutableData();
1720 if (input_data == nullptr) {
1721 std::cerr << "mallocData for inTensor failed." << std::endl;
1722 return -1;
1723 }
1724 GenerateRandomData<float>(tensor.DataSize(), input_data, std::uniform_real_distribution<float>(0.1f, 1.0f));
1725 }
1726 return mindspore::kSuccess;
1727 }
1728
PredictAsync(napi_env env,napi_callback_info info)1729 napi_value MSLiteModelNapi::PredictAsync(napi_env env, napi_callback_info info) {
1730 napi_status status = napi_ok;
1731 napi_value undefinedResult = nullptr;
1732 napi_value result = nullptr;
1733 const int32_t refCount = 1;
1734 napi_valuetype valueType;
1735
1736 std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1737 if (asyncContext == nullptr) {
1738 MS_LOG(ERROR) << "MSLiteModelAsyncContext object create failed.";
1739 return undefinedResult;
1740 }
1741
1742 GET_PARAMS(env, info, ARGS_TWO);
1743 for (size_t i = PARAM0; i < argc; i++) {
1744 if (i == PARAM1) {
1745 status = napi_typeof(env, argv[i], &valueType);
1746 if ((status != napi_ok) || (valueType != napi_function)) {
1747 MS_LOG(ERROR) << "napi_typeof check callback failed.";
1748 return result;
1749 }
1750 status = napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1751 if (status != napi_ok) {
1752 MS_LOG(ERROR) << "failed to create reference of callback";
1753 return result;
1754 }
1755 }
1756 }
1757
1758 if (SetTensorData(env, thisVar, argv[PARAM0], asyncContext.get()) != SUCCESS) {
1759 MS_LOG(ERROR) << "Set tensor data failed.";
1760 return undefinedResult;
1761 }
1762
1763 if (asyncContext->callbackRef == nullptr) {
1764 status = napi_create_promise(env, &asyncContext->deferred, &result);
1765 if (status != napi_ok) {
1766 MS_LOG(ERROR) << "create promise failed.";
1767 return result;
1768 }
1769 } else {
1770 status = napi_get_undefined(env, &result);
1771 if (status != napi_ok) {
1772 MS_LOG(ERROR) << "create callback failed.";
1773 return result;
1774 }
1775 }
1776
1777 napi_value resource = nullptr;
1778 napi_create_string_utf8(env, "Predict", NAPI_AUTO_LENGTH, &resource);
1779 status = napi_create_async_work(
1780 env, nullptr, resource,
1781 [](napi_env env, void *data) {
1782 auto context = static_cast<MSLiteModelAsyncContext *>(data);
1783 context->status = SUCCESS;
1784 },
1785 PredictAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1786 if (status != napi_ok) {
1787 result = nullptr;
1788 } else {
1789 status = napi_queue_async_work(env, asyncContext->work);
1790 if (status == napi_ok) {
1791 asyncContext.release();
1792 } else {
1793 result = nullptr;
1794 }
1795 }
1796 return result;
1797 }
1798
SetTensorData(napi_env env,napi_value thisVar,napi_value argv,MSLiteModelAsyncContext * async_context)1799 int32_t MSLiteModelNapi::SetTensorData(napi_env env, napi_value thisVar, napi_value argv,
1800 MSLiteModelAsyncContext *async_context) {
1801 uint32_t array_length = 0;
1802 napi_status status = napi_get_array_length(env, argv, &array_length);
1803 if (status != napi_ok || array_length <= 0) {
1804 MS_LOG(ERROR) << "get inputs tensor length failed.";
1805 return ERR_INVALID_PARAM;
1806 }
1807
1808 status = napi_unwrap(env, thisVar, reinterpret_cast<void **>(&(async_context->lite_model)));
1809 if (status != napi_ok || async_context->lite_model == nullptr) {
1810 MS_LOG(ERROR) << "get model napi error";
1811 return ERROR;
1812 }
1813 auto modelNapi = async_context->lite_model;
1814 if (modelNapi->native_model_ == nullptr) {
1815 MS_LOG(ERROR) << "model is released(null), please create model again";
1816 return ERROR;
1817 }
1818
1819 auto inputs = modelNapi->native_model_->GetInputs();
1820 if (inputs.size() != array_length) {
1821 MS_LOG(ERROR) << "array length not equal to model inputs size.";
1822 return ERR_INVALID_PARAM;
1823 }
1824
1825 for (size_t i = 0; i < array_length; i++) {
1826 napi_value element = nullptr;
1827 status = napi_get_element(env, argv, i, &element);
1828 if (status != napi_ok) {
1829 MS_LOG(ERROR) << "can not get element";
1830 return ERROR;
1831 }
1832
1833 std::string property_name = "getData";
1834 bool exist = false;
1835 napi_value data_func = nullptr;
1836
1837 napi_status status = napi_has_named_property(env, element, property_name.c_str(), &exist);
1838
1839 if (status != napi_ok || !exist) {
1840 MS_LOG(ERROR) << "can not find " << property_name.c_str() << " property.";
1841 return ERROR;
1842 }
1843
1844 if (napi_get_named_property(env, element, property_name.c_str(), &data_func) != napi_ok) {
1845 MS_LOG(ERROR) << "get " << property_name.c_str() << " property fail.";
1846 return ERROR;
1847 }
1848 void *js_data = nullptr;
1849 size_t length = 0;
1850 napi_value return_val;
1851
1852 status = napi_call_function(env, element, data_func, 0, nullptr, &return_val);
1853 if (status != napi_ok || return_val == nullptr) {
1854 MS_LOG(ERROR) << "napi call function error.";
1855 return ERROR;
1856 }
1857 status = napi_get_arraybuffer_info(env, return_val, &js_data, &length);
1858 if (status != napi_ok || js_data == nullptr) {
1859 MS_LOG(ERROR) << "Get js data error.";
1860 return ERROR;
1861 }
1862 if (inputs[i].DataSize() != length) {
1863 MS_LOG(ERROR) << "tensor size is: " << static_cast<int>(inputs[i].DataSize()) << ", but data length got "
1864 << static_cast<int>(length);
1865 return ERROR;
1866 }
1867
1868 auto tensor_data = inputs[i].MutableData();
1869 if (tensor_data == nullptr) {
1870 MS_LOG(ERROR) << "malloc data for tensor failed.";
1871 return ERROR;
1872 }
1873 memcpy(tensor_data, js_data, length);
1874 }
1875 return SUCCESS;
1876 }
1877
PredictAsyncCallbackComplete(napi_env env,napi_status status,void * data)1878 void MSLiteModelNapi::PredictAsyncCallbackComplete(napi_env env, napi_status status, void *data) {
1879 napi_value valueParam = nullptr;
1880 auto asyncContext = static_cast<MSLiteModelAsyncContext *>(data);
1881
1882 if (asyncContext != nullptr) {
1883 if (!asyncContext->status) {
1884 auto modelNapi = asyncContext->lite_model;
1885 if (modelNapi->native_model_ == nullptr) {
1886 MS_LOG(ERROR) << "model is released(null), please create model again";
1887 return;
1888 }
1889 auto inputs = modelNapi->native_model_->GetInputs();
1890 std::vector<MSTensor> outputs;
1891
1892 auto predict_ret = modelNapi->native_model_->Predict(inputs, &outputs);
1893 if (predict_ret != mindspore::kSuccess) {
1894 MS_LOG(ERROR) << "model predict failed.";
1895 return;
1896 }
1897
1898 napi_create_array_with_length(env, outputs.size(), &valueParam);
1899 for (size_t i = 0; i < outputs.size(); i++) {
1900 status = napi_set_element(env, valueParam, i, MSTensorNapi::NewInstance(env, outputs[i]));
1901 if (status != napi_ok) {
1902 MS_LOG(ERROR) << "napi_set_element failed! code: " << status;
1903 }
1904 }
1905 MS_LOG(INFO) << "predict model success.";
1906 }
1907 CommonCallbackRoutine(env, asyncContext, valueParam);
1908 } else {
1909 MS_LOG(ERROR) << "ERROR: PredictAsyncCallbackComplete asyncContext is Null!";
1910 }
1911 }
1912
GetWeights(napi_env env,napi_callback_info info)1913 napi_value MSLiteModelNapi::GetWeights(napi_env env, napi_callback_info info) {
1914 napi_value undefinedResult = nullptr;
1915 napi_get_undefined(env, &undefinedResult);
1916
1917 size_t argCount = 0;
1918 napi_value jsThis = nullptr;
1919 napi_value jsResult = nullptr;
1920 MSLiteModelNapi *modelNapi = nullptr;
1921
1922 napi_status status = napi_get_cb_info(env, info, &argCount, nullptr, &jsThis, nullptr);
1923 if (status != napi_ok || jsThis == nullptr) {
1924 MS_LOG(ERROR) << "failed to retrieve details about the callback";
1925 return undefinedResult;
1926 }
1927
1928 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
1929 if (status != napi_ok || modelNapi == nullptr) {
1930 MS_LOG(ERROR) << "failed to get model";
1931 return undefinedResult;
1932 }
1933
1934 if (modelNapi->native_model_ == nullptr) {
1935 MS_LOG(ERROR) << "model is released(null), please create model again";
1936 return undefinedResult;
1937 }
1938 std::vector<MSTensor> weights = modelNapi->native_model_->GetFeatureMaps();
1939 std::vector<MSTensor> feature_maps;
1940 for (size_t i = 0; i < weights.size(); i++) {
1941 auto tensor = mindspore::MSTensor::CreateTensor(weights.at(i).Name(), weights.at(i).DataType(), {}, nullptr, 0);
1942 if (tensor == nullptr) {
1943 MS_LOG(ERROR) << "create tensor failed.";
1944 return undefinedResult;
1945 }
1946 tensor->SetShape(weights.at(i).Shape());
1947 tensor->SetFormat(weights.at(i).format());
1948 tensor->SetDataType(weights.at(i).DataType());
1949 tensor->SetData(weights.at(i).MutableData(), false);
1950 feature_maps.push_back(*tensor);
1951 delete tensor;
1952 }
1953
1954 size_t size = weights.size();
1955 MS_LOG(INFO) << "weights size: " << size;
1956 napi_create_array_with_length(env, size, &jsResult);
1957 for (size_t i = 0; i < size; i++) {
1958 status = napi_set_element(env, jsResult, i, MSTensorNapi::NewInstance(env, feature_maps[i]));
1959 if (status != napi_ok) {
1960 MS_LOG(ERROR) << "napi_set_element failed! code: " << status;
1961 }
1962 }
1963 MS_LOG(INFO) << "get model weights success";
1964 return jsResult;
1965 }
1966
SetModelInputs(napi_env env,napi_value argv,std::shared_ptr<Model> model)1967 int32_t SetModelInputs(napi_env env, napi_value argv, std::shared_ptr<Model> model) {
1968 uint32_t array_length = 0;
1969 napi_status status = napi_get_array_length(env, argv, &array_length);
1970 if (status != napi_ok || array_length <= 0) {
1971 MS_LOG(ERROR) << "get inputs tensor length failed.";
1972 return ERR_INVALID_PARAM;
1973 }
1974
1975 if (model == nullptr) {
1976 MS_LOG(ERROR) << "model is nullptr";
1977 return ERR_INVALID_PARAM;
1978 }
1979
1980 auto inputs = model->GetInputs();
1981 if (inputs.size() != array_length) {
1982 MS_LOG(ERROR) << "array length not equal to model inputs size.";
1983 return ERR_INVALID_PARAM;
1984 }
1985
1986 for (size_t i = 0; i < array_length; i++) {
1987 napi_value element = nullptr;
1988 status = napi_get_element(env, argv, i, &element);
1989 if (status != napi_ok) {
1990 MS_LOG(ERROR) << "can not get element";
1991 return ERROR;
1992 }
1993
1994 std::string property_name = "getData";
1995 bool exist = false;
1996 napi_value data_func = nullptr;
1997
1998 napi_status status = napi_has_named_property(env, element, property_name.c_str(), &exist);
1999
2000 if (status != napi_ok || !exist) {
2001 MS_LOG(ERROR) << "can not find " << property_name.c_str() << " property.";
2002 return ERROR;
2003 }
2004
2005 if (napi_get_named_property(env, element, property_name.c_str(), &data_func) != napi_ok) {
2006 MS_LOG(ERROR) << "get " << property_name.c_str() << " property fail.";
2007 return ERROR;
2008 }
2009 void *js_data = nullptr;
2010 size_t length = 0;
2011 napi_value return_val;
2012
2013 status = napi_call_function(env, element, data_func, 0, nullptr, &return_val);
2014 if (status != napi_ok || return_val == nullptr) {
2015 MS_LOG(ERROR) << "napi call function error.";
2016 return ERROR;
2017 }
2018 status = napi_get_arraybuffer_info(env, return_val, &js_data, &length);
2019 if (status != napi_ok || js_data == nullptr) {
2020 MS_LOG(ERROR) << "Get js data error.";
2021 return ERROR;
2022 }
2023 if (inputs[i].DataSize() != length) {
2024 MS_LOG(ERROR) << "tensor size is: " << static_cast<int>(inputs[i].DataSize()) << ", but data length got "
2025 << static_cast<int>(length);
2026 return ERROR;
2027 }
2028
2029 auto tensor_data = inputs[i].MutableData();
2030 if (tensor_data == nullptr) {
2031 MS_LOG(ERROR) << "malloc data for tensor failed.";
2032 return ERROR;
2033 }
2034 memcpy(tensor_data, js_data, length);
2035 }
2036 return SUCCESS;
2037 }
2038
RunStep(napi_env env,napi_callback_info info)2039 napi_value MSLiteModelNapi::RunStep(napi_env env, napi_callback_info info) {
2040 napi_value undefinedResult = nullptr;
2041 bool result = false;
2042 napi_status status = napi_get_boolean(env, result, &undefinedResult);
2043 if (status != napi_ok) {
2044 MS_LOG(ERROR) << "get bool error";
2045 return undefinedResult;
2046 }
2047
2048 napi_value jsThis = nullptr;
2049 MSLiteModelNapi *modelNapi = nullptr;
2050 size_t argCount = PARAM1;
2051 napi_value argv[ARGS_ONE] = {0};
2052
2053 status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2054 if (status != napi_ok || jsThis == nullptr) {
2055 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2056 return undefinedResult;
2057 }
2058
2059 if (argCount < ARGS_ONE) {
2060 MS_LOG(ERROR) << "argument num is less than one, please give input tensors";
2061 return undefinedResult;
2062 }
2063
2064 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2065 if (status != napi_ok || modelNapi == nullptr) {
2066 MS_LOG(ERROR) << "get model napi error";
2067 return undefinedResult;
2068 }
2069
2070 if (SetModelInputs(env, argv[PARAM0], modelNapi->native_model_) != SUCCESS) {
2071 MS_LOG(ERROR) << "set tensor data failed";
2072 return undefinedResult;
2073 }
2074
2075 if (modelNapi->native_model_ == nullptr) {
2076 MS_LOG(ERROR) << "model is released(null), please create model again";
2077 return undefinedResult;
2078 }
2079
2080 auto ret = modelNapi->native_model_->RunStep();
2081 if (ret != kSuccess) {
2082 MS_LOG(ERROR) << "Model run step failed";
2083 return undefinedResult;
2084 }
2085 status = napi_get_boolean(env, true, &undefinedResult);
2086 if (status != napi_ok) {
2087 MS_LOG(ERROR) << "create bool true value failed";
2088 return undefinedResult;
2089 }
2090 return undefinedResult;
2091 }
2092
UpdateWeights(napi_env env,napi_callback_info info)2093 napi_value MSLiteModelNapi::UpdateWeights(napi_env env, napi_callback_info info) {
2094 napi_value undefinedResult = nullptr;
2095 bool result = false;
2096 napi_status status = napi_get_boolean(env, result, &undefinedResult);
2097 if (status != napi_ok) {
2098 MS_LOG(ERROR) << "get bool error";
2099 return undefinedResult;
2100 }
2101
2102 napi_value jsThis = nullptr;
2103 napi_value jsResult = nullptr;
2104 MSLiteModelNapi *modelNapi = nullptr;
2105 napi_value argv[ARGS_ONE] = {0};
2106 size_t argCount = PARAM1;
2107 status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2108 if (status != napi_ok || jsThis == nullptr) {
2109 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2110 return undefinedResult;
2111 }
2112 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2113 if (status != napi_ok || modelNapi == nullptr) {
2114 MS_LOG(ERROR) << "get model napi error";
2115 return undefinedResult;
2116 }
2117
2118 if (modelNapi->native_model_ == nullptr) {
2119 MS_LOG(ERROR) << "model is released(null), please create model again";
2120 return undefinedResult;
2121 }
2122
2123 // set inputs data
2124 uint32_t array_length = 0;
2125 status = napi_get_array_length(env, argv[PARAM0], &array_length);
2126 if (status != napi_ok || array_length <= 0) {
2127 MS_LOG(ERROR) << "get inputs tensor length failed.";
2128 return undefinedResult;
2129 }
2130
2131 std::vector<MSTensor> weights;
2132 for (size_t i = 0; i < array_length; i++) {
2133 napi_value element = nullptr;
2134 status = napi_get_element(env, argv[PARAM0], i, &element);
2135 if (status != napi_ok) {
2136 MS_LOG(ERROR) << "can not get element";
2137 return undefinedResult;
2138 }
2139
2140 // get tensor name
2141 std::string tensor_name;
2142 auto ret = CommonNapi::GetPropertyString(env, element, "name", tensor_name);
2143 if (ret != SUCCESS) {
2144 MS_LOG(ERROR) << "get tensor name property failed";
2145 return undefinedResult;
2146 }
2147
2148 // get tensor format
2149 int format;
2150 ret = CommonNapi::GetPropertyInt32(env, element, "format", format);
2151 if (ret != SUCCESS) {
2152 MS_LOG(ERROR) << "get format property failed";
2153 return undefinedResult;
2154 }
2155
2156 // get dtype
2157 int dtype;
2158 ret = CommonNapi::GetPropertyInt32(env, element, "dtype", dtype);
2159 if (ret != SUCCESS) {
2160 MS_LOG(ERROR) << "get format property failed";
2161 return undefinedResult;
2162 }
2163
2164 // get data size
2165 int data_size;
2166 ret = CommonNapi::GetPropertyInt32(env, element, "dataSize", data_size);
2167 if (ret != SUCCESS) {
2168 MS_LOG(ERROR) << "get dataSize property failed";
2169 return undefinedResult;
2170 }
2171
2172 // get shape
2173 std::vector<int32_t> shape;
2174 ret = CommonNapi::GetPropertyInt32Array(env, element, "shape", shape);
2175 if (ret != SUCCESS) {
2176 MS_LOG(ERROR) << "get shape property failed";
2177 return undefinedResult;
2178 }
2179
2180 // get data
2181 std::string property_name = "getData";
2182 bool exist = false;
2183 napi_value data_func = nullptr;
2184
2185 status = napi_has_named_property(env, element, property_name.c_str(), &exist);
2186 if (status != napi_ok || !exist) {
2187 MS_LOG(ERROR) << "can not find target property";
2188 return undefinedResult;
2189 }
2190
2191 if (napi_get_named_property(env, element, property_name.c_str(), &data_func) != napi_ok) {
2192 MS_LOG(ERROR) << "get " << property_name.c_str() << " property fail.";
2193 return undefinedResult;
2194 }
2195 void *js_data = nullptr;
2196 size_t length = 0;
2197
2198 napi_value return_val;
2199 status = napi_call_function(env, element, data_func, 0, nullptr, &return_val);
2200 if (status != napi_ok || return_val == nullptr) {
2201 MS_LOG(ERROR) << "napi call function error.";
2202 return undefinedResult;
2203 }
2204
2205 status = napi_get_arraybuffer_info(env, return_val, &js_data, &length);
2206 if (status != napi_ok || js_data == nullptr) {
2207 MS_LOG(ERROR) << "get js data error.";
2208 return undefinedResult;
2209 }
2210
2211 std::vector<int64_t> int64_shape;
2212 int64_shape.reserve(shape.size());
2213 std::transform(shape.begin(), shape.end(), std::back_inserter(int64_shape), [](int32_t value) {
2214 return static_cast<int64_t>(value);
2215 });
2216 auto tensor = mindspore::MSTensor::CreateTensor(tensor_name, static_cast<mindspore::DataType>(dtype), int64_shape, nullptr, 0);
2217 if (tensor == nullptr) {
2218 MS_LOG(ERROR) << "create tensor failed.";
2219 return undefinedResult;
2220 }
2221 tensor->SetFormat(static_cast<mindspore::Format>(format));
2222 auto tensor_data = tensor->MutableData();
2223 if (tensor_data == nullptr) {
2224 MS_LOG(ERROR) << "mutable tensor data failed, get nullptr";
2225 return undefinedResult;
2226 }
2227
2228 if (tensor->DataSize() != length) {
2229 MS_LOG(ERROR) << "tensor size is: " << static_cast<int>(tensor->DataSize()) << ", but data length got "
2230 << static_cast<int>(length);
2231 return undefinedResult;
2232 }
2233
2234 memcpy(tensor_data, js_data, length);
2235 weights.push_back(*tensor);
2236 delete tensor;
2237 }
2238
2239 if (modelNapi->native_model_->UpdateFeatureMaps(weights) != mindspore::kSuccess) {
2240 MS_LOG(ERROR) << "UpdateFeatureMaps failed";
2241 return undefinedResult;
2242 }
2243 status = napi_get_boolean(env, true, &jsResult);
2244 if (status != napi_ok) {
2245 MS_LOG(ERROR) << "get bool error";
2246 return undefinedResult;
2247 }
2248 return jsResult;
2249 }
2250
ExportModel(napi_env env,napi_callback_info info)2251 napi_value MSLiteModelNapi::ExportModel(napi_env env, napi_callback_info info) {
2252 napi_value undefinedResult = nullptr;
2253 bool result = false;
2254 napi_status status = napi_get_boolean(env, result, &undefinedResult);
2255 if (status != napi_ok) {
2256 MS_LOG(ERROR) << "get bool error";
2257 return undefinedResult;
2258 }
2259
2260 napi_value jsThis = nullptr;
2261 napi_value jsResult = nullptr;
2262 MSLiteModelNapi *modelNapi = nullptr;
2263 napi_value argv[ARGS_FOUR] = {0};
2264 size_t argCount = PARAM4;
2265 status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2266 if (status != napi_ok || jsThis == nullptr) {
2267 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2268 return undefinedResult;
2269 }
2270 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2271 if (status != napi_ok || modelNapi == nullptr) {
2272 MS_LOG(ERROR) << "get model napi error";
2273 return undefinedResult;
2274 }
2275
2276 if (modelNapi->native_model_ == nullptr) {
2277 MS_LOG(ERROR) << "model is released(null), please create model again";
2278 return undefinedResult;
2279 }
2280
2281 // get modelfile
2282 char char_buf[SIZE];
2283 size_t buf_length = 0;
2284 status = napi_get_value_string_utf8(env, argv[PARAM0], char_buf, SIZE, &buf_length);
2285 if ((status != napi_ok) || (buf_length <= 0)) {
2286 MS_LOG(ERROR) << "Parse model file failed.";
2287 return undefinedResult;
2288 }
2289
2290 std::string model_path;
2291 model_path.assign(char_buf, char_buf + buf_length);
2292 MS_LOG(DEBUG) << "model_path: " << model_path.c_str();
2293
2294 mindspore::QuantizationType quantization_type = kNoQuant;
2295 int32_t quantization_type_value;
2296 // get quantization
2297 if (argCount >= ARGS_TWO) {
2298 if (napi_get_value_int32(env, argv[PARAM1], &quantization_type_value) != napi_ok) {
2299 MS_LOG(WARNING) << "fail to get int32_t value from quantizationType";
2300 return undefinedResult;
2301 }
2302 quantization_type = static_cast<mindspore::QuantizationType>(quantization_type_value);
2303 }
2304
2305 // get inference mode
2306 bool export_inference_only = true;
2307 if (argCount >= ARGS_THREE) {
2308 if (napi_get_value_bool(env, argv[PARAM2], &export_inference_only) != napi_ok) {
2309 MS_LOG(WARNING) << "fail to get bool value from exportInferenceOnly";
2310 return undefinedResult;
2311 }
2312 }
2313
2314 // get output names
2315 std::vector<std::string> output_tensor_name;
2316 if (argCount >= ARGS_FOUR) {
2317 auto ret = CommonNapi::GetStringArray(env, argv[PARAM3], output_tensor_name);
2318 if (ret != SUCCESS) {
2319 MS_LOG(ERROR) << "Get context target failed.";
2320 return undefinedResult;
2321 }
2322 }
2323
2324 auto ret = mindspore::Serialization::ExportModel(*(modelNapi->native_model_.get()), static_cast<mindspore::ModelType>(kMindIR),
2325 model_path, static_cast<mindspore::QuantizationType>(quantization_type),
2326 export_inference_only, output_tensor_name);
2327 if (ret != mindspore::kSuccess) {
2328 MS_LOG(ERROR) << "Export model failed";
2329 return undefinedResult;
2330 }
2331
2332 status = napi_get_boolean(env, true, &jsResult);
2333 if (status != napi_ok) {
2334 MS_LOG(ERROR) << "get bool error";
2335 return undefinedResult;
2336 }
2337 MS_LOG(DEBUG) << "Export Model Success";
2338 return jsResult;
2339 }
2340
ExportWeightsCollaborateWithMicro(napi_env env,napi_callback_info info)2341 napi_value MSLiteModelNapi::ExportWeightsCollaborateWithMicro(napi_env env, napi_callback_info info) {
2342 napi_value undefinedResult = nullptr;
2343 bool result = false;
2344 napi_status status = napi_get_boolean(env, result, &undefinedResult);
2345 if (status != napi_ok) {
2346 MS_LOG(ERROR) << "get bool error";
2347 return undefinedResult;
2348 }
2349
2350 napi_value jsThis = nullptr;
2351 napi_value jsResult = nullptr;
2352 MSLiteModelNapi *modelNapi = nullptr;
2353 napi_value argv[ARGS_FOUR] = {0};
2354 size_t argCount = PARAM4;
2355 status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2356 if (status != napi_ok || jsThis == nullptr) {
2357 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2358 return undefinedResult;
2359 }
2360 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2361 if (status != napi_ok || modelNapi == nullptr) {
2362 MS_LOG(ERROR) << "get model napi error";
2363 return undefinedResult;
2364 }
2365
2366 if (modelNapi->native_model_ == nullptr) {
2367 MS_LOG(ERROR) << "model is released(null), please create model again";
2368 return undefinedResult;
2369 }
2370
2371 // get weight file
2372 char char_buf[SIZE];
2373 size_t buf_length = 0;
2374 status = napi_get_value_string_utf8(env, argv[PARAM0], char_buf, SIZE, &buf_length);
2375 if ((status != napi_ok) || (buf_length <= 0)) {
2376 MS_LOG(ERROR) << "Parse model file failed.";
2377 return undefinedResult;
2378 }
2379
2380 std::string weight_file;
2381 weight_file.assign(char_buf, char_buf + buf_length);
2382 MS_LOG(DEBUG) << "weight_file: " << weight_file.c_str();
2383
2384 // get is inference
2385 bool is_inference = true;
2386 if (argCount >= ARGS_TWO) {
2387 if (napi_get_value_bool(env, argv[PARAM1], &is_inference) != napi_ok) {
2388 MS_LOG(WARNING) << "fail to get bool value from isInference";
2389 return undefinedResult;
2390 }
2391 }
2392
2393 // get inference mode
2394 bool enable_fp16 = false;
2395 if (argCount >= ARGS_THREE) {
2396 if (napi_get_value_bool(env, argv[PARAM2], &enable_fp16) != napi_ok) {
2397 MS_LOG(WARNING) << "fail to get bool value from enableFp16";
2398 return undefinedResult;
2399 }
2400 }
2401
2402 // get output names
2403 std::vector<std::string> changeable_weights_name;
2404 if (argCount >= ARGS_FOUR) {
2405 auto ret = CommonNapi::GetStringArray(env, argv[PARAM3], changeable_weights_name);
2406 if (ret != SUCCESS) {
2407 MS_LOG(ERROR) << "failed to get string array from changeableWeightsName";
2408 return undefinedResult;
2409 }
2410 }
2411
2412 auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(*(modelNapi->native_model_.get()), static_cast<mindspore::ModelType>(kMindIR),
2413 weight_file, is_inference, enable_fp16, changeable_weights_name);
2414
2415 if (ret != mindspore::kSuccess) {
2416 MS_LOG(ERROR) << "ExportWeightsCollaborateWithMicro failed";
2417 return undefinedResult;
2418 }
2419
2420 status = napi_get_boolean(env, true, &jsResult);
2421 if (status != napi_ok) {
2422 MS_LOG(ERROR) << "get bool error";
2423 return undefinedResult;
2424 }
2425 MS_LOG(DEBUG) << "ExportWeightsCollaborateWithMicro Success";
2426 return jsResult;
2427 }
2428
SetupVirtualBatch(napi_env env,napi_callback_info info)2429 napi_value MSLiteModelNapi::SetupVirtualBatch(napi_env env, napi_callback_info info) {
2430 napi_value undefinedResult = nullptr;
2431 bool result = false;
2432 napi_status status = napi_get_boolean(env, result, &undefinedResult);
2433 if (status != napi_ok) {
2434 MS_LOG(ERROR) << "get bool error";
2435 return undefinedResult;
2436 }
2437
2438 napi_value jsThis = nullptr;
2439 napi_value jsResult = nullptr;
2440 MSLiteModelNapi *modelNapi = nullptr;
2441 napi_value argv[ARGS_THREE] = {0};
2442 size_t argCount = ARGS_THREE;
2443 status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2444 if (status != napi_ok || jsThis == nullptr) {
2445 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2446 return undefinedResult;
2447 }
2448 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2449 if (status != napi_ok || modelNapi == nullptr) {
2450 MS_LOG(ERROR) << "get model napi error";
2451 return undefinedResult;
2452 }
2453
2454 if (modelNapi->native_model_ == nullptr) {
2455 MS_LOG(ERROR) << "model is released(null), please create model again";
2456 return undefinedResult;
2457 }
2458
2459 // get virtual batch
2460 int virtual_batch_multiplier;
2461 if (napi_get_value_int32(env, argv[PARAM0], &virtual_batch_multiplier) != napi_ok) {
2462 MS_LOG(WARNING) << "fail to get int32 value from virtualBatchMultiplier";
2463 return undefinedResult;
2464 }
2465
2466 // get lr
2467 double lr = -1.0f;
2468 if (argCount >= ARGS_TWO) {
2469 if (napi_get_value_double(env, argv[PARAM1], &lr) != napi_ok) {
2470 MS_LOG(WARNING) << "fail to get double value from lr";
2471 return undefinedResult;
2472 }
2473 }
2474
2475 // get lr
2476 double momentum = -1.0f;
2477 if (argCount >= ARGS_THREE) {
2478 if (napi_get_value_double(env, argv[PARAM2], &momentum) != napi_ok) {
2479 MS_LOG(WARNING) << "fail to get double value from momentum";
2480 return undefinedResult;
2481 }
2482 }
2483
2484
2485 auto ret = modelNapi->native_model_->SetupVirtualBatch(virtual_batch_multiplier, static_cast<float>(lr), static_cast<float>(momentum));
2486
2487 if (ret != mindspore::kSuccess) {
2488 MS_LOG(ERROR) << "SetupVirtualBatch failed";
2489 return undefinedResult;
2490 }
2491
2492 status = napi_get_boolean(env, true, &jsResult);
2493 if (status != napi_ok) {
2494 MS_LOG(ERROR) << "get bool error";
2495 return undefinedResult;
2496 }
2497 return jsResult;
2498 }
GetTrainMode(napi_env env,napi_callback_info info)2499 napi_value MSLiteModelNapi::GetTrainMode(napi_env env, napi_callback_info info) {
2500 napi_value undefinedResult = nullptr;
2501
2502 napi_value jsThis = nullptr;
2503 napi_value jsResult = nullptr;
2504 MSLiteModelNapi *modelNapi = nullptr;
2505 napi_value argv[ARGS_ONE] = {0};
2506 size_t argCount = ARGS_ONE;
2507 auto status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2508 if (status != napi_ok || jsThis == nullptr) {
2509 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2510 return undefinedResult;
2511 }
2512 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2513 if (status != napi_ok || modelNapi == nullptr) {
2514 MS_LOG(ERROR) << "get model napi error";
2515 return undefinedResult;
2516 }
2517 if (modelNapi->native_model_ == nullptr) {
2518 MS_LOG(ERROR) << "model is released(null), please create model again";
2519 return undefinedResult;
2520 }
2521
2522 auto train_mode = modelNapi->native_model_->GetTrainMode();
2523
2524 status = napi_get_boolean(env, train_mode, &jsResult);
2525 if (status != napi_ok) {
2526 MS_LOG(WARNING) << "create bool value error";
2527 return undefinedResult;
2528 }
2529 return jsResult;
2530 }
SetTrainMode(napi_env env,napi_callback_info info)2531 napi_value MSLiteModelNapi::SetTrainMode(napi_env env, napi_callback_info info) {
2532 napi_value undefinedResult = nullptr;
2533
2534 napi_value jsThis = nullptr;
2535 napi_value jsResult = nullptr;
2536 MSLiteModelNapi *modelNapi = nullptr;
2537 napi_value argv[ARGS_ONE] = {0};
2538 size_t argCount = ARGS_ONE;
2539 auto status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2540 if (status != napi_ok || jsThis == nullptr) {
2541 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2542 return undefinedResult;
2543 }
2544 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2545 if (status != napi_ok || modelNapi == nullptr) {
2546 MS_LOG(ERROR) << "get model napi error";
2547 return undefinedResult;
2548 }
2549 if (modelNapi->native_model_ == nullptr) {
2550 MS_LOG(ERROR) << "model is released(null), please create model again";
2551 return undefinedResult;
2552 }
2553
2554 bool train_mode;
2555 if (napi_get_value_bool(env, argv[PARAM0], &train_mode) != napi_ok) {
2556 MS_LOG(WARNING) << "failed to get bool value from input train mode.";
2557 return undefinedResult;
2558 }
2559 if (!model_info_->train_model) {
2560 MS_LOG(WARNING) << "current model is not train model, unable to set train or eval mode";
2561 return undefinedResult;
2562 }
2563 if (modelNapi->native_model_->SetTrainMode(train_mode) != kSuccess) {
2564 MS_LOG(ERROR) << "set train mode failed";
2565 return undefinedResult;
2566 }
2567
2568 status = napi_get_boolean(env, true, &jsResult);
2569 if (status != napi_ok) {
2570 MS_LOG(WARNING) << "create bool value error";
2571 return undefinedResult;
2572 }
2573 return jsResult;
2574 }
GetLearningRate(napi_env env,napi_callback_info info)2575 napi_value MSLiteModelNapi::GetLearningRate(napi_env env, napi_callback_info info) {
2576 napi_value undefinedResult = nullptr;
2577
2578 napi_value jsThis = nullptr;
2579 napi_value jsResult = nullptr;
2580 MSLiteModelNapi *modelNapi = nullptr;
2581 napi_value argv[ARGS_ONE] = {0};
2582 size_t argCount = ARGS_ONE;
2583 auto status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2584 if (status != napi_ok || jsThis == nullptr) {
2585 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2586 return undefinedResult;
2587 }
2588 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2589 if (status != napi_ok || modelNapi == nullptr) {
2590 MS_LOG(ERROR) << "get model napi error";
2591 return undefinedResult;
2592 }
2593 if (modelNapi->native_model_ == nullptr) {
2594 MS_LOG(ERROR) << "model is released(null), please create model again";
2595 return undefinedResult;
2596 }
2597
2598 auto lr = modelNapi->native_model_->GetLearningRate();
2599
2600 status = napi_create_double(env, lr, &jsResult);
2601 if (status != napi_ok) {
2602 MS_LOG(WARNING) << "create double value error";
2603 return undefinedResult;
2604 }
2605 return jsResult;
2606 }
SetLearningRate(napi_env env,napi_callback_info info)2607 napi_value MSLiteModelNapi::SetLearningRate(napi_env env, napi_callback_info info) {
2608 napi_value undefinedResult = nullptr;
2609
2610 napi_value jsThis = nullptr;
2611 napi_value jsResult = nullptr;
2612 MSLiteModelNapi *modelNapi = nullptr;
2613 napi_value argv[ARGS_ONE] = {0};
2614 size_t argCount = ARGS_ONE;
2615 auto status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2616 if (status != napi_ok || jsThis == nullptr) {
2617 MS_LOG(ERROR) << "failed to retrieve details about the callback";
2618 return undefinedResult;
2619 }
2620 status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2621 if (status != napi_ok || modelNapi == nullptr) {
2622 MS_LOG(ERROR) << "get model napi error";
2623 return undefinedResult;
2624 }
2625 if (modelNapi->native_model_ == nullptr) {
2626 MS_LOG(ERROR) << "model is released(null), please create model again";
2627 return undefinedResult;
2628 }
2629
2630 if (!model_info_->train_model) {
2631 MS_LOG(WARNING) << "current model is not train model, unable to set learning rate";
2632 return undefinedResult;
2633 }
2634
2635 double lr;
2636 if (napi_get_value_double(env, argv[PARAM0], &lr) != napi_ok) {
2637 MS_LOG(WARNING) << "failed to get double value.";
2638 return undefinedResult;
2639 }
2640
2641 if (modelNapi->native_model_->SetLearningRate(static_cast<float>(lr)) != kSuccess) {
2642 MS_LOG(ERROR) << "set learning rate failed";
2643 return undefinedResult;
2644 }
2645
2646 status = napi_get_boolean(env, true, &jsResult);
2647 if (status != napi_ok) {
2648 MS_LOG(WARNING) << "create bool value error";
2649 return undefinedResult;
2650 }
2651 return jsResult;
2652 }
2653 } // namespace mindspore
2654