• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "load_mindir/load_model.h"
17 #include <sys/stat.h>
18 #include <sys/types.h>
19 #include <cstring>
20 #include <string>
21 #include <memory>
22 #include <algorithm>
23 #include <fstream>
24 #include <iostream>
25 #include <stack>
26 #include <list>
27 #include <utility>
28 #include <nlohmann/json.hpp>
29 #include "mindspore/core/ops/structure_ops.h"
30 #include "mindspore/core/ops/sequence_ops.h"
31 #include "mindspore/core/ops/framework_ops.h"
32 #include "utils/crypto.h"
33 #include "utils/os.h"
34 #include "ir/value.h"
35 #include "ir/tensor.h"
36 #include "ir/param_info.h"
37 #include "ir/map_tensor.h"
38 #include "ir/functor.h"
39 #include "ops/primitive_c.h"
40 #include "abstract/abstract_value.h"
41 #include "abstract/ops/primitive_infer_map.h"
42 #include "utils/hash_map.h"
43 #include "utils/log_adapter.h"
44 #include "utils/check_convert_utils.h"
45 #include "utils/ms_utils_secure.h"
46 #include "abstract/abstract_function.h"
47 #include "load_mindir/infer_mindir.h"
48 #include "include/common/debug/common.h"
49 #include "proto/mind_ir.pb.h"
50 #include "google/protobuf/io/zero_copy_stream_impl.h"
51 
52 using std::string;
53 using std::vector;
54 
55 namespace mindspore {
56 namespace {
57 static constexpr char kConstantValueNode[] = "Constant";
58 static constexpr char kQuantParam[] = "quant_param";
59 static constexpr char kGraphInputQuantParam[] = "graph_input_quant_param";
60 
61 enum ParseForm : int {
62   FORM_PARSE_TYPE = 0,
63   FORM_PARSE_SCALAR = 1,
64   FORM_PARSE_TENSOR = 2,
65   FORM_PARSE_NONE = 3,
66   FORM_PARSE_MONAD = 4,
67   FORM_PARSE_SEQUENCE = 5,
68   FORM_PARSE_UNDEFINE = 6,
69 };
70 
71 static std::map<std::string, ParseForm> kParseTypeSwitchMap{
72   {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR},
73   {"none", FORM_PARSE_NONE}, {"Monad", FORM_PARSE_MONAD},   {"Sequence", FORM_PARSE_SEQUENCE}};
74 
75 static mindspore::HashMap<int, TypeId> kDefaultValueSwitchMap{
76   {mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool},
77   {mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8},
78   {mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16},
79   {mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32},
80   {mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64},
81   {mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8},
82   {mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16},
83   {mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32},
84   {mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64},
85   {mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
86   {mind_ir::TensorProto_DataType_BFLOAT16, kNumberTypeBFloat16},
87   {mind_ir::TensorProto_DataType_QINT4X2, kNumberTypeInt4},
88   {mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
89   {mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64},
90   {mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
91   {mind_ir::TensorProto_DataType_STRING, kObjectTypeString},
92   {mind_ir::TensorProto_DataType_COMPLEX64, kNumberTypeComplex64},
93   {mind_ir::TensorProto_DataType_COMPLEX128, kNumberTypeComplex128}};
94 
95 template <typename T, typename P>
ParserAttr(const std::string & str,const mindspore::HashMap<string,P> & kv)96 std::shared_ptr<T> ParserAttr(const std::string &str, const mindspore::HashMap<string, P> &kv) {
97   std::stack<std::string> rules;
98   std::stack<P> value;
99   size_t count = 0;
100   for (size_t i = 0; i < str.length(); i++) {
101     if (str[i] == '[') {
102       rules.push(std::string("["));
103     } else if (str[i] == ']') {
104       // rules
105       std::vector<P> vec;
106       while (!rules.empty() && rules.top() != "[") {
107         rules.pop();
108         vec.push_back(value.top());
109         value.pop();
110       }
111       if (!rules.empty()) {
112         // pop "["
113         rules.pop();
114       }
115       // make tuple for names
116       std::string res = "dummy";
117       // make tuple for values
118       reverse(vec.begin(), vec.end());
119       auto vt = std::make_shared<T>(vec);
120       if (rules.empty() && value.empty()) {
121         return vt;
122       }
123       rules.push(res);
124       value.push(vt);
125     } else if (str[i] == ',') {
126       continue;
127     } else {
128       count++;
129       if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
130         auto value_name = str.substr((i - count) + 1, count);
131         if (kv.find(value_name) == kv.end()) {
132           MS_LOG(ERROR) << "Node's attributes and shape do not match.";
133           return nullptr;
134         }
135         value.push(kv.at(value_name));
136         rules.push(value_name);
137         count = 0;
138       }
139     }
140   }
141   return {};
142 }
143 
144 template <typename T>
ParserScalarAttrValue(const std::string & attr_name,const mindspore::HashMap<string,ValuePtr> & kv)145 std::shared_ptr<T> ParserScalarAttrValue(const std::string &attr_name, const mindspore::HashMap<string, ValuePtr> &kv) {
146   std::string str = attr_name;
147   auto replace = [&](const string &orgStr, const string &newStr) {
148     std::string::size_type pos;
149     while ((pos = str.find(orgStr)) != std::string::npos) {
150       (void)str.replace(pos, orgStr.length(), newStr);
151     }
152     return str;
153   };
154   // remove "scalar:"
155   str = replace("scalar:", "");
156   // remove "Tuple"
157   str = replace("Tuple", "");
158   // remove "List"
159   str = replace("List", "");
160   auto result = ParserAttr<T, ValuePtr>(str, kv);
161   return result;
162 }
163 
ParserAttrShape(const std::string & attr_name,const mindspore::HashMap<string,abstract::AbstractBasePtr> & kv)164 std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
165   const std::string &attr_name, const mindspore::HashMap<string, abstract::AbstractBasePtr> &kv) {
166   std::string str = attr_name;
167   auto replace = [&](const string &orgStr, const string &newStr) {
168     std::string::size_type pos;
169     while ((pos = str.find(orgStr)) != std::string::npos) {
170       (void)str.replace(pos, orgStr.length(), newStr);
171     }
172     return str;
173   };
174   // remove "scalar:"
175   str = replace("shape:", "");
176   // remove "Tuple"
177   str = replace("Tuple", "");
178   // remove "List"
179   str = replace("List", "");
180 
181   auto result = ParserAttr<abstract::AbstractTuple, abstract::AbstractBasePtr>(str, kv);
182   return result;
183 }
184 
ParseParameterName(const string & name)185 std::string ParseParameterName(const string &name) {
186   string delimiter = ":";
187   size_t pos;
188   if ((pos = name.find(delimiter)) != string::npos) {
189     return name.substr(pos + 1, string::npos - (pos + 1));
190   }
191   return name;
192 }
193 
ParseCNodeName(const string & name)194 std::string ParseCNodeName(const string &name) {
195   string delimiter = ":";
196   size_t pos = name.find(delimiter);
197   size_t end_pos = name.find_last_of(delimiter);
198   if (pos != string::npos && end_pos != string::npos && pos != end_pos) {
199     return name.substr(pos + 1, end_pos - (pos + 1));
200   }
201   return name;
202 }
203 
204 #define PARSE_MINDIR_ATTR_IN_INT_FORM(type, valuetype)                                                    \
205   ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
206     if (attr_proto.ints_size() > index) {                                                                 \
207       auto value = static_cast<valuetype>(attr_proto.ints(index));                                        \
208       return MakeValue<valuetype>(value);                                                                 \
209     }                                                                                                     \
210     MS_LOG(INTERNAL_EXCEPTION) << "Parse MindIR attr failed.";                                            \
211   }                                                                                                       \
212   ValuePtr ParseAttrInSingleScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto) {      \
213     if (attr_proto.has_i()) {                                                                             \
214       auto value = static_cast<valuetype>(attr_proto.i());                                                \
215       return MakeValue<valuetype>(value);                                                                 \
216     }                                                                                                     \
217     MS_LOG(INTERNAL_EXCEPTION) << "Parse MindIR attr failed.";                                            \
218   }
219 
220 #define PARSE_MINDIR_ATTR_IN_SCALAR_FORM(type, valuetype)                                                 \
221   ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
222     if (attr_proto.type##s_size() > index) {                                                              \
223       auto value = static_cast<valuetype>(attr_proto.type##s(index));                                     \
224       return MakeValue<valuetype>(value);                                                                 \
225     }                                                                                                     \
226     MS_LOG(INTERNAL_EXCEPTION) << "Parse MindIR attr failed.";                                            \
227   }
228 
PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t,int8_t)229 PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t, int8_t)
230 
231 PARSE_MINDIR_ATTR_IN_INT_FORM(int16_t, int16_t)
232 
233 PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, int32_t)
234 
235 PARSE_MINDIR_ATTR_IN_INT_FORM(int64_t, int64_t)
236 
237 PARSE_MINDIR_ATTR_IN_INT_FORM(uint8_t, uint8_t)
238 
239 PARSE_MINDIR_ATTR_IN_INT_FORM(uint16_t, uint16_t)
240 
241 PARSE_MINDIR_ATTR_IN_INT_FORM(uint32_t, uint32_t)
242 
243 PARSE_MINDIR_ATTR_IN_INT_FORM(uint64_t, uint64_t)
244 
245 PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, bool)
246 
247 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(double, double)
248 
249 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(float, float)
250 
251 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(string, string)
252 
253 ValuePtr ParseAttrInSingleScalar_string_string(const mind_ir::AttributeProto &attr_proto) {
254   auto value = static_cast<string>(attr_proto.s());
255   return MakeValue<string>(value);
256 }
257 
ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto & attr_proto)258 ValuePtr ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto &attr_proto) {
259   auto value = static_cast<float>(attr_proto.f());
260   return MakeValue<float>(value);
261 }
262 
ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto & attr_proto)263 ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &attr_proto) {
264   auto value = static_cast<double>(attr_proto.d());
265   return MakeValue<double>(value);
266 }
267 
GetParseFormType(const std::string & ref_attr_name)268 ParseForm GetParseFormType(const std::string &ref_attr_name) {
269   for (const auto &iter : kParseTypeSwitchMap) {
270     if (ref_attr_name.find(iter.first) == 0) {
271       return iter.second;
272     }
273   }
274   return FORM_PARSE_UNDEFINE;
275 }
276 
277 template <typename T>
NewValueNodeWithAbstract(const T & value)278 AnfNodePtr NewValueNodeWithAbstract(const T &value) {
279   auto node = NewValueNode(value);
280   node->set_abstract(value->ToAbstract());
281   return node;
282 }
283 
FindGraphByName(const std::vector<FuncGraphPtr> & graphs,const std::string & name)284 FuncGraphPtr FindGraphByName(const std::vector<FuncGraphPtr> &graphs, const std::string &name) {
285   auto iter = std::find_if(graphs.begin(), graphs.end(), [&name](const auto &g) { return g->ToString() == name; });
286   if (iter != graphs.end()) {
287     return *iter;
288   }
289   return nullptr;
290 }
291 
CheckModelConfigureInfo(const mind_ir::ModelProto & model_proto)292 bool CheckModelConfigureInfo(const mind_ir::ModelProto &model_proto) {
293   if (!model_proto.has_producer_name()) {
294     MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
295     return false;
296   }
297   const auto &producer_name = model_proto.producer_name();
298   MS_LOG(INFO) << "Producer name: " << producer_name;
299 
300   if (!model_proto.has_model_version()) {
301     MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
302     return false;
303   }
304   const auto &model_version = model_proto.model_version();
305   MS_LOG(INFO) << "Producer version: " << model_version;
306 
307   int64_t mind_ir_version = 0;
308   if (model_proto.has_mind_ir_version()) {
309     mind_ir_version = model_proto.mind_ir_version();
310   }
311   if (!mind_ir::Version_IsValid(mind_ir_version)) {
312     MS_LOG(EXCEPTION) << "This software can only support the maximum mind ir version: " << mind_ir::Version_MAX
313                       << ", please install the latest version to support the mind ir version: " << mind_ir_version;
314   }
315   if (model_proto.has_little_endian()) {
316     if (model_proto.little_endian() != common::IsLittleByteOrder()) {
317       MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!";
318       return false;
319     }
320   }
321   return true;
322 }
323 }  // namespace
324 
325 namespace {
326 class MSANFModelParser {
327  public:
328   MSANFModelParser() = default;
329   ~MSANFModelParser() = default;
330 
331   static void LoadTensorMapClear();
332   FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto, const std::map<std::string, ValuePtr> &weights = {},
333                      mindspore::HashMap<std::string, AnfNodePtr> *name_to_node = nullptr);
334   bool Parse(const mind_ir::ModelProto &model_proto, const std::vector<FuncGraphPtr> &graphs,
335              mindspore::HashMap<std::string, AnfNodePtr> *name_to_node = nullptr);
336   const LayoutMap ParseLayout(const mind_ir::ModelProto &model_proto);
337 
SetLite()338   void SetLite() { is_lite_ = true; }
IsLite() const339   bool IsLite() const { return is_lite_; }
SetMindIRPath(const std::string & file_path)340   void SetMindIRPath(const std::string &file_path) { mindir_path_ = file_path; }
SetMindIRDecKey(const unsigned char * dec_key)341   void SetMindIRDecKey(const unsigned char *dec_key) { mindir_dec_key_ = dec_key; }
SetMindIRKeySize(size_t size)342   void SetMindIRKeySize(size_t size) { mindir_key_size_ = size; }
SetMindIRDecMode(const std::string & dec_mode)343   void SetMindIRDecMode(const std::string &dec_mode) { mindir_dec_mode_ = dec_mode; }
344 
345  private:
346   void TrytoBuildCNodeAbstract();
347   bool BuildPrimitiveNode(const mind_ir::PrimitiveProto &primitive_proto);
348   abstract::AbstractBasePtr BuildAbstractFunction(const mind_ir::AttributeProto &attr_proto);
349   void CorrectFuncGraph(const FuncGraphPtr &root);
350   bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
351   bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
352   bool BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto);
353   ValuePtr GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
354   bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
355   bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
356   bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
357   bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &parameter_proto);
358   bool BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
359                                            const mind_ir::MapTensorProto &map_parameter_proto);
360   abstract::AbstractMapTensorPtr BuildAbstractMapTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
361   abstract::AbstractCOOTensorPtr BuildAbstractCOOTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
362   abstract::AbstractCSRTensorPtr BuildAbstractCSRTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
363   abstract::AbstractSequencePtr BuildAbstractSequence(const mind_ir::AttributeProto &attr_proto);
364   abstract::AbstractScalarPtr BuildAbstractScalar(const mind_ir::AttributeProto &attr_proto) const;
365   bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
366   bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
367   bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
368   abstract::AbstractTensorPtr GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto);
369   CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto);
370   bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
371   bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
372   bool SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
373   bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
374   void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
375                                    mindspore::HashMap<std::string, ValuePtr> *multi_value_map);
376   ValuePtr ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index);
377   ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
378   bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
379   bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
380   ValuePtr BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
381   AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
382   bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr);
383   void SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr);
384   bool SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto, const AnfNodePtr &node_ptr);
385   abstract::AbstractBasePtr GetNodeAbstractFromAttrProtoWithType(const mind_ir::AttributeProto &attr_proto);
386   void SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr);
387   bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
388   bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
389   bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
390   bool GetAttrValueForValueNodeWithType(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
391   bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
392   bool ObtainValueNodeInNoneForm(const std::string &value_node_name);
393   bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
394   ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto);
395   ValuePtr ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto);
396   std::vector<std::shared_ptr<mindspore::QuantizationParam>> GenerateQuantizationParam(
397     const mind_ir::TensorProto &attr_tensor);
398   FunctorPtr GenerateFunctorValue(const mind_ir::FunctorProto &functor_proto);
little_endian() const399   bool little_endian() const { return little_endian_; }
400   mindspore::HashMap<std::string, abstract::AbstractBasePtr> GetAbstractForNode(
401     const mind_ir::AttributeProto &attr_proto);
402   AnfNodePtr GetAnfNode(const std::string &node_name);
403   tensor::TensorPtr GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor);
404 
405   static tensor::TensorPtr GetIncTensor(const std::string &tensor_name);
406   static void SetIncTensor(const std::string &tensor_name, const tensor::TensorPtr &tensor);
407 
408   FuncGraphPtr top_graph_ = nullptr;
409   bool is_lite_ = false;
410   bool abstract_valid_ = false;
411   mindspore::HashMap<std::string, AnfNodePtr> anfnode_build_map_;
412   std::string mindir_path_;
413   const unsigned char *mindir_dec_key_{nullptr};
414   size_t mindir_key_size_{0};
415   std::string mindir_dec_mode_;
416   bool little_endian_ = common::IsLittleByteOrder();
417   std::map<std::string, std::unique_ptr<Byte[]>> tenor_data_;
418   bool is_kernel_graph_{false};
419   std::list<std::pair<const CNodePtr, const mind_ir::AttributeProto *>> node_abstract_protos_;
420 };
421 
GetValueFromAttributeProto(const mind_ir::AttributeProto & attr_proto)422 ValuePtr MSANFModelParser::GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
423   auto attr_name = attr_proto.name();
424   switch (attr_proto.type()) {
425     case mind_ir::AttributeProto_AttributeType_TENSORS: {
426       mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
427       if (tensor_proto.has_raw_data()) {
428         // For real tensor.
429         tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
430         if (tensor_info == nullptr) {
431           MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
432           return nullptr;
433         }
434         return tensor_info;
435       } else if (tensor_proto.name() == kGraphInputQuantParam) {
436         auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
437         if (!quantization_param_vector.empty()) {
438           return quantization_param_vector[0];
439         }
440       } else {
441         // For data type.
442         const int attr_tensor_type = tensor_proto.data_type();
443         auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
444         if (iter == kDefaultValueSwitchMap.end()) {
445           MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
446           return nullptr;
447         }
448         return TypeIdToType(iter->second);
449       }
450       MS_LOG(ERROR) << "Failed to get the tensor for value.";
451       return nullptr;
452     }
453     case mind_ir::AttributeProto_AttributeType_NONE: {
454       return kNone;
455     }
456     case mind_ir::AttributeProto_AttributeType_TUPLE:
457     case mind_ir::AttributeProto_AttributeType_LIST: {
458       auto sequence_value = ObtainValueInSequenceForm(attr_proto);
459       if (sequence_value == nullptr) {
460         MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
461         return nullptr;
462       }
463       return sequence_value;
464     }
465     case mind_ir::AttributeProto_AttributeType_DICT: {
466       auto dict_value = ObtainValueInDictionaryForm(attr_proto);
467       if (dict_value == nullptr) {
468         MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name;
469         return nullptr;
470       }
471       return dict_value;
472     }
473     case mind_ir::AttributeProto_AttributeType_FUNCTOR: {
474       auto functor_value = GenerateFunctorValue(attr_proto.functor());
475       if (functor_value == nullptr) {
476         MS_LOG(ERROR) << "Failed to get functor value for " << attr_name;
477         return nullptr;
478       }
479       return functor_value;
480     }
481     default: {
482       ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
483       if (value == nullptr) {
484         MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name;
485         return nullptr;
486       }
487       return value;
488     }
489   }
490 }
491 
GenerateFunctorValue(const mind_ir::FunctorProto & functor_proto)492 FunctorPtr MSANFModelParser::GenerateFunctorValue(const mind_ir::FunctorProto &functor_proto) {
493   auto name = functor_proto.name();
494   auto type = functor_proto.type();
495   auto values = GetValueFromAttributeProto(functor_proto.values(0));
496   if (type == mind_ir::FunctorProto_FunctorType_SHAPE_CALC_FUNCTOR) {
497     auto creator = FunctorRegistry::Instance().GetCreator(name);
498     if (creator == nullptr) {
499       MS_LOG(ERROR) << "Cannot find the functor creator: " << name;
500       return nullptr;
501     }
502     auto functor = creator();
503     functor->FromValue(values);
504     return functor;
505   }
506   MS_LOG(ERROR) << "Unknown functor type: " << type;
507   return nullptr;
508 }
509 
GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto & attr_tensor)510 tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor) {
511   ShapeVector shape;
512   const int attr_tensor_type = attr_tensor.data_type();
513   for (int i = 0; i < attr_tensor.dims_size(); ++i) {
514     shape.push_back(attr_tensor.dims(i));
515   }
516   tensor::TensorPtr tensor = nullptr;
517   if (!attr_tensor.has_compression_type() ||
518       attr_tensor.compression_type() == mind_ir::TensorProto_CompressionType_NO_COMPRESSION) {
519     tensor = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
520   } else {
521     auto compression_type = static_cast<TensorCompressionType>(static_cast<int>(attr_tensor.compression_type()));
522     size_t data_size = 0;
523     if (!attr_tensor.has_external_data()) {
524       data_size = attr_tensor.raw_data().size();
525     } else {
526       data_size = LongToSize(attr_tensor.external_data().length());
527     }
528     tensor =
529       std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape, data_size, compression_type);
530   }
531 
532   auto quantization_param_vector = GenerateQuantizationParam(attr_tensor);
533   if (!quantization_param_vector.empty()) {
534     tensor->set_quant_param(quantization_param_vector);
535   }
536 
537   MS_EXCEPTION_IF_NULL(tensor);
538   const std::string &tensor_buf = attr_tensor.raw_data();
539   if (attr_tensor.has_raw_data() && tensor->data().nbytes() != 0) {
540     auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
541     errno_t ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
542     if (ret != EOK) {
543       MS_LOG(ERROR) << "Failed to copy data from tensor proto.";
544       return nullptr;
545     }
546   } else if (attr_tensor.has_external_data()) {
547     auto ret = GetTensorDataFromExternal(attr_tensor, tensor);
548     if (!ret) {
549       MS_LOG(ERROR) << "Failed to get external data from tensor proto.";
550       return nullptr;
551     }
552   } else {
553     MS_LOG(DEBUG) << "Parameter will load initialized data.";
554   }
555   return tensor;
556 }
557 
GenerateQuantizationParam(const mind_ir::TensorProto & attr_tensor)558 std::vector<std::shared_ptr<mindspore::QuantizationParam>> MSANFModelParser::GenerateQuantizationParam(
559   const mind_ir::TensorProto &attr_tensor) {
560   auto quant_param_proto = attr_tensor.quant_params();
561   std::vector<std::shared_ptr<mindspore::QuantizationParam>> quantization_param_vector;
562   for (int i = 0; i < quant_param_proto.size(); i++) {
563     auto quant_data = quant_param_proto.Get(i);
564     QuantizationParam quantization_param(quant_data.quant_algo_name());
565     for (int index = 0; index < quant_data.attribute_size(); index++) {
566       auto quant_attr_proto = quant_data.attribute().Get(index);
567       if (quant_attr_proto.type() != mind_ir::AttributeProto_AttributeType_LIST) {
568         MS_LOG(ERROR) << "quant_attr_proto.type is " << quant_attr_proto.type()
569                       << ", is should be mind_ir::AttributeProto_AttributeType_LIST ("
570                       << mind_ir::AttributeProto_AttributeType_LIST << ")";
571         return {};
572       }
573       auto sequence_value = ObtainValueInSequenceForm(quant_attr_proto);
574       quantization_param.SetAttr(quant_attr_proto.name(), sequence_value);
575     }
576     quantization_param_vector.push_back(std::make_shared<mindspore::QuantizationParam>(quantization_param));
577   }
578   return quantization_param_vector;
579 }
580 
GetNodeAbstractFromAttrProtoWithType(const mind_ir::AttributeProto & attr_proto)581 abstract::AbstractBasePtr MSANFModelParser::GetNodeAbstractFromAttrProtoWithType(
582   const mind_ir::AttributeProto &attr_proto) {
583   switch (attr_proto.type()) {
584     case mind_ir::AttributeProto_AttributeType_TENSORS: {
585       const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(0);
586       return GetAbsTensorFromTensorProto(attr_tensor);
587     }
588     case mind_ir::AttributeProto_AttributeType_CSR_TENSOR: {
589       return BuildAbstractCSRTensorFromAttrProto(attr_proto);
590     }
591     case mind_ir::AttributeProto_AttributeType_COO_TENSOR: {
592       return BuildAbstractCOOTensorFromAttrProto(attr_proto);
593     }
594     case mind_ir::AttributeProto_AttributeType_MAP_TENSOR: {
595       return BuildAbstractMapTensorFromAttrProto(attr_proto);
596     }
597     case mind_ir::AttributeProto_AttributeType_LIST:
598     case mind_ir::AttributeProto_AttributeType_TUPLE: {
599       return BuildAbstractSequence(attr_proto);
600     }
601     case mind_ir::AttributeProto_AttributeType_UMONAD: {
602       return kUMonad->ToAbstract();
603     }
604     case mind_ir::AttributeProto_AttributeType_IOMONAD: {
605       return kIOMonad->ToAbstract();
606     }
607     // in old version the bool is load and export in an error type.
608     // but MindIR should be Compatible with older versions.
609     case mind_ir::AttributeProto_AttributeType_BOOL: {
610       return kBool->ToAbstract();
611     }
612     case mind_ir::AttributeProto_AttributeType_SCALAR: {
613       return BuildAbstractScalar(attr_proto);
614     }
615     case mind_ir::AttributeProto_AttributeType_NONE: {
616       return kNone->ToAbstract();
617     }
618     case mind_ir::AttributeProto_AttributeType_FUNCGRAPHCLOSURE:
619     case mind_ir::AttributeProto_AttributeType_PRIMITIVECLOSURE:
620     case mind_ir::AttributeProto_AttributeType_PARTIALCLOSURE:
621     case mind_ir::AttributeProto_AttributeType_UNIONFUNCCLOSURE: {
622       return BuildAbstractFunction(attr_proto);
623     }
624     default: {
625       MS_LOG(INFO) << "Not support to get the abstract from AttrProto type: " << attr_proto.type();
626       return nullptr;
627     }
628   }
629 }
630 
SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto & attr_proto,const AnfNodePtr & node_ptr)631 bool MSANFModelParser::SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto,
632                                                     const AnfNodePtr &node_ptr) {
633   mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
634   string shape_ref_attr_name;
635   if (attr_proto.ref_attr_name().find("shape:") == string::npos) {
636     MS_LOG(ERROR) << "Cannot use a attr_proto " << attr_proto.ref_attr_name() << " to init shape.";
637     return false;
638   }
639 
640   shape_ref_attr_name = attr_proto.ref_attr_name();
641   bool is_tuple_or_list =
642     shape_ref_attr_name.find("Tuple[") != string::npos || shape_ref_attr_name.find("List[") != string::npos;
643   kv = GetAbstractForNode(attr_proto);
644   if (kv.empty()) {
645     return SetEmptyTensorProtoCNodeAbstract(node_ptr);
646   } else if (!is_tuple_or_list) {
647     auto iter = kv.begin();
648     if (iter->second != nullptr) {
649       node_ptr->set_abstract(iter->second);
650     }
651   } else {
652     auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
653     node_ptr->set_abstract(abstract);
654     if (abstract == nullptr) {
655       MS_LOG(ERROR) << "Node's attribute is nullptr.";
656       return false;
657     }
658   }
659   return true;
660 }
661 
SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto & node_proto,const CNodePtr & cnode_ptr)662 void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr) {
663   auto prim_to_add_attr = GetCNodePrimitiveWithoutDoSignature(cnode_ptr);
664   if (prim_to_add_attr != nullptr) {
665     prim_to_add_attr->set_attr("is_load", MakeValue(true));
666   }
667   for (int i = 0; i < node_proto.attribute_size(); ++i) {
668     const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
669     // Compatible with older versions.
670     if (attr_proto.has_ref_attr_name()) {
671       if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
672         SetCNodeAbstract(attr_proto, cnode_ptr);
673         continue;
674       }
675       if (prim_to_add_attr != nullptr && !GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
676         MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
677                       << ", attributes error: " << attr_proto.DebugString();
678       }
679     } else {
680       // ref_attr_name is removed in newer versions.
681       if (attr_proto.name() == "shape") {
682         SetCNodeAbstract(attr_proto, cnode_ptr);
683         continue;
684       }
685       if (prim_to_add_attr != nullptr && !SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
686         MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
687                       << ", attributes error: " << attr_proto.DebugString();
688       }
689     }
690   }
691 }
692 
GetAbsTensorFromTensorProto(const mind_ir::TensorProto & tensor_proto)693 abstract::AbstractTensorPtr MSANFModelParser::GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto) {
694   ShapeVector shape;
695   for (int i = 0; i < tensor_proto.dims_size(); ++i) {
696     (void)shape.emplace_back(tensor_proto.dims(i));
697   }
698 
699   if (!tensor_proto.has_data_type()) {
700     MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
701     MS_LOG(ERROR) << "mind_ir TensorProto has no data_type.";
702     return nullptr;
703   }
704   auto iter = kDefaultValueSwitchMap.find(tensor_proto.data_type());
705   if (iter == kDefaultValueSwitchMap.end()) {
706     MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
707     MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
708     return nullptr;
709   }
710   auto tensor_shape = std::make_shared<abstract::Shape>(shape);
711   auto tensor_info = std::make_shared<abstract::AbstractTensor>(TypeIdToType(iter->second), tensor_shape);
712   if (tensor_proto.has_ref_key()) {
713     auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
714     auto abs_ref = std::make_shared<abstract::AbstractRefTensor>(tensor_info, ref_key);
715     return abs_ref;
716   }
717   if (tensor_proto.has_name()) {
718     tensor_info->set_name(tensor_proto.name());
719   }
720   return tensor_info;
721 }
722 
BuildParameterForFuncGraph(const ParameterPtr & node,const mind_ir::TensorProto & parameter_proto)723 bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
724                                                   const mind_ir::TensorProto &parameter_proto) {
725   MS_EXCEPTION_IF_NULL(node);
726 
727   if (!parameter_proto.has_name()) {
728     MS_LOG(ERROR) << "mind_ir TensorProto has no name!";
729     return false;
730   }
731   const auto &unique_name = parameter_proto.name();
732   string debug_info_name = ParseParameterName(unique_name);
733   auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
734   node->set_debug_info(debug_info_ptr);
735   node->set_name(debug_info_name);
736 
737   ParamInfoPtr param_info = std::make_shared<ParamInfo>();
738   param_info->set_name(debug_info_name);
739 
740   MS_LOG(DEBUG) << "Load parameter name: " << unique_name;
741   auto tensor = GenerateTensorPtrFromTensorProto(parameter_proto);
742   if (tensor == nullptr) {
743     MS_LOG(ERROR) << "Build tensor failed from the parameter proto.";
744     return false;
745   }
746   tensor->set_param_info(param_info);
747   node->set_default_param(tensor);
748   node->set_abstract(tensor->ToAbstract());
749 
750   anfnode_build_map_[parameter_proto.name()] = node;
751   return true;
752 }
753 
BuildAbstractCOOTensorFromAttrProto(const mind_ir::AttributeProto & attr_proto)754 abstract::AbstractCOOTensorPtr MSANFModelParser::BuildAbstractCOOTensorFromAttrProto(
755   const mind_ir::AttributeProto &attr_proto) {
756   std::vector<abstract::AbstractBasePtr> vec;
757   for (int i = 0; i < attr_proto.values_size(); ++i) {
758     auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
759     if (abs == nullptr) {
760       MS_LOG(WARNING) << "Failed to get the COOTensor's abstract from AttrProto. " << attr_proto.DebugString();
761       return nullptr;
762     }
763     (void)vec.emplace_back(abs);
764   }
765   return std::make_shared<abstract::AbstractCOOTensor>(vec);
766 }
767 
BuildAbstractCSRTensorFromAttrProto(const mind_ir::AttributeProto & attr_proto)768 abstract::AbstractCSRTensorPtr MSANFModelParser::BuildAbstractCSRTensorFromAttrProto(
769   const mind_ir::AttributeProto &attr_proto) {
770   std::vector<abstract::AbstractBasePtr> vec;
771   for (int i = 0; i < attr_proto.values_size(); ++i) {
772     auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
773     if (abs == nullptr) {
774       MS_LOG(WARNING) << "Failed to get the CSRTensor's abstract from AttrProto. " << attr_proto.DebugString();
775       return nullptr;
776     }
777     (void)vec.emplace_back(abs);
778   }
779   return std::make_shared<abstract::AbstractCSRTensor>(vec);
780 }
781 
BuildAbstractMapTensorFromAttrProto(const mind_ir::AttributeProto & attr_proto)782 abstract::AbstractMapTensorPtr MSANFModelParser::BuildAbstractMapTensorFromAttrProto(
783   const mind_ir::AttributeProto &attr_proto) {
784   // default value
785   if (attr_proto.values_size() != 1) {
786     MS_LOG(INTERNAL_EXCEPTION) << "AttrProto for AbstractMapTensor should has 1 value, but got "
787                                << attr_proto.values_size();
788   }
789   const auto &default_value_proto = attr_proto.values(0);
790   auto default_value = ObtainCNodeAttrInSingleScalarForm(default_value_proto);
791   MS_EXCEPTION_IF_NULL(default_value);
792 
793   constexpr int kAbstractMapTensorAttrProtoTensorsSize = 2;
794   if (attr_proto.tensors_size() != kAbstractMapTensorAttrProtoTensorsSize) {
795     MS_LOG(INTERNAL_EXCEPTION) << "AttrProto for AbstractMapTensor should has 2 tensors, but got "
796                                << attr_proto.tensors_size();
797   }
798   // key tensor
799   const auto &key_tensor_proto = attr_proto.tensors(0);
800   auto key_tensor_abs = GetAbsTensorFromTensorProto(key_tensor_proto);
801   MS_EXCEPTION_IF_NULL(key_tensor_abs);
802   // value tensor
803   const auto &value_tensor_proto = attr_proto.tensors(1);
804   auto value_tensor_abs = GetAbsTensorFromTensorProto(value_tensor_proto);
805   MS_EXCEPTION_IF_NULL(value_tensor_abs);
806   auto value_build_shape_ptr = value_tensor_abs->BuildShape();
807   if (!value_build_shape_ptr->isa<abstract::Shape>()) {
808     MS_LOG(INTERNAL_EXCEPTION) << "value_shape of AbstractMapTensor should be a Shape, but got "
809                                << value_build_shape_ptr->ToString();
810   }
811   auto value_shape_ptr = value_build_shape_ptr->cast<abstract::ShapePtr>();
812   MS_EXCEPTION_IF_NULL(value_shape_ptr);
813   auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor_abs->BuildType()->type_id(),
814                                                         value_tensor_abs->BuildType()->type_id(),
815                                                         value_shape_ptr->shape(), default_value);
816   return std::make_shared<abstract::AbstractMapTensor>(map_tensor);
817 }
818 
BuildAbstractScalar(const mind_ir::AttributeProto & attr_proto) const819 abstract::AbstractScalarPtr MSANFModelParser::BuildAbstractScalar(const mind_ir::AttributeProto &attr_proto) const {
820   const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(0);
821   auto iter = kDefaultValueSwitchMap.find(attr_tensor.data_type());
822   if (iter == kDefaultValueSwitchMap.end()) {
823     MS_LOG(ERROR) << "mind_ir build tensor: " << attr_tensor.name() << " failed";
824     MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << attr_tensor.data_type() << " is not support yet!";
825     return nullptr;
826   }
827   return std::make_shared<abstract::AbstractScalar>(TypeIdToType(iter->second));
828 }
829 
BuildAbstractSequence(const mind_ir::AttributeProto & attr_proto)830 abstract::AbstractSequencePtr MSANFModelParser::BuildAbstractSequence(const mind_ir::AttributeProto &attr_proto) {
831   std::vector<abstract::AbstractBasePtr> vec;
832 
833   for (int i = 0; i < attr_proto.values_size(); ++i) {
834     auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
835     if (abs == nullptr) {
836       MS_LOG(WARNING) << "Failed to get the tuple's abstract from AttrProto. " << attr_proto.DebugString();
837       return nullptr;
838     }
839     (void)vec.emplace_back(abs);
840   }
841   abstract::AbstractSequencePtr seq_abs;
842   if (attr_proto.type() == mind_ir::AttributeProto_AttributeType_TUPLE) {
843     seq_abs = std::make_shared<abstract::AbstractTuple>(vec);
844   } else {
845     seq_abs = std::make_shared<abstract::AbstractList>(vec);
846   }
847   if (attr_proto.has_seq_info()) {
848     auto seq_info = attr_proto.seq_info();
849     seq_abs->set_dynamic_len(seq_info.is_dyn_len());
850     if (seq_info.has_tuple_elem_item()) {
851       auto elem_proto = seq_info.tuple_elem_item();
852       auto elem_abs = GetNodeAbstractFromAttrProtoWithType(elem_proto);
853       seq_abs->set_dynamic_len_element_abs(elem_abs);
854     }
855   }
856   return seq_abs;
857 }
858 
BuildMapParameterFromMapTensorProto(const ParameterPtr & node,const mind_ir::MapTensorProto & map_parameter_proto)859 bool MSANFModelParser::BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
860                                                            const mind_ir::MapTensorProto &map_parameter_proto) {
861   MS_EXCEPTION_IF_NULL(node);
862 
863   if (!map_parameter_proto.has_name()) {
864     MS_LOG(ERROR) << "mind_ir MapTensorProto has no name!";
865     return false;
866   }
867 
868   string debug_info_name = ParseParameterName(map_parameter_proto.name());
869   auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
870   node->set_debug_info(debug_info_ptr);
871   node->set_name(debug_info_name);
872 
873   ParamInfoPtr param_info = std::make_shared<ParamInfo>();
874   param_info->set_name(debug_info_name);
875 
876   MS_LOG(DEBUG) << "Load map parameter name: " << map_parameter_proto.name();
877   // default value
878   if (!map_parameter_proto.has_default_value()) {
879     MS_LOG(ERROR) << "MapTensorProto should have default value: " << map_parameter_proto.name();
880     return false;
881   }
882   const auto &default_value_proto = map_parameter_proto.default_value();
883   auto default_value = BuildValueFromAttributeProto(default_value_proto);
884   if (default_value == nullptr) {
885     MS_LOG(ERROR) << "Build default value from AttributeProto failed.";
886     return false;
887   }
888   // key tensor
889   if (!map_parameter_proto.has_key_tensor()) {
890     MS_LOG(ERROR) << "MapTensorProto should have key tensor: " << map_parameter_proto.name();
891     return false;
892   }
893   const auto &key_tensor_proto = map_parameter_proto.key_tensor();
894   auto key_tensor = GenerateTensorPtrFromTensorProto(key_tensor_proto);
895   if (key_tensor == nullptr) {
896     MS_LOG(ERROR) << "Generate key tensor from TensorProto failed.";
897     return false;
898   }
899   // value tensor
900   if (!map_parameter_proto.has_value_tensor()) {
901     MS_LOG(ERROR) << "MapTensorProto should have value tensor: " << map_parameter_proto.name();
902     return false;
903   }
904   const auto &value_tensor_proto = map_parameter_proto.value_tensor();
905   auto value_tensor = GenerateTensorPtrFromTensorProto(value_tensor_proto);
906   if (value_tensor == nullptr) {
907     MS_LOG(ERROR) << "Generate value tensor from TensorProto failed.";
908     return false;
909   }
910   // status tensor
911   if (!map_parameter_proto.has_status_tensor()) {
912     MS_LOG(ERROR) << "MapTensorProto should have status tensor: " << map_parameter_proto.name();
913     return false;
914   }
915   const auto &status_tensor_proto = map_parameter_proto.status_tensor();
916   auto status_tensor = GenerateTensorPtrFromTensorProto(status_tensor_proto);
917   if (status_tensor == nullptr) {
918     MS_LOG(ERROR) << "Generate status tensor from TensorProto failed.";
919     return false;
920   }
921 
922   auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor, value_tensor, status_tensor, default_value);
923   map_tensor->set_param_info(param_info);
924   node->set_default_param(map_tensor);
925   node->set_abstract(map_tensor->ToAbstract());
926 
927   anfnode_build_map_[map_parameter_proto.name()] = node;
928   return true;
929 }
930 
GetTensorDataFromExternal(const mind_ir::TensorProto & tensor_proto,const tensor::TensorPtr & tensor_info)931 bool MSANFModelParser::GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto,
932                                                  const tensor::TensorPtr &tensor_info) {
933   if (!tensor_proto.has_external_data()) {
934     return false;
935   }
936   const unsigned char *data = nullptr;
937   auto it = tenor_data_.find(tensor_proto.external_data().location());
938   if (it != tenor_data_.end()) {
939     data = it->second.get();
940   } else {
941     std::string file = mindir_path_ + "/" + tensor_proto.external_data().location();
942     if (mindir_dec_key_ != nullptr) {
943       size_t plain_len;
944       auto plain_data = Decrypt(&plain_len, file, mindir_dec_key_, mindir_key_size_, mindir_dec_mode_);
945       if (plain_data == nullptr) {
946         MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
947         return false;
948       }
949       data = plain_data.get();
950       (void)tenor_data_.emplace(tensor_proto.external_data().location(), std::move(plain_data));
951     } else {
952       // Read file
953       std::basic_ifstream<char> fid(file, std::ios::in | std::ios::binary);
954       if (!fid) {
955         MS_LOG(EXCEPTION) << "Open file '" << file << "' failed, please check the correct of the file.";
956       }
957       (void)fid.seekg(0, std::ios_base::end);
958       size_t file_size = static_cast<size_t>(fid.tellg());
959       fid.clear();
960       (void)fid.seekg(0);
961       std::unique_ptr<char[]> plain_data(new (std::nothrow) char[file_size]);
962       if (plain_data == nullptr) {
963         MS_LOG(ERROR) << "Failed to create file buffer, file size: " << file_size << " bytes";
964         return false;
965       }
966       constexpr Byte is_little_endian = 1;
967       constexpr int byte_order_index = 0;
968       (void)fid.read(plain_data.get(), SizeToLong(file_size));
969       fid.close();
970       // if byte order is not same return false
971       if ((plain_data[byte_order_index] == is_little_endian) ^ little_endian()) {
972         MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!";
973         return false;
974       }
975       data = reinterpret_cast<const unsigned char *>(plain_data.get());
976       (void)tenor_data_.emplace(tensor_proto.external_data().location(),
977                                 std::unique_ptr<Byte[]>(reinterpret_cast<Byte *>(plain_data.release())));
978     }
979   }
980   auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
981   MS_EXCEPTION_IF_NULL(tensor_data_buf);
982   MS_EXCEPTION_IF_NULL(data);
983 
984   if (tensor_info->data().nbytes() == 0 || tensor_proto.external_data().length() == 0) {
985     // no need to copy data
986     return true;
987   }
988 
989   auto ret =
990     common::huge_memcpy(tensor_data_buf, tensor_info->data().nbytes(), data + tensor_proto.external_data().offset(),
991                         LongToSize(tensor_proto.external_data().length()));
992   if (ret != EOK) {
993     MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
994     return false;
995   }
996   return true;
997 }
998 
BuildInputForFuncGraph(const ParameterPtr & node,const mind_ir::ValueInfoProto & value_proto)999 bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto) {
1000   MS_EXCEPTION_IF_NULL(node);
1001 
1002   if (!value_proto.has_name()) {
1003     MS_LOG(ERROR) << "mind_ir ValueInfoProto has no name!";
1004     return false;
1005   }
1006   string debug_info_name = ParseParameterName(value_proto.name());
1007   auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
1008   node->set_debug_info(debug_info_ptr);
1009   node->set_name(debug_info_name);
1010 
1011   // Set abstract of the parameter
1012   if (value_proto.tensor_size() > 0) {
1013     const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
1014     auto tensor_info = GetAbsTensorFromTensorProto(tensor_proto);
1015     if (tensor_info == nullptr) {
1016       MS_LOG(ERROR) << "Get tensor_info fail.";
1017       return false;
1018     }
1019     node->set_abstract(tensor_info);
1020     if (tensor_proto.has_ref_key() && top_graph_ != nullptr) {
1021       auto parameters = top_graph_->parameters();
1022       for (const auto &parameter : parameters) {
1023         auto parameter_abs = parameter->abstract();
1024         if (parameter_abs->isa<abstract::AbstractRefTensor>()) {
1025           auto parameter_abs_value = parameter_abs->cast<abstract::AbstractRefPtr>()->ref_key_value();
1026           auto ref_key_value = parameter_abs_value->cast<StringImmPtr>();
1027           if (ref_key_value != nullptr && ref_key_value->value() == tensor_proto.ref_key()) {
1028             node->set_default_param(parameter->cast<ParameterPtr>()->default_param());
1029             break;
1030           }
1031         }
1032       }
1033     }
1034   } else if (value_proto.has_denotation()) {
1035     if (value_proto.denotation() == "UMonadType") {
1036       node->set_abstract(kUMonad->ToAbstract());
1037     } else if (value_proto.denotation() == "IOMonadType") {
1038       node->set_abstract(kIOMonad->ToAbstract());
1039     }
1040     MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
1041   }
1042   if (value_proto.has_attr_info()) {
1043     auto attr_proto = value_proto.attr_info();
1044     // Compatible with the previous proto.
1045     if (attr_proto.has_ref_attr_name()) {
1046       if (!SetNodeAbstractFromAttrProto(attr_proto, node)) {
1047         MS_LOG(ERROR) << "Failed to get abstract for input node " << node->name()
1048                       << " from proto:" << attr_proto.DebugString();
1049       }
1050     } else {
1051       auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto);
1052       if (abs == nullptr) {
1053         MS_LOG(ERROR) << "Failed to get abstract for input node " << node->name()
1054                       << " from attr_proto:" << attr_proto.DebugString();
1055       }
1056       node->set_abstract(abs);
1057     }
1058   }
1059   if (node->abstract() == nullptr) {
1060     MS_LOG(INFO) << "Failed to build abstract of node:" << node->name()
1061                  << " from ValueInfoProto:" << value_proto.DebugString();
1062   }
1063   anfnode_build_map_[value_proto.name()] = node;
1064   return true;
1065 }
1066 
ImportParametersForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1067 bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
1068                                                 const mind_ir::GraphProto &importProto) {
1069   MS_EXCEPTION_IF_NULL(outputFuncGraph);
1070   MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
1071   for (int i = 0; i < importProto.input_size(); ++i) {
1072     const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
1073     if (is_kernel_graph_ && anfnode_build_map_.count(input_proto.name()) > 0) {
1074       continue;
1075     }
1076     if (!BuildInputForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
1077       MS_LOG(ERROR) << "Build input for funcgraph fail at index: " << i;
1078       return false;
1079     }
1080   }
1081 
1082   MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
1083   for (int i = 0; i < importProto.parameter_size(); ++i) {
1084     const mind_ir::TensorProto &parameter_proto = importProto.parameter(i);
1085     if (is_kernel_graph_ && anfnode_build_map_.count(parameter_proto.name()) > 0) {
1086       continue;
1087     }
1088     if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
1089       MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
1090       return false;
1091     }
1092   }
1093   outputFuncGraph->set_fv_param_count(IntToSize(importProto.parameter_size()));
1094   return true;
1095 }
1096 
ImportMapParametersForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1097 bool MSANFModelParser::ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph,
1098                                                    const mind_ir::GraphProto &importProto) {
1099   MS_EXCEPTION_IF_NULL(outputFuncGraph);
1100   MS_LOG(INFO) << "All MapParameters size is: " << importProto.map_parameter_size();
1101   for (int i = 0; i < importProto.map_parameter_size(); ++i) {
1102     const mind_ir::MapTensorProto &map_parameter_proto = importProto.map_parameter(i);
1103     if (!BuildMapParameterFromMapTensorProto(outputFuncGraph->add_parameter(), map_parameter_proto)) {
1104       MS_LOG(ERROR) << "Build map parameter for funcgraph fail at index: " << i;
1105       return false;
1106     }
1107   }
1108   outputFuncGraph->set_fv_param_count(IntToSize(importProto.parameter_size()));
1109   return true;
1110 }
1111 
ObtainCNodeAttrInTypeForm(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)1112 bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
1113   MS_EXCEPTION_IF_NULL(prim);
1114   const int attr_tensor_type = attr_proto.tensors(0).data_type();
1115   if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
1116     MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
1117     return false;
1118   }
1119   (void)prim->AddAttr(attr_proto.name(), TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
1120   return true;
1121 }
1122 
ParseAttrInScalarForm(const mind_ir::AttributeProto & attr_proto,int index)1123 ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index) {
1124   const int attr_type = static_cast<int>(attr_proto.type());
1125   switch (attr_type) {
1126     case mind_ir::AttributeProto_AttributeType_STRING: {
1127       return ParseAttrInScalar_string_string(attr_proto, index);
1128     }
1129     case mind_ir::AttributeProto_AttributeType_INT8: {
1130       return ParseAttrInScalar_int8_t_int8_t(attr_proto, index);
1131     }
1132     case mind_ir::AttributeProto_AttributeType_INT16: {
1133       return ParseAttrInScalar_int16_t_int16_t(attr_proto, index);
1134     }
1135     case mind_ir::AttributeProto_AttributeType_INT32: {
1136       return ParseAttrInScalar_int32_t_int32_t(attr_proto, index);
1137     }
1138     case mind_ir::AttributeProto_AttributeType_INT64: {
1139       return ParseAttrInScalar_int64_t_int64_t(attr_proto, index);
1140     }
1141     case mind_ir::AttributeProto_AttributeType_UINT8: {
1142       return ParseAttrInScalar_uint8_t_uint8_t(attr_proto, index);
1143     }
1144     case mind_ir::AttributeProto_AttributeType_UINT16: {
1145       return ParseAttrInScalar_uint16_t_uint16_t(attr_proto, index);
1146     }
1147     case mind_ir::AttributeProto_AttributeType_UINT32: {
1148       return ParseAttrInScalar_uint32_t_uint32_t(attr_proto, index);
1149     }
1150     case mind_ir::AttributeProto_AttributeType_UINT64: {
1151       return ParseAttrInScalar_uint64_t_uint64_t(attr_proto, index);
1152     }
1153     case mind_ir::AttributeProto_AttributeType_FLOAT: {
1154       return ParseAttrInScalar_float_float(attr_proto, index);
1155     }
1156     case mind_ir::AttributeProto_AttributeType_DOUBLE: {
1157       return ParseAttrInScalar_double_double(attr_proto, index);
1158     }
1159     case mind_ir::AttributeProto_AttributeType_BOOL: {
1160       return ParseAttrInScalar_int32_t_bool(attr_proto, index);
1161     }
1162     case mind_ir::AttributeProto_AttributeType_TENSORS: {
1163       const int attr_tensor_type = attr_proto.tensors(index).data_type();
1164       if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
1165         MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
1166         return nullptr;
1167       }
1168       return TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
1169     }
1170     default:
1171       MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
1172       return nullptr;
1173   }
1174 }
1175 
ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto & attr_proto,mindspore::HashMap<std::string,ValuePtr> * multi_value_map)1176 void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
1177                                                    mindspore::HashMap<std::string, ValuePtr> *multi_value_map) {
1178   string name;
1179   auto func = [&name, &multi_value_map, this](const mind_ir::AttributeProto &attr_proto, int length) -> void {
1180     for (int i = 0; i < length; ++i) {
1181       auto res = this->ParseAttrInScalarForm(attr_proto, i);
1182       name = "value" + std::to_string(i + 1);
1183       (void)multi_value_map->emplace(name, res);
1184     }
1185   };
1186   func(attr_proto, attr_proto.ints_size());
1187   func(attr_proto, attr_proto.doubles_size());
1188   func(attr_proto, attr_proto.floats_size());
1189   func(attr_proto, attr_proto.strings_size());
1190   func(attr_proto, attr_proto.tensors_size());
1191 }
1192 
ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto & attr_proto)1193 ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) {
1194   const int attr_type = static_cast<int>(attr_proto.type());
1195   switch (attr_type) {
1196     case mind_ir::AttributeProto_AttributeType_STRING: {
1197       return ParseAttrInSingleScalar_string_string(attr_proto);
1198     }
1199     case mind_ir::AttributeProto_AttributeType_INT8: {
1200       return ParseAttrInSingleScalar_int8_t_int8_t(attr_proto);
1201     }
1202     case mind_ir::AttributeProto_AttributeType_INT16: {
1203       return ParseAttrInSingleScalar_int16_t_int16_t(attr_proto);
1204     }
1205     case mind_ir::AttributeProto_AttributeType_INT32: {
1206       return ParseAttrInSingleScalar_int32_t_int32_t(attr_proto);
1207     }
1208     case mind_ir::AttributeProto_AttributeType_INT64: {
1209       return ParseAttrInSingleScalar_int64_t_int64_t(attr_proto);
1210     }
1211     case mind_ir::AttributeProto_AttributeType_UINT8: {
1212       return ParseAttrInSingleScalar_uint8_t_uint8_t(attr_proto);
1213     }
1214     case mind_ir::AttributeProto_AttributeType_UINT16: {
1215       return ParseAttrInSingleScalar_uint16_t_uint16_t(attr_proto);
1216     }
1217     case mind_ir::AttributeProto_AttributeType_UINT32: {
1218       return ParseAttrInSingleScalar_uint32_t_uint32_t(attr_proto);
1219     }
1220     case mind_ir::AttributeProto_AttributeType_UINT64: {
1221       return ParseAttrInSingleScalar_uint64_t_uint64_t(attr_proto);
1222     }
1223     case mind_ir::AttributeProto_AttributeType_FLOAT: {
1224       return ParseAttrInSingleScalar_float_float(attr_proto);
1225     }
1226     case mind_ir::AttributeProto_AttributeType_DOUBLE: {
1227       return ParseAttrInSingleScalar_double_double(attr_proto);
1228     }
1229     case mind_ir::AttributeProto_AttributeType_BOOL: {
1230       return ParseAttrInSingleScalar_int32_t_bool(attr_proto);
1231     }
1232     default:
1233       MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
1234       return nullptr;
1235   }
1236 }
1237 
ObtainCNodeAttrInTensorForm(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)1238 bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
1239                                                    const mind_ir::AttributeProto &attr_proto) {
1240   MS_EXCEPTION_IF_NULL(prim);
1241   const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0);
1242   auto tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor);
1243   if (tensor_info == nullptr) {
1244     MS_LOG(ERROR) << "Failed to get attr[" << attr_proto.name() << "] for node " << prim->ToString()
1245                   << " from the proto.";
1246     return false;
1247   }
1248   (void)prim->AddAttr(attr_proto.name(), MakeValue(tensor_info));
1249   return true;
1250 }
1251 
SetPrimitiveAttrWithType(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)1252 bool MSANFModelParser::SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
1253   MS_EXCEPTION_IF_NULL(prim);
1254   const std::string &attr_name = attr_proto.name();
1255   auto value = GetValueFromAttributeProto(attr_proto);
1256   if (value == nullptr) {
1257     MS_LOG(ERROR) << "Failed to get value from proto.\n proto info:" << attr_proto.name();
1258     return false;
1259   }
1260   const std::string &op_type = prim->name();
1261   if (is_kernel_graph_) {
1262     (void)prim->AddAttr(attr_name, value);
1263     return true;
1264   }
1265   CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value);
1266   // Compatible with older versions.
1267   if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa<StringImm>()) {
1268     auto str_dtype = GetValue<std::string>(value);
1269     if (str_dtype == "int32") {
1270       int64_t index = 3;
1271       (void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
1272     }
1273     MS_EXCEPTION(NotSupportError)
1274       << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
1275       << value->ToString();
1276   }
1277   (void)prim->AddAttr(attr_name, value);
1278   return true;
1279 }
1280 
GetAttrValueForCNode(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)1281 bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
1282   MS_EXCEPTION_IF_NULL(prim);
1283   const std::string &attr_name = attr_proto.name();
1284   if (!attr_proto.has_ref_attr_name()) {
1285     MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
1286     return false;
1287   }
1288   const std::string &ref_attr_name = attr_proto.ref_attr_name();
1289   ParseForm type = GetParseFormType(ref_attr_name);
1290   mindspore::HashMap<std::string, ValuePtr> multi_value_map;
1291   switch (type) {
1292     case FORM_PARSE_TYPE: {
1293       (void)ObtainCNodeAttrInTypeForm(prim, attr_proto);
1294       break;
1295     }
1296     case FORM_PARSE_SCALAR: {
1297       if (ref_attr_name.find("value0") != std::string::npos) {
1298         ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
1299         MS_EXCEPTION_IF_NULL(res);
1300         const std::string &op_type = prim->name();
1301         if (is_kernel_graph_) {
1302           (void)prim->AddAttr(attr_name, res);
1303           break;
1304         }
1305         CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res);
1306         if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && res->isa<StringImm>()) {
1307           auto str_dtype = GetValue<std::string>(res);
1308           if (str_dtype == "int32") {
1309             int64_t index = 3;
1310             (void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
1311             break;
1312           }
1313           MS_EXCEPTION(NotSupportError)
1314             << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
1315             << res->ToString();
1316         }
1317         (void)prim->AddAttr(attr_name, res);
1318         break;
1319       } else if (ref_attr_name.find("Tuple[]") != std::string::npos) {
1320         (void)prim->AddAttr(attr_name, std::make_shared<ValueTuple>(std::vector<ValuePtr>()));
1321         break;
1322       } else if (ref_attr_name.find("List[]") != std::string::npos) {
1323         (void)prim->AddAttr(attr_name, std::make_shared<ValueList>(std::vector<ValuePtr>()));
1324         break;
1325       }
1326       ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
1327       break;
1328     }
1329     case FORM_PARSE_TENSOR: {
1330       (void)ObtainCNodeAttrInTensorForm(prim, attr_proto);
1331       break;
1332     }
1333     case FORM_PARSE_NONE: {
1334       (void)prim->AddAttr(attr_name, kNone);
1335       break;
1336     }
1337     default:
1338       MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
1339       return false;
1340   }
1341 
1342   if (type == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
1343     if (ref_attr_name.find("Tuple") != std::string::npos) {
1344       auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
1345       (void)prim->AddAttr(attr_name, value_tuple_ptr);
1346     } else {
1347       auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
1348       (void)prim->AddAttr(attr_name, value_list_ptr);
1349     }
1350   }
1351   return true;
1352 }
1353 
ObtainValueNodeInTensorForm(const std::string & value_node_name,const mind_ir::TensorProto & attr_tensor)1354 bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
1355                                                    const mind_ir::TensorProto &attr_tensor) {
1356   tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor);
1357   if (tensor_info == nullptr) {
1358     MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
1359     return false;
1360   }
1361   auto new_value_node = NewValueNode(MakeValue(tensor_info));
1362   MS_EXCEPTION_IF_NULL(new_value_node);
1363   auto tensor_abstract = tensor_info->ToAbstract();
1364   MS_EXCEPTION_IF_NULL(tensor_abstract);
1365   new_value_node->set_abstract(tensor_abstract);
1366   anfnode_build_map_[value_node_name] = new_value_node;
1367   return true;
1368 }
1369 
ObtainValueNodeInTupleTensorForm(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)1370 bool MSANFModelParser::ObtainValueNodeInTupleTensorForm(const std::string &value_node_name,
1371                                                         const mind_ir::AttributeProto &attr_proto) {
1372   std::vector<tensor::TensorPtr> tensor_vec;
1373   for (int i = 0; i < attr_proto.tensors_size(); ++i) {
1374     mind_ir::TensorProto attr_tensor = attr_proto.tensors(i);
1375     const int attr_tensor_type = attr_tensor.data_type();
1376     ShapeVector shape;
1377     for (int j = 0; j < attr_tensor.dims_size(); ++j) {
1378       shape.push_back(attr_tensor.dims(j));
1379     }
1380     tensor::TensorPtr tensor_info = nullptr;
1381     if (!attr_tensor.has_compression_type() ||
1382         attr_tensor.compression_type() == mind_ir::TensorProto_CompressionType_NO_COMPRESSION) {
1383       tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
1384     } else {
1385       auto compression_type = static_cast<TensorCompressionType>(static_cast<int>(attr_tensor.compression_type()));
1386       size_t data_size = 0;
1387       if (!attr_tensor.has_external_data()) {
1388         data_size = attr_tensor.raw_data().size();
1389       } else {
1390         data_size = LongToSize(attr_tensor.external_data().length());
1391       }
1392       tensor_info =
1393         std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape, data_size, compression_type);
1394     }
1395     const std::string &tensor_buf = attr_tensor.raw_data();
1396     auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
1397     errno_t ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
1398     if (ret != EOK) {
1399       MS_LOG(ERROR) << "Obtain ValueNode in TupleTensorForm occur memcpy_s error.";
1400       return false;
1401     }
1402     tensor_vec.push_back(tensor_info);
1403   }
1404   auto value = MakeValue(tensor_vec);
1405   auto new_value_node = NewValueNode(value);
1406   new_value_node->set_abstract(value->ToAbstract());
1407   anfnode_build_map_[value_node_name] = new_value_node;
1408   return true;
1409 }
1410 
ObtainValueNodeInTypeForm(const std::string & value_node_name,const mind_ir::TensorProto & attr_tensor)1411 bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
1412                                                  const mind_ir::TensorProto &attr_tensor) {
1413   const int attr_tensor_type = attr_tensor.data_type();
1414   auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
1415   if (iter == kDefaultValueSwitchMap.end()) {
1416     MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
1417     return false;
1418   }
1419   auto value = TypeIdToType(iter->second);
1420   auto new_value_node = NewValueNode(value);
1421   new_value_node->set_abstract(value->ToAbstract());
1422   anfnode_build_map_[value_node_name] = new_value_node;
1423   return true;
1424 }
1425 
ObtainValueNodeInNoneForm(const std::string & value_node_name)1426 bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_name) {
1427   auto new_value_node = NewValueNode(kNone);
1428   MS_EXCEPTION_IF_NULL(new_value_node);
1429   new_value_node->set_abstract(kNone->ToAbstract());
1430   anfnode_build_map_[value_node_name] = new_value_node;
1431   return true;
1432 }
1433 
ObtainValueNodeInMonadForm(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)1434 bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name,
1435                                                   const mind_ir::AttributeProto &attr_proto) {
1436   const std::string &ref_attr_name = attr_proto.ref_attr_name();
1437   if (ref_attr_name.find("UMonad") != std::string::npos) {
1438     auto monad_abs = kUMonad->ToAbstract();
1439     auto new_value_node = NewValueNode(kUMonad);
1440     MS_EXCEPTION_IF_NULL(new_value_node);
1441     new_value_node->set_abstract(monad_abs);
1442     anfnode_build_map_[value_node_name] = new_value_node;
1443   } else if (ref_attr_name.find("IOMonad") != std::string::npos) {
1444     auto monad_abs = kIOMonad->ToAbstract();
1445     auto new_value_node = NewValueNode(kIOMonad);
1446     MS_EXCEPTION_IF_NULL(new_value_node);
1447     new_value_node->set_abstract(monad_abs);
1448     anfnode_build_map_[value_node_name] = new_value_node;
1449   } else {
1450     return false;
1451   }
1452   return true;
1453 }
1454 
ObtainValueInDictionaryForm(const mind_ir::AttributeProto & attr_proto)1455 ValuePtr MSANFModelParser::ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto) {
1456   std::vector<std::pair<ValuePtr, ValuePtr>> key_values;
1457   for (int i = 0; i < attr_proto.values_size(); ++i) {
1458     const mind_ir::AttributeProto &key_value_proto = attr_proto.values(i);
1459     if (!key_value_proto.has_name()) {
1460       MS_LOG(INTERNAL_EXCEPTION) << "Dict type AttributeProto should has name as key of dictionary";
1461     }
1462     auto key = std::make_shared<abstract::AbstractScalar>(key_value_proto.name())->BuildValue();
1463     MS_EXCEPTION_IF_NULL(key);
1464     auto &values = key_value_proto.values();
1465     if (values.size() != 1) {
1466       MS_LOG(INTERNAL_EXCEPTION) << "Dict type AttributeProto should has exactly one value, but got " << values.size()
1467                                  << " value(s).";
1468     }
1469     auto &value = values[0];
1470     switch (value.type()) {
1471       case mind_ir::AttributeProto_AttributeType_TENSORS: {
1472         const mind_ir::TensorProto &tensor_proto = value.tensors(0);
1473         if (tensor_proto.has_raw_data()) {
1474           // For real tensor.
1475           tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
1476           if (tensor_info == nullptr) {
1477             MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
1478             return nullptr;
1479           }
1480           (void)key_values.emplace_back(std::make_pair(key, tensor_info));
1481         } else {
1482           // For data type.
1483           const int attr_tensor_type = tensor_proto.data_type();
1484           auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
1485           if (iter == kDefaultValueSwitchMap.end()) {
1486             MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
1487             return nullptr;
1488           }
1489           (void)key_values.emplace_back(std::make_pair(key, TypeIdToType(iter->second)));
1490         }
1491         break;
1492       }
1493       case mind_ir::AttributeProto_AttributeType_TUPLE:
1494       case mind_ir::AttributeProto_AttributeType_LIST: {
1495         auto sequence_value = ObtainValueInSequenceForm(value);
1496         if (sequence_value == nullptr) {
1497           MS_LOG(ERROR) << "Failed to get the sequence value";
1498           return nullptr;
1499         }
1500         (void)key_values.emplace_back(std::make_pair(key, sequence_value));
1501         break;
1502       }
1503       case mind_ir::AttributeProto_AttributeType_DICT: {
1504         auto dict_value = ObtainValueInDictionaryForm(value);
1505         if (dict_value == nullptr) {
1506           MS_LOG(ERROR) << "Failed to get the dictionary value";
1507           return nullptr;
1508         }
1509         (void)key_values.emplace_back(std::make_pair(key, dict_value));
1510         break;
1511       }
1512       default: {
1513         // For string and scalar.
1514         auto scalar_value = ParseAttrInScalarForm(value, 0);
1515         if (scalar_value == nullptr) {
1516           MS_LOG(ERROR) << "Failed to get the scalar for ValueNode.";
1517           return nullptr;
1518         }
1519         (void)key_values.emplace_back(std::make_pair(key, scalar_value));
1520       }
1521     }
1522   }
1523   return std::make_shared<ValueDictionary>(key_values);
1524 }
1525 
ObtainValueInSequenceForm(const mind_ir::AttributeProto & attr_proto)1526 ValuePtr MSANFModelParser::ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto) {
1527   std::vector<ValuePtr> vec;
1528   for (int i = 0; i < attr_proto.values_size(); ++i) {
1529     mind_ir::AttributeProto elem_attr_proto = attr_proto.values(i);
1530     switch (elem_attr_proto.type()) {
1531       case mind_ir::AttributeProto_AttributeType_TENSORS: {
1532         mind_ir::TensorProto tensor_proto = elem_attr_proto.tensors(0);
1533         if (tensor_proto.has_raw_data()) {
1534           // For real tensor.
1535           tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
1536           if (tensor_info == nullptr) {
1537             MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
1538             return nullptr;
1539           }
1540           (void)vec.emplace_back(tensor_info);
1541         } else if (tensor_proto.name() == kQuantParam) {
1542           auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
1543           if (!quantization_param_vector.empty()) {
1544             (void)vec.emplace_back(quantization_param_vector[0]);
1545           }
1546         } else {
1547           // For data type.
1548           const int attr_tensor_type = tensor_proto.data_type();
1549           auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
1550           if (iter == kDefaultValueSwitchMap.end()) {
1551             MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
1552             return nullptr;
1553           }
1554           (void)vec.emplace_back(TypeIdToType(iter->second));
1555         }
1556         break;
1557       }
1558       case mind_ir::AttributeProto_AttributeType_TUPLE:
1559       case mind_ir::AttributeProto_AttributeType_LIST: {
1560         auto sequence_value = ObtainValueInSequenceForm(elem_attr_proto);
1561         if (sequence_value == nullptr) {
1562           MS_LOG(ERROR) << "Failed to get the sequence value";
1563           return nullptr;
1564         }
1565         (void)vec.emplace_back(sequence_value);
1566         break;
1567       }
1568       default: {
1569         // For string and scalar.
1570         auto scalar_value = ParseAttrInScalarForm(elem_attr_proto, 0);
1571         if (scalar_value == nullptr) {
1572           MS_LOG(ERROR) << "Failed to get the scalar for ValueNode.";
1573           return nullptr;
1574         }
1575         (void)vec.emplace_back(scalar_value);
1576       }
1577     }
1578   }
1579   auto type = attr_proto.type();
1580   ValuePtr value_sequence;
1581   if (type == mind_ir::AttributeProto_AttributeType_TUPLE) {
1582     value_sequence = std::make_shared<ValueTuple>(vec);
1583   } else if (type == mind_ir::AttributeProto_AttributeType_LIST) {
1584     value_sequence = std::make_shared<ValueList>(vec);
1585   } else {
1586     MS_LOG(INTERNAL_EXCEPTION) << "The attribute type should be tuple or list, but it is " << type;
1587   }
1588 
1589   return value_sequence;
1590 }
1591 
BuildValueFromAttributeProto(const mind_ir::AttributeProto & attr_proto)1592 ValuePtr MSANFModelParser::BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
1593   switch (attr_proto.type()) {
1594     case mind_ir::AttributeProto_AttributeType_TENSORS: {
1595       const auto &tensor_proto = attr_proto.tensors(0);
1596       if (tensor_proto.has_raw_data()) {
1597         // For real tensor.
1598         tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
1599         if (tensor_info == nullptr) {
1600           MS_LOG(ERROR) << "Failed to GenerateTensorPtrFromTensorProto.";
1601           return nullptr;
1602         }
1603         return MakeValue(tensor_info);
1604       } else {
1605         // For data type.
1606         const int attr_tensor_type = tensor_proto.data_type();
1607         auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
1608         if (iter == kDefaultValueSwitchMap.end()) {
1609           MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
1610           return nullptr;
1611         }
1612         return TypeIdToType(iter->second);
1613       }
1614     }
1615     case mind_ir::AttributeProto_AttributeType_NONE: {
1616       return kNone;
1617     }
1618     case mind_ir::AttributeProto_AttributeType_UMONAD: {
1619       return kUMonad;
1620     }
1621     case mind_ir::AttributeProto_AttributeType_IOMONAD: {
1622       return kIOMonad;
1623     }
1624     case mind_ir::AttributeProto_AttributeType_TUPLE:
1625     case mind_ir::AttributeProto_AttributeType_LIST: {
1626       return ObtainValueInSequenceForm(attr_proto);
1627     }
1628     case mind_ir::AttributeProto_AttributeType_CLASS_TYPE: {
1629       auto class_type = static_cast<std::string>(attr_proto.s());
1630       return std::make_shared<MindIRClassType>(class_type);
1631     }
1632     case mind_ir::AttributeProto_AttributeType_TYPE_NULL: {
1633       return kTypeNull;
1634     }
1635     case mind_ir::AttributeProto_AttributeType_NAME_SPACE: {
1636       auto name_space = static_cast<std::string>(attr_proto.s());
1637       return std::make_shared<MindIRNameSpace>(name_space);
1638     }
1639     case mind_ir::AttributeProto_AttributeType_SYMBOL: {
1640       auto symbol = static_cast<std::string>(attr_proto.s());
1641       return std::make_shared<MindIRSymbol>(symbol);
1642     }
1643     default: {
1644       return ObtainCNodeAttrInSingleScalarForm(attr_proto);
1645     }
1646   }
1647 }
1648 
GetAttrValueForValueNodeWithType(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)1649 bool MSANFModelParser::GetAttrValueForValueNodeWithType(const std::string &value_node_name,
1650                                                         const mind_ir::AttributeProto &attr_proto) {
1651   auto value = BuildValueFromAttributeProto(attr_proto);
1652   if (value == nullptr) {
1653     MS_LOG(ERROR) << "Failed to build value from AttributeProto while building valuenode: " << value_node_name;
1654     return false;
1655   }
1656   auto abstract = value->ToAbstract();
1657   MS_EXCEPTION_IF_NULL(abstract);
1658   ValueNodePtr new_value_node = NewValueNode(value);
1659   new_value_node->set_abstract(abstract);
1660   anfnode_build_map_[value_node_name] = new_value_node;
1661   return true;
1662 }
1663 
GetAttrValueForValueNode(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)1664 bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
1665                                                 const mind_ir::AttributeProto &attr_proto) {
1666   if (!attr_proto.has_ref_attr_name()) {
1667     MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
1668     return false;
1669   }
1670   const std::string &ref_attr_name = attr_proto.ref_attr_name();
1671   ParseForm type = GetParseFormType(ref_attr_name);
1672   ValueNodePtr new_value_node;
1673   mindspore::HashMap<std::string, ValuePtr> multi_value_map;
1674   switch (type) {
1675     case FORM_PARSE_TYPE: {
1676       (void)ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0));
1677       break;
1678     }
1679     case FORM_PARSE_SCALAR: {
1680       if (ref_attr_name.find("value0") != std::string::npos) {
1681         auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
1682         MS_EXCEPTION_IF_NULL(res);
1683         new_value_node = NewValueNode(res);
1684         new_value_node->set_abstract(res->ToAbstract());
1685         anfnode_build_map_[value_node_name] = new_value_node;
1686         break;
1687       }
1688       if (ref_attr_name.find("Tuple[]") != std::string::npos) {
1689         MS_LOG(INFO) << "Build Tuple() ValueNode for primitive.";
1690         ValuePtr res = MakeValue(std::vector<ValuePtr>{});
1691         new_value_node = NewValueNode(res);
1692         new_value_node->set_abstract(res->ToAbstract());
1693         anfnode_build_map_[value_node_name] = new_value_node;
1694         break;
1695       }
1696       if (ref_attr_name.find("Tuple[value") != std::string::npos && attr_proto.tensors_size() > 1) {
1697         MS_LOG(INFO) << "Build TupleTensor ValueNode for primitive.";
1698         (void)ObtainValueNodeInTupleTensorForm(value_node_name, attr_proto);
1699         break;
1700       }
1701       ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
1702       break;
1703     }
1704     case FORM_PARSE_TENSOR: {
1705       (void)ObtainValueNodeInTensorForm(value_node_name, attr_proto.tensors(0));
1706       break;
1707     }
1708     case FORM_PARSE_NONE: {
1709       (void)ObtainValueNodeInNoneForm(value_node_name);
1710       break;
1711     }
1712     case FORM_PARSE_MONAD: {
1713       (void)ObtainValueNodeInMonadForm(value_node_name, attr_proto);
1714       break;
1715     }
1716     default:
1717       MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
1718       return false;
1719   }
1720   if (type == FORM_PARSE_SCALAR && !multi_value_map.empty()) {
1721     if (ref_attr_name.find("Tuple") != std::string::npos) {
1722       auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
1723       if (value_tuple_ptr == nullptr) {
1724         MS_LOG(ERROR) << "Failed to build the value of the ValueNode, attr_proto:" << attr_proto.DebugString()
1725                       << ", value_node_name:" << value_node_name;
1726         return false;
1727       }
1728       new_value_node = NewValueNode(value_tuple_ptr);
1729       new_value_node->set_abstract(value_tuple_ptr->ToAbstract());
1730     } else {
1731       auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
1732       if (value_list_ptr == nullptr) {
1733         MS_LOG(ERROR) << "Failed to build the value of the ValueNode, attr_proto:" << attr_proto.DebugString()
1734                       << ", value_node_name:" << value_node_name;
1735         return false;
1736       }
1737       new_value_node = NewValueNode(value_list_ptr);
1738       new_value_node->set_abstract(value_list_ptr->ToAbstract());
1739     }
1740     anfnode_build_map_[value_node_name] = new_value_node;
1741   }
1742   return true;
1743 }
1744 
BuildValueNodeForFuncGraph(const mind_ir::NodeProto & node_proto)1745 bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto) {
1746   if (node_proto.output_size() == 0) {
1747     MS_LOG(ERROR) << "The Proto output is empty.";
1748     return false;
1749   }
1750   const std::string &value_node_name = node_proto.output(0);
1751   const mind_ir::AttributeProto &attr_proto = node_proto.attribute(0);
1752   if (attr_proto.has_ref_attr_name()) {
1753     return GetAttrValueForValueNode(value_node_name, attr_proto);
1754   }
1755   return GetAttrValueForValueNodeWithType(value_node_name, attr_proto);
1756 }
1757 
GetAbstractForNode(const mind_ir::AttributeProto & attr_proto)1758 mindspore::HashMap<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForNode(
1759   const mind_ir::AttributeProto &attr_proto) {
1760   mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
1761   for (int i = 0; i < attr_proto.tensors_size(); ++i) {
1762     const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i);
1763     auto tensor_info = GetAbsTensorFromTensorProto(attr_tensor);
1764     (void)kv.emplace(attr_tensor.name(), tensor_info);
1765   }
1766   return kv;
1767 }
1768 
BuildOperatorNode(const mind_ir::NodeProto & node_proto)1769 AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_proto) {
1770   const std::string kOperatorTypeFlag = std::string("REF::");
1771   const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
1772   const std::string &node_type = node_proto.op_type();
1773   MS_LOG(DEBUG) << "Process Operator:" << node_type;
1774   // Operator maybe CNode,FuncGraph or Parameter.
1775 
1776   if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
1777     auto anfNode = GetAnfNode(node_type.substr(kOpTypeFlagSize));
1778     if (anfNode == nullptr) {
1779       MS_LOG(ERROR) << "Can't find the ref:" << node_type;
1780       return nullptr;
1781     }
1782     return anfNode;
1783   }
1784 
1785   // Operator is  primitive.
1786   std::shared_ptr<Primitive> prim;
1787   auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
1788   if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
1789     prim = op_primc_fns[node_type]();
1790   } else {
1791     if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
1792       auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
1793       prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
1794       MS_EXCEPTION_IF_NULL(prim);
1795       prim->set_instance_name(op_name);
1796     } else {
1797       MS_LOG(DEBUG) << "Special node_type: " << node_type;
1798       prim = std::make_shared<Primitive>(node_type);
1799       MS_EXCEPTION_IF_NULL(prim);
1800       prim->set_instance_name(node_type);
1801     }
1802   }
1803   prim->set_attr("is_load", MakeValue(true));
1804   return NewValueNodeWithAbstract(prim);
1805 }
1806 
SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr & node_ptr)1807 bool MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr) {
1808   auto primitive = GetCNodePrimitive(node_ptr);
1809   if (primitive != nullptr) {
1810     auto node_type = primitive->name();
1811     if (node_type == "UpdateState") {
1812       node_ptr->set_abstract(kUMonad->ToAbstract());
1813     } else if (node_type == "Depend") {
1814       node_ptr->set_abstract(kBool->ToAbstract());
1815     } else {
1816       auto cnode_ptr = node_ptr->cast<CNodePtr>();
1817       AbstractBasePtrList elem;
1818       for (size_t index = 1; index < cnode_ptr->size(); ++index) {
1819         auto abs = cnode_ptr->input(index)->abstract();
1820         if (abs != nullptr) {
1821           if (abs->GetValueTrack() == nullptr) {
1822             abs->set_value(kValueAny);
1823           }
1824           elem.push_back(abs);
1825         }
1826       }
1827       if (!elem.empty()) {
1828         node_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
1829       }
1830     }
1831   } else {
1832     MS_LOG(ERROR) << "Failed to get the abstract of node:" << node_ptr->DebugString();
1833     return false;
1834   }
1835   return true;
1836 }
1837 
1838 // Set CNode abstract.
SetCNodeAbstract(const mind_ir::AttributeProto & attr_proto,const CNodePtr & cnode_ptr)1839 void MSANFModelParser::SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr) {
1840   if (attr_proto.has_ref_attr_name()) {
1841     if (!SetNodeAbstractFromAttrProto(attr_proto, cnode_ptr)) {
1842       MS_LOG(ERROR) << "Failed to get CNode abstract from proto.";
1843     }
1844   } else {
1845     auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto);
1846     cnode_ptr->set_abstract(abs);
1847   }
1848   if (cnode_ptr->abstract() == nullptr) {
1849     MS_LOG(INFO) << "Failed to Build CNode abstract from proto. CNode: " << cnode_ptr->ToString()
1850                  << " attr_proto: " << attr_proto.DebugString();
1851     node_abstract_protos_.push_back(std::pair(cnode_ptr, &attr_proto));
1852   }
1853 }
1854 
BuildCNodeForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::NodeProto & node_proto)1855 CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
1856                                                   const mind_ir::NodeProto &node_proto) {
1857   MS_EXCEPTION_IF_NULL(outputFuncGraph);
1858   if (!node_proto.has_op_type()) {
1859     MS_LOG(ERROR) << "Get CNode op_type failed!";
1860     return nullptr;
1861   }
1862   if (node_proto.output_size() <= 0) {
1863     MS_LOG(ERROR) << "Get CNode out failed!";
1864     return nullptr;
1865   }
1866   const std::string &node_name = node_proto.output(0);
1867   MS_LOG(DEBUG) << "Process CNode: " << node_name;
1868   // Build inputs.
1869   std::vector<AnfNodePtr> inputs;
1870   auto operator_node = BuildOperatorNode(node_proto);
1871   if (operator_node == nullptr) {
1872     MS_LOG(ERROR) << "Build operator node " << node_name << " failed!";
1873     return nullptr;
1874   }
1875   inputs.push_back(operator_node);
1876   for (int i = 0; i < node_proto.input_size(); ++i) {
1877     auto anfNode = GetAnfNode(node_proto.input(i));
1878     if (anfNode == nullptr) {
1879       MS_LOG(ERROR) << node_name << " input " << i << node_proto.input(i) << "can't find in nodes have parsed";
1880       return nullptr;
1881     }
1882     inputs.push_back(anfNode);
1883   }
1884 
1885   CNodePtr cnode_ptr = outputFuncGraph->FuncGraph::NewCNode(inputs);
1886   MS_EXCEPTION_IF_NULL(cnode_ptr);
1887   if (anfnode_build_map_.count(node_name) > 0) {
1888     MS_LOG(INTERNAL_EXCEPTION) << "Duplicate CNode name: " << node_name;
1889   }
1890   const std::string &fullname_with_scope = node_proto.domain();
1891   string debug_info_name = ParseCNodeName(node_name);
1892   auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
1893   cnode_ptr->set_debug_info(debug_info_ptr);
1894   cnode_ptr->set_fullname_with_scope(fullname_with_scope);
1895   cnode_ptr->set_load_flag(true);
1896   anfnode_build_map_[node_name] = cnode_ptr;
1897 
1898   // Set Abstract and prim attr for CNode
1899   SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr);
1900   if (!BuildAttrForCNode(cnode_ptr, node_proto)) {
1901     MS_LOG(ERROR) << "Failed build attr for node: " << cnode_ptr->DebugString()
1902                   << ", proto: " << node_proto.DebugString();
1903   }
1904   return cnode_ptr;
1905 }
1906 
BuildReturnForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1907 bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
1908                                                const mind_ir::GraphProto &importProto) {
1909   MS_EXCEPTION_IF_NULL(outputFuncGraph);
1910   std::vector<AnfNodePtr> inputs;
1911   if (importProto.output_size() > 1) {
1912     inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
1913     AbstractBasePtrList elem;
1914     for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
1915       const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
1916       const std::string &out_tuple = output_node.name();
1917       auto anfNode = GetAnfNode(out_tuple);
1918       if (anfNode == nullptr) {
1919         MS_LOG(ERROR) << "Miss return node: " << out_tuple;
1920         return false;
1921       }
1922       inputs.push_back(anfNode);
1923       elem.push_back(anfNode->abstract());
1924     }
1925     auto maketuple_ptr = outputFuncGraph->FuncGraph::NewCNode(inputs);
1926     MS_EXCEPTION_IF_NULL(maketuple_ptr);
1927     maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
1928     inputs.clear();
1929     inputs.push_back(NewValueNode(prim::kPrimReturn));
1930     inputs.push_back(maketuple_ptr);
1931     auto return_node = outputFuncGraph->FuncGraph::NewCNode(inputs);
1932     MS_EXCEPTION_IF_NULL(return_node);
1933     return_node->set_abstract(maketuple_ptr->abstract());
1934     return_node->set_load_flag(true);
1935     outputFuncGraph->set_return(return_node);
1936     MS_LOG(DEBUG) << "Construct funcgraph finined, all success.";
1937     return true;
1938   } else if (importProto.output_size() == 1) {
1939     auto graph_name = importProto.name();
1940     const auto &return_node_input0 = NewValueNode(prim::kPrimReturn);
1941     anfnode_build_map_[graph_name + kReturnPrimNode] = return_node_input0;
1942     inputs.push_back(return_node_input0);
1943     auto node_name = importProto.output(0).name();
1944     auto anf_node = GetAnfNode(node_name);
1945     if (anf_node == nullptr) {
1946       MS_LOG(ERROR) << "Miss return node: " << node_name;
1947       return false;
1948     }
1949     inputs.push_back(anf_node);
1950     anfnode_build_map_[node_name] = anf_node;
1951     auto return_node = outputFuncGraph->FuncGraph::NewCNode(inputs);
1952     MS_EXCEPTION_IF_NULL(return_node);
1953     return_node->set_abstract(anf_node->abstract());
1954     return_node->set_load_flag(true);
1955     outputFuncGraph->set_return(return_node);
1956     anfnode_build_map_[graph_name + kReturnNode] = return_node;
1957     MS_LOG(DEBUG) << "Construct funcgraph finined, all success!";
1958     return true;
1959   }
1960 
1961   return false;
1962 }
1963 
ImportNodesForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1964 bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
1965                                            const mind_ir::GraphProto &importProto) {
1966   MS_EXCEPTION_IF_NULL(outputFuncGraph);
1967   CNodePtr cnode_ptr = nullptr;
1968   for (int i = 0; i < importProto.node_size(); ++i) {
1969     const mind_ir::NodeProto &node_proto = importProto.node(i);
1970     if (is_kernel_graph_ && anfnode_build_map_.count(node_proto.output(0)) > 0) {
1971       continue;
1972     }
1973     const std::string &node_type = node_proto.op_type();
1974     if (node_type == kConstantValueNode) {
1975       if (!BuildValueNodeForFuncGraph(node_proto)) {
1976         MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: " << i;
1977         return false;
1978       }
1979       continue;
1980     }
1981     cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
1982     if (cnode_ptr == nullptr) {
1983       MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: " << i;
1984       return false;
1985     }
1986   }
1987   return BuildReturnForFuncGraph(outputFuncGraph, importProto);
1988 }
1989 
BuildAttrForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1990 bool MSANFModelParser::BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph,
1991                                              const mind_ir::GraphProto &importProto) {
1992   for (auto i = 0; i < importProto.attribute_size(); ++i) {
1993     const mind_ir::AttributeProto &attr_proto = importProto.attribute(i);
1994     auto value = GetValueFromAttributeProto(attr_proto);
1995     if (value == nullptr) {
1996       MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
1997       return false;
1998     }
1999     outputFuncGraph->set_attr(attr_proto.name(), value);
2000   }
2001   return true;
2002 }
2003 
BuildFuncGraph(const FuncGraphPtr & output_graph,const mind_ir::GraphProto & import_proto)2004 bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &output_graph, const mind_ir::GraphProto &import_proto) {
2005   MS_EXCEPTION_IF_NULL(output_graph);
2006   GraphDebugInfoPtr debug_info_ptr = output_graph->debug_info();
2007   MS_EXCEPTION_IF_NULL(debug_info_ptr);
2008   if (import_proto.has_name()) {
2009     debug_info_ptr->set_name(import_proto.name());
2010   } else {
2011     MS_LOG(ERROR) << "FuncGraph under converting has not name!";
2012     return false;
2013   }
2014   if (import_proto.has_bprop_hash()) {
2015     output_graph->set_bprop_hash(import_proto.bprop_hash());
2016   }
2017 
2018   if (import_proto.has_bprop_filepath()) {
2019     output_graph->set_bprop_filepath(import_proto.bprop_filepath());
2020   }
2021   if (!BuildAttrForFuncGraph(output_graph, import_proto)) {
2022     MS_LOG(ERROR) << "Build attribute for graph fail!";
2023   }
2024   if (!ImportParametersForGraph(output_graph, import_proto)) {
2025     MS_LOG(ERROR) << "Import parameters for graph fail!";
2026     return false;
2027   }
2028   if (!ImportMapParametersForGraph(output_graph, import_proto)) {
2029     MS_LOG(ERROR) << "Import map parameters for graph failed!";
2030     return false;
2031   }
2032   if (!ImportNodesForGraph(output_graph, import_proto)) {
2033     MS_LOG(ERROR) << "Import nodes for graph failed! " << import_proto.has_name();
2034     return false;
2035   }
2036   auto context = MsContext::GetInstance();
2037   MS_EXCEPTION_IF_NULL(context);
2038   const bool force_no_inline = common::IsDisableRuntimeConfig(common::kRuntimeInline);
2039   if (output_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
2040     const bool enable_ge = context->backend_policy() == "ge";
2041     auto cell_reuse_level =
2042       (enable_ge && !context->IsKByKExecutorMode()) ? CellReuseLevel::kNoInline : CellReuseLevel::kLazyInline;
2043     if (force_no_inline) {
2044       cell_reuse_level = CellReuseLevel::kNoInline;
2045     }
2046     context->SetCellReuseLevel(cell_reuse_level);
2047   }
2048   return true;
2049 }
2050 
SetValueForTopGraphParameter(const FuncGraphPtr & topGraph,const std::map<std::string,ValuePtr> & weights)2051 bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph,
2052                                                     const std::map<std::string, ValuePtr> &weights) {
2053   size_t fv_param_count = 0;
2054   auto parameters = topGraph->parameters();
2055   for (int64_t i = SizeToLong(parameters.size()) - 1; i >= 0; --i) {
2056     auto parameter = parameters[i]->cast<ParameterPtr>();
2057     if (parameter == nullptr) {
2058       MS_LOG(ERROR) << "AnfNode " << parameters[i]->DebugString() << " should be Parameter.";
2059       return false;
2060     }
2061     auto type = parameter->Type();
2062     if (type == nullptr) {
2063       MS_LOG(ERROR) << "Parameter " << parameter->DebugString() << " has no type.";
2064       return false;
2065     }
2066     if (!type->isa<RefType>()) {
2067       break;
2068     }
2069     auto parameter_name = parameter->name();
2070     auto weights_iter = weights.find(parameter_name);
2071     if (weights_iter == weights.end()) {
2072       MS_LOG(INFO) << "Find initial weight value for " << parameter_name << " failed.";
2073       continue;
2074     }
2075     parameter->set_default_param(weights_iter->second);
2076     fv_param_count++;
2077   }
2078   topGraph->set_fv_param_count(fv_param_count);
2079   return true;
2080 }
2081 
TrytoBuildCNodeAbstract()2082 void MSANFModelParser::TrytoBuildCNodeAbstract() {
2083   std::map<CNodePtr, int> visited_times;
2084   constexpr int kMaxCount = 3;
2085   while (!node_abstract_protos_.empty()) {
2086     auto &item = node_abstract_protos_.front();
2087     auto &count = visited_times[item.first];
2088     if (count++ > kMaxCount) {
2089       abstract_valid_ = false;
2090       MS_LOG(ERROR) << "Parse CNode: " << item.first->ToString() << " abstract error: " << item.second->DebugString();
2091     } else {
2092       SetCNodeAbstract(*(item.second), item.first);
2093     }
2094     node_abstract_protos_.pop_front();
2095   }
2096 }
2097 
CheckMindIRVseriosn(const mind_ir::ModelProto & model_proto)2098 bool CheckMindIRVseriosn(const mind_ir::ModelProto &model_proto) {
2099   if (model_proto.has_mind_ir_version()) {
2100     auto mind_ir_version = model_proto.mind_ir_version();
2101     if (mind_ir_version >= 2) {
2102       return true;
2103     }
2104   }
2105   return false;
2106 }
2107 
Parse(const mind_ir::ModelProto & model_proto,const std::map<std::string,ValuePtr> & weights,mindspore::HashMap<std::string,AnfNodePtr> * name_to_node)2108 FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto,
2109                                      const std::map<std::string, ValuePtr> &weights,
2110                                      mindspore::HashMap<std::string, AnfNodePtr> *name_to_node) {
2111   if (IsLite()) {
2112     abstract_valid_ = true;
2113   }
2114   if (name_to_node) {
2115     anfnode_build_map_ = *name_to_node;
2116   }
2117   for (int i = 0; i < model_proto.primitives_size(); ++i) {
2118     if (!BuildPrimitiveNode(model_proto.primitives(i))) {
2119       MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
2120       return nullptr;
2121     }
2122   }
2123   FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
2124   const mind_ir::GraphProto &graphBuild = model_proto.graph();
2125 
2126   // Forward declare FuncGraph name
2127   // Compatible with the previous proto.
2128   if (graphBuild.has_name()) {
2129     anfnode_build_map_[graphBuild.name()] = NewValueNodeWithAbstract(dstGraph);
2130   }
2131   for (int i = 0; i < model_proto.functions_size(); ++i) {
2132     FuncGraphPtr graph = std::make_shared<FuncGraph>();
2133     const auto &graph_proto = model_proto.functions(i);
2134     if (!graph_proto.has_name()) {
2135       MS_LOG(ERROR) << "The function has not a name. Please export mindIR again. ";
2136       return nullptr;
2137     }
2138     if (anfnode_build_map_.count(graph_proto.name()) > 0) {
2139       MS_LOG(ERROR) << "There is a duplication function graph name: " << graph_proto.name();
2140       return nullptr;
2141     }
2142     anfnode_build_map_[graph_proto.name()] = NewValueNodeWithAbstract(graph);
2143   }
2144 
2145   // Parser the proto.
2146   if (!BuildFuncGraph(dstGraph, graphBuild)) {
2147     MS_LOG(ERROR) << "Build funcgraph failed!";
2148     return nullptr;
2149   }
2150 
2151   if (!weights.empty()) {
2152     if (!SetValueForTopGraphParameter(dstGraph, weights)) {
2153       MS_LOG(ERROR) << "Set value for top graph fail.";
2154       return nullptr;
2155     }
2156   }
2157   bool generated_from_mindir_with_prim_func = CheckMindIRVseriosn(model_proto);
2158   dstGraph->set_flag("generated_from_mindir_with_prim_func", generated_from_mindir_with_prim_func);
2159   MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! graph: " << graphBuild.name() << ": " << dstGraph.get();
2160   top_graph_ = dstGraph;
2161   for (int i = 0; i < model_proto.functions_size(); ++i) {
2162     const auto &graph_proto = model_proto.functions(i);
2163     FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);
2164     if (!BuildFuncGraph(graph, graph_proto)) {
2165       MS_LOG(ERROR) << "Build funcgraph failed!";
2166       return nullptr;
2167     }
2168     graph->set_flag("generated_from_mindir_with_prim_func", generated_from_mindir_with_prim_func);
2169     MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! graph: " << graph_proto.name() << ": " << graph.get();
2170   }
2171   TrytoBuildCNodeAbstract();
2172   if (name_to_node) {
2173     *name_to_node = anfnode_build_map_;
2174   }
2175   // Release resource
2176   anfnode_build_map_.clear();
2177   // Correct the null abstract for compatibility with previous versions.
2178   if (!abstract_valid_ && weights.empty()) {
2179     CorrectFuncGraph(dstGraph);
2180   }
2181   return dstGraph;
2182 }
2183 
Parse(const mind_ir::ModelProto & model_proto,const std::vector<FuncGraphPtr> & graphs,mindspore::HashMap<std::string,AnfNodePtr> * name_to_node)2184 bool MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto, const std::vector<FuncGraphPtr> &graphs,
2185                              mindspore::HashMap<std::string, AnfNodePtr> *name_to_node) {
2186   is_kernel_graph_ = graphs.front()->type_name() == kKernelGraphTypeName;
2187   if (name_to_node) {
2188     anfnode_build_map_ = *name_to_node;
2189   }
2190   auto build_params_attrs = [this](const FuncGraphPtr &graph, const mind_ir::GraphProto &proto) {
2191     MS_EXCEPTION_IF_NULL(graph);
2192     if (!proto.has_name()) {
2193       MS_LOG(ERROR) << "KernelGraph under converting has not name!";
2194       return false;
2195     }
2196     GraphDebugInfoPtr debug_info_ptr = graph->debug_info();
2197     MS_EXCEPTION_IF_NULL(debug_info_ptr);
2198     debug_info_ptr->set_name(proto.name());
2199     if (!BuildAttrForFuncGraph(graph, proto)) {
2200       MS_LOG(ERROR) << "Build attribute for graph fail!";
2201     }
2202     if (!ImportParametersForGraph(graph, proto)) {
2203       MS_LOG(ERROR) << "Import parameters for graph fail!";
2204       return false;
2205     }
2206     if (!ImportMapParametersForGraph(graph, proto)) {
2207       MS_LOG(ERROR) << "Import map parameters for graph failed!";
2208       return false;
2209     }
2210     return true;
2211   };
2212   for (int i = 0; i < model_proto.primitives_size(); ++i) {
2213     if (!BuildPrimitiveNode(model_proto.primitives(i))) {
2214       MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
2215       return false;
2216     }
2217   }
2218   const mind_ir::GraphProto &graph_build = model_proto.graph();
2219   const auto &root = FindGraphByName(graphs, graph_build.name());
2220   MS_EXCEPTION_IF_NULL(root);
2221   anfnode_build_map_[graph_build.name()] = NewValueNodeWithAbstract(root);
2222   top_graph_ = root;
2223   if (!build_params_attrs(root, graph_build)) {
2224     MS_LOG(ERROR) << "Build funcgraph params and attrs failed.";
2225     return false;
2226   }
2227   for (int i = 0; i < model_proto.functions_size(); ++i) {
2228     const auto &graph_proto = model_proto.functions(i);
2229     if (!graph_proto.has_name()) {
2230       MS_LOG(ERROR) << "The function has not a name. Please export mindIR again. ";
2231       return false;
2232     }
2233     const auto &graph_name = graph_proto.name();
2234     if (anfnode_build_map_.count(graph_name) > 0) {
2235       MS_LOG(ERROR) << "There is a duplication function graph name: " << graph_proto.name();
2236       return false;
2237     }
2238     const auto &graph = FindGraphByName(graphs, graph_name);
2239     MS_EXCEPTION_IF_NULL(graph);
2240     auto debug_info = graph->debug_info();
2241     debug_info->set_name(graph_name);
2242     anfnode_build_map_[graph_name] = NewValueNodeWithAbstract(graph);
2243     if (!build_params_attrs(graph, graph_proto)) {
2244       MS_LOG(ERROR) << "Build funcgraph params and attrs failed.";
2245       return false;
2246     }
2247   }
2248 
2249   // Parser the proto.
2250   if (!ImportNodesForGraph(root, graph_build)) {
2251     MS_LOG(ERROR) << "Build funcgraph " << graph_build.name() << " value node and cnode failed.";
2252     return false;
2253   } else {
2254     MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! graph: " << graph_build.name();
2255   }
2256   std::map<std::string, mind_ir::GraphProto> sorted_proto;
2257   std::for_each(model_proto.functions().begin(), model_proto.functions().end(),
2258                 [&sorted_proto](const auto &proto) { sorted_proto[proto.name()] = proto; });
2259   for (const auto &[name, proto] : sorted_proto) {
2260     FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[name]);
2261     if (!ImportNodesForGraph(graph, proto)) {
2262       MS_LOG(ERROR) << "Build funcgraph: " << name << "'s value_node and cnode failed.";
2263       return false;
2264     } else {
2265       MS_LOG(INFO) << "Build FuncGraph Success! graph: " << name;
2266     }
2267   }
2268   TrytoBuildCNodeAbstract();
2269   if (name_to_node) {
2270     *name_to_node = anfnode_build_map_;
2271   }
2272   // Release resource
2273   anfnode_build_map_.clear();
2274   return true;
2275 }
2276 
ParseLayout(const mind_ir::ModelProto & model_proto)2277 const LayoutMap MSANFModelParser::ParseLayout(const mind_ir::ModelProto &model_proto) {
2278   LayoutMap ret;
2279   mind_ir::ParallelProto parallel_proto = model_proto.parallel();
2280   for (int i = 0; i < parallel_proto.layout_size(); ++i) {
2281     const mind_ir::LayoutProto &layout_proto = parallel_proto.layout(i);
2282     LayoutPtr cur_layout = std::make_shared<Layout>();
2283     const std::string name = layout_proto.name();
2284     std::vector<int64_t> device_arrangement;
2285     for (int num = 0; num < layout_proto.device_arrangement_int_size(); ++num) {
2286       (void)device_arrangement.emplace_back(layout_proto.device_arrangement_int(num));
2287     }
2288     std::vector<int64_t> tensor_map;
2289     for (int num = 0; num < layout_proto.tensor_map_int_size(); ++num) {
2290       (void)tensor_map.emplace_back(layout_proto.tensor_map_int(num));
2291     }
2292     std::vector<int64_t> slice_shape;
2293     for (int num = 0; num < layout_proto.slice_shape_int_size(); ++num) {
2294       (void)slice_shape.emplace_back(layout_proto.slice_shape_int(num));
2295     }
2296     int64_t field_size = layout_proto.field_size();
2297     bool uniform_spilt = layout_proto.uniform_split();
2298     const std::string opt_shard_group = layout_proto.opt_shard_group();
2299 
2300     cur_layout->set_device_arrangement(device_arrangement);
2301     cur_layout->set_tensor_map(tensor_map);
2302     cur_layout->set_slice_shape(slice_shape);
2303     cur_layout->set_field_size(field_size);
2304     cur_layout->set_uniform_split(uniform_spilt);
2305     cur_layout->set_opt_shard_group(opt_shard_group);
2306 
2307     // Check optional field for backward compatibility.
2308     if (layout_proto.has_pipeline_shared()) {
2309       bool pipeline_shared = layout_proto.pipeline_shared();
2310       bool is_send = layout_proto.is_send();
2311       int64_t peer_rank = layout_proto.peer_rank();
2312       int64_t sr_tag = layout_proto.sr_tag();
2313 
2314       cur_layout->set_pipeline_shared(pipeline_shared);
2315       cur_layout->set_is_send(is_send);
2316       cur_layout->set_peer_rank(peer_rank);
2317       cur_layout->set_sr_tag(sr_tag);
2318     }
2319     ret[name] = cur_layout;
2320   }
2321   return ret;
2322 }
2323 
GetAnfNode(const std::string & node_name)2324 AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) {
2325   if (node_name.find("MetaFuncGraph::") == 0) {
2326     auto fg_name = node_name.substr(std::string("MetaFuncGraph::").length());
2327     auto mindir_meta_fg = std::make_shared<MindIRMetaFuncGraph>(fg_name);
2328     return NewValueNode(mindir_meta_fg);
2329   }
2330   if (node_name.find("ClassType::") == 0) {
2331     auto class_type = node_name.substr(std::string("ClassType::").length());
2332     auto mindir_class_type = std::make_shared<MindIRClassType>(class_type);
2333     return NewValueNode(mindir_class_type);
2334   }
2335   auto it = anfnode_build_map_.find(node_name);
2336   if (it == anfnode_build_map_.end()) {
2337     return nullptr;
2338   }
2339   // The FunctionGraph node can't been shared.
2340   FuncGraphPtr func_graph_ptr = GetValueNode<FuncGraphPtr>(it->second);
2341   if (func_graph_ptr != nullptr) {
2342     auto node = NewValueNode(func_graph_ptr);
2343     node->set_abstract(func_graph_ptr->ToAbstract());
2344     return node;
2345   } else {
2346     return it->second;
2347   }
2348 }
2349 
BuildPrimitiveNode(const mind_ir::PrimitiveProto & primitive_proto)2350 bool MSANFModelParser::BuildPrimitiveNode(const mind_ir::PrimitiveProto &primitive_proto) {
2351   static auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
2352   auto &prim_type = primitive_proto.op_type();
2353   const auto &type = primitive_proto.prim_type();
2354   std::shared_ptr<Primitive> prim;
2355 
2356   auto it = op_primc_fns.find(prim_type);
2357   if (it != op_primc_fns.end()) {
2358     prim = it->second();
2359   } else {
2360     if (prim_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
2361       auto op_name = prim_type.substr(strlen(kDoSignaturePrimitivePrefix));
2362       prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
2363     } else {
2364       prim = std::make_shared<Primitive>(prim_type);
2365     }
2366   }
2367   if (type == mind_ir::PrimitiveProto_PrimType_PRIMITIVE_FUNCTION) {
2368     MS_LOG(DEBUG) << "PrimitiveFunction special node_type: " << prim_type;
2369     prim->AddAttr("primitive_function", MakeValue(true));
2370   }
2371 
2372   if (primitive_proto.has_instance_name()) {
2373     prim->set_instance_name(primitive_proto.instance_name());
2374   }
2375 
2376   // Set primitive attributes
2377   auto prim_to_add_attr = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
2378   MS_EXCEPTION_IF_NULL(prim_to_add_attr);
2379   prim_to_add_attr->set_attr("is_load", MakeValue(true));
2380   for (int i = 0; i < primitive_proto.attribute_size(); ++i) {
2381     const mind_ir::AttributeProto &attr_proto = primitive_proto.attribute(i);
2382     if (!SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
2383       MS_LOG(ERROR) << "Parse prim: " << prim->ToString() << " attributes error: " << attr_proto.DebugString();
2384       return false;
2385     }
2386   }
2387   if (anfnode_build_map_.count(primitive_proto.name()) > 0) {
2388     MS_LOG(ERROR) << "There is a duplication primitive instance name: " << primitive_proto.name();
2389     return false;
2390   }
2391   anfnode_build_map_[primitive_proto.name()] = NewValueNodeWithAbstract(prim);
2392   return true;
2393 }
2394 
BuildAbstractFunction(const mind_ir::AttributeProto & attr_proto)2395 abstract::AbstractBasePtr MSANFModelParser::BuildAbstractFunction(const mind_ir::AttributeProto &attr_proto) {
2396   switch (attr_proto.type()) {
2397     case mind_ir::AttributeProto_AttributeType_PRIMITIVECLOSURE:
2398     case mind_ir::AttributeProto_AttributeType_FUNCGRAPHCLOSURE: {
2399       auto func_node = GetAnfNode(attr_proto.s());
2400       if (func_node == nullptr) {
2401         FuncGraphPtr dummy_graph = std::make_shared<FuncGraph>();
2402         MS_LOG(DEBUG) << "Failed to get function graph closure: " << attr_proto.DebugString();
2403         return dummy_graph->ToAbstract();
2404       }
2405       return func_node->abstract();
2406     }
2407     case mind_ir::AttributeProto_AttributeType_PARTIALCLOSURE: {
2408       auto anf_node = GetAnfNode(attr_proto.s());
2409       if (anf_node == nullptr) {
2410         return nullptr;
2411       }
2412       auto partial_node = anf_node->cast<CNodePtr>();
2413       MS_EXCEPTION_IF_NULL(partial_node);
2414       if (!IsPrimitiveCNode(partial_node, prim::kPrimPartial)) {
2415         MS_LOG(INTERNAL_EXCEPTION) << "Not partial CNode, but got " << partial_node->DebugString();
2416       }
2417       AbstractBasePtrList args_spec_list;
2418       auto &inputs = partial_node->inputs();
2419       const size_t kPartial_args_begin_pos = 2;
2420       const size_t kPartial_fn_pos = 1;
2421       if (inputs.size() <= kPartial_args_begin_pos) {
2422         MS_LOG(ERROR) << "Partial node input size is wrong.";
2423         return nullptr;
2424       }
2425       (void)std::transform(inputs.begin() + kPartial_args_begin_pos, inputs.end(), std::back_inserter(args_spec_list),
2426                            [](const AnfNodePtr &arg) -> AbstractBasePtr { return arg->abstract(); });
2427       auto &op_node = inputs[kPartial_fn_pos];
2428       MS_EXCEPTION_IF_NULL(op_node);
2429       abstract::AbstractFuncAtomPtr fn;
2430       if (op_node->abstract() != nullptr) {
2431         fn = op_node->abstract()->cast<abstract::AbstractFuncAtomPtr>();
2432         if (fn == nullptr) {
2433           MS_LOG(DEBUG) << "Can't get the abstract of partial node: " << op_node->ToString();
2434           FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op_node);
2435           if (fg == nullptr) {
2436             MS_LOG(INTERNAL_EXCEPTION) << "partial_node: " << partial_node->DebugString()
2437                                        << ", op_node: " << op_node->DebugString() << ", "
2438                                        << op_node->abstract()->ToString();
2439           }
2440           fn = fg->ToAbstract()->cast<abstract::AbstractFuncAtomPtr>();
2441         }
2442       } else {
2443         MS_LOG(DEBUG) << "Can't get the abstract of partial node: " << op_node->ToString();
2444         FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op_node);
2445         MS_EXCEPTION_IF_NULL(fg);
2446         fn = fg->ToAbstract()->cast<abstract::AbstractFuncAtomPtr>();
2447       }
2448       return std::make_shared<abstract::PartialAbstractClosure>(fn, args_spec_list, partial_node);
2449     }
2450     case mind_ir::AttributeProto_AttributeType_UNIONFUNCCLOSURE: {
2451       abstract::AbstractFuncAtomPtrList func_list;
2452       for (int index = 0; index < attr_proto.values_size(); index++) {
2453         auto &item_proto = attr_proto.values(index);
2454         auto item_abstract = BuildAbstractFunction(item_proto);
2455         if (item_abstract == nullptr) {
2456           MS_LOG(WARNING) << "Can't get the abstract of function union closure: " << item_proto.DebugString();
2457           return nullptr;
2458         }
2459         (void)func_list.emplace_back(item_abstract->cast<abstract::AbstractFuncAtomPtr>());
2460       }
2461       return std::make_shared<abstract::AbstractFuncUnion>(func_list);
2462     }
2463     default: {
2464       MS_LOG(ERROR) << "Not support function abstract: " << attr_proto.DebugString();
2465       return nullptr;
2466     }
2467   }
2468 }
2469 
CorrectFuncGraph(const FuncGraphPtr & root)2470 void MSANFModelParser::CorrectFuncGraph(const FuncGraphPtr &root) {
2471   MS_LOG(DEBUG) << "Begin to correct the funcgraph.";
2472   MS_EXCEPTION_IF_NULL(root);
2473   auto inputs = root->get_inputs();
2474   auto valid =
2475     std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &arg) -> bool { return arg->abstract() != nullptr; });
2476   if (valid) {
2477     (void)ValidMindir(root);
2478   } else {
2479     MS_LOG(INFO) << "There are some nullptr of abstract in the top function graph parameters." << root->DumpText();
2480   }
2481   MS_LOG(DEBUG) << "End to correct the funcgraph.";
2482 }
2483 
BuildAttrForCNode(const CNodePtr & cnode,const mind_ir::NodeProto & node_proto)2484 bool MSANFModelParser::BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto) {
2485   for (auto i = 0; i < node_proto.node_attr_size(); ++i) {
2486     const auto &attr_proto = node_proto.node_attr(i);
2487     auto value = GetValueFromAttributeProto(attr_proto);
2488     if (value == nullptr) {
2489       MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
2490       return false;
2491     }
2492     cnode->AddAttr(attr_proto.name(), value);
2493   }
2494   for (auto i = 0; i < node_proto.primal_attr_size(); ++i) {
2495     const auto &attr_proto = node_proto.primal_attr(i);
2496     auto value = GetValueFromAttributeProto(attr_proto);
2497     if (value == nullptr) {
2498       MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
2499       return false;
2500     }
2501     cnode->AddPrimalAttr(attr_proto.name(), value);
2502   }
2503   return true;
2504 }
2505 
get_all_files(const std::string & dir_in,std::vector<std::string> * files)2506 bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
2507   if (dir_in.empty()) {
2508     return false;
2509   }
2510   struct stat s;
2511   int ret = stat(dir_in.c_str(), &s);
2512   if (ret != 0) {
2513     MS_LOG(ERROR) << "stat error, ret is : " << ret;
2514     return false;
2515   }
2516   if (!S_ISDIR(s.st_mode)) {
2517     return false;
2518   }
2519   DIR *open_dir = opendir(dir_in.c_str());
2520   if (open_dir == nullptr) {
2521     MS_LOG(EXCEPTION) << "open dir " << dir_in.c_str() << " failed";
2522   }
2523   dirent *p = nullptr;
2524   while ((p = readdir(open_dir)) != nullptr) {
2525     struct stat st;
2526     if (p->d_name[0] != '.') {
2527       std::string name = dir_in + std::string("/") + std::string(p->d_name);
2528       ret = stat(name.c_str(), &st);
2529       if (ret != 0) {
2530         MS_LOG(ERROR) << "stat error, ret is : " << ret;
2531         closedir(open_dir);
2532         return false;
2533       }
2534       if (S_ISDIR(st.st_mode)) {
2535         if (!get_all_files(name, files)) {
2536           MS_LOG(ERROR) << "Get files failed, ret is : " << ret;
2537           closedir(open_dir);
2538           return false;
2539         }
2540       } else if (S_ISREG(st.st_mode)) {
2541         files->push_back(name);
2542       }
2543     }
2544   }
2545   closedir(open_dir);
2546   return true;
2547 }
2548 
endsWith(const string s,const string sub)2549 int endsWith(const string s, const string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }
2550 
ParseModelProto(mind_ir::ModelProto * model,const std::string & path,const MindIRLoader * loader)2551 bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path, const MindIRLoader *loader) {
2552   if (loader->dec_key() != nullptr) {
2553     size_t plain_len;
2554     auto plain_data = Decrypt(&plain_len, path, loader->dec_key(), loader->key_len(), loader->dec_mode());
2555     if (plain_data == nullptr) {
2556       MS_LOG(ERROR)
2557         << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode or the file integrity.";
2558       return false;
2559     }
2560     if (!model->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), static_cast<int32_t>(plain_len))) {
2561       MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file, dec_key or dec_mode.";
2562       return false;
2563     }
2564   } else {
2565     std::fstream input_graph(path, std::ios::in | std::ios::binary);
2566     if (!input_graph || !model->ParseFromIstream(&input_graph)) {
2567       MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
2568       return false;
2569     }
2570   }
2571   return true;
2572 }
2573 
ParseGraphProto(mind_ir::GraphProto * graph,const std::string & path,const MindIRLoader * loader)2574 bool ParseGraphProto(mind_ir::GraphProto *graph, const std::string &path, const MindIRLoader *loader) {
2575   if (loader->dec_key() != nullptr) {
2576     size_t plain_len;
2577     auto plain_data = Decrypt(&plain_len, path, loader->dec_key(), loader->key_len(), loader->dec_mode());
2578     if (plain_data == nullptr) {
2579       MS_LOG(ERROR)
2580         << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode or the file integrity.";
2581       return false;
2582     }
2583     if (!graph->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), static_cast<int32_t>(plain_len))) {
2584       MS_LOG(ERROR) << "Load variable file failed, please check the correctness of the mindir's variable file, "
2585                        "dec_key or dec_mode";
2586       return false;
2587     }
2588   } else {
2589     std::fstream input_param(path, std::ios::in | std::ios::binary);
2590     if (!input_param || !graph->ParseFromIstream(&input_param)) {
2591       MS_LOG(ERROR) << "Load variable file failed, please check the correctness of mindir's variable file.";
2592       return false;
2593     }
2594   }
2595   return true;
2596 }
2597 
InitModelParser(MSANFModelParser * model_parser,const MindIRLoader * loader)2598 void InitModelParser(MSANFModelParser *model_parser, const MindIRLoader *loader) {
2599   model_parser->SetMindIRDecKey(loader->dec_key());
2600   model_parser->SetMindIRKeySize(loader->key_len());
2601   model_parser->SetMindIRDecMode(loader->dec_mode());
2602 
2603   if (loader->is_lite()) {
2604     model_parser->SetLite();
2605   }
2606 }
2607 }  // namespace
2608 
LoadPreprocess(const std::string & file_name)2609 std::vector<std::string> MindIRLoader::LoadPreprocess(const std::string &file_name) {
2610   if (file_name.length() > PATH_MAX) {
2611     MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
2612     return {};
2613   }
2614   char abs_path_buff[PATH_MAX];
2615 
2616 #ifdef _WIN32
2617   _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
2618 #else
2619   if (!realpath(file_name.c_str(), abs_path_buff)) {
2620     MS_LOG(ERROR) << "Load MindIR get absolute path failed";
2621   }
2622 #endif
2623 
2624   // Read graph
2625   mind_ir::ModelProto origin_model;
2626   if (!ParseModelProto(&origin_model, std::string(abs_path_buff), this)) {
2627     MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
2628     return {};
2629   }
2630 
2631   // Read dataset preprocessor
2632   auto preprocessor = origin_model.preprocessor();
2633 
2634   // Separate columns and parse
2635   std::vector<std::string> input_columns;
2636   for (auto i = 0; i < preprocessor.op_size(); i++) {
2637     std::string column = preprocessor.op()[i].input_columns();
2638     if (std::find(input_columns.begin(), input_columns.end(), column) == input_columns.end()) {
2639       input_columns.push_back(column);
2640     }
2641   }
2642 
2643   // Each column has one string to indicate its preprocess behaviour
2644   std::vector<std::string> map_jsons;
2645   for (std::string &column : input_columns) {
2646     nlohmann::json dataset_json;
2647     nlohmann::json child_dataset_json;
2648     for (auto i = preprocessor.op_size() - 1; i >= 0; i--) {
2649       if (preprocessor.op()[i].input_columns() == column) {
2650         child_dataset_json["input_columns"] = nlohmann::json::parse(preprocessor.op()[i].input_columns());
2651         child_dataset_json["op_type"] = nlohmann::json::parse(preprocessor.op()[i].op_type());
2652         child_dataset_json["operations"] = nlohmann::json::parse(preprocessor.op()[i].operations());
2653         child_dataset_json["output_columns"] = nlohmann::json::parse(preprocessor.op()[i].output_columns());
2654         child_dataset_json["offload"] = preprocessor.op()[i].offload();
2655 
2656         dataset_json["children"] = child_dataset_json;
2657         child_dataset_json = dataset_json;
2658       }
2659     }
2660     map_jsons.push_back(dataset_json["children"].dump());
2661   }
2662   return map_jsons;
2663 }
2664 
LoadMindIRs(const std::vector<std::string> & file_names)2665 std::vector<FuncGraphPtr> MindIRLoader::LoadMindIRs(const std::vector<std::string> &file_names) {
2666   std::vector<FuncGraphPtr> funcgraph_vec;
2667   MS_LOG(DEBUG) << "Load multiple MindIR files.";
2668   for (const auto &file_name : file_names) {
2669     MS_LOG(DEBUG) << "Load " << file_name;
2670     funcgraph_vec.push_back(LoadMindIR(file_name));
2671   }
2672   return funcgraph_vec;
2673 }
2674 
LoadMindIR(const void * buffer,const size_t & size)2675 FuncGraphPtr MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
2676   /* mindir -> func_graph
2677    * only support lite */
2678   mind_ir::ModelProto model;
2679   auto ret = model.ParseFromArray(buffer, SizeToInt(size));
2680   if (!ret) {
2681     MS_LOG(ERROR) << "ParseFromArray failed.";
2682     return nullptr;
2683   }
2684   if (!CheckModelConfigureInfo(model)) {
2685     MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2686     return nullptr;
2687   }
2688   MSANFModelParser model_parser;
2689   InitModelParser(&model_parser, this);
2690   FuncGraphPtr func_graph = model_parser.Parse(model);
2691 
2692   return func_graph;
2693 }
2694 
2695 mindspore::HashMap<std::string, AnfNodePtr> anfnode_build_map_;
LoadMindIR(const std::string & file_name,mindspore::HashMap<std::string,AnfNodePtr> * name_to_node)2696 FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name,
2697                                       mindspore::HashMap<std::string, AnfNodePtr> *name_to_node) {
2698   if (file_name.length() > PATH_MAX) {
2699     MS_LOG(EXCEPTION) << "The length of the file name exceeds the limit.";
2700   }
2701   char abs_path_buff[PATH_MAX];
2702   vector<string> files;
2703 
2704 #ifdef _WIN32
2705   _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
2706 #else
2707   if (!realpath(file_name.c_str(), abs_path_buff)) {
2708     MS_LOG(EXCEPTION) << "Load MindIR get absolute path of " << file_name
2709                       << " failed, errno is: " << ErrnoToString(errno);
2710   }
2711 #endif
2712   // Read graph
2713   mind_ir::ModelProto origin_model;
2714   if (!ParseModelProto(&origin_model, std::string(abs_path_buff), this)) {
2715     return nullptr;
2716   }
2717 
2718   if (!CheckModelConfigureInfo(origin_model)) {
2719     MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2720     return nullptr;
2721   }
2722   // Load parameter into graph
2723   if (endsWith(std::string(abs_path_buff), "_graph.mindir") && (origin_model.graph().parameter_size() == 0)) {
2724     if (strlen(abs_path_buff) < strlen("graph.mindir")) {
2725       MS_LOG(ERROR) << "The abs_path_buff length is less than 'graph.mindir'.";
2726       return nullptr;
2727     }
2728     size_t path_len = strlen(abs_path_buff) - strlen("graph.mindir");
2729     string var_path = std::string(abs_path_buff).substr(0, path_len);
2730     var_path += "variables";
2731     std::ifstream ifs(var_path);
2732     if (ifs.good()) {
2733       MS_LOG(DEBUG) << "MindIR file has variables path, load parameter into graph.";
2734       (void)get_all_files(var_path, &files);
2735     } else {
2736       MS_LOG(ERROR) << "Load graph's variable folder failed, please check the correctness of variable folder.";
2737       return nullptr;
2738     }
2739 
2740     size_t file_size = files.size();
2741     mind_ir::GraphProto *mod_graph = origin_model.mutable_graph();
2742     for (size_t file_index = 0; file_index < file_size; file_index++) {
2743       mind_ir::GraphProto param_graph;
2744       if (!ParseGraphProto(&param_graph, files[file_index], this)) {
2745         return nullptr;
2746       }
2747 
2748       for (int param_index = 0; param_index < param_graph.parameter_size(); param_index++) {
2749         mind_ir::TensorProto *param_proto = mod_graph->add_parameter();
2750         param_proto->set_name(param_graph.parameter(param_index).name());
2751         param_proto->set_data_type(param_graph.parameter(param_index).data_type());
2752         param_proto->set_raw_data(param_graph.parameter(param_index).raw_data());
2753         param_proto->set_compression_type(param_graph.parameter(param_index).compression_type());
2754         for (const auto &dim : param_graph.parameter(param_index).dims()) {
2755           param_proto->add_dims(dim);
2756         }
2757       }
2758     }
2759   }
2760 
2761   MSANFModelParser model_parser;
2762 
2763   auto mindir_path = std::string(abs_path_buff);
2764   model_parser.SetMindIRPath(mindir_path.substr(0, mindir_path.rfind("/")));
2765   InitModelParser(&model_parser, this);
2766   FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model, weights_value_map_, name_to_node);
2767   if (has_parallel_info_) {
2768     layout_map_ = model_parser.ParseLayout(origin_model);
2769   }
2770   return dstgraph_ptr;
2771 }
2772 
LoadMindIR(const void * buffer,const size_t & size,const std::string & mindir_path,FuncGraphPtr * func_graph,std::string * user_info_string)2773 bool MindIRLoader::LoadMindIR(const void *buffer, const size_t &size, const std::string &mindir_path,
2774                               FuncGraphPtr *func_graph, std::string *user_info_string) {
2775   mind_ir::ModelProto model;
2776   auto ret = model.ParseFromArray(buffer, SizeToInt(size));
2777   if (!ret) {
2778     MS_LOG(ERROR) << "ParseFromArray failed.";
2779     return false;
2780   }
2781   if (!CheckModelConfigureInfo(model)) {
2782     MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2783     return false;
2784   }
2785   MSANFModelParser model_parser;
2786   InitModelParser(&model_parser, this);
2787   model_parser.SetMindIRPath(mindir_path);
2788   *func_graph = model_parser.Parse(model);
2789   std::stringstream user_info_buffer;
2790   // user_info to string
2791   auto user_info = model.user_info();
2792   user_info_buffer << "{";
2793   for (auto it = user_info.begin(); it != user_info.end(); it++) {
2794     if (it != user_info.begin()) {
2795       user_info_buffer << ", ";
2796     }
2797     user_info_buffer << "\"" << it->first << "\": \"" << it->second + "\"";
2798   }
2799   user_info_buffer << "}";
2800   *user_info_string = user_info_buffer.str();
2801   return true;
2802 }
2803 
LoadMindIR(const void * buffer,const size_t & size,const std::string & mindir_path)2804 FuncGraphPtr MindIRLoader::LoadMindIR(const void *buffer, const size_t &size, const std::string &mindir_path) {
2805   mind_ir::ModelProto model;
2806   auto ret = model.ParseFromArray(buffer, SizeToInt(size));
2807   if (!ret) {
2808     MS_LOG(ERROR) << "ParseFromArray failed.";
2809     return nullptr;
2810   }
2811   if (!CheckModelConfigureInfo(model)) {
2812     MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2813     return nullptr;
2814   }
2815   MSANFModelParser model_parser;
2816   InitModelParser(&model_parser, this);
2817   model_parser.SetMindIRPath(mindir_path);
2818   FuncGraphPtr func_graph = model_parser.Parse(model);
2819   return func_graph;
2820 }
2821 
LoadMindIR(const std::string & file_name,const std::vector<FuncGraphPtr> & graphs,mindspore::HashMap<std::string,AnfNodePtr> * name_to_node)2822 bool MindIRLoader::LoadMindIR(const std::string &file_name, const std::vector<FuncGraphPtr> &graphs,
2823                               mindspore::HashMap<std::string, AnfNodePtr> *name_to_node) {
2824   if (file_name.length() > PATH_MAX) {
2825     MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
2826     return false;
2827   }
2828   char abs_path_buff[PATH_MAX];
2829 #ifdef _WIN32
2830   _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
2831 #else
2832   if (!realpath(file_name.c_str(), abs_path_buff)) {
2833     MS_LOG(EXCEPTION) << "Load MindIR get absolute path of " << file_name
2834                       << " failed, errno is: " << ErrnoToString(errno);
2835   }
2836 #endif
2837   mind_ir::ModelProto model_proto;
2838   // Read graph
2839   if (!ParseModelProto(&model_proto, std::string(abs_path_buff), this)) {
2840     return false;
2841   }
2842   if (!CheckModelConfigureInfo(model_proto)) {
2843     MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2844     return false;
2845   }
2846   MSANFModelParser model_parser;
2847   InitModelParser(&model_parser, this);
2848   if (!model_parser.Parse(model_proto, graphs, name_to_node)) {
2849     MS_LOG(ERROR) << "Parse model failed!";
2850     return false;
2851   }
2852   return true;
2853 }
2854 
ReadProtoFile(const std::string & file)2855 std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
2856   if (file.empty()) {
2857     MS_LOG(ERROR) << "file is nullptr";
2858     return nullptr;
2859   }
2860 
2861   char real_path[PATH_MAX] = {0};
2862 #if defined(_WIN32) || defined(_WIN64)
2863   if (_fullpath(real_path, file.c_str(), PATH_MAX) == nullptr) {
2864     MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
2865     return nullptr;
2866   }
2867 #else
2868   if (realpath(file.c_str(), real_path) == nullptr) {
2869     MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
2870     return nullptr;
2871   }
2872 #endif
2873 
2874   std::ifstream ifs(real_path);
2875   if (!ifs.good()) {
2876     MS_LOG(ERROR) << "file: " << real_path << " is not exist";
2877     return nullptr;
2878   }
2879 
2880   if (!ifs.is_open()) {
2881     MS_LOG(ERROR) << "file: " << real_path << "open failed";
2882     return nullptr;
2883   }
2884 
2885   ifs.seekg(0, std::ios::end);
2886   size_t size = ifs.tellg();
2887   std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
2888   if (buf == nullptr) {
2889     MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
2890     ifs.close();
2891     return nullptr;
2892   }
2893 
2894   ifs.seekg(0, std::ios::beg);
2895   ifs.read(buf->data(), size);
2896   ifs.close();
2897 
2898   return buf;
2899 }
2900 
ConvertStreamToFuncGraph(const char * buf,const size_t buf_size,bool is_lite)2901 FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
2902   MS_EXCEPTION_IF_NULL(buf);
2903   std::string str(buf, buf_size);
2904   mind_ir::ModelProto model_;
2905   if (!model_.ParseFromString(str)) {
2906     MS_LOG(ERROR) << "Parse model from buffer fail!";
2907   }
2908   MSANFModelParser model_parser;
2909   if (is_lite) {
2910     model_parser.SetLite();
2911   }
2912   FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
2913   return dstgraph_ptr;
2914 }
2915 }  // namespace mindspore
2916