• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "cache_session.h"
18 #include "src/common/context_util.h"
19 #include "src/common/tensor_util.h"
20 #include "src/common/mmap_utils.h"
21 #include "src/common/file_utils.h"
22 #include "src/litert/delegate/nnrt/nnrt_model_kernel.h"
23 
24 namespace mindspore {
25 namespace lite {
~CacheSession()26 CacheSession::~CacheSession() {
27   if (nn_executor_ != nullptr) {
28     OH_NNExecutor_Destroy(&nn_executor_);
29     MS_LOG(INFO) << "Destroy NNExecutor Finish.";
30   }
31 }
32 
CompileGraph(Model * model)33 int CacheSession::CompileGraph(Model *model) {
34   bool expected = false;
35   if (!is_running_.compare_exchange_strong(expected, true)) {
36     MS_LOG(ERROR) << "Not support multi-threading";
37     return RET_ERROR;
38   }
39   // Convert to abstract base model interface
40   auto ret = ConvertInOutTensors(model);
41   context_->set_schema_version(reinterpret_cast<LiteModel *>(model)->GetSchemaVersion());
42   if (ret != RET_OK) {
43     MS_LOG(ERROR) << "ConvertTensors failed: " << ret;
44     is_running_.store(false);
45     return ret;
46   }
47   InitGraphInputTensors(model);
48   InitGraphOutputTensors(model);
49 
50   // create NNRt kernel
51   ret = ScheduleToNNRTKernel();
52   if (ret != RET_OK) {
53     MS_LOG(ERROR) << "Schedule NNRt kernel failed: " << ret;
54     is_running_.store(false);
55     return ret;
56   }
57 
58   InitGraphInOutTensorsMap(model);
59   ret = PrepareKernels(model);
60   if (ret != RET_OK) {
61     MS_LOG(ERROR) << "Prepare kernels failed: " << ret;
62     is_running_.store(false);
63     return ret;
64   }
65 
66   ret = InitExecutor();
67   if (ret != RET_OK) {
68     MS_LOG(ERROR) << "InitExecutor failed: " << ret;
69     is_running_.store(false);
70     return ret;
71   }
72 
73   MarkSharedWeight(kernels_);
74   FreePackOpWeight(kernels_);
75 
76   is_running_.store(false);
77   return RET_OK;
78 }
79 
InitExecutor()80 int CacheSession::InitExecutor() {
81   executor_ = new (std::nothrow) Executor();
82   if (executor_ == nullptr) {
83     MS_LOG(ERROR) << "New Executor failed";
84     return RET_ERROR;
85   }
86   auto ret = executor_->Prepare(kernels_, inputs_, outputs_, context_.get());
87   if (ret != RET_OK) {
88     MS_LOG(ERROR) << "Prepare executor failed: " << ret;
89     return ret;
90   }
91   return RET_OK;
92 }
93 
ConvertInOutTensors(const lite::Model * model)94 int CacheSession::ConvertInOutTensors(const lite::Model *model) {
95   MS_ASSERT(model != nullptr);
96   auto lite_model = reinterpret_cast<const lite::LiteModel *>(model);
97   uint32_t tensor_count = model->graph_.all_tensors_.size();
98   auto model_input_indices = model->graph_.input_indices_;
99   auto model_output_indices = model->graph_.output_indices_;
100 
101   for (uint32_t i = 0; i < tensor_count; ++i) {
102     auto *src_tensor = model->graph_.all_tensors_[i];
103     if (!IsContain(model_input_indices, i) && !IsContain(model_output_indices, i)) {
104       this->tensors_.emplace_back(nullptr);
105       continue;
106     }
107     if (src_tensor == nullptr) {
108       MS_LOG(ERROR) << i << "th tensor in model is nullptr";
109       return RET_NULL_PTR;
110     }
111     auto *dst_tensor = ConvertTensor(*src_tensor);
112     if (dst_tensor == nullptr) {
113       MS_LOG(ERROR) << "Convert new " << i << "th tensor failed!";
114       return RET_NULL_PTR;
115     }
116     auto ret = ConvertTensorsData(lite_model, i, dst_tensor);
117     if (ret != RET_OK) {
118       MS_LOG(ERROR) << "Convert data of " << i << "th tensor failed";
119       delete dst_tensor;
120       return ret;
121     }
122     ConvertTensorsQuantParam(src_tensor, dst_tensor);
123     if (IsContain(model_input_indices, i)) {
124       dst_tensor->set_category(Category::GRAPH_INPUT);
125     }
126     if (IsContain(model_output_indices, i)) {
127       // a tensor is as both input and output, would be treated as an input.
128       if (!dst_tensor->IsGraphInput()) {
129         dst_tensor->set_category(Category::GRAPH_OUTPUT);
130       }
131     }
132 
133     ret = CheckTensorValid(dst_tensor);
134     if (ret != RET_OK) {
135       MS_LOG(ERROR) << "Check " << i << "th tensor failed";
136       delete dst_tensor;
137       return ret;
138     }
139 
140     this->tensors_.emplace_back(dst_tensor);
141   }
142   return RET_OK;
143 }
144 
Init(const std::shared_ptr<InnerContext> & context)145 int CacheSession::Init(const std::shared_ptr<InnerContext> &context) {
146   if (context == nullptr) {
147     MS_LOG(ERROR) << "context is nullptr";
148     return RET_NULL_PTR;
149   }
150   bool expected = false;
151   if (!is_running_.compare_exchange_strong(expected, true)) {
152     MS_LOG(ERROR) << "Not support multi-threading";
153     return RET_ERROR;
154   }
155   context_ = context;
156   auto ret = context_->Init();
157   if (ret != RET_OK) {
158     MS_LOG(ERROR) << "Init Context failed";
159     return ret;
160   }
161   ms_context_ = MSContextFromContext(context);
162   if (ms_context_ == nullptr) {
163     MS_LOG(ERROR) << "transfer context to ms context failed.";
164     return RET_NULL_PTR;
165   }
166 
167   auto iter = std::find_if(context_->device_list_.begin(), context_->device_list_.end(),
168                            [](DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; });
169   if(iter == context_->device_list_.end()) {
170     MS_LOG(ERROR) << "Found non NNRT device info";
171     return RET_ERROR;
172   }
173   nnrt_device_info_ = iter->device_info_.nnrt_device_info_;
174 
175   const auto &extensions = nnrt_device_info_.extensions_;
176   mindspore::lite::nnrt::ExtensionOptionsParser::Parse(extensions, &extension_options_);
177 
178   is_running_.store(false);
179   return RET_OK;
180 }
181 
ParseInputOutputFromModelBuffer(const char * model_buf,LiteModel * model)182 int CacheSession::ParseInputOutputFromModelBuffer(const char *model_buf, LiteModel *model) {
183   const void *meta_graph = nullptr;
184   meta_graph = reinterpret_cast<const void *>(schema::GetMetaGraph(model_buf));
185   assert(meta_graph != nullptr);
186 
187   auto status = GenerateModelInputOutput<schema::MetaGraph, schema::CNode>(
188     *reinterpret_cast<const schema::MetaGraph *>(meta_graph), model->graph_);
189   if (status != RET_OK) {
190     MS_LOG(ERROR) << "fail to generate model";
191     return status;
192   }
193   model->buf = const_cast<char *>(model_buf);
194   return RET_OK;
195 }
196 
LoadModelAndCompileByPath(const std::string & model_path,mindspore::ModelType model_type)197 int CacheSession::LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type) {
198   size_t model_size;
199   bool use_mmap = IsMmapEnable();
200   auto model_buf = LoadModelByPath(model_path, model_type, &model_size, use_mmap);
201   if (model_buf == nullptr) {
202     MS_LOG(ERROR) << "Read model file failed";
203     return RET_ERROR;
204   }
205 
206   Model *model = nullptr;
207   if (extension_options_.cache_path_.empty()) {
208     MS_LOG(ERROR) << "cache path is empty";
209     return RET_ERROR;
210   } else {
211     model = ImportInOutFromBuffer(model_buf, model_size, true, model_type, model_path);
212     if (model == nullptr) {
213       MS_LOG(ERROR) << "Import model failed";
214       return RET_ERROR;
215     }
216     dynamic_cast<LiteModel *>(model)->PrepareInnerTensors();
217   }
218   if (model == nullptr) {
219     MS_LOG(ERROR) << "Import model failed";
220     return RET_ERROR;
221   }
222 
223   if (use_mmap) {
224     reinterpret_cast<lite::LiteModel *>(model)->model_buf_by_mmap_ = true;
225   } else {
226     MS_LOG(WARNING) << "Memory may exceed the limit of business demands.";
227   }
228   (reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true);
229   auto ret = CompileGraph(model);
230   if (ret != lite::RET_OK) {
231     MS_LOG(ERROR) << "Compile model failed";
232     model->buf = nullptr;
233     delete model;
234     return RET_ERROR;
235   }
236   set_model(model);
237   return RET_OK;
238 }
239 
ImportInOutFromBuffer(const char * model_buf,size_t size,bool take_buf,mindspore::ModelType model_type,const std::string & path)240 Model *CacheSession::ImportInOutFromBuffer(const char *model_buf, size_t size, bool take_buf, mindspore::ModelType model_type,
241                                const std::string &path) {
242   MS_LOG(INFO) << "import model from lite model";
243   auto *model = new (std::nothrow) LiteModel(path);
244   if (model == nullptr) {
245     MS_LOG(ERROR) << "new model fail!";
246     return nullptr;
247   }
248 
249   auto status = ParseInputOutputFromModelBuffer(model_buf, model);
250   if (status != RET_OK) {
251     MS_LOG(ERROR) << "construct model failed.";
252     delete model;
253     return nullptr;
254   }
255   model->buf = const_cast<char *>(model_buf);
256   model->buf_size_ = size;
257   return model;
258 }
259 
ScheduleToNNRTKernel()260 int CacheSession::ScheduleToNNRTKernel() {
261   if (!IsKirinNPUWithOnlineInference(nnrt_device_info_.device_id_)) {
262     MS_LOG(ERROR) << "only support NPU_ device.";
263     return RET_ERROR;
264   }
265   auto ret = CreateFullModelKernel();
266   if (ret != kSuccess) {
267     MS_LOG(ERROR) << "Build npu model failed.";
268     return RET_ERROR;
269   }
270   return RET_OK;
271 }
272 
IsKirinNPUWithOnlineInference(size_t device_id)273 bool CacheSession::IsKirinNPUWithOnlineInference(size_t device_id) {
274   const std::string kirin_npu_name_prefix = "NPU_";
275   const char *device_name;
276   auto ret = OH_NNDevice_GetName(device_id, &device_name);
277   if (ret != OH_NN_SUCCESS) {
278     MS_LOG(WARNING) << "Get name of device: " << device_id << " failed, error: " << ret;
279     return false;
280   }
281 
282   if (strncmp(kirin_npu_name_prefix.c_str(), device_name, kirin_npu_name_prefix.size()) != 0) {
283     MS_LOG(WARNING) << "strncmp: " << device_id << " failed, device_name: " << device_name;
284     return false;
285   }
286   return true;
287 }
288 
CreateFullModelKernel()289 Status CacheSession::CreateFullModelKernel() {
290   OH_NNCompilation* nn_compilation = OH_NNCompilation_ConstructForCache();
291   if (nn_compilation == nullptr) {
292     MS_LOG(ERROR) << "Construct NNCompilation failed";
293     return kLiteError;
294   }
295   MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success.";
296 
297   auto ret_code = InitNNCompilation(nn_compilation);
298   if (ret_code != kSuccess) {
299     MS_LOG(ERROR) << "Init NNCompilation failed";
300     OH_NNCompilation_Destroy(&nn_compilation);
301     return kLiteError;
302   }
303 
304   OH_NNExecutor *nn_executor = nullptr;
305   nn_executor = OH_NNExecutor_Construct(nn_compilation);
306   if (nn_executor == nullptr) {
307     MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code;
308     OH_NNCompilation_Destroy(&nn_compilation);
309     return kLiteError;
310   }
311   OH_NNCompilation_Destroy(&nn_compilation);
312 
313   ms_inputs_ = LiteTensorsToMSTensors(inputs_);
314   ms_outputs_ = LiteTensorsToMSTensors(outputs_);
315   auto nnrt_model_kernel = new (std::nothrow) NNRTModelKernel(nn_executor, nnrt_device_info_, ms_inputs_, ms_outputs_);
316   if (nnrt_model_kernel == nullptr) {
317     OH_NNExecutor_Destroy(&nn_executor);
318     MS_LOG(ERROR) << "new NNRTModelKernel failed";
319     return kLiteError;
320   }
321   nn_executor_ = nn_executor;
322 
323   std::shared_ptr<kernel::Kernel> shared_kernel(nnrt_model_kernel);
324   auto *kernel_exec = new (std::nothrow) kernel::KernelExec(shared_kernel);
325   if (kernel_exec == nullptr) {
326     MS_LOG(ERROR) << "nnrt kernel exec create failed.";
327     return kLiteError;
328   }
329   auto delegate_type = kNumberTypeFloat32;
330   for (auto &input : nnrt_model_kernel->inputs()) {
331     if (static_cast<TypeId>(input.DataType()) == kNumberTypeFloat16) {
332       delegate_type = kNumberTypeFloat16;
333       break;
334     }
335   }
336   kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, NHWC, schema::PrimitiveType_NONE, "", ""};
337   kernel_exec->set_desc(delegate_desc);
338   kernel_exec->set_context(context_.get());
339   kernels_.push_back(kernel_exec);
340 
341   return kSuccess;
342 }
343 
InitNNCompilation(OH_NNCompilation * nn_compilation) const344 Status CacheSession::InitNNCompilation(OH_NNCompilation *nn_compilation) const {
345   auto ret_code = OH_NNCompilation_SetDevice(nn_compilation, nnrt_device_info_.device_id_);
346   if (ret_code != OH_NN_SUCCESS) {
347     MS_LOG(ERROR) << "NNCompilation set device id failed, ret: " << ret_code;
348     return kLiteError;
349   }
350   ret_code = OH_NNCompilation_SetPerformanceMode(nn_compilation,
351                                                  (OH_NN_PerformanceMode)(nnrt_device_info_.performance_mode_));
352   if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
353     MS_LOG(ERROR) << "NNCompilation set performance mode failed, ret: " << ret_code;
354     return kLiteError;
355   }
356   ret_code = OH_NNCompilation_SetPriority(nn_compilation, (OH_NN_Priority)(nnrt_device_info_.priority_));
357   if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
358     MS_LOG(ERROR) << "NNCompilation set priority failed, ret: " << ret_code;
359     return kLiteError;
360   }
361   ret_code = OH_NNCompilation_EnableFloat16(nn_compilation, nnrt_device_info_.enable_fp16_);
362   if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
363     MS_LOG(ERROR) << "NNCompilation enable fp16 failed, ret: " << ret_code;
364     return kLiteError;
365   }
366 
367   if (!extension_options_.cache_path_.empty()) {
368     ret_code = OH_NNCompilation_SetCache(nn_compilation, extension_options_.cache_path_.c_str(),
369                                          extension_options_.cache_version_);
370     if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
371       MS_LOG(ERROR) << "NNCompilation set cache failed, ret: " << ret_code;
372       return kLiteError;
373     }
374   } else {
375     MS_LOG(ERROR) << "NNCompilation must set Cache.";
376     return kLiteError;
377   }
378 
379   size_t extension_size = nnrt_device_info_.extensions_.size();
380   for (size_t i = 0; i < extension_size; i++) {
381     auto &src_extensoin = nnrt_device_info_.extensions_[i];
382     ret_code = OH_NNCompilation_AddExtensionConfig(nn_compilation, src_extensoin.name.c_str(),
383                                                    (char *)((void *)src_extensoin.value.data()),
384                                                    src_extensoin.value.size());
385     if (ret_code != OH_NN_SUCCESS) {
386       MS_LOG(ERROR) << "OH_NNCompilation_AddExtensionConfig " << i << ": "<< src_extensoin.name << " failed, ret: "
387                     << ret_code;
388       return kLiteError;
389     }
390   }
391 
392   ret_code = OH_NNCompilation_Build(nn_compilation);
393   if (ret_code != OH_NN_SUCCESS) {
394     MS_LOG(ERROR) << "Build NNCompilation failed, ret: " << ret_code;
395     return kLiteError;
396   }
397   return kSuccess;
398 }
399 
LoadModelByPath(const std::string & file,mindspore::ModelType model_type,size_t * size,bool use_mmap)400 const char *CacheSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, bool use_mmap) {
401   size_t buf_size;
402   char *model_buf;
403   if (use_mmap) {
404     model_buf = reinterpret_cast<char *>(lite::ReadFileByMmap(file.c_str(), &buf_size, false));
405   } else {
406     MS_LOG(WARNING) << "Memory may exceed the limit of business demands.";
407     model_buf = lite::ReadFile(file.c_str(), &buf_size);
408   }
409   if (model_buf == nullptr) {
410     MS_LOG(ERROR) << "The model path is invalid";
411     return model_buf;
412   }
413 
414   char *lite_buf = nullptr;
415   auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, size, model_type);
416   if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) {
417     if (use_mmap) {
418       lite::UnmapMmapBuffer(const_cast<void *>(static_cast<const void *>(model_buf)), buf_size);
419     } else {
420       delete[] model_buf;
421     }
422     model_buf = nullptr;
423     return nullptr;
424   }
425 
426   return lite_buf;
427 }
428 }  // namespace lite
429 }  // namespace mindspore
430