1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "load_mindir/load_model.h"
17 #include <sys/stat.h>
18 #include <sys/types.h>
19 #include <cstring>
20 #include <string>
21 #include <memory>
22 #include <algorithm>
23 #include <fstream>
24 #include <iostream>
25 #include <stack>
26 #include <list>
27 #include <utility>
28 #include <nlohmann/json.hpp>
29 #include "mindspore/core/ops/structure_ops.h"
30 #include "mindspore/core/ops/sequence_ops.h"
31 #include "mindspore/core/ops/framework_ops.h"
32 #include "utils/crypto.h"
33 #include "utils/os.h"
34 #include "ir/value.h"
35 #include "ir/tensor.h"
36 #include "ir/param_info.h"
37 #include "ir/map_tensor.h"
38 #include "ir/functor.h"
39 #include "ops/primitive_c.h"
40 #include "abstract/abstract_value.h"
41 #include "abstract/ops/primitive_infer_map.h"
42 #include "utils/hash_map.h"
43 #include "utils/log_adapter.h"
44 #include "utils/check_convert_utils.h"
45 #include "utils/ms_utils_secure.h"
46 #include "abstract/abstract_function.h"
47 #include "load_mindir/infer_mindir.h"
48 #include "include/common/debug/common.h"
49 #include "proto/mind_ir.pb.h"
50 #include "google/protobuf/io/zero_copy_stream_impl.h"
51
52 using std::string;
53 using std::vector;
54
55 namespace mindspore {
56 namespace {
57 static constexpr char kConstantValueNode[] = "Constant";
58 static constexpr char kQuantParam[] = "quant_param";
59 static constexpr char kGraphInputQuantParam[] = "graph_input_quant_param";
60
61 enum ParseForm : int {
62 FORM_PARSE_TYPE = 0,
63 FORM_PARSE_SCALAR = 1,
64 FORM_PARSE_TENSOR = 2,
65 FORM_PARSE_NONE = 3,
66 FORM_PARSE_MONAD = 4,
67 FORM_PARSE_SEQUENCE = 5,
68 FORM_PARSE_UNDEFINE = 6,
69 };
70
71 static std::map<std::string, ParseForm> kParseTypeSwitchMap{
72 {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR},
73 {"none", FORM_PARSE_NONE}, {"Monad", FORM_PARSE_MONAD}, {"Sequence", FORM_PARSE_SEQUENCE}};
74
75 static mindspore::HashMap<int, TypeId> kDefaultValueSwitchMap{
76 {mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool},
77 {mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8},
78 {mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16},
79 {mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32},
80 {mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64},
81 {mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8},
82 {mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16},
83 {mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32},
84 {mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64},
85 {mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
86 {mind_ir::TensorProto_DataType_BFLOAT16, kNumberTypeBFloat16},
87 {mind_ir::TensorProto_DataType_QINT4X2, kNumberTypeInt4},
88 {mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
89 {mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64},
90 {mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
91 {mind_ir::TensorProto_DataType_STRING, kObjectTypeString},
92 {mind_ir::TensorProto_DataType_COMPLEX64, kNumberTypeComplex64},
93 {mind_ir::TensorProto_DataType_COMPLEX128, kNumberTypeComplex128}};
94
95 template <typename T, typename P>
ParserAttr(const std::string & str,const mindspore::HashMap<string,P> & kv)96 std::shared_ptr<T> ParserAttr(const std::string &str, const mindspore::HashMap<string, P> &kv) {
97 std::stack<std::string> rules;
98 std::stack<P> value;
99 size_t count = 0;
100 for (size_t i = 0; i < str.length(); i++) {
101 if (str[i] == '[') {
102 rules.push(std::string("["));
103 } else if (str[i] == ']') {
104 // rules
105 std::vector<P> vec;
106 while (!rules.empty() && rules.top() != "[") {
107 rules.pop();
108 vec.push_back(value.top());
109 value.pop();
110 }
111 if (!rules.empty()) {
112 // pop "["
113 rules.pop();
114 }
115 // make tuple for names
116 std::string res = "dummy";
117 // make tuple for values
118 reverse(vec.begin(), vec.end());
119 auto vt = std::make_shared<T>(vec);
120 if (rules.empty() && value.empty()) {
121 return vt;
122 }
123 rules.push(res);
124 value.push(vt);
125 } else if (str[i] == ',') {
126 continue;
127 } else {
128 count++;
129 if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
130 auto value_name = str.substr((i - count) + 1, count);
131 if (kv.find(value_name) == kv.end()) {
132 MS_LOG(ERROR) << "Node's attributes and shape do not match.";
133 return nullptr;
134 }
135 value.push(kv.at(value_name));
136 rules.push(value_name);
137 count = 0;
138 }
139 }
140 }
141 return {};
142 }
143
144 template <typename T>
ParserScalarAttrValue(const std::string & attr_name,const mindspore::HashMap<string,ValuePtr> & kv)145 std::shared_ptr<T> ParserScalarAttrValue(const std::string &attr_name, const mindspore::HashMap<string, ValuePtr> &kv) {
146 std::string str = attr_name;
147 auto replace = [&](const string &orgStr, const string &newStr) {
148 std::string::size_type pos;
149 while ((pos = str.find(orgStr)) != std::string::npos) {
150 (void)str.replace(pos, orgStr.length(), newStr);
151 }
152 return str;
153 };
154 // remove "scalar:"
155 str = replace("scalar:", "");
156 // remove "Tuple"
157 str = replace("Tuple", "");
158 // remove "List"
159 str = replace("List", "");
160 auto result = ParserAttr<T, ValuePtr>(str, kv);
161 return result;
162 }
163
ParserAttrShape(const std::string & attr_name,const mindspore::HashMap<string,abstract::AbstractBasePtr> & kv)164 std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
165 const std::string &attr_name, const mindspore::HashMap<string, abstract::AbstractBasePtr> &kv) {
166 std::string str = attr_name;
167 auto replace = [&](const string &orgStr, const string &newStr) {
168 std::string::size_type pos;
169 while ((pos = str.find(orgStr)) != std::string::npos) {
170 (void)str.replace(pos, orgStr.length(), newStr);
171 }
172 return str;
173 };
174 // remove "scalar:"
175 str = replace("shape:", "");
176 // remove "Tuple"
177 str = replace("Tuple", "");
178 // remove "List"
179 str = replace("List", "");
180
181 auto result = ParserAttr<abstract::AbstractTuple, abstract::AbstractBasePtr>(str, kv);
182 return result;
183 }
184
ParseParameterName(const string & name)185 std::string ParseParameterName(const string &name) {
186 string delimiter = ":";
187 size_t pos;
188 if ((pos = name.find(delimiter)) != string::npos) {
189 return name.substr(pos + 1, string::npos - (pos + 1));
190 }
191 return name;
192 }
193
ParseCNodeName(const string & name)194 std::string ParseCNodeName(const string &name) {
195 string delimiter = ":";
196 size_t pos = name.find(delimiter);
197 size_t end_pos = name.find_last_of(delimiter);
198 if (pos != string::npos && end_pos != string::npos && pos != end_pos) {
199 return name.substr(pos + 1, end_pos - (pos + 1));
200 }
201 return name;
202 }
203
204 #define PARSE_MINDIR_ATTR_IN_INT_FORM(type, valuetype) \
205 ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
206 if (attr_proto.ints_size() > index) { \
207 auto value = static_cast<valuetype>(attr_proto.ints(index)); \
208 return MakeValue<valuetype>(value); \
209 } \
210 MS_LOG(INTERNAL_EXCEPTION) << "Parse MindIR attr failed."; \
211 } \
212 ValuePtr ParseAttrInSingleScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto) { \
213 if (attr_proto.has_i()) { \
214 auto value = static_cast<valuetype>(attr_proto.i()); \
215 return MakeValue<valuetype>(value); \
216 } \
217 MS_LOG(INTERNAL_EXCEPTION) << "Parse MindIR attr failed."; \
218 }
219
220 #define PARSE_MINDIR_ATTR_IN_SCALAR_FORM(type, valuetype) \
221 ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
222 if (attr_proto.type##s_size() > index) { \
223 auto value = static_cast<valuetype>(attr_proto.type##s(index)); \
224 return MakeValue<valuetype>(value); \
225 } \
226 MS_LOG(INTERNAL_EXCEPTION) << "Parse MindIR attr failed."; \
227 }
228
PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t,int8_t)229 PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t, int8_t)
230
231 PARSE_MINDIR_ATTR_IN_INT_FORM(int16_t, int16_t)
232
233 PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, int32_t)
234
235 PARSE_MINDIR_ATTR_IN_INT_FORM(int64_t, int64_t)
236
237 PARSE_MINDIR_ATTR_IN_INT_FORM(uint8_t, uint8_t)
238
239 PARSE_MINDIR_ATTR_IN_INT_FORM(uint16_t, uint16_t)
240
241 PARSE_MINDIR_ATTR_IN_INT_FORM(uint32_t, uint32_t)
242
243 PARSE_MINDIR_ATTR_IN_INT_FORM(uint64_t, uint64_t)
244
245 PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, bool)
246
247 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(double, double)
248
249 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(float, float)
250
251 PARSE_MINDIR_ATTR_IN_SCALAR_FORM(string, string)
252
253 ValuePtr ParseAttrInSingleScalar_string_string(const mind_ir::AttributeProto &attr_proto) {
254 auto value = static_cast<string>(attr_proto.s());
255 return MakeValue<string>(value);
256 }
257
ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto & attr_proto)258 ValuePtr ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto &attr_proto) {
259 auto value = static_cast<float>(attr_proto.f());
260 return MakeValue<float>(value);
261 }
262
ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto & attr_proto)263 ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &attr_proto) {
264 auto value = static_cast<double>(attr_proto.d());
265 return MakeValue<double>(value);
266 }
267
GetParseFormType(const std::string & ref_attr_name)268 ParseForm GetParseFormType(const std::string &ref_attr_name) {
269 for (const auto &iter : kParseTypeSwitchMap) {
270 if (ref_attr_name.find(iter.first) == 0) {
271 return iter.second;
272 }
273 }
274 return FORM_PARSE_UNDEFINE;
275 }
276
277 template <typename T>
NewValueNodeWithAbstract(const T & value)278 AnfNodePtr NewValueNodeWithAbstract(const T &value) {
279 auto node = NewValueNode(value);
280 node->set_abstract(value->ToAbstract());
281 return node;
282 }
283
FindGraphByName(const std::vector<FuncGraphPtr> & graphs,const std::string & name)284 FuncGraphPtr FindGraphByName(const std::vector<FuncGraphPtr> &graphs, const std::string &name) {
285 auto iter = std::find_if(graphs.begin(), graphs.end(), [&name](const auto &g) { return g->ToString() == name; });
286 if (iter != graphs.end()) {
287 return *iter;
288 }
289 return nullptr;
290 }
291
CheckModelConfigureInfo(const mind_ir::ModelProto & model_proto)292 bool CheckModelConfigureInfo(const mind_ir::ModelProto &model_proto) {
293 if (!model_proto.has_producer_name()) {
294 MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
295 return false;
296 }
297 const auto &producer_name = model_proto.producer_name();
298 MS_LOG(INFO) << "Producer name: " << producer_name;
299
300 if (!model_proto.has_model_version()) {
301 MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
302 return false;
303 }
304 const auto &model_version = model_proto.model_version();
305 MS_LOG(INFO) << "Producer version: " << model_version;
306
307 int64_t mind_ir_version = 0;
308 if (model_proto.has_mind_ir_version()) {
309 mind_ir_version = model_proto.mind_ir_version();
310 }
311 if (!mind_ir::Version_IsValid(mind_ir_version)) {
312 MS_LOG(EXCEPTION) << "This software can only support the maximum mind ir version: " << mind_ir::Version_MAX
313 << ", please install the latest version to support the mind ir version: " << mind_ir_version;
314 }
315 if (model_proto.has_little_endian()) {
316 if (model_proto.little_endian() != common::IsLittleByteOrder()) {
317 MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!";
318 return false;
319 }
320 }
321 return true;
322 }
323 } // namespace
324
325 namespace {
326 class MSANFModelParser {
327 public:
328 MSANFModelParser() = default;
329 ~MSANFModelParser() = default;
330
331 static void LoadTensorMapClear();
332 FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto, const std::map<std::string, ValuePtr> &weights = {},
333 mindspore::HashMap<std::string, AnfNodePtr> *name_to_node = nullptr);
334 bool Parse(const mind_ir::ModelProto &model_proto, const std::vector<FuncGraphPtr> &graphs,
335 mindspore::HashMap<std::string, AnfNodePtr> *name_to_node = nullptr);
336 const LayoutMap ParseLayout(const mind_ir::ModelProto &model_proto);
337
SetLite()338 void SetLite() { is_lite_ = true; }
IsLite() const339 bool IsLite() const { return is_lite_; }
SetMindIRPath(const std::string & file_path)340 void SetMindIRPath(const std::string &file_path) { mindir_path_ = file_path; }
SetMindIRDecKey(const unsigned char * dec_key)341 void SetMindIRDecKey(const unsigned char *dec_key) { mindir_dec_key_ = dec_key; }
SetMindIRKeySize(size_t size)342 void SetMindIRKeySize(size_t size) { mindir_key_size_ = size; }
SetMindIRDecMode(const std::string & dec_mode)343 void SetMindIRDecMode(const std::string &dec_mode) { mindir_dec_mode_ = dec_mode; }
344
345 private:
346 void TrytoBuildCNodeAbstract();
347 bool BuildPrimitiveNode(const mind_ir::PrimitiveProto &primitive_proto);
348 abstract::AbstractBasePtr BuildAbstractFunction(const mind_ir::AttributeProto &attr_proto);
349 void CorrectFuncGraph(const FuncGraphPtr &root);
350 bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
351 bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
352 bool BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto);
353 ValuePtr GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
354 bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
355 bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
356 bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
357 bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto ¶meter_proto);
358 bool BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
359 const mind_ir::MapTensorProto &map_parameter_proto);
360 abstract::AbstractMapTensorPtr BuildAbstractMapTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
361 abstract::AbstractCOOTensorPtr BuildAbstractCOOTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
362 abstract::AbstractCSRTensorPtr BuildAbstractCSRTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
363 abstract::AbstractSequencePtr BuildAbstractSequence(const mind_ir::AttributeProto &attr_proto);
364 abstract::AbstractScalarPtr BuildAbstractScalar(const mind_ir::AttributeProto &attr_proto) const;
365 bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
366 bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
367 bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
368 abstract::AbstractTensorPtr GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto);
369 CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto);
370 bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
371 bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
372 bool SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
373 bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
374 void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
375 mindspore::HashMap<std::string, ValuePtr> *multi_value_map);
376 ValuePtr ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index);
377 ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
378 bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
379 bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
380 ValuePtr BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
381 AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
382 bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr);
383 void SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr);
384 bool SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto, const AnfNodePtr &node_ptr);
385 abstract::AbstractBasePtr GetNodeAbstractFromAttrProtoWithType(const mind_ir::AttributeProto &attr_proto);
386 void SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr);
387 bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
388 bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
389 bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
390 bool GetAttrValueForValueNodeWithType(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
391 bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
392 bool ObtainValueNodeInNoneForm(const std::string &value_node_name);
393 bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
394 ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto);
395 ValuePtr ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto);
396 std::vector<std::shared_ptr<mindspore::QuantizationParam>> GenerateQuantizationParam(
397 const mind_ir::TensorProto &attr_tensor);
398 FunctorPtr GenerateFunctorValue(const mind_ir::FunctorProto &functor_proto);
little_endian() const399 bool little_endian() const { return little_endian_; }
400 mindspore::HashMap<std::string, abstract::AbstractBasePtr> GetAbstractForNode(
401 const mind_ir::AttributeProto &attr_proto);
402 AnfNodePtr GetAnfNode(const std::string &node_name);
403 tensor::TensorPtr GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor);
404
405 static tensor::TensorPtr GetIncTensor(const std::string &tensor_name);
406 static void SetIncTensor(const std::string &tensor_name, const tensor::TensorPtr &tensor);
407
408 FuncGraphPtr top_graph_ = nullptr;
409 bool is_lite_ = false;
410 bool abstract_valid_ = false;
411 mindspore::HashMap<std::string, AnfNodePtr> anfnode_build_map_;
412 std::string mindir_path_;
413 const unsigned char *mindir_dec_key_{nullptr};
414 size_t mindir_key_size_{0};
415 std::string mindir_dec_mode_;
416 bool little_endian_ = common::IsLittleByteOrder();
417 std::map<std::string, std::unique_ptr<Byte[]>> tenor_data_;
418 bool is_kernel_graph_{false};
419 std::list<std::pair<const CNodePtr, const mind_ir::AttributeProto *>> node_abstract_protos_;
420 };
421
GetValueFromAttributeProto(const mind_ir::AttributeProto & attr_proto)422 ValuePtr MSANFModelParser::GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
423 auto attr_name = attr_proto.name();
424 switch (attr_proto.type()) {
425 case mind_ir::AttributeProto_AttributeType_TENSORS: {
426 mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
427 if (tensor_proto.has_raw_data()) {
428 // For real tensor.
429 tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
430 if (tensor_info == nullptr) {
431 MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
432 return nullptr;
433 }
434 return tensor_info;
435 } else if (tensor_proto.name() == kGraphInputQuantParam) {
436 auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
437 if (!quantization_param_vector.empty()) {
438 return quantization_param_vector[0];
439 }
440 } else {
441 // For data type.
442 const int attr_tensor_type = tensor_proto.data_type();
443 auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
444 if (iter == kDefaultValueSwitchMap.end()) {
445 MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
446 return nullptr;
447 }
448 return TypeIdToType(iter->second);
449 }
450 MS_LOG(ERROR) << "Failed to get the tensor for value.";
451 return nullptr;
452 }
453 case mind_ir::AttributeProto_AttributeType_NONE: {
454 return kNone;
455 }
456 case mind_ir::AttributeProto_AttributeType_TUPLE:
457 case mind_ir::AttributeProto_AttributeType_LIST: {
458 auto sequence_value = ObtainValueInSequenceForm(attr_proto);
459 if (sequence_value == nullptr) {
460 MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
461 return nullptr;
462 }
463 return sequence_value;
464 }
465 case mind_ir::AttributeProto_AttributeType_DICT: {
466 auto dict_value = ObtainValueInDictionaryForm(attr_proto);
467 if (dict_value == nullptr) {
468 MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name;
469 return nullptr;
470 }
471 return dict_value;
472 }
473 case mind_ir::AttributeProto_AttributeType_FUNCTOR: {
474 auto functor_value = GenerateFunctorValue(attr_proto.functor());
475 if (functor_value == nullptr) {
476 MS_LOG(ERROR) << "Failed to get functor value for " << attr_name;
477 return nullptr;
478 }
479 return functor_value;
480 }
481 default: {
482 ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
483 if (value == nullptr) {
484 MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name;
485 return nullptr;
486 }
487 return value;
488 }
489 }
490 }
491
GenerateFunctorValue(const mind_ir::FunctorProto & functor_proto)492 FunctorPtr MSANFModelParser::GenerateFunctorValue(const mind_ir::FunctorProto &functor_proto) {
493 auto name = functor_proto.name();
494 auto type = functor_proto.type();
495 auto values = GetValueFromAttributeProto(functor_proto.values(0));
496 if (type == mind_ir::FunctorProto_FunctorType_SHAPE_CALC_FUNCTOR) {
497 auto creator = FunctorRegistry::Instance().GetCreator(name);
498 if (creator == nullptr) {
499 MS_LOG(ERROR) << "Cannot find the functor creator: " << name;
500 return nullptr;
501 }
502 auto functor = creator();
503 functor->FromValue(values);
504 return functor;
505 }
506 MS_LOG(ERROR) << "Unknown functor type: " << type;
507 return nullptr;
508 }
509
GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto & attr_tensor)510 tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor) {
511 ShapeVector shape;
512 const int attr_tensor_type = attr_tensor.data_type();
513 for (int i = 0; i < attr_tensor.dims_size(); ++i) {
514 shape.push_back(attr_tensor.dims(i));
515 }
516 tensor::TensorPtr tensor = nullptr;
517 if (!attr_tensor.has_compression_type() ||
518 attr_tensor.compression_type() == mind_ir::TensorProto_CompressionType_NO_COMPRESSION) {
519 tensor = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
520 } else {
521 auto compression_type = static_cast<TensorCompressionType>(static_cast<int>(attr_tensor.compression_type()));
522 size_t data_size = 0;
523 if (!attr_tensor.has_external_data()) {
524 data_size = attr_tensor.raw_data().size();
525 } else {
526 data_size = LongToSize(attr_tensor.external_data().length());
527 }
528 tensor =
529 std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape, data_size, compression_type);
530 }
531
532 auto quantization_param_vector = GenerateQuantizationParam(attr_tensor);
533 if (!quantization_param_vector.empty()) {
534 tensor->set_quant_param(quantization_param_vector);
535 }
536
537 MS_EXCEPTION_IF_NULL(tensor);
538 const std::string &tensor_buf = attr_tensor.raw_data();
539 if (attr_tensor.has_raw_data() && tensor->data().nbytes() != 0) {
540 auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
541 errno_t ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
542 if (ret != EOK) {
543 MS_LOG(ERROR) << "Failed to copy data from tensor proto.";
544 return nullptr;
545 }
546 } else if (attr_tensor.has_external_data()) {
547 auto ret = GetTensorDataFromExternal(attr_tensor, tensor);
548 if (!ret) {
549 MS_LOG(ERROR) << "Failed to get external data from tensor proto.";
550 return nullptr;
551 }
552 } else {
553 MS_LOG(DEBUG) << "Parameter will load initialized data.";
554 }
555 return tensor;
556 }
557
GenerateQuantizationParam(const mind_ir::TensorProto & attr_tensor)558 std::vector<std::shared_ptr<mindspore::QuantizationParam>> MSANFModelParser::GenerateQuantizationParam(
559 const mind_ir::TensorProto &attr_tensor) {
560 auto quant_param_proto = attr_tensor.quant_params();
561 std::vector<std::shared_ptr<mindspore::QuantizationParam>> quantization_param_vector;
562 for (int i = 0; i < quant_param_proto.size(); i++) {
563 auto quant_data = quant_param_proto.Get(i);
564 QuantizationParam quantization_param(quant_data.quant_algo_name());
565 for (int index = 0; index < quant_data.attribute_size(); index++) {
566 auto quant_attr_proto = quant_data.attribute().Get(index);
567 if (quant_attr_proto.type() != mind_ir::AttributeProto_AttributeType_LIST) {
568 MS_LOG(ERROR) << "quant_attr_proto.type is " << quant_attr_proto.type()
569 << ", is should be mind_ir::AttributeProto_AttributeType_LIST ("
570 << mind_ir::AttributeProto_AttributeType_LIST << ")";
571 return {};
572 }
573 auto sequence_value = ObtainValueInSequenceForm(quant_attr_proto);
574 quantization_param.SetAttr(quant_attr_proto.name(), sequence_value);
575 }
576 quantization_param_vector.push_back(std::make_shared<mindspore::QuantizationParam>(quantization_param));
577 }
578 return quantization_param_vector;
579 }
580
GetNodeAbstractFromAttrProtoWithType(const mind_ir::AttributeProto & attr_proto)581 abstract::AbstractBasePtr MSANFModelParser::GetNodeAbstractFromAttrProtoWithType(
582 const mind_ir::AttributeProto &attr_proto) {
583 switch (attr_proto.type()) {
584 case mind_ir::AttributeProto_AttributeType_TENSORS: {
585 const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(0);
586 return GetAbsTensorFromTensorProto(attr_tensor);
587 }
588 case mind_ir::AttributeProto_AttributeType_CSR_TENSOR: {
589 return BuildAbstractCSRTensorFromAttrProto(attr_proto);
590 }
591 case mind_ir::AttributeProto_AttributeType_COO_TENSOR: {
592 return BuildAbstractCOOTensorFromAttrProto(attr_proto);
593 }
594 case mind_ir::AttributeProto_AttributeType_MAP_TENSOR: {
595 return BuildAbstractMapTensorFromAttrProto(attr_proto);
596 }
597 case mind_ir::AttributeProto_AttributeType_LIST:
598 case mind_ir::AttributeProto_AttributeType_TUPLE: {
599 return BuildAbstractSequence(attr_proto);
600 }
601 case mind_ir::AttributeProto_AttributeType_UMONAD: {
602 return kUMonad->ToAbstract();
603 }
604 case mind_ir::AttributeProto_AttributeType_IOMONAD: {
605 return kIOMonad->ToAbstract();
606 }
607 // in old version the bool is load and export in an error type.
608 // but MindIR should be Compatible with older versions.
609 case mind_ir::AttributeProto_AttributeType_BOOL: {
610 return kBool->ToAbstract();
611 }
612 case mind_ir::AttributeProto_AttributeType_SCALAR: {
613 return BuildAbstractScalar(attr_proto);
614 }
615 case mind_ir::AttributeProto_AttributeType_NONE: {
616 return kNone->ToAbstract();
617 }
618 case mind_ir::AttributeProto_AttributeType_FUNCGRAPHCLOSURE:
619 case mind_ir::AttributeProto_AttributeType_PRIMITIVECLOSURE:
620 case mind_ir::AttributeProto_AttributeType_PARTIALCLOSURE:
621 case mind_ir::AttributeProto_AttributeType_UNIONFUNCCLOSURE: {
622 return BuildAbstractFunction(attr_proto);
623 }
624 default: {
625 MS_LOG(INFO) << "Not support to get the abstract from AttrProto type: " << attr_proto.type();
626 return nullptr;
627 }
628 }
629 }
630
SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto & attr_proto,const AnfNodePtr & node_ptr)631 bool MSANFModelParser::SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto,
632 const AnfNodePtr &node_ptr) {
633 mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
634 string shape_ref_attr_name;
635 if (attr_proto.ref_attr_name().find("shape:") == string::npos) {
636 MS_LOG(ERROR) << "Cannot use a attr_proto " << attr_proto.ref_attr_name() << " to init shape.";
637 return false;
638 }
639
640 shape_ref_attr_name = attr_proto.ref_attr_name();
641 bool is_tuple_or_list =
642 shape_ref_attr_name.find("Tuple[") != string::npos || shape_ref_attr_name.find("List[") != string::npos;
643 kv = GetAbstractForNode(attr_proto);
644 if (kv.empty()) {
645 return SetEmptyTensorProtoCNodeAbstract(node_ptr);
646 } else if (!is_tuple_or_list) {
647 auto iter = kv.begin();
648 if (iter->second != nullptr) {
649 node_ptr->set_abstract(iter->second);
650 }
651 } else {
652 auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
653 node_ptr->set_abstract(abstract);
654 if (abstract == nullptr) {
655 MS_LOG(ERROR) << "Node's attribute is nullptr.";
656 return false;
657 }
658 }
659 return true;
660 }
661
SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto & node_proto,const CNodePtr & cnode_ptr)662 void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr) {
663 auto prim_to_add_attr = GetCNodePrimitiveWithoutDoSignature(cnode_ptr);
664 if (prim_to_add_attr != nullptr) {
665 prim_to_add_attr->set_attr("is_load", MakeValue(true));
666 }
667 for (int i = 0; i < node_proto.attribute_size(); ++i) {
668 const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
669 // Compatible with older versions.
670 if (attr_proto.has_ref_attr_name()) {
671 if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
672 SetCNodeAbstract(attr_proto, cnode_ptr);
673 continue;
674 }
675 if (prim_to_add_attr != nullptr && !GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
676 MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
677 << ", attributes error: " << attr_proto.DebugString();
678 }
679 } else {
680 // ref_attr_name is removed in newer versions.
681 if (attr_proto.name() == "shape") {
682 SetCNodeAbstract(attr_proto, cnode_ptr);
683 continue;
684 }
685 if (prim_to_add_attr != nullptr && !SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
686 MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
687 << ", attributes error: " << attr_proto.DebugString();
688 }
689 }
690 }
691 }
692
GetAbsTensorFromTensorProto(const mind_ir::TensorProto & tensor_proto)693 abstract::AbstractTensorPtr MSANFModelParser::GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto) {
694 ShapeVector shape;
695 for (int i = 0; i < tensor_proto.dims_size(); ++i) {
696 (void)shape.emplace_back(tensor_proto.dims(i));
697 }
698
699 if (!tensor_proto.has_data_type()) {
700 MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
701 MS_LOG(ERROR) << "mind_ir TensorProto has no data_type.";
702 return nullptr;
703 }
704 auto iter = kDefaultValueSwitchMap.find(tensor_proto.data_type());
705 if (iter == kDefaultValueSwitchMap.end()) {
706 MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
707 MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
708 return nullptr;
709 }
710 auto tensor_shape = std::make_shared<abstract::Shape>(shape);
711 auto tensor_info = std::make_shared<abstract::AbstractTensor>(TypeIdToType(iter->second), tensor_shape);
712 if (tensor_proto.has_ref_key()) {
713 auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
714 auto abs_ref = std::make_shared<abstract::AbstractRefTensor>(tensor_info, ref_key);
715 return abs_ref;
716 }
717 if (tensor_proto.has_name()) {
718 tensor_info->set_name(tensor_proto.name());
719 }
720 return tensor_info;
721 }
722
BuildParameterForFuncGraph(const ParameterPtr & node,const mind_ir::TensorProto & parameter_proto)723 bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
724 const mind_ir::TensorProto ¶meter_proto) {
725 MS_EXCEPTION_IF_NULL(node);
726
727 if (!parameter_proto.has_name()) {
728 MS_LOG(ERROR) << "mind_ir TensorProto has no name!";
729 return false;
730 }
731 const auto &unique_name = parameter_proto.name();
732 string debug_info_name = ParseParameterName(unique_name);
733 auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
734 node->set_debug_info(debug_info_ptr);
735 node->set_name(debug_info_name);
736
737 ParamInfoPtr param_info = std::make_shared<ParamInfo>();
738 param_info->set_name(debug_info_name);
739
740 MS_LOG(DEBUG) << "Load parameter name: " << unique_name;
741 auto tensor = GenerateTensorPtrFromTensorProto(parameter_proto);
742 if (tensor == nullptr) {
743 MS_LOG(ERROR) << "Build tensor failed from the parameter proto.";
744 return false;
745 }
746 tensor->set_param_info(param_info);
747 node->set_default_param(tensor);
748 node->set_abstract(tensor->ToAbstract());
749
750 anfnode_build_map_[parameter_proto.name()] = node;
751 return true;
752 }
753
BuildAbstractCOOTensorFromAttrProto(const mind_ir::AttributeProto & attr_proto)754 abstract::AbstractCOOTensorPtr MSANFModelParser::BuildAbstractCOOTensorFromAttrProto(
755 const mind_ir::AttributeProto &attr_proto) {
756 std::vector<abstract::AbstractBasePtr> vec;
757 for (int i = 0; i < attr_proto.values_size(); ++i) {
758 auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
759 if (abs == nullptr) {
760 MS_LOG(WARNING) << "Failed to get the COOTensor's abstract from AttrProto. " << attr_proto.DebugString();
761 return nullptr;
762 }
763 (void)vec.emplace_back(abs);
764 }
765 return std::make_shared<abstract::AbstractCOOTensor>(vec);
766 }
767
BuildAbstractCSRTensorFromAttrProto(const mind_ir::AttributeProto & attr_proto)768 abstract::AbstractCSRTensorPtr MSANFModelParser::BuildAbstractCSRTensorFromAttrProto(
769 const mind_ir::AttributeProto &attr_proto) {
770 std::vector<abstract::AbstractBasePtr> vec;
771 for (int i = 0; i < attr_proto.values_size(); ++i) {
772 auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
773 if (abs == nullptr) {
774 MS_LOG(WARNING) << "Failed to get the CSRTensor's abstract from AttrProto. " << attr_proto.DebugString();
775 return nullptr;
776 }
777 (void)vec.emplace_back(abs);
778 }
779 return std::make_shared<abstract::AbstractCSRTensor>(vec);
780 }
781
BuildAbstractMapTensorFromAttrProto(const mind_ir::AttributeProto & attr_proto)782 abstract::AbstractMapTensorPtr MSANFModelParser::BuildAbstractMapTensorFromAttrProto(
783 const mind_ir::AttributeProto &attr_proto) {
784 // default value
785 if (attr_proto.values_size() != 1) {
786 MS_LOG(INTERNAL_EXCEPTION) << "AttrProto for AbstractMapTensor should has 1 value, but got "
787 << attr_proto.values_size();
788 }
789 const auto &default_value_proto = attr_proto.values(0);
790 auto default_value = ObtainCNodeAttrInSingleScalarForm(default_value_proto);
791 MS_EXCEPTION_IF_NULL(default_value);
792
793 constexpr int kAbstractMapTensorAttrProtoTensorsSize = 2;
794 if (attr_proto.tensors_size() != kAbstractMapTensorAttrProtoTensorsSize) {
795 MS_LOG(INTERNAL_EXCEPTION) << "AttrProto for AbstractMapTensor should has 2 tensors, but got "
796 << attr_proto.tensors_size();
797 }
798 // key tensor
799 const auto &key_tensor_proto = attr_proto.tensors(0);
800 auto key_tensor_abs = GetAbsTensorFromTensorProto(key_tensor_proto);
801 MS_EXCEPTION_IF_NULL(key_tensor_abs);
802 // value tensor
803 const auto &value_tensor_proto = attr_proto.tensors(1);
804 auto value_tensor_abs = GetAbsTensorFromTensorProto(value_tensor_proto);
805 MS_EXCEPTION_IF_NULL(value_tensor_abs);
806 auto value_build_shape_ptr = value_tensor_abs->BuildShape();
807 if (!value_build_shape_ptr->isa<abstract::Shape>()) {
808 MS_LOG(INTERNAL_EXCEPTION) << "value_shape of AbstractMapTensor should be a Shape, but got "
809 << value_build_shape_ptr->ToString();
810 }
811 auto value_shape_ptr = value_build_shape_ptr->cast<abstract::ShapePtr>();
812 MS_EXCEPTION_IF_NULL(value_shape_ptr);
813 auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor_abs->BuildType()->type_id(),
814 value_tensor_abs->BuildType()->type_id(),
815 value_shape_ptr->shape(), default_value);
816 return std::make_shared<abstract::AbstractMapTensor>(map_tensor);
817 }
818
BuildAbstractScalar(const mind_ir::AttributeProto & attr_proto) const819 abstract::AbstractScalarPtr MSANFModelParser::BuildAbstractScalar(const mind_ir::AttributeProto &attr_proto) const {
820 const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(0);
821 auto iter = kDefaultValueSwitchMap.find(attr_tensor.data_type());
822 if (iter == kDefaultValueSwitchMap.end()) {
823 MS_LOG(ERROR) << "mind_ir build tensor: " << attr_tensor.name() << " failed";
824 MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << attr_tensor.data_type() << " is not support yet!";
825 return nullptr;
826 }
827 return std::make_shared<abstract::AbstractScalar>(TypeIdToType(iter->second));
828 }
829
BuildAbstractSequence(const mind_ir::AttributeProto & attr_proto)830 abstract::AbstractSequencePtr MSANFModelParser::BuildAbstractSequence(const mind_ir::AttributeProto &attr_proto) {
831 std::vector<abstract::AbstractBasePtr> vec;
832
833 for (int i = 0; i < attr_proto.values_size(); ++i) {
834 auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
835 if (abs == nullptr) {
836 MS_LOG(WARNING) << "Failed to get the tuple's abstract from AttrProto. " << attr_proto.DebugString();
837 return nullptr;
838 }
839 (void)vec.emplace_back(abs);
840 }
841 abstract::AbstractSequencePtr seq_abs;
842 if (attr_proto.type() == mind_ir::AttributeProto_AttributeType_TUPLE) {
843 seq_abs = std::make_shared<abstract::AbstractTuple>(vec);
844 } else {
845 seq_abs = std::make_shared<abstract::AbstractList>(vec);
846 }
847 if (attr_proto.has_seq_info()) {
848 auto seq_info = attr_proto.seq_info();
849 seq_abs->set_dynamic_len(seq_info.is_dyn_len());
850 if (seq_info.has_tuple_elem_item()) {
851 auto elem_proto = seq_info.tuple_elem_item();
852 auto elem_abs = GetNodeAbstractFromAttrProtoWithType(elem_proto);
853 seq_abs->set_dynamic_len_element_abs(elem_abs);
854 }
855 }
856 return seq_abs;
857 }
858
BuildMapParameterFromMapTensorProto(const ParameterPtr & node,const mind_ir::MapTensorProto & map_parameter_proto)859 bool MSANFModelParser::BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
860 const mind_ir::MapTensorProto &map_parameter_proto) {
861 MS_EXCEPTION_IF_NULL(node);
862
863 if (!map_parameter_proto.has_name()) {
864 MS_LOG(ERROR) << "mind_ir MapTensorProto has no name!";
865 return false;
866 }
867
868 string debug_info_name = ParseParameterName(map_parameter_proto.name());
869 auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
870 node->set_debug_info(debug_info_ptr);
871 node->set_name(debug_info_name);
872
873 ParamInfoPtr param_info = std::make_shared<ParamInfo>();
874 param_info->set_name(debug_info_name);
875
876 MS_LOG(DEBUG) << "Load map parameter name: " << map_parameter_proto.name();
877 // default value
878 if (!map_parameter_proto.has_default_value()) {
879 MS_LOG(ERROR) << "MapTensorProto should have default value: " << map_parameter_proto.name();
880 return false;
881 }
882 const auto &default_value_proto = map_parameter_proto.default_value();
883 auto default_value = BuildValueFromAttributeProto(default_value_proto);
884 if (default_value == nullptr) {
885 MS_LOG(ERROR) << "Build default value from AttributeProto failed.";
886 return false;
887 }
888 // key tensor
889 if (!map_parameter_proto.has_key_tensor()) {
890 MS_LOG(ERROR) << "MapTensorProto should have key tensor: " << map_parameter_proto.name();
891 return false;
892 }
893 const auto &key_tensor_proto = map_parameter_proto.key_tensor();
894 auto key_tensor = GenerateTensorPtrFromTensorProto(key_tensor_proto);
895 if (key_tensor == nullptr) {
896 MS_LOG(ERROR) << "Generate key tensor from TensorProto failed.";
897 return false;
898 }
899 // value tensor
900 if (!map_parameter_proto.has_value_tensor()) {
901 MS_LOG(ERROR) << "MapTensorProto should have value tensor: " << map_parameter_proto.name();
902 return false;
903 }
904 const auto &value_tensor_proto = map_parameter_proto.value_tensor();
905 auto value_tensor = GenerateTensorPtrFromTensorProto(value_tensor_proto);
906 if (value_tensor == nullptr) {
907 MS_LOG(ERROR) << "Generate value tensor from TensorProto failed.";
908 return false;
909 }
910 // status tensor
911 if (!map_parameter_proto.has_status_tensor()) {
912 MS_LOG(ERROR) << "MapTensorProto should have status tensor: " << map_parameter_proto.name();
913 return false;
914 }
915 const auto &status_tensor_proto = map_parameter_proto.status_tensor();
916 auto status_tensor = GenerateTensorPtrFromTensorProto(status_tensor_proto);
917 if (status_tensor == nullptr) {
918 MS_LOG(ERROR) << "Generate status tensor from TensorProto failed.";
919 return false;
920 }
921
922 auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor, value_tensor, status_tensor, default_value);
923 map_tensor->set_param_info(param_info);
924 node->set_default_param(map_tensor);
925 node->set_abstract(map_tensor->ToAbstract());
926
927 anfnode_build_map_[map_parameter_proto.name()] = node;
928 return true;
929 }
930
GetTensorDataFromExternal(const mind_ir::TensorProto & tensor_proto,const tensor::TensorPtr & tensor_info)931 bool MSANFModelParser::GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto,
932 const tensor::TensorPtr &tensor_info) {
933 if (!tensor_proto.has_external_data()) {
934 return false;
935 }
936 const unsigned char *data = nullptr;
937 auto it = tenor_data_.find(tensor_proto.external_data().location());
938 if (it != tenor_data_.end()) {
939 data = it->second.get();
940 } else {
941 std::string file = mindir_path_ + "/" + tensor_proto.external_data().location();
942 if (mindir_dec_key_ != nullptr) {
943 size_t plain_len;
944 auto plain_data = Decrypt(&plain_len, file, mindir_dec_key_, mindir_key_size_, mindir_dec_mode_);
945 if (plain_data == nullptr) {
946 MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
947 return false;
948 }
949 data = plain_data.get();
950 (void)tenor_data_.emplace(tensor_proto.external_data().location(), std::move(plain_data));
951 } else {
952 // Read file
953 std::basic_ifstream<char> fid(file, std::ios::in | std::ios::binary);
954 if (!fid) {
955 MS_LOG(EXCEPTION) << "Open file '" << file << "' failed, please check the correct of the file.";
956 }
957 (void)fid.seekg(0, std::ios_base::end);
958 size_t file_size = static_cast<size_t>(fid.tellg());
959 fid.clear();
960 (void)fid.seekg(0);
961 std::unique_ptr<char[]> plain_data(new (std::nothrow) char[file_size]);
962 if (plain_data == nullptr) {
963 MS_LOG(ERROR) << "Failed to create file buffer, file size: " << file_size << " bytes";
964 return false;
965 }
966 constexpr Byte is_little_endian = 1;
967 constexpr int byte_order_index = 0;
968 (void)fid.read(plain_data.get(), SizeToLong(file_size));
969 fid.close();
970 // if byte order is not same return false
971 if ((plain_data[byte_order_index] == is_little_endian) ^ little_endian()) {
972 MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!";
973 return false;
974 }
975 data = reinterpret_cast<const unsigned char *>(plain_data.get());
976 (void)tenor_data_.emplace(tensor_proto.external_data().location(),
977 std::unique_ptr<Byte[]>(reinterpret_cast<Byte *>(plain_data.release())));
978 }
979 }
980 auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
981 MS_EXCEPTION_IF_NULL(tensor_data_buf);
982 MS_EXCEPTION_IF_NULL(data);
983
984 if (tensor_info->data().nbytes() == 0 || tensor_proto.external_data().length() == 0) {
985 // no need to copy data
986 return true;
987 }
988
989 auto ret =
990 common::huge_memcpy(tensor_data_buf, tensor_info->data().nbytes(), data + tensor_proto.external_data().offset(),
991 LongToSize(tensor_proto.external_data().length()));
992 if (ret != EOK) {
993 MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
994 return false;
995 }
996 return true;
997 }
998
BuildInputForFuncGraph(const ParameterPtr & node,const mind_ir::ValueInfoProto & value_proto)999 bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto) {
1000 MS_EXCEPTION_IF_NULL(node);
1001
1002 if (!value_proto.has_name()) {
1003 MS_LOG(ERROR) << "mind_ir ValueInfoProto has no name!";
1004 return false;
1005 }
1006 string debug_info_name = ParseParameterName(value_proto.name());
1007 auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
1008 node->set_debug_info(debug_info_ptr);
1009 node->set_name(debug_info_name);
1010
1011 // Set abstract of the parameter
1012 if (value_proto.tensor_size() > 0) {
1013 const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
1014 auto tensor_info = GetAbsTensorFromTensorProto(tensor_proto);
1015 if (tensor_info == nullptr) {
1016 MS_LOG(ERROR) << "Get tensor_info fail.";
1017 return false;
1018 }
1019 node->set_abstract(tensor_info);
1020 if (tensor_proto.has_ref_key() && top_graph_ != nullptr) {
1021 auto parameters = top_graph_->parameters();
1022 for (const auto ¶meter : parameters) {
1023 auto parameter_abs = parameter->abstract();
1024 if (parameter_abs->isa<abstract::AbstractRefTensor>()) {
1025 auto parameter_abs_value = parameter_abs->cast<abstract::AbstractRefPtr>()->ref_key_value();
1026 auto ref_key_value = parameter_abs_value->cast<StringImmPtr>();
1027 if (ref_key_value != nullptr && ref_key_value->value() == tensor_proto.ref_key()) {
1028 node->set_default_param(parameter->cast<ParameterPtr>()->default_param());
1029 break;
1030 }
1031 }
1032 }
1033 }
1034 } else if (value_proto.has_denotation()) {
1035 if (value_proto.denotation() == "UMonadType") {
1036 node->set_abstract(kUMonad->ToAbstract());
1037 } else if (value_proto.denotation() == "IOMonadType") {
1038 node->set_abstract(kIOMonad->ToAbstract());
1039 }
1040 MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
1041 }
1042 if (value_proto.has_attr_info()) {
1043 auto attr_proto = value_proto.attr_info();
1044 // Compatible with the previous proto.
1045 if (attr_proto.has_ref_attr_name()) {
1046 if (!SetNodeAbstractFromAttrProto(attr_proto, node)) {
1047 MS_LOG(ERROR) << "Failed to get abstract for input node " << node->name()
1048 << " from proto:" << attr_proto.DebugString();
1049 }
1050 } else {
1051 auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto);
1052 if (abs == nullptr) {
1053 MS_LOG(ERROR) << "Failed to get abstract for input node " << node->name()
1054 << " from attr_proto:" << attr_proto.DebugString();
1055 }
1056 node->set_abstract(abs);
1057 }
1058 }
1059 if (node->abstract() == nullptr) {
1060 MS_LOG(INFO) << "Failed to build abstract of node:" << node->name()
1061 << " from ValueInfoProto:" << value_proto.DebugString();
1062 }
1063 anfnode_build_map_[value_proto.name()] = node;
1064 return true;
1065 }
1066
ImportParametersForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1067 bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
1068 const mind_ir::GraphProto &importProto) {
1069 MS_EXCEPTION_IF_NULL(outputFuncGraph);
1070 MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
1071 for (int i = 0; i < importProto.input_size(); ++i) {
1072 const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
1073 if (is_kernel_graph_ && anfnode_build_map_.count(input_proto.name()) > 0) {
1074 continue;
1075 }
1076 if (!BuildInputForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
1077 MS_LOG(ERROR) << "Build input for funcgraph fail at index: " << i;
1078 return false;
1079 }
1080 }
1081
1082 MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
1083 for (int i = 0; i < importProto.parameter_size(); ++i) {
1084 const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i);
1085 if (is_kernel_graph_ && anfnode_build_map_.count(parameter_proto.name()) > 0) {
1086 continue;
1087 }
1088 if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
1089 MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
1090 return false;
1091 }
1092 }
1093 outputFuncGraph->set_fv_param_count(IntToSize(importProto.parameter_size()));
1094 return true;
1095 }
1096
ImportMapParametersForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1097 bool MSANFModelParser::ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph,
1098 const mind_ir::GraphProto &importProto) {
1099 MS_EXCEPTION_IF_NULL(outputFuncGraph);
1100 MS_LOG(INFO) << "All MapParameters size is: " << importProto.map_parameter_size();
1101 for (int i = 0; i < importProto.map_parameter_size(); ++i) {
1102 const mind_ir::MapTensorProto &map_parameter_proto = importProto.map_parameter(i);
1103 if (!BuildMapParameterFromMapTensorProto(outputFuncGraph->add_parameter(), map_parameter_proto)) {
1104 MS_LOG(ERROR) << "Build map parameter for funcgraph fail at index: " << i;
1105 return false;
1106 }
1107 }
1108 outputFuncGraph->set_fv_param_count(IntToSize(importProto.parameter_size()));
1109 return true;
1110 }
1111
ObtainCNodeAttrInTypeForm(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)1112 bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
1113 MS_EXCEPTION_IF_NULL(prim);
1114 const int attr_tensor_type = attr_proto.tensors(0).data_type();
1115 if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
1116 MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
1117 return false;
1118 }
1119 (void)prim->AddAttr(attr_proto.name(), TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
1120 return true;
1121 }
1122
ParseAttrInScalarForm(const mind_ir::AttributeProto & attr_proto,int index)1123 ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index) {
1124 const int attr_type = static_cast<int>(attr_proto.type());
1125 switch (attr_type) {
1126 case mind_ir::AttributeProto_AttributeType_STRING: {
1127 return ParseAttrInScalar_string_string(attr_proto, index);
1128 }
1129 case mind_ir::AttributeProto_AttributeType_INT8: {
1130 return ParseAttrInScalar_int8_t_int8_t(attr_proto, index);
1131 }
1132 case mind_ir::AttributeProto_AttributeType_INT16: {
1133 return ParseAttrInScalar_int16_t_int16_t(attr_proto, index);
1134 }
1135 case mind_ir::AttributeProto_AttributeType_INT32: {
1136 return ParseAttrInScalar_int32_t_int32_t(attr_proto, index);
1137 }
1138 case mind_ir::AttributeProto_AttributeType_INT64: {
1139 return ParseAttrInScalar_int64_t_int64_t(attr_proto, index);
1140 }
1141 case mind_ir::AttributeProto_AttributeType_UINT8: {
1142 return ParseAttrInScalar_uint8_t_uint8_t(attr_proto, index);
1143 }
1144 case mind_ir::AttributeProto_AttributeType_UINT16: {
1145 return ParseAttrInScalar_uint16_t_uint16_t(attr_proto, index);
1146 }
1147 case mind_ir::AttributeProto_AttributeType_UINT32: {
1148 return ParseAttrInScalar_uint32_t_uint32_t(attr_proto, index);
1149 }
1150 case mind_ir::AttributeProto_AttributeType_UINT64: {
1151 return ParseAttrInScalar_uint64_t_uint64_t(attr_proto, index);
1152 }
1153 case mind_ir::AttributeProto_AttributeType_FLOAT: {
1154 return ParseAttrInScalar_float_float(attr_proto, index);
1155 }
1156 case mind_ir::AttributeProto_AttributeType_DOUBLE: {
1157 return ParseAttrInScalar_double_double(attr_proto, index);
1158 }
1159 case mind_ir::AttributeProto_AttributeType_BOOL: {
1160 return ParseAttrInScalar_int32_t_bool(attr_proto, index);
1161 }
1162 case mind_ir::AttributeProto_AttributeType_TENSORS: {
1163 const int attr_tensor_type = attr_proto.tensors(index).data_type();
1164 if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
1165 MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
1166 return nullptr;
1167 }
1168 return TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
1169 }
1170 default:
1171 MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
1172 return nullptr;
1173 }
1174 }
1175
ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto & attr_proto,mindspore::HashMap<std::string,ValuePtr> * multi_value_map)1176 void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
1177 mindspore::HashMap<std::string, ValuePtr> *multi_value_map) {
1178 string name;
1179 auto func = [&name, &multi_value_map, this](const mind_ir::AttributeProto &attr_proto, int length) -> void {
1180 for (int i = 0; i < length; ++i) {
1181 auto res = this->ParseAttrInScalarForm(attr_proto, i);
1182 name = "value" + std::to_string(i + 1);
1183 (void)multi_value_map->emplace(name, res);
1184 }
1185 };
1186 func(attr_proto, attr_proto.ints_size());
1187 func(attr_proto, attr_proto.doubles_size());
1188 func(attr_proto, attr_proto.floats_size());
1189 func(attr_proto, attr_proto.strings_size());
1190 func(attr_proto, attr_proto.tensors_size());
1191 }
1192
ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto & attr_proto)1193 ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) {
1194 const int attr_type = static_cast<int>(attr_proto.type());
1195 switch (attr_type) {
1196 case mind_ir::AttributeProto_AttributeType_STRING: {
1197 return ParseAttrInSingleScalar_string_string(attr_proto);
1198 }
1199 case mind_ir::AttributeProto_AttributeType_INT8: {
1200 return ParseAttrInSingleScalar_int8_t_int8_t(attr_proto);
1201 }
1202 case mind_ir::AttributeProto_AttributeType_INT16: {
1203 return ParseAttrInSingleScalar_int16_t_int16_t(attr_proto);
1204 }
1205 case mind_ir::AttributeProto_AttributeType_INT32: {
1206 return ParseAttrInSingleScalar_int32_t_int32_t(attr_proto);
1207 }
1208 case mind_ir::AttributeProto_AttributeType_INT64: {
1209 return ParseAttrInSingleScalar_int64_t_int64_t(attr_proto);
1210 }
1211 case mind_ir::AttributeProto_AttributeType_UINT8: {
1212 return ParseAttrInSingleScalar_uint8_t_uint8_t(attr_proto);
1213 }
1214 case mind_ir::AttributeProto_AttributeType_UINT16: {
1215 return ParseAttrInSingleScalar_uint16_t_uint16_t(attr_proto);
1216 }
1217 case mind_ir::AttributeProto_AttributeType_UINT32: {
1218 return ParseAttrInSingleScalar_uint32_t_uint32_t(attr_proto);
1219 }
1220 case mind_ir::AttributeProto_AttributeType_UINT64: {
1221 return ParseAttrInSingleScalar_uint64_t_uint64_t(attr_proto);
1222 }
1223 case mind_ir::AttributeProto_AttributeType_FLOAT: {
1224 return ParseAttrInSingleScalar_float_float(attr_proto);
1225 }
1226 case mind_ir::AttributeProto_AttributeType_DOUBLE: {
1227 return ParseAttrInSingleScalar_double_double(attr_proto);
1228 }
1229 case mind_ir::AttributeProto_AttributeType_BOOL: {
1230 return ParseAttrInSingleScalar_int32_t_bool(attr_proto);
1231 }
1232 default:
1233 MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
1234 return nullptr;
1235 }
1236 }
1237
ObtainCNodeAttrInTensorForm(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)1238 bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
1239 const mind_ir::AttributeProto &attr_proto) {
1240 MS_EXCEPTION_IF_NULL(prim);
1241 const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0);
1242 auto tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor);
1243 if (tensor_info == nullptr) {
1244 MS_LOG(ERROR) << "Failed to get attr[" << attr_proto.name() << "] for node " << prim->ToString()
1245 << " from the proto.";
1246 return false;
1247 }
1248 (void)prim->AddAttr(attr_proto.name(), MakeValue(tensor_info));
1249 return true;
1250 }
1251
SetPrimitiveAttrWithType(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)1252 bool MSANFModelParser::SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
1253 MS_EXCEPTION_IF_NULL(prim);
1254 const std::string &attr_name = attr_proto.name();
1255 auto value = GetValueFromAttributeProto(attr_proto);
1256 if (value == nullptr) {
1257 MS_LOG(ERROR) << "Failed to get value from proto.\n proto info:" << attr_proto.name();
1258 return false;
1259 }
1260 const std::string &op_type = prim->name();
1261 if (is_kernel_graph_) {
1262 (void)prim->AddAttr(attr_name, value);
1263 return true;
1264 }
1265 CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value);
1266 // Compatible with older versions.
1267 if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa<StringImm>()) {
1268 auto str_dtype = GetValue<std::string>(value);
1269 if (str_dtype == "int32") {
1270 int64_t index = 3;
1271 (void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
1272 }
1273 MS_EXCEPTION(NotSupportError)
1274 << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
1275 << value->ToString();
1276 }
1277 (void)prim->AddAttr(attr_name, value);
1278 return true;
1279 }
1280
GetAttrValueForCNode(const PrimitivePtr & prim,const mind_ir::AttributeProto & attr_proto)1281 bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
1282 MS_EXCEPTION_IF_NULL(prim);
1283 const std::string &attr_name = attr_proto.name();
1284 if (!attr_proto.has_ref_attr_name()) {
1285 MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
1286 return false;
1287 }
1288 const std::string &ref_attr_name = attr_proto.ref_attr_name();
1289 ParseForm type = GetParseFormType(ref_attr_name);
1290 mindspore::HashMap<std::string, ValuePtr> multi_value_map;
1291 switch (type) {
1292 case FORM_PARSE_TYPE: {
1293 (void)ObtainCNodeAttrInTypeForm(prim, attr_proto);
1294 break;
1295 }
1296 case FORM_PARSE_SCALAR: {
1297 if (ref_attr_name.find("value0") != std::string::npos) {
1298 ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
1299 MS_EXCEPTION_IF_NULL(res);
1300 const std::string &op_type = prim->name();
1301 if (is_kernel_graph_) {
1302 (void)prim->AddAttr(attr_name, res);
1303 break;
1304 }
1305 CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res);
1306 if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && res->isa<StringImm>()) {
1307 auto str_dtype = GetValue<std::string>(res);
1308 if (str_dtype == "int32") {
1309 int64_t index = 3;
1310 (void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
1311 break;
1312 }
1313 MS_EXCEPTION(NotSupportError)
1314 << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
1315 << res->ToString();
1316 }
1317 (void)prim->AddAttr(attr_name, res);
1318 break;
1319 } else if (ref_attr_name.find("Tuple[]") != std::string::npos) {
1320 (void)prim->AddAttr(attr_name, std::make_shared<ValueTuple>(std::vector<ValuePtr>()));
1321 break;
1322 } else if (ref_attr_name.find("List[]") != std::string::npos) {
1323 (void)prim->AddAttr(attr_name, std::make_shared<ValueList>(std::vector<ValuePtr>()));
1324 break;
1325 }
1326 ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
1327 break;
1328 }
1329 case FORM_PARSE_TENSOR: {
1330 (void)ObtainCNodeAttrInTensorForm(prim, attr_proto);
1331 break;
1332 }
1333 case FORM_PARSE_NONE: {
1334 (void)prim->AddAttr(attr_name, kNone);
1335 break;
1336 }
1337 default:
1338 MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
1339 return false;
1340 }
1341
1342 if (type == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
1343 if (ref_attr_name.find("Tuple") != std::string::npos) {
1344 auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
1345 (void)prim->AddAttr(attr_name, value_tuple_ptr);
1346 } else {
1347 auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
1348 (void)prim->AddAttr(attr_name, value_list_ptr);
1349 }
1350 }
1351 return true;
1352 }
1353
ObtainValueNodeInTensorForm(const std::string & value_node_name,const mind_ir::TensorProto & attr_tensor)1354 bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
1355 const mind_ir::TensorProto &attr_tensor) {
1356 tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor);
1357 if (tensor_info == nullptr) {
1358 MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
1359 return false;
1360 }
1361 auto new_value_node = NewValueNode(MakeValue(tensor_info));
1362 MS_EXCEPTION_IF_NULL(new_value_node);
1363 auto tensor_abstract = tensor_info->ToAbstract();
1364 MS_EXCEPTION_IF_NULL(tensor_abstract);
1365 new_value_node->set_abstract(tensor_abstract);
1366 anfnode_build_map_[value_node_name] = new_value_node;
1367 return true;
1368 }
1369
ObtainValueNodeInTupleTensorForm(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)1370 bool MSANFModelParser::ObtainValueNodeInTupleTensorForm(const std::string &value_node_name,
1371 const mind_ir::AttributeProto &attr_proto) {
1372 std::vector<tensor::TensorPtr> tensor_vec;
1373 for (int i = 0; i < attr_proto.tensors_size(); ++i) {
1374 mind_ir::TensorProto attr_tensor = attr_proto.tensors(i);
1375 const int attr_tensor_type = attr_tensor.data_type();
1376 ShapeVector shape;
1377 for (int j = 0; j < attr_tensor.dims_size(); ++j) {
1378 shape.push_back(attr_tensor.dims(j));
1379 }
1380 tensor::TensorPtr tensor_info = nullptr;
1381 if (!attr_tensor.has_compression_type() ||
1382 attr_tensor.compression_type() == mind_ir::TensorProto_CompressionType_NO_COMPRESSION) {
1383 tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
1384 } else {
1385 auto compression_type = static_cast<TensorCompressionType>(static_cast<int>(attr_tensor.compression_type()));
1386 size_t data_size = 0;
1387 if (!attr_tensor.has_external_data()) {
1388 data_size = attr_tensor.raw_data().size();
1389 } else {
1390 data_size = LongToSize(attr_tensor.external_data().length());
1391 }
1392 tensor_info =
1393 std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape, data_size, compression_type);
1394 }
1395 const std::string &tensor_buf = attr_tensor.raw_data();
1396 auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
1397 errno_t ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
1398 if (ret != EOK) {
1399 MS_LOG(ERROR) << "Obtain ValueNode in TupleTensorForm occur memcpy_s error.";
1400 return false;
1401 }
1402 tensor_vec.push_back(tensor_info);
1403 }
1404 auto value = MakeValue(tensor_vec);
1405 auto new_value_node = NewValueNode(value);
1406 new_value_node->set_abstract(value->ToAbstract());
1407 anfnode_build_map_[value_node_name] = new_value_node;
1408 return true;
1409 }
1410
ObtainValueNodeInTypeForm(const std::string & value_node_name,const mind_ir::TensorProto & attr_tensor)1411 bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
1412 const mind_ir::TensorProto &attr_tensor) {
1413 const int attr_tensor_type = attr_tensor.data_type();
1414 auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
1415 if (iter == kDefaultValueSwitchMap.end()) {
1416 MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
1417 return false;
1418 }
1419 auto value = TypeIdToType(iter->second);
1420 auto new_value_node = NewValueNode(value);
1421 new_value_node->set_abstract(value->ToAbstract());
1422 anfnode_build_map_[value_node_name] = new_value_node;
1423 return true;
1424 }
1425
ObtainValueNodeInNoneForm(const std::string & value_node_name)1426 bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_name) {
1427 auto new_value_node = NewValueNode(kNone);
1428 MS_EXCEPTION_IF_NULL(new_value_node);
1429 new_value_node->set_abstract(kNone->ToAbstract());
1430 anfnode_build_map_[value_node_name] = new_value_node;
1431 return true;
1432 }
1433
ObtainValueNodeInMonadForm(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)1434 bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name,
1435 const mind_ir::AttributeProto &attr_proto) {
1436 const std::string &ref_attr_name = attr_proto.ref_attr_name();
1437 if (ref_attr_name.find("UMonad") != std::string::npos) {
1438 auto monad_abs = kUMonad->ToAbstract();
1439 auto new_value_node = NewValueNode(kUMonad);
1440 MS_EXCEPTION_IF_NULL(new_value_node);
1441 new_value_node->set_abstract(monad_abs);
1442 anfnode_build_map_[value_node_name] = new_value_node;
1443 } else if (ref_attr_name.find("IOMonad") != std::string::npos) {
1444 auto monad_abs = kIOMonad->ToAbstract();
1445 auto new_value_node = NewValueNode(kIOMonad);
1446 MS_EXCEPTION_IF_NULL(new_value_node);
1447 new_value_node->set_abstract(monad_abs);
1448 anfnode_build_map_[value_node_name] = new_value_node;
1449 } else {
1450 return false;
1451 }
1452 return true;
1453 }
1454
ObtainValueInDictionaryForm(const mind_ir::AttributeProto & attr_proto)1455 ValuePtr MSANFModelParser::ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto) {
1456 std::vector<std::pair<ValuePtr, ValuePtr>> key_values;
1457 for (int i = 0; i < attr_proto.values_size(); ++i) {
1458 const mind_ir::AttributeProto &key_value_proto = attr_proto.values(i);
1459 if (!key_value_proto.has_name()) {
1460 MS_LOG(INTERNAL_EXCEPTION) << "Dict type AttributeProto should has name as key of dictionary";
1461 }
1462 auto key = std::make_shared<abstract::AbstractScalar>(key_value_proto.name())->BuildValue();
1463 MS_EXCEPTION_IF_NULL(key);
1464 auto &values = key_value_proto.values();
1465 if (values.size() != 1) {
1466 MS_LOG(INTERNAL_EXCEPTION) << "Dict type AttributeProto should has exactly one value, but got " << values.size()
1467 << " value(s).";
1468 }
1469 auto &value = values[0];
1470 switch (value.type()) {
1471 case mind_ir::AttributeProto_AttributeType_TENSORS: {
1472 const mind_ir::TensorProto &tensor_proto = value.tensors(0);
1473 if (tensor_proto.has_raw_data()) {
1474 // For real tensor.
1475 tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
1476 if (tensor_info == nullptr) {
1477 MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
1478 return nullptr;
1479 }
1480 (void)key_values.emplace_back(std::make_pair(key, tensor_info));
1481 } else {
1482 // For data type.
1483 const int attr_tensor_type = tensor_proto.data_type();
1484 auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
1485 if (iter == kDefaultValueSwitchMap.end()) {
1486 MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
1487 return nullptr;
1488 }
1489 (void)key_values.emplace_back(std::make_pair(key, TypeIdToType(iter->second)));
1490 }
1491 break;
1492 }
1493 case mind_ir::AttributeProto_AttributeType_TUPLE:
1494 case mind_ir::AttributeProto_AttributeType_LIST: {
1495 auto sequence_value = ObtainValueInSequenceForm(value);
1496 if (sequence_value == nullptr) {
1497 MS_LOG(ERROR) << "Failed to get the sequence value";
1498 return nullptr;
1499 }
1500 (void)key_values.emplace_back(std::make_pair(key, sequence_value));
1501 break;
1502 }
1503 case mind_ir::AttributeProto_AttributeType_DICT: {
1504 auto dict_value = ObtainValueInDictionaryForm(value);
1505 if (dict_value == nullptr) {
1506 MS_LOG(ERROR) << "Failed to get the dictionary value";
1507 return nullptr;
1508 }
1509 (void)key_values.emplace_back(std::make_pair(key, dict_value));
1510 break;
1511 }
1512 default: {
1513 // For string and scalar.
1514 auto scalar_value = ParseAttrInScalarForm(value, 0);
1515 if (scalar_value == nullptr) {
1516 MS_LOG(ERROR) << "Failed to get the scalar for ValueNode.";
1517 return nullptr;
1518 }
1519 (void)key_values.emplace_back(std::make_pair(key, scalar_value));
1520 }
1521 }
1522 }
1523 return std::make_shared<ValueDictionary>(key_values);
1524 }
1525
ObtainValueInSequenceForm(const mind_ir::AttributeProto & attr_proto)1526 ValuePtr MSANFModelParser::ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto) {
1527 std::vector<ValuePtr> vec;
1528 for (int i = 0; i < attr_proto.values_size(); ++i) {
1529 mind_ir::AttributeProto elem_attr_proto = attr_proto.values(i);
1530 switch (elem_attr_proto.type()) {
1531 case mind_ir::AttributeProto_AttributeType_TENSORS: {
1532 mind_ir::TensorProto tensor_proto = elem_attr_proto.tensors(0);
1533 if (tensor_proto.has_raw_data()) {
1534 // For real tensor.
1535 tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
1536 if (tensor_info == nullptr) {
1537 MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
1538 return nullptr;
1539 }
1540 (void)vec.emplace_back(tensor_info);
1541 } else if (tensor_proto.name() == kQuantParam) {
1542 auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
1543 if (!quantization_param_vector.empty()) {
1544 (void)vec.emplace_back(quantization_param_vector[0]);
1545 }
1546 } else {
1547 // For data type.
1548 const int attr_tensor_type = tensor_proto.data_type();
1549 auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
1550 if (iter == kDefaultValueSwitchMap.end()) {
1551 MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
1552 return nullptr;
1553 }
1554 (void)vec.emplace_back(TypeIdToType(iter->second));
1555 }
1556 break;
1557 }
1558 case mind_ir::AttributeProto_AttributeType_TUPLE:
1559 case mind_ir::AttributeProto_AttributeType_LIST: {
1560 auto sequence_value = ObtainValueInSequenceForm(elem_attr_proto);
1561 if (sequence_value == nullptr) {
1562 MS_LOG(ERROR) << "Failed to get the sequence value";
1563 return nullptr;
1564 }
1565 (void)vec.emplace_back(sequence_value);
1566 break;
1567 }
1568 default: {
1569 // For string and scalar.
1570 auto scalar_value = ParseAttrInScalarForm(elem_attr_proto, 0);
1571 if (scalar_value == nullptr) {
1572 MS_LOG(ERROR) << "Failed to get the scalar for ValueNode.";
1573 return nullptr;
1574 }
1575 (void)vec.emplace_back(scalar_value);
1576 }
1577 }
1578 }
1579 auto type = attr_proto.type();
1580 ValuePtr value_sequence;
1581 if (type == mind_ir::AttributeProto_AttributeType_TUPLE) {
1582 value_sequence = std::make_shared<ValueTuple>(vec);
1583 } else if (type == mind_ir::AttributeProto_AttributeType_LIST) {
1584 value_sequence = std::make_shared<ValueList>(vec);
1585 } else {
1586 MS_LOG(INTERNAL_EXCEPTION) << "The attribute type should be tuple or list, but it is " << type;
1587 }
1588
1589 return value_sequence;
1590 }
1591
BuildValueFromAttributeProto(const mind_ir::AttributeProto & attr_proto)1592 ValuePtr MSANFModelParser::BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
1593 switch (attr_proto.type()) {
1594 case mind_ir::AttributeProto_AttributeType_TENSORS: {
1595 const auto &tensor_proto = attr_proto.tensors(0);
1596 if (tensor_proto.has_raw_data()) {
1597 // For real tensor.
1598 tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
1599 if (tensor_info == nullptr) {
1600 MS_LOG(ERROR) << "Failed to GenerateTensorPtrFromTensorProto.";
1601 return nullptr;
1602 }
1603 return MakeValue(tensor_info);
1604 } else {
1605 // For data type.
1606 const int attr_tensor_type = tensor_proto.data_type();
1607 auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
1608 if (iter == kDefaultValueSwitchMap.end()) {
1609 MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
1610 return nullptr;
1611 }
1612 return TypeIdToType(iter->second);
1613 }
1614 }
1615 case mind_ir::AttributeProto_AttributeType_NONE: {
1616 return kNone;
1617 }
1618 case mind_ir::AttributeProto_AttributeType_UMONAD: {
1619 return kUMonad;
1620 }
1621 case mind_ir::AttributeProto_AttributeType_IOMONAD: {
1622 return kIOMonad;
1623 }
1624 case mind_ir::AttributeProto_AttributeType_TUPLE:
1625 case mind_ir::AttributeProto_AttributeType_LIST: {
1626 return ObtainValueInSequenceForm(attr_proto);
1627 }
1628 case mind_ir::AttributeProto_AttributeType_CLASS_TYPE: {
1629 auto class_type = static_cast<std::string>(attr_proto.s());
1630 return std::make_shared<MindIRClassType>(class_type);
1631 }
1632 case mind_ir::AttributeProto_AttributeType_TYPE_NULL: {
1633 return kTypeNull;
1634 }
1635 case mind_ir::AttributeProto_AttributeType_NAME_SPACE: {
1636 auto name_space = static_cast<std::string>(attr_proto.s());
1637 return std::make_shared<MindIRNameSpace>(name_space);
1638 }
1639 case mind_ir::AttributeProto_AttributeType_SYMBOL: {
1640 auto symbol = static_cast<std::string>(attr_proto.s());
1641 return std::make_shared<MindIRSymbol>(symbol);
1642 }
1643 default: {
1644 return ObtainCNodeAttrInSingleScalarForm(attr_proto);
1645 }
1646 }
1647 }
1648
GetAttrValueForValueNodeWithType(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)1649 bool MSANFModelParser::GetAttrValueForValueNodeWithType(const std::string &value_node_name,
1650 const mind_ir::AttributeProto &attr_proto) {
1651 auto value = BuildValueFromAttributeProto(attr_proto);
1652 if (value == nullptr) {
1653 MS_LOG(ERROR) << "Failed to build value from AttributeProto while building valuenode: " << value_node_name;
1654 return false;
1655 }
1656 auto abstract = value->ToAbstract();
1657 MS_EXCEPTION_IF_NULL(abstract);
1658 ValueNodePtr new_value_node = NewValueNode(value);
1659 new_value_node->set_abstract(abstract);
1660 anfnode_build_map_[value_node_name] = new_value_node;
1661 return true;
1662 }
1663
GetAttrValueForValueNode(const std::string & value_node_name,const mind_ir::AttributeProto & attr_proto)1664 bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
1665 const mind_ir::AttributeProto &attr_proto) {
1666 if (!attr_proto.has_ref_attr_name()) {
1667 MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
1668 return false;
1669 }
1670 const std::string &ref_attr_name = attr_proto.ref_attr_name();
1671 ParseForm type = GetParseFormType(ref_attr_name);
1672 ValueNodePtr new_value_node;
1673 mindspore::HashMap<std::string, ValuePtr> multi_value_map;
1674 switch (type) {
1675 case FORM_PARSE_TYPE: {
1676 (void)ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0));
1677 break;
1678 }
1679 case FORM_PARSE_SCALAR: {
1680 if (ref_attr_name.find("value0") != std::string::npos) {
1681 auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
1682 MS_EXCEPTION_IF_NULL(res);
1683 new_value_node = NewValueNode(res);
1684 new_value_node->set_abstract(res->ToAbstract());
1685 anfnode_build_map_[value_node_name] = new_value_node;
1686 break;
1687 }
1688 if (ref_attr_name.find("Tuple[]") != std::string::npos) {
1689 MS_LOG(INFO) << "Build Tuple() ValueNode for primitive.";
1690 ValuePtr res = MakeValue(std::vector<ValuePtr>{});
1691 new_value_node = NewValueNode(res);
1692 new_value_node->set_abstract(res->ToAbstract());
1693 anfnode_build_map_[value_node_name] = new_value_node;
1694 break;
1695 }
1696 if (ref_attr_name.find("Tuple[value") != std::string::npos && attr_proto.tensors_size() > 1) {
1697 MS_LOG(INFO) << "Build TupleTensor ValueNode for primitive.";
1698 (void)ObtainValueNodeInTupleTensorForm(value_node_name, attr_proto);
1699 break;
1700 }
1701 ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
1702 break;
1703 }
1704 case FORM_PARSE_TENSOR: {
1705 (void)ObtainValueNodeInTensorForm(value_node_name, attr_proto.tensors(0));
1706 break;
1707 }
1708 case FORM_PARSE_NONE: {
1709 (void)ObtainValueNodeInNoneForm(value_node_name);
1710 break;
1711 }
1712 case FORM_PARSE_MONAD: {
1713 (void)ObtainValueNodeInMonadForm(value_node_name, attr_proto);
1714 break;
1715 }
1716 default:
1717 MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
1718 return false;
1719 }
1720 if (type == FORM_PARSE_SCALAR && !multi_value_map.empty()) {
1721 if (ref_attr_name.find("Tuple") != std::string::npos) {
1722 auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
1723 if (value_tuple_ptr == nullptr) {
1724 MS_LOG(ERROR) << "Failed to build the value of the ValueNode, attr_proto:" << attr_proto.DebugString()
1725 << ", value_node_name:" << value_node_name;
1726 return false;
1727 }
1728 new_value_node = NewValueNode(value_tuple_ptr);
1729 new_value_node->set_abstract(value_tuple_ptr->ToAbstract());
1730 } else {
1731 auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
1732 if (value_list_ptr == nullptr) {
1733 MS_LOG(ERROR) << "Failed to build the value of the ValueNode, attr_proto:" << attr_proto.DebugString()
1734 << ", value_node_name:" << value_node_name;
1735 return false;
1736 }
1737 new_value_node = NewValueNode(value_list_ptr);
1738 new_value_node->set_abstract(value_list_ptr->ToAbstract());
1739 }
1740 anfnode_build_map_[value_node_name] = new_value_node;
1741 }
1742 return true;
1743 }
1744
BuildValueNodeForFuncGraph(const mind_ir::NodeProto & node_proto)1745 bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto) {
1746 if (node_proto.output_size() == 0) {
1747 MS_LOG(ERROR) << "The Proto output is empty.";
1748 return false;
1749 }
1750 const std::string &value_node_name = node_proto.output(0);
1751 const mind_ir::AttributeProto &attr_proto = node_proto.attribute(0);
1752 if (attr_proto.has_ref_attr_name()) {
1753 return GetAttrValueForValueNode(value_node_name, attr_proto);
1754 }
1755 return GetAttrValueForValueNodeWithType(value_node_name, attr_proto);
1756 }
1757
GetAbstractForNode(const mind_ir::AttributeProto & attr_proto)1758 mindspore::HashMap<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForNode(
1759 const mind_ir::AttributeProto &attr_proto) {
1760 mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
1761 for (int i = 0; i < attr_proto.tensors_size(); ++i) {
1762 const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i);
1763 auto tensor_info = GetAbsTensorFromTensorProto(attr_tensor);
1764 (void)kv.emplace(attr_tensor.name(), tensor_info);
1765 }
1766 return kv;
1767 }
1768
BuildOperatorNode(const mind_ir::NodeProto & node_proto)1769 AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_proto) {
1770 const std::string kOperatorTypeFlag = std::string("REF::");
1771 const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
1772 const std::string &node_type = node_proto.op_type();
1773 MS_LOG(DEBUG) << "Process Operator:" << node_type;
1774 // Operator maybe CNode,FuncGraph or Parameter.
1775
1776 if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
1777 auto anfNode = GetAnfNode(node_type.substr(kOpTypeFlagSize));
1778 if (anfNode == nullptr) {
1779 MS_LOG(ERROR) << "Can't find the ref:" << node_type;
1780 return nullptr;
1781 }
1782 return anfNode;
1783 }
1784
1785 // Operator is primitive.
1786 std::shared_ptr<Primitive> prim;
1787 auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
1788 if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
1789 prim = op_primc_fns[node_type]();
1790 } else {
1791 if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
1792 auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
1793 prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
1794 MS_EXCEPTION_IF_NULL(prim);
1795 prim->set_instance_name(op_name);
1796 } else {
1797 MS_LOG(DEBUG) << "Special node_type: " << node_type;
1798 prim = std::make_shared<Primitive>(node_type);
1799 MS_EXCEPTION_IF_NULL(prim);
1800 prim->set_instance_name(node_type);
1801 }
1802 }
1803 prim->set_attr("is_load", MakeValue(true));
1804 return NewValueNodeWithAbstract(prim);
1805 }
1806
SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr & node_ptr)1807 bool MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr) {
1808 auto primitive = GetCNodePrimitive(node_ptr);
1809 if (primitive != nullptr) {
1810 auto node_type = primitive->name();
1811 if (node_type == "UpdateState") {
1812 node_ptr->set_abstract(kUMonad->ToAbstract());
1813 } else if (node_type == "Depend") {
1814 node_ptr->set_abstract(kBool->ToAbstract());
1815 } else {
1816 auto cnode_ptr = node_ptr->cast<CNodePtr>();
1817 AbstractBasePtrList elem;
1818 for (size_t index = 1; index < cnode_ptr->size(); ++index) {
1819 auto abs = cnode_ptr->input(index)->abstract();
1820 if (abs != nullptr) {
1821 if (abs->GetValueTrack() == nullptr) {
1822 abs->set_value(kValueAny);
1823 }
1824 elem.push_back(abs);
1825 }
1826 }
1827 if (!elem.empty()) {
1828 node_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
1829 }
1830 }
1831 } else {
1832 MS_LOG(ERROR) << "Failed to get the abstract of node:" << node_ptr->DebugString();
1833 return false;
1834 }
1835 return true;
1836 }
1837
1838 // Set CNode abstract.
SetCNodeAbstract(const mind_ir::AttributeProto & attr_proto,const CNodePtr & cnode_ptr)1839 void MSANFModelParser::SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr) {
1840 if (attr_proto.has_ref_attr_name()) {
1841 if (!SetNodeAbstractFromAttrProto(attr_proto, cnode_ptr)) {
1842 MS_LOG(ERROR) << "Failed to get CNode abstract from proto.";
1843 }
1844 } else {
1845 auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto);
1846 cnode_ptr->set_abstract(abs);
1847 }
1848 if (cnode_ptr->abstract() == nullptr) {
1849 MS_LOG(INFO) << "Failed to Build CNode abstract from proto. CNode: " << cnode_ptr->ToString()
1850 << " attr_proto: " << attr_proto.DebugString();
1851 node_abstract_protos_.push_back(std::pair(cnode_ptr, &attr_proto));
1852 }
1853 }
1854
BuildCNodeForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::NodeProto & node_proto)1855 CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
1856 const mind_ir::NodeProto &node_proto) {
1857 MS_EXCEPTION_IF_NULL(outputFuncGraph);
1858 if (!node_proto.has_op_type()) {
1859 MS_LOG(ERROR) << "Get CNode op_type failed!";
1860 return nullptr;
1861 }
1862 if (node_proto.output_size() <= 0) {
1863 MS_LOG(ERROR) << "Get CNode out failed!";
1864 return nullptr;
1865 }
1866 const std::string &node_name = node_proto.output(0);
1867 MS_LOG(DEBUG) << "Process CNode: " << node_name;
1868 // Build inputs.
1869 std::vector<AnfNodePtr> inputs;
1870 auto operator_node = BuildOperatorNode(node_proto);
1871 if (operator_node == nullptr) {
1872 MS_LOG(ERROR) << "Build operator node " << node_name << " failed!";
1873 return nullptr;
1874 }
1875 inputs.push_back(operator_node);
1876 for (int i = 0; i < node_proto.input_size(); ++i) {
1877 auto anfNode = GetAnfNode(node_proto.input(i));
1878 if (anfNode == nullptr) {
1879 MS_LOG(ERROR) << node_name << " input " << i << node_proto.input(i) << "can't find in nodes have parsed";
1880 return nullptr;
1881 }
1882 inputs.push_back(anfNode);
1883 }
1884
1885 CNodePtr cnode_ptr = outputFuncGraph->FuncGraph::NewCNode(inputs);
1886 MS_EXCEPTION_IF_NULL(cnode_ptr);
1887 if (anfnode_build_map_.count(node_name) > 0) {
1888 MS_LOG(INTERNAL_EXCEPTION) << "Duplicate CNode name: " << node_name;
1889 }
1890 const std::string &fullname_with_scope = node_proto.domain();
1891 string debug_info_name = ParseCNodeName(node_name);
1892 auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
1893 cnode_ptr->set_debug_info(debug_info_ptr);
1894 cnode_ptr->set_fullname_with_scope(fullname_with_scope);
1895 cnode_ptr->set_load_flag(true);
1896 anfnode_build_map_[node_name] = cnode_ptr;
1897
1898 // Set Abstract and prim attr for CNode
1899 SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr);
1900 if (!BuildAttrForCNode(cnode_ptr, node_proto)) {
1901 MS_LOG(ERROR) << "Failed build attr for node: " << cnode_ptr->DebugString()
1902 << ", proto: " << node_proto.DebugString();
1903 }
1904 return cnode_ptr;
1905 }
1906
BuildReturnForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1907 bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
1908 const mind_ir::GraphProto &importProto) {
1909 MS_EXCEPTION_IF_NULL(outputFuncGraph);
1910 std::vector<AnfNodePtr> inputs;
1911 if (importProto.output_size() > 1) {
1912 inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
1913 AbstractBasePtrList elem;
1914 for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
1915 const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
1916 const std::string &out_tuple = output_node.name();
1917 auto anfNode = GetAnfNode(out_tuple);
1918 if (anfNode == nullptr) {
1919 MS_LOG(ERROR) << "Miss return node: " << out_tuple;
1920 return false;
1921 }
1922 inputs.push_back(anfNode);
1923 elem.push_back(anfNode->abstract());
1924 }
1925 auto maketuple_ptr = outputFuncGraph->FuncGraph::NewCNode(inputs);
1926 MS_EXCEPTION_IF_NULL(maketuple_ptr);
1927 maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
1928 inputs.clear();
1929 inputs.push_back(NewValueNode(prim::kPrimReturn));
1930 inputs.push_back(maketuple_ptr);
1931 auto return_node = outputFuncGraph->FuncGraph::NewCNode(inputs);
1932 MS_EXCEPTION_IF_NULL(return_node);
1933 return_node->set_abstract(maketuple_ptr->abstract());
1934 return_node->set_load_flag(true);
1935 outputFuncGraph->set_return(return_node);
1936 MS_LOG(DEBUG) << "Construct funcgraph finined, all success.";
1937 return true;
1938 } else if (importProto.output_size() == 1) {
1939 auto graph_name = importProto.name();
1940 const auto &return_node_input0 = NewValueNode(prim::kPrimReturn);
1941 anfnode_build_map_[graph_name + kReturnPrimNode] = return_node_input0;
1942 inputs.push_back(return_node_input0);
1943 auto node_name = importProto.output(0).name();
1944 auto anf_node = GetAnfNode(node_name);
1945 if (anf_node == nullptr) {
1946 MS_LOG(ERROR) << "Miss return node: " << node_name;
1947 return false;
1948 }
1949 inputs.push_back(anf_node);
1950 anfnode_build_map_[node_name] = anf_node;
1951 auto return_node = outputFuncGraph->FuncGraph::NewCNode(inputs);
1952 MS_EXCEPTION_IF_NULL(return_node);
1953 return_node->set_abstract(anf_node->abstract());
1954 return_node->set_load_flag(true);
1955 outputFuncGraph->set_return(return_node);
1956 anfnode_build_map_[graph_name + kReturnNode] = return_node;
1957 MS_LOG(DEBUG) << "Construct funcgraph finined, all success!";
1958 return true;
1959 }
1960
1961 return false;
1962 }
1963
ImportNodesForGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1964 bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
1965 const mind_ir::GraphProto &importProto) {
1966 MS_EXCEPTION_IF_NULL(outputFuncGraph);
1967 CNodePtr cnode_ptr = nullptr;
1968 for (int i = 0; i < importProto.node_size(); ++i) {
1969 const mind_ir::NodeProto &node_proto = importProto.node(i);
1970 if (is_kernel_graph_ && anfnode_build_map_.count(node_proto.output(0)) > 0) {
1971 continue;
1972 }
1973 const std::string &node_type = node_proto.op_type();
1974 if (node_type == kConstantValueNode) {
1975 if (!BuildValueNodeForFuncGraph(node_proto)) {
1976 MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: " << i;
1977 return false;
1978 }
1979 continue;
1980 }
1981 cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
1982 if (cnode_ptr == nullptr) {
1983 MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: " << i;
1984 return false;
1985 }
1986 }
1987 return BuildReturnForFuncGraph(outputFuncGraph, importProto);
1988 }
1989
BuildAttrForFuncGraph(const FuncGraphPtr & outputFuncGraph,const mind_ir::GraphProto & importProto)1990 bool MSANFModelParser::BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph,
1991 const mind_ir::GraphProto &importProto) {
1992 for (auto i = 0; i < importProto.attribute_size(); ++i) {
1993 const mind_ir::AttributeProto &attr_proto = importProto.attribute(i);
1994 auto value = GetValueFromAttributeProto(attr_proto);
1995 if (value == nullptr) {
1996 MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
1997 return false;
1998 }
1999 outputFuncGraph->set_attr(attr_proto.name(), value);
2000 }
2001 return true;
2002 }
2003
BuildFuncGraph(const FuncGraphPtr & output_graph,const mind_ir::GraphProto & import_proto)2004 bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &output_graph, const mind_ir::GraphProto &import_proto) {
2005 MS_EXCEPTION_IF_NULL(output_graph);
2006 GraphDebugInfoPtr debug_info_ptr = output_graph->debug_info();
2007 MS_EXCEPTION_IF_NULL(debug_info_ptr);
2008 if (import_proto.has_name()) {
2009 debug_info_ptr->set_name(import_proto.name());
2010 } else {
2011 MS_LOG(ERROR) << "FuncGraph under converting has not name!";
2012 return false;
2013 }
2014 if (import_proto.has_bprop_hash()) {
2015 output_graph->set_bprop_hash(import_proto.bprop_hash());
2016 }
2017
2018 if (import_proto.has_bprop_filepath()) {
2019 output_graph->set_bprop_filepath(import_proto.bprop_filepath());
2020 }
2021 if (!BuildAttrForFuncGraph(output_graph, import_proto)) {
2022 MS_LOG(ERROR) << "Build attribute for graph fail!";
2023 }
2024 if (!ImportParametersForGraph(output_graph, import_proto)) {
2025 MS_LOG(ERROR) << "Import parameters for graph fail!";
2026 return false;
2027 }
2028 if (!ImportMapParametersForGraph(output_graph, import_proto)) {
2029 MS_LOG(ERROR) << "Import map parameters for graph failed!";
2030 return false;
2031 }
2032 if (!ImportNodesForGraph(output_graph, import_proto)) {
2033 MS_LOG(ERROR) << "Import nodes for graph failed! " << import_proto.has_name();
2034 return false;
2035 }
2036 auto context = MsContext::GetInstance();
2037 MS_EXCEPTION_IF_NULL(context);
2038 const bool force_no_inline = common::IsDisableRuntimeConfig(common::kRuntimeInline);
2039 if (output_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
2040 const bool enable_ge = context->backend_policy() == "ge";
2041 auto cell_reuse_level =
2042 (enable_ge && !context->IsKByKExecutorMode()) ? CellReuseLevel::kNoInline : CellReuseLevel::kLazyInline;
2043 if (force_no_inline) {
2044 cell_reuse_level = CellReuseLevel::kNoInline;
2045 }
2046 context->SetCellReuseLevel(cell_reuse_level);
2047 }
2048 return true;
2049 }
2050
SetValueForTopGraphParameter(const FuncGraphPtr & topGraph,const std::map<std::string,ValuePtr> & weights)2051 bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph,
2052 const std::map<std::string, ValuePtr> &weights) {
2053 size_t fv_param_count = 0;
2054 auto parameters = topGraph->parameters();
2055 for (int64_t i = SizeToLong(parameters.size()) - 1; i >= 0; --i) {
2056 auto parameter = parameters[i]->cast<ParameterPtr>();
2057 if (parameter == nullptr) {
2058 MS_LOG(ERROR) << "AnfNode " << parameters[i]->DebugString() << " should be Parameter.";
2059 return false;
2060 }
2061 auto type = parameter->Type();
2062 if (type == nullptr) {
2063 MS_LOG(ERROR) << "Parameter " << parameter->DebugString() << " has no type.";
2064 return false;
2065 }
2066 if (!type->isa<RefType>()) {
2067 break;
2068 }
2069 auto parameter_name = parameter->name();
2070 auto weights_iter = weights.find(parameter_name);
2071 if (weights_iter == weights.end()) {
2072 MS_LOG(INFO) << "Find initial weight value for " << parameter_name << " failed.";
2073 continue;
2074 }
2075 parameter->set_default_param(weights_iter->second);
2076 fv_param_count++;
2077 }
2078 topGraph->set_fv_param_count(fv_param_count);
2079 return true;
2080 }
2081
TrytoBuildCNodeAbstract()2082 void MSANFModelParser::TrytoBuildCNodeAbstract() {
2083 std::map<CNodePtr, int> visited_times;
2084 constexpr int kMaxCount = 3;
2085 while (!node_abstract_protos_.empty()) {
2086 auto &item = node_abstract_protos_.front();
2087 auto &count = visited_times[item.first];
2088 if (count++ > kMaxCount) {
2089 abstract_valid_ = false;
2090 MS_LOG(ERROR) << "Parse CNode: " << item.first->ToString() << " abstract error: " << item.second->DebugString();
2091 } else {
2092 SetCNodeAbstract(*(item.second), item.first);
2093 }
2094 node_abstract_protos_.pop_front();
2095 }
2096 }
2097
CheckMindIRVseriosn(const mind_ir::ModelProto & model_proto)2098 bool CheckMindIRVseriosn(const mind_ir::ModelProto &model_proto) {
2099 if (model_proto.has_mind_ir_version()) {
2100 auto mind_ir_version = model_proto.mind_ir_version();
2101 if (mind_ir_version >= 2) {
2102 return true;
2103 }
2104 }
2105 return false;
2106 }
2107
Parse(const mind_ir::ModelProto & model_proto,const std::map<std::string,ValuePtr> & weights,mindspore::HashMap<std::string,AnfNodePtr> * name_to_node)2108 FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto,
2109 const std::map<std::string, ValuePtr> &weights,
2110 mindspore::HashMap<std::string, AnfNodePtr> *name_to_node) {
2111 if (IsLite()) {
2112 abstract_valid_ = true;
2113 }
2114 if (name_to_node) {
2115 anfnode_build_map_ = *name_to_node;
2116 }
2117 for (int i = 0; i < model_proto.primitives_size(); ++i) {
2118 if (!BuildPrimitiveNode(model_proto.primitives(i))) {
2119 MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
2120 return nullptr;
2121 }
2122 }
2123 FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
2124 const mind_ir::GraphProto &graphBuild = model_proto.graph();
2125
2126 // Forward declare FuncGraph name
2127 // Compatible with the previous proto.
2128 if (graphBuild.has_name()) {
2129 anfnode_build_map_[graphBuild.name()] = NewValueNodeWithAbstract(dstGraph);
2130 }
2131 for (int i = 0; i < model_proto.functions_size(); ++i) {
2132 FuncGraphPtr graph = std::make_shared<FuncGraph>();
2133 const auto &graph_proto = model_proto.functions(i);
2134 if (!graph_proto.has_name()) {
2135 MS_LOG(ERROR) << "The function has not a name. Please export mindIR again. ";
2136 return nullptr;
2137 }
2138 if (anfnode_build_map_.count(graph_proto.name()) > 0) {
2139 MS_LOG(ERROR) << "There is a duplication function graph name: " << graph_proto.name();
2140 return nullptr;
2141 }
2142 anfnode_build_map_[graph_proto.name()] = NewValueNodeWithAbstract(graph);
2143 }
2144
2145 // Parser the proto.
2146 if (!BuildFuncGraph(dstGraph, graphBuild)) {
2147 MS_LOG(ERROR) << "Build funcgraph failed!";
2148 return nullptr;
2149 }
2150
2151 if (!weights.empty()) {
2152 if (!SetValueForTopGraphParameter(dstGraph, weights)) {
2153 MS_LOG(ERROR) << "Set value for top graph fail.";
2154 return nullptr;
2155 }
2156 }
2157 bool generated_from_mindir_with_prim_func = CheckMindIRVseriosn(model_proto);
2158 dstGraph->set_flag("generated_from_mindir_with_prim_func", generated_from_mindir_with_prim_func);
2159 MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! graph: " << graphBuild.name() << ": " << dstGraph.get();
2160 top_graph_ = dstGraph;
2161 for (int i = 0; i < model_proto.functions_size(); ++i) {
2162 const auto &graph_proto = model_proto.functions(i);
2163 FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);
2164 if (!BuildFuncGraph(graph, graph_proto)) {
2165 MS_LOG(ERROR) << "Build funcgraph failed!";
2166 return nullptr;
2167 }
2168 graph->set_flag("generated_from_mindir_with_prim_func", generated_from_mindir_with_prim_func);
2169 MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! graph: " << graph_proto.name() << ": " << graph.get();
2170 }
2171 TrytoBuildCNodeAbstract();
2172 if (name_to_node) {
2173 *name_to_node = anfnode_build_map_;
2174 }
2175 // Release resource
2176 anfnode_build_map_.clear();
2177 // Correct the null abstract for compatibility with previous versions.
2178 if (!abstract_valid_ && weights.empty()) {
2179 CorrectFuncGraph(dstGraph);
2180 }
2181 return dstGraph;
2182 }
2183
Parse(const mind_ir::ModelProto & model_proto,const std::vector<FuncGraphPtr> & graphs,mindspore::HashMap<std::string,AnfNodePtr> * name_to_node)2184 bool MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto, const std::vector<FuncGraphPtr> &graphs,
2185 mindspore::HashMap<std::string, AnfNodePtr> *name_to_node) {
2186 is_kernel_graph_ = graphs.front()->type_name() == kKernelGraphTypeName;
2187 if (name_to_node) {
2188 anfnode_build_map_ = *name_to_node;
2189 }
2190 auto build_params_attrs = [this](const FuncGraphPtr &graph, const mind_ir::GraphProto &proto) {
2191 MS_EXCEPTION_IF_NULL(graph);
2192 if (!proto.has_name()) {
2193 MS_LOG(ERROR) << "KernelGraph under converting has not name!";
2194 return false;
2195 }
2196 GraphDebugInfoPtr debug_info_ptr = graph->debug_info();
2197 MS_EXCEPTION_IF_NULL(debug_info_ptr);
2198 debug_info_ptr->set_name(proto.name());
2199 if (!BuildAttrForFuncGraph(graph, proto)) {
2200 MS_LOG(ERROR) << "Build attribute for graph fail!";
2201 }
2202 if (!ImportParametersForGraph(graph, proto)) {
2203 MS_LOG(ERROR) << "Import parameters for graph fail!";
2204 return false;
2205 }
2206 if (!ImportMapParametersForGraph(graph, proto)) {
2207 MS_LOG(ERROR) << "Import map parameters for graph failed!";
2208 return false;
2209 }
2210 return true;
2211 };
2212 for (int i = 0; i < model_proto.primitives_size(); ++i) {
2213 if (!BuildPrimitiveNode(model_proto.primitives(i))) {
2214 MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
2215 return false;
2216 }
2217 }
2218 const mind_ir::GraphProto &graph_build = model_proto.graph();
2219 const auto &root = FindGraphByName(graphs, graph_build.name());
2220 MS_EXCEPTION_IF_NULL(root);
2221 anfnode_build_map_[graph_build.name()] = NewValueNodeWithAbstract(root);
2222 top_graph_ = root;
2223 if (!build_params_attrs(root, graph_build)) {
2224 MS_LOG(ERROR) << "Build funcgraph params and attrs failed.";
2225 return false;
2226 }
2227 for (int i = 0; i < model_proto.functions_size(); ++i) {
2228 const auto &graph_proto = model_proto.functions(i);
2229 if (!graph_proto.has_name()) {
2230 MS_LOG(ERROR) << "The function has not a name. Please export mindIR again. ";
2231 return false;
2232 }
2233 const auto &graph_name = graph_proto.name();
2234 if (anfnode_build_map_.count(graph_name) > 0) {
2235 MS_LOG(ERROR) << "There is a duplication function graph name: " << graph_proto.name();
2236 return false;
2237 }
2238 const auto &graph = FindGraphByName(graphs, graph_name);
2239 MS_EXCEPTION_IF_NULL(graph);
2240 auto debug_info = graph->debug_info();
2241 debug_info->set_name(graph_name);
2242 anfnode_build_map_[graph_name] = NewValueNodeWithAbstract(graph);
2243 if (!build_params_attrs(graph, graph_proto)) {
2244 MS_LOG(ERROR) << "Build funcgraph params and attrs failed.";
2245 return false;
2246 }
2247 }
2248
2249 // Parser the proto.
2250 if (!ImportNodesForGraph(root, graph_build)) {
2251 MS_LOG(ERROR) << "Build funcgraph " << graph_build.name() << " value node and cnode failed.";
2252 return false;
2253 } else {
2254 MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! graph: " << graph_build.name();
2255 }
2256 std::map<std::string, mind_ir::GraphProto> sorted_proto;
2257 std::for_each(model_proto.functions().begin(), model_proto.functions().end(),
2258 [&sorted_proto](const auto &proto) { sorted_proto[proto.name()] = proto; });
2259 for (const auto &[name, proto] : sorted_proto) {
2260 FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[name]);
2261 if (!ImportNodesForGraph(graph, proto)) {
2262 MS_LOG(ERROR) << "Build funcgraph: " << name << "'s value_node and cnode failed.";
2263 return false;
2264 } else {
2265 MS_LOG(INFO) << "Build FuncGraph Success! graph: " << name;
2266 }
2267 }
2268 TrytoBuildCNodeAbstract();
2269 if (name_to_node) {
2270 *name_to_node = anfnode_build_map_;
2271 }
2272 // Release resource
2273 anfnode_build_map_.clear();
2274 return true;
2275 }
2276
ParseLayout(const mind_ir::ModelProto & model_proto)2277 const LayoutMap MSANFModelParser::ParseLayout(const mind_ir::ModelProto &model_proto) {
2278 LayoutMap ret;
2279 mind_ir::ParallelProto parallel_proto = model_proto.parallel();
2280 for (int i = 0; i < parallel_proto.layout_size(); ++i) {
2281 const mind_ir::LayoutProto &layout_proto = parallel_proto.layout(i);
2282 LayoutPtr cur_layout = std::make_shared<Layout>();
2283 const std::string name = layout_proto.name();
2284 std::vector<int64_t> device_arrangement;
2285 for (int num = 0; num < layout_proto.device_arrangement_int_size(); ++num) {
2286 (void)device_arrangement.emplace_back(layout_proto.device_arrangement_int(num));
2287 }
2288 std::vector<int64_t> tensor_map;
2289 for (int num = 0; num < layout_proto.tensor_map_int_size(); ++num) {
2290 (void)tensor_map.emplace_back(layout_proto.tensor_map_int(num));
2291 }
2292 std::vector<int64_t> slice_shape;
2293 for (int num = 0; num < layout_proto.slice_shape_int_size(); ++num) {
2294 (void)slice_shape.emplace_back(layout_proto.slice_shape_int(num));
2295 }
2296 int64_t field_size = layout_proto.field_size();
2297 bool uniform_spilt = layout_proto.uniform_split();
2298 const std::string opt_shard_group = layout_proto.opt_shard_group();
2299
2300 cur_layout->set_device_arrangement(device_arrangement);
2301 cur_layout->set_tensor_map(tensor_map);
2302 cur_layout->set_slice_shape(slice_shape);
2303 cur_layout->set_field_size(field_size);
2304 cur_layout->set_uniform_split(uniform_spilt);
2305 cur_layout->set_opt_shard_group(opt_shard_group);
2306
2307 // Check optional field for backward compatibility.
2308 if (layout_proto.has_pipeline_shared()) {
2309 bool pipeline_shared = layout_proto.pipeline_shared();
2310 bool is_send = layout_proto.is_send();
2311 int64_t peer_rank = layout_proto.peer_rank();
2312 int64_t sr_tag = layout_proto.sr_tag();
2313
2314 cur_layout->set_pipeline_shared(pipeline_shared);
2315 cur_layout->set_is_send(is_send);
2316 cur_layout->set_peer_rank(peer_rank);
2317 cur_layout->set_sr_tag(sr_tag);
2318 }
2319 ret[name] = cur_layout;
2320 }
2321 return ret;
2322 }
2323
GetAnfNode(const std::string & node_name)2324 AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) {
2325 if (node_name.find("MetaFuncGraph::") == 0) {
2326 auto fg_name = node_name.substr(std::string("MetaFuncGraph::").length());
2327 auto mindir_meta_fg = std::make_shared<MindIRMetaFuncGraph>(fg_name);
2328 return NewValueNode(mindir_meta_fg);
2329 }
2330 if (node_name.find("ClassType::") == 0) {
2331 auto class_type = node_name.substr(std::string("ClassType::").length());
2332 auto mindir_class_type = std::make_shared<MindIRClassType>(class_type);
2333 return NewValueNode(mindir_class_type);
2334 }
2335 auto it = anfnode_build_map_.find(node_name);
2336 if (it == anfnode_build_map_.end()) {
2337 return nullptr;
2338 }
2339 // The FunctionGraph node can't been shared.
2340 FuncGraphPtr func_graph_ptr = GetValueNode<FuncGraphPtr>(it->second);
2341 if (func_graph_ptr != nullptr) {
2342 auto node = NewValueNode(func_graph_ptr);
2343 node->set_abstract(func_graph_ptr->ToAbstract());
2344 return node;
2345 } else {
2346 return it->second;
2347 }
2348 }
2349
BuildPrimitiveNode(const mind_ir::PrimitiveProto & primitive_proto)2350 bool MSANFModelParser::BuildPrimitiveNode(const mind_ir::PrimitiveProto &primitive_proto) {
2351 static auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
2352 auto &prim_type = primitive_proto.op_type();
2353 const auto &type = primitive_proto.prim_type();
2354 std::shared_ptr<Primitive> prim;
2355
2356 auto it = op_primc_fns.find(prim_type);
2357 if (it != op_primc_fns.end()) {
2358 prim = it->second();
2359 } else {
2360 if (prim_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
2361 auto op_name = prim_type.substr(strlen(kDoSignaturePrimitivePrefix));
2362 prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
2363 } else {
2364 prim = std::make_shared<Primitive>(prim_type);
2365 }
2366 }
2367 if (type == mind_ir::PrimitiveProto_PrimType_PRIMITIVE_FUNCTION) {
2368 MS_LOG(DEBUG) << "PrimitiveFunction special node_type: " << prim_type;
2369 prim->AddAttr("primitive_function", MakeValue(true));
2370 }
2371
2372 if (primitive_proto.has_instance_name()) {
2373 prim->set_instance_name(primitive_proto.instance_name());
2374 }
2375
2376 // Set primitive attributes
2377 auto prim_to_add_attr = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
2378 MS_EXCEPTION_IF_NULL(prim_to_add_attr);
2379 prim_to_add_attr->set_attr("is_load", MakeValue(true));
2380 for (int i = 0; i < primitive_proto.attribute_size(); ++i) {
2381 const mind_ir::AttributeProto &attr_proto = primitive_proto.attribute(i);
2382 if (!SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
2383 MS_LOG(ERROR) << "Parse prim: " << prim->ToString() << " attributes error: " << attr_proto.DebugString();
2384 return false;
2385 }
2386 }
2387 if (anfnode_build_map_.count(primitive_proto.name()) > 0) {
2388 MS_LOG(ERROR) << "There is a duplication primitive instance name: " << primitive_proto.name();
2389 return false;
2390 }
2391 anfnode_build_map_[primitive_proto.name()] = NewValueNodeWithAbstract(prim);
2392 return true;
2393 }
2394
BuildAbstractFunction(const mind_ir::AttributeProto & attr_proto)2395 abstract::AbstractBasePtr MSANFModelParser::BuildAbstractFunction(const mind_ir::AttributeProto &attr_proto) {
2396 switch (attr_proto.type()) {
2397 case mind_ir::AttributeProto_AttributeType_PRIMITIVECLOSURE:
2398 case mind_ir::AttributeProto_AttributeType_FUNCGRAPHCLOSURE: {
2399 auto func_node = GetAnfNode(attr_proto.s());
2400 if (func_node == nullptr) {
2401 FuncGraphPtr dummy_graph = std::make_shared<FuncGraph>();
2402 MS_LOG(DEBUG) << "Failed to get function graph closure: " << attr_proto.DebugString();
2403 return dummy_graph->ToAbstract();
2404 }
2405 return func_node->abstract();
2406 }
2407 case mind_ir::AttributeProto_AttributeType_PARTIALCLOSURE: {
2408 auto anf_node = GetAnfNode(attr_proto.s());
2409 if (anf_node == nullptr) {
2410 return nullptr;
2411 }
2412 auto partial_node = anf_node->cast<CNodePtr>();
2413 MS_EXCEPTION_IF_NULL(partial_node);
2414 if (!IsPrimitiveCNode(partial_node, prim::kPrimPartial)) {
2415 MS_LOG(INTERNAL_EXCEPTION) << "Not partial CNode, but got " << partial_node->DebugString();
2416 }
2417 AbstractBasePtrList args_spec_list;
2418 auto &inputs = partial_node->inputs();
2419 const size_t kPartial_args_begin_pos = 2;
2420 const size_t kPartial_fn_pos = 1;
2421 if (inputs.size() <= kPartial_args_begin_pos) {
2422 MS_LOG(ERROR) << "Partial node input size is wrong.";
2423 return nullptr;
2424 }
2425 (void)std::transform(inputs.begin() + kPartial_args_begin_pos, inputs.end(), std::back_inserter(args_spec_list),
2426 [](const AnfNodePtr &arg) -> AbstractBasePtr { return arg->abstract(); });
2427 auto &op_node = inputs[kPartial_fn_pos];
2428 MS_EXCEPTION_IF_NULL(op_node);
2429 abstract::AbstractFuncAtomPtr fn;
2430 if (op_node->abstract() != nullptr) {
2431 fn = op_node->abstract()->cast<abstract::AbstractFuncAtomPtr>();
2432 if (fn == nullptr) {
2433 MS_LOG(DEBUG) << "Can't get the abstract of partial node: " << op_node->ToString();
2434 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op_node);
2435 if (fg == nullptr) {
2436 MS_LOG(INTERNAL_EXCEPTION) << "partial_node: " << partial_node->DebugString()
2437 << ", op_node: " << op_node->DebugString() << ", "
2438 << op_node->abstract()->ToString();
2439 }
2440 fn = fg->ToAbstract()->cast<abstract::AbstractFuncAtomPtr>();
2441 }
2442 } else {
2443 MS_LOG(DEBUG) << "Can't get the abstract of partial node: " << op_node->ToString();
2444 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op_node);
2445 MS_EXCEPTION_IF_NULL(fg);
2446 fn = fg->ToAbstract()->cast<abstract::AbstractFuncAtomPtr>();
2447 }
2448 return std::make_shared<abstract::PartialAbstractClosure>(fn, args_spec_list, partial_node);
2449 }
2450 case mind_ir::AttributeProto_AttributeType_UNIONFUNCCLOSURE: {
2451 abstract::AbstractFuncAtomPtrList func_list;
2452 for (int index = 0; index < attr_proto.values_size(); index++) {
2453 auto &item_proto = attr_proto.values(index);
2454 auto item_abstract = BuildAbstractFunction(item_proto);
2455 if (item_abstract == nullptr) {
2456 MS_LOG(WARNING) << "Can't get the abstract of function union closure: " << item_proto.DebugString();
2457 return nullptr;
2458 }
2459 (void)func_list.emplace_back(item_abstract->cast<abstract::AbstractFuncAtomPtr>());
2460 }
2461 return std::make_shared<abstract::AbstractFuncUnion>(func_list);
2462 }
2463 default: {
2464 MS_LOG(ERROR) << "Not support function abstract: " << attr_proto.DebugString();
2465 return nullptr;
2466 }
2467 }
2468 }
2469
CorrectFuncGraph(const FuncGraphPtr & root)2470 void MSANFModelParser::CorrectFuncGraph(const FuncGraphPtr &root) {
2471 MS_LOG(DEBUG) << "Begin to correct the funcgraph.";
2472 MS_EXCEPTION_IF_NULL(root);
2473 auto inputs = root->get_inputs();
2474 auto valid =
2475 std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &arg) -> bool { return arg->abstract() != nullptr; });
2476 if (valid) {
2477 (void)ValidMindir(root);
2478 } else {
2479 MS_LOG(INFO) << "There are some nullptr of abstract in the top function graph parameters." << root->DumpText();
2480 }
2481 MS_LOG(DEBUG) << "End to correct the funcgraph.";
2482 }
2483
BuildAttrForCNode(const CNodePtr & cnode,const mind_ir::NodeProto & node_proto)2484 bool MSANFModelParser::BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto) {
2485 for (auto i = 0; i < node_proto.node_attr_size(); ++i) {
2486 const auto &attr_proto = node_proto.node_attr(i);
2487 auto value = GetValueFromAttributeProto(attr_proto);
2488 if (value == nullptr) {
2489 MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
2490 return false;
2491 }
2492 cnode->AddAttr(attr_proto.name(), value);
2493 }
2494 for (auto i = 0; i < node_proto.primal_attr_size(); ++i) {
2495 const auto &attr_proto = node_proto.primal_attr(i);
2496 auto value = GetValueFromAttributeProto(attr_proto);
2497 if (value == nullptr) {
2498 MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
2499 return false;
2500 }
2501 cnode->AddPrimalAttr(attr_proto.name(), value);
2502 }
2503 return true;
2504 }
2505
get_all_files(const std::string & dir_in,std::vector<std::string> * files)2506 bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
2507 if (dir_in.empty()) {
2508 return false;
2509 }
2510 struct stat s;
2511 int ret = stat(dir_in.c_str(), &s);
2512 if (ret != 0) {
2513 MS_LOG(ERROR) << "stat error, ret is : " << ret;
2514 return false;
2515 }
2516 if (!S_ISDIR(s.st_mode)) {
2517 return false;
2518 }
2519 DIR *open_dir = opendir(dir_in.c_str());
2520 if (open_dir == nullptr) {
2521 MS_LOG(EXCEPTION) << "open dir " << dir_in.c_str() << " failed";
2522 }
2523 dirent *p = nullptr;
2524 while ((p = readdir(open_dir)) != nullptr) {
2525 struct stat st;
2526 if (p->d_name[0] != '.') {
2527 std::string name = dir_in + std::string("/") + std::string(p->d_name);
2528 ret = stat(name.c_str(), &st);
2529 if (ret != 0) {
2530 MS_LOG(ERROR) << "stat error, ret is : " << ret;
2531 closedir(open_dir);
2532 return false;
2533 }
2534 if (S_ISDIR(st.st_mode)) {
2535 if (!get_all_files(name, files)) {
2536 MS_LOG(ERROR) << "Get files failed, ret is : " << ret;
2537 closedir(open_dir);
2538 return false;
2539 }
2540 } else if (S_ISREG(st.st_mode)) {
2541 files->push_back(name);
2542 }
2543 }
2544 }
2545 closedir(open_dir);
2546 return true;
2547 }
2548
endsWith(const string s,const string sub)2549 int endsWith(const string s, const string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }
2550
ParseModelProto(mind_ir::ModelProto * model,const std::string & path,const MindIRLoader * loader)2551 bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path, const MindIRLoader *loader) {
2552 if (loader->dec_key() != nullptr) {
2553 size_t plain_len;
2554 auto plain_data = Decrypt(&plain_len, path, loader->dec_key(), loader->key_len(), loader->dec_mode());
2555 if (plain_data == nullptr) {
2556 MS_LOG(ERROR)
2557 << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode or the file integrity.";
2558 return false;
2559 }
2560 if (!model->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), static_cast<int32_t>(plain_len))) {
2561 MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file, dec_key or dec_mode.";
2562 return false;
2563 }
2564 } else {
2565 std::fstream input_graph(path, std::ios::in | std::ios::binary);
2566 if (!input_graph || !model->ParseFromIstream(&input_graph)) {
2567 MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
2568 return false;
2569 }
2570 }
2571 return true;
2572 }
2573
ParseGraphProto(mind_ir::GraphProto * graph,const std::string & path,const MindIRLoader * loader)2574 bool ParseGraphProto(mind_ir::GraphProto *graph, const std::string &path, const MindIRLoader *loader) {
2575 if (loader->dec_key() != nullptr) {
2576 size_t plain_len;
2577 auto plain_data = Decrypt(&plain_len, path, loader->dec_key(), loader->key_len(), loader->dec_mode());
2578 if (plain_data == nullptr) {
2579 MS_LOG(ERROR)
2580 << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode or the file integrity.";
2581 return false;
2582 }
2583 if (!graph->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), static_cast<int32_t>(plain_len))) {
2584 MS_LOG(ERROR) << "Load variable file failed, please check the correctness of the mindir's variable file, "
2585 "dec_key or dec_mode";
2586 return false;
2587 }
2588 } else {
2589 std::fstream input_param(path, std::ios::in | std::ios::binary);
2590 if (!input_param || !graph->ParseFromIstream(&input_param)) {
2591 MS_LOG(ERROR) << "Load variable file failed, please check the correctness of mindir's variable file.";
2592 return false;
2593 }
2594 }
2595 return true;
2596 }
2597
InitModelParser(MSANFModelParser * model_parser,const MindIRLoader * loader)2598 void InitModelParser(MSANFModelParser *model_parser, const MindIRLoader *loader) {
2599 model_parser->SetMindIRDecKey(loader->dec_key());
2600 model_parser->SetMindIRKeySize(loader->key_len());
2601 model_parser->SetMindIRDecMode(loader->dec_mode());
2602
2603 if (loader->is_lite()) {
2604 model_parser->SetLite();
2605 }
2606 }
2607 } // namespace
2608
LoadPreprocess(const std::string & file_name)2609 std::vector<std::string> MindIRLoader::LoadPreprocess(const std::string &file_name) {
2610 if (file_name.length() > PATH_MAX) {
2611 MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
2612 return {};
2613 }
2614 char abs_path_buff[PATH_MAX];
2615
2616 #ifdef _WIN32
2617 _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
2618 #else
2619 if (!realpath(file_name.c_str(), abs_path_buff)) {
2620 MS_LOG(ERROR) << "Load MindIR get absolute path failed";
2621 }
2622 #endif
2623
2624 // Read graph
2625 mind_ir::ModelProto origin_model;
2626 if (!ParseModelProto(&origin_model, std::string(abs_path_buff), this)) {
2627 MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
2628 return {};
2629 }
2630
2631 // Read dataset preprocessor
2632 auto preprocessor = origin_model.preprocessor();
2633
2634 // Separate columns and parse
2635 std::vector<std::string> input_columns;
2636 for (auto i = 0; i < preprocessor.op_size(); i++) {
2637 std::string column = preprocessor.op()[i].input_columns();
2638 if (std::find(input_columns.begin(), input_columns.end(), column) == input_columns.end()) {
2639 input_columns.push_back(column);
2640 }
2641 }
2642
2643 // Each column has one string to indicate its preprocess behaviour
2644 std::vector<std::string> map_jsons;
2645 for (std::string &column : input_columns) {
2646 nlohmann::json dataset_json;
2647 nlohmann::json child_dataset_json;
2648 for (auto i = preprocessor.op_size() - 1; i >= 0; i--) {
2649 if (preprocessor.op()[i].input_columns() == column) {
2650 child_dataset_json["input_columns"] = nlohmann::json::parse(preprocessor.op()[i].input_columns());
2651 child_dataset_json["op_type"] = nlohmann::json::parse(preprocessor.op()[i].op_type());
2652 child_dataset_json["operations"] = nlohmann::json::parse(preprocessor.op()[i].operations());
2653 child_dataset_json["output_columns"] = nlohmann::json::parse(preprocessor.op()[i].output_columns());
2654 child_dataset_json["offload"] = preprocessor.op()[i].offload();
2655
2656 dataset_json["children"] = child_dataset_json;
2657 child_dataset_json = dataset_json;
2658 }
2659 }
2660 map_jsons.push_back(dataset_json["children"].dump());
2661 }
2662 return map_jsons;
2663 }
2664
LoadMindIRs(const std::vector<std::string> & file_names)2665 std::vector<FuncGraphPtr> MindIRLoader::LoadMindIRs(const std::vector<std::string> &file_names) {
2666 std::vector<FuncGraphPtr> funcgraph_vec;
2667 MS_LOG(DEBUG) << "Load multiple MindIR files.";
2668 for (const auto &file_name : file_names) {
2669 MS_LOG(DEBUG) << "Load " << file_name;
2670 funcgraph_vec.push_back(LoadMindIR(file_name));
2671 }
2672 return funcgraph_vec;
2673 }
2674
LoadMindIR(const void * buffer,const size_t & size)2675 FuncGraphPtr MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
2676 /* mindir -> func_graph
2677 * only support lite */
2678 mind_ir::ModelProto model;
2679 auto ret = model.ParseFromArray(buffer, SizeToInt(size));
2680 if (!ret) {
2681 MS_LOG(ERROR) << "ParseFromArray failed.";
2682 return nullptr;
2683 }
2684 if (!CheckModelConfigureInfo(model)) {
2685 MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2686 return nullptr;
2687 }
2688 MSANFModelParser model_parser;
2689 InitModelParser(&model_parser, this);
2690 FuncGraphPtr func_graph = model_parser.Parse(model);
2691
2692 return func_graph;
2693 }
2694
2695 mindspore::HashMap<std::string, AnfNodePtr> anfnode_build_map_;
LoadMindIR(const std::string & file_name,mindspore::HashMap<std::string,AnfNodePtr> * name_to_node)2696 FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name,
2697 mindspore::HashMap<std::string, AnfNodePtr> *name_to_node) {
2698 if (file_name.length() > PATH_MAX) {
2699 MS_LOG(EXCEPTION) << "The length of the file name exceeds the limit.";
2700 }
2701 char abs_path_buff[PATH_MAX];
2702 vector<string> files;
2703
2704 #ifdef _WIN32
2705 _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
2706 #else
2707 if (!realpath(file_name.c_str(), abs_path_buff)) {
2708 MS_LOG(EXCEPTION) << "Load MindIR get absolute path of " << file_name
2709 << " failed, errno is: " << ErrnoToString(errno);
2710 }
2711 #endif
2712 // Read graph
2713 mind_ir::ModelProto origin_model;
2714 if (!ParseModelProto(&origin_model, std::string(abs_path_buff), this)) {
2715 return nullptr;
2716 }
2717
2718 if (!CheckModelConfigureInfo(origin_model)) {
2719 MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2720 return nullptr;
2721 }
2722 // Load parameter into graph
2723 if (endsWith(std::string(abs_path_buff), "_graph.mindir") && (origin_model.graph().parameter_size() == 0)) {
2724 if (strlen(abs_path_buff) < strlen("graph.mindir")) {
2725 MS_LOG(ERROR) << "The abs_path_buff length is less than 'graph.mindir'.";
2726 return nullptr;
2727 }
2728 size_t path_len = strlen(abs_path_buff) - strlen("graph.mindir");
2729 string var_path = std::string(abs_path_buff).substr(0, path_len);
2730 var_path += "variables";
2731 std::ifstream ifs(var_path);
2732 if (ifs.good()) {
2733 MS_LOG(DEBUG) << "MindIR file has variables path, load parameter into graph.";
2734 (void)get_all_files(var_path, &files);
2735 } else {
2736 MS_LOG(ERROR) << "Load graph's variable folder failed, please check the correctness of variable folder.";
2737 return nullptr;
2738 }
2739
2740 size_t file_size = files.size();
2741 mind_ir::GraphProto *mod_graph = origin_model.mutable_graph();
2742 for (size_t file_index = 0; file_index < file_size; file_index++) {
2743 mind_ir::GraphProto param_graph;
2744 if (!ParseGraphProto(¶m_graph, files[file_index], this)) {
2745 return nullptr;
2746 }
2747
2748 for (int param_index = 0; param_index < param_graph.parameter_size(); param_index++) {
2749 mind_ir::TensorProto *param_proto = mod_graph->add_parameter();
2750 param_proto->set_name(param_graph.parameter(param_index).name());
2751 param_proto->set_data_type(param_graph.parameter(param_index).data_type());
2752 param_proto->set_raw_data(param_graph.parameter(param_index).raw_data());
2753 param_proto->set_compression_type(param_graph.parameter(param_index).compression_type());
2754 for (const auto &dim : param_graph.parameter(param_index).dims()) {
2755 param_proto->add_dims(dim);
2756 }
2757 }
2758 }
2759 }
2760
2761 MSANFModelParser model_parser;
2762
2763 auto mindir_path = std::string(abs_path_buff);
2764 model_parser.SetMindIRPath(mindir_path.substr(0, mindir_path.rfind("/")));
2765 InitModelParser(&model_parser, this);
2766 FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model, weights_value_map_, name_to_node);
2767 if (has_parallel_info_) {
2768 layout_map_ = model_parser.ParseLayout(origin_model);
2769 }
2770 return dstgraph_ptr;
2771 }
2772
LoadMindIR(const void * buffer,const size_t & size,const std::string & mindir_path,FuncGraphPtr * func_graph,std::string * user_info_string)2773 bool MindIRLoader::LoadMindIR(const void *buffer, const size_t &size, const std::string &mindir_path,
2774 FuncGraphPtr *func_graph, std::string *user_info_string) {
2775 mind_ir::ModelProto model;
2776 auto ret = model.ParseFromArray(buffer, SizeToInt(size));
2777 if (!ret) {
2778 MS_LOG(ERROR) << "ParseFromArray failed.";
2779 return false;
2780 }
2781 if (!CheckModelConfigureInfo(model)) {
2782 MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2783 return false;
2784 }
2785 MSANFModelParser model_parser;
2786 InitModelParser(&model_parser, this);
2787 model_parser.SetMindIRPath(mindir_path);
2788 *func_graph = model_parser.Parse(model);
2789 std::stringstream user_info_buffer;
2790 // user_info to string
2791 auto user_info = model.user_info();
2792 user_info_buffer << "{";
2793 for (auto it = user_info.begin(); it != user_info.end(); it++) {
2794 if (it != user_info.begin()) {
2795 user_info_buffer << ", ";
2796 }
2797 user_info_buffer << "\"" << it->first << "\": \"" << it->second + "\"";
2798 }
2799 user_info_buffer << "}";
2800 *user_info_string = user_info_buffer.str();
2801 return true;
2802 }
2803
LoadMindIR(const void * buffer,const size_t & size,const std::string & mindir_path)2804 FuncGraphPtr MindIRLoader::LoadMindIR(const void *buffer, const size_t &size, const std::string &mindir_path) {
2805 mind_ir::ModelProto model;
2806 auto ret = model.ParseFromArray(buffer, SizeToInt(size));
2807 if (!ret) {
2808 MS_LOG(ERROR) << "ParseFromArray failed.";
2809 return nullptr;
2810 }
2811 if (!CheckModelConfigureInfo(model)) {
2812 MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2813 return nullptr;
2814 }
2815 MSANFModelParser model_parser;
2816 InitModelParser(&model_parser, this);
2817 model_parser.SetMindIRPath(mindir_path);
2818 FuncGraphPtr func_graph = model_parser.Parse(model);
2819 return func_graph;
2820 }
2821
LoadMindIR(const std::string & file_name,const std::vector<FuncGraphPtr> & graphs,mindspore::HashMap<std::string,AnfNodePtr> * name_to_node)2822 bool MindIRLoader::LoadMindIR(const std::string &file_name, const std::vector<FuncGraphPtr> &graphs,
2823 mindspore::HashMap<std::string, AnfNodePtr> *name_to_node) {
2824 if (file_name.length() > PATH_MAX) {
2825 MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
2826 return false;
2827 }
2828 char abs_path_buff[PATH_MAX];
2829 #ifdef _WIN32
2830 _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
2831 #else
2832 if (!realpath(file_name.c_str(), abs_path_buff)) {
2833 MS_LOG(EXCEPTION) << "Load MindIR get absolute path of " << file_name
2834 << " failed, errno is: " << ErrnoToString(errno);
2835 }
2836 #endif
2837 mind_ir::ModelProto model_proto;
2838 // Read graph
2839 if (!ParseModelProto(&model_proto, std::string(abs_path_buff), this)) {
2840 return false;
2841 }
2842 if (!CheckModelConfigureInfo(model_proto)) {
2843 MS_LOG(ERROR) << "Check configuration info for pb file failed!";
2844 return false;
2845 }
2846 MSANFModelParser model_parser;
2847 InitModelParser(&model_parser, this);
2848 if (!model_parser.Parse(model_proto, graphs, name_to_node)) {
2849 MS_LOG(ERROR) << "Parse model failed!";
2850 return false;
2851 }
2852 return true;
2853 }
2854
ReadProtoFile(const std::string & file)2855 std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
2856 if (file.empty()) {
2857 MS_LOG(ERROR) << "file is nullptr";
2858 return nullptr;
2859 }
2860
2861 char real_path[PATH_MAX] = {0};
2862 #if defined(_WIN32) || defined(_WIN64)
2863 if (_fullpath(real_path, file.c_str(), PATH_MAX) == nullptr) {
2864 MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
2865 return nullptr;
2866 }
2867 #else
2868 if (realpath(file.c_str(), real_path) == nullptr) {
2869 MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
2870 return nullptr;
2871 }
2872 #endif
2873
2874 std::ifstream ifs(real_path);
2875 if (!ifs.good()) {
2876 MS_LOG(ERROR) << "file: " << real_path << " is not exist";
2877 return nullptr;
2878 }
2879
2880 if (!ifs.is_open()) {
2881 MS_LOG(ERROR) << "file: " << real_path << "open failed";
2882 return nullptr;
2883 }
2884
2885 ifs.seekg(0, std::ios::end);
2886 size_t size = ifs.tellg();
2887 std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
2888 if (buf == nullptr) {
2889 MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
2890 ifs.close();
2891 return nullptr;
2892 }
2893
2894 ifs.seekg(0, std::ios::beg);
2895 ifs.read(buf->data(), size);
2896 ifs.close();
2897
2898 return buf;
2899 }
2900
ConvertStreamToFuncGraph(const char * buf,const size_t buf_size,bool is_lite)2901 FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
2902 MS_EXCEPTION_IF_NULL(buf);
2903 std::string str(buf, buf_size);
2904 mind_ir::ModelProto model_;
2905 if (!model_.ParseFromString(str)) {
2906 MS_LOG(ERROR) << "Parse model from buffer fail!";
2907 }
2908 MSANFModelParser model_parser;
2909 if (is_lite) {
2910 model_parser.SetLite();
2911 }
2912 FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
2913 return dstgraph_ptr;
2914 }
2915 } // namespace mindspore
2916