• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/lite_model.h"
18 #include <sys/stat.h>
19 #include <iostream>
20 #include <fstream>
21 #include <vector>
22 #include <set>
23 #include <unordered_map>
24 #include <memory>
25 #include "src/common/prim_util.h"
26 #include "src/common/graph_util.h"
27 #include "src/common/file_utils.h"
28 #ifdef ENABLE_V0
29 #include "src/ops/compat/compat_register.h"
30 #endif
31 
32 namespace mindspore::lite {
33 #ifdef ENABLE_V0
ConvertAttrs(LiteGraph::Node * node,std::vector<schema::Tensor * > * dst_tensor)34 int LiteModel::ConvertAttrs(LiteGraph::Node *node, std::vector<schema::Tensor *> *dst_tensor) {
35   if (node == nullptr || dst_tensor == nullptr) {
36     MS_LOG(ERROR) << "node or tensor_vec is nullptr.";
37     return RET_ERROR;
38   }
39   auto primitive = node->primitive_;
40   if (primitive == nullptr) {
41     MS_LOG(ERROR) << "primitive is nullptr.";
42     return RET_ERROR;
43   }
44   auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive);
45   int primitive_type = prim->value_type();
46   auto creator = CompatRegistry::GetInstance()->GetTransferAttrFunc(SCHEMA_VERSION::SCHEMA_V0, primitive_type);
47   if (creator == nullptr) {
48     MS_LOG(DEBUG) << "the node don't need to convert attr to tensor.";
49     return RET_OK;
50   }
51   int status = creator(node, dst_tensor, &this->attr_tensor_bufs_);
52   if (status != RET_OK && status != RET_NO_CHANGE) {
53     MS_LOG(ERROR) << "translate attr to tensor failed.";
54     return status;
55   }
56   return RET_OK;
57 }
58 
ConvertAttrToTensors()59 int LiteModel::ConvertAttrToTensors() {
60   if (schema_version_ != SCHEMA_VERSION::SCHEMA_V0) {
61     MS_LOG(DEBUG) << "no need to convert attr to tensor.";
62     return RET_OK;
63   }
64   std::unordered_map<int, std::set<int>> subgraph_node_indexes;
65   for (size_t subgraph_index = 0; subgraph_index < this->graph_.sub_graphs_.size(); ++subgraph_index) {
66     for (size_t node_index = 0; node_index < this->graph_.sub_graphs_[subgraph_index]->node_indices_.size();
67          ++node_index) {
68       subgraph_node_indexes[subgraph_index].insert(this->graph_.sub_graphs_[subgraph_index]->node_indices_[node_index]);
69     }
70   }
71   int cur_all_tensors_size = this->graph_.all_tensors_.size();
72   for (size_t index = 0; index < this->graph_.all_nodes_.size(); ++index) {
73     std::vector<schema::Tensor *> dst_tensors;
74     int status = ConvertAttrs(this->graph_.all_nodes_[index], &dst_tensors);
75     if (status != RET_OK) {
76       MS_LOG(ERROR) << "fail to convert attr to tensor.";
77       return RET_ERROR;
78     }
79     if (dst_tensors.empty()) {
80       continue;
81     }
82     std::vector<int> subgraphs_with_node;
83     for (size_t subgraph_index = 0; subgraph_index < this->graph_.sub_graphs_.size(); ++subgraph_index) {
84       if (subgraph_node_indexes[subgraph_index].find(index) == subgraph_node_indexes[subgraph_index].end()) {
85         continue;
86       }
87       subgraphs_with_node.push_back(subgraph_index);
88     }
89     for (auto tensor : dst_tensors) {
90       for (auto subgraph_index : subgraphs_with_node) {
91         this->graph_.sub_graphs_[subgraph_index]->tensor_indices_.push_back(cur_all_tensors_size);
92       }
93       this->graph_.all_nodes_[index]->input_indices_.push_back(cur_all_tensors_size++);
94       this->graph_.all_tensors_.push_back(tensor);
95     }
96   }
97   return RET_OK;
98 }
99 #endif
100 
Free()101 void LiteModel::Free() {
102   if (this->buf != nullptr) {
103     delete[](this->buf);
104     this->buf = nullptr;
105   }
106   auto nodes_size = this->graph_.all_nodes_.size();
107   for (size_t i = 0; i < nodes_size; ++i) {
108     auto node = this->graph_.all_nodes_[i];
109     node->primitive_ = nullptr;
110   }
111   for (auto &tensor_buf : attr_tensor_bufs_) {
112     free(tensor_buf);
113     tensor_buf = nullptr;
114   }
115   attr_tensor_bufs_.resize(0);
116 
117   for (auto &node_buf : node_bufs_) {
118     free(node_buf);
119     node_buf = nullptr;
120   }
121   node_bufs_.resize(0);
122 #ifdef ENABLE_MODEL_OBF
123   for (auto &prim : deobf_prims_) {
124     free(prim);
125   }
126   deobf_prims_.resize(0);
127 #endif
128 }
129 
Destroy()130 void LiteModel::Destroy() {
131   Free();
132   auto nodes_size = this->graph_.all_nodes_.size();
133   for (size_t i = 0; i < nodes_size; ++i) {
134     auto node = this->graph_.all_nodes_[i];
135     MS_ASSERT(node != nullptr);
136     delete node;
137   }
138   this->graph_.all_nodes_.clear();
139 
140   auto sub_graph_size = this->graph_.sub_graphs_.size();
141   for (size_t i = 0; i < sub_graph_size; ++i) {
142     auto sub_graph = this->graph_.sub_graphs_[i];
143     delete sub_graph;
144   }
145 }
146 
ConvertSubGraph(const schema::SubGraph & sub_graph)147 int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) {
148   if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr ||
149       sub_graph.tensorIndices() == nullptr) {
150     MS_LOG(ERROR) << "sub_graph is invalid";
151     return RET_ERROR;
152   }
153 
154   auto *subgraph = new (std::nothrow) LiteGraph::SubGraph();
155   if (subgraph == nullptr) {
156     MS_LOG(ERROR) << "new subGraph fail!";
157     return RET_ERROR;
158   }
159 
160   subgraph->name_ = sub_graph.name()->c_str();
161   auto in_count = sub_graph.inputIndices()->size();
162   for (uint32_t i = 0; i < in_count; ++i) {
163     subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i));
164   }
165   auto out_count = sub_graph.outputIndices()->size();
166   for (uint32_t i = 0; i < out_count; ++i) {
167     subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i));
168   }
169   if (sub_graph.nodeIndices() != nullptr) {
170     auto node_count = sub_graph.nodeIndices()->size();
171     for (uint32_t i = 0; i < node_count; ++i) {
172       subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i));
173     }
174   }
175   auto tensor_count = sub_graph.tensorIndices()->size();
176   for (uint32_t i = 0; i < tensor_count; ++i) {
177     subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i));
178   }
179   this->graph_.sub_graphs_.push_back(subgraph);
180   return RET_OK;
181 }
182 
VersionVerify(flatbuffers::Verifier * verify) const183 int LiteModel::VersionVerify(flatbuffers::Verifier *verify) const {
184   if (verify == nullptr) {
185     MS_LOG(ERROR) << "verify is null.";
186     return RET_ERROR;
187   }
188   if (schema::VerifyMetaGraphBuffer(*verify)) {
189     return SCHEMA_VERSION::SCHEMA_CUR;
190   }
191 #ifdef ENABLE_V0
192   if (schema::v0::VerifyMetaGraphBuffer(*verify)) {
193     return SCHEMA_VERSION::SCHEMA_V0;
194   }
195 #endif
196   return SCHEMA_VERSION::SCHEMA_INVALID;
197 }
198 
NodeVerify() const199 int LiteModel::NodeVerify() const {
200   auto tensor_size = this->graph_.all_tensors_.size();
201   uint32_t subgraph_size = this->graph_.sub_graphs_.size();
202 
203   for (auto &node : this->graph_.all_nodes_) {
204     if (node == nullptr || node->primitive_ == nullptr) {
205       MS_LOG(ERROR) << "node or its primitive_ is null.";
206       return RET_ERROR;
207     }
208     if (std::any_of(node->input_indices_.begin(), node->input_indices_.end(),
209                     [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
210       MS_LOG(ERROR) << "Index of node->input_indices_ is beyond size.";
211       return RET_ERROR;
212     }
213     if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(),
214                     [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
215       MS_LOG(ERROR) << "Index of node->output_indices_ is beyond size.";
216       return RET_ERROR;
217     }
218 
219     if (IsPartialNode(node->primitive_, schema_version_)) {
220       auto subgraph_index = GetPartialGraphIndex(node->primitive_, schema_version_);
221       if (static_cast<uint32_t>(subgraph_index) >= subgraph_size) {
222         MS_LOG(ERROR) << "subgraph index:" << subgraph_index << " is beyond subgraph_size: " << subgraph_size;
223         return RET_ERROR;
224       }
225     }
226   }
227   return RET_OK;
228 }
229 
SubGraphVerify() const230 int LiteModel::SubGraphVerify() const {
231   auto tensor_size = this->graph_.all_tensors_.size();
232   auto node_size = this->graph_.all_nodes_.size();
233 
234   if (graph_.sub_graphs_[0]->input_indices_.size() == 0 || graph_.sub_graphs_[0]->output_indices_.size() == 0) {
235     MS_LOG(ERROR) << "The model has invalid input and output, please check";
236     return RET_ERROR;
237   }
238 
239   for (auto &graph : this->graph_.sub_graphs_) {
240     if (graph == nullptr) {
241       MS_LOG(ERROR) << "graph is null.";
242       return RET_ERROR;
243     }
244     if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(),
245                     [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
246       MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size.";
247       return RET_ERROR;
248     }
249     if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(),
250                     [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
251       MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size.";
252       return RET_ERROR;
253     }
254     if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(),
255                     [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
256       MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size.";
257       return RET_ERROR;
258     }
259     if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(),
260                     [&node_size](const uint32_t &idx) { return idx >= node_size; })) {
261       MS_LOG(ERROR) << "Index of graph->node_indices_ is beyond node_size.";
262       return RET_ERROR;
263     }
264   }
265   return RET_OK;
266 }
267 
ModelVerify() const268 bool LiteModel::ModelVerify() const {
269   if (this->graph_.sub_graphs_.empty()) {
270     MS_LOG(ERROR) << "Model does not have a main graph.";
271     return false;
272   }
273 
274   auto all_tensors_size = this->graph_.all_tensors_.size();
275   for (auto input_index : this->graph_.input_indices_) {
276     if (input_index >= all_tensors_size) {
277       MS_LOG(ERROR) << "Graph input indices is beyond tensor_size.";
278       return false;
279     }
280     auto *tensor = static_cast<schema::Tensor *>(this->graph_.all_tensors_.at(input_index));
281     if (tensor == nullptr) {
282       MS_LOG(ERROR) << "Tensor in all tensors is nullptr.";
283       return false;
284     }
285   }
286 
287   if (std::any_of(this->graph_.output_indices_.begin(), this->graph_.output_indices_.end(),
288                   [&all_tensors_size](const uint32_t &idx) { return idx >= all_tensors_size; })) {
289     MS_LOG(ERROR) << "Graph output indices is beyond tensor_size.";
290     return false;
291   }
292   return NodeVerify() == RET_OK && SubGraphVerify() == RET_OK;
293 }
294 
GetMetaGraphByVerison()295 const void *LiteModel::GetMetaGraphByVerison() {
296   MS_ASSERT(this->buf != nullptr);
297   if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
298     return reinterpret_cast<const void *>(schema::GetMetaGraph(this->buf));
299   }
300 #ifdef ENABLE_V0
301   if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
302     return reinterpret_cast<const void *>(schema::v0::GetMetaGraph(buf));
303   }
304 #endif
305   return nullptr;
306 }
307 
GenerateModelByVersion(const void * meta_graph)308 int LiteModel::GenerateModelByVersion(const void *meta_graph) {
309   MS_ASSERT(meta_graph != nullptr);
310   int status = RET_ERROR;
311 #ifdef ENABLE_MODEL_OBF
312   DeObfuscator *model_deobf = nullptr;
313 #endif
314   if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
315 #ifdef ENABLE_MODEL_OBF
316     if (IsMetaGraphObfuscated<schema::MetaGraph>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph))) {
317       model_deobf =
318         GetModelDeObfuscator<schema::MetaGraph>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph), this);
319       this->model_obfuscated_ = true;
320       if (model_deobf == nullptr) {
321         return RET_ERROR;
322       }
323     }
324 #endif
325     status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph));
326   }
327 #ifdef ENABLE_V0
328   if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
329     status = GenerateModel<schema::v0::MetaGraph, schema::v0::CNode>(
330       *reinterpret_cast<const schema::v0::MetaGraph *>(meta_graph));
331   }
332 #endif
333 #ifdef ENABLE_MODEL_OBF
334   if (this->model_obfuscated_) {
335     MS_ASSERT(model_deobf != nullptr);
336     status = DeObfuscateModel(this, model_deobf);
337     if (status != RET_OK) {
338       MS_LOG(ERROR) << "deobfuscate model wrong.";
339       std::cerr << "deobfuscate model wrong." << std::endl;
340     }
341     delete (model_deobf);
342   }
343 #endif
344   return status;
345 }
346 
ConstructModel()347 int LiteModel::ConstructModel() {
348   if (this->buf == nullptr || this->buf_size_ <= 0) {
349     MS_LOG(ERROR) << "cannot construct model.";
350     return RET_NULL_PTR;
351   }
352   flatbuffers::Verifier verify((const uint8_t *)this->buf, this->buf_size_);
353   schema_version_ = VersionVerify(&verify);
354   if (schema_version_ == SCHEMA_INVALID) {
355     MS_LOG(ERROR) << "The model buffer is invalid and fail to create graph.";
356 #ifndef ENABLE_V0
357     MS_LOG(ERROR) << "Maybe this is a model transferred out using the conversion tool before 1.1.0";
358     MS_LOG(ERROR) << unsupport_v0_log;
359 #endif
360     return RET_ERROR;
361   }
362   const void *meta_graph = GetMetaGraphByVerison();
363   if (meta_graph == nullptr) {
364     MS_LOG(ERROR) << "meta_graph is nullptr!";
365     return RET_NULL_PTR;
366   }
367 
368   int status = GenerateModelByVersion(meta_graph);
369   if (status != RET_OK) {
370     MS_LOG(ERROR) << "fail to generate model";
371     return status;
372   }
373 
374   if (this->graph_.version_ != Version()) {
375     MS_LOG(WARNING) << "model version is " << this->graph_.version_ << ", inference version is " << Version()
376                     << " not equal";
377   }
378   if (this->graph_.sub_graphs_.empty()) {
379     return RET_ERROR;
380   }
381 
382   return ModelVerify() ? RET_OK : RET_ERROR;
383 }
384 
385 namespace {
386 constexpr size_t kMaxModelBufferSize = static_cast<size_t>(1024) * 1024 * 1024 * 2;
387 }
388 
ImportFromBuffer(const char * model_buf,size_t size,bool take_buf)389 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
390   if (model_buf == nullptr) {
391     MS_LOG(ERROR) << "The model buf is nullptr";
392     return nullptr;
393   }
394   auto *model = new (std::nothrow) LiteModel();
395   if (model == nullptr) {
396     MS_LOG(ERROR) << "new model fail!";
397     return nullptr;
398   }
399   if (take_buf) {
400     model->buf = const_cast<char *>(model_buf);
401   } else {
402     if (size == 0 || size > kMaxModelBufferSize) {
403       MS_LOG(ERROR) << "Input model buffer size invalid, require (0, 2GB].";
404       delete (model);
405       return nullptr;
406     }
407     model->buf = new char[size];
408     if (model->buf == nullptr) {
409       MS_LOG(ERROR) << "new inner model buf fail!";
410       delete (model);
411       return nullptr;
412     }
413     memcpy(model->buf, model_buf, size);
414   }
415   model->buf_size_ = size;
416   auto status = model->ConstructModel();
417   if (status != RET_OK) {
418     if (take_buf) {
419       model->buf = nullptr;
420     }
421     MS_LOG(ERROR) << "construct model failed.";
422     delete model;
423     return nullptr;
424   }
425   return model;
426 }
427 
Import(const char * model_buf,size_t size)428 Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
429 
Import(const char * filename)430 Model *Model::Import(const char *filename) {
431   size_t size = -1;
432   auto buf = ReadFile(filename, &size);
433   if (buf == nullptr) {
434     return nullptr;
435   }
436   return ImportFromBuffer(buf, size, true);
437 }
438 
Export(Model * model,char * buffer,size_t * len)439 int Model::Export(Model *model, char *buffer, size_t *len) {
440   if (len == nullptr) {
441     MS_LOG(ERROR) << "len is nullptr";
442     return RET_ERROR;
443   }
444   auto *liteModel = reinterpret_cast<LiteModel *>(model);
445 
446   if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) {
447     MS_LOG(ERROR) << "model buffer is invalid";
448     return RET_ERROR;
449   }
450   if (*len < liteModel->buf_size_ && buffer != nullptr) {
451     MS_LOG(ERROR) << "Buffer is too small, Export Failed";
452     return RET_ERROR;
453   }
454   if (buffer == nullptr) {
455     buffer = reinterpret_cast<char *>(malloc(liteModel->buf_size_));
456     if (buffer == nullptr) {
457       MS_LOG(ERROR) << "allocated model buf fail!";
458       return RET_ERROR;
459     }
460   }
461   memcpy(buffer, liteModel->buf, liteModel->buf_size_);
462   *len = liteModel->buf_size_;
463   return RET_OK;
464 }
465 
Export(Model * model,const char * filename)466 int Model::Export(Model *model, const char *filename) {
467   auto *liteModel = reinterpret_cast<LiteModel *>(model);
468   if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) {
469     MS_LOG(ERROR) << "model buf is invalid";
470     return RET_ERROR;
471   }
472 
473   std::ofstream ofs(filename);
474   if (!ofs.good() || !ofs.is_open()) {
475     MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing";
476     return RET_ERROR;
477   }
478 
479   ofs.seekp(0, std::ios::beg);
480   ofs.write(liteModel->buf, liteModel->buf_size_);
481   ofs.close();
482 #ifdef SUPPORT_MSVC
483   return RET_OK;
484 #else
485   return chmod(filename, S_IRUSR);
486 #endif
487 }
488 }  // namespace mindspore::lite
489