• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/js_api/mslite_model_napi.h"
17 #include <climits>
18 #include <algorithm>
19 #include <random>
20 #include <cstring>
21 #include <memory>
22 #include <map>
23 #include <vector>
24 #include <unistd.h>
25 #include <fcntl.h>
26 #include <sys/mman.h>
27 #include <sys/stat.h>
28 #include "include/js_api/mstensor_napi.h"
29 #include "include/js_api/common_napi.h"
30 #include "include/js_api/ms_parameters_napi.h"
31 #include "include/js_api/ms_errors.h"
32 #include "include/js_api/mslite_model_callback_napi.h"
33 #include "src/common/log.h"
34 #include "mindspore/lite/src/common/log.h"
35 #include "include/c_api/model_c.h"
36 #include "include/c_api/context_c.h"
37 #include "include/c_api/types_c.h"
38 #include "include/js_api/nnrt_device_desc_napi.h"
39 
40 namespace mindspore {
41 thread_local napi_ref MSLiteModelNapi::constructor_ = nullptr;
42 ModelInfo *MSLiteModelNapi::model_info_ = nullptr;
43 ContextInfo *MSLiteModelNapi::context_ = nullptr;
44 std::mutex MSLiteModelNapi::create_mutex_;
45 napi_ref MSLiteModelNapi::tensorFormat_ = nullptr;
46 napi_ref MSLiteModelNapi::tensorDataType_ = nullptr;
47 napi_ref MSLiteModelNapi::contextThreadAffinityMode_ = nullptr;
48 napi_ref MSLiteModelNapi::contextQuantizationType_ = nullptr;
49 napi_ref MSLiteModelNapi::contextOptimizationLevel_ = nullptr;
50 napi_ref MSLiteModelNapi::contextPerformanceMode_ = nullptr;
51 napi_ref MSLiteModelNapi::contextPriority_ = nullptr;
52 napi_ref MSLiteModelNapi::contextNnrtDeviceType_ = nullptr;
53 
54 #define GET_PARAMS(env, info, num) \
55   size_t argc = num;               \
56   napi_value argv[num] = {0};      \
57   napi_value thisVar = nullptr;    \
58   void *data;                      \
59   napi_get_cb_info(env, info, &argc, argv, &thisVar, &data)
60 
61 namespace {
62 const int ARGS_ONE = 1;
63 const int ARGS_TWO = 2;
64 const int ARGS_THREE = 3;
65 const int ARGS_FOUR = 4;
66 
67 const int PARAM0 = 0;
68 const int PARAM1 = 1;
69 const int PARAM2 = 2;
70 const int PARAM3 = 3;
71 const int PARAM4 = 4;
72 const int UNSET_VALUE = -1;
73 
74 const int SIZE = 100;
75 
76 const std::string CLASS_NAME = "Model";
77 
78 const std::unordered_map<std::string, DeviceType> kDeviceTypes{
79   {"cpu", kCPU},
80   {"nnrt", kNNRt},
81   {"gpu", kGPU},
82 };
83 }  // namespace
84 
MSLiteModelNapi()85 MSLiteModelNapi::MSLiteModelNapi() : native_model_(nullptr), env_(nullptr) {
86   MS_LOG(INFO) << "MSLiteModelNapi Instances create.";
87 }
88 
~MSLiteModelNapi()89 MSLiteModelNapi::~MSLiteModelNapi() {
90   native_model_ = nullptr;
91   env_ = nullptr;
92   MS_LOG(INFO) << "MSLiteModelNapi Instances destroy.";
93 }
94 
Finalize(napi_env env,void * nativeObject,void * finalize)95 void MSLiteModelNapi::Finalize(napi_env env, void *nativeObject, void *finalize) {
96   (void)env;
97   (void)finalize;
98   if (nativeObject != nullptr) {
99     // delete nativeObject
100     auto obj = static_cast<MSLiteModelNapi *>(nativeObject);
101     delete obj;
102     obj = nullptr;
103   }
104   MS_LOG(INFO) << "Finalize success";
105 }
106 
Init(napi_env env,napi_value exports)107 napi_value MSLiteModelNapi::Init(napi_env env, napi_value exports) {
108   napi_property_descriptor properties[] = {
109     DECLARE_NAPI_FUNCTION("getInputs", GetInputs),
110     DECLARE_NAPI_FUNCTION("resize", Resize),
111     DECLARE_NAPI_FUNCTION("predict", PredictAsync),
112     DECLARE_NAPI_FUNCTION("runStep", RunStep),
113     DECLARE_NAPI_FUNCTION("getWeights", GetWeights),
114     DECLARE_NAPI_FUNCTION("updateWeights", UpdateWeights),
115     DECLARE_NAPI_FUNCTION("setupVirtualBatch", SetupVirtualBatch),
116     DECLARE_NAPI_FUNCTION("exportModel", ExportModel),
117     DECLARE_NAPI_FUNCTION("exportWeightsCollaborateWithMicro", ExportWeightsCollaborateWithMicro),
118     DECLARE_NAPI_GETTER_SETTER("trainMode", GetTrainMode, SetTrainMode),
119     DECLARE_NAPI_GETTER_SETTER("learningRate", GetLearningRate, SetLearningRate),
120     };
121 
122   napi_property_descriptor staticProperty[] = {
123     DECLARE_NAPI_STATIC_FUNCTION("loadModelFromFile", LoadMSLiteModelFromFile),
124     DECLARE_NAPI_STATIC_FUNCTION("loadModelFromBuffer", LoadMSLiteModelFromBuffer),
125     DECLARE_NAPI_STATIC_FUNCTION("loadModelFromFd", LoadMSLiteModelFromFd),
126     DECLARE_NAPI_STATIC_FUNCTION("loadTrainModelFromFile", LoadMSLiteTrainModelFromFile),
127     DECLARE_NAPI_STATIC_FUNCTION("loadTrainModelFromBuffer", LoadMSLiteTrainModelFromBuffer),
128     DECLARE_NAPI_STATIC_FUNCTION("loadTrainModelFromFd", LoadMSLiteTrainModelFromFd),
129     DECLARE_NAPI_STATIC_FUNCTION("getAllNNRTDeviceDescriptions", GetAllNnrtDeviceDescs),
130     DECLARE_NAPI_PROPERTY("Format", CreateFormatObject(env)),
131     DECLARE_NAPI_PROPERTY("DataType", CreateDataTypeObject(env)),
132     DECLARE_NAPI_PROPERTY("ThreadAffinityMode", CreateThreadAffinityModeObject(env)),
133     DECLARE_NAPI_PROPERTY("QuantizationType", CreateQuantizationTypeObject(env)),
134     DECLARE_NAPI_PROPERTY("OptimizationLevel", CreateOptimizationLevelObject(env)),
135     DECLARE_NAPI_PROPERTY("PerformanceMode", CreatePerformanceModeObject(env)),
136     DECLARE_NAPI_PROPERTY("Priority", CreatePriorityObject(env)),
137     DECLARE_NAPI_PROPERTY("NNRTDeviceType", CreateNnrtDeviceTypeObject(env)),
138   };
139 
140   napi_value constructor = nullptr;
141   napi_status status = napi_define_class(env, CLASS_NAME.c_str(), NAPI_AUTO_LENGTH, Constructor, nullptr,
142                                          sizeof(properties) / sizeof(properties[0]), properties, &constructor);
143   if (status != napi_ok) {
144     MS_LOG(ERROR) << "Failed to define MSLiteModel class";
145     return nullptr;
146   }
147 
148   status = napi_create_reference(env, constructor, REFERENCE_CREATION_COUNT, &constructor_);
149   if (status != napi_ok) {
150     MS_LOG(ERROR) << "Failed to create reference of constructor";
151     return nullptr;
152   }
153 
154   status = napi_set_named_property(env, exports, CLASS_NAME.c_str(), constructor);
155   if (status != napi_ok) {
156     MS_LOG(ERROR) << "Failed to set constructor";
157     return nullptr;
158   }
159 
160   status = napi_define_properties(env, exports, sizeof(staticProperty) / sizeof(staticProperty[0]), staticProperty);
161   if (status != napi_ok) {
162     MS_LOG(ERROR) << "Failed to define static function";
163     return nullptr;
164   }
165 
166   MS_LOG(INFO) << "init success";
167   return exports;
168 }
169 
CreateFormatObject(napi_env env)170 napi_value MSLiteModelNapi::CreateFormatObject(napi_env env)
171 {
172   napi_value result = nullptr;
173   napi_status status;
174   std::string propName;
175   int32_t refCount = 1;
176 
177   status = napi_create_object(env, &result);
178   if (status == napi_ok) {
179     for (auto &iter : tensorFormatMap) {
180       propName = iter.first;
181       status = AddNamedProperty(env, result, propName, iter.second);
182       if (status != napi_ok) {
183         MS_LOG(ERROR) << "Failed to add named prop in CreateFormatObject.";
184         break;
185       }
186       propName.clear();
187     }
188     if (status == napi_ok) {
189       status = napi_create_reference(env, result, refCount, &tensorFormat_);
190       if (status == napi_ok) {
191         return result;
192       }
193     }
194   }
195   MS_LOG(ERROR) << "CreateFormatObject is Failed!";
196   napi_get_undefined(env, &result);
197   return result;
198 }
199 
CreateDataTypeObject(napi_env env)200 napi_value MSLiteModelNapi::CreateDataTypeObject(napi_env env) {
201   napi_value result = nullptr;
202   napi_status status;
203   std::string propName;
204   int32_t refCount = 1;
205 
206   status = napi_create_object(env, &result);
207   if (status == napi_ok) {
208     for (auto &iter : tensorDataTypeMap) {
209       propName = iter.first;
210       status = AddNamedProperty(env, result, propName, iter.second);
211       if (status != napi_ok) {
212         MS_LOG(ERROR) << "Failed to add named prop in CreateDataTypeObject.";
213         break;
214       }
215       propName.clear();
216     }
217     if (status == napi_ok) {
218       status = napi_create_reference(env, result, refCount, &tensorDataType_);
219       if (status == napi_ok) {
220         return result;
221       }
222     }
223   }
224   MS_LOG(ERROR) << "CreateDataTypeObject is Failed!";
225   napi_get_undefined(env, &result);
226   return result;
227 }
228 
CreateThreadAffinityModeObject(napi_env env)229 napi_value MSLiteModelNapi::CreateThreadAffinityModeObject(napi_env env) {
230   napi_value result = nullptr;
231   napi_status status;
232   std::string propName;
233   int32_t refCount = 1;
234 
235   status = napi_create_object(env, &result);
236   if (status == napi_ok) {
237     for (auto &iter : contextThreadAffinityModeMap) {
238       propName = iter.first;
239       status = AddNamedProperty(env, result, propName, iter.second);
240       if (status != napi_ok) {
241         MS_LOG(ERROR) << "Failed to add named prop in CreateThreadAffinityModeObject.";
242         break;
243       }
244       propName.clear();
245     }
246     if (status == napi_ok) {
247       status = napi_create_reference(env, result, refCount, &contextThreadAffinityMode_);
248       if (status == napi_ok) {
249         return result;
250       }
251     }
252   }
253   MS_LOG(ERROR) << "CreateThreadAffinityModeObject is Failed!";
254   napi_get_undefined(env, &result);
255   return result;
256 }
257 
CreateQuantizationTypeObject(napi_env env)258 napi_value MSLiteModelNapi::CreateQuantizationTypeObject(napi_env env) {
259   napi_value result = nullptr;
260   napi_status status;
261   std::string propName;
262   int32_t refCount = 1;
263 
264   status = napi_create_object(env, &result);
265   if (status == napi_ok) {
266     for (auto &iter : contextQuantizationTypeMap) {
267       propName = iter.first;
268       status = AddNamedProperty(env, result, propName, iter.second);
269       if (status != napi_ok) {
270         MS_LOG(ERROR) << "Failed to add named prop in CreateQuantizationTypeObject.";
271         break;
272       }
273       propName.clear();
274     }
275     if (status == napi_ok) {
276       status = napi_create_reference(env, result, refCount, &contextQuantizationType_);
277       if (status == napi_ok) {
278         return result;
279       }
280     }
281   }
282   MS_LOG(ERROR) << "CreateQuantizationTypeObject is Failed!";
283   napi_get_undefined(env, &result);
284   return result;
285 }
286 
CreateOptimizationLevelObject(napi_env env)287 napi_value MSLiteModelNapi::CreateOptimizationLevelObject(napi_env env) {
288   napi_value result = nullptr;
289   napi_status status;
290   std::string propName;
291   int32_t refCount = 1;
292 
293   status = napi_create_object(env, &result);
294   if (status == napi_ok) {
295     for (auto &iter : contextOptimizationLevelTypeMap) {
296       propName = iter.first;
297       status = AddNamedProperty(env, result, propName, iter.second);
298       if (status != napi_ok) {
299         MS_LOG(ERROR) << "Failed to add named prop in CreateOptimizationLevelObject.";
300         break;
301       }
302       propName.clear();
303     }
304     if (status == napi_ok) {
305       status = napi_create_reference(env, result, refCount, &contextOptimizationLevel_);
306       if (status == napi_ok) {
307         return result;
308       }
309     }
310   }
311   MS_LOG(ERROR) << "CreateOptimizationLevelObject is Failed!";
312   napi_get_undefined(env, &result);
313   return result;
314 }
315 
CreatePerformanceModeObject(napi_env env)316 napi_value MSLiteModelNapi::CreatePerformanceModeObject(napi_env env) {
317   napi_value result = nullptr;
318   napi_status status;
319   std::string propName;
320   int32_t refCount = 1;
321 
322   status = napi_create_object(env, &result);
323   if (status == napi_ok) {
324     for (auto &iter : contextPerformanceModeTypeMap) {
325       propName = iter.first;
326       status = AddNamedProperty(env, result, propName, iter.second);
327       if (status != napi_ok) {
328         MS_LOG(ERROR) << "Failed to add named prop in CreatePerformanceModeObject.";
329         break;
330       }
331       propName.clear();
332     }
333     if (status == napi_ok) {
334       status = napi_create_reference(env, result, refCount, &contextPerformanceMode_);
335       if (status == napi_ok) {
336         return result;
337       }
338     }
339   }
340   MS_LOG(ERROR) << "CreatePerformanceModeObject is Failed!";
341   napi_get_undefined(env, &result);
342   return result;
343 }
344 
CreatePriorityObject(napi_env env)345 napi_value MSLiteModelNapi::CreatePriorityObject(napi_env env) {
346   napi_value result = nullptr;
347   napi_status status;
348   std::string propName;
349   int32_t refCount = 1;
350 
351   status = napi_create_object(env, &result);
352   if (status == napi_ok) {
353     for (auto &iter : contextPriorityTypeMap) {
354       propName = iter.first;
355       status = AddNamedProperty(env, result, propName, iter.second);
356       if (status != napi_ok) {
357         MS_LOG(ERROR) << "Failed to add named prop in CreatePriorityObject.";
358         break;
359       }
360       propName.clear();
361     }
362     if (status == napi_ok) {
363       status = napi_create_reference(env, result, refCount, &contextPriority_);
364       if (status == napi_ok) {
365         return result;
366       }
367     }
368   }
369   MS_LOG(ERROR) << "CreatePriorityObject is Failed!";
370   napi_get_undefined(env, &result);
371   return result;
372 }
373 
CreateNnrtDeviceTypeObject(napi_env env)374 napi_value MSLiteModelNapi::CreateNnrtDeviceTypeObject(napi_env env) {
375   napi_value result = nullptr;
376   napi_status status;
377   std::string propName;
378   int32_t refCount = 1;
379 
380   status = napi_create_object(env, &result);
381   if (status == napi_ok) {
382     for (auto &iter : contextNnrtDeviceTypeTypeMap) {
383       propName = iter.first;
384       status = AddNamedProperty(env, result, propName, iter.second);
385       if (status != napi_ok) {
386         MS_LOG(ERROR) << "Failed to add named prop in CreateNnrtDeviceTypeObject.";
387         break;
388       }
389       propName.clear();
390     }
391     if (status == napi_ok) {
392       status = napi_create_reference(env, result, refCount, &contextNnrtDeviceType_);
393       if (status == napi_ok) {
394         return result;
395       }
396     }
397   }
398   MS_LOG(ERROR) << "CreateNnrtDeviceTypeObject is Failed!";
399   napi_get_undefined(env, &result);
400   return result;
401 }
402 
AddNamedProperty(napi_env env,napi_value object,const std::string name,int32_t enumValue)403 napi_status MSLiteModelNapi::AddNamedProperty(napi_env env, napi_value object, const std::string name,
404                                               int32_t enumValue) {
405   napi_status status;
406   napi_value enumNapiValue;
407 
408   status = napi_create_int32(env, enumValue, &enumNapiValue);
409   if (status == napi_ok) {
410     status = napi_set_named_property(env, object, name.c_str(), enumNapiValue);
411   }
412   return status;
413 }
414 
GetAllNnrtDeviceDescs(napi_env env,napi_callback_info info)415 napi_value MSLiteModelNapi::GetAllNnrtDeviceDescs(napi_env env, napi_callback_info info) {
416   size_t num;
417   napi_value jsResult = nullptr;
418   NNRTDeviceDesc *devices = OH_AI_GetAllNNRTDeviceDescs(&num);
419   if (devices == nullptr) {
420     MS_LOG(ERROR) << "Get all nnrt devices error, may nnrt is not supported.";
421     OH_AI_DestroyAllNNRTDeviceDescs(&devices);
422     return jsResult;
423   }
424 
425   MS_LOG(INFO) << "all nnrt devices size: " << num;
426   napi_create_array_with_length(env, num, &jsResult);
427   for (size_t i = 0; i < num; i++) {
428     NnrtDeviceDesc nnrt_device;
429     NNRTDeviceDesc *nnrt_device_desc = OH_AI_GetElementOfNNRTDeviceDescs(devices, i);
430     nnrt_device.name.assign(OH_AI_GetNameFromNNRTDeviceDesc(nnrt_device_desc));
431     size_t id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(nnrt_device_desc);
432     nnrt_device.id = id;
433     nnrt_device.type = static_cast<ContextNnrtDeviceType>(OH_AI_GetTypeFromNNRTDeviceDesc(nnrt_device_desc));
434     auto status = napi_set_element(env, jsResult, i, NnrtDeviceDescNapi::NewInstance(env, nnrt_device));
435     if (status != napi_ok) {
436       MS_LOG(ERROR) << "napi_set_element failed! code: " << status;
437       OH_AI_DestroyAllNNRTDeviceDescs(&devices);
438       return jsResult;
439     }
440   }
441   MS_LOG(INFO) << "get All nnrt devices success!";
442   OH_AI_DestroyAllNNRTDeviceDescs(&devices);
443   return jsResult;
444 }
445 
CreateModel(ModelInfo * model_info_ptr,ContextInfo * context_info_ptr)446 std::shared_ptr<mindspore::Model> MSLiteModelNapi::CreateModel(ModelInfo *model_info_ptr,
447                                                                ContextInfo *context_info_ptr) {
448   if (context_info_ptr == nullptr) {
449     MS_LOG(ERROR) << "context_info_ptr is nullptr.";
450     return nullptr;
451   }
452   // create and init context
453   std::string s;
454   for (const auto &device_name : context_info_ptr->target) {
455     s += device_name + " ";
456   }
457   MS_LOG(DEBUG) << "target device: " << s.c_str();
458 
459   auto context = std::make_shared<mindspore::Context>();
460   if (context == nullptr) {
461     MS_LOG(ERROR) << "Failed to new context.";
462     return nullptr;
463   }
464 
465   auto &device_infos = context->MutableDeviceInfo();
466   if (context_info_ptr->target.empty()) {
467     MS_LOG(ERROR) << "context is empty.";
468     return nullptr;
469   }
470   if (GetDeviceInfoContext(context_info_ptr, device_infos) != SUCCESS) {
471     MS_LOG(ERROR) << "Create context failed.";
472     return nullptr;
473   }
474   context->SetThreadNum(context_info_ptr->cpu_device.thread_num);
475   MS_LOG(DEBUG) << "current thread num is : " << context->GetThreadNum();
476 
477   switch (model_info_ptr->mode) {
478     case kBuffer: {
479       MS_LOG(DEBUG) << "input model buffer, model_buffer_total: " << model_info_ptr->model_buffer_total;
480       if (model_info_ptr->model_buffer_data == nullptr || model_info_ptr->model_buffer_total <= 0) {
481         MS_LOG(ERROR) << "Failed to build model.";
482         return nullptr;
483       }
484       std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
485       if (model_ptr == nullptr) {
486         MS_LOG(ERROR) << "Failed to new mindspore::model.";
487         return nullptr;
488       }
489       auto ret = model_ptr->Build(model_info_ptr->model_buffer_data, model_info_ptr->model_buffer_total,
490                                   mindspore::kMindIR, context);
491       if (ret == mindspore::kSuccess) {
492         MS_LOG(INFO) << "Build model from buffer success.";
493         return model_ptr;
494       }
495       break;
496     }
497     case kPath: {
498       MS_LOG(DEBUG) << "input model path, model_buffer_total: " << model_info_ptr->model_path.c_str();
499       std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
500       if (model_ptr == nullptr) {
501         MS_LOG(ERROR) << "Failed to new mindspore::model.";
502         return nullptr;
503       }
504       auto ret = model_ptr->Build(model_info_ptr->model_path, mindspore::kMindIR, context);
505       if (ret == mindspore::kSuccess) {
506         MS_LOG(INFO) << "Build model from path success.";
507         return model_ptr;
508       }
509       return nullptr;
510     }
511     case kFD: {
512       MS_LOG(DEBUG) << "input model fd:" << model_info_ptr->model_fd
513                     << ", model_buffer_total: " << model_info_ptr->model_buffer_total;
514       std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
515       if (model_ptr == nullptr) {
516         MS_LOG(ERROR) << "Failed to new mindspore::model.";
517         return nullptr;
518       }
519       auto ret = model_ptr->Build(model_info_ptr->model_buffer_data, model_info_ptr->model_buffer_total,
520                                   mindspore::kMindIR, context);
521 
522       (void)munmap(model_info_ptr->model_buffer_data, model_info_ptr->model_buffer_total);
523       if (ret == mindspore::kSuccess) {
524         MS_LOG(INFO) << "Build model from fd success.";
525         return model_ptr;
526       }
527 
528       break;
529     }
530     default: {
531       MS_LOG(ERROR) << "Invalid model mode.";
532     }
533   }
534   MS_LOG(ERROR) << "Build model failed.";
535   return nullptr;
536 }
537 
CreateTrainModel(ModelInfo * model_info_ptr,ContextInfo * context_info_ptr)538 std::shared_ptr<mindspore::Model> MSLiteModelNapi::CreateTrainModel(ModelInfo *model_info_ptr,
539                                                                     ContextInfo *context_info_ptr) {
540   // create and init context
541   std::string s;
542   for (const auto &device_name : context_info_ptr->target) {
543     s += device_name + " ";
544   }
545   MS_LOG(DEBUG) << "target device: " << s.c_str();
546 
547   auto context = std::make_shared<mindspore::Context>();
548   if (context == nullptr) {
549     MS_LOG(ERROR) << "Failed to new context.";
550     return nullptr;
551   }
552 
553   auto &device_infos = context->MutableDeviceInfo();
554   if (context_info_ptr->target.empty()) {
555     MS_LOG(ERROR) << "context is empty.";
556     return nullptr;
557   }
558   if (GetDeviceInfoContext(context_info_ptr, device_infos) != SUCCESS) {
559     MS_LOG(ERROR) << "Create context failed.";
560     return nullptr;
561   }
562 
563   auto train_cfg = std::make_shared<TrainCfg>();
564   std::vector<std::string> loss_names;
565   for (const auto &name : train_cfg->GetLossName()) {
566       loss_names.push_back(name);
567   }
568   for (const auto &name : context_info_ptr->train_cfg.loss_names) {
569       loss_names.push_back(name);
570   }
571   train_cfg->SetLossName(loss_names);
572   train_cfg->optimization_level_ = static_cast<OptimizationLevel>(context_info_ptr->train_cfg.optimization_level);
573 
574   switch (model_info_ptr->mode) {
575     case kBuffer: {
576       MS_LOG(DEBUG) << "input model buffer, model_buffer_total: " << model_info_ptr->model_buffer_total;
577       if (model_info_ptr->model_buffer_data == nullptr || model_info_ptr->model_buffer_total <= 0) {
578         MS_LOG(ERROR) << "Failed to build model.";
579         return nullptr;
580       }
581       std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
582       if (model_ptr == nullptr) {
583         MS_LOG(ERROR) << "Failed to new mindspore::model.";
584         return nullptr;
585       }
586       mindspore::Graph graph;
587       auto status = mindspore::Serialization::Load(model_info_ptr->model_buffer_data,
588                                                    model_info_ptr->model_buffer_total, mindspore::kMindIR, &graph);
589       if (status != mindspore::kSuccess) {
590         MS_LOG(ERROR) << "load ms file failed.";
591         return nullptr;
592       }
593       auto ret = model_ptr->Build(static_cast<mindspore::GraphCell>(graph), context, train_cfg);
594       if (ret == mindspore::kSuccess) {
595         MS_LOG(INFO) << "Build model from buffer success.";
596         return model_ptr;
597       }
598       break;
599     }
600     case kPath: {
601       MS_LOG(DEBUG) << "input model path, model_buffer_total: " << model_info_ptr->model_path.c_str();
602       std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
603       if (model_ptr == nullptr) {
604         MS_LOG(ERROR) << "Failed to new mindspore::model.";
605         return nullptr;
606       }
607 
608       mindspore::Graph graph;
609       auto status = mindspore::Serialization::Load(model_info_ptr->model_path, mindspore::kMindIR, &graph);
610       if (status != mindspore::kSuccess) {
611         MS_LOG(ERROR) << "load ms file failed.";
612         return nullptr;
613       }
614       auto ret = model_ptr->Build(static_cast<mindspore::GraphCell>(graph), context, train_cfg);
615       if (ret == mindspore::kSuccess) {
616         MS_LOG(INFO) << "Build model from path success.";
617         return model_ptr;
618       }
619       return nullptr;
620     }
621     case kFD: {
622       MS_LOG(DEBUG) << "input model fd:" << model_info_ptr->model_fd
623                     << ", model_buffer_total: " << model_info_ptr->model_buffer_total;
624       std::shared_ptr<mindspore::Model> model_ptr = std::make_shared<mindspore::Model>();
625       if (model_ptr == nullptr) {
626         MS_LOG(ERROR) << "Failed to new mindspore::model.";
627         return nullptr;
628       }
629 
630       mindspore::Graph graph;
631       auto status = mindspore::Serialization::Load(model_info_ptr->model_buffer_data,
632                                                    model_info_ptr->model_buffer_total, mindspore::kMindIR, &graph);
633       if (status != mindspore::kSuccess) {
634         MS_LOG(ERROR) << "load ms file failed.";
635         return nullptr;
636       }
637       auto ret = model_ptr->Build(static_cast<mindspore::GraphCell>(graph), context, train_cfg);
638       (void)munmap(model_info_ptr->model_buffer_data, model_info_ptr->model_buffer_total);
639       if (ret == mindspore::kSuccess) {
640         MS_LOG(INFO) << "Build model from fd success.";
641         return model_ptr;
642       }
643 
644       break;
645     }
646     default: {
647       MS_LOG(ERROR) << "Invalid model mode.";
648     }
649   }
650   MS_LOG(ERROR) << "Build model failed.";
651   return nullptr;
652 }
653 
GetDeviceInfoContext(ContextInfo * context_ptr,std::vector<std::shared_ptr<DeviceInfoContext>> & device_infos)654 int32_t MSLiteModelNapi::GetDeviceInfoContext(ContextInfo *context_ptr,
655                                               std::vector<std::shared_ptr<DeviceInfoContext>> &device_infos) {
656   for (auto device_name : context_ptr->target) {
657     if (kDeviceTypes.find(device_name) == kDeviceTypes.end()) {
658       MS_LOG(ERROR) << "Invalid device: " << device_name.c_str();
659       return ERR_INVALID_OPERATION;
660     }
661 
662     auto device_type = kDeviceTypes.at(device_name);
663     switch (device_type) {
664       case kCPU: {
665         auto cpu_device = std::make_shared<mindspore::CPUDeviceInfo>();
666         if (cpu_device == nullptr) {
667           MS_LOG(ERROR) << "Failed to new CPU deviceInfo.";
668           return ERR_INVALID_OPERATION;
669         }
670         bool is_fp16 = (context_ptr->cpu_device.precision_mode.compare("preferred_fp16") == 0) ? true : false;
671         cpu_device->SetEnableFP16(is_fp16);
672         device_infos.push_back(cpu_device);
673         break;
674       }
675       case kNNRt: {
676         auto nnrt_device = std::make_shared<mindspore::NNRTDeviceInfo>();
677         if (nnrt_device == nullptr) {
678           MS_LOG(ERROR) << "Failed to new NNRT deviceInfo.";
679           return ERR_INVALID_OPERATION;
680         }
681         nnrt_device->SetDeviceID(context_ptr->nnrt_device.device_id);
682         if (context_ptr->nnrt_device.performance_mode != UNSET_VALUE) {
683           nnrt_device->SetPerformanceMode(context_ptr->nnrt_device.performance_mode);
684         }
685         if (context_ptr->nnrt_device.priority != UNSET_VALUE) {
686           nnrt_device->SetPriority(context_ptr->nnrt_device.priority);
687         }
688         // ignore extensions
689         device_infos.push_back(nnrt_device);
690         break;
691       }
692       default: {
693         MS_LOG(ERROR) << "invalid device.";
694         return ERR_INVALID_OPERATION;
695       }
696     }
697   }
698   return SUCCESS;
699 }
700 
Constructor(napi_env env,napi_callback_info info)701 napi_value MSLiteModelNapi::Constructor(napi_env env, napi_callback_info info) {
702   napi_status status;
703   napi_value result = nullptr;
704   napi_get_undefined(env, &result);
705   GET_PARAMS(env, info, ARGS_TWO);
706 
707   std::unique_ptr<MSLiteModelNapi> model_napi = std::make_unique<MSLiteModelNapi>();
708   if (model_napi == nullptr) {
709     MS_LOG(ERROR) << "No memory";
710     return result;
711   }
712 
713   model_napi->env_ = env;
714   if (model_info_->train_model) {
715     model_napi->native_model_ = CreateTrainModel(model_info_, context_);
716   } else {
717     model_napi->native_model_ = CreateModel(model_info_, context_);
718   }
719   if (model_napi->native_model_ == nullptr) {
720     MS_LOG(ERROR) << "Failed to create model.";
721     return result;
722   }
723 
724   status =
725     napi_wrap(env, thisVar, reinterpret_cast<void *>(model_napi.get()), MSLiteModelNapi::Finalize, nullptr, nullptr);
726   if (status == napi_ok) {
727     model_napi.release();
728     return thisVar;
729   }
730   return result;
731 }
732 
ParseModelInfo(napi_env env,napi_value root,ModelInfo & model_info)733 int32_t MSLiteModelNapi::ParseModelInfo(napi_env env, napi_value root, ModelInfo &model_info) {
734   napi_valuetype valueType;
735   napi_status status = napi_typeof(env, root, &valueType);
736   if (status != napi_ok) {
737     MS_LOG(ERROR) << "napi_typeof error.";
738     return ERR_INVALID_PARAM;
739   }
740   if ((valueType != napi_object) && (valueType != napi_string) && (valueType != napi_number)) {
741     MS_LOG(ERROR) << "model is invaild.";
742     return ERR_INVALID_PARAM;
743   }
744 
745   bool is_model_buffer = false;
746   napi_is_arraybuffer(env, root, &is_model_buffer);
747   if (is_model_buffer) {
748     // copy buffer
749     char *array_buffer_data;
750     size_t array_buffer_total;
751     status = napi_get_arraybuffer_info(env, root, reinterpret_cast<void **>(&array_buffer_data), &array_buffer_total);
752     if ((status != napi_ok) || (array_buffer_total <= 0)) {
753       MS_LOG(ERROR) << "Parse model buffer failed.";
754       return ERR_INVALID_PARAM;
755     }
756 
757     // shallow copy
758     model_info.model_buffer_data = array_buffer_data;
759     model_info.model_buffer_total = array_buffer_total;
760     model_info.mode = kBuffer;
761   } else if (valueType == napi_number) {
762     int32_t fd;
763     status = napi_get_value_int32(env, root, &fd);
764     if ((status != napi_ok) || (fd <= 0)) {
765       MS_LOG(ERROR) << "Parse model FD failed.";
766       return ERR_INVALID_PARAM;
767     }
768 
769     int size = lseek(fd, 0, SEEK_END);
770     (void)lseek(fd, 0, SEEK_SET);
771     auto mmap_buffers = mmap(NULL, size, PROT_READ, MAP_SHARED, fd, 0);
772     if (mmap_buffers == NULL) {
773       MS_LOG(ERROR) << "mmap_buffers is NULL.";
774       return ERR_INVALID_PARAM;
775     }
776     model_info.model_fd = fd;
777     model_info.model_buffer_data = static_cast<char *>(mmap_buffers);
778     model_info.model_buffer_total = size;
779     model_info.mode = kFD;
780   } else {
781     char char_buf[SIZE];
782     size_t buf_length = 0;
783     status = napi_get_value_string_utf8(env, root, char_buf, SIZE, &buf_length);
784     if ((status != napi_ok) || (buf_length <= 0)) {
785       MS_LOG(ERROR) << "Parse model file failed.";
786       return ERR_INVALID_PARAM;
787     }
788     model_info.model_path.assign(char_buf, char_buf + buf_length);
789     model_info.mode = kPath;
790     MS_LOG(DEBUG) << "model_path: " << model_info.model_path.c_str();
791   }
792   return SUCCESS;
793 }
794 
ParseContextInfo(napi_env env,napi_value args,ContextInfo & context)795 int32_t MSLiteModelNapi::ParseContextInfo(napi_env env, napi_value args, ContextInfo &context) {
796   napi_valuetype valueType;
797   napi_status status = napi_typeof(env, args, &valueType);
798   if ((status != napi_ok) || (valueType != napi_object)) {
799     MS_LOG(ERROR) << "context is invaild.";
800     return ERR_NOT_EXISTED_PARAM;
801   }
802 
803   std::vector<std::string> str_values;
804   auto ret = CommonNapi::GetPropertyStringArray(env, args, "target", str_values);
805   if (ret != SUCCESS) {
806     MS_LOG(ERROR) << "Get context target failed.";
807     return ret;
808   }
809   context.target.assign(str_values.begin(), str_values.end());
810 
811   ret = GetCpuDeviceInfo(env, args, context);
812   if (ret != ERR_NOT_EXISTED_PARAM && ret != SUCCESS) {
813     MS_LOG(ERROR) << "Get context CpuDeviceInfo failed.";
814     return ret;
815   }
816 
817   ret = GetNNRTDeviceInfo(env, args, context);
818   if (ret != ERR_NOT_EXISTED_PARAM && ret != SUCCESS) {
819     MS_LOG(ERROR) << "Get context NnrtDeviceInfo failed.";
820     return ret;
821   }
822   return SUCCESS;
823 }
824 
ParseTrainCfgInfo(napi_env env,napi_value root,TrainConfig & cfg)825 int32_t MSLiteModelNapi::ParseTrainCfgInfo(napi_env env, napi_value root, TrainConfig &cfg) {
826   napi_valuetype valueType;
827   napi_status status = napi_typeof(env, root, &valueType);
828   if ((status != napi_ok) || (valueType != napi_object)) {
829     MS_LOG(ERROR) << "TrainCfg is invaild.";
830     return ERR_NOT_EXISTED_PARAM;
831   }
832   std::vector<std::string> str_values;
833   auto ret = CommonNapi::GetPropertyStringArray(env, root, "lossName", str_values);
834   if (ret != SUCCESS && ret != ERR_NOT_EXISTED_PARAM) {
835     MS_LOG(ERROR) << "Get lossName failed.";
836     return ret;
837   }
838   cfg.loss_names.assign(str_values.begin(), str_values.end());
839 
840   int32_t int_value = 0;
841   ret = CommonNapi::GetPropertyInt32(env, root, "optimizationLevel", int_value);
842   if (ret != SUCCESS && ret != ERR_NOT_EXISTED_PARAM) {
843     MS_LOG(ERROR) << "Get optimization level failed";
844     return ret;
845   } else {
846     cfg.optimization_level = int_value;
847   }
848   return SUCCESS;
849 }
850 
CreateMSLiteModelWrapper(napi_env env,MSLiteModelAsyncContext * async_context)851 napi_value MSLiteModelNapi::CreateMSLiteModelWrapper(napi_env env, MSLiteModelAsyncContext *async_context) {
852   std::lock_guard<std::mutex> lock(create_mutex_);
853   napi_status status;
854   napi_value result = nullptr;
855   napi_value constructor;
856   napi_get_undefined(env, &result);
857 
858   status = napi_get_reference_value(env, constructor_, &constructor);
859   if (status != napi_ok) {
860     MS_LOG(ERROR) << "get reference failed.";
861     return result;
862   }
863   model_info_ = &(async_context->model_info);
864   context_ = &(async_context->context);
865   status = napi_new_instance(env, constructor, 0, nullptr, &result);
866   if (status == napi_ok) {
867     return result;
868   }
869 
870   return result;
871 }
872 
GetMSLiteModelAsyncCallbackComplete(napi_env env,napi_status status,void * data)873 void MSLiteModelNapi::GetMSLiteModelAsyncCallbackComplete(napi_env env, napi_status status, void *data) {
874   napi_value valueParam = nullptr;
875   auto async_context = static_cast<MSLiteModelAsyncContext *>(data);
876 
877   if (async_context != nullptr) {
878     if (!async_context->status) {
879       valueParam = CreateMSLiteModelWrapper(env, async_context);
880     }
881     CommonCallbackRoutine(env, async_context, valueParam);
882   } else {
883     MS_LOG(ERROR) << "GetMSLiteModelAsyncCallbackComplete asyncContext is Null!";
884   }
885 }
886 
CommonCallbackRoutine(napi_env env,MSLiteModelAsyncContext * & asyncContext,const napi_value & valueParam)887 void MSLiteModelNapi::CommonCallbackRoutine(napi_env env, MSLiteModelAsyncContext *&asyncContext,
888                                             const napi_value &valueParam) {
889   napi_value result[ARGS_ONE] = {0};
890   napi_value retVal;
891   napi_value error = nullptr;
892 
893   if (!asyncContext->status) {
894     result[PARAM0] = valueParam;
895   } else {
896     napi_value message = nullptr;
897     std::string messageValue = CommonNapi::getMessageByCode(asyncContext->status);
898     napi_create_string_utf8(env, messageValue.c_str(), NAPI_AUTO_LENGTH, &message);
899 
900     napi_value code = nullptr;
901     napi_create_string_utf8(env, (std::to_string(asyncContext->status)).c_str(), NAPI_AUTO_LENGTH, &code);
902 
903     napi_create_error(env, code, message, &error);
904     napi_get_undefined(env, &result[PARAM0]);
905   }
906 
907   if (asyncContext->deferred != nullptr) {
908     if (!asyncContext->status) {
909       napi_resolve_deferred(env, asyncContext->deferred, result[PARAM0]);
910     } else {
911       napi_reject_deferred(env, asyncContext->deferred, error);
912     }
913   } else {
914     napi_value callback = nullptr;
915     napi_get_reference_value(env, asyncContext->callbackRef, &callback);
916     napi_call_function(env, nullptr, callback, ARGS_ONE, result, &retVal);
917     napi_delete_reference(env, asyncContext->callbackRef);
918   }
919   napi_delete_async_work(env, asyncContext->work);
920 
921   delete asyncContext;
922   asyncContext = nullptr;
923 }
924 
LoadMSLiteModelFromFile(napi_env env,napi_callback_info info)925 napi_value MSLiteModelNapi::LoadMSLiteModelFromFile(napi_env env, napi_callback_info info) {
926   napi_status status;
927   napi_value result = nullptr;
928   const int32_t refCount = 1;
929   GET_PARAMS(env, info, ARGS_THREE);
930   napi_valuetype valueType = napi_undefined;
931 
932   std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
933 
934   int32_t ret;
935   for (size_t i = PARAM0; i < argc; i++) {
936     if (i == PARAM0) {
937       ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
938       if (ret != SUCCESS) {
939         MS_LOG(ERROR) << "Parsing model failed.";
940         return result;
941       }
942     } else if (i == PARAM1) {
943       napi_typeof(env, argv[i], &valueType);
944       if (valueType == napi_function) {
945         napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
946       } else {
947         ret = ParseContextInfo(env, argv[i], asyncContext->context);
948         if (ret != SUCCESS) {
949           MS_LOG(ERROR) << "Parsing context failed.";
950           return result;
951         }
952       }
953     } else if (i == PARAM2) {
954       napi_typeof(env, argv[i], &valueType);
955       if (valueType == napi_function) {
956         napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
957       }
958       break;
959     } else {
960       MS_LOG(ERROR) << "Invalid input params.";
961       return result;
962     }
963   }
964 
965   if (asyncContext->callbackRef == nullptr) {
966     status = napi_create_promise(env, &asyncContext->deferred, &result);
967     if (status != napi_ok) {
968       MS_LOG(ERROR) << "create promise failed.";
969       return result;
970     }
971   } else {
972     status = napi_get_undefined(env, &result);
973     if (status != napi_ok) {
974       MS_LOG(ERROR) << "create callback failed.";
975       return result;
976     }
977   }
978 
979   napi_value resource = nullptr;
980   napi_create_string_utf8(env, "LoadMSLiteModelFromFile", NAPI_AUTO_LENGTH, &resource);
981   status = napi_create_async_work(
982     env, nullptr, resource,
983     [](napi_env env, void *data) {
984       auto context = static_cast<MSLiteModelAsyncContext *>(data);
985       context->status = SUCCESS;
986     },
987     GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
988   if (status != napi_ok) {
989     result = nullptr;
990   } else {
991     status = napi_queue_async_work(env, asyncContext->work);
992     if (status == napi_ok) {
993       asyncContext.release();
994     } else {
995       result = nullptr;
996     }
997   }
998   return result;
999 }
1000 
LoadMSLiteTrainModelFromFile(napi_env env,napi_callback_info info)1001 napi_value MSLiteModelNapi::LoadMSLiteTrainModelFromFile(napi_env env, napi_callback_info info) {
1002   napi_status status;
1003   napi_value result = nullptr;
1004   GET_PARAMS(env, info, ARGS_THREE);
1005 
1006   std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1007 
1008   asyncContext->model_info.train_model = true;
1009   int32_t ret;
1010   for (size_t i = PARAM0; i < argc; i++) {
1011     if (i == PARAM0) {
1012       ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1013       if (ret != SUCCESS) {
1014         MS_LOG(ERROR) << "Parsing model failed.";
1015         return result;
1016       }
1017     } else if (i == PARAM1) {
1018       ret = ParseTrainCfgInfo(env, argv[i], asyncContext->context.train_cfg);
1019       if (ret != SUCCESS) {
1020         MS_LOG(ERROR) << "Parsing TrainCfg failed.";
1021         return result;
1022       }
1023     } else if (i == PARAM2) {
1024       ret = ParseContextInfo(env, argv[i], asyncContext->context);
1025       if (ret != SUCCESS) {
1026         MS_LOG(ERROR) << "Parsing context failed.";
1027         return result;
1028       }
1029     } else {
1030       MS_LOG(ERROR) << "Invalid input params.";
1031       return result;
1032     }
1033   }
1034 
1035   if (asyncContext->callbackRef == nullptr) {
1036     status = napi_create_promise(env, &asyncContext->deferred, &result);
1037     if (status != napi_ok) {
1038       MS_LOG(ERROR) << "create promise failed.";
1039       return result;
1040     }
1041   } else {
1042     status = napi_get_undefined(env, &result);
1043     if (status != napi_ok) {
1044       MS_LOG(ERROR) << "create callback failed.";
1045       return result;
1046     }
1047   }
1048 
1049   napi_value resource = nullptr;
1050   napi_create_string_utf8(env, "LoadMSLiteTrainModelFromFile", NAPI_AUTO_LENGTH, &resource);
1051   status = napi_create_async_work(
1052     env, nullptr, resource,
1053     [](napi_env env, void *data) {
1054       auto context = static_cast<MSLiteModelAsyncContext *>(data);
1055       context->status = SUCCESS;
1056     },
1057     GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1058   if (status != napi_ok) {
1059     result = nullptr;
1060   } else {
1061     status = napi_queue_async_work(env, asyncContext->work);
1062     if (status == napi_ok) {
1063       asyncContext.release();
1064     } else {
1065       result = nullptr;
1066     }
1067   }
1068   return result;
1069 }
1070 
LoadMSLiteTrainModelFromBuffer(napi_env env,napi_callback_info info)1071 napi_value MSLiteModelNapi::LoadMSLiteTrainModelFromBuffer(napi_env env, napi_callback_info info) {
1072   napi_status status;
1073   napi_value result = nullptr;
1074   GET_PARAMS(env, info, ARGS_THREE);
1075 
1076   std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1077 
1078   asyncContext->model_info.train_model = true;
1079   int32_t ret;
1080   for (size_t i = PARAM0; i < argc; i++) {
1081     if (i == PARAM0) {
1082       ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1083       if (ret != SUCCESS) {
1084         MS_LOG(ERROR) << "Parsing model failed.";
1085         return result;
1086       }
1087     } else if (i == PARAM1) {
1088       ret = ParseTrainCfgInfo(env, argv[i], asyncContext->context.train_cfg);
1089       if (ret != SUCCESS) {
1090         MS_LOG(ERROR) << "Parsing TrainCfg failed.";
1091         return result;
1092       }
1093     } else if (i == PARAM2) {
1094       ret = ParseContextInfo(env, argv[i], asyncContext->context);
1095       if (ret != SUCCESS) {
1096         MS_LOG(ERROR) << "Parsing context failed.";
1097         return result;
1098       }
1099     } else {
1100       MS_LOG(ERROR) << "Invalid input params.";
1101       return result;
1102     }
1103   }
1104 
1105   if (asyncContext->callbackRef == nullptr) {
1106     status = napi_create_promise(env, &asyncContext->deferred, &result);
1107     if (status != napi_ok) {
1108       MS_LOG(ERROR) << "create promise failed.";
1109       return result;
1110     }
1111   } else {
1112     status = napi_get_undefined(env, &result);
1113     if (status != napi_ok) {
1114       MS_LOG(ERROR) << "create callback failed.";
1115       return result;
1116     }
1117   }
1118 
1119   napi_value resource = nullptr;
1120   napi_create_string_utf8(env, "LoadMSLiteTrainModelFromBuffer", NAPI_AUTO_LENGTH, &resource);
1121   status = napi_create_async_work(
1122     env, nullptr, resource,
1123     [](napi_env env, void *data) {
1124       auto context = static_cast<MSLiteModelAsyncContext *>(data);
1125       context->status = SUCCESS;
1126     },
1127     GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1128   if (status != napi_ok) {
1129     result = nullptr;
1130   } else {
1131     status = napi_queue_async_work(env, asyncContext->work);
1132     if (status == napi_ok) {
1133       asyncContext.release();
1134     } else {
1135       result = nullptr;
1136     }
1137   }
1138   return result;
1139 }
1140 
LoadMSLiteTrainModelFromFd(napi_env env,napi_callback_info info)1141 napi_value MSLiteModelNapi::LoadMSLiteTrainModelFromFd(napi_env env, napi_callback_info info) {
1142   napi_status status;
1143   napi_value result = nullptr;
1144   GET_PARAMS(env, info, ARGS_THREE);
1145 
1146   std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1147 
1148   asyncContext->model_info.train_model = true;
1149   int32_t ret;
1150   for (size_t i = PARAM0; i < argc; i++) {
1151     if (i == PARAM0) {
1152       ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1153       if (ret != SUCCESS) {
1154         MS_LOG(ERROR) << "Parsing model failed.";
1155         return result;
1156       }
1157     } else if (i == PARAM1) {
1158       ret = ParseTrainCfgInfo(env, argv[i], asyncContext->context.train_cfg);
1159       if (ret != SUCCESS) {
1160         MS_LOG(ERROR) << "Parsing TrainCfg failed.";
1161         return result;
1162       }
1163     } else if (i == PARAM2) {
1164       ret = ParseContextInfo(env, argv[i], asyncContext->context);
1165       if (ret != SUCCESS) {
1166         MS_LOG(ERROR) << "Parsing context failed.";
1167         return result;
1168       }
1169     } else {
1170       MS_LOG(ERROR) << "Invalid input params.";
1171       return result;
1172     }
1173   }
1174 
1175   if (asyncContext->callbackRef == nullptr) {
1176     status = napi_create_promise(env, &asyncContext->deferred, &result);
1177     if (status != napi_ok) {
1178       MS_LOG(ERROR) << "create promise failed.";
1179       return result;
1180     }
1181   } else {
1182     status = napi_get_undefined(env, &result);
1183     if (status != napi_ok) {
1184       MS_LOG(ERROR) << "create callback failed.";
1185       return result;
1186     }
1187   }
1188 
1189   napi_value resource = nullptr;
1190   napi_create_string_utf8(env, "LoadMSLiteTrainModelFromFd", NAPI_AUTO_LENGTH, &resource);
1191   status = napi_create_async_work(
1192     env, nullptr, resource,
1193     [](napi_env env, void *data) {
1194       auto context = static_cast<MSLiteModelAsyncContext *>(data);
1195       context->status = SUCCESS;
1196     },
1197     GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1198   if (status != napi_ok) {
1199     result = nullptr;
1200   } else {
1201     status = napi_queue_async_work(env, asyncContext->work);
1202     if (status == napi_ok) {
1203       asyncContext.release();
1204     } else {
1205       result = nullptr;
1206     }
1207   }
1208   return result;
1209 }
1210 
LoadMSLiteModelFromBuffer(napi_env env,napi_callback_info info)1211 napi_value MSLiteModelNapi::LoadMSLiteModelFromBuffer(napi_env env, napi_callback_info info) {
1212   napi_status status;
1213   napi_value result = nullptr;
1214   const int32_t refCount = 1;
1215   GET_PARAMS(env, info, ARGS_THREE);
1216   napi_valuetype valueType = napi_undefined;
1217 
1218   std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1219 
1220   int32_t ret;
1221   for (size_t i = PARAM0; i < argc; i++) {
1222     if (i == PARAM0) {
1223       ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1224       if (ret != SUCCESS) {
1225         MS_LOG(ERROR) << "Parsing model failed.";
1226         return result;
1227       }
1228     } else if (i == PARAM1) {
1229       napi_typeof(env, argv[i], &valueType);
1230       if (valueType == napi_function) {
1231         napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1232       } else {
1233         ret = ParseContextInfo(env, argv[i], asyncContext->context);
1234         if (ret != SUCCESS) {
1235           MS_LOG(ERROR) << "Parsing context failed.";
1236           return result;
1237         }
1238       }
1239     } else if (i == PARAM2) {
1240       napi_typeof(env, argv[i], &valueType);
1241       if (valueType == napi_function) {
1242         napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1243       }
1244       break;
1245     } else {
1246       MS_LOG(ERROR) << "Invalid input params.";
1247       return result;
1248     }
1249   }
1250 
1251   if (asyncContext->callbackRef == nullptr) {
1252     status = napi_create_promise(env, &asyncContext->deferred, &result);
1253     if (status != napi_ok) {
1254       MS_LOG(ERROR) << "create promise failed.";
1255       return result;
1256     }
1257   } else {
1258     status = napi_get_undefined(env, &result);
1259     if (status != napi_ok) {
1260       MS_LOG(ERROR) << "create callback failed.";
1261       return result;
1262     }
1263   }
1264 
1265   napi_value resource = nullptr;
1266   napi_create_string_utf8(env, "LoadMSLiteModelFromBuffer", NAPI_AUTO_LENGTH, &resource);
1267   status = napi_create_async_work(
1268     env, nullptr, resource,
1269     [](napi_env env, void *data) {
1270       auto context = static_cast<MSLiteModelAsyncContext *>(data);
1271       context->status = SUCCESS;
1272     },
1273     GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1274   if (status != napi_ok) {
1275     result = nullptr;
1276   } else {
1277     status = napi_queue_async_work(env, asyncContext->work);
1278     if (status == napi_ok) {
1279       asyncContext.release();
1280     } else {
1281       result = nullptr;
1282     }
1283   }
1284   return result;
1285 }
1286 
LoadMSLiteModelFromFd(napi_env env,napi_callback_info info)1287 napi_value MSLiteModelNapi::LoadMSLiteModelFromFd(napi_env env, napi_callback_info info) {
1288   napi_status status;
1289   napi_value result = nullptr;
1290   const int32_t refCount = 1;
1291   GET_PARAMS(env, info, ARGS_THREE);
1292   napi_valuetype valueType = napi_undefined;
1293 
1294   std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1295 
1296   int32_t ret;
1297   for (size_t i = PARAM0; i < argc; i++) {
1298     if (i == PARAM0) {
1299       ret = ParseModelInfo(env, argv[i], asyncContext->model_info);
1300       if (ret != SUCCESS) {
1301         MS_LOG(ERROR) << "Parsing model failed.";
1302         return result;
1303       }
1304     } else if (i == PARAM1) {
1305       napi_typeof(env, argv[i], &valueType);
1306       if (valueType == napi_function) {
1307         napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1308       } else {
1309         ret = ParseContextInfo(env, argv[i], asyncContext->context);
1310         if (ret != SUCCESS) {
1311           MS_LOG(ERROR) << "Parsing context failed.";
1312           return result;
1313         }
1314       }
1315     } else if (i == PARAM2) {
1316       napi_typeof(env, argv[i], &valueType);
1317       if (valueType == napi_function) {
1318         napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1319       }
1320       break;
1321     } else {
1322       MS_LOG(ERROR) << "Invalid input params.";
1323       return result;
1324     }
1325   }
1326 
1327   if (asyncContext->callbackRef == nullptr) {
1328     status = napi_create_promise(env, &asyncContext->deferred, &result);
1329     if (status != napi_ok) {
1330       MS_LOG(ERROR) << "create promise failed.";
1331       return result;
1332     }
1333   } else {
1334     status = napi_get_undefined(env, &result);
1335     if (status != napi_ok) {
1336       MS_LOG(ERROR) << "create callback failed.";
1337       return result;
1338     }
1339   }
1340 
1341   napi_value resource = nullptr;
1342   napi_create_string_utf8(env, "LoadMSLiteModelFromFd", NAPI_AUTO_LENGTH, &resource);
1343   status = napi_create_async_work(
1344     env, nullptr, resource,
1345     [](napi_env env, void *data) {
1346       auto context = static_cast<MSLiteModelAsyncContext *>(data);
1347       context->status = SUCCESS;
1348     },
1349     GetMSLiteModelAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1350   if (status != napi_ok) {
1351     result = nullptr;
1352   } else {
1353     status = napi_queue_async_work(env, asyncContext->work);
1354     if (status == napi_ok) {
1355       asyncContext.release();
1356     } else {
1357       result = nullptr;
1358     }
1359   }
1360   return result;
1361 }
1362 
GetCpuDeviceInfo(napi_env env,napi_value args,ContextInfo & context)1363 int32_t MSLiteModelNapi::GetCpuDeviceInfo(napi_env env, napi_value args, ContextInfo &context) {
1364   bool has_cpu_property = false;
1365   napi_status status = napi_has_named_property(env, args, "cpu", &has_cpu_property);
1366   if (status != napi_ok) {
1367     MS_LOG(ERROR) << "can not find cpu property";
1368     return ERR_INVALID_OPERATION;
1369   }
1370   if (!has_cpu_property) {
1371     return ERR_NOT_EXISTED_PARAM;
1372   }
1373 
1374   napi_value config_item = nullptr;
1375   status = napi_get_named_property(env, args, "cpu", &config_item);
1376   if (status != napi_ok) {
1377     MS_LOG(ERROR) << "can not get cpu property";
1378     return ERR_INVALID_OPERATION;
1379   }
1380 
1381   int32_t int_value = 0;
1382   std::string str_value = "";
1383   std::vector<int32_t> affinity_cores;
1384 
1385   if (CommonNapi::GetPropertyInt32(env, config_item, "threadNum", int_value) == SUCCESS) {
1386     MS_LOG(DEBUG) << "threadNum: " << int_value;
1387     context.cpu_device.thread_num = int_value;
1388   } else {
1389     context.cpu_device.thread_num = PARAM2;
1390   }
1391 
1392   if (CommonNapi::GetPropertyInt32(env, config_item, "threadAffinityMode", int_value) == SUCCESS) {
1393     MS_LOG(DEBUG) << "threadAffinityMode: " << int_value;
1394     if (int_value > PARAM2 || int_value < PARAM0) {
1395       MS_LOG(ERROR) << "threadAffinityMode value is set: " << int_value << ", is out of limition";
1396       return ERR_INVALID_OPERATION;
1397     }
1398     context.cpu_device.thread_affinity_mode = int_value;
1399   } else {
1400     context.cpu_device.thread_affinity_mode = PARAM0;
1401   }
1402 
1403   if (CommonNapi::GetPropertyInt32Array(env, config_item, "threadAffinityCoreList", affinity_cores) == SUCCESS) {
1404     MS_LOG(DEBUG) << "affinityCores size: " << affinity_cores.size();
1405     context.cpu_device.thread_affinity_cores.assign(affinity_cores.begin(), affinity_cores.end());
1406   } else {
1407     context.cpu_device.thread_affinity_cores = {};
1408   }
1409 
1410   if (CommonNapi::GetPropertyString(env, config_item, "precisionMode", str_value) == SUCCESS) {
1411     MS_LOG(DEBUG) << "precisionMode: " << str_value.c_str();
1412     context.cpu_device.precision_mode = str_value;
1413   } else {
1414     context.cpu_device.precision_mode = "enforce_fp32";
1415   }
1416   return SUCCESS;
1417 }
1418 
GetNNRTDeviceInfo(napi_env env,napi_value args,ContextInfo & context)1419 int32_t MSLiteModelNapi::GetNNRTDeviceInfo(napi_env env, napi_value args, ContextInfo &context) {
1420   bool has_nnrt_property = false;
1421   napi_status status = napi_has_named_property(env, args, "nnrt", &has_nnrt_property);
1422   if (status != napi_ok) {
1423     MS_LOG(ERROR) << "can not find nnrt property";
1424     return ERR_ILLEGAL_STATE;
1425   }
1426   if (!has_nnrt_property) {
1427     return ERR_NOT_EXISTED_PARAM;
1428   }
1429 
1430   napi_value config_item = nullptr;
1431   status = napi_get_named_property(env, args, "nnrt", &config_item);
1432   if (status != napi_ok) {
1433     MS_LOG(ERROR) << "can not get nnrt property";
1434     return ERR_INVALID_PARAM;
1435   }
1436 
1437   int32_t int_value = 0;
1438   std::string str_value = "";
1439   std::vector<int32_t> affinity_cores;
1440 
1441   uint64_t device_id;
1442   auto ret = CommonNapi::GetPropertyBigIntUint64(env, config_item, "deviceID", device_id);
1443   if (ret == SUCCESS) {
1444     MS_LOG(DEBUG) << "deviceID: " << device_id;
1445     context.nnrt_device.device_id = static_cast<size_t>(device_id);
1446   } else if (ret == ERR_NOT_EXISTED_PARAM) {
1447     size_t num = 0;
1448     auto *desc = OH_AI_GetAllNNRTDeviceDescs(&num);
1449     if (desc == nullptr || num == 0) {
1450       MS_LOG(WARNING) << "Failed to get nnrt device id, skip adding nnrt device info.";
1451       return ERR_NOT_EXISTED_PARAM;
1452     }
1453     auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc);
1454     OH_AI_DestroyAllNNRTDeviceDescs(&desc);
1455     MS_LOG(INFO) << "set nnrt device id to " << id;
1456     context.nnrt_device.device_id = id;
1457   } else {
1458     return ERR_INVALID_PARAM;
1459   }
1460 
1461   ret = CommonNapi::GetPropertyInt32(env, config_item, "performanceMode", int_value);
1462   if (ret == SUCCESS) {
1463     MS_LOG(DEBUG) << "performanceMode: " << int_value;
1464     if (int_value > PARAM4 || int_value < PARAM0) {
1465       MS_LOG(ERROR) << "performanceMode value is set to: " << int_value << ", which is out of range";
1466       return ERR_INVALID_PARAM;
1467     }
1468     context.nnrt_device.performance_mode = int_value;
1469   } else if (ret == ERR_NOT_EXISTED_PARAM) {
1470     context.nnrt_device.performance_mode = UNSET_VALUE;
1471   } else {
1472     return ERR_INVALID_PARAM;
1473   }
1474 
1475   ret = CommonNapi::GetPropertyInt32(env, config_item, "priority", int_value);
1476   if (ret == SUCCESS) {
1477     MS_LOG(DEBUG) << "priority: " << int_value;
1478     if (int_value > PARAM3 || int_value < PARAM0) {
1479       MS_LOG(ERROR) << "priority value is set to: " << int_value << ", which is out of range";
1480       return ERR_INVALID_PARAM;
1481     }
1482     context.nnrt_device.priority = int_value;
1483   } else if (ret == ERR_NOT_EXISTED_PARAM) {
1484     context.nnrt_device.priority = UNSET_VALUE;
1485   } else {
1486     return ERR_INVALID_PARAM;
1487   }
1488 
1489   // ignore extensions for now
1490   return SUCCESS;
1491 }
1492 
GetInputs(napi_env env,napi_callback_info info)1493 napi_value MSLiteModelNapi::GetInputs(napi_env env, napi_callback_info info) {
1494   napi_value undefinedResult = nullptr;
1495   napi_get_undefined(env, &undefinedResult);
1496 
1497   size_t argCount = 0;
1498   napi_value jsThis = nullptr;
1499   napi_value jsResult = nullptr;
1500   MSLiteModelNapi *modelNapi = nullptr;
1501 
1502   napi_status status = napi_get_cb_info(env, info, &argCount, nullptr, &jsThis, nullptr);
1503   if (status != napi_ok || jsThis == nullptr) {
1504     MS_LOG(ERROR) << "failed to retrieve details about the callback";
1505     return undefinedResult;
1506   }
1507 
1508   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
1509   if (status != napi_ok || modelNapi == nullptr) {
1510     MS_LOG(ERROR) << "failed to get model";
1511     return undefinedResult;
1512   }
1513 
1514   if (modelNapi->native_model_ == nullptr) {
1515     MS_LOG(ERROR) << "model is released(null), please create model again";
1516     return undefinedResult;
1517   }
1518   std::vector<MSTensor> inputs = modelNapi->native_model_->GetInputs();
1519   std::vector<MSTensor> tensor_inputs;
1520   for (size_t i = 0; i < inputs.size(); i++) {
1521     auto tensor = mindspore::MSTensor::CreateTensor(inputs.at(i).Name(), inputs.at(i).DataType(), {}, nullptr, 0);
1522     if (tensor == nullptr) {
1523       MS_LOG(ERROR) << "create tensor failed.";
1524       return undefinedResult;
1525     }
1526     tensor->SetShape(inputs.at(i).Shape());
1527     tensor->SetFormat(inputs.at(i).format());
1528     tensor->SetDataType(inputs.at(i).DataType());
1529     tensor_inputs.push_back(*tensor);
1530     delete tensor;
1531   }
1532 
1533   size_t size = inputs.size();
1534   MS_LOG(INFO) << "inputs size: " << size;
1535   napi_create_array_with_length(env, size, &jsResult);
1536   for (size_t i = 0; i < size; i++) {
1537     status = napi_set_element(env, jsResult, i, MSTensorNapi::NewInstance(env, tensor_inputs[i]));
1538     if (status != napi_ok) {
1539       MS_LOG(ERROR) << "napi_set_element failed! code: " << status;
1540     }
1541   }
1542   MS_LOG(INFO) << "get model inputs success: " << inputs[0].Name().c_str();
1543   return jsResult;
1544 }
1545 
Resize(napi_env env,napi_callback_info info)1546 napi_value MSLiteModelNapi::Resize(napi_env env, napi_callback_info info) {
1547   napi_value undefinedResult = nullptr;
1548   bool result = false;
1549   napi_status status = napi_get_boolean(env, result, &undefinedResult);
1550   if (status != napi_ok) {
1551     MS_LOG(ERROR) << "get bool error";
1552     return undefinedResult;
1553   }
1554 
1555   napi_value jsThis = nullptr;
1556   napi_value jsResult = nullptr;
1557   MSLiteModelNapi *modelNapi = nullptr;
1558   napi_value argv[ARGS_TWO] = {0};
1559   size_t argCount = PARAM2;
1560   status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
1561   if (status != napi_ok || jsThis == nullptr) {
1562     MS_LOG(ERROR) << "failed to retrieve details about the callback";
1563     return undefinedResult;
1564   }
1565   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
1566   if (status != napi_ok || modelNapi == nullptr) {
1567     MS_LOG(ERROR) << "get model napi error";
1568     return undefinedResult;
1569   }
1570 
1571   if (modelNapi->native_model_ == nullptr) {
1572     MS_LOG(ERROR) << "model is released(null), please create model again";
1573     return undefinedResult;
1574   }
1575   std::vector<MSTensor> inputs = modelNapi->native_model_->GetInputs();
1576   std::vector<MSTensor> tensor_inputs;
1577   std::vector<std::vector<int64_t>> dims;
1578 
1579   // set inputs data
1580   uint32_t array_length = 0;
1581   status = napi_get_array_length(env, argv[PARAM0], &array_length);
1582   if (status != napi_ok || array_length <= 0) {
1583     MS_LOG(ERROR) << "get inputs tensor length failed.";
1584     return undefinedResult;
1585   }
1586   if (inputs.size() != array_length) {
1587     MS_LOG(ERROR) << "array length not equal to model inputs size.";
1588     return undefinedResult;
1589   }
1590   for (size_t i = 0; i < array_length; i++) {
1591     napi_value element = nullptr;
1592     status = napi_get_element(env, argv[PARAM0], i, &element);
1593     if (status != napi_ok) {
1594       MS_LOG(ERROR) << "can not get element";
1595       return undefinedResult;
1596     }
1597 
1598     std::string property_name = "getData";
1599     bool exist = false;
1600     napi_value data_func = nullptr;
1601 
1602     status = napi_has_named_property(env, element, property_name.c_str(), &exist);
1603     if (status != napi_ok || !exist) {
1604       MS_LOG(ERROR) << "can not find target property";
1605       return undefinedResult;
1606     }
1607 
1608     if (status != napi_ok || !exist) {
1609       MS_LOG(ERROR) << "can not find " << property_name.c_str() << " property.";
1610       return undefinedResult;
1611     }
1612 
1613     if (napi_get_named_property(env, element, property_name.c_str(), &data_func) != napi_ok) {
1614       MS_LOG(ERROR) << "get " << property_name.c_str() << " property fail.";
1615       return undefinedResult;
1616     }
1617     void *js_data = nullptr;
1618     size_t length = 0;
1619     napi_value return_val;
1620 
1621     status = napi_call_function(env, element, data_func, 0, nullptr, &return_val);
1622     if (status != napi_ok || return_val == nullptr) {
1623       MS_LOG(ERROR) << "napi call function error.";
1624       return undefinedResult;
1625     }
1626 
1627     status = napi_get_arraybuffer_info(env, return_val, &js_data, &length);
1628     if (status != napi_ok || js_data == nullptr) {
1629       MS_LOG(ERROR) << "get js data error.";
1630       return undefinedResult;
1631     }
1632     if (inputs[i].DataSize() != length) {
1633       MS_LOG(ERROR) << "tensor size is: " << static_cast<int>(inputs[i].DataSize()) << ", but data length got "
1634                     << static_cast<int>(length);
1635       return undefinedResult;
1636     }
1637 
1638     auto tensor_data = inputs[i].MutableData();
1639     if (tensor_data == nullptr) {
1640       MS_LOG(ERROR) << "malloc data for tensor failed.";
1641       return undefinedResult;
1642     }
1643     memcpy(tensor_data, js_data, length);
1644   }
1645 
1646   napi_value dim_num = nullptr;
1647   int64_t dim_ele = 0;
1648   uint32_t dims_size = 0;
1649   uint32_t dim_size = 0;
1650 
1651   status = napi_is_array(env, argv[PARAM1], &result);
1652   if (status != napi_ok || result == false) {
1653     MS_LOG(ERROR) << "new dim is not a array";
1654     return undefinedResult;
1655   }
1656 
1657   status = napi_get_array_length(env, argv[PARAM1], &dims_size);
1658   if (status != napi_ok) {
1659     MS_LOG(ERROR) << "get new dims size error";
1660     return undefinedResult;
1661   }
1662   for (size_t i = 0; i < dims_size; i++) {
1663     napi_value dim_element = nullptr;
1664     status = napi_get_element(env, argv[PARAM1], i, &dim_element);
1665     if (status != napi_ok) {
1666       MS_LOG(ERROR) << "can not get element";
1667       return undefinedResult;
1668     }
1669 
1670     status = napi_is_array(env, dim_element, &result);
1671     if (status != napi_ok || result == false) {
1672       MS_LOG(ERROR) << "new dim's element is not a array";
1673       return undefinedResult;
1674     }
1675 
1676     status = napi_get_array_length(env, dim_element, &dim_size);
1677     if (status != napi_ok) {
1678       MS_LOG(ERROR) << "get new dim size error";
1679       return undefinedResult;
1680     }
1681     std::vector<int64_t> dim(dim_size);
1682     for (size_t j = 0; j < dim_size; j++) {
1683       status = napi_get_element(env, dim_element, j, &dim_num);
1684       if (status != napi_ok) {
1685         MS_LOG(ERROR) << "get dim num error";
1686         return undefinedResult;
1687       }
1688       status = napi_get_value_int64(env, dim_num, &dim_ele);
1689       if (status != napi_ok) {
1690         MS_LOG(ERROR) << "get dim element error";
1691         return undefinedResult;
1692       }
1693       dim[j] = dim_ele;
1694     }
1695     dims.push_back(dim);
1696   }
1697   if (modelNapi->native_model_->Resize(inputs, dims) != mindspore::kSuccess) {
1698     MS_LOG(ERROR) << "resize failed";
1699     return undefinedResult;
1700   }
1701   status = napi_get_boolean(env, result, &jsResult);
1702   if (status != napi_ok) {
1703     MS_LOG(ERROR) << "get bool error";
1704     return undefinedResult;
1705   }
1706   return jsResult;
1707 }
1708 
1709 template <typename T, typename Distribution>
GenerateRandomData(int size,void * data,Distribution distribution)1710 void GenerateRandomData(int size, void *data, Distribution distribution) {
1711   std::mt19937 random_engine;
1712   int elements_num = size / sizeof(T);
1713   (void)std::generate_n(static_cast<T *>(data), elements_num,
1714                         [&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
1715 }
1716 
GenerateInputDataWithRandom(std::vector<mindspore::MSTensor> inputs)1717 int GenerateInputDataWithRandom(std::vector<mindspore::MSTensor> inputs) {
1718   for (auto tensor : inputs) {
1719     auto input_data = tensor.MutableData();
1720     if (input_data == nullptr) {
1721       std::cerr << "mallocData for inTensor failed." << std::endl;
1722       return -1;
1723     }
1724     GenerateRandomData<float>(tensor.DataSize(), input_data, std::uniform_real_distribution<float>(0.1f, 1.0f));
1725   }
1726   return mindspore::kSuccess;
1727 }
1728 
PredictAsync(napi_env env,napi_callback_info info)1729 napi_value MSLiteModelNapi::PredictAsync(napi_env env, napi_callback_info info) {
1730   napi_status status = napi_ok;
1731   napi_value undefinedResult = nullptr;
1732   napi_value result = nullptr;
1733   const int32_t refCount = 1;
1734   napi_valuetype valueType;
1735 
1736   std::unique_ptr<MSLiteModelAsyncContext> asyncContext = std::make_unique<MSLiteModelAsyncContext>();
1737   if (asyncContext == nullptr) {
1738     MS_LOG(ERROR) << "MSLiteModelAsyncContext object create failed.";
1739     return undefinedResult;
1740   }
1741 
1742   GET_PARAMS(env, info, ARGS_TWO);
1743   for (size_t i = PARAM0; i < argc; i++) {
1744     if (i == PARAM1) {
1745       status = napi_typeof(env, argv[i], &valueType);
1746       if ((status != napi_ok) || (valueType != napi_function)) {
1747         MS_LOG(ERROR) << "napi_typeof check callback failed.";
1748         return result;
1749       }
1750       status = napi_create_reference(env, argv[i], refCount, &asyncContext->callbackRef);
1751       if (status != napi_ok) {
1752         MS_LOG(ERROR) << "failed to create reference of callback";
1753         return result;
1754       }
1755     }
1756   }
1757 
1758   if (SetTensorData(env, thisVar, argv[PARAM0], asyncContext.get()) != SUCCESS) {
1759     MS_LOG(ERROR) << "Set tensor data failed.";
1760     return undefinedResult;
1761   }
1762 
1763   if (asyncContext->callbackRef == nullptr) {
1764     status = napi_create_promise(env, &asyncContext->deferred, &result);
1765     if (status != napi_ok) {
1766       MS_LOG(ERROR) << "create promise failed.";
1767       return result;
1768     }
1769   } else {
1770     status = napi_get_undefined(env, &result);
1771     if (status != napi_ok) {
1772       MS_LOG(ERROR) << "create callback failed.";
1773       return result;
1774     }
1775   }
1776 
1777   napi_value resource = nullptr;
1778   napi_create_string_utf8(env, "Predict", NAPI_AUTO_LENGTH, &resource);
1779   status = napi_create_async_work(
1780     env, nullptr, resource,
1781     [](napi_env env, void *data) {
1782       auto context = static_cast<MSLiteModelAsyncContext *>(data);
1783       context->status = SUCCESS;
1784     },
1785     PredictAsyncCallbackComplete, static_cast<void *>(asyncContext.get()), &asyncContext->work);
1786   if (status != napi_ok) {
1787     result = nullptr;
1788   } else {
1789     status = napi_queue_async_work(env, asyncContext->work);
1790     if (status == napi_ok) {
1791       asyncContext.release();
1792     } else {
1793       result = nullptr;
1794     }
1795   }
1796   return result;
1797 }
1798 
SetTensorData(napi_env env,napi_value thisVar,napi_value argv,MSLiteModelAsyncContext * async_context)1799 int32_t MSLiteModelNapi::SetTensorData(napi_env env, napi_value thisVar, napi_value argv,
1800                                        MSLiteModelAsyncContext *async_context) {
1801   uint32_t array_length = 0;
1802   napi_status status = napi_get_array_length(env, argv, &array_length);
1803   if (status != napi_ok || array_length <= 0) {
1804     MS_LOG(ERROR) << "get inputs tensor length failed.";
1805     return ERR_INVALID_PARAM;
1806   }
1807 
1808   status = napi_unwrap(env, thisVar, reinterpret_cast<void **>(&(async_context->lite_model)));
1809   if (status != napi_ok || async_context->lite_model == nullptr) {
1810     MS_LOG(ERROR) << "get model napi error";
1811     return ERROR;
1812   }
1813   auto modelNapi = async_context->lite_model;
1814   if (modelNapi->native_model_ == nullptr) {
1815     MS_LOG(ERROR) << "model is released(null), please create model again";
1816     return ERROR;
1817   }
1818 
1819   auto inputs = modelNapi->native_model_->GetInputs();
1820   if (inputs.size() != array_length) {
1821     MS_LOG(ERROR) << "array length not equal to model inputs size.";
1822     return ERR_INVALID_PARAM;
1823   }
1824 
1825   for (size_t i = 0; i < array_length; i++) {
1826     napi_value element = nullptr;
1827     status = napi_get_element(env, argv, i, &element);
1828     if (status != napi_ok) {
1829       MS_LOG(ERROR) << "can not get element";
1830       return ERROR;
1831     }
1832 
1833     std::string property_name = "getData";
1834     bool exist = false;
1835     napi_value data_func = nullptr;
1836 
1837     napi_status status = napi_has_named_property(env, element, property_name.c_str(), &exist);
1838 
1839     if (status != napi_ok || !exist) {
1840       MS_LOG(ERROR) << "can not find " << property_name.c_str() << " property.";
1841       return ERROR;
1842     }
1843 
1844     if (napi_get_named_property(env, element, property_name.c_str(), &data_func) != napi_ok) {
1845       MS_LOG(ERROR) << "get " << property_name.c_str() << " property fail.";
1846       return ERROR;
1847     }
1848     void *js_data = nullptr;
1849     size_t length = 0;
1850     napi_value return_val;
1851 
1852     status = napi_call_function(env, element, data_func, 0, nullptr, &return_val);
1853     if (status != napi_ok || return_val == nullptr) {
1854       MS_LOG(ERROR) << "napi call function error.";
1855       return ERROR;
1856     }
1857     status = napi_get_arraybuffer_info(env, return_val, &js_data, &length);
1858     if (status != napi_ok || js_data == nullptr) {
1859       MS_LOG(ERROR) << "Get js data error.";
1860       return ERROR;
1861     }
1862     if (inputs[i].DataSize() != length) {
1863       MS_LOG(ERROR) << "tensor size is: " << static_cast<int>(inputs[i].DataSize()) << ", but data length got "
1864                     << static_cast<int>(length);
1865       return ERROR;
1866     }
1867 
1868     auto tensor_data = inputs[i].MutableData();
1869     if (tensor_data == nullptr) {
1870       MS_LOG(ERROR) << "malloc data for tensor failed.";
1871       return ERROR;
1872     }
1873     memcpy(tensor_data, js_data, length);
1874   }
1875   return SUCCESS;
1876 }
1877 
PredictAsyncCallbackComplete(napi_env env,napi_status status,void * data)1878 void MSLiteModelNapi::PredictAsyncCallbackComplete(napi_env env, napi_status status, void *data) {
1879   napi_value valueParam = nullptr;
1880   auto asyncContext = static_cast<MSLiteModelAsyncContext *>(data);
1881 
1882   if (asyncContext != nullptr) {
1883     if (!asyncContext->status) {
1884       auto modelNapi = asyncContext->lite_model;
1885       if (modelNapi->native_model_ == nullptr) {
1886         MS_LOG(ERROR) << "model is released(null), please create model again";
1887         return;
1888       }
1889       auto inputs = modelNapi->native_model_->GetInputs();
1890       std::vector<MSTensor> outputs;
1891 
1892       auto predict_ret = modelNapi->native_model_->Predict(inputs, &outputs);
1893       if (predict_ret != mindspore::kSuccess) {
1894         MS_LOG(ERROR) << "model predict failed.";
1895         return;
1896       }
1897 
1898       napi_create_array_with_length(env, outputs.size(), &valueParam);
1899       for (size_t i = 0; i < outputs.size(); i++) {
1900         status = napi_set_element(env, valueParam, i, MSTensorNapi::NewInstance(env, outputs[i]));
1901         if (status != napi_ok) {
1902           MS_LOG(ERROR) << "napi_set_element failed! code: " << status;
1903         }
1904       }
1905       MS_LOG(INFO) << "predict model success.";
1906     }
1907     CommonCallbackRoutine(env, asyncContext, valueParam);
1908   } else {
1909     MS_LOG(ERROR) << "ERROR: PredictAsyncCallbackComplete asyncContext is Null!";
1910   }
1911 }
1912 
GetWeights(napi_env env,napi_callback_info info)1913 napi_value MSLiteModelNapi::GetWeights(napi_env env, napi_callback_info info) {
1914   napi_value undefinedResult = nullptr;
1915   napi_get_undefined(env, &undefinedResult);
1916 
1917   size_t argCount = 0;
1918   napi_value jsThis = nullptr;
1919   napi_value jsResult = nullptr;
1920   MSLiteModelNapi *modelNapi = nullptr;
1921 
1922   napi_status status = napi_get_cb_info(env, info, &argCount, nullptr, &jsThis, nullptr);
1923   if (status != napi_ok || jsThis == nullptr) {
1924     MS_LOG(ERROR) << "failed to retrieve details about the callback";
1925     return undefinedResult;
1926   }
1927 
1928   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
1929   if (status != napi_ok || modelNapi == nullptr) {
1930     MS_LOG(ERROR) << "failed to get model";
1931     return undefinedResult;
1932   }
1933 
1934   if (modelNapi->native_model_ == nullptr) {
1935     MS_LOG(ERROR) << "model is released(null), please create model again";
1936     return undefinedResult;
1937   }
1938   std::vector<MSTensor> weights = modelNapi->native_model_->GetFeatureMaps();
1939   std::vector<MSTensor> feature_maps;
1940   for (size_t i = 0; i < weights.size(); i++) {
1941     auto tensor = mindspore::MSTensor::CreateTensor(weights.at(i).Name(), weights.at(i).DataType(), {}, nullptr, 0);
1942     if (tensor == nullptr) {
1943       MS_LOG(ERROR) << "create tensor failed.";
1944       return undefinedResult;
1945     }
1946     tensor->SetShape(weights.at(i).Shape());
1947     tensor->SetFormat(weights.at(i).format());
1948     tensor->SetDataType(weights.at(i).DataType());
1949     tensor->SetData(weights.at(i).MutableData(), false);
1950     feature_maps.push_back(*tensor);
1951     delete tensor;
1952   }
1953 
1954   size_t size = weights.size();
1955   MS_LOG(INFO) << "weights size: " << size;
1956   napi_create_array_with_length(env, size, &jsResult);
1957   for (size_t i = 0; i < size; i++) {
1958     status = napi_set_element(env, jsResult, i, MSTensorNapi::NewInstance(env, feature_maps[i]));
1959     if (status != napi_ok) {
1960       MS_LOG(ERROR) << "napi_set_element failed! code: " << status;
1961     }
1962   }
1963   MS_LOG(INFO) << "get model weights success";
1964   return jsResult;
1965 }
1966 
SetModelInputs(napi_env env,napi_value argv,std::shared_ptr<Model> model)1967 int32_t SetModelInputs(napi_env env, napi_value argv, std::shared_ptr<Model> model) {
1968   uint32_t array_length = 0;
1969   napi_status status = napi_get_array_length(env, argv, &array_length);
1970   if (status != napi_ok || array_length <= 0) {
1971     MS_LOG(ERROR) << "get inputs tensor length failed.";
1972     return ERR_INVALID_PARAM;
1973   }
1974 
1975   if (model == nullptr) {
1976     MS_LOG(ERROR) << "model is nullptr";
1977     return ERR_INVALID_PARAM;
1978   }
1979 
1980   auto inputs = model->GetInputs();
1981   if (inputs.size() != array_length) {
1982     MS_LOG(ERROR) << "array length not equal to model inputs size.";
1983     return ERR_INVALID_PARAM;
1984   }
1985 
1986   for (size_t i = 0; i < array_length; i++) {
1987     napi_value element = nullptr;
1988     status = napi_get_element(env, argv, i, &element);
1989     if (status != napi_ok) {
1990       MS_LOG(ERROR) << "can not get element";
1991       return ERROR;
1992     }
1993 
1994     std::string property_name = "getData";
1995     bool exist = false;
1996     napi_value data_func = nullptr;
1997 
1998     napi_status status = napi_has_named_property(env, element, property_name.c_str(), &exist);
1999 
2000     if (status != napi_ok || !exist) {
2001       MS_LOG(ERROR) << "can not find " << property_name.c_str() << " property.";
2002       return ERROR;
2003     }
2004 
2005     if (napi_get_named_property(env, element, property_name.c_str(), &data_func) != napi_ok) {
2006       MS_LOG(ERROR) << "get " << property_name.c_str() << " property fail.";
2007       return ERROR;
2008     }
2009     void *js_data = nullptr;
2010     size_t length = 0;
2011     napi_value return_val;
2012 
2013     status = napi_call_function(env, element, data_func, 0, nullptr, &return_val);
2014     if (status != napi_ok || return_val == nullptr) {
2015       MS_LOG(ERROR) << "napi call function error.";
2016       return ERROR;
2017     }
2018     status = napi_get_arraybuffer_info(env, return_val, &js_data, &length);
2019     if (status != napi_ok || js_data == nullptr) {
2020       MS_LOG(ERROR) << "Get js data error.";
2021       return ERROR;
2022     }
2023     if (inputs[i].DataSize() != length) {
2024       MS_LOG(ERROR) << "tensor size is: " << static_cast<int>(inputs[i].DataSize()) << ", but data length got "
2025                     << static_cast<int>(length);
2026       return ERROR;
2027     }
2028 
2029     auto tensor_data = inputs[i].MutableData();
2030     if (tensor_data == nullptr) {
2031       MS_LOG(ERROR) << "malloc data for tensor failed.";
2032       return ERROR;
2033     }
2034     memcpy(tensor_data, js_data, length);
2035   }
2036   return SUCCESS;
2037 }
2038 
RunStep(napi_env env,napi_callback_info info)2039 napi_value MSLiteModelNapi::RunStep(napi_env env, napi_callback_info info) {
2040   napi_value undefinedResult = nullptr;
2041   bool result = false;
2042   napi_status status = napi_get_boolean(env, result, &undefinedResult);
2043   if (status != napi_ok) {
2044     MS_LOG(ERROR) << "get bool error";
2045     return undefinedResult;
2046   }
2047 
2048   napi_value jsThis = nullptr;
2049   MSLiteModelNapi *modelNapi = nullptr;
2050   size_t argCount = PARAM1;
2051   napi_value argv[ARGS_ONE] = {0};
2052 
2053   status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2054   if (status != napi_ok || jsThis == nullptr) {
2055     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2056     return undefinedResult;
2057   }
2058 
2059   if (argCount < ARGS_ONE) {
2060     MS_LOG(ERROR) << "argument num is less than one, please give input tensors";
2061     return undefinedResult;
2062   }
2063 
2064   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2065   if (status != napi_ok || modelNapi == nullptr) {
2066     MS_LOG(ERROR) << "get model napi error";
2067     return undefinedResult;
2068   }
2069 
2070   if (SetModelInputs(env, argv[PARAM0], modelNapi->native_model_) != SUCCESS) {
2071     MS_LOG(ERROR) << "set tensor data failed";
2072     return undefinedResult;
2073   }
2074 
2075   if (modelNapi->native_model_ == nullptr) {
2076     MS_LOG(ERROR) << "model is released(null), please create model again";
2077     return undefinedResult;
2078   }
2079 
2080   auto ret = modelNapi->native_model_->RunStep();
2081   if (ret != kSuccess) {
2082     MS_LOG(ERROR) << "Model run step failed";
2083     return undefinedResult;
2084   }
2085   status = napi_get_boolean(env, true, &undefinedResult);
2086   if (status != napi_ok) {
2087     MS_LOG(ERROR) << "create bool true value failed";
2088     return undefinedResult;
2089   }
2090   return undefinedResult;
2091 }
2092 
UpdateWeights(napi_env env,napi_callback_info info)2093 napi_value MSLiteModelNapi::UpdateWeights(napi_env env, napi_callback_info info) {
2094   napi_value undefinedResult = nullptr;
2095   bool result = false;
2096   napi_status status = napi_get_boolean(env, result, &undefinedResult);
2097   if (status != napi_ok) {
2098     MS_LOG(ERROR) << "get bool error";
2099     return undefinedResult;
2100   }
2101 
2102   napi_value jsThis = nullptr;
2103   napi_value jsResult = nullptr;
2104   MSLiteModelNapi *modelNapi = nullptr;
2105   napi_value argv[ARGS_ONE] = {0};
2106   size_t argCount = PARAM1;
2107   status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2108   if (status != napi_ok || jsThis == nullptr) {
2109     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2110     return undefinedResult;
2111   }
2112   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2113   if (status != napi_ok || modelNapi == nullptr) {
2114     MS_LOG(ERROR) << "get model napi error";
2115     return undefinedResult;
2116   }
2117 
2118   if (modelNapi->native_model_ == nullptr) {
2119     MS_LOG(ERROR) << "model is released(null), please create model again";
2120     return undefinedResult;
2121   }
2122 
2123   // set inputs data
2124   uint32_t array_length = 0;
2125   status = napi_get_array_length(env, argv[PARAM0], &array_length);
2126   if (status != napi_ok || array_length <= 0) {
2127     MS_LOG(ERROR) << "get inputs tensor length failed.";
2128     return undefinedResult;
2129   }
2130 
2131   std::vector<MSTensor> weights;
2132   for (size_t i = 0; i < array_length; i++) {
2133     napi_value element = nullptr;
2134     status = napi_get_element(env, argv[PARAM0], i, &element);
2135     if (status != napi_ok) {
2136       MS_LOG(ERROR) << "can not get element";
2137       return undefinedResult;
2138     }
2139 
2140     // get tensor name
2141     std::string tensor_name;
2142     auto ret = CommonNapi::GetPropertyString(env, element, "name", tensor_name);
2143     if (ret != SUCCESS) {
2144       MS_LOG(ERROR) << "get tensor name property failed";
2145       return undefinedResult;
2146     }
2147 
2148     // get tensor format
2149     int format;
2150     ret = CommonNapi::GetPropertyInt32(env, element, "format", format);
2151     if (ret != SUCCESS) {
2152       MS_LOG(ERROR) << "get format property failed";
2153       return undefinedResult;
2154     }
2155 
2156     // get dtype
2157     int dtype;
2158     ret = CommonNapi::GetPropertyInt32(env, element, "dtype", dtype);
2159     if (ret != SUCCESS) {
2160       MS_LOG(ERROR) << "get format property failed";
2161       return undefinedResult;
2162     }
2163 
2164     // get data size
2165     int data_size;
2166     ret = CommonNapi::GetPropertyInt32(env, element, "dataSize", data_size);
2167     if (ret != SUCCESS) {
2168       MS_LOG(ERROR) << "get dataSize property failed";
2169       return undefinedResult;
2170     }
2171 
2172     // get shape
2173     std::vector<int32_t> shape;
2174     ret = CommonNapi::GetPropertyInt32Array(env, element, "shape", shape);
2175     if (ret != SUCCESS) {
2176       MS_LOG(ERROR) << "get shape property failed";
2177       return undefinedResult;
2178     }
2179 
2180     // get data
2181     std::string property_name = "getData";
2182     bool exist = false;
2183     napi_value data_func = nullptr;
2184 
2185     status = napi_has_named_property(env, element, property_name.c_str(), &exist);
2186     if (status != napi_ok || !exist) {
2187       MS_LOG(ERROR) << "can not find target property";
2188       return undefinedResult;
2189     }
2190 
2191     if (napi_get_named_property(env, element, property_name.c_str(), &data_func) != napi_ok) {
2192       MS_LOG(ERROR) << "get " << property_name.c_str() << " property fail.";
2193       return undefinedResult;
2194     }
2195     void *js_data = nullptr;
2196     size_t length = 0;
2197 
2198     napi_value return_val;
2199     status = napi_call_function(env, element, data_func, 0, nullptr, &return_val);
2200     if (status != napi_ok || return_val == nullptr) {
2201       MS_LOG(ERROR) << "napi call function error.";
2202       return undefinedResult;
2203     }
2204 
2205     status = napi_get_arraybuffer_info(env, return_val, &js_data, &length);
2206     if (status != napi_ok || js_data == nullptr) {
2207       MS_LOG(ERROR) << "get js data error.";
2208       return undefinedResult;
2209     }
2210 
2211     std::vector<int64_t> int64_shape;
2212     int64_shape.reserve(shape.size());
2213     std::transform(shape.begin(), shape.end(), std::back_inserter(int64_shape), [](int32_t value) {
2214       return static_cast<int64_t>(value);
2215     });
2216     auto tensor = mindspore::MSTensor::CreateTensor(tensor_name, static_cast<mindspore::DataType>(dtype), int64_shape, nullptr, 0);
2217     if (tensor == nullptr) {
2218       MS_LOG(ERROR) << "create tensor failed.";
2219       return undefinedResult;
2220     }
2221     tensor->SetFormat(static_cast<mindspore::Format>(format));
2222     auto tensor_data = tensor->MutableData();
2223     if (tensor_data == nullptr) {
2224       MS_LOG(ERROR) << "mutable tensor data failed, get nullptr";
2225       return undefinedResult;
2226     }
2227 
2228     if (tensor->DataSize() != length) {
2229       MS_LOG(ERROR) << "tensor size is: " << static_cast<int>(tensor->DataSize()) << ", but data length got "
2230                     << static_cast<int>(length);
2231       return undefinedResult;
2232     }
2233 
2234     memcpy(tensor_data, js_data, length);
2235     weights.push_back(*tensor);
2236     delete tensor;
2237   }
2238 
2239   if (modelNapi->native_model_->UpdateFeatureMaps(weights) != mindspore::kSuccess) {
2240     MS_LOG(ERROR) << "UpdateFeatureMaps failed";
2241     return undefinedResult;
2242   }
2243   status = napi_get_boolean(env, true, &jsResult);
2244   if (status != napi_ok) {
2245     MS_LOG(ERROR) << "get bool error";
2246     return undefinedResult;
2247   }
2248   return jsResult;
2249 }
2250 
ExportModel(napi_env env,napi_callback_info info)2251 napi_value MSLiteModelNapi::ExportModel(napi_env env, napi_callback_info info) {
2252   napi_value undefinedResult = nullptr;
2253   bool result = false;
2254   napi_status status = napi_get_boolean(env, result, &undefinedResult);
2255   if (status != napi_ok) {
2256     MS_LOG(ERROR) << "get bool error";
2257     return undefinedResult;
2258   }
2259 
2260   napi_value jsThis = nullptr;
2261   napi_value jsResult = nullptr;
2262   MSLiteModelNapi *modelNapi = nullptr;
2263   napi_value argv[ARGS_FOUR] = {0};
2264   size_t argCount = PARAM4;
2265   status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2266   if (status != napi_ok || jsThis == nullptr) {
2267     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2268     return undefinedResult;
2269   }
2270   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2271   if (status != napi_ok || modelNapi == nullptr) {
2272     MS_LOG(ERROR) << "get model napi error";
2273     return undefinedResult;
2274   }
2275 
2276   if (modelNapi->native_model_ == nullptr) {
2277     MS_LOG(ERROR) << "model is released(null), please create model again";
2278     return undefinedResult;
2279   }
2280 
2281   // get modelfile
2282   char char_buf[SIZE];
2283   size_t buf_length = 0;
2284   status = napi_get_value_string_utf8(env, argv[PARAM0], char_buf, SIZE, &buf_length);
2285   if ((status != napi_ok) || (buf_length <= 0)) {
2286     MS_LOG(ERROR) << "Parse model file failed.";
2287     return undefinedResult;
2288   }
2289 
2290   std::string model_path;
2291   model_path.assign(char_buf, char_buf + buf_length);
2292   MS_LOG(DEBUG) << "model_path: " << model_path.c_str();
2293 
2294   mindspore::QuantizationType quantization_type = kNoQuant;
2295   int32_t quantization_type_value;
2296   // get quantization
2297   if (argCount >= ARGS_TWO) {
2298     if (napi_get_value_int32(env, argv[PARAM1], &quantization_type_value) != napi_ok) {
2299       MS_LOG(WARNING) << "fail to get int32_t value from quantizationType";
2300       return undefinedResult;
2301     }
2302     quantization_type = static_cast<mindspore::QuantizationType>(quantization_type_value);
2303   }
2304 
2305   // get inference mode
2306   bool export_inference_only = true;
2307   if (argCount >= ARGS_THREE) {
2308     if (napi_get_value_bool(env, argv[PARAM2], &export_inference_only) != napi_ok) {
2309       MS_LOG(WARNING) << "fail to get bool value from exportInferenceOnly";
2310       return undefinedResult;
2311     }
2312   }
2313 
2314   // get output names
2315   std::vector<std::string> output_tensor_name;
2316   if (argCount >= ARGS_FOUR) {
2317     auto ret = CommonNapi::GetStringArray(env, argv[PARAM3], output_tensor_name);
2318     if (ret != SUCCESS) {
2319       MS_LOG(ERROR) << "Get context target failed.";
2320       return undefinedResult;
2321     }
2322   }
2323 
2324   auto ret = mindspore::Serialization::ExportModel(*(modelNapi->native_model_.get()), static_cast<mindspore::ModelType>(kMindIR),
2325                                         model_path, static_cast<mindspore::QuantizationType>(quantization_type),
2326                                         export_inference_only, output_tensor_name);
2327   if (ret != mindspore::kSuccess) {
2328     MS_LOG(ERROR) << "Export model failed";
2329     return undefinedResult;
2330   }
2331 
2332   status = napi_get_boolean(env, true, &jsResult);
2333   if (status != napi_ok) {
2334     MS_LOG(ERROR) << "get bool error";
2335     return undefinedResult;
2336   }
2337   MS_LOG(DEBUG) << "Export Model Success";
2338   return jsResult;
2339 }
2340 
ExportWeightsCollaborateWithMicro(napi_env env,napi_callback_info info)2341 napi_value MSLiteModelNapi::ExportWeightsCollaborateWithMicro(napi_env env, napi_callback_info info) {
2342   napi_value undefinedResult = nullptr;
2343   bool result = false;
2344   napi_status status = napi_get_boolean(env, result, &undefinedResult);
2345   if (status != napi_ok) {
2346     MS_LOG(ERROR) << "get bool error";
2347     return undefinedResult;
2348   }
2349 
2350   napi_value jsThis = nullptr;
2351   napi_value jsResult = nullptr;
2352   MSLiteModelNapi *modelNapi = nullptr;
2353   napi_value argv[ARGS_FOUR] = {0};
2354   size_t argCount = PARAM4;
2355   status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2356   if (status != napi_ok || jsThis == nullptr) {
2357     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2358     return undefinedResult;
2359   }
2360   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2361   if (status != napi_ok || modelNapi == nullptr) {
2362     MS_LOG(ERROR) << "get model napi error";
2363     return undefinedResult;
2364   }
2365 
2366   if (modelNapi->native_model_ == nullptr) {
2367     MS_LOG(ERROR) << "model is released(null), please create model again";
2368     return undefinedResult;
2369   }
2370 
2371   // get weight file
2372   char char_buf[SIZE];
2373   size_t buf_length = 0;
2374   status = napi_get_value_string_utf8(env, argv[PARAM0], char_buf, SIZE, &buf_length);
2375   if ((status != napi_ok) || (buf_length <= 0)) {
2376     MS_LOG(ERROR) << "Parse model file failed.";
2377     return undefinedResult;
2378   }
2379 
2380   std::string weight_file;
2381   weight_file.assign(char_buf, char_buf + buf_length);
2382   MS_LOG(DEBUG) << "weight_file: " << weight_file.c_str();
2383 
2384   // get is inference
2385   bool is_inference = true;
2386   if (argCount >= ARGS_TWO) {
2387     if (napi_get_value_bool(env, argv[PARAM1], &is_inference) != napi_ok) {
2388       MS_LOG(WARNING) << "fail to get bool value from isInference";
2389       return undefinedResult;
2390     }
2391   }
2392 
2393   // get inference mode
2394   bool enable_fp16 = false;
2395   if (argCount >= ARGS_THREE) {
2396     if (napi_get_value_bool(env, argv[PARAM2], &enable_fp16) != napi_ok) {
2397       MS_LOG(WARNING) << "fail to get bool value from enableFp16";
2398       return undefinedResult;
2399     }
2400   }
2401 
2402   // get output names
2403   std::vector<std::string> changeable_weights_name;
2404   if (argCount >= ARGS_FOUR) {
2405     auto ret = CommonNapi::GetStringArray(env, argv[PARAM3], changeable_weights_name);
2406     if (ret != SUCCESS) {
2407       MS_LOG(ERROR) << "failed to get string array from changeableWeightsName";
2408       return undefinedResult;
2409     }
2410   }
2411 
2412   auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(*(modelNapi->native_model_.get()), static_cast<mindspore::ModelType>(kMindIR),
2413                                                               weight_file, is_inference, enable_fp16, changeable_weights_name);
2414 
2415   if (ret != mindspore::kSuccess) {
2416     MS_LOG(ERROR) << "ExportWeightsCollaborateWithMicro failed";
2417     return undefinedResult;
2418   }
2419 
2420   status = napi_get_boolean(env, true, &jsResult);
2421   if (status != napi_ok) {
2422     MS_LOG(ERROR) << "get bool error";
2423     return undefinedResult;
2424   }
2425   MS_LOG(DEBUG) << "ExportWeightsCollaborateWithMicro Success";
2426   return jsResult;
2427 }
2428 
SetupVirtualBatch(napi_env env,napi_callback_info info)2429 napi_value MSLiteModelNapi::SetupVirtualBatch(napi_env env, napi_callback_info info) {
2430   napi_value undefinedResult = nullptr;
2431   bool result = false;
2432   napi_status status = napi_get_boolean(env, result, &undefinedResult);
2433   if (status != napi_ok) {
2434     MS_LOG(ERROR) << "get bool error";
2435     return undefinedResult;
2436   }
2437 
2438   napi_value jsThis = nullptr;
2439   napi_value jsResult = nullptr;
2440   MSLiteModelNapi *modelNapi = nullptr;
2441   napi_value argv[ARGS_THREE] = {0};
2442   size_t argCount = ARGS_THREE;
2443   status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2444   if (status != napi_ok || jsThis == nullptr) {
2445     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2446     return undefinedResult;
2447   }
2448   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2449   if (status != napi_ok || modelNapi == nullptr) {
2450     MS_LOG(ERROR) << "get model napi error";
2451     return undefinedResult;
2452   }
2453 
2454   if (modelNapi->native_model_ == nullptr) {
2455     MS_LOG(ERROR) << "model is released(null), please create model again";
2456     return undefinedResult;
2457   }
2458 
2459   // get virtual batch
2460   int virtual_batch_multiplier;
2461   if (napi_get_value_int32(env, argv[PARAM0], &virtual_batch_multiplier) != napi_ok) {
2462     MS_LOG(WARNING) << "fail to get int32 value from virtualBatchMultiplier";
2463     return undefinedResult;
2464   }
2465 
2466   // get lr
2467   double lr = -1.0f;
2468   if (argCount >= ARGS_TWO) {
2469     if (napi_get_value_double(env, argv[PARAM1], &lr) != napi_ok) {
2470       MS_LOG(WARNING) << "fail to get double value from lr";
2471       return undefinedResult;
2472     }
2473   }
2474 
2475   // get lr
2476   double momentum = -1.0f;
2477   if (argCount >= ARGS_THREE) {
2478     if (napi_get_value_double(env, argv[PARAM2], &momentum) != napi_ok) {
2479       MS_LOG(WARNING) << "fail to get double value from momentum";
2480       return undefinedResult;
2481     }
2482   }
2483 
2484 
2485   auto ret = modelNapi->native_model_->SetupVirtualBatch(virtual_batch_multiplier, static_cast<float>(lr), static_cast<float>(momentum));
2486 
2487   if (ret != mindspore::kSuccess) {
2488     MS_LOG(ERROR) << "SetupVirtualBatch failed";
2489     return undefinedResult;
2490   }
2491 
2492   status = napi_get_boolean(env, true, &jsResult);
2493   if (status != napi_ok) {
2494     MS_LOG(ERROR) << "get bool error";
2495     return undefinedResult;
2496   }
2497   return jsResult;
2498 }
GetTrainMode(napi_env env,napi_callback_info info)2499 napi_value MSLiteModelNapi::GetTrainMode(napi_env env, napi_callback_info info) {
2500   napi_value undefinedResult = nullptr;
2501 
2502   napi_value jsThis = nullptr;
2503   napi_value jsResult = nullptr;
2504   MSLiteModelNapi *modelNapi = nullptr;
2505   napi_value argv[ARGS_ONE] = {0};
2506   size_t argCount = ARGS_ONE;
2507   auto status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2508   if (status != napi_ok || jsThis == nullptr) {
2509     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2510     return undefinedResult;
2511   }
2512   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2513   if (status != napi_ok || modelNapi == nullptr) {
2514     MS_LOG(ERROR) << "get model napi error";
2515     return undefinedResult;
2516   }
2517   if (modelNapi->native_model_ == nullptr) {
2518     MS_LOG(ERROR) << "model is released(null), please create model again";
2519     return undefinedResult;
2520   }
2521 
2522   auto train_mode = modelNapi->native_model_->GetTrainMode();
2523 
2524   status = napi_get_boolean(env, train_mode, &jsResult);
2525   if (status != napi_ok) {
2526     MS_LOG(WARNING) << "create bool value error";
2527     return undefinedResult;
2528   }
2529   return jsResult;
2530 }
SetTrainMode(napi_env env,napi_callback_info info)2531 napi_value MSLiteModelNapi::SetTrainMode(napi_env env, napi_callback_info info) {
2532   napi_value undefinedResult = nullptr;
2533 
2534   napi_value jsThis = nullptr;
2535   napi_value jsResult = nullptr;
2536   MSLiteModelNapi *modelNapi = nullptr;
2537   napi_value argv[ARGS_ONE] = {0};
2538   size_t argCount = ARGS_ONE;
2539   auto status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2540   if (status != napi_ok || jsThis == nullptr) {
2541     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2542     return undefinedResult;
2543   }
2544   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2545   if (status != napi_ok || modelNapi == nullptr) {
2546     MS_LOG(ERROR) << "get model napi error";
2547     return undefinedResult;
2548   }
2549   if (modelNapi->native_model_ == nullptr) {
2550     MS_LOG(ERROR) << "model is released(null), please create model again";
2551     return undefinedResult;
2552   }
2553 
2554   bool train_mode;
2555   if (napi_get_value_bool(env, argv[PARAM0], &train_mode) != napi_ok) {
2556     MS_LOG(WARNING) << "failed to get bool value from input train mode.";
2557     return undefinedResult;
2558   }
2559   if (!model_info_->train_model) {
2560     MS_LOG(WARNING) << "current model is not train model, unable to set train or eval mode";
2561     return undefinedResult;
2562   }
2563   if (modelNapi->native_model_->SetTrainMode(train_mode) != kSuccess) {
2564     MS_LOG(ERROR) << "set train mode failed";
2565     return undefinedResult;
2566   }
2567 
2568   status = napi_get_boolean(env, true, &jsResult);
2569   if (status != napi_ok) {
2570     MS_LOG(WARNING) << "create bool value error";
2571     return undefinedResult;
2572   }
2573   return jsResult;
2574 }
GetLearningRate(napi_env env,napi_callback_info info)2575 napi_value MSLiteModelNapi::GetLearningRate(napi_env env, napi_callback_info info) {
2576   napi_value undefinedResult = nullptr;
2577 
2578   napi_value jsThis = nullptr;
2579   napi_value jsResult = nullptr;
2580   MSLiteModelNapi *modelNapi = nullptr;
2581   napi_value argv[ARGS_ONE] = {0};
2582   size_t argCount = ARGS_ONE;
2583   auto status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2584   if (status != napi_ok || jsThis == nullptr) {
2585     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2586     return undefinedResult;
2587   }
2588   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2589   if (status != napi_ok || modelNapi == nullptr) {
2590     MS_LOG(ERROR) << "get model napi error";
2591     return undefinedResult;
2592   }
2593   if (modelNapi->native_model_ == nullptr) {
2594     MS_LOG(ERROR) << "model is released(null), please create model again";
2595     return undefinedResult;
2596   }
2597 
2598   auto lr = modelNapi->native_model_->GetLearningRate();
2599 
2600   status = napi_create_double(env, lr, &jsResult);
2601   if (status != napi_ok) {
2602     MS_LOG(WARNING) << "create double value error";
2603     return undefinedResult;
2604   }
2605   return jsResult;
2606 }
SetLearningRate(napi_env env,napi_callback_info info)2607 napi_value MSLiteModelNapi::SetLearningRate(napi_env env, napi_callback_info info) {
2608   napi_value undefinedResult = nullptr;
2609 
2610   napi_value jsThis = nullptr;
2611   napi_value jsResult = nullptr;
2612   MSLiteModelNapi *modelNapi = nullptr;
2613   napi_value argv[ARGS_ONE] = {0};
2614   size_t argCount = ARGS_ONE;
2615   auto status = napi_get_cb_info(env, info, &argCount, argv, &jsThis, nullptr);
2616   if (status != napi_ok || jsThis == nullptr) {
2617     MS_LOG(ERROR) << "failed to retrieve details about the callback";
2618     return undefinedResult;
2619   }
2620   status = napi_unwrap(env, jsThis, reinterpret_cast<void **>(&modelNapi));
2621   if (status != napi_ok || modelNapi == nullptr) {
2622     MS_LOG(ERROR) << "get model napi error";
2623     return undefinedResult;
2624   }
2625   if (modelNapi->native_model_ == nullptr) {
2626     MS_LOG(ERROR) << "model is released(null), please create model again";
2627     return undefinedResult;
2628   }
2629 
2630   if (!model_info_->train_model) {
2631     MS_LOG(WARNING) << "current model is not train model, unable to set learning rate";
2632     return undefinedResult;
2633   }
2634 
2635   double lr;
2636   if (napi_get_value_double(env, argv[PARAM0], &lr) != napi_ok) {
2637     MS_LOG(WARNING) << "failed to get double value.";
2638     return undefinedResult;
2639   }
2640 
2641   if (modelNapi->native_model_->SetLearningRate(static_cast<float>(lr)) != kSuccess) {
2642     MS_LOG(ERROR) << "set learning rate failed";
2643     return undefinedResult;
2644   }
2645 
2646   status = napi_get_boolean(env, true, &jsResult);
2647   if (status != napi_ok) {
2648     MS_LOG(WARNING) << "create bool value error";
2649     return undefinedResult;
2650   }
2651   return jsResult;
2652 }
2653 }  // namespace mindspore
2654