• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_LITE_MODEL_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_LITE_MODEL_H_
19 
20 #include <string>
21 #include <utility>
22 #include <vector>
23 #include "include/errorcode.h"
24 #include "include/model.h"
25 #include "schema/model_generated.h"
26 #include "src/common/common.h"
27 #include "src/common/log_adapter.h"
28 #include "src/common/version_manager.h"
29 #include "src/litert/schema_tensor_wrapper.h"
30 #include "nnacl/op_base.h"
31 #include "src/common/prim_util.h"
32 #include "include/api/types.h"
33 #ifdef ENABLE_LITE_HELPER
34 #include "src/common/helper/infer_helpers.h"
35 #endif
36 #include "include/registry/deobf_processor.h"
37 
38 namespace mindspore {
39 namespace lite {
40 class MS_API LiteModel : public Model {
41  public:
model_path_(std::move (model_path))42   explicit LiteModel(std::string model_path = "") : model_path_(std::move(model_path)) {}
43 
44 #ifdef ENABLE_LITE_HELPER
45   int ConstructModel(const char *model_buf, size_t size, bool take_buf,
46                      mindspore::infer::helper::InferHelpers *infer_helpers = nullptr);
47 #else
48   int ConstructModel(const char *model_buf, size_t size, bool take_buf);
49 #endif
50 
51   bool ModelVerify() const;
52 
53   void Free() override;
54 
55   void Destroy() override;
56 
~LiteModel()57   ~LiteModel() override { Destroy(); }
58 
keep_model_buf()59   bool keep_model_buf() const { return this->keep_model_buf_; }
60 
set_keep_model_buf(bool keep)61   void set_keep_model_buf(bool keep) { this->keep_model_buf_ = keep; }
62 
GetSchemaVersion()63   int GetSchemaVersion() const { return schema_version_; }
64 
65   SchemaTensorWrapper *GetSchemaTensor(const size_t &tensor_index) const;
66 
67   static int VersionVerify(flatbuffers::Verifier *verify);
68 
69 #ifdef ENABLE_LITE_HELPER
70   bool PrepareInnerTensors(mindspore::infer::helper::InferHelpers *infer_helpers = nullptr);
71 #else
72   bool PrepareInnerTensors();
73 #endif
74 
75  private:
76   bool CheckQuantAllInit(const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params);
77 
78   template <typename T = schema::MetaGraph, typename U = schema::CNode>
SetQuantType(const T & meta_graph,const U * c_node,LiteGraph::Node * node)79   int SetQuantType(const T &meta_graph, const U *c_node, LiteGraph::Node *node) {
80     node->quant_type_ = c_node->quantType();
81     if (node->quant_type_ < schema::QuantType_MIN || node->quant_type_ > schema::QuantType_MAX) {
82       MS_LOG(ERROR) << "node->quant_type_:" << node->quant_type_ << " is invalid.";
83       delete node;
84       return RET_ERROR;
85     }
86     if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
87       SetNodeDeviceType(node, *c_node);
88     }
89     std::string version = meta_graph.version() == NULL ? "" : meta_graph.version()->str();
90     const int min_version_length = 5;
91     if (version.length() > min_version_length) {
92       version = version.substr(version.length() - min_version_length, version.length());
93     }
94     bool old_version_weight_quant =
95       ((meta_graph.version() == nullptr || version < "1.3.0") && node->quant_type_ == schema::QuantType_QUANT_NONE &&
96        CheckNeedWeightQuant(meta_graph, c_node->inputIndex()));
97     if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) {
98       node->quant_type_ = schema::QuantType_QUANT_ALL;
99     } else if (node->quant_type_ == schema::QuantType_WeightQuant || old_version_weight_quant) {
100       node->quant_type_ = schema::QuantType_QUANT_WEIGHT;
101     }
102     return RET_OK;
103   }
104 
105   template <typename T>
CheckNeedWeightQuant(const T & meta_graph,const flatbuffers::Vector<uint32_t> * in_tensor_index)106   bool CheckNeedWeightQuant(const T &meta_graph, const flatbuffers::Vector<uint32_t> *in_tensor_index) {
107     if (in_tensor_index == nullptr) {
108       MS_LOG(ERROR) << "in_tensor_index is nullptr";
109       return false;
110     }
111     const size_t min_quant_size = 2;
112     if (in_tensor_index->size() < min_quant_size) {
113       return false;
114     }
115     bool global_init_flag = false;
116     for (size_t i = 0; i < in_tensor_index->size(); i++) {
117       auto index = size_t(in_tensor_index->template GetAs<uint32_t>(i));
118       if (meta_graph.allTensors() == nullptr) {
119         MS_LOG(ERROR) << "meta_graph.allTensors() is null.";
120         return false;
121       }
122       if (index >= meta_graph.allTensors()->size()) {
123         MS_LOG(ERROR) << "in_tensor_index is invalid.";
124         return false;
125       }
126       auto tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(index);
127       bool cur_tensor_init_flag = CheckQuantAllInit(tensor->quantParams());
128       global_init_flag = global_init_flag || cur_tensor_init_flag;
129       if (tensor->data() == nullptr && cur_tensor_init_flag) {
130         MS_LOG(DEBUG) << tensor->name()
131                       << " is a non-const tensor, but there are quantization parameters, which may belong to full "
132                          "quantization.";
133         return false;
134       }
135     }
136     return global_init_flag;
137   }
138 
139   template <typename T = schema::MetaGraph, typename U = schema::CNode>
ConvertNodes(const T & meta_graph)140   bool ConvertNodes(const T &meta_graph) {
141     MS_CHECK_TRUE_MSG(meta_graph.nodes() != nullptr, false, "meta_graph is invalid, please check your model file.");
142     for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) {
143       auto *node = new (std::nothrow) LiteGraph::Node();
144       MS_CHECK_TRUE_MSG(node != nullptr, false, "new node fail!");
145       auto c_node = meta_graph.nodes()->template GetAs<U>(i);
146       MS_CHECK_TRUE_MSG(c_node != nullptr, false, "get as cnode fail!");
147       node->node_type_ = GetPrimitiveType(c_node->primitive(), schema_version_);
148       if(this->deobf != nullptr){
149         auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
150         auto deobf_ptr = reinterpret_cast<DeObfProcessor *>(this->deobf);
151         DeObfRet ret = (deobf_ptr->*DeObfRegister::CreateDeObfNodeReg)(src_prim,i,schema_version_);
152         if(ret == kDeObfFailed){
153           delete node;
154           return false;
155         }
156         if(ret == kNoObf){
157           this->graph_.all_nodes_.push_back(node);
158           continue;
159         }
160         node->primitive_ = const_cast<schema::Primitive *>(src_prim);
161       }
162       else{
163         node->primitive_ = c_node->primitive();
164       }
165       auto status = SetQuantType(meta_graph, c_node, node);
166       if (status == RET_ERROR) {
167         return false;
168       }
169       if (c_node->name() == nullptr) {
170         node->name_ = "";
171       } else {
172         node->name_ = c_node->name()->c_str();
173       }
174       if (c_node->inputIndex() != nullptr) {
175         auto count = c_node->inputIndex()->size();
176         for (uint32_t j = 0; j < count; ++j) {
177           node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
178         }
179       }
180       if (c_node->outputIndex() != nullptr) {
181         auto count = c_node->outputIndex()->size();
182         for (uint32_t j = 0; j < count; ++j) {
183           node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
184         }
185       }
186       this->graph_.all_nodes_.push_back(node);
187     }
188     return true;
189   }
190 
191   template <typename T = schema::MetaGraph>
ConvertTensors(const T & meta_graph)192   bool ConvertTensors(const T &meta_graph) {
193     if (meta_graph.allTensors() == nullptr) {
194       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
195       return false;
196     }
197     auto tensor_count = meta_graph.allTensors()->size();
198     for (uint32_t i = 0; i < tensor_count; ++i) {
199       auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
200       if (tensor == nullptr) {
201         MS_LOG(ERROR) << i << "the tensor in metagraph is nullptr";
202         return false;
203       }
204       MS_CHECK_TRUE_RET(tensor->format() >= schema::Format_MIN && tensor->format() <= schema::Format_MAX, false);
205       this->graph_.all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
206     }
207     return true;
208   }
209 
210   template <typename T = schema::MetaGraph>
MetaGraphMappingSubGraph(const T & meta_graph)211   int MetaGraphMappingSubGraph(const T &meta_graph) {
212     if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || meta_graph.nodes() == nullptr ||
213         meta_graph.allTensors() == nullptr) {
214       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
215       return RET_ERROR;
216     }
217     auto *subgraph = new (std::nothrow) LiteGraph::SubGraph();
218     if (subgraph == nullptr) {
219       MS_LOG(ERROR) << "new subGraph fail!";
220       return RET_ERROR;
221     }
222     if (meta_graph.name() != nullptr) {
223       subgraph->name_ = meta_graph.name()->c_str();
224     }
225     auto in_count = meta_graph.inputIndex()->size();
226     for (uint32_t i = 0; i < in_count; ++i) {
227       subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i)));
228     }
229     auto out_count = meta_graph.outputIndex()->size();
230     for (uint32_t i = 0; i < out_count; ++i) {
231       subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i)));
232     }
233     auto node_count = meta_graph.nodes()->size();
234     for (uint32_t i = 0; i < node_count; ++i) {
235       subgraph->node_indices_.push_back(i);
236     }
237     auto tensor_count = meta_graph.allTensors()->size();
238     for (uint32_t i = 0; i < tensor_count; ++i) {
239       subgraph->tensor_indices_.push_back(i);
240     }
241     this->graph_.sub_graphs_.push_back(subgraph);
242     return RET_OK;
243   }
244 
245   template <typename T = schema::MetaGraph, typename U = schema::CNode>
GenerateModel(const T & meta_graph)246   int GenerateModel(const T &meta_graph) {
247     if (meta_graph.name() != nullptr) {
248       this->graph_.name_ = meta_graph.name()->c_str();
249     }
250     if (meta_graph.version() != nullptr) {
251       this->graph_.version_ = meta_graph.version()->c_str();
252     }
253     if (!ConvertNodes<T, U>(meta_graph)) {
254       MS_LOG(ERROR) << "convert node failed";
255       return RET_ERROR;
256     }
257     if (!ConvertTensors<T>(meta_graph)) {
258       MS_LOG(ERROR) << "convert tensor failed";
259       return RET_ERROR;
260     }
261 
262     if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr ||
263         meta_graph.allTensors() == nullptr) {
264       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
265       return RET_ERROR;
266     }
267 
268     // converterInputOutput
269     auto in_count = meta_graph.inputIndex()->size();
270     for (uint32_t i = 0; i < in_count; ++i) {
271       this->graph_.input_indices_.push_back(meta_graph.inputIndex()->Get(i));
272     }
273     auto out_count = meta_graph.outputIndex()->size();
274     for (uint32_t i = 0; i < out_count; ++i) {
275       this->graph_.output_indices_.push_back(meta_graph.outputIndex()->Get(i));
276     }
277 
278     if (meta_graph.subGraph() == nullptr) {
279       int ret = MetaGraphMappingSubGraph<T>(meta_graph);
280       if (ret != RET_OK) {
281         MS_LOG(ERROR) << "converter old version model wrong.";
282         return ret;
283       }
284     } else {
285       auto sub_graphs = meta_graph.subGraph();
286       MS_ASSERT(sub_graphs != nullptr);
287       auto sub_graph_size = sub_graphs->size();
288       for (size_t i = 0; i < sub_graph_size; i++) {
289         auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i);
290         int ret = ConvertSubGraph(*sub_graph);
291         if (ret != RET_OK) {
292           MS_LOG(ERROR) << "converter subgraph wrong.";
293           return ret;
294         }
295       }
296     }
297     return RET_OK;
298   }
299 
SetNodeDeviceType(LiteGraph::Node * node,const schema::CNode & c_node)300   void SetNodeDeviceType(LiteGraph::Node *node, const schema::CNode &c_node) {
301     node->device_type_ = c_node.deviceType();
302   }
303 
304   int GenerateModelByVersion();
305 
306   int ConvertSubGraph(const schema::SubGraph &sub_graph);
307 
308   int NodeVerify() const;
309 
310   int GraphInOutVerify() const;
311 
312   int SubGraphVerify() const;
313 
314   int SubGraphInOutVerify(const LiteGraph::SubGraph *graph) const;
315 
316  public:
317   std::vector<void *> node_bufs_;
318   bool model_buf_by_mmap_ = false;
319 
320  protected:
321   std::vector<char *> attr_tensor_bufs_;
322   bool keep_model_buf_ = false;
323   int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
324   // tensor_index --- external_data
325   std::vector<SchemaTensorWrapper *> inner_all_tensors_;
326   const std::string model_path_;
327 };
328 
329 #ifdef ENABLE_LITE_HELPER
330 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf,
331                         mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite,
332                         const std::string &path = "", mindspore::infer::helper::InferHelpers *infer_helpers = nullptr);
333 #else
334 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf,
335                         mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite,
336                         const std::string &path = "");
337 #endif
338 LiteModel *LiteImportFromPath(const char *model_path);
339 Model *ImportFromPath(const char *model_path);
340 
341 std::string ModelDebugString(Model *model);
342 }  // namespace lite
343 }  // namespace mindspore
344 
345 #endif  // MINDSPORE_LITE_SRC_RUNTIME_LITE_MODEL_H_
346