1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/lite_model.h"
18 #include <sys/stat.h>
19 #include <iostream>
20 #include <sstream>
21 #include <functional>
22 #include <vector>
23 #include <algorithm>
24 #include <set>
25 #include <unordered_set>
26 #include <unordered_map>
27 #include <memory>
28 #include <numeric>
29 #include "src/common/prim_util.h"
30 #include "src/common/graph_util.h"
31 #include "src/common/file_utils.h"
32 #include "src/common/utils.h"
33 #include "src/tensor.h"
34 #include "extendrt/mindir_loader/model_loader.h"
35 #include "src/common/mmap_utils.h"
36 #include <dlfcn.h>
37
38 namespace mindspore::lite {
39 namespace {
40 constexpr size_t kMaxModelBufferSize = static_cast<size_t>(1024) * 1024 * 1024 * 2;
41 }
42
Free()43 void LiteModel::Free() {
44 if (this->model_buf_by_mmap_) {
45 UnmapMmapBuffer(static_cast<void *>(this->buf), this->buf_size_);
46 this->buf = nullptr;
47 }
48 if (this->buf != nullptr && !this->model_buf_by_mmap_) {
49 delete[](this->buf);
50 this->buf = nullptr;
51 }
52 auto nodes_size = this->graph_.all_nodes_.size();
53 for (size_t i = 0; i < nodes_size; ++i) {
54 auto node = this->graph_.all_nodes_[i];
55 node->primitive_ = nullptr;
56 }
57 for (auto &tensor_buf : attr_tensor_bufs_) {
58 free(tensor_buf);
59 tensor_buf = nullptr;
60 }
61 attr_tensor_bufs_.resize(0);
62
63 for (auto &node_buf : node_bufs_) {
64 free(node_buf);
65 node_buf = nullptr;
66 }
67 node_bufs_.resize(0);
68
69 for (auto *schema_tensor_wrapper : inner_all_tensors_) {
70 if (schema_tensor_wrapper != nullptr) {
71 delete schema_tensor_wrapper;
72 }
73 }
74 inner_all_tensors_.clear();
75
76 if(this->deobf != nullptr){
77 delete(reinterpret_cast<DeObfProcessor *>(this->deobf));
78 }
79 }
80
Destroy()81 void LiteModel::Destroy() {
82 Free();
83 auto nodes_size = this->graph_.all_nodes_.size();
84 for (size_t i = 0; i < nodes_size; ++i) {
85 auto node = this->graph_.all_nodes_[i];
86 MS_ASSERT(node != nullptr);
87 delete node;
88 }
89 this->graph_.all_nodes_.clear();
90
91 auto sub_graph_size = this->graph_.sub_graphs_.size();
92 for (size_t i = 0; i < sub_graph_size; ++i) {
93 auto sub_graph = this->graph_.sub_graphs_[i];
94 delete sub_graph;
95 }
96 }
97
ConvertSubGraph(const schema::SubGraph & sub_graph)98 int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) {
99 if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr ||
100 sub_graph.tensorIndices() == nullptr) {
101 MS_LOG(ERROR) << "sub_graph is invalid";
102 MS_LOG(ERROR) << "sub_graph.name() = " << sub_graph.name() << ", sub_graph.inputIndices() = " << sub_graph.inputIndices()
103 << ", sub_graph.outputIndices() = " << sub_graph.outputIndices() << ", sub_graph.tensorIndices() = " << sub_graph.tensorIndices();
104 return RET_ERROR;
105 }
106
107 auto *subgraph = new (std::nothrow) LiteGraph::SubGraph();
108 if (subgraph == nullptr) {
109 MS_LOG(ERROR) << "new subGraph fail!";
110 return RET_ERROR;
111 }
112
113 subgraph->name_ = sub_graph.name()->c_str();
114 auto in_count = sub_graph.inputIndices()->size();
115 for (uint32_t i = 0; i < in_count; ++i) {
116 subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i));
117 }
118 auto out_count = sub_graph.outputIndices()->size();
119 for (uint32_t i = 0; i < out_count; ++i) {
120 subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i));
121 }
122 if (sub_graph.nodeIndices() != nullptr) {
123 auto node_count = sub_graph.nodeIndices()->size();
124 for (uint32_t i = 0; i < node_count; ++i) {
125 subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i));
126 }
127 }
128 auto tensor_count = sub_graph.tensorIndices()->size();
129 for (uint32_t i = 0; i < tensor_count; ++i) {
130 subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i));
131 }
132 this->graph_.sub_graphs_.push_back(subgraph);
133 return RET_OK;
134 }
135
VersionVerify(flatbuffers::Verifier * verify)136 int LiteModel::VersionVerify(flatbuffers::Verifier *verify) {
137 if (verify == nullptr) {
138 MS_LOG(ERROR) << "verify is null.";
139 return RET_ERROR;
140 }
141 if (schema::VerifyMetaGraphBuffer(*verify)) {
142 return SCHEMA_VERSION::SCHEMA_CUR;
143 }
144 return SCHEMA_VERSION::SCHEMA_INVALID;
145 }
146
NodeVerify() const147 int LiteModel::NodeVerify() const {
148 auto tensor_size = this->graph_.all_tensors_.size();
149 uint32_t node_size = this->graph_.all_nodes_.size();
150 uint32_t subgraph_size = static_cast<uint32_t>(this->graph_.sub_graphs_.size());
151
152 for (uint32_t node_index = 0; node_index < node_size; node_index++) {
153 auto &node = this->graph_.all_nodes_.at(node_index);
154 if (node == nullptr || node->primitive_ == nullptr) {
155 MS_LOG(ERROR) << "node or its primitive_ is null.";
156 return RET_ERROR;
157 }
158 if (std::any_of(node->input_indices_.begin(), node->input_indices_.end(),
159 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
160 MS_LOG(ERROR) << "Index of node->input_indices_ is beyond size.";
161 return RET_ERROR;
162 }
163 if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(),
164 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
165 MS_LOG(ERROR) << "Index of node->output_indices_ is beyond size.";
166 return RET_ERROR;
167 }
168 if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(), [&, this](const uint32_t &idx) {
169 return this->graph_.all_tensors_[idx]->nodeType() == static_cast<int>(NodeType_ValueNode) &&
170 this->graph_.all_tensors_[idx]->data() != nullptr;
171 })) {
172 MS_LOG(ERROR) << "node output tensor node type is ValueNode, node name: " << node->name_;
173 return RET_ERROR;
174 }
175 if (node->output_indices_.size() !=
176 std::unordered_set<uint32_t>(node->output_indices_.begin(), node->output_indices_.end()).size()) {
177 MS_LOG(ERROR) << "node output indices contain duplicate.";
178 return RET_ERROR;
179 }
180
181 if (IsPartialNode(node->primitive_, schema_version_)) {
182 auto partial_fusion = reinterpret_cast<const schema::Primitive *>(node->primitive_)->value_as_PartialFusion();
183 MS_CHECK_FALSE(partial_fusion == nullptr, RET_ERROR);
184 int64_t subgraph_index = partial_fusion->sub_graph_index();
185 if (subgraph_index < 0) {
186 MS_LOG(ERROR) << "invalid subgraph index:" << subgraph_index;
187 return RET_ERROR;
188 }
189 if (subgraph_index >= static_cast<int64_t>(subgraph_size)) {
190 MS_LOG(ERROR) << "subgraph index:" << subgraph_index << " is beyond subgraph_size: " << subgraph_size;
191 return RET_ERROR;
192 }
193 for (uint32_t graph_index = 0; graph_index < subgraph_size; graph_index++) {
194 auto &graph = this->graph_.sub_graphs_.at(graph_index);
195 if (IsContain(graph->node_indices_, node_index) && graph_index == static_cast<uint32_t>(subgraph_index)) {
196 MS_LOG(ERROR) << "The subgraph called by PartialNode is the subgraph where it is located, subgraph index: "
197 << subgraph_index;
198 return RET_ERROR;
199 }
200 }
201 }
202 if ((!IsTensorListNode(node->primitive_, schema_version_)) && (!IsPartialNode(node->primitive_, schema_version_))) {
203 if (std::any_of(node->input_indices_.begin(), node->input_indices_.end(), [this](const uint32_t &idx) {
204 return TypeId(this->graph_.all_tensors_[idx]->dataType()) == kObjectTypeTensorType;
205 })) {
206 MS_LOG(ERROR) << "node input tensor type can't be object type, node name: " << node->name_;
207 return RET_ERROR;
208 }
209 if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(), [this](const uint32_t &idx) {
210 return TypeId(this->graph_.all_tensors_[idx]->dataType()) == kObjectTypeTensorType;
211 })) {
212 MS_LOG(ERROR) << "node output tensor type can't be object type, node name: " << node->name_;
213 return RET_ERROR;
214 }
215 }
216 }
217 return RET_OK;
218 }
219
SubGraphVerify() const220 int LiteModel::SubGraphVerify() const {
221 auto tensor_size = this->graph_.all_tensors_.size();
222 auto node_size = this->graph_.all_nodes_.size();
223
224 if (graph_.sub_graphs_[0]->input_indices_.size() == 0 || graph_.sub_graphs_[0]->output_indices_.size() == 0) {
225 MS_LOG(ERROR) << "The model has invalid input and output, please check";
226 return RET_ERROR;
227 }
228
229 for (auto &graph : this->graph_.sub_graphs_) {
230 if (graph == nullptr) {
231 MS_LOG(ERROR) << "graph is null.";
232 return RET_ERROR;
233 }
234 if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(),
235 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
236 MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size.";
237 return RET_ERROR;
238 }
239 if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(),
240 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
241 MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size.";
242 return RET_ERROR;
243 }
244 if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(),
245 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
246 MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size.";
247 return RET_ERROR;
248 }
249 if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(), [&](const uint32_t &idx) {
250 bool repeated = std::count_if(graph->node_indices_.begin(), graph->node_indices_.end(),
251 [&idx](const uint32_t &index) { return index == idx; }) != 1;
252 return repeated || idx >= node_size;
253 })) {
254 MS_LOG(ERROR) << "The subgraph contains repeated nodes or the node index is beyond node_size.";
255 return RET_ERROR;
256 }
257 auto ret = SubGraphInOutVerify(graph);
258 if (ret != RET_OK) {
259 MS_LOG(ERROR) << "Fail to pass the subgraph input output verification.";
260 return ret;
261 }
262 }
263 return RET_OK;
264 }
265
GraphInOutVerify() const266 int LiteModel::GraphInOutVerify() const {
267 std::unordered_set<uint32_t> all_subgraphs_inputs;
268 std::unordered_set<uint32_t> all_subgraphs_outputs;
269 for (auto subgraph : this->graph_.sub_graphs_) {
270 for (auto input_idx : subgraph->input_indices_) {
271 all_subgraphs_inputs.emplace(input_idx);
272 }
273 for (auto output_idx : subgraph->output_indices_) {
274 all_subgraphs_outputs.emplace(output_idx);
275 }
276 }
277
278 for (auto input_idx : this->graph_.input_indices_) {
279 if (all_subgraphs_inputs.count(input_idx) == 0) {
280 MS_LOG(ERROR) << "The graph input is not valid.";
281 return RET_ERROR;
282 }
283 }
284
285 for (auto output_idx : this->graph_.output_indices_) {
286 if (all_subgraphs_outputs.count(output_idx) == 0) {
287 MS_LOG(ERROR) << "The graph output is not valid.";
288 return RET_ERROR;
289 }
290 }
291
292 return RET_OK;
293 }
294
SubGraphInOutVerify(const LiteGraph::SubGraph * graph) const295 int LiteModel::SubGraphInOutVerify(const LiteGraph::SubGraph *graph) const {
296 MS_CHECK_TRUE_RET(graph != nullptr, RET_ERROR);
297 auto from_node = [&, this](uint32_t cur_idx) -> bool {
298 for (auto node_idx : graph->node_indices_) {
299 auto node = this->graph_.all_nodes_.at(node_idx);
300 if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(),
301 [&cur_idx](uint32_t idx) { return cur_idx == idx; })) {
302 return true;
303 }
304 }
305 return false;
306 };
307 for (auto in_idx : graph->input_indices_) {
308 auto in_tensor = this->graph_.all_tensors_.at(in_idx);
309 bool is_from_node = from_node(in_idx);
310 bool has_data = in_tensor->data() != nullptr && in_tensor->data()->data() != nullptr;
311 if (is_from_node || (in_tensor->dataType() != kObjectTypeTensorType && has_data)) {
312 MS_LOG(ERROR) << "The graph input is not valid.";
313 return RET_ERROR;
314 }
315 }
316 for (auto out_idx : graph->output_indices_) {
317 auto tensor = this->graph_.all_tensors_.at(out_idx);
318 bool is_from_node = from_node(out_idx);
319 bool is_input = std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(),
320 [&out_idx](uint32_t idx) { return out_idx == idx; });
321 bool from_node_and_has_data = is_from_node && (tensor->data() != nullptr && tensor->data()->data() != nullptr);
322 bool isolated_and_no_data = !is_from_node && (tensor->data() == nullptr || tensor->data()->data() == nullptr);
323 if (!is_input && (from_node_and_has_data || isolated_and_no_data)) {
324 MS_LOG(ERROR) << "The graph output is not valid.";
325 return RET_ERROR;
326 }
327 }
328 return RET_OK;
329 }
330
ModelVerify() const331 bool LiteModel::ModelVerify() const {
332 if (this->graph_.sub_graphs_.empty()) {
333 MS_LOG(ERROR) << "Model does not have a main graph.";
334 return false;
335 }
336
337 if (this->graph_.input_indices_.empty()) {
338 MS_LOG(ERROR) << "Model does not have inputs.";
339 return false;
340 }
341
342 if (this->graph_.output_indices_.empty()) {
343 MS_LOG(ERROR) << "Model does not have outputs.";
344 return false;
345 }
346
347 if (this->graph_.input_indices_ == this->graph_.output_indices_) {
348 MS_LOG(ERROR) << "Model outputs can not be totally same as the inputs.";
349 return false;
350 }
351
352 auto all_tensors_size = this->graph_.all_tensors_.size();
353 for (auto input_index : this->graph_.input_indices_) {
354 if (input_index >= all_tensors_size) {
355 MS_LOG(ERROR) << "Graph input indices is beyond tensor_size.";
356 return false;
357 }
358 auto *tensor = this->graph_.all_tensors_.at(input_index);
359 if (tensor == nullptr) {
360 MS_LOG(ERROR) << "Tensor in all tensors is nullptr.";
361 return false;
362 }
363 // check the input data type
364 if ((static_cast<const TypeId>(tensor->dataType()) <= kNumberTypeBegin ||
365 static_cast<const TypeId>(tensor->dataType()) >= kNumberTypeEnd) &&
366 static_cast<const TypeId>(tensor->dataType()) != kObjectTypeString) {
367 MS_LOG(ERROR) << "The data type is not supported to malloc.";
368 return false;
369 }
370 }
371 if (this->graph_.output_indices_.size() == 1 &&
372 graph_.sub_graphs_[0]->output_indices_.size() != graph_.output_indices_.size()) {
373 MS_LOG(ERROR) << "should be equal";
374 return false;
375 }
376
377 if (std::any_of(graph_.output_indices_.begin(), graph_.output_indices_.end(),
378 [&all_tensors_size](const uint32_t &idx) { return idx >= all_tensors_size; })) {
379 MS_LOG(ERROR) << "Graph output indices is beyond tensor_size.";
380 return false;
381 }
382
383 if (GraphInOutVerify() != RET_OK) {
384 MS_LOG(ERROR) << "The model has invalid input and output.";
385 return false;
386 }
387
388 return NodeVerify() == RET_OK && SubGraphVerify() == RET_OK;
389 }
390 //static variable used for deobfuscator
391 ObfCreateFunc DeObfRegister::NewDeObfProcessor = DeObfRegister::Fail;
392 bool (DeObfProcessor::*DeObfRegister::GetModelDeObfReg)(const void *meta_graph, Model *model);
393 void (DeObfProcessor::*DeObfRegister::DeObfuscateReg)(Model *model);
394 DeObfRet (DeObfProcessor::*DeObfRegister::CreateDeObfNodeReg)(const schema::Primitive *&src_prim, int i, int schema__version);
395 void *DeObfRegister::deobf_handle = nullptr;
396
GenerateModelByVersion()397 int LiteModel::GenerateModelByVersion() {
398 if (this->buf == nullptr) {
399 MS_LOG(ERROR) << "Model buffer not inited";
400 return RET_ERROR;
401 }
402 const void *meta_graph = nullptr;
403 if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
404 meta_graph = reinterpret_cast<const void *>(schema::GetMetaGraph(this->buf));
405 }
406 MS_ASSERT(meta_graph != nullptr);
407 int status = RET_ERROR;
408 if(dlopen("libdeobfuscator_lib.z.so", RTLD_NOLOAD) == nullptr) {
409 DeObfRegister::deobf_handle = dlopen("libdeobfuscator_lib.z.so", RTLD_NOW | RTLD_GLOBAL);
410 }
411 if(DeObfRegister::deobf_handle == nullptr) {
412 MS_LOG(WARNING) << "Deobfuscate ability is disabled, so obfuscated models can not be executed.";
413 } else {
414 auto CreateDeObfFunc = reinterpret_cast<ObfCreateFunc>(dlsym(DeObfRegister::deobf_handle, "CreateDeObfFunc"));
415 if (CreateDeObfFunc == nullptr) {
416 MS_LOG(WARNING) << "cannot fetch CreateDeObfFunc";
417 } else {
418 DeObfRegister::RegisterDeObfuscator(CreateDeObfFunc);
419 DeObfRegister::NewDeObfProcessor(*this);
420 }
421 }
422 if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
423 if(this->deobf != nullptr) {
424 auto deobf_ptr = reinterpret_cast<DeObfProcessor *>(this->deobf);
425 auto ret = (deobf_ptr->*DeObfRegister::GetModelDeObfReg)(meta_graph, this);
426 if(!ret){
427 return RET_ERROR;
428 }
429 }
430 status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph));
431 }
432 if(this->deobf != nullptr) {
433 auto deobf_ptr = reinterpret_cast<DeObfProcessor *>(this->deobf);
434 (deobf_ptr->*DeObfRegister::DeObfuscateReg)(this);
435 }
436 if(DeObfRegister::deobf_handle != nullptr) {
437 dlclose(DeObfRegister::deobf_handle);
438 }
439 if (IsVersionGreaterThan(GetShortVersionStr(this->graph_.version_), GetShortVersionStr(Version()))) {
440 MS_LOG(WARNING) << "The current model version " << this->graph_.version_
441 << " is later than the inference engine version " << Version()
442 << ". Use a converter tool whose version is earlier than or equal to "
443 << "the inference engine version to convert the model.";
444 }
445 MS_LOG(INFO) << "MindSpore Lite inference version: " << Version();
446 return status;
447 }
448
449 namespace {
InitModelBuffer(LiteModel * model,const char * model_buf,size_t size,bool take_buf)450 int InitModelBuffer(LiteModel *model, const char *model_buf, size_t size, bool take_buf) {
451 if (model_buf == nullptr || size == 0) {
452 MS_LOG(ERROR) << "Input model buffer is nullptr.";
453 return RET_INPUT_PARAM_INVALID;
454 }
455 MS_ASSERT(model != nullptr);
456 if (take_buf) {
457 model->buf = const_cast<char *>(model_buf);
458 } else {
459 if (size > kMaxModelBufferSize) {
460 MS_LOG(ERROR) << "Input model buffer size invalid, require (0, 2GB].";
461 return RET_ERROR;
462 }
463 model->buf = new char[size];
464 if (model->buf == nullptr) {
465 MS_LOG(ERROR) << "new inner model buf fail!";
466 return RET_NULL_PTR;
467 }
468 memcpy(model->buf, model_buf, size);
469 }
470 model->buf_size_ = size;
471 return RET_OK;
472 }
473 } // namespace
474
475 #ifdef ENABLE_LITE_HELPER
ConstructModel(const char * model_buf,size_t size,bool take_buf,mindspore::infer::helper::InferHelpers * infer_helpers)476 int LiteModel::ConstructModel(const char *model_buf, size_t size, bool take_buf,
477 mindspore::infer::helper::InferHelpers *infer_helpers) {
478 #else
479 int LiteModel::ConstructModel(const char *model_buf, size_t size, bool take_buf) {
480 #endif
481 auto ret = InitModelBuffer(this, model_buf, size, take_buf);
482 if (ret != RET_OK) {
483 return ret;
484 }
485
486 flatbuffers::Verifier verify((const uint8_t *)this->buf, this->buf_size_, INT32_MAX, INT32_MAX);
487 schema_version_ = VersionVerify(&verify);
488 if (schema_version_ == SCHEMA_INVALID) {
489 MS_LOG(ERROR) << "The model buffer is invalid and fail to create graph.";
490 if (take_buf) {
491 this->buf = nullptr;
492 }
493 return RET_ERROR;
494 }
495 int status = GenerateModelByVersion();
496 if (status != RET_OK) {
497 MS_LOG(ERROR) << "fail to generate model";
498 if (take_buf) {
499 this->buf = nullptr;
500 }
501 return status;
502 }
503 if (!ModelVerify()) {
504 MS_LOG(ERROR) << "ModelVerify failed.";
505 if (take_buf) {
506 this->buf = nullptr;
507 }
508 return RET_ERROR;
509 }
510 #ifdef ENABLE_LITE_HELPER
511 if (!PrepareInnerTensors(infer_helpers)) {
512 #else
513 if (!PrepareInnerTensors()) {
514 #endif
515 MS_LOG(ERROR) << "PrepareInnerTensors failed.";
516 if (take_buf) {
517 this->buf = nullptr;
518 }
519 return RET_ERROR;
520 }
521
522 return RET_OK;
523 }
524
525 #ifdef ENABLE_LITE_HELPER
526 bool LiteModel::PrepareInnerTensors(mindspore::infer::helper::InferHelpers *infer_helpers) {
527 #else
528 bool LiteModel::PrepareInnerTensors() {
529 #endif
530 if (!this->inner_all_tensors_.empty()) {
531 MS_LOG(ERROR) << "Already prepared tensors";
532 return false;
533 }
534 auto dir = GetDirectory(this->model_path_);
535 this->inner_all_tensors_.resize(graph_.all_tensors_.size());
536 for (size_t i = 0; i < graph_.all_tensors_.size(); i++) {
537 auto tensor_wrapper = new (std::nothrow) SchemaTensorWrapper();
538 if (tensor_wrapper == nullptr) {
539 MS_LOG(ERROR) << "Create SchemaTensorWrapper return nullptr";
540 return false;
541 }
542 if (graph_.all_tensors_.at(i) != nullptr) {
543 #ifdef ENABLE_LITE_HELPER
544 if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir,
545 infer_helpers)) {
546 #else
547 if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir)) {
548 #endif
549 delete tensor_wrapper;
550 return false;
551 }
552 }
553 this->inner_all_tensors_[i] = tensor_wrapper;
554 }
555 return true;
556 }
557
558 SchemaTensorWrapper *LiteModel::GetSchemaTensor(const size_t &tensor_index) const {
559 if (tensor_index >= this->inner_all_tensors_.size()) {
560 return nullptr;
561 }
562 return this->inner_all_tensors_.at(tensor_index);
563 }
564
565 LiteModel *LiteImportFromPath(const char *model_path) {
566 if (model_path == nullptr) {
567 MS_LOG(ERROR) << "The model path is nullptr";
568 return nullptr;
569 }
570 size_t size = 0;
571 auto buf = ReadFile(model_path, &size);
572 if (buf == nullptr) {
573 return nullptr;
574 }
575 auto *model = new (std::nothrow) LiteModel(model_path);
576 if (model == nullptr) {
577 MS_LOG(ERROR) << "new model fail!";
578 return nullptr;
579 }
580
581 auto status = model->ConstructModel(buf, size, true);
582 if (status != RET_OK) {
583 MS_LOG(ERROR) << "construct model failed.";
584 delete model;
585 return nullptr;
586 }
587 return model;
588 }
589
590 bool LiteModel::CheckQuantAllInit(
591 const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params) {
592 if (quant_params == nullptr) {
593 return false;
594 }
595 for (size_t i = 0; i < quant_params->size(); i++) {
596 auto quant_param = quant_params->Get(i);
597 if (quant_param != nullptr && quant_param->inited() == false) {
598 return false;
599 }
600 }
601 return true;
602 }
603
604 Model *ImportFromPath(const char *model_path) { return LiteImportFromPath(model_path); }
605
606 #ifdef ENABLE_LITE_HELPER
607 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, mindspore::ModelType model_type,
608 const std::string &path, mindspore::infer::helper::InferHelpers *infer_helpers) {
609 #else
610 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, mindspore::ModelType model_type,
611 const std::string &path) {
612 #endif
613 auto model_loader = mindspore::infer::ModelLoaderRegistry::GetInstance()->GetModelLoader(model_type);
614 if (model_loader != nullptr) {
615 MS_LOG(INFO) << "import model from model loader";
616 auto model = model_loader->ImportModel(model_buf, size, true);
617 if (model != nullptr) {
618 return model;
619 }
620 }
621
622 MS_LOG(INFO) << "import model from lite model";
623 auto *model = new (std::nothrow) LiteModel(path);
624 if (model == nullptr) {
625 MS_LOG(ERROR) << "new model fail!";
626 return nullptr;
627 }
628 #ifdef ENABLE_LITE_HELPER
629 auto status = model->ConstructModel(model_buf, size, take_buf, infer_helpers);
630 #else
631 auto status = model->ConstructModel(model_buf, size, take_buf);
632 #endif
633 if (status != RET_OK) {
634 MS_LOG(ERROR) << "construct model failed.";
635 delete model;
636 return nullptr;
637 }
638 return model;
639 }
640
641 std::string LiteGraph::ToString() const {
642 std::stringstream ss;
643 ss << "all_nodes: " << all_nodes_.size() << std::endl;
644 for (size_t i = 0; i < all_nodes_.size(); i++) {
645 ss << "- node " << i << ": " << all_nodes_[i]->primitive_ << std::endl;
646 ss << "- node " << i << " input_indices_: " << all_nodes_[i]->input_indices_ << std::endl;
647 ss << "- node " << i << " output_indices_: " << all_nodes_[i]->output_indices_ << std::endl;
648 }
649 ss << "all_tensors: " << all_tensors_.size() << std::endl;
650 for (size_t i = 0; i < all_tensors_.size(); i++) {
651 ss << "- tensor " << i << ": " << all_tensors_[i] << std::endl;
652 }
653 ss << "input_indices: " << input_indices_<< std::endl;
654 ss << "output_indices: " << output_indices_ << std::endl;
655
656 ss << "subgraphs: " << std::endl;
657 int count = 0;
658 for (auto subgraph: sub_graphs_) {
659 ss << "- subgraph " << count++ << std::endl;
660 ss << "--- subgraph input " << subgraph->input_indices_ << std::endl;
661 ss << "--- subgraph output " << subgraph->output_indices_ << std::endl;
662 ss << "--- subgraph node " << subgraph->node_indices_ << std::endl;
663 ss << "--- subgraph tensor " << subgraph->tensor_indices_ << std::endl;
664 }
665 return ss.str();
666 }
667
668 Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
669
670 Model *Model::Import(const char *filename) { return ImportFromPath(filename); }
671
672 int Model::Export(Model *model, char *buffer, size_t *len) {
673 if (len == nullptr) {
674 MS_LOG(ERROR) << "len is nullptr";
675 return RET_ERROR;
676 }
677 auto *liteModel = reinterpret_cast<LiteModel *>(model);
678
679 if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) {
680 MS_LOG(ERROR) << "model buffer is invalid";
681 return RET_ERROR;
682 }
683 if (*len < liteModel->buf_size_ && buffer != nullptr) {
684 MS_LOG(ERROR) << "Buffer is too small, Export Failed";
685 return RET_ERROR;
686 }
687 if (buffer == nullptr) {
688 buffer = reinterpret_cast<char *>(malloc(liteModel->buf_size_));
689 if (buffer == nullptr) {
690 MS_LOG(ERROR) << "allocated model buf fail!";
691 return RET_ERROR;
692 }
693 }
694 memcpy(buffer, liteModel->buf, liteModel->buf_size_);
695 *len = liteModel->buf_size_;
696 return RET_OK;
697 }
698
699 int Model::Export(Model *model, const char *filename) {
700 auto *liteModel = reinterpret_cast<LiteModel *>(model);
701 if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) {
702 MS_LOG(ERROR) << "model buf is invalid";
703 return RET_ERROR;
704 }
705
706 std::ofstream ofs(filename);
707 if (!ofs.good() || !ofs.is_open()) {
708 MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing";
709 return RET_ERROR;
710 }
711
712 ofs.seekp(0, std::ios::beg);
713 ofs.write(liteModel->buf, liteModel->buf_size_);
714 ofs.close();
715 #ifdef _MSC_VER
716 return RET_OK;
717 #else
718 return chmod(filename, S_IRUSR);
719 #endif
720 }
721
722 std::string ModelDebugString(Model *model) {
723 if (model == nullptr) {
724 return "";
725 }
726 std::ostringstream oss;
727 std::string deli = "\n";
728 oss << "{" << deli;
729 oss << "model_type: " << model->model_type_ << deli;
730
731 // debug graph
732 oss << "graph: {" << deli;
733 oss << "name: " << model->graph_.name_ << deli;
734 oss << "version: " << model->graph_.version_;
735
736 // input indices
737 oss << "input_indices: [" << deli;
738 for (auto i : model->graph_.input_indices_) {
739 oss << i << ", " << deli;
740 }
741 oss << "]" << deli;
742
743 // output indices
744 oss << "output_indices: [" << deli;
745 for (auto i : model->graph_.output_indices_) {
746 oss << i << ", " << deli;
747 }
748 oss << "]" << deli;
749
750 // all tensors
751 oss << "all_tensors: [" << deli;
752 for (auto tensor : model->graph_.all_tensors_) {
753 oss << "{" << tensor->name() << "}";
754 }
755 oss << "]" << deli;
756
757 // all nodes
758 oss << "all_nodes: [" << deli;
759 for (auto node : model->graph_.all_nodes_) {
760 oss << "{" << deli;
761 oss << "name: " << node->name_ << deli;
762 oss << "op_type: " << node->op_type_ << deli;
763 oss << "node_type: " << node->node_type_ << deli;
764 oss << "input: [";
765 for (auto i : node->input_indices_) {
766 oss << i << ", ";
767 }
768 oss << "]" << deli;
769 oss << "output: [";
770 for (auto i : node->output_indices_) {
771 oss << i << ", ";
772 }
773 oss << "]" << deli;
774
775 // // primitive
776 // auto *primitive = reinterpret_cast<schema
777
778 oss << "}" << deli;
779 }
780 oss << "]" << deli;
781
782 oss << "}" << deli;
783 oss << "}" << deli;
784
785 // dump
786 return oss.str();
787 }
788 } // namespace mindspore::lite
789