• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_LITE_MODEL_H_
18 #define MINDSPORE_LITE_SRC_LITE_MODEL_H_
19 
20 #include <string>
21 #include <vector>
22 #include "include/errorcode.h"
23 #include "include/model.h"
24 #include "include/version.h"
25 #include "schema/model_generated.h"
26 #include "src/common/common.h"
27 #include "src/common/log_adapter.h"
28 #include "src/common/version_manager.h"
29 #ifdef ENABLE_V0
30 #include "schema/model_v0_generated.h"
31 #endif
32 #ifdef ENABLE_MODEL_OBF
33 #include "tools/obfuscator/include/deobfuscator.h"
34 #endif
35 
36 namespace mindspore {
37 namespace lite {
38 class LiteModel : public Model {
39  public:
40   int ConstructModel();
41 
42   bool ModelVerify() const;
43 
44   void Free() override;
45 
46   void Destroy() override;
47 
~LiteModel()48   ~LiteModel() override { Destroy(); }
49 
keep_model_buf()50   bool keep_model_buf() const { return this->keep_model_buf_; }
51 
set_keep_model_buf(bool keep)52   void set_keep_model_buf(bool keep) { this->keep_model_buf_ = keep; }
53 
GetSchemaVersion()54   int GetSchemaVersion() const { return schema_version_; }
55 
56  private:
57 #ifdef ENABLE_V0
58   int ConvertAttrs(LiteGraph::Node *node, std::vector<schema::Tensor *> *dst_tensor);
59 
60   int ConvertAttrToTensors();
61 #endif
62 
63   template <typename T = schema::MetaGraph, typename U = schema::CNode>
ConvertNodes(const T & meta_graph)64   bool ConvertNodes(const T &meta_graph) {
65     if (meta_graph.nodes() == nullptr) {
66       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
67       return false;
68     }
69     for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) {
70       auto *node = new (std::nothrow) LiteGraph::Node();
71       if (node == nullptr) {
72         MS_LOG(ERROR) << "new node fail!";
73         return false;
74       }
75       auto c_node = meta_graph.nodes()->template GetAs<U>(i);
76       if (c_node == nullptr) {
77         MS_LOG(ERROR) << "get as cnode fail!";
78         return false;
79       }
80 #ifdef ENABLE_MODEL_OBF
81       auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
82       auto src_prim_type = src_prim->value_type();
83       unsigned char *dst_prim = nullptr;
84       if (src_prim_type == schema::PrimitiveType_GenOP) {
85         auto src_node_stat = this->all_nodes_stat_[i];
86         auto dst_prim_type = this->all_prims_type_[i];
87         auto ret = DeObfuscatePrimitive(src_prim, src_node_stat, &dst_prim, schema::PrimitiveType(dst_prim_type));
88         if (!ret) {
89           MS_LOG(ERROR) << "Deobfuscate primitive failed!";
90           delete node;
91           return false;
92         }
93         if (dst_prim == nullptr) {
94           this->all_nodes_.push_back(node);
95           continue;
96         }
97         this->deobf_prims_.push_back(dst_prim);
98         src_prim = reinterpret_cast<const schema::Primitive *>(flatbuffers::GetRoot<schema::Primitive>(dst_prim));
99       }
100       node->primitive_ = const_cast<schema::Primitive *>(src_prim);
101 #else
102       node->primitive_ = c_node->primitive();
103 #endif
104       node->quant_type_ = c_node->quantType();
105       if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
106         SetNodeDeviceType(node, *c_node);
107       }
108 #ifdef ENABLE_V0
109       if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
110         SetNodeDeviceType(node, *c_node);
111       }
112 #endif
113       if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) {
114         node->quant_type_ = schema::QuantType_QUANT_ALL;
115       } else if (node->quant_type_ == schema::QuantType_WeightQuant) {
116         node->quant_type_ = schema::QuantType_QUANT_WEIGHT;
117       }
118       if (c_node->name() == nullptr) {
119         node->name_ = "";
120       } else {
121         node->name_ = c_node->name()->c_str();
122       }
123       if (c_node->inputIndex() != nullptr) {
124         auto count = c_node->inputIndex()->size();
125         for (uint32_t j = 0; j < count; ++j) {
126           node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
127         }
128       }
129       if (c_node->outputIndex() != nullptr) {
130         auto count = c_node->outputIndex()->size();
131         for (uint32_t j = 0; j < count; ++j) {
132           node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
133         }
134       }
135       this->graph_.all_nodes_.push_back(node);
136     }
137     return true;
138   }
139 
140   template <typename T = schema::MetaGraph>
ConvertTensors(const T & meta_graph)141   bool ConvertTensors(const T &meta_graph) {
142     if (meta_graph.allTensors() == nullptr) {
143       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
144       return false;
145     }
146     auto tensor_count = meta_graph.allTensors()->size();
147     for (uint32_t i = 0; i < tensor_count; ++i) {
148       auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
149       if (tensor == nullptr) {
150         MS_LOG(ERROR) << i << "the tensor in metagraph is nullptr";
151         return false;
152       }
153       this->graph_.all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
154     }
155     return true;
156   }
157 
158   template <typename T = schema::MetaGraph>
MetaGraphMappingSubGraph(const T & meta_graph)159   int MetaGraphMappingSubGraph(const T &meta_graph) {
160     if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || meta_graph.nodes() == nullptr ||
161         meta_graph.allTensors() == nullptr) {
162       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
163       return RET_ERROR;
164     }
165     auto *subgraph = new (std::nothrow) LiteGraph::SubGraph();
166     if (subgraph == nullptr) {
167       MS_LOG(ERROR) << "new subGraph fail!";
168       return RET_ERROR;
169     }
170     if (meta_graph.name() != nullptr) {
171       subgraph->name_ = meta_graph.name()->c_str();
172     }
173     auto in_count = meta_graph.inputIndex()->size();
174     for (uint32_t i = 0; i < in_count; ++i) {
175       subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i)));
176     }
177     auto out_count = meta_graph.outputIndex()->size();
178     for (uint32_t i = 0; i < out_count; ++i) {
179       subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i)));
180     }
181     auto node_count = meta_graph.nodes()->size();
182     for (uint32_t i = 0; i < node_count; ++i) {
183       subgraph->node_indices_.push_back(i);
184     }
185     auto tensor_count = meta_graph.allTensors()->size();
186     for (uint32_t i = 0; i < tensor_count; ++i) {
187       subgraph->tensor_indices_.push_back(i);
188     }
189     this->graph_.sub_graphs_.push_back(subgraph);
190     return RET_OK;
191   }
192 
193   template <typename T = schema::MetaGraph, typename U = schema::CNode>
GenerateModel(const T & meta_graph)194   int GenerateModel(const T &meta_graph) {
195     if (meta_graph.name() != nullptr) {
196       this->graph_.name_ = meta_graph.name()->c_str();
197     }
198     if (meta_graph.version() != nullptr) {
199       this->graph_.version_ = meta_graph.version()->c_str();
200     }
201     if (!ConvertNodes<T, U>(meta_graph)) {
202       MS_LOG(ERROR) << "convert node failed";
203       return RET_ERROR;
204     }
205     if (!ConvertTensors<T>(meta_graph)) {
206       MS_LOG(ERROR) << "convert tensor failed";
207       return RET_ERROR;
208     }
209 
210     if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr ||
211         meta_graph.allTensors() == nullptr) {
212       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
213       return RET_ERROR;
214     }
215 
216     // converterInputOutput
217     auto in_count = meta_graph.inputIndex()->size();
218     for (uint32_t i = 0; i < in_count; ++i) {
219       this->graph_.input_indices_.push_back(meta_graph.inputIndex()->Get(i));
220     }
221     auto out_count = meta_graph.outputIndex()->size();
222     for (uint32_t i = 0; i < out_count; ++i) {
223       this->graph_.output_indices_.push_back(meta_graph.outputIndex()->Get(i));
224     }
225 
226     if (meta_graph.subGraph() == nullptr) {
227       int ret = MetaGraphMappingSubGraph<T>(meta_graph);
228       if (ret != RET_OK) {
229         MS_LOG(ERROR) << "converter old version model wrong.";
230         return ret;
231       }
232     } else {
233       auto sub_graphs = meta_graph.subGraph();
234       MS_ASSERT(sub_graphs != nullptr);
235       auto sub_graph_size = sub_graphs->size();
236       for (size_t i = 0; i < sub_graph_size; i++) {
237         auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i);
238         int ret = ConvertSubGraph(*sub_graph);
239         if (ret != RET_OK) {
240           MS_LOG(ERROR) << "converter subgraph wrong.";
241           return ret;
242         }
243       }
244     }
245 #ifdef ENABLE_V0
246     if (ConvertAttrToTensors() != RET_OK) {
247       MS_LOG(ERROR) << "fail to convert attr to tensor.";
248       return RET_ERROR;
249     }
250 #endif
251     return RET_OK;
252   }
253 
SetNodeDeviceType(LiteGraph::Node * node,const schema::CNode & c_node)254   void SetNodeDeviceType(LiteGraph::Node *node, const schema::CNode &c_node) {
255     node->device_type_ = c_node.deviceType();
256   }
257 
258 #ifdef ENABLE_V0
SetNodeDeviceType(LiteGraph::Node * node,const schema::v0::CNode & c_node)259   void SetNodeDeviceType(LiteGraph::Node *node, const schema::v0::CNode &c_node) { node->device_type_ = -1; }
260 #endif
261 
262   int VersionVerify(flatbuffers::Verifier *verify) const;
263 
264   const void *GetMetaGraphByVerison();
265 
266   int GenerateModelByVersion(const void *meta_graph);
267 
268   int ConvertSubGraph(const schema::SubGraph &sub_graph);
269 
270   int NodeVerify() const;
271 
272   int SubGraphVerify() const;
273 
274  public:
275   size_t buf_size_ = 0;
276   std::vector<void *> node_bufs_;
277 
278  protected:
279   std::vector<char *> attr_tensor_bufs_;
280   bool keep_model_buf_ = false;
281   int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
282 };
283 
284 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
285 }  // namespace lite
286 }  // namespace mindspore
287 
288 #endif  // MINDSPORE_LITE_SRC_LITE_MODEL_H_
289