1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "load_mindir/anf_model_parser.h"
18 #include <climits>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <stack>
23 #include <string>
24 #include <vector>
25 #include <unordered_map>
26 #include <utility>
27 #include "ir/tensor.h"
28 #include "ir/param_info.h"
29 #include "ops/primitive_c.h"
30 #include "abstract/abstract_value.h"
31 #include "utils/log_adapter.h"
32 #include "utils/shape_utils.h"
33 #include "utils/check_convert_utils.h"
34
35 using std::string;
36
37 namespace mindspore {
38 std::map<std::string, tensor::TensorPtr> MSANFModelParser::load_tensor_map_;
39 static constexpr char kConstantValueNode[] = "Constant";
40 static constexpr char kCNodeShapeAttr[] = "shape";
41 static constexpr char kCNodeShape1Attr[] = "shape1";
42 static constexpr char kCNodeShape2Attr[] = "shape2";
43 static constexpr char kDoSignaturePrimitivePrefix[] = "S-Prim-";
44 static constexpr char kHyperMapPrefix[] = "hyper_map";
45
46 enum ParseForm : int {
47 FORM_PARSE_TYPE = 0,
48 FORM_PARSE_SCALAR = 1,
49 FORM_PARSE_TENSOR = 2,
50 FORM_PARSE_NONE = 3,
51 FORM_PARSE_MONAD = 4,
52 FORM_PARSE_UNDEFINE = 5,
53 };
54
55 static std::map<std::string, ParseForm> kParseTypeSwitchMap{
56 {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR},
57 {"none", FORM_PARSE_NONE}, {"Monad", FORM_PARSE_MONAD}, {"", FORM_PARSE_UNDEFINE}};
58
59 static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
60 {mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool},
61 {mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8},
62 {mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16},
63 {mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32},
64 {mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64},
65 {mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8},
66 {mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16},
67 {mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32},
68 {mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64},
69 {mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
70 {mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
71 {mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64},
72 {mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
73 {mind_ir::TensorProto_DataType_STRING, kObjectTypeString},
74 };
75
76 template <typename T, typename P>
ParserAttr(const std::string & str,const std::unordered_map<string,P> & kv)77 std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) {
78 std::stack<std::string> rules;
79 std::stack<P> value;
80 int count = 0;
81 for (size_t i = 0; i < str.length(); i++) {
82 if (str[i] == '[') {
83 rules.push(std::string("["));
84 } else if (str[i] == ']') {
85 // rules
86 std::vector<P> vec;
87 while (rules.top() != "[") {
88 rules.pop();
89 vec.push_back(value.top());
90 value.pop();
91 }
92 // pop "["
93 rules.pop();
94 // make tuple for names
95 std::string res = "dummy";
96 // make tuple for values
97 reverse(vec.begin(), vec.end());
98 auto vt = std::make_shared<T>(vec);
99 if (rules.empty() && value.empty()) {
100 return vt;
101 }
102 rules.push(res);
103 value.push(vt);
104 } else if (str[i] == ',') {
105 continue;
106 } else {
107 count++;
108 if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
109 auto value_name = str.substr(static_cast<int>(i) - count + 1, count);
110 if (kv.find(value_name) == kv.end()) {
111 MS_LOG(ERROR) << "Node's attributes and shape do not match.";
112 return nullptr;
113 }
114 value.push(kv.at(value_name));
115 rules.push(value_name);
116 count = 0;
117 }
118 }
119 }
120 return {};
121 }
122
123 template <typename T>
ParserScalarAttrValue(const std::string & attr_name,const std::unordered_map<string,ValuePtr> & kv)124 std::shared_ptr<T> ParserScalarAttrValue(const std::string &attr_name, const std::unordered_map<string, ValuePtr> &kv) {
125 std::string str = attr_name;
126 auto replace = [&](const string &orgStr, const string &newStr) {
127 std::string::size_type pos(0);
128 while ((pos = str.find(orgStr)) != std::string::npos) {
129 str.replace(pos, orgStr.length(), newStr);
130 }
131 return str;
132 };
133 // remove "scalar:"
134 str = replace("scalar:", "");
135 // remove "Tuple"
136 str = replace("Tuple", "");
137 // remove "List"
138 str = replace("List", "");
139 auto result = ParserAttr<T>(str, kv);
140 return result;
141 }
142
ParserAttrShape(const std::string & attr_name,const std::unordered_map<string,abstract::AbstractBasePtr> & kv)143 std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
144 const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) {
145 std::string str = attr_name;
146 auto replace = [&](const string &orgStr, const string &newStr) {
147 std::string::size_type pos(0);
148 while ((pos = str.find(orgStr)) != std::string::npos) {
149 str.replace(pos, orgStr.length(), newStr);
150 }
151 return str;
152 };
153 // remove "scalar:"
154 str = replace("shape:", "");
155 // remove "Tuple"
156 str = replace("Tuple", "");
157 // remove "List"
158 str = replace("List", "");
159
160 auto result = ParserAttr<abstract::AbstractTuple>(str, kv);
161 return result;
162 }
163
ParseParameterName(const string & name)164 std::string ParseParameterName(const string &name) {
165 string delimiter = ":";
166 size_t pos(0);
167 if ((pos = name.find(delimiter)) != string::npos) {
168 return name.substr(pos + 1, string::npos - (pos + 1));
169 }
170 return name;
171 }
172
ParseCNodeName(const string & name)173 std::string ParseCNodeName(const string &name) {
174 string delimiter = ":";
175 size_t pos = name.find(delimiter);
176 size_t end_pos = name.find_last_of(delimiter);
177 if (pos != string::npos && end_pos != string::npos && pos != end_pos) {
178 return name.substr(pos + 1, end_pos - (pos + 1));
179 }
180 return name;
181 }
182
183 #define PARSE_MINDIR_ATTR_IN_INT_FORM(type, valuetype) \
184 ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
185 auto value = static_cast<valuetype>(attr_proto.ints(index)); \
186 return MakeValue<valuetype>(value); \
187 } \
188 ValuePtr ParseAttrInSingleScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto) { \
189 auto value = static_cast<valuetype>(attr_proto.i()); \
190 return MakeValue<valuetype>(value); \
191 }
192
193 #define PARSE_MINDIR_ATTR_IN_SCALAR_FORM(type, valuetype) \
194 ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
195 auto value = static_cast<valuetype>(attr_proto.type##s(index)); \
196 return MakeValue<valuetype>(value); \
197 }
198
PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t,int8_t)199 PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t, int8_t)
200 PARSE_MINDIR_ATTR_IN_INT_FORM(int16_t, int16_t)
201 PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, int32_t)
202 PARSE_MINDIR_ATTR_IN_INT_FORM(int64_t, int64_t)
203 PARSE_MINDIR_ATTR_IN_INT_FORM(uint8_t, uint8_t)
204 PARSE_MINDIR_ATTR_IN_INT_FORM(uint16_t, uint16_t)
205 PARSE_MINDIR_ATTR_IN_INT_FORM(uint32_t, uint32_t)
206 PARSE_MINDIR_ATTR_IN_INT_FORM(uint64_t, uint64_t)
207 PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, bool)
208
209 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(double, double)
210 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(float, float)
211 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(string, string)
212
213 ValuePtr ParseAttrInSingleScalar_string_string(const mind_ir::AttributeProto &attr_proto) {
214 auto value = static_cast<string>(attr_proto.s());
215 return MakeValue<string>(value);
216 }
217
ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto & attr_proto)218 ValuePtr ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto &attr_proto) {
219 auto value = static_cast<float>(attr_proto.f());
220 return MakeValue<float>(value);
221 }
222
ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto & attr_proto)223 ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &attr_proto) {
224 auto value = static_cast<double>(attr_proto.d());
225 return MakeValue<double>(value);
226 }
227
BuildTensorInfoForFuncGraph(const mind_ir::TensorProto & tensor_proto)228 tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) {
229 ShapeVector shape;
230 for (int i = 0; i < tensor_proto.dims_size(); ++i) {
231 shape.push_back(tensor_proto.dims(i));
232 }
233
234 if (!tensor_proto.has_data_type()) {
235 MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
236 MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type.";
237 }
238 if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
239 MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
240 MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
241 }
242
243 tensor::TensorPtr tensor_info =
244 std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
245 return tensor_info;
246 }
247
BuildParameterForFuncGraph(const ParameterPtr & node,const mind_ir::TensorProto & parameter_proto)248 bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
249 const mind_ir::TensorProto ¶meter_proto) {
250 MS_EXCEPTION_IF_NULL(node);
251
252 if (!parameter_proto.has_name()) {
253 MS_LOG(ERROR) << "mind_ir TensorProto has no name!";
254 return false;
255 }
256 string debug_info_name = ParseParameterName(parameter_proto.name());
257 auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
258 node->set_debug_info(debug_info_ptr);
259 node->set_name(debug_info_name);
260
261 tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
262 MS_EXCEPTION_IF_NULL(tensor_info);
263 MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name;
264 if (!IsIncLoad() || load_tensor_map_.find(debug_info_name) == load_tensor_map_.end()) {
265 load_tensor_map_[debug_info_name] = tensor_info;
266 } else {
267 MS_LOG(DEBUG) << "Parameter: " << debug_info_name << " has been already loaded, use it again.";
268 tensor::TensorPtr load_tensor_info = load_tensor_map_[debug_info_name];
269 auto tensor_abstract = load_tensor_info->ToAbstract();
270 MS_EXCEPTION_IF_NULL(tensor_abstract);
271 node->set_abstract(tensor_abstract);
272 node->set_default_param(load_tensor_info);
273 anfnode_build_map_[parameter_proto.name()] = node;
274 return true;
275 }
276 ParamInfoPtr param_info = std::make_shared<ParamInfo>();
277 param_info->set_name(debug_info_name);
278 tensor_info->set_param_info(param_info);
279
280 auto tensor_abstract = tensor_info->ToAbstract();
281 MS_EXCEPTION_IF_NULL(tensor_abstract);
282 node->set_abstract(tensor_abstract);
283
284 std::string initial_data = parameter_proto.raw_data();
285 auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
286 MS_EXCEPTION_IF_NULL(tensor_data_buf);
287 auto ret = memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), initial_data.data(),
288 initial_data.size());
289 if (ret != 0) {
290 MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
291 return false;
292 }
293
294 node->set_default_param(tensor_info);
295
296 anfnode_build_map_[parameter_proto.name()] = node;
297 return true;
298 }
299
BuildInputForFuncGraph(const ParameterPtr & node,const mind_ir::ValueInfoProto & value_proto)300 bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto) {
301 MS_EXCEPTION_IF_NULL(node);
302
303 if (!value_proto.has_name()) {
304 MS_LOG(ERROR) << "mind_ir ValueInfoProto has no name!";
305 return false;
306 }
307 string debug_info_name = ParseParameterName(value_proto.name());
308 auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
309 node->set_debug_info(debug_info_ptr);
310 node->set_name(debug_info_name);
311
312 // Set abstract of the parameter
313 if (value_proto.tensor_size() > 0) {
314 const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
315 tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
316 MS_EXCEPTION_IF_NULL(tensor_info);
317 auto tensor_abstract = tensor_info->ToAbstract();
318 node->set_abstract(tensor_abstract);
319 } else if (value_proto.has_denotation()) {
320 MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
321 }
322 anfnode_build_map_[value_proto.name()] = node;
323 return true;
324 }
325
ImportParametersForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)326 bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
327 const mind_ir::GraphProto &importProto) {
328 MS_EXCEPTION_IF_NULL(outputFuncGraph);
329 MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
330 for (int i = 0; i < importProto.input_size(); ++i) {
331 const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
332 if (!BuildInputForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
333 MS_LOG(ERROR) << "Build input for funcgraph fail at index: " << i;
334 return false;
335 }
336 }
337
338 MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
339 for (int i = 0; i < importProto.parameter_size(); ++i) {
340 const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i);
341 if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
342 MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
343 return false;
344 }
345 }
346 return true;
347 }
348
ObtainCNodeAttrInTypeForm(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)349 bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
350 MS_EXCEPTION_IF_NULL(prim);
351 const int attr_tensor_type = attr_proto.tensors(0).data_type();
352 if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
353 MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
354 return false;
355 }
356 prim->AddAttr(attr_proto.name(), TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
357 return true;
358 }
359
ParseAttrInScalarForm(const mind_ir::AttributeProto & attr_proto,int index)360 ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index) {
361 const int attr_type = attr_proto.type();
362 switch (attr_type) {
363 case mind_ir::AttributeProto_AttributeType_STRING: {
364 return ParseAttrInScalar_string_string(attr_proto, index);
365 }
366 case mind_ir::AttributeProto_AttributeType_INT8: {
367 return ParseAttrInScalar_int8_t_int8_t(attr_proto, index);
368 }
369 case mind_ir::AttributeProto_AttributeType_INT16: {
370 return ParseAttrInScalar_int16_t_int16_t(attr_proto, index);
371 }
372 case mind_ir::AttributeProto_AttributeType_INT32: {
373 return ParseAttrInScalar_int32_t_int32_t(attr_proto, index);
374 }
375 case mind_ir::AttributeProto_AttributeType_INT64: {
376 return ParseAttrInScalar_int64_t_int64_t(attr_proto, index);
377 }
378 case mind_ir::AttributeProto_AttributeType_UINT8: {
379 return ParseAttrInScalar_uint8_t_uint8_t(attr_proto, index);
380 }
381 case mind_ir::AttributeProto_AttributeType_UINT16: {
382 return ParseAttrInScalar_uint16_t_uint16_t(attr_proto, index);
383 }
384 case mind_ir::AttributeProto_AttributeType_UINT32: {
385 return ParseAttrInScalar_uint32_t_uint32_t(attr_proto, index);
386 }
387 case mind_ir::AttributeProto_AttributeType_UINT64: {
388 return ParseAttrInScalar_uint64_t_uint64_t(attr_proto, index);
389 }
390 case mind_ir::AttributeProto_AttributeType_FLOAT: {
391 return ParseAttrInScalar_float_float(attr_proto, index);
392 }
393 case mind_ir::AttributeProto_AttributeType_DOUBLE: {
394 return ParseAttrInScalar_double_double(attr_proto, index);
395 }
396 case mind_ir::AttributeProto_AttributeType_BOOL: {
397 return ParseAttrInScalar_int32_t_bool(attr_proto, index);
398 }
399 case mind_ir::AttributeProto_AttributeType_TENSORS: {
400 const int attr_tensor_type = attr_proto.tensors(index).data_type();
401 if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
402 MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
403 return {};
404 }
405 return TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
406 }
407 default:
408 MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
409 return {};
410 }
411 return {};
412 }
413
ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto & attr_proto,std::unordered_map<std::string,ValuePtr> * multi_value_map)414 void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
415 std::unordered_map<std::string, ValuePtr> *multi_value_map) {
416 string name;
417 auto func = [&name, &multi_value_map, this](const mind_ir::AttributeProto &attr_proto, int i) -> void {
418 auto res = this->ParseAttrInScalarForm(attr_proto, i);
419 name = "value" + std::to_string(i + 1);
420 (void)multi_value_map->emplace(name, res);
421 };
422 for (int i = 0; i < attr_proto.ints_size(); i++) {
423 func(attr_proto, i);
424 }
425 for (int i = 0; i < attr_proto.doubles_size(); i++) {
426 func(attr_proto, i);
427 }
428 for (int i = 0; i < attr_proto.floats_size(); i++) {
429 func(attr_proto, i);
430 }
431 for (int i = 0; i < attr_proto.strings_size(); i++) {
432 func(attr_proto, i);
433 }
434 for (int i = 0; i < attr_proto.tensors_size(); i++) {
435 func(attr_proto, i);
436 }
437 }
438
ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto & attr_proto) const439 ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) const {
440 const int attr_type = attr_proto.type();
441 switch (attr_type) {
442 case mind_ir::AttributeProto_AttributeType_STRING: {
443 return ParseAttrInSingleScalar_string_string(attr_proto);
444 }
445 case mind_ir::AttributeProto_AttributeType_INT8: {
446 return ParseAttrInSingleScalar_int8_t_int8_t(attr_proto);
447 }
448 case mind_ir::AttributeProto_AttributeType_INT16: {
449 return ParseAttrInSingleScalar_int16_t_int16_t(attr_proto);
450 }
451 case mind_ir::AttributeProto_AttributeType_INT32: {
452 return ParseAttrInSingleScalar_int32_t_int32_t(attr_proto);
453 }
454 case mind_ir::AttributeProto_AttributeType_INT64: {
455 return ParseAttrInSingleScalar_int64_t_int64_t(attr_proto);
456 }
457 case mind_ir::AttributeProto_AttributeType_UINT8: {
458 return ParseAttrInSingleScalar_uint8_t_uint8_t(attr_proto);
459 }
460 case mind_ir::AttributeProto_AttributeType_UINT16: {
461 return ParseAttrInSingleScalar_uint16_t_uint16_t(attr_proto);
462 }
463 case mind_ir::AttributeProto_AttributeType_UINT32: {
464 return ParseAttrInSingleScalar_uint32_t_uint32_t(attr_proto);
465 }
466 case mind_ir::AttributeProto_AttributeType_UINT64: {
467 return ParseAttrInSingleScalar_uint64_t_uint64_t(attr_proto);
468 }
469 case mind_ir::AttributeProto_AttributeType_FLOAT: {
470 return ParseAttrInSingleScalar_float_float(attr_proto);
471 }
472 case mind_ir::AttributeProto_AttributeType_DOUBLE: {
473 return ParseAttrInSingleScalar_double_double(attr_proto);
474 }
475 case mind_ir::AttributeProto_AttributeType_BOOL: {
476 return ParseAttrInSingleScalar_int32_t_bool(attr_proto);
477 }
478 default:
479 MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
480 return {};
481 }
482 return {};
483 }
484
ObtainCNodeAttrInTensorForm(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)485 bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
486 const mind_ir::AttributeProto &attr_proto) {
487 MS_EXCEPTION_IF_NULL(prim);
488 const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0);
489 const int attr_tensor_type = attr_tensor.data_type();
490 ShapeVector shape;
491 for (int i = 0; i < attr_tensor.dims_size(); ++i) {
492 shape.push_back(attr_tensor.dims(i));
493 }
494 tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
495 MS_EXCEPTION_IF_NULL(tensor_info);
496 const std::string &tensor_buf = attr_tensor.raw_data();
497 auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
498 auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
499 if (ret != 0) {
500 MS_LOG(ERROR) << "Obtain CNode in TensorForm occur memcpy_s error.";
501 return false;
502 }
503 prim->AddAttr(attr_proto.name(), MakeValue(tensor_info));
504 return true;
505 }
506
GetTypeString(const std::string & ref_attr_name,size_t * pos)507 string GetTypeString(const std::string &ref_attr_name, size_t *pos) {
508 if ((*pos = ref_attr_name.find("scalar:")) != std::string::npos) {
509 return ref_attr_name.substr(*pos, string("scalar:").length() - 1);
510 } else if ((*pos = ref_attr_name.find("type:")) != std::string::npos) {
511 return ref_attr_name.substr(*pos, string("type:").length() - 1);
512 } else if ((*pos = ref_attr_name.find("tensor:")) != std::string::npos) {
513 return ref_attr_name.substr(*pos, string("tensor:").length() - 1);
514 }
515 return "";
516 }
517
GetAttrValueForCNode(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)518 bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
519 MS_EXCEPTION_IF_NULL(prim);
520 const std::string &attr_name = attr_proto.name();
521 if (!attr_proto.has_ref_attr_name()) {
522 MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
523 return false;
524 }
525 const std::string &ref_attr_name = attr_proto.ref_attr_name();
526
527 std::size_t pos(0);
528 string type = GetTypeString(ref_attr_name, &pos);
529 std::unordered_map<std::string, ValuePtr> multi_value_map;
530 switch (kParseTypeSwitchMap[type]) {
531 case FORM_PARSE_TYPE: {
532 ObtainCNodeAttrInTypeForm(prim, attr_proto);
533 break;
534 }
535 case FORM_PARSE_SCALAR: {
536 std::size_t value_pos(0);
537 if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
538 ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
539 const std::string &op_type = prim->name();
540 if (!IsLite()) {
541 CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res);
542 }
543 if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && res->isa<StringImm>()) {
544 auto str_dtype = GetValue<std::string>(res);
545 if (str_dtype == "int32") {
546 const int64_t attr_value = 3;
547 (void)prim->AddAttr(attr_name, MakeValue<int64_t>(attr_value));
548 break;
549 }
550 MS_EXCEPTION(NotSupportError)
551 << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
552 << res->ToString();
553 }
554 prim->AddAttr(attr_name, res);
555 break;
556 }
557 ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
558 break;
559 }
560 case FORM_PARSE_TENSOR: {
561 ObtainCNodeAttrInTensorForm(prim, attr_proto);
562 break;
563 }
564 default:
565 MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
566 return false;
567 }
568
569 if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
570 if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) {
571 auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
572 prim->AddAttr(attr_name, value_tuple_ptr);
573 } else {
574 auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
575 prim->AddAttr(attr_name, value_list_ptr);
576 }
577 }
578 return true;
579 }
580
ObtainValueNodeInTensorForm(const std::string & value_node_name,const mind_ir::TensorProto & attr_tensor)581 bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
582 const mind_ir::TensorProto &attr_tensor) {
583 const int attr_tensor_type = attr_tensor.data_type();
584 ShapeVector shape;
585 for (int i = 0; i < attr_tensor.dims_size(); ++i) {
586 shape.push_back(attr_tensor.dims(i));
587 }
588 tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
589 MS_EXCEPTION_IF_NULL(tensor_info);
590 const std::string &tensor_buf = attr_tensor.raw_data();
591 auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
592 auto ret =
593 memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), tensor_buf.data(), tensor_buf.size());
594 if (ret != 0) {
595 MS_LOG(ERROR) << "Obtain ValueNode in TensorForm occur memcpy_s error.";
596 return false;
597 }
598
599 auto new_value_node = NewValueNode(MakeValue(tensor_info));
600 MS_EXCEPTION_IF_NULL(new_value_node);
601 auto tensor_abstract = tensor_info->ToAbstract();
602 MS_EXCEPTION_IF_NULL(tensor_abstract);
603 new_value_node->set_abstract(tensor_abstract);
604 anfnode_build_map_[value_node_name] = new_value_node;
605 return true;
606 }
607
ObtainValueNodeInTupleTensorForm(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)608 bool MSANFModelParser::ObtainValueNodeInTupleTensorForm(const std::string &value_node_name,
609 const mind_ir::AttributeProto &attr_proto) {
610 std::vector<tensor::TensorPtr> tensor_vec;
611 for (int i = 0; i < attr_proto.tensors_size(); ++i) {
612 mind_ir::TensorProto attr_tensor = attr_proto.tensors(i);
613 const int attr_tensor_type = attr_tensor.data_type();
614 ShapeVector shape;
615 for (int j = 0; j < attr_tensor.dims_size(); ++j) {
616 shape.push_back(attr_tensor.dims(j));
617 }
618 tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
619 const std::string &tensor_buf = attr_tensor.raw_data();
620 auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
621 auto ret = memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), tensor_buf.data(),
622 tensor_buf.size());
623 if (ret != 0) {
624 MS_LOG(ERROR) << "Obtain ValueNode in TupleTensorForm occur memcpy_s error.";
625 return false;
626 }
627 tensor_vec.push_back(tensor_info);
628 }
629 auto new_value_node = NewValueNode(MakeValue(tensor_vec));
630 anfnode_build_map_[value_node_name] = new_value_node;
631 return true;
632 }
633
ObtainValueNodeInTypeForm(const std::string & value_node_name,const mind_ir::TensorProto & attr_tensor)634 bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
635 const mind_ir::TensorProto &attr_tensor) {
636 const int attr_tensor_type = attr_tensor.data_type();
637 if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
638 MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
639 return false;
640 }
641 auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
642 abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
643 new_value_node->set_abstract(abs_type);
644 anfnode_build_map_[value_node_name] = new_value_node;
645 return true;
646 }
647
ObtainValueNodeInNoneForm(const std::string & value_node_name)648 bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_name) {
649 auto new_value_node = NewValueNode(kNone);
650 MS_EXCEPTION_IF_NULL(new_value_node);
651 new_value_node->set_abstract(kNone->ToAbstract());
652 anfnode_build_map_[value_node_name] = new_value_node;
653 return true;
654 }
655
ObtainValueNodeInMonadForm(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)656 bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name,
657 const mind_ir::AttributeProto &attr_proto) {
658 const std::string &ref_attr_name = attr_proto.ref_attr_name();
659 if (ref_attr_name.find("UMonad") != std::string::npos) {
660 auto monad_abs = kUMonad->ToAbstract();
661 auto new_value_node = NewValueNode(kUMonad);
662 MS_EXCEPTION_IF_NULL(new_value_node);
663 new_value_node->set_abstract(monad_abs);
664 anfnode_build_map_[value_node_name] = new_value_node;
665 } else if (ref_attr_name.find("IOMonad") != std::string::npos) {
666 auto monad_abs = kIOMonad->ToAbstract();
667 auto new_value_node = NewValueNode(kIOMonad);
668 MS_EXCEPTION_IF_NULL(new_value_node);
669 new_value_node->set_abstract(monad_abs);
670 anfnode_build_map_[value_node_name] = new_value_node;
671 } else {
672 return false;
673 }
674 return true;
675 }
676
677 namespace {
GetTypeFromAttrName(const std::string & ref_attr_name)678 std::string GetTypeFromAttrName(const std::string &ref_attr_name) {
679 string type = "";
680 std::size_t pos(0);
681 if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
682 return ref_attr_name.substr(pos, string("scalar:").length() - 1);
683 } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
684 return ref_attr_name.substr(pos, string("type:").length() - 1);
685 } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
686 return ref_attr_name.substr(pos, string("tensor:").length() - 1);
687 } else if ((pos = ref_attr_name.find("Monad:")) != std::string::npos) {
688 return ref_attr_name.substr(pos, string("Monad:").length() - 1);
689 } else if (ref_attr_name == "none") {
690 return ref_attr_name;
691 }
692 return type;
693 }
694 } // namespace
695
GetAttrValueForValueNode(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)696 bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
697 const mind_ir::AttributeProto &attr_proto) {
698 if (!attr_proto.has_ref_attr_name()) {
699 MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
700 return false;
701 }
702 const std::string &ref_attr_name = attr_proto.ref_attr_name();
703 auto type = GetTypeFromAttrName(ref_attr_name);
704 ValueNodePtr new_value_node;
705 std::unordered_map<std::string, ValuePtr> multi_value_map;
706 switch (kParseTypeSwitchMap[type]) {
707 case FORM_PARSE_TYPE: {
708 ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0));
709 break;
710 }
711 case FORM_PARSE_SCALAR: {
712 std::size_t value_pos(0);
713 if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
714 auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
715 new_value_node = NewValueNode(res);
716 new_value_node->set_abstract(res->ToAbstract());
717 anfnode_build_map_[value_node_name] = new_value_node;
718 break;
719 }
720 if ((value_pos = ref_attr_name.find("Tuple[]")) != std::string::npos) {
721 MS_LOG(INFO) << "Build Tuple() ValueNode for primitive.";
722 ValuePtr res = MakeValue(std::vector<ValuePtr>{});
723 new_value_node = NewValueNode(res);
724 new_value_node->set_abstract(res->ToAbstract());
725 anfnode_build_map_[value_node_name] = new_value_node;
726 break;
727 }
728 if ((value_pos = ref_attr_name.find("Tuple[value")) != std::string::npos && attr_proto.tensors_size() > 1) {
729 MS_LOG(INFO) << "Build TupleTensor ValueNode for primitive.";
730 if (!ObtainValueNodeInTupleTensorForm(value_node_name, attr_proto)) {
731 MS_LOG(ERROR) << "Obtain valuenode in tuple tensor Form failed. ";
732 return false;
733 }
734 break;
735 }
736 ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
737 break;
738 }
739 case FORM_PARSE_TENSOR: {
740 (void)ObtainValueNodeInTensorForm(value_node_name, attr_proto.tensors(0));
741 break;
742 }
743 case FORM_PARSE_NONE: {
744 (void)ObtainValueNodeInNoneForm(value_node_name);
745 break;
746 }
747 case FORM_PARSE_MONAD: {
748 (void)ObtainValueNodeInMonadForm(value_node_name, attr_proto);
749 break;
750 }
751 default:
752 MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
753 return false;
754 }
755
756 if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
757 if (ref_attr_name.find("Tuple") != std::string::npos) {
758 auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
759 new_value_node = NewValueNode(value_tuple_ptr);
760 new_value_node->set_abstract(value_tuple_ptr->ToAbstract());
761 } else {
762 auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
763 new_value_node = NewValueNode(value_list_ptr);
764 new_value_node->set_abstract(value_list_ptr->ToAbstract());
765 }
766 anfnode_build_map_[value_node_name] = new_value_node;
767 }
768 return true;
769 }
770
BuildValueNodeForFuncGraph(const mind_ir::NodeProto & node_proto)771 bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto) {
772 const std::string &value_node_name = node_proto.output(0);
773 const mind_ir::AttributeProto &attr_proto = node_proto.attribute(0);
774 if (!attr_proto.has_ref_attr_name()) {
775 MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
776 return false;
777 }
778 return GetAttrValueForValueNode(value_node_name, attr_proto);
779 }
780
GetAbstractForCNode(const mind_ir::AttributeProto & attr_proto)781 std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
782 const mind_ir::AttributeProto &attr_proto) {
783 std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
784 for (int i = 0; i < attr_proto.tensors_size(); ++i) {
785 ShapeVector shape_vec;
786 const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i);
787 for (int j = 0; j < attr_tensor.dims_size(); ++j) {
788 shape_vec.push_back(attr_tensor.dims(j));
789 }
790 tensor::TensorPtr tensor_info =
791 std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
792 MS_EXCEPTION_IF_NULL(tensor_info);
793 auto abstract = tensor_info->ToAbstract();
794 MS_EXCEPTION_IF_NULL(abstract);
795 (void)kv.emplace(attr_tensor.name(), abstract);
796 }
797 return kv;
798 }
799
800 // S-Prim-xxx or S-Prim-hyper_map[xxx] -> xxx
GetDoSignaturePrimitiveName(const std::string & node_type)801 static std::string GetDoSignaturePrimitiveName(const std::string &node_type) {
802 // Remove `S-Prim-` prefix.
803 auto prim_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
804 if (prim_name.compare(0, strlen(kHyperMapPrefix), kHyperMapPrefix) != 0) {
805 return prim_name;
806 }
807 // hyper_map[xxx] -> xxx
808 constexpr auto offset = 2;
809 auto op_name = prim_name.substr(strlen(kHyperMapPrefix) + 1, prim_name.length() - strlen(kHyperMapPrefix) - offset);
810 return op_name;
811 }
812
BuildOperatorNode(const mind_ir::NodeProto & node_proto)813 AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_proto) {
814 const std::string kOperatorTypeFlag = std::string("REF::");
815 const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
816 const std::string &node_type = node_proto.op_type();
817 MS_LOG(DEBUG) << "Process Operator :" << node_type;
818 // Operator maybe CNode,FuncGraph or Parameter.
819
820 if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
821 auto anfNode = GetAnfNode(node_type.substr(kOpTypeFlagSize));
822 if (anfNode == nullptr) {
823 MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
824 }
825 return anfNode;
826 }
827
828 // Operator is primitive.
829 std::shared_ptr<Primitive> prim;
830 auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
831 if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
832 prim = op_primc_fns[node_type]();
833 } else {
834 if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
835 auto op_name = GetDoSignaturePrimitiveName(node_type);
836 prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
837 MS_EXCEPTION_IF_NULL(prim);
838 prim->set_instance_name(op_name);
839 } else {
840 MS_LOG(DEBUG) << "Special node_type: " << node_type;
841 prim = std::make_shared<Primitive>(node_type);
842 MS_EXCEPTION_IF_NULL(prim);
843 prim->set_instance_name(node_type);
844 }
845 }
846 MS_EXCEPTION_IF_NULL(prim);
847 for (int i = 0; i < node_proto.attribute_size(); ++i) {
848 const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
849 // CNode abstract
850 if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
851 continue;
852 }
853 if (!GetAttrValueForCNode(prim, attr_proto)) {
854 MS_LOG(EXCEPTION) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
855 }
856 }
857 prim->set_attr("is_load", MakeValue(true));
858 return std::make_shared<ValueNode>(prim);
859 }
860
861 // Set CNode abstract.
SetCNodeAbastract(const mind_ir::NodeProto & node_proto,CNodePtr cnode_ptr)862 void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) {
863 const std::string &node_type = node_proto.op_type();
864 // Handle control flow operator.
865 auto operatorPtr = cnode_ptr->input(0);
866 // Set abstract of switch(c,f,t),switchLayer(c,tup) and
867 // partial(func,args) to null
868 auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
869 if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
870 IsPrimitiveEquals(prim::kPrimPartial, prim)) {
871 cnode_ptr->set_abstract(nullptr);
872 return;
873 }
874
875 // If the operator is not a primitive, the abstract will been set to null.
876 // Because there are not some operators in front end, the abstract of primitive should be reserved.
877 if (prim == nullptr) {
878 cnode_ptr->set_abstract(nullptr);
879 return;
880 }
881
882 std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
883 string shape_ref_attr_name;
884
885 for (int i = 0; i < node_proto.attribute_size(); ++i) {
886 const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
887 if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
888 shape_ref_attr_name = attr_proto.ref_attr_name();
889 kv = GetAbstractForCNode(attr_proto);
890 break;
891 }
892 }
893
894 // Because there is not context in unit test,
895 // abstract->broaden() is replaced by abstract->set_value(kAnyValue).
896 if (kv.size() == 0) {
897 if (node_type == "UpdateState") {
898 cnode_ptr->set_abstract(kUMonad->ToAbstract());
899 } else if (node_type == "Depend") {
900 cnode_ptr->set_abstract(kBool->ToAbstract());
901 } else {
902 AbstractBasePtrList elem;
903 for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
904 auto abs = cnode_ptr->input(index)->abstract();
905 if (abs != nullptr) {
906 if (abs->GetValueTrack() == nullptr) {
907 abs->set_value(kAnyValue);
908 }
909 elem.push_back(abs);
910 }
911 }
912 if (!elem.empty()) {
913 cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
914 }
915 }
916 } else if (kv.size() == 1) {
917 std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
918 if (iter->second != nullptr) {
919 iter->second->set_value(kAnyValue);
920 cnode_ptr->set_abstract(iter->second);
921 }
922 } else {
923 auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
924 if (abstract == nullptr) {
925 cnode_ptr->set_abstract(nullptr);
926 MS_LOG(ERROR) << "Node's attribute is nullptr.";
927 } else {
928 abstract->set_value(kAnyValue);
929 cnode_ptr->set_abstract(abstract);
930 }
931 }
932 }
933
BuildCNodeForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::NodeProto & node_proto)934 CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
935 const mind_ir::NodeProto &node_proto) {
936 MS_EXCEPTION_IF_NULL(outputFuncGraph);
937 if (!node_proto.has_op_type()) {
938 MS_LOG(ERROR) << "Get CNode op_type failed!";
939 return nullptr;
940 }
941 const std::string &node_name = node_proto.output(0);
942 MS_LOG(DEBUG) << "Process CNode: " << node_name;
943 // Build inputs.
944 std::vector<AnfNodePtr> inputs;
945 inputs.push_back(BuildOperatorNode(node_proto));
946 for (int i = 0; i < node_proto.input_size(); ++i) {
947 auto anfNode = GetAnfNode(node_proto.input(i));
948 if (anfNode == nullptr) {
949 MS_LOG(ERROR) << node_name << " input " << i << node_proto.input(i) << "can't find in nodes have parsed";
950 return nullptr;
951 }
952 inputs.push_back(anfNode);
953 }
954
955 CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
956 MS_EXCEPTION_IF_NULL(cnode_ptr);
957 SetCNodeAbastract(node_proto, cnode_ptr);
958
959 const std::string &fullname_with_scope = node_proto.domain();
960 string debug_info_name = ParseCNodeName(node_name);
961 auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
962 cnode_ptr->set_debug_info(debug_info_ptr);
963 cnode_ptr->set_fullname_with_scope(fullname_with_scope);
964 cnode_ptr->set_load_flag(true);
965 if (anfnode_build_map_.count(node_name) > 0) {
966 MS_LOG(EXCEPTION) << "Duplicate CNode name: " << node_name;
967 }
968 anfnode_build_map_[node_name] = cnode_ptr;
969 return cnode_ptr;
970 }
971
BuildReturnForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)972 bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
973 const mind_ir::GraphProto &importProto) {
974 MS_EXCEPTION_IF_NULL(outputFuncGraph);
975 std::vector<AnfNodePtr> inputs;
976 if (importProto.output_size() > 1) {
977 inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
978 AbstractBasePtrList elem;
979 for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
980 const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
981 const std::string &out_tuple = output_node.name();
982 auto anfNode = GetAnfNode(out_tuple);
983 if (anfNode == nullptr) {
984 MS_LOG(ERROR) << "Miss return node: " << out_tuple;
985 return false;
986 }
987 inputs.push_back(anfNode);
988 elem.push_back(anfNode->abstract());
989 }
990 auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
991 MS_EXCEPTION_IF_NULL(maketuple_ptr);
992 maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
993 inputs.push_back(NewValueNode(prim::kPrimReturn));
994 inputs.push_back(maketuple_ptr);
995 auto return_node = outputFuncGraph->NewCNode(inputs);
996 MS_EXCEPTION_IF_NULL(return_node);
997 return_node->set_load_flag(true);
998 outputFuncGraph->set_return(return_node);
999 MS_LOG(DEBUG) << "Construct funcgraph finined, all success.";
1000 } else {
1001 inputs.clear();
1002 inputs.push_back(NewValueNode(prim::kPrimReturn));
1003 auto nodeName = importProto.output(0).name();
1004 auto anfNode = GetAnfNode(nodeName);
1005 if (anfNode == nullptr) {
1006 MS_LOG(ERROR) << "Miss return node: " << nodeName;
1007 return false;
1008 }
1009 inputs.push_back(anfNode);
1010 auto return_node = outputFuncGraph->NewCNode(inputs);
1011 MS_EXCEPTION_IF_NULL(return_node);
1012 return_node->set_load_flag(true);
1013 outputFuncGraph->set_return(return_node);
1014 MS_LOG(DEBUG) << "Construct funcgraph finined, all success!";
1015 }
1016 return true;
1017 }
1018
ImportNodesForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1019 bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
1020 const mind_ir::GraphProto &importProto) {
1021 MS_EXCEPTION_IF_NULL(outputFuncGraph);
1022 CNodePtr cnode_ptr = nullptr;
1023 for (int i = 0; i < importProto.node_size(); ++i) {
1024 const mind_ir::NodeProto &node_proto = importProto.node(i);
1025 const std::string &node_type = node_proto.op_type();
1026 if (node_type == kConstantValueNode) {
1027 if (!BuildValueNodeForFuncGraph(node_proto)) {
1028 MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: " << i;
1029 return false;
1030 }
1031 continue;
1032 }
1033 cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
1034 if (cnode_ptr == nullptr) {
1035 MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: " << i;
1036 return false;
1037 }
1038 }
1039
1040 return BuildReturnForFuncGraph(outputFuncGraph, importProto);
1041 }
1042
BuildFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1043 bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) {
1044 MS_EXCEPTION_IF_NULL(outputFuncGraph);
1045 GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
1046 MS_EXCEPTION_IF_NULL(debug_info_ptr);
1047 if (importProto.has_name()) {
1048 debug_info_ptr->set_name(importProto.name());
1049 } else {
1050 MS_LOG(ERROR) << "FuncGraph under converting has not name!";
1051 }
1052 if (importProto.has_bprop_hash()) {
1053 outputFuncGraph->set_bprop_hash(importProto.bprop_hash());
1054 }
1055
1056 if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
1057 MS_LOG(ERROR) << "import parameters for graph fail!";
1058 return false;
1059 }
1060 return ImportNodesForGraph(outputFuncGraph, importProto);
1061 }
1062
MSANFParseModelConfigureInfo(const mind_ir::ModelProto & model_proto)1063 bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto) {
1064 if (!model_proto.has_producer_name()) {
1065 MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
1066 return false;
1067 }
1068 producer_name_ = model_proto.producer_name();
1069 MS_LOG(INFO) << "producer_name :" << producer_name_;
1070
1071 if (!model_proto.has_model_version()) {
1072 MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
1073 return false;
1074 }
1075 model_version_ = model_proto.model_version();
1076 MS_LOG(INFO) << "producer_version : " << model_version_;
1077
1078 if (!model_proto.has_ir_version()) {
1079 MS_LOG(ERROR) << "Parse model version from pb file failed!";
1080 return false;
1081 }
1082 ir_version_ = model_proto.ir_version();
1083 MS_LOG(INFO) << "ir_version :" << ir_version_;
1084 return true;
1085 }
1086
Parse(const mind_ir::ModelProto & model_proto)1087 FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
1088 FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
1089 MS_EXCEPTION_IF_NULL(dstGraph);
1090 if (!MSANFParseModelConfigureInfo(model_proto)) {
1091 MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
1092 }
1093 const mind_ir::GraphProto &graphBuild = model_proto.graph();
1094
1095 // Forward declare FuncGraph name
1096 // Compatible with the previous proto.
1097 if (graphBuild.has_name()) {
1098 anfnode_build_map_[graphBuild.name()] = std::make_shared<ValueNode>(dstGraph);
1099 }
1100 for (int i = 0; i < model_proto.functions_size(); ++i) {
1101 FuncGraphPtr graph = std::make_shared<FuncGraph>();
1102 const auto &graph_proto = model_proto.functions(i);
1103 if (!graph_proto.has_name()) {
1104 MS_LOG(EXCEPTION) << "The function has not a name. Please export mindIR again. ";
1105 }
1106 if (anfnode_build_map_.count(graph_proto.name()) > 0) {
1107 MS_LOG(EXCEPTION) << "There is a duplication function graph name: " << graph_proto.name();
1108 }
1109 anfnode_build_map_[graph_proto.name()] = std::make_shared<ValueNode>(graph);
1110 }
1111
1112 // Parser the proto.
1113 if (!BuildFuncGraph(dstGraph, graphBuild)) {
1114 MS_LOG(ERROR) << "Build funcgraph failed!";
1115 return nullptr;
1116 }
1117 MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
1118 for (int i = 0; i < model_proto.functions_size(); ++i) {
1119 const auto &graph_proto = model_proto.functions(i);
1120 FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);
1121 if (!BuildFuncGraph(graph, graph_proto)) {
1122 MS_LOG(ERROR) << "Build funcgraph failed!";
1123 return nullptr;
1124 }
1125 MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name();
1126 }
1127 // Release resource
1128 anfnode_build_map_.clear();
1129 return dstGraph;
1130 }
1131
GetAnfNode(const std::string & node_name)1132 AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) {
1133 auto it = anfnode_build_map_.find(node_name);
1134 if (it == anfnode_build_map_.end()) {
1135 return nullptr;
1136 }
1137 FuncGraphPtr func_graph_ptr = GetValueNode<FuncGraphPtr>(it->second);
1138 if (func_graph_ptr) {
1139 return NewValueNode(func_graph_ptr);
1140 } else {
1141 return it->second;
1142 }
1143 }
1144 } // namespace mindspore
1145