• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "load_mindir/anf_model_parser.h"
18 #include <climits>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <stack>
23 #include <string>
24 #include <vector>
25 #include <unordered_map>
26 #include <utility>
27 #include "ir/tensor.h"
28 #include "ir/param_info.h"
29 #include "ops/primitive_c.h"
30 #include "abstract/abstract_value.h"
31 #include "utils/log_adapter.h"
32 #include "utils/shape_utils.h"
33 #include "utils/check_convert_utils.h"
34 
35 using std::string;
36 
37 namespace mindspore {
38 std::map<std::string, tensor::TensorPtr> MSANFModelParser::load_tensor_map_;
39 static constexpr char kConstantValueNode[] = "Constant";
40 static constexpr char kCNodeShapeAttr[] = "shape";
41 static constexpr char kCNodeShape1Attr[] = "shape1";
42 static constexpr char kCNodeShape2Attr[] = "shape2";
43 static constexpr char kDoSignaturePrimitivePrefix[] = "S-Prim-";
44 static constexpr char kHyperMapPrefix[] = "hyper_map";
45 
46 enum ParseForm : int {
47   FORM_PARSE_TYPE = 0,
48   FORM_PARSE_SCALAR = 1,
49   FORM_PARSE_TENSOR = 2,
50   FORM_PARSE_NONE = 3,
51   FORM_PARSE_MONAD = 4,
52   FORM_PARSE_UNDEFINE = 5,
53 };
54 
55 static std::map<std::string, ParseForm> kParseTypeSwitchMap{
56   {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR},
57   {"none", FORM_PARSE_NONE}, {"Monad", FORM_PARSE_MONAD},   {"", FORM_PARSE_UNDEFINE}};
58 
59 static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
60   {mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool},
61   {mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8},
62   {mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16},
63   {mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32},
64   {mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64},
65   {mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8},
66   {mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16},
67   {mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32},
68   {mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64},
69   {mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
70   {mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
71   {mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64},
72   {mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
73   {mind_ir::TensorProto_DataType_STRING, kObjectTypeString},
74 };
75 
76 template <typename T, typename P>
ParserAttr(const std::string & str,const std::unordered_map<string,P> & kv)77 std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) {
78   std::stack<std::string> rules;
79   std::stack<P> value;
80   int count = 0;
81   for (size_t i = 0; i < str.length(); i++) {
82     if (str[i] == '[') {
83       rules.push(std::string("["));
84     } else if (str[i] == ']') {
85       // rules
86       std::vector<P> vec;
87       while (rules.top() != "[") {
88         rules.pop();
89         vec.push_back(value.top());
90         value.pop();
91       }
92       // pop "["
93       rules.pop();
94       // make tuple for names
95       std::string res = "dummy";
96       // make tuple for values
97       reverse(vec.begin(), vec.end());
98       auto vt = std::make_shared<T>(vec);
99       if (rules.empty() && value.empty()) {
100         return vt;
101       }
102       rules.push(res);
103       value.push(vt);
104     } else if (str[i] == ',') {
105       continue;
106     } else {
107       count++;
108       if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
109         auto value_name = str.substr(static_cast<int>(i) - count + 1, count);
110         if (kv.find(value_name) == kv.end()) {
111           MS_LOG(ERROR) << "Node's attributes and shape do not match.";
112           return nullptr;
113         }
114         value.push(kv.at(value_name));
115         rules.push(value_name);
116         count = 0;
117       }
118     }
119   }
120   return {};
121 }
122 
123 template <typename T>
ParserScalarAttrValue(const std::string & attr_name,const std::unordered_map<string,ValuePtr> & kv)124 std::shared_ptr<T> ParserScalarAttrValue(const std::string &attr_name, const std::unordered_map<string, ValuePtr> &kv) {
125   std::string str = attr_name;
126   auto replace = [&](const string &orgStr, const string &newStr) {
127     std::string::size_type pos(0);
128     while ((pos = str.find(orgStr)) != std::string::npos) {
129       str.replace(pos, orgStr.length(), newStr);
130     }
131     return str;
132   };
133   // remove "scalar:"
134   str = replace("scalar:", "");
135   // remove "Tuple"
136   str = replace("Tuple", "");
137   // remove "List"
138   str = replace("List", "");
139   auto result = ParserAttr<T>(str, kv);
140   return result;
141 }
142 
ParserAttrShape(const std::string & attr_name,const std::unordered_map<string,abstract::AbstractBasePtr> & kv)143 std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
144   const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) {
145   std::string str = attr_name;
146   auto replace = [&](const string &orgStr, const string &newStr) {
147     std::string::size_type pos(0);
148     while ((pos = str.find(orgStr)) != std::string::npos) {
149       str.replace(pos, orgStr.length(), newStr);
150     }
151     return str;
152   };
153   // remove "scalar:"
154   str = replace("shape:", "");
155   // remove "Tuple"
156   str = replace("Tuple", "");
157   // remove "List"
158   str = replace("List", "");
159 
160   auto result = ParserAttr<abstract::AbstractTuple>(str, kv);
161   return result;
162 }
163 
ParseParameterName(const string & name)164 std::string ParseParameterName(const string &name) {
165   string delimiter = ":";
166   size_t pos(0);
167   if ((pos = name.find(delimiter)) != string::npos) {
168     return name.substr(pos + 1, string::npos - (pos + 1));
169   }
170   return name;
171 }
172 
ParseCNodeName(const string & name)173 std::string ParseCNodeName(const string &name) {
174   string delimiter = ":";
175   size_t pos = name.find(delimiter);
176   size_t end_pos = name.find_last_of(delimiter);
177   if (pos != string::npos && end_pos != string::npos && pos != end_pos) {
178     return name.substr(pos + 1, end_pos - (pos + 1));
179   }
180   return name;
181 }
182 
183 #define PARSE_MINDIR_ATTR_IN_INT_FORM(type, valuetype)                                                    \
184   ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
185     auto value = static_cast<valuetype>(attr_proto.ints(index));                                          \
186     return MakeValue<valuetype>(value);                                                                   \
187   }                                                                                                       \
188   ValuePtr ParseAttrInSingleScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto) {      \
189     auto value = static_cast<valuetype>(attr_proto.i());                                                  \
190     return MakeValue<valuetype>(value);                                                                   \
191   }
192 
193 #define PARSE_MINDIR_ATTR_IN_SCALAR_FORM(type, valuetype)                                                 \
194   ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
195     auto value = static_cast<valuetype>(attr_proto.type##s(index));                                       \
196     return MakeValue<valuetype>(value);                                                                   \
197   }
198 
PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t,int8_t)199 PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t, int8_t)
200 PARSE_MINDIR_ATTR_IN_INT_FORM(int16_t, int16_t)
201 PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, int32_t)
202 PARSE_MINDIR_ATTR_IN_INT_FORM(int64_t, int64_t)
203 PARSE_MINDIR_ATTR_IN_INT_FORM(uint8_t, uint8_t)
204 PARSE_MINDIR_ATTR_IN_INT_FORM(uint16_t, uint16_t)
205 PARSE_MINDIR_ATTR_IN_INT_FORM(uint32_t, uint32_t)
206 PARSE_MINDIR_ATTR_IN_INT_FORM(uint64_t, uint64_t)
207 PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, bool)
208 
209 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(double, double)
210 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(float, float)
211 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(string, string)
212 
213 ValuePtr ParseAttrInSingleScalar_string_string(const mind_ir::AttributeProto &attr_proto) {
214   auto value = static_cast<string>(attr_proto.s());
215   return MakeValue<string>(value);
216 }
217 
ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto & attr_proto)218 ValuePtr ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto &attr_proto) {
219   auto value = static_cast<float>(attr_proto.f());
220   return MakeValue<float>(value);
221 }
222 
ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto & attr_proto)223 ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &attr_proto) {
224   auto value = static_cast<double>(attr_proto.d());
225   return MakeValue<double>(value);
226 }
227 
BuildTensorInfoForFuncGraph(const mind_ir::TensorProto & tensor_proto)228 tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) {
229   ShapeVector shape;
230   for (int i = 0; i < tensor_proto.dims_size(); ++i) {
231     shape.push_back(tensor_proto.dims(i));
232   }
233 
234   if (!tensor_proto.has_data_type()) {
235     MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
236     MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type.";
237   }
238   if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
239     MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
240     MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
241   }
242 
243   tensor::TensorPtr tensor_info =
244     std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
245   return tensor_info;
246 }
247 
BuildParameterForFuncGraph(const ParameterPtr & node,const mind_ir::TensorProto & parameter_proto)248 bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
249                                                   const mind_ir::TensorProto &parameter_proto) {
250   MS_EXCEPTION_IF_NULL(node);
251 
252   if (!parameter_proto.has_name()) {
253     MS_LOG(ERROR) << "mind_ir TensorProto has no name!";
254     return false;
255   }
256   string debug_info_name = ParseParameterName(parameter_proto.name());
257   auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
258   node->set_debug_info(debug_info_ptr);
259   node->set_name(debug_info_name);
260 
261   tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
262   MS_EXCEPTION_IF_NULL(tensor_info);
263   MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name;
264   if (!IsIncLoad() || load_tensor_map_.find(debug_info_name) == load_tensor_map_.end()) {
265     load_tensor_map_[debug_info_name] = tensor_info;
266   } else {
267     MS_LOG(DEBUG) << "Parameter: " << debug_info_name << " has been already loaded, use it again.";
268     tensor::TensorPtr load_tensor_info = load_tensor_map_[debug_info_name];
269     auto tensor_abstract = load_tensor_info->ToAbstract();
270     MS_EXCEPTION_IF_NULL(tensor_abstract);
271     node->set_abstract(tensor_abstract);
272     node->set_default_param(load_tensor_info);
273     anfnode_build_map_[parameter_proto.name()] = node;
274     return true;
275   }
276   ParamInfoPtr param_info = std::make_shared<ParamInfo>();
277   param_info->set_name(debug_info_name);
278   tensor_info->set_param_info(param_info);
279 
280   auto tensor_abstract = tensor_info->ToAbstract();
281   MS_EXCEPTION_IF_NULL(tensor_abstract);
282   node->set_abstract(tensor_abstract);
283 
284   std::string initial_data = parameter_proto.raw_data();
285   auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
286   MS_EXCEPTION_IF_NULL(tensor_data_buf);
287   auto ret = memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), initial_data.data(),
288                       initial_data.size());
289   if (ret != 0) {
290     MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
291     return false;
292   }
293 
294   node->set_default_param(tensor_info);
295 
296   anfnode_build_map_[parameter_proto.name()] = node;
297   return true;
298 }
299 
BuildInputForFuncGraph(const ParameterPtr & node,const mind_ir::ValueInfoProto & value_proto)300 bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto) {
301   MS_EXCEPTION_IF_NULL(node);
302 
303   if (!value_proto.has_name()) {
304     MS_LOG(ERROR) << "mind_ir ValueInfoProto has no name!";
305     return false;
306   }
307   string debug_info_name = ParseParameterName(value_proto.name());
308   auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
309   node->set_debug_info(debug_info_ptr);
310   node->set_name(debug_info_name);
311 
312   // Set abstract of the parameter
313   if (value_proto.tensor_size() > 0) {
314     const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
315     tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
316     MS_EXCEPTION_IF_NULL(tensor_info);
317     auto tensor_abstract = tensor_info->ToAbstract();
318     node->set_abstract(tensor_abstract);
319   } else if (value_proto.has_denotation()) {
320     MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
321   }
322   anfnode_build_map_[value_proto.name()] = node;
323   return true;
324 }
325 
ImportParametersForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)326 bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
327                                                 const mind_ir::GraphProto &importProto) {
328   MS_EXCEPTION_IF_NULL(outputFuncGraph);
329   MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
330   for (int i = 0; i < importProto.input_size(); ++i) {
331     const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
332     if (!BuildInputForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
333       MS_LOG(ERROR) << "Build input for funcgraph fail at index: " << i;
334       return false;
335     }
336   }
337 
338   MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
339   for (int i = 0; i < importProto.parameter_size(); ++i) {
340     const mind_ir::TensorProto &parameter_proto = importProto.parameter(i);
341     if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
342       MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
343       return false;
344     }
345   }
346   return true;
347 }
348 
ObtainCNodeAttrInTypeForm(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)349 bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
350   MS_EXCEPTION_IF_NULL(prim);
351   const int attr_tensor_type = attr_proto.tensors(0).data_type();
352   if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
353     MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
354     return false;
355   }
356   prim->AddAttr(attr_proto.name(), TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
357   return true;
358 }
359 
ParseAttrInScalarForm(const mind_ir::AttributeProto & attr_proto,int index)360 ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index) {
361   const int attr_type = attr_proto.type();
362   switch (attr_type) {
363     case mind_ir::AttributeProto_AttributeType_STRING: {
364       return ParseAttrInScalar_string_string(attr_proto, index);
365     }
366     case mind_ir::AttributeProto_AttributeType_INT8: {
367       return ParseAttrInScalar_int8_t_int8_t(attr_proto, index);
368     }
369     case mind_ir::AttributeProto_AttributeType_INT16: {
370       return ParseAttrInScalar_int16_t_int16_t(attr_proto, index);
371     }
372     case mind_ir::AttributeProto_AttributeType_INT32: {
373       return ParseAttrInScalar_int32_t_int32_t(attr_proto, index);
374     }
375     case mind_ir::AttributeProto_AttributeType_INT64: {
376       return ParseAttrInScalar_int64_t_int64_t(attr_proto, index);
377     }
378     case mind_ir::AttributeProto_AttributeType_UINT8: {
379       return ParseAttrInScalar_uint8_t_uint8_t(attr_proto, index);
380     }
381     case mind_ir::AttributeProto_AttributeType_UINT16: {
382       return ParseAttrInScalar_uint16_t_uint16_t(attr_proto, index);
383     }
384     case mind_ir::AttributeProto_AttributeType_UINT32: {
385       return ParseAttrInScalar_uint32_t_uint32_t(attr_proto, index);
386     }
387     case mind_ir::AttributeProto_AttributeType_UINT64: {
388       return ParseAttrInScalar_uint64_t_uint64_t(attr_proto, index);
389     }
390     case mind_ir::AttributeProto_AttributeType_FLOAT: {
391       return ParseAttrInScalar_float_float(attr_proto, index);
392     }
393     case mind_ir::AttributeProto_AttributeType_DOUBLE: {
394       return ParseAttrInScalar_double_double(attr_proto, index);
395     }
396     case mind_ir::AttributeProto_AttributeType_BOOL: {
397       return ParseAttrInScalar_int32_t_bool(attr_proto, index);
398     }
399     case mind_ir::AttributeProto_AttributeType_TENSORS: {
400       const int attr_tensor_type = attr_proto.tensors(index).data_type();
401       if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
402         MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
403         return {};
404       }
405       return TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
406     }
407     default:
408       MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
409       return {};
410   }
411   return {};
412 }
413 
ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto & attr_proto,std::unordered_map<std::string,ValuePtr> * multi_value_map)414 void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
415                                                    std::unordered_map<std::string, ValuePtr> *multi_value_map) {
416   string name;
417   auto func = [&name, &multi_value_map, this](const mind_ir::AttributeProto &attr_proto, int i) -> void {
418     auto res = this->ParseAttrInScalarForm(attr_proto, i);
419     name = "value" + std::to_string(i + 1);
420     (void)multi_value_map->emplace(name, res);
421   };
422   for (int i = 0; i < attr_proto.ints_size(); i++) {
423     func(attr_proto, i);
424   }
425   for (int i = 0; i < attr_proto.doubles_size(); i++) {
426     func(attr_proto, i);
427   }
428   for (int i = 0; i < attr_proto.floats_size(); i++) {
429     func(attr_proto, i);
430   }
431   for (int i = 0; i < attr_proto.strings_size(); i++) {
432     func(attr_proto, i);
433   }
434   for (int i = 0; i < attr_proto.tensors_size(); i++) {
435     func(attr_proto, i);
436   }
437 }
438 
ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto & attr_proto) const439 ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) const {
440   const int attr_type = attr_proto.type();
441   switch (attr_type) {
442     case mind_ir::AttributeProto_AttributeType_STRING: {
443       return ParseAttrInSingleScalar_string_string(attr_proto);
444     }
445     case mind_ir::AttributeProto_AttributeType_INT8: {
446       return ParseAttrInSingleScalar_int8_t_int8_t(attr_proto);
447     }
448     case mind_ir::AttributeProto_AttributeType_INT16: {
449       return ParseAttrInSingleScalar_int16_t_int16_t(attr_proto);
450     }
451     case mind_ir::AttributeProto_AttributeType_INT32: {
452       return ParseAttrInSingleScalar_int32_t_int32_t(attr_proto);
453     }
454     case mind_ir::AttributeProto_AttributeType_INT64: {
455       return ParseAttrInSingleScalar_int64_t_int64_t(attr_proto);
456     }
457     case mind_ir::AttributeProto_AttributeType_UINT8: {
458       return ParseAttrInSingleScalar_uint8_t_uint8_t(attr_proto);
459     }
460     case mind_ir::AttributeProto_AttributeType_UINT16: {
461       return ParseAttrInSingleScalar_uint16_t_uint16_t(attr_proto);
462     }
463     case mind_ir::AttributeProto_AttributeType_UINT32: {
464       return ParseAttrInSingleScalar_uint32_t_uint32_t(attr_proto);
465     }
466     case mind_ir::AttributeProto_AttributeType_UINT64: {
467       return ParseAttrInSingleScalar_uint64_t_uint64_t(attr_proto);
468     }
469     case mind_ir::AttributeProto_AttributeType_FLOAT: {
470       return ParseAttrInSingleScalar_float_float(attr_proto);
471     }
472     case mind_ir::AttributeProto_AttributeType_DOUBLE: {
473       return ParseAttrInSingleScalar_double_double(attr_proto);
474     }
475     case mind_ir::AttributeProto_AttributeType_BOOL: {
476       return ParseAttrInSingleScalar_int32_t_bool(attr_proto);
477     }
478     default:
479       MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
480       return {};
481   }
482   return {};
483 }
484 
ObtainCNodeAttrInTensorForm(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)485 bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
486                                                    const mind_ir::AttributeProto &attr_proto) {
487   MS_EXCEPTION_IF_NULL(prim);
488   const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0);
489   const int attr_tensor_type = attr_tensor.data_type();
490   ShapeVector shape;
491   for (int i = 0; i < attr_tensor.dims_size(); ++i) {
492     shape.push_back(attr_tensor.dims(i));
493   }
494   tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
495   MS_EXCEPTION_IF_NULL(tensor_info);
496   const std::string &tensor_buf = attr_tensor.raw_data();
497   auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
498   auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
499   if (ret != 0) {
500     MS_LOG(ERROR) << "Obtain CNode in TensorForm occur memcpy_s error.";
501     return false;
502   }
503   prim->AddAttr(attr_proto.name(), MakeValue(tensor_info));
504   return true;
505 }
506 
GetTypeString(const std::string & ref_attr_name,size_t * pos)507 string GetTypeString(const std::string &ref_attr_name, size_t *pos) {
508   if ((*pos = ref_attr_name.find("scalar:")) != std::string::npos) {
509     return ref_attr_name.substr(*pos, string("scalar:").length() - 1);
510   } else if ((*pos = ref_attr_name.find("type:")) != std::string::npos) {
511     return ref_attr_name.substr(*pos, string("type:").length() - 1);
512   } else if ((*pos = ref_attr_name.find("tensor:")) != std::string::npos) {
513     return ref_attr_name.substr(*pos, string("tensor:").length() - 1);
514   }
515   return "";
516 }
517 
GetAttrValueForCNode(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)518 bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
519   MS_EXCEPTION_IF_NULL(prim);
520   const std::string &attr_name = attr_proto.name();
521   if (!attr_proto.has_ref_attr_name()) {
522     MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
523     return false;
524   }
525   const std::string &ref_attr_name = attr_proto.ref_attr_name();
526 
527   std::size_t pos(0);
528   string type = GetTypeString(ref_attr_name, &pos);
529   std::unordered_map<std::string, ValuePtr> multi_value_map;
530   switch (kParseTypeSwitchMap[type]) {
531     case FORM_PARSE_TYPE: {
532       ObtainCNodeAttrInTypeForm(prim, attr_proto);
533       break;
534     }
535     case FORM_PARSE_SCALAR: {
536       std::size_t value_pos(0);
537       if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
538         ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
539         const std::string &op_type = prim->name();
540         if (!IsLite()) {
541           CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res);
542         }
543         if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && res->isa<StringImm>()) {
544           auto str_dtype = GetValue<std::string>(res);
545           if (str_dtype == "int32") {
546             const int64_t attr_value = 3;
547             (void)prim->AddAttr(attr_name, MakeValue<int64_t>(attr_value));
548             break;
549           }
550           MS_EXCEPTION(NotSupportError)
551             << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
552             << res->ToString();
553         }
554         prim->AddAttr(attr_name, res);
555         break;
556       }
557       ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
558       break;
559     }
560     case FORM_PARSE_TENSOR: {
561       ObtainCNodeAttrInTensorForm(prim, attr_proto);
562       break;
563     }
564     default:
565       MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
566       return false;
567   }
568 
569   if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
570     if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) {
571       auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
572       prim->AddAttr(attr_name, value_tuple_ptr);
573     } else {
574       auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
575       prim->AddAttr(attr_name, value_list_ptr);
576     }
577   }
578   return true;
579 }
580 
ObtainValueNodeInTensorForm(const std::string & value_node_name,const mind_ir::TensorProto & attr_tensor)581 bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
582                                                    const mind_ir::TensorProto &attr_tensor) {
583   const int attr_tensor_type = attr_tensor.data_type();
584   ShapeVector shape;
585   for (int i = 0; i < attr_tensor.dims_size(); ++i) {
586     shape.push_back(attr_tensor.dims(i));
587   }
588   tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
589   MS_EXCEPTION_IF_NULL(tensor_info);
590   const std::string &tensor_buf = attr_tensor.raw_data();
591   auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
592   auto ret =
593     memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), tensor_buf.data(), tensor_buf.size());
594   if (ret != 0) {
595     MS_LOG(ERROR) << "Obtain ValueNode in TensorForm occur memcpy_s error.";
596     return false;
597   }
598 
599   auto new_value_node = NewValueNode(MakeValue(tensor_info));
600   MS_EXCEPTION_IF_NULL(new_value_node);
601   auto tensor_abstract = tensor_info->ToAbstract();
602   MS_EXCEPTION_IF_NULL(tensor_abstract);
603   new_value_node->set_abstract(tensor_abstract);
604   anfnode_build_map_[value_node_name] = new_value_node;
605   return true;
606 }
607 
ObtainValueNodeInTupleTensorForm(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)608 bool MSANFModelParser::ObtainValueNodeInTupleTensorForm(const std::string &value_node_name,
609                                                         const mind_ir::AttributeProto &attr_proto) {
610   std::vector<tensor::TensorPtr> tensor_vec;
611   for (int i = 0; i < attr_proto.tensors_size(); ++i) {
612     mind_ir::TensorProto attr_tensor = attr_proto.tensors(i);
613     const int attr_tensor_type = attr_tensor.data_type();
614     ShapeVector shape;
615     for (int j = 0; j < attr_tensor.dims_size(); ++j) {
616       shape.push_back(attr_tensor.dims(j));
617     }
618     tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
619     const std::string &tensor_buf = attr_tensor.raw_data();
620     auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
621     auto ret = memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), tensor_buf.data(),
622                         tensor_buf.size());
623     if (ret != 0) {
624       MS_LOG(ERROR) << "Obtain ValueNode in TupleTensorForm occur memcpy_s error.";
625       return false;
626     }
627     tensor_vec.push_back(tensor_info);
628   }
629   auto new_value_node = NewValueNode(MakeValue(tensor_vec));
630   anfnode_build_map_[value_node_name] = new_value_node;
631   return true;
632 }
633 
ObtainValueNodeInTypeForm(const std::string & value_node_name,const mind_ir::TensorProto & attr_tensor)634 bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
635                                                  const mind_ir::TensorProto &attr_tensor) {
636   const int attr_tensor_type = attr_tensor.data_type();
637   if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
638     MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
639     return false;
640   }
641   auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
642   abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
643   new_value_node->set_abstract(abs_type);
644   anfnode_build_map_[value_node_name] = new_value_node;
645   return true;
646 }
647 
ObtainValueNodeInNoneForm(const std::string & value_node_name)648 bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_name) {
649   auto new_value_node = NewValueNode(kNone);
650   MS_EXCEPTION_IF_NULL(new_value_node);
651   new_value_node->set_abstract(kNone->ToAbstract());
652   anfnode_build_map_[value_node_name] = new_value_node;
653   return true;
654 }
655 
ObtainValueNodeInMonadForm(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)656 bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name,
657                                                   const mind_ir::AttributeProto &attr_proto) {
658   const std::string &ref_attr_name = attr_proto.ref_attr_name();
659   if (ref_attr_name.find("UMonad") != std::string::npos) {
660     auto monad_abs = kUMonad->ToAbstract();
661     auto new_value_node = NewValueNode(kUMonad);
662     MS_EXCEPTION_IF_NULL(new_value_node);
663     new_value_node->set_abstract(monad_abs);
664     anfnode_build_map_[value_node_name] = new_value_node;
665   } else if (ref_attr_name.find("IOMonad") != std::string::npos) {
666     auto monad_abs = kIOMonad->ToAbstract();
667     auto new_value_node = NewValueNode(kIOMonad);
668     MS_EXCEPTION_IF_NULL(new_value_node);
669     new_value_node->set_abstract(monad_abs);
670     anfnode_build_map_[value_node_name] = new_value_node;
671   } else {
672     return false;
673   }
674   return true;
675 }
676 
677 namespace {
GetTypeFromAttrName(const std::string & ref_attr_name)678 std::string GetTypeFromAttrName(const std::string &ref_attr_name) {
679   string type = "";
680   std::size_t pos(0);
681   if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
682     return ref_attr_name.substr(pos, string("scalar:").length() - 1);
683   } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
684     return ref_attr_name.substr(pos, string("type:").length() - 1);
685   } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
686     return ref_attr_name.substr(pos, string("tensor:").length() - 1);
687   } else if ((pos = ref_attr_name.find("Monad:")) != std::string::npos) {
688     return ref_attr_name.substr(pos, string("Monad:").length() - 1);
689   } else if (ref_attr_name == "none") {
690     return ref_attr_name;
691   }
692   return type;
693 }
694 }  // namespace
695 
GetAttrValueForValueNode(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)696 bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
697                                                 const mind_ir::AttributeProto &attr_proto) {
698   if (!attr_proto.has_ref_attr_name()) {
699     MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
700     return false;
701   }
702   const std::string &ref_attr_name = attr_proto.ref_attr_name();
703   auto type = GetTypeFromAttrName(ref_attr_name);
704   ValueNodePtr new_value_node;
705   std::unordered_map<std::string, ValuePtr> multi_value_map;
706   switch (kParseTypeSwitchMap[type]) {
707     case FORM_PARSE_TYPE: {
708       ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0));
709       break;
710     }
711     case FORM_PARSE_SCALAR: {
712       std::size_t value_pos(0);
713       if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
714         auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
715         new_value_node = NewValueNode(res);
716         new_value_node->set_abstract(res->ToAbstract());
717         anfnode_build_map_[value_node_name] = new_value_node;
718         break;
719       }
720       if ((value_pos = ref_attr_name.find("Tuple[]")) != std::string::npos) {
721         MS_LOG(INFO) << "Build Tuple() ValueNode for primitive.";
722         ValuePtr res = MakeValue(std::vector<ValuePtr>{});
723         new_value_node = NewValueNode(res);
724         new_value_node->set_abstract(res->ToAbstract());
725         anfnode_build_map_[value_node_name] = new_value_node;
726         break;
727       }
728       if ((value_pos = ref_attr_name.find("Tuple[value")) != std::string::npos && attr_proto.tensors_size() > 1) {
729         MS_LOG(INFO) << "Build TupleTensor ValueNode for primitive.";
730         if (!ObtainValueNodeInTupleTensorForm(value_node_name, attr_proto)) {
731           MS_LOG(ERROR) << "Obtain valuenode in tuple tensor Form failed. ";
732           return false;
733         }
734         break;
735       }
736       ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
737       break;
738     }
739     case FORM_PARSE_TENSOR: {
740       (void)ObtainValueNodeInTensorForm(value_node_name, attr_proto.tensors(0));
741       break;
742     }
743     case FORM_PARSE_NONE: {
744       (void)ObtainValueNodeInNoneForm(value_node_name);
745       break;
746     }
747     case FORM_PARSE_MONAD: {
748       (void)ObtainValueNodeInMonadForm(value_node_name, attr_proto);
749       break;
750     }
751     default:
752       MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
753       return false;
754   }
755 
756   if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
757     if (ref_attr_name.find("Tuple") != std::string::npos) {
758       auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
759       new_value_node = NewValueNode(value_tuple_ptr);
760       new_value_node->set_abstract(value_tuple_ptr->ToAbstract());
761     } else {
762       auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
763       new_value_node = NewValueNode(value_list_ptr);
764       new_value_node->set_abstract(value_list_ptr->ToAbstract());
765     }
766     anfnode_build_map_[value_node_name] = new_value_node;
767   }
768   return true;
769 }
770 
BuildValueNodeForFuncGraph(const mind_ir::NodeProto & node_proto)771 bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto) {
772   const std::string &value_node_name = node_proto.output(0);
773   const mind_ir::AttributeProto &attr_proto = node_proto.attribute(0);
774   if (!attr_proto.has_ref_attr_name()) {
775     MS_LOG(ERROR) << "parse ValueNode  don't have ref_attr_name";
776     return false;
777   }
778   return GetAttrValueForValueNode(value_node_name, attr_proto);
779 }
780 
GetAbstractForCNode(const mind_ir::AttributeProto & attr_proto)781 std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
782   const mind_ir::AttributeProto &attr_proto) {
783   std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
784   for (int i = 0; i < attr_proto.tensors_size(); ++i) {
785     ShapeVector shape_vec;
786     const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i);
787     for (int j = 0; j < attr_tensor.dims_size(); ++j) {
788       shape_vec.push_back(attr_tensor.dims(j));
789     }
790     tensor::TensorPtr tensor_info =
791       std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
792     MS_EXCEPTION_IF_NULL(tensor_info);
793     auto abstract = tensor_info->ToAbstract();
794     MS_EXCEPTION_IF_NULL(abstract);
795     (void)kv.emplace(attr_tensor.name(), abstract);
796   }
797   return kv;
798 }
799 
800 // S-Prim-xxx or S-Prim-hyper_map[xxx] -> xxx
GetDoSignaturePrimitiveName(const std::string & node_type)801 static std::string GetDoSignaturePrimitiveName(const std::string &node_type) {
802   // Remove `S-Prim-` prefix.
803   auto prim_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
804   if (prim_name.compare(0, strlen(kHyperMapPrefix), kHyperMapPrefix) != 0) {
805     return prim_name;
806   }
807   // hyper_map[xxx] -> xxx
808   constexpr auto offset = 2;
809   auto op_name = prim_name.substr(strlen(kHyperMapPrefix) + 1, prim_name.length() - strlen(kHyperMapPrefix) - offset);
810   return op_name;
811 }
812 
BuildOperatorNode(const mind_ir::NodeProto & node_proto)813 AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_proto) {
814   const std::string kOperatorTypeFlag = std::string("REF::");
815   const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
816   const std::string &node_type = node_proto.op_type();
817   MS_LOG(DEBUG) << "Process Operator :" << node_type;
818   // Operator maybe CNode,FuncGraph or Parameter.
819 
820   if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
821     auto anfNode = GetAnfNode(node_type.substr(kOpTypeFlagSize));
822     if (anfNode == nullptr) {
823       MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
824     }
825     return anfNode;
826   }
827 
828   // Operator is  primitive.
829   std::shared_ptr<Primitive> prim;
830   auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
831   if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
832     prim = op_primc_fns[node_type]();
833   } else {
834     if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
835       auto op_name = GetDoSignaturePrimitiveName(node_type);
836       prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
837       MS_EXCEPTION_IF_NULL(prim);
838       prim->set_instance_name(op_name);
839     } else {
840       MS_LOG(DEBUG) << "Special node_type: " << node_type;
841       prim = std::make_shared<Primitive>(node_type);
842       MS_EXCEPTION_IF_NULL(prim);
843       prim->set_instance_name(node_type);
844     }
845   }
846   MS_EXCEPTION_IF_NULL(prim);
847   for (int i = 0; i < node_proto.attribute_size(); ++i) {
848     const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
849     // CNode abstract
850     if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
851       continue;
852     }
853     if (!GetAttrValueForCNode(prim, attr_proto)) {
854       MS_LOG(EXCEPTION) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
855     }
856   }
857   prim->set_attr("is_load", MakeValue(true));
858   return std::make_shared<ValueNode>(prim);
859 }
860 
861 // Set CNode abstract.
SetCNodeAbastract(const mind_ir::NodeProto & node_proto,CNodePtr cnode_ptr)862 void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) {
863   const std::string &node_type = node_proto.op_type();
864   // Handle control flow operator.
865   auto operatorPtr = cnode_ptr->input(0);
866   // Set abstract of switch(c,f,t),switchLayer(c,tup) and
867   // partial(func,args) to null
868   auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
869   if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
870       IsPrimitiveEquals(prim::kPrimPartial, prim)) {
871     cnode_ptr->set_abstract(nullptr);
872     return;
873   }
874 
875   // If the operator is not a primitive, the abstract will been set to null.
876   // Because there are not some operators in front end, the abstract of primitive should be reserved.
877   if (prim == nullptr) {
878     cnode_ptr->set_abstract(nullptr);
879     return;
880   }
881 
882   std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
883   string shape_ref_attr_name;
884 
885   for (int i = 0; i < node_proto.attribute_size(); ++i) {
886     const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
887     if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
888       shape_ref_attr_name = attr_proto.ref_attr_name();
889       kv = GetAbstractForCNode(attr_proto);
890       break;
891     }
892   }
893 
894   // Because there is not context in unit test,
895   // abstract->broaden() is replaced by abstract->set_value(kAnyValue).
896   if (kv.size() == 0) {
897     if (node_type == "UpdateState") {
898       cnode_ptr->set_abstract(kUMonad->ToAbstract());
899     } else if (node_type == "Depend") {
900       cnode_ptr->set_abstract(kBool->ToAbstract());
901     } else {
902       AbstractBasePtrList elem;
903       for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
904         auto abs = cnode_ptr->input(index)->abstract();
905         if (abs != nullptr) {
906           if (abs->GetValueTrack() == nullptr) {
907             abs->set_value(kAnyValue);
908           }
909           elem.push_back(abs);
910         }
911       }
912       if (!elem.empty()) {
913         cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
914       }
915     }
916   } else if (kv.size() == 1) {
917     std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
918     if (iter->second != nullptr) {
919       iter->second->set_value(kAnyValue);
920       cnode_ptr->set_abstract(iter->second);
921     }
922   } else {
923     auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
924     if (abstract == nullptr) {
925       cnode_ptr->set_abstract(nullptr);
926       MS_LOG(ERROR) << "Node's attribute is nullptr.";
927     } else {
928       abstract->set_value(kAnyValue);
929       cnode_ptr->set_abstract(abstract);
930     }
931   }
932 }
933 
BuildCNodeForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::NodeProto & node_proto)934 CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
935                                                   const mind_ir::NodeProto &node_proto) {
936   MS_EXCEPTION_IF_NULL(outputFuncGraph);
937   if (!node_proto.has_op_type()) {
938     MS_LOG(ERROR) << "Get CNode op_type failed!";
939     return nullptr;
940   }
941   const std::string &node_name = node_proto.output(0);
942   MS_LOG(DEBUG) << "Process CNode: " << node_name;
943   // Build inputs.
944   std::vector<AnfNodePtr> inputs;
945   inputs.push_back(BuildOperatorNode(node_proto));
946   for (int i = 0; i < node_proto.input_size(); ++i) {
947     auto anfNode = GetAnfNode(node_proto.input(i));
948     if (anfNode == nullptr) {
949       MS_LOG(ERROR) << node_name << " input " << i << node_proto.input(i) << "can't find in nodes have parsed";
950       return nullptr;
951     }
952     inputs.push_back(anfNode);
953   }
954 
955   CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
956   MS_EXCEPTION_IF_NULL(cnode_ptr);
957   SetCNodeAbastract(node_proto, cnode_ptr);
958 
959   const std::string &fullname_with_scope = node_proto.domain();
960   string debug_info_name = ParseCNodeName(node_name);
961   auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
962   cnode_ptr->set_debug_info(debug_info_ptr);
963   cnode_ptr->set_fullname_with_scope(fullname_with_scope);
964   cnode_ptr->set_load_flag(true);
965   if (anfnode_build_map_.count(node_name) > 0) {
966     MS_LOG(EXCEPTION) << "Duplicate CNode name: " << node_name;
967   }
968   anfnode_build_map_[node_name] = cnode_ptr;
969   return cnode_ptr;
970 }
971 
BuildReturnForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)972 bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
973                                                const mind_ir::GraphProto &importProto) {
974   MS_EXCEPTION_IF_NULL(outputFuncGraph);
975   std::vector<AnfNodePtr> inputs;
976   if (importProto.output_size() > 1) {
977     inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
978     AbstractBasePtrList elem;
979     for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
980       const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
981       const std::string &out_tuple = output_node.name();
982       auto anfNode = GetAnfNode(out_tuple);
983       if (anfNode == nullptr) {
984         MS_LOG(ERROR) << "Miss return node: " << out_tuple;
985         return false;
986       }
987       inputs.push_back(anfNode);
988       elem.push_back(anfNode->abstract());
989     }
990     auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
991     MS_EXCEPTION_IF_NULL(maketuple_ptr);
992     maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
993     inputs.push_back(NewValueNode(prim::kPrimReturn));
994     inputs.push_back(maketuple_ptr);
995     auto return_node = outputFuncGraph->NewCNode(inputs);
996     MS_EXCEPTION_IF_NULL(return_node);
997     return_node->set_load_flag(true);
998     outputFuncGraph->set_return(return_node);
999     MS_LOG(DEBUG) << "Construct funcgraph finined, all success.";
1000   } else {
1001     inputs.clear();
1002     inputs.push_back(NewValueNode(prim::kPrimReturn));
1003     auto nodeName = importProto.output(0).name();
1004     auto anfNode = GetAnfNode(nodeName);
1005     if (anfNode == nullptr) {
1006       MS_LOG(ERROR) << "Miss return node: " << nodeName;
1007       return false;
1008     }
1009     inputs.push_back(anfNode);
1010     auto return_node = outputFuncGraph->NewCNode(inputs);
1011     MS_EXCEPTION_IF_NULL(return_node);
1012     return_node->set_load_flag(true);
1013     outputFuncGraph->set_return(return_node);
1014     MS_LOG(DEBUG) << "Construct funcgraph finined, all success!";
1015   }
1016   return true;
1017 }
1018 
ImportNodesForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1019 bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
1020                                            const mind_ir::GraphProto &importProto) {
1021   MS_EXCEPTION_IF_NULL(outputFuncGraph);
1022   CNodePtr cnode_ptr = nullptr;
1023   for (int i = 0; i < importProto.node_size(); ++i) {
1024     const mind_ir::NodeProto &node_proto = importProto.node(i);
1025     const std::string &node_type = node_proto.op_type();
1026     if (node_type == kConstantValueNode) {
1027       if (!BuildValueNodeForFuncGraph(node_proto)) {
1028         MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: " << i;
1029         return false;
1030       }
1031       continue;
1032     }
1033     cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
1034     if (cnode_ptr == nullptr) {
1035       MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: " << i;
1036       return false;
1037     }
1038   }
1039 
1040   return BuildReturnForFuncGraph(outputFuncGraph, importProto);
1041 }
1042 
BuildFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1043 bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) {
1044   MS_EXCEPTION_IF_NULL(outputFuncGraph);
1045   GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
1046   MS_EXCEPTION_IF_NULL(debug_info_ptr);
1047   if (importProto.has_name()) {
1048     debug_info_ptr->set_name(importProto.name());
1049   } else {
1050     MS_LOG(ERROR) << "FuncGraph under converting has not name!";
1051   }
1052   if (importProto.has_bprop_hash()) {
1053     outputFuncGraph->set_bprop_hash(importProto.bprop_hash());
1054   }
1055 
1056   if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
1057     MS_LOG(ERROR) << "import parameters for graph fail!";
1058     return false;
1059   }
1060   return ImportNodesForGraph(outputFuncGraph, importProto);
1061 }
1062 
MSANFParseModelConfigureInfo(const mind_ir::ModelProto & model_proto)1063 bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto) {
1064   if (!model_proto.has_producer_name()) {
1065     MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
1066     return false;
1067   }
1068   producer_name_ = model_proto.producer_name();
1069   MS_LOG(INFO) << "producer_name :" << producer_name_;
1070 
1071   if (!model_proto.has_model_version()) {
1072     MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
1073     return false;
1074   }
1075   model_version_ = model_proto.model_version();
1076   MS_LOG(INFO) << "producer_version : " << model_version_;
1077 
1078   if (!model_proto.has_ir_version()) {
1079     MS_LOG(ERROR) << "Parse model version from pb file failed!";
1080     return false;
1081   }
1082   ir_version_ = model_proto.ir_version();
1083   MS_LOG(INFO) << "ir_version :" << ir_version_;
1084   return true;
1085 }
1086 
Parse(const mind_ir::ModelProto & model_proto)1087 FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
1088   FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
1089   MS_EXCEPTION_IF_NULL(dstGraph);
1090   if (!MSANFParseModelConfigureInfo(model_proto)) {
1091     MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
1092   }
1093   const mind_ir::GraphProto &graphBuild = model_proto.graph();
1094 
1095   // Forward declare FuncGraph name
1096   // Compatible with the previous proto.
1097   if (graphBuild.has_name()) {
1098     anfnode_build_map_[graphBuild.name()] = std::make_shared<ValueNode>(dstGraph);
1099   }
1100   for (int i = 0; i < model_proto.functions_size(); ++i) {
1101     FuncGraphPtr graph = std::make_shared<FuncGraph>();
1102     const auto &graph_proto = model_proto.functions(i);
1103     if (!graph_proto.has_name()) {
1104       MS_LOG(EXCEPTION) << "The function has not a name. Please export mindIR again. ";
1105     }
1106     if (anfnode_build_map_.count(graph_proto.name()) > 0) {
1107       MS_LOG(EXCEPTION) << "There is a duplication function graph name: " << graph_proto.name();
1108     }
1109     anfnode_build_map_[graph_proto.name()] = std::make_shared<ValueNode>(graph);
1110   }
1111 
1112   // Parser the proto.
1113   if (!BuildFuncGraph(dstGraph, graphBuild)) {
1114     MS_LOG(ERROR) << "Build funcgraph failed!";
1115     return nullptr;
1116   }
1117   MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
1118   for (int i = 0; i < model_proto.functions_size(); ++i) {
1119     const auto &graph_proto = model_proto.functions(i);
1120     FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);
1121     if (!BuildFuncGraph(graph, graph_proto)) {
1122       MS_LOG(ERROR) << "Build funcgraph failed!";
1123       return nullptr;
1124     }
1125     MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name();
1126   }
1127   // Release resource
1128   anfnode_build_map_.clear();
1129   return dstGraph;
1130 }
1131 
GetAnfNode(const std::string & node_name)1132 AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) {
1133   auto it = anfnode_build_map_.find(node_name);
1134   if (it == anfnode_build_map_.end()) {
1135     return nullptr;
1136   }
1137   FuncGraphPtr func_graph_ptr = GetValueNode<FuncGraphPtr>(it->second);
1138   if (func_graph_ptr) {
1139     return NewValueNode(func_graph_ptr);
1140   } else {
1141     return it->second;
1142   }
1143 }
1144 }  // namespace mindspore
1145