1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_LITE_MODEL_H_ 18 #define MINDSPORE_LITE_SRC_RUNTIME_LITE_MODEL_H_ 19 20 #include <string> 21 #include <utility> 22 #include <vector> 23 #include "include/errorcode.h" 24 #include "include/model.h" 25 #include "schema/model_generated.h" 26 #include "src/common/common.h" 27 #include "src/common/log_adapter.h" 28 #include "src/common/version_manager.h" 29 #include "src/litert/schema_tensor_wrapper.h" 30 #include "nnacl/op_base.h" 31 #include "src/common/prim_util.h" 32 #include "include/api/types.h" 33 #ifdef ENABLE_LITE_HELPER 34 #include "src/common/helper/infer_helpers.h" 35 #endif 36 #include "include/registry/deobf_processor.h" 37 38 namespace mindspore { 39 namespace lite { 40 class MS_API LiteModel : public Model { 41 public: model_path_(std::move (model_path))42 explicit LiteModel(std::string model_path = "") : model_path_(std::move(model_path)) {} 43 44 #ifdef ENABLE_LITE_HELPER 45 int ConstructModel(const char *model_buf, size_t size, bool take_buf, 46 mindspore::infer::helper::InferHelpers *infer_helpers = nullptr); 47 #else 48 int ConstructModel(const char *model_buf, size_t size, bool take_buf); 49 #endif 50 51 bool ModelVerify() const; 52 53 void Free() override; 54 55 void Destroy() override; 56 ~LiteModel()57 ~LiteModel() override { Destroy(); } 58 keep_model_buf()59 bool keep_model_buf() const { return this->keep_model_buf_; } 60 set_keep_model_buf(bool keep)61 void set_keep_model_buf(bool keep) { this->keep_model_buf_ = keep; } 62 GetSchemaVersion()63 int GetSchemaVersion() const { return schema_version_; } 64 65 SchemaTensorWrapper *GetSchemaTensor(const size_t &tensor_index) const; 66 67 static int VersionVerify(flatbuffers::Verifier *verify); 68 69 #ifdef ENABLE_LITE_HELPER 70 bool PrepareInnerTensors(mindspore::infer::helper::InferHelpers *infer_helpers = nullptr); 71 #else 72 bool PrepareInnerTensors(); 73 #endif 74 75 private: 76 bool CheckQuantAllInit(const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params); 77 78 template <typename T = schema::MetaGraph, typename U = schema::CNode> SetQuantType(const T & meta_graph,const U * c_node,LiteGraph::Node * node)79 int SetQuantType(const T &meta_graph, const U *c_node, LiteGraph::Node *node) { 80 node->quant_type_ = c_node->quantType(); 81 if (node->quant_type_ < schema::QuantType_MIN || node->quant_type_ > schema::QuantType_MAX) { 82 MS_LOG(ERROR) << "node->quant_type_:" << node->quant_type_ << " is invalid."; 83 delete node; 84 return RET_ERROR; 85 } 86 if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) { 87 SetNodeDeviceType(node, *c_node); 88 } 89 std::string version = meta_graph.version() == NULL ? "" : meta_graph.version()->str(); 90 const int min_version_length = 5; 91 if (version.length() > min_version_length) { 92 version = version.substr(version.length() - min_version_length, version.length()); 93 } 94 bool old_version_weight_quant = 95 ((meta_graph.version() == nullptr || version < "1.3.0") && node->quant_type_ == schema::QuantType_QUANT_NONE && 96 CheckNeedWeightQuant(meta_graph, c_node->inputIndex())); 97 if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) { 98 node->quant_type_ = schema::QuantType_QUANT_ALL; 99 } else if (node->quant_type_ == schema::QuantType_WeightQuant || old_version_weight_quant) { 100 node->quant_type_ = schema::QuantType_QUANT_WEIGHT; 101 } 102 return RET_OK; 103 } 104 105 template <typename T> CheckNeedWeightQuant(const T & meta_graph,const flatbuffers::Vector<uint32_t> * in_tensor_index)106 bool CheckNeedWeightQuant(const T &meta_graph, const flatbuffers::Vector<uint32_t> *in_tensor_index) { 107 if (in_tensor_index == nullptr) { 108 MS_LOG(ERROR) << "in_tensor_index is nullptr"; 109 return false; 110 } 111 const size_t min_quant_size = 2; 112 if (in_tensor_index->size() < min_quant_size) { 113 return false; 114 } 115 bool global_init_flag = false; 116 for (size_t i = 0; i < in_tensor_index->size(); i++) { 117 auto index = size_t(in_tensor_index->template GetAs<uint32_t>(i)); 118 if (meta_graph.allTensors() == nullptr) { 119 MS_LOG(ERROR) << "meta_graph.allTensors() is null."; 120 return false; 121 } 122 if (index >= meta_graph.allTensors()->size()) { 123 MS_LOG(ERROR) << "in_tensor_index is invalid."; 124 return false; 125 } 126 auto tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(index); 127 bool cur_tensor_init_flag = CheckQuantAllInit(tensor->quantParams()); 128 global_init_flag = global_init_flag || cur_tensor_init_flag; 129 if (tensor->data() == nullptr && cur_tensor_init_flag) { 130 MS_LOG(DEBUG) << tensor->name() 131 << " is a non-const tensor, but there are quantization parameters, which may belong to full " 132 "quantization."; 133 return false; 134 } 135 } 136 return global_init_flag; 137 } 138 139 template <typename T = schema::MetaGraph, typename U = schema::CNode> ConvertNodes(const T & meta_graph)140 bool ConvertNodes(const T &meta_graph) { 141 MS_CHECK_TRUE_MSG(meta_graph.nodes() != nullptr, false, "meta_graph is invalid, please check your model file."); 142 for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) { 143 auto *node = new (std::nothrow) LiteGraph::Node(); 144 MS_CHECK_TRUE_MSG(node != nullptr, false, "new node fail!"); 145 auto c_node = meta_graph.nodes()->template GetAs<U>(i); 146 MS_CHECK_TRUE_MSG(c_node != nullptr, false, "get as cnode fail!"); 147 node->node_type_ = GetPrimitiveType(c_node->primitive(), schema_version_); 148 if(this->deobf != nullptr){ 149 auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive()); 150 auto deobf_ptr = reinterpret_cast<DeObfProcessor *>(this->deobf); 151 DeObfRet ret = (deobf_ptr->*DeObfRegister::CreateDeObfNodeReg)(src_prim,i,schema_version_); 152 if(ret == kDeObfFailed){ 153 delete node; 154 return false; 155 } 156 if(ret == kNoObf){ 157 this->graph_.all_nodes_.push_back(node); 158 continue; 159 } 160 node->primitive_ = const_cast<schema::Primitive *>(src_prim); 161 } 162 else{ 163 node->primitive_ = c_node->primitive(); 164 } 165 auto status = SetQuantType(meta_graph, c_node, node); 166 if (status == RET_ERROR) { 167 return false; 168 } 169 if (c_node->name() == nullptr) { 170 node->name_ = ""; 171 } else { 172 node->name_ = c_node->name()->c_str(); 173 } 174 if (c_node->inputIndex() != nullptr) { 175 auto count = c_node->inputIndex()->size(); 176 for (uint32_t j = 0; j < count; ++j) { 177 node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j))); 178 } 179 } 180 if (c_node->outputIndex() != nullptr) { 181 auto count = c_node->outputIndex()->size(); 182 for (uint32_t j = 0; j < count; ++j) { 183 node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j))); 184 } 185 } 186 this->graph_.all_nodes_.push_back(node); 187 } 188 return true; 189 } 190 191 template <typename T = schema::MetaGraph> ConvertTensors(const T & meta_graph)192 bool ConvertTensors(const T &meta_graph) { 193 if (meta_graph.allTensors() == nullptr) { 194 MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 195 return false; 196 } 197 auto tensor_count = meta_graph.allTensors()->size(); 198 for (uint32_t i = 0; i < tensor_count; ++i) { 199 auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i); 200 if (tensor == nullptr) { 201 MS_LOG(ERROR) << i << "the tensor in metagraph is nullptr"; 202 return false; 203 } 204 MS_CHECK_TRUE_RET(tensor->format() >= schema::Format_MIN && tensor->format() <= schema::Format_MAX, false); 205 this->graph_.all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor)); 206 } 207 return true; 208 } 209 210 template <typename T = schema::MetaGraph> MetaGraphMappingSubGraph(const T & meta_graph)211 int MetaGraphMappingSubGraph(const T &meta_graph) { 212 if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || meta_graph.nodes() == nullptr || 213 meta_graph.allTensors() == nullptr) { 214 MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 215 return RET_ERROR; 216 } 217 auto *subgraph = new (std::nothrow) LiteGraph::SubGraph(); 218 if (subgraph == nullptr) { 219 MS_LOG(ERROR) << "new subGraph fail!"; 220 return RET_ERROR; 221 } 222 if (meta_graph.name() != nullptr) { 223 subgraph->name_ = meta_graph.name()->c_str(); 224 } 225 auto in_count = meta_graph.inputIndex()->size(); 226 for (uint32_t i = 0; i < in_count; ++i) { 227 subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i))); 228 } 229 auto out_count = meta_graph.outputIndex()->size(); 230 for (uint32_t i = 0; i < out_count; ++i) { 231 subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i))); 232 } 233 auto node_count = meta_graph.nodes()->size(); 234 for (uint32_t i = 0; i < node_count; ++i) { 235 subgraph->node_indices_.push_back(i); 236 } 237 auto tensor_count = meta_graph.allTensors()->size(); 238 for (uint32_t i = 0; i < tensor_count; ++i) { 239 subgraph->tensor_indices_.push_back(i); 240 } 241 this->graph_.sub_graphs_.push_back(subgraph); 242 return RET_OK; 243 } 244 245 template <typename T = schema::MetaGraph, typename U = schema::CNode> GenerateModel(const T & meta_graph)246 int GenerateModel(const T &meta_graph) { 247 if (meta_graph.name() != nullptr) { 248 this->graph_.name_ = meta_graph.name()->c_str(); 249 } 250 if (meta_graph.version() != nullptr) { 251 this->graph_.version_ = meta_graph.version()->c_str(); 252 } 253 if (!ConvertNodes<T, U>(meta_graph)) { 254 MS_LOG(ERROR) << "convert node failed"; 255 return RET_ERROR; 256 } 257 if (!ConvertTensors<T>(meta_graph)) { 258 MS_LOG(ERROR) << "convert tensor failed"; 259 return RET_ERROR; 260 } 261 262 if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || 263 meta_graph.allTensors() == nullptr) { 264 MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 265 return RET_ERROR; 266 } 267 268 // converterInputOutput 269 auto in_count = meta_graph.inputIndex()->size(); 270 for (uint32_t i = 0; i < in_count; ++i) { 271 this->graph_.input_indices_.push_back(meta_graph.inputIndex()->Get(i)); 272 } 273 auto out_count = meta_graph.outputIndex()->size(); 274 for (uint32_t i = 0; i < out_count; ++i) { 275 this->graph_.output_indices_.push_back(meta_graph.outputIndex()->Get(i)); 276 } 277 278 if (meta_graph.subGraph() == nullptr) { 279 int ret = MetaGraphMappingSubGraph<T>(meta_graph); 280 if (ret != RET_OK) { 281 MS_LOG(ERROR) << "converter old version model wrong."; 282 return ret; 283 } 284 } else { 285 auto sub_graphs = meta_graph.subGraph(); 286 MS_ASSERT(sub_graphs != nullptr); 287 auto sub_graph_size = sub_graphs->size(); 288 for (size_t i = 0; i < sub_graph_size; i++) { 289 auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i); 290 int ret = ConvertSubGraph(*sub_graph); 291 if (ret != RET_OK) { 292 MS_LOG(ERROR) << "converter subgraph wrong."; 293 return ret; 294 } 295 } 296 } 297 return RET_OK; 298 } 299 SetNodeDeviceType(LiteGraph::Node * node,const schema::CNode & c_node)300 void SetNodeDeviceType(LiteGraph::Node *node, const schema::CNode &c_node) { 301 node->device_type_ = c_node.deviceType(); 302 } 303 304 int GenerateModelByVersion(); 305 306 int ConvertSubGraph(const schema::SubGraph &sub_graph); 307 308 int NodeVerify() const; 309 310 int GraphInOutVerify() const; 311 312 int SubGraphVerify() const; 313 314 int SubGraphInOutVerify(const LiteGraph::SubGraph *graph) const; 315 316 public: 317 std::vector<void *> node_bufs_; 318 bool model_buf_by_mmap_ = false; 319 320 protected: 321 std::vector<char *> attr_tensor_bufs_; 322 bool keep_model_buf_ = false; 323 int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; 324 // tensor_index --- external_data 325 std::vector<SchemaTensorWrapper *> inner_all_tensors_; 326 const std::string model_path_; 327 }; 328 329 #ifdef ENABLE_LITE_HELPER 330 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, 331 mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite, 332 const std::string &path = "", mindspore::infer::helper::InferHelpers *infer_helpers = nullptr); 333 #else 334 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, 335 mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite, 336 const std::string &path = ""); 337 #endif 338 LiteModel *LiteImportFromPath(const char *model_path); 339 Model *ImportFromPath(const char *model_path); 340 341 std::string ModelDebugString(Model *model); 342 } // namespace lite 343 } // namespace mindspore 344 345 #endif // MINDSPORE_LITE_SRC_RUNTIME_LITE_MODEL_H_ 346