• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include <string>
19 
20 #include "extendrt/mindir_loader/mindir_model/mindir_model_loader.h"
21 #include "extendrt/mindir_loader/mindir_model/mindir_model_util.h"
22 #include "src/litert/kernel_registry.h"
23 #include "ops/primitive_c.h"
24 
25 namespace mindspore::infer::mindir {
26 const char kNodeTypeConstant[] = "Constant";
27 
ImportModel(const char * model_buf,size_t size,bool take_buf)28 AbstractBaseModel *MindirModelLoader::ImportModel(const char *model_buf, size_t size, bool take_buf) {
29   this->model_ = new MindirModel();
30   MS_CHECK_TRUE_MSG(this->model_ != nullptr, nullptr,
31                     "MindirModelLoader: Import model failed: new mindir model failed.");
32   this->model_->model_type_ = mindspore::lite::ModelType_MindIR;
33   auto ret = this->InitModelBuffer(this->model_, model_buf, size, take_buf);
34   MS_CHECK_TRUE_MSG(ret == RET_OK, nullptr,
35                     "MindirModelLoader: Import model failed: init model buffer error with " << ret);
36 
37   // mind_ir::ModelProto model_proto;
38   MS_CHECK_TRUE_MSG(this->model_->mindir_model_proto_.ParseFromArray(this->model_->buf, static_cast<int32_t>(size)),
39                     nullptr, "MindirModelLoader: Import model failed, please check the correctness of the file.");
40 
41   MS_LOG(ERROR) << "model_proto: " << this->model_->mindir_model_proto_.DebugString();
42 
43   if (!this->ConvertModel(this->model_->mindir_model_proto_)) {
44     MS_LOG(ERROR)
45       << "MindirModelLoader: Import model failed, convert model error, please check the correctness of the file.";
46     delete this->model_;
47     this->model_ = nullptr;
48     return nullptr;
49   }
50 
51   return this->model_;
52 }
53 
ConvertModel(const mind_ir::ModelProto & model_proto)54 bool MindirModelLoader::ConvertModel(const mind_ir::ModelProto &model_proto) {
55   this->model_->graph_.name_ = "";
56   if (model_proto.has_model_version()) {
57     this->model_->graph_.version_ = model_proto.model_version();
58   }
59 
60   MS_CHECK_TRUE_MSG(
61     ConvertPrimitives(model_proto), false,
62     "MindirModelLoader: Import model failed, convert primitives error, please check the correctness of the file.");
63   this->tensor_count_ = 0;
64   this->node_count_ = 0;
65   if (model_proto.has_graph()) {
66     this->model_->graph_.name_ = model_proto.graph().name();
67     // root graph, do not pass sub graph
68     if (model_proto.functions_size() > 0) {
69       MS_CHECK_TRUE_MSG(
70         ConvertGraph(model_proto.graph(), nullptr, true), false,
71         "MindirModelLoader: Import model failed, convert root graph error, please check the correctness of the file.");
72     } else {
73       // no subgraph, add graph to subgraph
74       auto *sub_graph = new LiteGraph::SubGraph();
75       sub_graph->name_ = model_proto.graph().name();
76       MS_CHECK_TRUE_MSG(
77         ConvertGraph(model_proto.graph(), sub_graph, true), false,
78         "MindirModelLoader: Import model failed, convert root graph error, please check the correctness of the file.");
79       this->model_->graph_.sub_graphs_.push_back(sub_graph);
80     }
81   }
82 
83   for (int i = 0; i < model_proto.functions_size(); i++) {
84     auto sub_graph_proto = model_proto.functions(i);
85     auto *sub_graph = new LiteGraph::SubGraph();
86     if (sub_graph == nullptr) {
87       MS_LOG(ERROR) << "MindirModelLoader: Import model failed, new sub graph failed.";
88       return mindspore::lite::RET_ERROR;
89     }
90     // MS_CHECK_FALSE_MSG(sub_graph == nullptr, mindspore::lite::RET_ERROR,
91     //                    "MindirModelLoader: Import model failed, new sub graph failed.");
92     sub_graph->name_ = sub_graph_proto.name();
93     MS_CHECK_TRUE_MSG(
94       ConvertGraph(sub_graph_proto, sub_graph), false,
95       "MindirModelLoader: Import model failed, convert sub graph error, please check the correctness of the file.");
96     this->model_->graph_.sub_graphs_.push_back(sub_graph);
97   }
98   MS_LOG(INFO) << "MindirModelLoader: Import model successful.";
99   return true;
100 }
101 
ConvertPrimitives(const mind_ir::ModelProto & model_proto)102 bool MindirModelLoader::ConvertPrimitives(const mind_ir::ModelProto &model_proto) {
103   static auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
104   for (int i = 0; i < model_proto.primitives_size(); i++) {
105     auto primitive_proto = model_proto.primitives(i);
106     auto op_type = primitive_proto.op_type();
107     std::shared_ptr<mindspore::Primitive> prim;
108     auto it = op_primc_fns.find(op_type);
109     if (it == op_primc_fns.end()) {
110       MS_LOG(WARNING) << "MindirModelLoader: Convert primitives failed, unsupported op primitive type: " << op_type;
111       continue;
112     }
113     prim = it->second();
114     MS_CHECK_TRUE_MSG(prim != nullptr, false, "MindirModelLoader: Convert primitives failed, the prim is nullptr.");
115     prim->set_instance_name(op_type);
116     for (int j = 0; j < primitive_proto.attribute_size(); j++) {
117       auto attr_proto = primitive_proto.attribute(j);
118       auto value_ptr = MindirModelUtil::MakeValueFromAttribute(attr_proto);
119       MS_CHECK_TRUE_MSG(value_ptr != nullptr, false,
120                         "MindirModelLoader: convert primitives failed, parse prim: "
121                           << prim->ToString() << " attributes error: " << attr_proto.DebugString());
122       (void)prim->AddAttr(attr_proto.name(), value_ptr);
123     }
124     static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
125     auto op_it = operator_fns.find(op_type);
126     if (op_it == operator_fns.end()) {
127       MS_LOG(WARNING) << "MindirModelLoader: Convert primitives failed, unsupported op operator type: " << op_type;
128       continue;
129     }
130     auto base_operator = op_it->second(prim);
131     MS_CHECK_TRUE_MSG(this->all_operators_.count(primitive_proto.name()) <= 0, false,
132                       "MindirModelLoader: There is a duplication primitive instance name: " << primitive_proto.name());
133     this->all_operators_[primitive_proto.name()] = base_operator;
134   }
135   return true;
136 }
137 
ConvertGraph(const mind_ir::GraphProto & graph_proto,LiteGraph::SubGraph * sub_graph,bool is_main_graph)138 bool MindirModelLoader::ConvertGraph(const mind_ir::GraphProto &graph_proto, LiteGraph::SubGraph *sub_graph,
139                                      bool is_main_graph) {
140   MS_CHECK_TRUE_MSG(
141     ConvertTensors(graph_proto, sub_graph, is_main_graph), false,
142     "MindirModelLoader: Convert Graph failed, convert tensors error, please check the correctness of the file.");
143   MS_CHECK_TRUE_MSG(
144     ConvertNodes(graph_proto, sub_graph, is_main_graph), false,
145     "MindirModelLoader: Convert Graph failed, convert nodes error, please check the correctness of the file.");
146   return true;
147 }
148 
ConvertTensors(const mind_ir::GraphProto & graph_proto,LiteGraph::SubGraph * sub_graph,bool is_main_graph)149 bool MindirModelLoader::ConvertTensors(const mind_ir::GraphProto &graph_proto, LiteGraph::SubGraph *sub_graph,
150                                        bool is_main_graph) {
151   for (int i = 0; i < graph_proto.input_size(); i++) {
152     const mind_ir::TensorProto &tensor_proto = graph_proto.input(i).tensor(0);
153     TensorProtoWrap tensor_wrap(graph_proto.input(i).name(), tensor_proto);
154     this->model_->all_mindir_tensors_.push_back(tensor_wrap);
155     this->tensor_index_map_[graph_proto.input(i).name()] = this->tensor_count_;
156     if (sub_graph != nullptr) {
157       sub_graph->input_indices_.push_back(this->tensor_count_);
158       sub_graph->tensor_indices_.push_back(this->tensor_count_);
159     }
160     if (is_main_graph) {
161       this->model_->graph_.input_indices_.push_back(this->tensor_count_);
162     }
163     this->tensor_count_++;
164   }
165   for (int i = 0; i < graph_proto.output_size(); i++) {
166     const mind_ir::TensorProto &tensor_proto = graph_proto.output(i).tensor(0);
167     TensorProtoWrap tensor_wrap(graph_proto.output(i).name(), tensor_proto);
168     this->model_->all_mindir_tensors_.push_back(tensor_wrap);
169     this->tensor_index_map_[graph_proto.output(i).name()] = this->tensor_count_;
170     if (sub_graph != nullptr) {
171       sub_graph->output_indices_.push_back(this->tensor_count_);
172       sub_graph->tensor_indices_.push_back(this->tensor_count_);
173     }
174     if (is_main_graph) {
175       this->model_->graph_.output_indices_.push_back(this->tensor_count_);
176     }
177     this->tensor_count_++;
178   }
179   for (int i = 0; i < graph_proto.parameter_size(); i++) {
180     const mind_ir::TensorProto &tensor_proto = graph_proto.parameter(i);
181     TensorProtoWrap tensor_wrap(tensor_proto.name(), tensor_proto);
182     this->model_->all_mindir_tensors_.push_back(tensor_wrap);
183     this->tensor_index_map_[tensor_proto.name()] = this->tensor_count_;
184     if (sub_graph != nullptr) {
185       sub_graph->tensor_indices_.push_back(this->tensor_count_);
186     }
187     this->tensor_count_++;
188   }
189   return true;
190 }
191 
ConvertNodes(const mind_ir::GraphProto & graph_proto,LiteGraph::SubGraph * sub_graph,bool is_main_graph)192 bool MindirModelLoader::ConvertNodes(const mind_ir::GraphProto &graph_proto, LiteGraph::SubGraph *sub_graph,
193                                      bool is_main_graph) {
194   for (int i = 0; i < graph_proto.node_size(); i++) {
195     auto node_proto = graph_proto.node(i);
196     if (node_proto.op_type() == kNodeTypeConstant) {
197       // Constant node, convert to tensor
198       for (int j = 0; j < node_proto.attribute_size(); j++) {
199         auto attribute_proto = node_proto.attribute(j);
200         if (attribute_proto.type() == mind_ir::AttributeProto_AttributeType_TENSORS) {
201           const mind_ir::TensorProto &tensor_proto = attribute_proto.tensors(0);
202           TensorProtoWrap tensor_wrap(node_proto.name(), tensor_proto);
203           this->model_->all_mindir_tensors_.push_back(tensor_wrap);
204           this->tensor_index_map_[node_proto.name()] = this->tensor_count_;
205           if (sub_graph != nullptr) {
206             sub_graph->tensor_indices_.push_back(this->tensor_count_);
207           }
208           this->tensor_count_++;
209         }
210       }
211       continue;
212     }
213     auto *node = new LiteGraph::Node();
214     if (node == nullptr) {
215       MS_LOG(ERROR) << "MindirModelLoader: Convert nodes failed, new node failed.";
216       return false;
217     }
218     node->name_ = node_proto.name();
219     node->base_operator_ = this->MakePrimitiveC(node_proto.op_type());
220     auto base_operator = std::reinterpret_pointer_cast<ops::BaseOperator>(node->base_operator_);
221     node->op_type_ = base_operator->GetPrim()->instance_name();
222 
223     // solve input
224     for (int j = 0; j < node_proto.input_size(); j++) {
225       std::string input_name = node_proto.input(j);
226       auto it = this->tensor_index_map_.find(input_name);
227       if (it == this->tensor_index_map_.end()) {
228         MS_LOG(WARNING) << "MindirModelLoader: Convert nodes failed, cannot find input index with " << input_name;
229         continue;
230       }
231       node->input_indices_.push_back(it->second);
232     }
233 
234     // solve output
235     for (int j = 0; j < node_proto.output_size(); j++) {
236       std::string output_name = node_proto.output(j);
237       auto it = this->tensor_index_map_.find(output_name);
238       if (it == this->tensor_index_map_.end()) {
239         MS_LOG(WARNING) << "MindirModelLoader: Convert nodes failed, cannot find output index with " << output_name;
240         continue;
241       }
242       node->output_indices_.push_back(it->second);
243     }
244 
245     this->model_->graph_.all_nodes_.push_back(node);
246     if (sub_graph != nullptr) {
247       sub_graph->node_indices_.push_back(this->node_count_);
248     }
249     this->node_count_++;
250   }
251   return true;
252 }
253 
MakePrimitiveC(const std::string & node_type)254 std::shared_ptr<void> MindirModelLoader::MakePrimitiveC(const std::string &node_type) {
255   const std::string kOperatorTypeFlag = std::string("REF::");
256   const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
257   if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
258     auto it = this->all_operators_.find(node_type.substr(kOpTypeFlagSize));
259     if (it == this->all_operators_.end()) {
260       MS_LOG(ERROR) << "MindirModelLoader: make primitiveC failed, can't find the primitive ref:" << node_type;
261       return nullptr;
262     }
263     return it->second;
264   }
265 
266   // node_type is not ref: pointer
267   auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
268   if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
269     // registered primitive
270     auto prim = (op_primc_fns[node_type]());
271     MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "MindirModelLoader: Make primitiveC failed, the prim is nullptr.");
272     prim->set_instance_name(node_type);
273     static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
274     auto op_it = operator_fns.find(node_type);
275     if (op_it == operator_fns.end()) {
276       MS_LOG(WARNING) << "MindirModelLoader: Make primitiveC failed, unsupported op operator type: " << node_type;
277       return nullptr;
278     }
279     return op_it->second(prim);
280   } else {
281     // S_Prim_xxx or S_Prim_hyper_map[xxx] and custom node type, now not support
282     MS_LOG(ERROR) << "MindirModelLoader: make primitiveC failed, not support node type: " << node_type;
283     return nullptr;
284   }
285 }
286 
MindirModelLoaderCreator()287 static std::shared_ptr<ModelLoader> MindirModelLoaderCreator() { return std::make_shared<MindirModelLoader>(); }
288 
289 REG_MODEL_LOADER(kMindIR, MindirModelLoaderCreator);
290 }  // namespace mindspore::infer::mindir
291