1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/lite_model.h"
18 #include <sys/stat.h>
19 #include <iostream>
20 #include <fstream>
21 #include <vector>
22 #include <set>
23 #include <unordered_map>
24 #include <memory>
25 #include "src/common/prim_util.h"
26 #include "src/common/graph_util.h"
27 #include "src/common/file_utils.h"
28 #ifdef ENABLE_V0
29 #include "src/ops/compat/compat_register.h"
30 #endif
31
32 namespace mindspore::lite {
33 #ifdef ENABLE_V0
ConvertAttrs(LiteGraph::Node * node,std::vector<schema::Tensor * > * dst_tensor)34 int LiteModel::ConvertAttrs(LiteGraph::Node *node, std::vector<schema::Tensor *> *dst_tensor) {
35 if (node == nullptr || dst_tensor == nullptr) {
36 MS_LOG(ERROR) << "node or tensor_vec is nullptr.";
37 return RET_ERROR;
38 }
39 auto primitive = node->primitive_;
40 if (primitive == nullptr) {
41 MS_LOG(ERROR) << "primitive is nullptr.";
42 return RET_ERROR;
43 }
44 auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive);
45 int primitive_type = prim->value_type();
46 auto creator = CompatRegistry::GetInstance()->GetTransferAttrFunc(SCHEMA_VERSION::SCHEMA_V0, primitive_type);
47 if (creator == nullptr) {
48 MS_LOG(DEBUG) << "the node don't need to convert attr to tensor.";
49 return RET_OK;
50 }
51 int status = creator(node, dst_tensor, &this->attr_tensor_bufs_);
52 if (status != RET_OK && status != RET_NO_CHANGE) {
53 MS_LOG(ERROR) << "translate attr to tensor failed.";
54 return status;
55 }
56 return RET_OK;
57 }
58
ConvertAttrToTensors()59 int LiteModel::ConvertAttrToTensors() {
60 if (schema_version_ != SCHEMA_VERSION::SCHEMA_V0) {
61 MS_LOG(DEBUG) << "no need to convert attr to tensor.";
62 return RET_OK;
63 }
64 std::unordered_map<int, std::set<int>> subgraph_node_indexes;
65 for (size_t subgraph_index = 0; subgraph_index < this->graph_.sub_graphs_.size(); ++subgraph_index) {
66 for (size_t node_index = 0; node_index < this->graph_.sub_graphs_[subgraph_index]->node_indices_.size();
67 ++node_index) {
68 subgraph_node_indexes[subgraph_index].insert(this->graph_.sub_graphs_[subgraph_index]->node_indices_[node_index]);
69 }
70 }
71 int cur_all_tensors_size = this->graph_.all_tensors_.size();
72 for (size_t index = 0; index < this->graph_.all_nodes_.size(); ++index) {
73 std::vector<schema::Tensor *> dst_tensors;
74 int status = ConvertAttrs(this->graph_.all_nodes_[index], &dst_tensors);
75 if (status != RET_OK) {
76 MS_LOG(ERROR) << "fail to convert attr to tensor.";
77 return RET_ERROR;
78 }
79 if (dst_tensors.empty()) {
80 continue;
81 }
82 std::vector<int> subgraphs_with_node;
83 for (size_t subgraph_index = 0; subgraph_index < this->graph_.sub_graphs_.size(); ++subgraph_index) {
84 if (subgraph_node_indexes[subgraph_index].find(index) == subgraph_node_indexes[subgraph_index].end()) {
85 continue;
86 }
87 subgraphs_with_node.push_back(subgraph_index);
88 }
89 for (auto tensor : dst_tensors) {
90 for (auto subgraph_index : subgraphs_with_node) {
91 this->graph_.sub_graphs_[subgraph_index]->tensor_indices_.push_back(cur_all_tensors_size);
92 }
93 this->graph_.all_nodes_[index]->input_indices_.push_back(cur_all_tensors_size++);
94 this->graph_.all_tensors_.push_back(tensor);
95 }
96 }
97 return RET_OK;
98 }
99 #endif
100
Free()101 void LiteModel::Free() {
102 if (this->buf != nullptr) {
103 delete[](this->buf);
104 this->buf = nullptr;
105 }
106 auto nodes_size = this->graph_.all_nodes_.size();
107 for (size_t i = 0; i < nodes_size; ++i) {
108 auto node = this->graph_.all_nodes_[i];
109 node->primitive_ = nullptr;
110 }
111 for (auto &tensor_buf : attr_tensor_bufs_) {
112 free(tensor_buf);
113 tensor_buf = nullptr;
114 }
115 attr_tensor_bufs_.resize(0);
116
117 for (auto &node_buf : node_bufs_) {
118 free(node_buf);
119 node_buf = nullptr;
120 }
121 node_bufs_.resize(0);
122 #ifdef ENABLE_MODEL_OBF
123 for (auto &prim : deobf_prims_) {
124 free(prim);
125 }
126 deobf_prims_.resize(0);
127 #endif
128 }
129
Destroy()130 void LiteModel::Destroy() {
131 Free();
132 auto nodes_size = this->graph_.all_nodes_.size();
133 for (size_t i = 0; i < nodes_size; ++i) {
134 auto node = this->graph_.all_nodes_[i];
135 MS_ASSERT(node != nullptr);
136 delete node;
137 }
138 this->graph_.all_nodes_.clear();
139
140 auto sub_graph_size = this->graph_.sub_graphs_.size();
141 for (size_t i = 0; i < sub_graph_size; ++i) {
142 auto sub_graph = this->graph_.sub_graphs_[i];
143 delete sub_graph;
144 }
145 }
146
ConvertSubGraph(const schema::SubGraph & sub_graph)147 int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) {
148 if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr ||
149 sub_graph.tensorIndices() == nullptr) {
150 MS_LOG(ERROR) << "sub_graph is invalid";
151 return RET_ERROR;
152 }
153
154 auto *subgraph = new (std::nothrow) LiteGraph::SubGraph();
155 if (subgraph == nullptr) {
156 MS_LOG(ERROR) << "new subGraph fail!";
157 return RET_ERROR;
158 }
159
160 subgraph->name_ = sub_graph.name()->c_str();
161 auto in_count = sub_graph.inputIndices()->size();
162 for (uint32_t i = 0; i < in_count; ++i) {
163 subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i));
164 }
165 auto out_count = sub_graph.outputIndices()->size();
166 for (uint32_t i = 0; i < out_count; ++i) {
167 subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i));
168 }
169 if (sub_graph.nodeIndices() != nullptr) {
170 auto node_count = sub_graph.nodeIndices()->size();
171 for (uint32_t i = 0; i < node_count; ++i) {
172 subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i));
173 }
174 }
175 auto tensor_count = sub_graph.tensorIndices()->size();
176 for (uint32_t i = 0; i < tensor_count; ++i) {
177 subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i));
178 }
179 this->graph_.sub_graphs_.push_back(subgraph);
180 return RET_OK;
181 }
182
VersionVerify(flatbuffers::Verifier * verify) const183 int LiteModel::VersionVerify(flatbuffers::Verifier *verify) const {
184 if (verify == nullptr) {
185 MS_LOG(ERROR) << "verify is null.";
186 return RET_ERROR;
187 }
188 if (schema::VerifyMetaGraphBuffer(*verify)) {
189 return SCHEMA_VERSION::SCHEMA_CUR;
190 }
191 #ifdef ENABLE_V0
192 if (schema::v0::VerifyMetaGraphBuffer(*verify)) {
193 return SCHEMA_VERSION::SCHEMA_V0;
194 }
195 #endif
196 return SCHEMA_VERSION::SCHEMA_INVALID;
197 }
198
NodeVerify() const199 int LiteModel::NodeVerify() const {
200 auto tensor_size = this->graph_.all_tensors_.size();
201 uint32_t subgraph_size = this->graph_.sub_graphs_.size();
202
203 for (auto &node : this->graph_.all_nodes_) {
204 if (node == nullptr || node->primitive_ == nullptr) {
205 MS_LOG(ERROR) << "node or its primitive_ is null.";
206 return RET_ERROR;
207 }
208 if (std::any_of(node->input_indices_.begin(), node->input_indices_.end(),
209 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
210 MS_LOG(ERROR) << "Index of node->input_indices_ is beyond size.";
211 return RET_ERROR;
212 }
213 if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(),
214 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
215 MS_LOG(ERROR) << "Index of node->output_indices_ is beyond size.";
216 return RET_ERROR;
217 }
218
219 if (IsPartialNode(node->primitive_, schema_version_)) {
220 auto subgraph_index = GetPartialGraphIndex(node->primitive_, schema_version_);
221 if (static_cast<uint32_t>(subgraph_index) >= subgraph_size) {
222 MS_LOG(ERROR) << "subgraph index:" << subgraph_index << " is beyond subgraph_size: " << subgraph_size;
223 return RET_ERROR;
224 }
225 }
226 }
227 return RET_OK;
228 }
229
SubGraphVerify() const230 int LiteModel::SubGraphVerify() const {
231 auto tensor_size = this->graph_.all_tensors_.size();
232 auto node_size = this->graph_.all_nodes_.size();
233
234 if (graph_.sub_graphs_[0]->input_indices_.size() == 0 || graph_.sub_graphs_[0]->output_indices_.size() == 0) {
235 MS_LOG(ERROR) << "The model has invalid input and output, please check";
236 return RET_ERROR;
237 }
238
239 for (auto &graph : this->graph_.sub_graphs_) {
240 if (graph == nullptr) {
241 MS_LOG(ERROR) << "graph is null.";
242 return RET_ERROR;
243 }
244 if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(),
245 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
246 MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size.";
247 return RET_ERROR;
248 }
249 if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(),
250 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
251 MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size.";
252 return RET_ERROR;
253 }
254 if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(),
255 [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
256 MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size.";
257 return RET_ERROR;
258 }
259 if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(),
260 [&node_size](const uint32_t &idx) { return idx >= node_size; })) {
261 MS_LOG(ERROR) << "Index of graph->node_indices_ is beyond node_size.";
262 return RET_ERROR;
263 }
264 }
265 return RET_OK;
266 }
267
ModelVerify() const268 bool LiteModel::ModelVerify() const {
269 if (this->graph_.sub_graphs_.empty()) {
270 MS_LOG(ERROR) << "Model does not have a main graph.";
271 return false;
272 }
273
274 auto all_tensors_size = this->graph_.all_tensors_.size();
275 for (auto input_index : this->graph_.input_indices_) {
276 if (input_index >= all_tensors_size) {
277 MS_LOG(ERROR) << "Graph input indices is beyond tensor_size.";
278 return false;
279 }
280 auto *tensor = static_cast<schema::Tensor *>(this->graph_.all_tensors_.at(input_index));
281 if (tensor == nullptr) {
282 MS_LOG(ERROR) << "Tensor in all tensors is nullptr.";
283 return false;
284 }
285 }
286
287 if (std::any_of(this->graph_.output_indices_.begin(), this->graph_.output_indices_.end(),
288 [&all_tensors_size](const uint32_t &idx) { return idx >= all_tensors_size; })) {
289 MS_LOG(ERROR) << "Graph output indices is beyond tensor_size.";
290 return false;
291 }
292 return NodeVerify() == RET_OK && SubGraphVerify() == RET_OK;
293 }
294
GetMetaGraphByVerison()295 const void *LiteModel::GetMetaGraphByVerison() {
296 MS_ASSERT(this->buf != nullptr);
297 if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
298 return reinterpret_cast<const void *>(schema::GetMetaGraph(this->buf));
299 }
300 #ifdef ENABLE_V0
301 if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
302 return reinterpret_cast<const void *>(schema::v0::GetMetaGraph(buf));
303 }
304 #endif
305 return nullptr;
306 }
307
GenerateModelByVersion(const void * meta_graph)308 int LiteModel::GenerateModelByVersion(const void *meta_graph) {
309 MS_ASSERT(meta_graph != nullptr);
310 int status = RET_ERROR;
311 #ifdef ENABLE_MODEL_OBF
312 DeObfuscator *model_deobf = nullptr;
313 #endif
314 if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
315 #ifdef ENABLE_MODEL_OBF
316 if (IsMetaGraphObfuscated<schema::MetaGraph>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph))) {
317 model_deobf =
318 GetModelDeObfuscator<schema::MetaGraph>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph), this);
319 this->model_obfuscated_ = true;
320 if (model_deobf == nullptr) {
321 return RET_ERROR;
322 }
323 }
324 #endif
325 status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph));
326 }
327 #ifdef ENABLE_V0
328 if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
329 status = GenerateModel<schema::v0::MetaGraph, schema::v0::CNode>(
330 *reinterpret_cast<const schema::v0::MetaGraph *>(meta_graph));
331 }
332 #endif
333 #ifdef ENABLE_MODEL_OBF
334 if (this->model_obfuscated_) {
335 MS_ASSERT(model_deobf != nullptr);
336 status = DeObfuscateModel(this, model_deobf);
337 if (status != RET_OK) {
338 MS_LOG(ERROR) << "deobfuscate model wrong.";
339 std::cerr << "deobfuscate model wrong." << std::endl;
340 }
341 delete (model_deobf);
342 }
343 #endif
344 return status;
345 }
346
ConstructModel()347 int LiteModel::ConstructModel() {
348 if (this->buf == nullptr || this->buf_size_ <= 0) {
349 MS_LOG(ERROR) << "cannot construct model.";
350 return RET_NULL_PTR;
351 }
352 flatbuffers::Verifier verify((const uint8_t *)this->buf, this->buf_size_);
353 schema_version_ = VersionVerify(&verify);
354 if (schema_version_ == SCHEMA_INVALID) {
355 MS_LOG(ERROR) << "The model buffer is invalid and fail to create graph.";
356 #ifndef ENABLE_V0
357 MS_LOG(ERROR) << "Maybe this is a model transferred out using the conversion tool before 1.1.0";
358 MS_LOG(ERROR) << unsupport_v0_log;
359 #endif
360 return RET_ERROR;
361 }
362 const void *meta_graph = GetMetaGraphByVerison();
363 if (meta_graph == nullptr) {
364 MS_LOG(ERROR) << "meta_graph is nullptr!";
365 return RET_NULL_PTR;
366 }
367
368 int status = GenerateModelByVersion(meta_graph);
369 if (status != RET_OK) {
370 MS_LOG(ERROR) << "fail to generate model";
371 return status;
372 }
373
374 if (this->graph_.version_ != Version()) {
375 MS_LOG(WARNING) << "model version is " << this->graph_.version_ << ", inference version is " << Version()
376 << " not equal";
377 }
378 if (this->graph_.sub_graphs_.empty()) {
379 return RET_ERROR;
380 }
381
382 return ModelVerify() ? RET_OK : RET_ERROR;
383 }
384
385 namespace {
386 constexpr size_t kMaxModelBufferSize = static_cast<size_t>(1024) * 1024 * 1024 * 2;
387 }
388
ImportFromBuffer(const char * model_buf,size_t size,bool take_buf)389 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
390 if (model_buf == nullptr) {
391 MS_LOG(ERROR) << "The model buf is nullptr";
392 return nullptr;
393 }
394 auto *model = new (std::nothrow) LiteModel();
395 if (model == nullptr) {
396 MS_LOG(ERROR) << "new model fail!";
397 return nullptr;
398 }
399 if (take_buf) {
400 model->buf = const_cast<char *>(model_buf);
401 } else {
402 if (size == 0 || size > kMaxModelBufferSize) {
403 MS_LOG(ERROR) << "Input model buffer size invalid, require (0, 2GB].";
404 delete (model);
405 return nullptr;
406 }
407 model->buf = new char[size];
408 if (model->buf == nullptr) {
409 MS_LOG(ERROR) << "new inner model buf fail!";
410 delete (model);
411 return nullptr;
412 }
413 memcpy(model->buf, model_buf, size);
414 }
415 model->buf_size_ = size;
416 auto status = model->ConstructModel();
417 if (status != RET_OK) {
418 if (take_buf) {
419 model->buf = nullptr;
420 }
421 MS_LOG(ERROR) << "construct model failed.";
422 delete model;
423 return nullptr;
424 }
425 return model;
426 }
427
Import(const char * model_buf,size_t size)428 Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
429
Import(const char * filename)430 Model *Model::Import(const char *filename) {
431 size_t size = -1;
432 auto buf = ReadFile(filename, &size);
433 if (buf == nullptr) {
434 return nullptr;
435 }
436 return ImportFromBuffer(buf, size, true);
437 }
438
Export(Model * model,char * buffer,size_t * len)439 int Model::Export(Model *model, char *buffer, size_t *len) {
440 if (len == nullptr) {
441 MS_LOG(ERROR) << "len is nullptr";
442 return RET_ERROR;
443 }
444 auto *liteModel = reinterpret_cast<LiteModel *>(model);
445
446 if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) {
447 MS_LOG(ERROR) << "model buffer is invalid";
448 return RET_ERROR;
449 }
450 if (*len < liteModel->buf_size_ && buffer != nullptr) {
451 MS_LOG(ERROR) << "Buffer is too small, Export Failed";
452 return RET_ERROR;
453 }
454 if (buffer == nullptr) {
455 buffer = reinterpret_cast<char *>(malloc(liteModel->buf_size_));
456 if (buffer == nullptr) {
457 MS_LOG(ERROR) << "allocated model buf fail!";
458 return RET_ERROR;
459 }
460 }
461 memcpy(buffer, liteModel->buf, liteModel->buf_size_);
462 *len = liteModel->buf_size_;
463 return RET_OK;
464 }
465
Export(Model * model,const char * filename)466 int Model::Export(Model *model, const char *filename) {
467 auto *liteModel = reinterpret_cast<LiteModel *>(model);
468 if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) {
469 MS_LOG(ERROR) << "model buf is invalid";
470 return RET_ERROR;
471 }
472
473 std::ofstream ofs(filename);
474 if (!ofs.good() || !ofs.is_open()) {
475 MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing";
476 return RET_ERROR;
477 }
478
479 ofs.seekp(0, std::ios::beg);
480 ofs.write(liteModel->buf, liteModel->buf_size_);
481 ofs.close();
482 #ifdef SUPPORT_MSVC
483 return RET_OK;
484 #else
485 return chmod(filename, S_IRUSR);
486 #endif
487 }
488 } // namespace mindspore::lite
489