1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18 #include <string>
19
20 #include "extendrt/mindir_loader/mindir_model/mindir_model_loader.h"
21 #include "extendrt/mindir_loader/mindir_model/mindir_model_util.h"
22 #include "src/litert/kernel_registry.h"
23 #include "ops/primitive_c.h"
24
25 namespace mindspore::infer::mindir {
26 const char kNodeTypeConstant[] = "Constant";
27
ImportModel(const char * model_buf,size_t size,bool take_buf)28 AbstractBaseModel *MindirModelLoader::ImportModel(const char *model_buf, size_t size, bool take_buf) {
29 this->model_ = new MindirModel();
30 MS_CHECK_TRUE_MSG(this->model_ != nullptr, nullptr,
31 "MindirModelLoader: Import model failed: new mindir model failed.");
32 this->model_->model_type_ = mindspore::lite::ModelType_MindIR;
33 auto ret = this->InitModelBuffer(this->model_, model_buf, size, take_buf);
34 MS_CHECK_TRUE_MSG(ret == RET_OK, nullptr,
35 "MindirModelLoader: Import model failed: init model buffer error with " << ret);
36
37 // mind_ir::ModelProto model_proto;
38 MS_CHECK_TRUE_MSG(this->model_->mindir_model_proto_.ParseFromArray(this->model_->buf, static_cast<int32_t>(size)),
39 nullptr, "MindirModelLoader: Import model failed, please check the correctness of the file.");
40
41 MS_LOG(ERROR) << "model_proto: " << this->model_->mindir_model_proto_.DebugString();
42
43 if (!this->ConvertModel(this->model_->mindir_model_proto_)) {
44 MS_LOG(ERROR)
45 << "MindirModelLoader: Import model failed, convert model error, please check the correctness of the file.";
46 delete this->model_;
47 this->model_ = nullptr;
48 return nullptr;
49 }
50
51 return this->model_;
52 }
53
ConvertModel(const mind_ir::ModelProto & model_proto)54 bool MindirModelLoader::ConvertModel(const mind_ir::ModelProto &model_proto) {
55 this->model_->graph_.name_ = "";
56 if (model_proto.has_model_version()) {
57 this->model_->graph_.version_ = model_proto.model_version();
58 }
59
60 MS_CHECK_TRUE_MSG(
61 ConvertPrimitives(model_proto), false,
62 "MindirModelLoader: Import model failed, convert primitives error, please check the correctness of the file.");
63 this->tensor_count_ = 0;
64 this->node_count_ = 0;
65 if (model_proto.has_graph()) {
66 this->model_->graph_.name_ = model_proto.graph().name();
67 // root graph, do not pass sub graph
68 if (model_proto.functions_size() > 0) {
69 MS_CHECK_TRUE_MSG(
70 ConvertGraph(model_proto.graph(), nullptr, true), false,
71 "MindirModelLoader: Import model failed, convert root graph error, please check the correctness of the file.");
72 } else {
73 // no subgraph, add graph to subgraph
74 auto *sub_graph = new LiteGraph::SubGraph();
75 sub_graph->name_ = model_proto.graph().name();
76 MS_CHECK_TRUE_MSG(
77 ConvertGraph(model_proto.graph(), sub_graph, true), false,
78 "MindirModelLoader: Import model failed, convert root graph error, please check the correctness of the file.");
79 this->model_->graph_.sub_graphs_.push_back(sub_graph);
80 }
81 }
82
83 for (int i = 0; i < model_proto.functions_size(); i++) {
84 auto sub_graph_proto = model_proto.functions(i);
85 auto *sub_graph = new LiteGraph::SubGraph();
86 if (sub_graph == nullptr) {
87 MS_LOG(ERROR) << "MindirModelLoader: Import model failed, new sub graph failed.";
88 return mindspore::lite::RET_ERROR;
89 }
90 // MS_CHECK_FALSE_MSG(sub_graph == nullptr, mindspore::lite::RET_ERROR,
91 // "MindirModelLoader: Import model failed, new sub graph failed.");
92 sub_graph->name_ = sub_graph_proto.name();
93 MS_CHECK_TRUE_MSG(
94 ConvertGraph(sub_graph_proto, sub_graph), false,
95 "MindirModelLoader: Import model failed, convert sub graph error, please check the correctness of the file.");
96 this->model_->graph_.sub_graphs_.push_back(sub_graph);
97 }
98 MS_LOG(INFO) << "MindirModelLoader: Import model successful.";
99 return true;
100 }
101
ConvertPrimitives(const mind_ir::ModelProto & model_proto)102 bool MindirModelLoader::ConvertPrimitives(const mind_ir::ModelProto &model_proto) {
103 static auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
104 for (int i = 0; i < model_proto.primitives_size(); i++) {
105 auto primitive_proto = model_proto.primitives(i);
106 auto op_type = primitive_proto.op_type();
107 std::shared_ptr<mindspore::Primitive> prim;
108 auto it = op_primc_fns.find(op_type);
109 if (it == op_primc_fns.end()) {
110 MS_LOG(WARNING) << "MindirModelLoader: Convert primitives failed, unsupported op primitive type: " << op_type;
111 continue;
112 }
113 prim = it->second();
114 MS_CHECK_TRUE_MSG(prim != nullptr, false, "MindirModelLoader: Convert primitives failed, the prim is nullptr.");
115 prim->set_instance_name(op_type);
116 for (int j = 0; j < primitive_proto.attribute_size(); j++) {
117 auto attr_proto = primitive_proto.attribute(j);
118 auto value_ptr = MindirModelUtil::MakeValueFromAttribute(attr_proto);
119 MS_CHECK_TRUE_MSG(value_ptr != nullptr, false,
120 "MindirModelLoader: convert primitives failed, parse prim: "
121 << prim->ToString() << " attributes error: " << attr_proto.DebugString());
122 (void)prim->AddAttr(attr_proto.name(), value_ptr);
123 }
124 static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
125 auto op_it = operator_fns.find(op_type);
126 if (op_it == operator_fns.end()) {
127 MS_LOG(WARNING) << "MindirModelLoader: Convert primitives failed, unsupported op operator type: " << op_type;
128 continue;
129 }
130 auto base_operator = op_it->second(prim);
131 MS_CHECK_TRUE_MSG(this->all_operators_.count(primitive_proto.name()) <= 0, false,
132 "MindirModelLoader: There is a duplication primitive instance name: " << primitive_proto.name());
133 this->all_operators_[primitive_proto.name()] = base_operator;
134 }
135 return true;
136 }
137
ConvertGraph(const mind_ir::GraphProto & graph_proto,LiteGraph::SubGraph * sub_graph,bool is_main_graph)138 bool MindirModelLoader::ConvertGraph(const mind_ir::GraphProto &graph_proto, LiteGraph::SubGraph *sub_graph,
139 bool is_main_graph) {
140 MS_CHECK_TRUE_MSG(
141 ConvertTensors(graph_proto, sub_graph, is_main_graph), false,
142 "MindirModelLoader: Convert Graph failed, convert tensors error, please check the correctness of the file.");
143 MS_CHECK_TRUE_MSG(
144 ConvertNodes(graph_proto, sub_graph, is_main_graph), false,
145 "MindirModelLoader: Convert Graph failed, convert nodes error, please check the correctness of the file.");
146 return true;
147 }
148
ConvertTensors(const mind_ir::GraphProto & graph_proto,LiteGraph::SubGraph * sub_graph,bool is_main_graph)149 bool MindirModelLoader::ConvertTensors(const mind_ir::GraphProto &graph_proto, LiteGraph::SubGraph *sub_graph,
150 bool is_main_graph) {
151 for (int i = 0; i < graph_proto.input_size(); i++) {
152 const mind_ir::TensorProto &tensor_proto = graph_proto.input(i).tensor(0);
153 TensorProtoWrap tensor_wrap(graph_proto.input(i).name(), tensor_proto);
154 this->model_->all_mindir_tensors_.push_back(tensor_wrap);
155 this->tensor_index_map_[graph_proto.input(i).name()] = this->tensor_count_;
156 if (sub_graph != nullptr) {
157 sub_graph->input_indices_.push_back(this->tensor_count_);
158 sub_graph->tensor_indices_.push_back(this->tensor_count_);
159 }
160 if (is_main_graph) {
161 this->model_->graph_.input_indices_.push_back(this->tensor_count_);
162 }
163 this->tensor_count_++;
164 }
165 for (int i = 0; i < graph_proto.output_size(); i++) {
166 const mind_ir::TensorProto &tensor_proto = graph_proto.output(i).tensor(0);
167 TensorProtoWrap tensor_wrap(graph_proto.output(i).name(), tensor_proto);
168 this->model_->all_mindir_tensors_.push_back(tensor_wrap);
169 this->tensor_index_map_[graph_proto.output(i).name()] = this->tensor_count_;
170 if (sub_graph != nullptr) {
171 sub_graph->output_indices_.push_back(this->tensor_count_);
172 sub_graph->tensor_indices_.push_back(this->tensor_count_);
173 }
174 if (is_main_graph) {
175 this->model_->graph_.output_indices_.push_back(this->tensor_count_);
176 }
177 this->tensor_count_++;
178 }
179 for (int i = 0; i < graph_proto.parameter_size(); i++) {
180 const mind_ir::TensorProto &tensor_proto = graph_proto.parameter(i);
181 TensorProtoWrap tensor_wrap(tensor_proto.name(), tensor_proto);
182 this->model_->all_mindir_tensors_.push_back(tensor_wrap);
183 this->tensor_index_map_[tensor_proto.name()] = this->tensor_count_;
184 if (sub_graph != nullptr) {
185 sub_graph->tensor_indices_.push_back(this->tensor_count_);
186 }
187 this->tensor_count_++;
188 }
189 return true;
190 }
191
ConvertNodes(const mind_ir::GraphProto & graph_proto,LiteGraph::SubGraph * sub_graph,bool is_main_graph)192 bool MindirModelLoader::ConvertNodes(const mind_ir::GraphProto &graph_proto, LiteGraph::SubGraph *sub_graph,
193 bool is_main_graph) {
194 for (int i = 0; i < graph_proto.node_size(); i++) {
195 auto node_proto = graph_proto.node(i);
196 if (node_proto.op_type() == kNodeTypeConstant) {
197 // Constant node, convert to tensor
198 for (int j = 0; j < node_proto.attribute_size(); j++) {
199 auto attribute_proto = node_proto.attribute(j);
200 if (attribute_proto.type() == mind_ir::AttributeProto_AttributeType_TENSORS) {
201 const mind_ir::TensorProto &tensor_proto = attribute_proto.tensors(0);
202 TensorProtoWrap tensor_wrap(node_proto.name(), tensor_proto);
203 this->model_->all_mindir_tensors_.push_back(tensor_wrap);
204 this->tensor_index_map_[node_proto.name()] = this->tensor_count_;
205 if (sub_graph != nullptr) {
206 sub_graph->tensor_indices_.push_back(this->tensor_count_);
207 }
208 this->tensor_count_++;
209 }
210 }
211 continue;
212 }
213 auto *node = new LiteGraph::Node();
214 if (node == nullptr) {
215 MS_LOG(ERROR) << "MindirModelLoader: Convert nodes failed, new node failed.";
216 return false;
217 }
218 node->name_ = node_proto.name();
219 node->base_operator_ = this->MakePrimitiveC(node_proto.op_type());
220 auto base_operator = std::reinterpret_pointer_cast<ops::BaseOperator>(node->base_operator_);
221 node->op_type_ = base_operator->GetPrim()->instance_name();
222
223 // solve input
224 for (int j = 0; j < node_proto.input_size(); j++) {
225 std::string input_name = node_proto.input(j);
226 auto it = this->tensor_index_map_.find(input_name);
227 if (it == this->tensor_index_map_.end()) {
228 MS_LOG(WARNING) << "MindirModelLoader: Convert nodes failed, cannot find input index with " << input_name;
229 continue;
230 }
231 node->input_indices_.push_back(it->second);
232 }
233
234 // solve output
235 for (int j = 0; j < node_proto.output_size(); j++) {
236 std::string output_name = node_proto.output(j);
237 auto it = this->tensor_index_map_.find(output_name);
238 if (it == this->tensor_index_map_.end()) {
239 MS_LOG(WARNING) << "MindirModelLoader: Convert nodes failed, cannot find output index with " << output_name;
240 continue;
241 }
242 node->output_indices_.push_back(it->second);
243 }
244
245 this->model_->graph_.all_nodes_.push_back(node);
246 if (sub_graph != nullptr) {
247 sub_graph->node_indices_.push_back(this->node_count_);
248 }
249 this->node_count_++;
250 }
251 return true;
252 }
253
MakePrimitiveC(const std::string & node_type)254 std::shared_ptr<void> MindirModelLoader::MakePrimitiveC(const std::string &node_type) {
255 const std::string kOperatorTypeFlag = std::string("REF::");
256 const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
257 if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
258 auto it = this->all_operators_.find(node_type.substr(kOpTypeFlagSize));
259 if (it == this->all_operators_.end()) {
260 MS_LOG(ERROR) << "MindirModelLoader: make primitiveC failed, can't find the primitive ref:" << node_type;
261 return nullptr;
262 }
263 return it->second;
264 }
265
266 // node_type is not ref: pointer
267 auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
268 if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
269 // registered primitive
270 auto prim = (op_primc_fns[node_type]());
271 MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "MindirModelLoader: Make primitiveC failed, the prim is nullptr.");
272 prim->set_instance_name(node_type);
273 static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
274 auto op_it = operator_fns.find(node_type);
275 if (op_it == operator_fns.end()) {
276 MS_LOG(WARNING) << "MindirModelLoader: Make primitiveC failed, unsupported op operator type: " << node_type;
277 return nullptr;
278 }
279 return op_it->second(prim);
280 } else {
281 // S_Prim_xxx or S_Prim_hyper_map[xxx] and custom node type, now not support
282 MS_LOG(ERROR) << "MindirModelLoader: make primitiveC failed, not support node type: " << node_type;
283 return nullptr;
284 }
285 }
286
MindirModelLoaderCreator()287 static std::shared_ptr<ModelLoader> MindirModelLoaderCreator() { return std::make_shared<MindirModelLoader>(); }
288
289 REG_MODEL_LOADER(kMindIR, MindirModelLoaderCreator);
290 } // namespace mindspore::infer::mindir
291