1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_LITE_MODEL_H_ 18 #define MINDSPORE_LITE_SRC_LITE_MODEL_H_ 19 20 #include <string> 21 #include <vector> 22 #include "include/errorcode.h" 23 #include "include/model.h" 24 #include "include/version.h" 25 #include "schema/model_generated.h" 26 #include "src/common/common.h" 27 #include "src/common/log_adapter.h" 28 #include "src/common/version_manager.h" 29 #ifdef ENABLE_V0 30 #include "schema/model_v0_generated.h" 31 #endif 32 #ifdef ENABLE_MODEL_OBF 33 #include "tools/obfuscator/include/deobfuscator.h" 34 #endif 35 36 namespace mindspore { 37 namespace lite { 38 class LiteModel : public Model { 39 public: 40 int ConstructModel(); 41 42 bool ModelVerify() const; 43 44 void Free() override; 45 46 void Destroy() override; 47 ~LiteModel()48 ~LiteModel() override { Destroy(); } 49 keep_model_buf()50 bool keep_model_buf() const { return this->keep_model_buf_; } 51 set_keep_model_buf(bool keep)52 void set_keep_model_buf(bool keep) { this->keep_model_buf_ = keep; } 53 GetSchemaVersion()54 int GetSchemaVersion() const { return schema_version_; } 55 56 private: 57 #ifdef ENABLE_V0 58 int ConvertAttrs(LiteGraph::Node *node, std::vector<schema::Tensor *> *dst_tensor); 59 60 int ConvertAttrToTensors(); 61 #endif 62 63 template <typename T = schema::MetaGraph, typename U = schema::CNode> ConvertNodes(const T & meta_graph)64 bool ConvertNodes(const T &meta_graph) { 65 if (meta_graph.nodes() == nullptr) { 66 MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 67 return false; 68 } 69 for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) { 70 auto *node = new (std::nothrow) LiteGraph::Node(); 71 if (node == nullptr) { 72 MS_LOG(ERROR) << "new node fail!"; 73 return false; 74 } 75 auto c_node = meta_graph.nodes()->template GetAs<U>(i); 76 if (c_node == nullptr) { 77 MS_LOG(ERROR) << "get as cnode fail!"; 78 return false; 79 } 80 #ifdef ENABLE_MODEL_OBF 81 auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive()); 82 auto src_prim_type = src_prim->value_type(); 83 unsigned char *dst_prim = nullptr; 84 if (src_prim_type == schema::PrimitiveType_GenOP) { 85 auto src_node_stat = this->all_nodes_stat_[i]; 86 auto dst_prim_type = this->all_prims_type_[i]; 87 auto ret = DeObfuscatePrimitive(src_prim, src_node_stat, &dst_prim, schema::PrimitiveType(dst_prim_type)); 88 if (!ret) { 89 MS_LOG(ERROR) << "Deobfuscate primitive failed!"; 90 delete node; 91 return false; 92 } 93 if (dst_prim == nullptr) { 94 this->all_nodes_.push_back(node); 95 continue; 96 } 97 this->deobf_prims_.push_back(dst_prim); 98 src_prim = reinterpret_cast<const schema::Primitive *>(flatbuffers::GetRoot<schema::Primitive>(dst_prim)); 99 } 100 node->primitive_ = const_cast<schema::Primitive *>(src_prim); 101 #else 102 node->primitive_ = c_node->primitive(); 103 #endif 104 node->quant_type_ = c_node->quantType(); 105 if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) { 106 SetNodeDeviceType(node, *c_node); 107 } 108 #ifdef ENABLE_V0 109 if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) { 110 SetNodeDeviceType(node, *c_node); 111 } 112 #endif 113 if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) { 114 node->quant_type_ = schema::QuantType_QUANT_ALL; 115 } else if (node->quant_type_ == schema::QuantType_WeightQuant) { 116 node->quant_type_ = schema::QuantType_QUANT_WEIGHT; 117 } 118 if (c_node->name() == nullptr) { 119 node->name_ = ""; 120 } else { 121 node->name_ = c_node->name()->c_str(); 122 } 123 if (c_node->inputIndex() != nullptr) { 124 auto count = c_node->inputIndex()->size(); 125 for (uint32_t j = 0; j < count; ++j) { 126 node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j))); 127 } 128 } 129 if (c_node->outputIndex() != nullptr) { 130 auto count = c_node->outputIndex()->size(); 131 for (uint32_t j = 0; j < count; ++j) { 132 node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j))); 133 } 134 } 135 this->graph_.all_nodes_.push_back(node); 136 } 137 return true; 138 } 139 140 template <typename T = schema::MetaGraph> ConvertTensors(const T & meta_graph)141 bool ConvertTensors(const T &meta_graph) { 142 if (meta_graph.allTensors() == nullptr) { 143 MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 144 return false; 145 } 146 auto tensor_count = meta_graph.allTensors()->size(); 147 for (uint32_t i = 0; i < tensor_count; ++i) { 148 auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i); 149 if (tensor == nullptr) { 150 MS_LOG(ERROR) << i << "the tensor in metagraph is nullptr"; 151 return false; 152 } 153 this->graph_.all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor)); 154 } 155 return true; 156 } 157 158 template <typename T = schema::MetaGraph> MetaGraphMappingSubGraph(const T & meta_graph)159 int MetaGraphMappingSubGraph(const T &meta_graph) { 160 if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || meta_graph.nodes() == nullptr || 161 meta_graph.allTensors() == nullptr) { 162 MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 163 return RET_ERROR; 164 } 165 auto *subgraph = new (std::nothrow) LiteGraph::SubGraph(); 166 if (subgraph == nullptr) { 167 MS_LOG(ERROR) << "new subGraph fail!"; 168 return RET_ERROR; 169 } 170 if (meta_graph.name() != nullptr) { 171 subgraph->name_ = meta_graph.name()->c_str(); 172 } 173 auto in_count = meta_graph.inputIndex()->size(); 174 for (uint32_t i = 0; i < in_count; ++i) { 175 subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i))); 176 } 177 auto out_count = meta_graph.outputIndex()->size(); 178 for (uint32_t i = 0; i < out_count; ++i) { 179 subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i))); 180 } 181 auto node_count = meta_graph.nodes()->size(); 182 for (uint32_t i = 0; i < node_count; ++i) { 183 subgraph->node_indices_.push_back(i); 184 } 185 auto tensor_count = meta_graph.allTensors()->size(); 186 for (uint32_t i = 0; i < tensor_count; ++i) { 187 subgraph->tensor_indices_.push_back(i); 188 } 189 this->graph_.sub_graphs_.push_back(subgraph); 190 return RET_OK; 191 } 192 193 template <typename T = schema::MetaGraph, typename U = schema::CNode> GenerateModel(const T & meta_graph)194 int GenerateModel(const T &meta_graph) { 195 if (meta_graph.name() != nullptr) { 196 this->graph_.name_ = meta_graph.name()->c_str(); 197 } 198 if (meta_graph.version() != nullptr) { 199 this->graph_.version_ = meta_graph.version()->c_str(); 200 } 201 if (!ConvertNodes<T, U>(meta_graph)) { 202 MS_LOG(ERROR) << "convert node failed"; 203 return RET_ERROR; 204 } 205 if (!ConvertTensors<T>(meta_graph)) { 206 MS_LOG(ERROR) << "convert tensor failed"; 207 return RET_ERROR; 208 } 209 210 if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || 211 meta_graph.allTensors() == nullptr) { 212 MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 213 return RET_ERROR; 214 } 215 216 // converterInputOutput 217 auto in_count = meta_graph.inputIndex()->size(); 218 for (uint32_t i = 0; i < in_count; ++i) { 219 this->graph_.input_indices_.push_back(meta_graph.inputIndex()->Get(i)); 220 } 221 auto out_count = meta_graph.outputIndex()->size(); 222 for (uint32_t i = 0; i < out_count; ++i) { 223 this->graph_.output_indices_.push_back(meta_graph.outputIndex()->Get(i)); 224 } 225 226 if (meta_graph.subGraph() == nullptr) { 227 int ret = MetaGraphMappingSubGraph<T>(meta_graph); 228 if (ret != RET_OK) { 229 MS_LOG(ERROR) << "converter old version model wrong."; 230 return ret; 231 } 232 } else { 233 auto sub_graphs = meta_graph.subGraph(); 234 MS_ASSERT(sub_graphs != nullptr); 235 auto sub_graph_size = sub_graphs->size(); 236 for (size_t i = 0; i < sub_graph_size; i++) { 237 auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i); 238 int ret = ConvertSubGraph(*sub_graph); 239 if (ret != RET_OK) { 240 MS_LOG(ERROR) << "converter subgraph wrong."; 241 return ret; 242 } 243 } 244 } 245 #ifdef ENABLE_V0 246 if (ConvertAttrToTensors() != RET_OK) { 247 MS_LOG(ERROR) << "fail to convert attr to tensor."; 248 return RET_ERROR; 249 } 250 #endif 251 return RET_OK; 252 } 253 SetNodeDeviceType(LiteGraph::Node * node,const schema::CNode & c_node)254 void SetNodeDeviceType(LiteGraph::Node *node, const schema::CNode &c_node) { 255 node->device_type_ = c_node.deviceType(); 256 } 257 258 #ifdef ENABLE_V0 SetNodeDeviceType(LiteGraph::Node * node,const schema::v0::CNode & c_node)259 void SetNodeDeviceType(LiteGraph::Node *node, const schema::v0::CNode &c_node) { node->device_type_ = -1; } 260 #endif 261 262 int VersionVerify(flatbuffers::Verifier *verify) const; 263 264 const void *GetMetaGraphByVerison(); 265 266 int GenerateModelByVersion(const void *meta_graph); 267 268 int ConvertSubGraph(const schema::SubGraph &sub_graph); 269 270 int NodeVerify() const; 271 272 int SubGraphVerify() const; 273 274 public: 275 size_t buf_size_ = 0; 276 std::vector<void *> node_bufs_; 277 278 protected: 279 std::vector<char *> attr_tensor_bufs_; 280 bool keep_model_buf_ = false; 281 int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; 282 }; 283 284 Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf); 285 } // namespace lite 286 } // namespace mindspore 287 288 #endif // MINDSPORE_LITE_SRC_LITE_MODEL_H_ 289