• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/common/debug/dump_proto.h"
17 #include <algorithm>
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 #include "google/protobuf/util/json_util.h"
24 #include "proto/anf_ir.pb.h"
25 #include "proto/mind_ir.pb.h"
26 #include "ir/graph_utils.h"
27 #include "utils/ms_context.h"
28 #include "utils/symbolic.h"
29 #include "include/common/debug/anf_dump_utils.h"
30 #include "utils/anf_utils.h"
31 #include "frontend/parallel/ops_info/ops_utils.h"  // todo: use constant string now
32 #include "mindspore/core/utils/file_utils.h"
33 
34 namespace mindspore {
35 class ProtoExporter {
36  public:
ProtoExporter()37   ProtoExporter() {}
~ProtoExporter()38   ~ProtoExporter() {}
39 
40   std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
41   void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
42 
43  private:
44   void InitModelInfo();
45   void GetOpNodeTypeAndAttrs(const FuncGraphPtr & /* func_graph */, const CNodePtr &cnode, irpb::NodeProto *node_proto);
46   std::string GetOpNodeInputId(const FuncGraphPtr & /* func_graph */, const AnfNodePtr &node,
47                                const std::map<AnfNodePtr, size_t> &apply_map,
48                                std::map<AnfNodePtr, size_t> *const_map_ptr) const;
49   void SetValueToProtoBasicTypes(const ValuePtr &val, irpb::ValueProto *const value_proto) const;
50   void SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto);
51   void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) const;
52   void SetSequenceToProto(const ValueSequencePtr &val, irpb::ValueProto *value_proto);
53   void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto);
54   void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto);
55   void SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto);
56 
57   void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
58   void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
59                     std::map<AnfNodePtr, size_t> *const_map_ptr);
60   void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
61                    std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto);
62   void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
63                              const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
64                              irpb::GraphProto *graph_proto);
65   void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto);
66 
GetConstNodeId(size_t idx)67   static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
68 
69   irpb::ModelProto model_;
70 };
71 
72 static std::map<TypeId, irpb::DataType> number_data_type_map = {{kNumberTypeBool, irpb::DT_BOOL},
73                                                                 {kNumberTypeInt8, irpb::DT_INT8},
74                                                                 {kNumberTypeInt16, irpb::DT_INT16},
75                                                                 {kNumberTypeInt32, irpb::DT_INT32},
76                                                                 {kNumberTypeInt64, irpb::DT_INT64},
77                                                                 {kNumberTypeUInt8, irpb::DT_UINT8},
78                                                                 {kNumberTypeUInt16, irpb::DT_UINT16},
79                                                                 {kNumberTypeUInt32, irpb::DT_UINT32},
80                                                                 {kNumberTypeUInt64, irpb::DT_UINT64},
81                                                                 {kNumberTypeFloat16, irpb::DT_FLOAT16},
82                                                                 {kNumberTypeFloat32, irpb::DT_FLOAT32},
83                                                                 {kNumberTypeFloat64, irpb::DT_FLOAT64},
84                                                                 {kNumberTypeBFloat16, irpb::DT_BFLOAT16},
85                                                                 {kNumberTypeInt, irpb::DT_BASE_INT},
86                                                                 {kNumberTypeUInt, irpb::DT_BASE_UINT},
87                                                                 {kNumberTypeFloat, irpb::DT_BASE_FLOAT},
88                                                                 {kNumberTypeComplex64, irpb::DT_COMPLEX64},
89                                                                 {kNumberTypeComplex128, irpb::DT_COMPLEX128},
90                                                                 {kObjectTypeString, irpb::DT_STRING},
91                                                                 {kObjectTypeTuple, irpb::DT_TUPLE},
92                                                                 {kNumberTypeInt4, irpb::DT_INT4}};
93 
GetNumberDataType(const TypePtr & type)94 static irpb::DataType GetNumberDataType(const TypePtr &type) {
95   auto iter = number_data_type_map.find(type->type_id());
96   if (iter != number_data_type_map.end()) {
97     return (*iter).second;
98   } else {
99     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected type " << type->type_name();
100   }
101 }
102 
IsKindOfTensorType(const TypePtr & type)103 static inline bool IsKindOfTensorType(const TypePtr &type) {
104   MS_EXCEPTION_IF_NULL(type);
105   return type->isa<TensorType>() || type->isa<RowTensorType>() || type->isa<CSRTensorType>() ||
106          type->isa<COOTensorType>() || type->isa<MapTensorType>();
107 }
108 
CheckIfValidType(const TypePtr & type)109 void CheckIfValidType(const TypePtr &type) {
110   MS_EXCEPTION_IF_NULL(type);
111   if (type->isa<Problem>()) {
112     MS_LOG(WARNING) << "The type: " << type->type_name();
113     return;
114   }
115   if (!(type->isa<Number>() || IsKindOfTensorType(type) || type->isa<Tuple>() || type->isa<TypeType>() ||
116         type->isa<List>() || type->isa<TypeAny>() || type->isa<RefKeyType>() || type->isa<RefType>() ||
117         type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<UndeterminedType>() ||
118         type->isa<SymbolicKeyType>() || type->isa<MonadType>() || type->isa<Dictionary>()) ||
119       type->isa<Slice>()) {
120     MS_LOG(INTERNAL_EXCEPTION) << "Unknown type: " << type->type_name();
121   }
122 }
123 
SetTensorType(const TypePtr & type,const BaseShapePtr & shape,irpb::TypeProto * const type_proto)124 void SetTensorType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *const type_proto) {
125   TypePtr elem_type = dyn_cast<TensorType>(type)->element();
126   type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
127   type_proto->set_data_type(irpb::DT_TENSOR);
128   if (shape != nullptr && shape->isa<abstract::Shape>()) {
129     abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
130     for (const auto &elem : shape_info->shape()) {
131       type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
132     }
133   }
134 }
135 
SetNodeOutputType(const TypePtr & type,const BaseShapePtr & shape,irpb::TypeProto * type_proto)136 void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
137   if (type_proto == nullptr) {
138     return;
139   }
140   if (type == nullptr) {
141     type_proto->set_data_type(irpb::DT_UNDEFINED);
142     return;
143   }
144   if (type->isa<External>()) {
145     return;
146   }
147   CheckIfValidType(type);
148   if (type->isa<Number>()) {
149     type_proto->set_data_type(GetNumberDataType(type));
150   } else if (type->isa<TensorType>()) {
151     SetTensorType(type, shape, type_proto);
152   } else if (type->isa<Tuple>()) {
153     TuplePtr tuple_type = dyn_cast<Tuple>(type);
154     type_proto->set_data_type(irpb::DT_TUPLE);
155     for (const auto &elem_type : tuple_type->elements()) {
156       SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
157     }
158   } else if (type->isa<TypeType>()) {
159     type_proto->set_data_type(irpb::DT_TYPE);
160   } else if (type->isa<List>()) {
161     ListPtr list_type = dyn_cast<List>(type);
162     type_proto->set_data_type(irpb::DT_LIST);
163     for (const auto &elem_type : list_type->elements()) {
164       SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
165     }
166   } else if (type->isa<TypeAny>()) {
167     type_proto->set_data_type(irpb::DT_ANY);
168   } else if (type->isa<RefKeyType>()) {
169     type_proto->set_data_type(irpb::DT_REFKEY);
170   } else if (type->isa<RefType>()) {
171     type_proto->set_data_type(irpb::DT_REF);
172   } else if (type->isa<Function>()) {
173     type_proto->set_data_type(irpb::DT_GRAPH);
174   } else if (type->isa<TypeNone>()) {
175     type_proto->set_data_type(irpb::DT_NONE);
176   } else if (type->isa<String>()) {
177     type_proto->set_data_type(irpb::DT_STRING);
178   } else {
179     type_proto->set_data_type(irpb::DT_SLICE);
180   }
181 }
182 
SetNodeOutputType(const AnfNodePtr & node,irpb::TypeProto * type_proto)183 void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) {
184   if (node == nullptr || type_proto == nullptr) {
185     return;
186   }
187   SetNodeOutputType(node->Type(), node->Shape(), type_proto);
188 }
189 
SetValueToProtoBasicTypes(const ValuePtr & val,irpb::ValueProto * const value_proto) const190 void ProtoExporter::SetValueToProtoBasicTypes(const ValuePtr &val, irpb::ValueProto *const value_proto) const {
191   if (val->isa<StringImm>()) {
192     const StringImmPtr &value = dyn_cast<StringImm>(val);
193     value_proto->set_dtype(irpb::DT_STRING);
194     value_proto->set_str_val(value->value());
195   } else if (val->isa<Scalar>()) {
196     SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
197   } else if (val->isa<Bool>()) {
198     value_proto->set_dtype(irpb::DT_TYPE);
199     value_proto->mutable_type_val()->set_data_type(irpb::DT_BOOL);
200   } else if (val->isa<Int>()) {
201     value_proto->set_dtype(irpb::DT_TYPE);
202     value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_INT);
203   } else if (val->isa<UInt>()) {
204     value_proto->set_dtype(irpb::DT_TYPE);
205     value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_UINT);
206   } else if (val->isa<Float>() || val->isa<BFloat>()) {
207     value_proto->set_dtype(irpb::DT_TYPE);
208     value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_FLOAT);
209   }
210 }
211 
SetValueToProto(const ValuePtr & val,irpb::ValueProto * value_proto)212 void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
213   if (val == nullptr || value_proto == nullptr) {
214     return;
215   }
216 
217   SetValueToProtoBasicTypes(val, value_proto);
218 
219   if (val->isa<ValueSequence>()) {
220     SetSequenceToProto(dyn_cast<ValueSequence>(val), value_proto);
221   } else if (val->isa<None>()) {
222     value_proto->set_dtype(irpb::DT_NONE);
223     value_proto->set_str_val("None");
224   } else if (val->isa<SymbolicKeyInstance>()) {
225     SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
226     ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
227     value_proto->set_dtype(irpb::DT_SYM_INST);
228     value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
229   } else if (val->isa<ValueDictionary>()) {
230     SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
231   } else if (val->isa<tensor::Tensor>()) {
232     tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
233     value_proto->set_dtype(irpb::DT_TENSOR);
234     irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
235     tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype()));
236     for (auto &elem : tensor_ptr->shape()) {
237       tensor_proto->add_dims(elem);
238     }
239   } else if (val->isa<TensorType>()) {
240     value_proto->set_dtype(irpb::DT_TYPE);
241 
242     irpb::TypeProto *type_proto = value_proto->mutable_type_val();
243     type_proto->set_data_type(irpb::DT_TENSOR);
244     TypePtr elem_type = dyn_cast<TensorType>(val)->element();
245     type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
246   } else if (val->isa<Monad>() || val->isa<MonadType>()) {
247     value_proto->set_str_val(val->ToString());
248   } else if (val->isa<Complex>()) {
249     value_proto->set_dtype(irpb::DT_TYPE);
250     value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_COMPLEX);
251   } else {
252     MS_LOG(DEBUG) << "Unsupported type " << val->type_name();
253   }
254 }
255 
SetScalarToProto(const ScalarPtr & val,irpb::ValueProto * value_proto) const256 void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) const {
257   if (val == nullptr || value_proto == nullptr) {
258     return;
259   }
260 
261   if (val->isa<BoolImm>()) {
262     const BoolImmPtr &value = dyn_cast<BoolImm>(val);
263     value_proto->set_dtype(irpb::DT_BOOL);
264     value_proto->set_bool_val(value->value());
265   } else if (val->isa<Int8Imm>()) {
266     const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
267     value_proto->set_dtype(irpb::DT_INT8);
268     value_proto->set_int_val(value->value());
269   } else if (val->isa<Int16Imm>()) {
270     const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
271     value_proto->set_dtype(irpb::DT_INT16);
272     value_proto->set_int_val(value->value());
273   } else if (val->isa<Int32Imm>()) {
274     const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
275     value_proto->set_dtype(irpb::DT_INT32);
276     value_proto->set_int_val(value->value());
277   } else if (val->isa<Int64Imm>()) {
278     const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
279     value_proto->set_dtype(irpb::DT_INT64);
280     value_proto->set_int_val(value->value());
281   } else if (val->isa<UInt8Imm>()) {
282     const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
283     value_proto->set_dtype(irpb::DT_UINT8);
284     value_proto->set_uint_val(value->value());
285   } else if (val->isa<UInt16Imm>()) {
286     const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
287     value_proto->set_dtype(irpb::DT_UINT16);
288     value_proto->set_uint_val(value->value());
289   } else if (val->isa<UInt32Imm>()) {
290     const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
291     value_proto->set_dtype(irpb::DT_UINT32);
292     value_proto->set_uint_val(value->value());
293   } else if (val->isa<UInt64Imm>()) {
294     const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
295     value_proto->set_dtype(irpb::DT_UINT64);
296     value_proto->set_uint_val(value->value());
297   } else if (val->isa<FP32Imm>()) {
298     const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
299     value_proto->set_dtype(irpb::DT_FLOAT32);
300     value_proto->set_float_val(value->value());
301   } else if (val->isa<FP64Imm>()) {
302     const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
303     value_proto->set_dtype(irpb::DT_FLOAT64);
304     value_proto->set_double_val(value->value());
305   } else {
306     MS_LOG(INTERNAL_EXCEPTION) << "Unknown scalar type " << val->ToString();
307   }
308 }
309 
SetSequenceToProto(const ValueSequencePtr & val,irpb::ValueProto * value_proto)310 void ProtoExporter::SetSequenceToProto(const ValueSequencePtr &val, irpb::ValueProto *value_proto) {
311   if (val == nullptr || value_proto == nullptr) {
312     return;
313   }
314 
315   if (val->isa<ValueTuple>()) {
316     const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
317     value_proto->set_dtype(irpb::DT_TUPLE);
318     for (const auto &item : value->value()) {
319       SetValueToProto(item, value_proto->add_values());
320     }
321   } else if (val->isa<ValueList>()) {
322     const ValueListPtr &value = dyn_cast<ValueList>(val);
323     value_proto->set_dtype(irpb::DT_LIST);
324     for (const auto &item : value->value()) {
325       SetValueToProto(item, value_proto->add_values());
326     }
327   }
328 }
329 
SetDictionaryToProto(const ValueDictionaryPtr & val,irpb::ValueProto * value_proto)330 void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) {
331   if (val == nullptr || value_proto == nullptr) {
332     return;
333   }
334 
335   value_proto->set_dtype(irpb::DT_DICT);
336   for (const auto &item : val->value()) {
337     irpb::NamedValueProto *named_val = value_proto->add_dict_val();
338     MS_EXCEPTION_IF_NULL(item.first);
339     if (!item.first->isa<StringImm>()) {
340       MS_LOG(INTERNAL_EXCEPTION) << "The key of NamedValueProto should be string type, but got "
341                                  << item.first->ToString();
342     }
343     named_val->set_key(GetValue<std::string>(item.first));
344     SetValueToProto(item.second, named_val->mutable_value());
345   }
346 }
347 
GetOpNodeTypeAndAttrs(const FuncGraphPtr &,const CNodePtr & cnode,irpb::NodeProto * node_proto)348 void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr & /* func_graph */, const CNodePtr &cnode,
349                                           irpb::NodeProto *node_proto) {
350   const auto &inputs = cnode->inputs();
351   AnfNodePtr op_node = inputs[0];
352 
353   if (op_node == nullptr || node_proto == nullptr) {
354     return;
355   }
356 
357   if (op_node->isa<CNode>() || op_node->isa<Parameter>() || IsValueNode<FuncGraph>(op_node)) {
358     MS_LOG(INTERNAL_EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got "
359                                << op_node->ToString();
360   }
361 
362   if (!IsValueNode<Primitive>(op_node)) {
363     MS_LOG(INTERNAL_EXCEPTION) << "Op node is not primitive: " << op_node->ToString();
364   }
365 
366   const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(op_node);
367   node_proto->set_op_type(prim->name());
368   for (const auto &attr : prim->attrs()) {
369     irpb::AttributeProto *attr_proto = node_proto->add_attribute();
370     attr_proto->set_name(attr.first);
371     SetValueToProto(attr.second, attr_proto->mutable_value());
372   }
373 
374   // Only CNode save the operator strategy
375   auto strategy_value = AnfDumpHandler::InStrategyValue(cnode);
376   if (strategy_value != nullptr) {
377     irpb::AttributeProto *attr_proto = node_proto->add_attribute();
378     attr_proto->set_name(mindspore::parallel::IN_STRATEGY);
379     SetValueToProto(strategy_value, attr_proto->mutable_value());
380   }
381 
382   node_proto->set_scope(op_node->scope()->name());
383 }
384 
GetOpNodeInputId(const FuncGraphPtr &,const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr) const385 std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr & /* func_graph */, const AnfNodePtr &node,
386                                             const std::map<AnfNodePtr, size_t> &apply_map,
387                                             std::map<AnfNodePtr, size_t> *const_map_ptr) const {
388   if (node == nullptr || const_map_ptr == nullptr) {
389     return "";
390   }
391 
392   if (node->isa<CNode>()) {
393     auto iter = apply_map.find(node);
394     if (iter == apply_map.end()) {
395       MS_LOG(INTERNAL_EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
396     }
397     return std::to_string(iter->second);
398   }
399 
400   if (node->isa<Parameter>()) {
401     return node->ToString();
402   }
403 
404   if (AnfUtils::IsCustomActorNode(node)) {
405     return AnfUtils::GetCustomActorName(node);
406   }
407 
408   if (node->isa<ValueNode>()) {
409     auto iter = const_map_ptr->find(node);
410     if (iter == const_map_ptr->end()) {
411       // Start index number from 1
412       auto const_idx = const_map_ptr->size() + 1;
413       (*const_map_ptr)[node] = const_idx;
414     }
415     return GetConstNodeId((*const_map_ptr)[node]);
416   }
417 
418   MS_LOG(INTERNAL_EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
419 }
420 
GetFuncGraphProtoString(const FuncGraphPtr & func_graph)421 std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
422   if (func_graph == nullptr) {
423     return "";
424   }
425 
426   InitModelInfo();
427   irpb::GraphProto *graph_proto = model_.mutable_graph();
428   ExportFuncGraph(func_graph, graph_proto);
429   return model_.SerializeAsString();
430 }
431 
ExportFuncGraph(const FuncGraphPtr & func_graph,irpb::GraphProto * graph_proto)432 void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
433   if (func_graph == nullptr || graph_proto == nullptr) {
434     return;
435   }
436 
437   // map for store ValueNodes of this graph
438   std::map<AnfNodePtr, size_t> const_map;
439 
440   // set graph name
441   graph_proto->set_name(func_graph->ToString());
442 
443   ExportParameters(func_graph, graph_proto);
444 
445   ExportCNodes(func_graph, graph_proto, &const_map);
446 
447   ExportValueNodes(const_map, graph_proto);
448 }
449 
ExportParameters(const FuncGraphPtr & func_graph,irpb::GraphProto * graph_proto)450 void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
451   if (func_graph == nullptr || graph_proto == nullptr) {
452     return;
453   }
454 
455   std::vector<AnfNodePtr> parameters = func_graph->parameters();
456   for (auto &param : parameters) {
457     irpb::ParameterProto *param_proto = graph_proto->add_parameters();
458     param_proto->set_name(param->ToString());
459 
460     SetNodeOutputType(param, param_proto->mutable_type());
461 
462     const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
463     if (param_ptr == nullptr) {
464       MS_LOG(INTERNAL_EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
465     }
466   }
467 }
468 
ExportCNodes(const FuncGraphPtr & func_graph,irpb::GraphProto * graph_proto,std::map<AnfNodePtr,size_t> * const_map_ptr)469 void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
470                                  std::map<AnfNodePtr, size_t> *const_map_ptr) {
471   if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
472     return;
473   }
474   // topo sort nodes
475   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
476   std::map<AnfNodePtr, size_t> apply_map;
477   for (const AnfNodePtr &node : nodes) {
478     MS_EXCEPTION_IF_NULL(node);
479     if (!node->isa<CNode>()) {
480       continue;
481     }
482     auto cnode = node->cast<CNodePtr>();
483     if (cnode != func_graph->get_return()) {
484       ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto);
485     } else {
486       ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
487     }
488   }
489 }
490 
ExportCNode(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,size_t> * apply_map_ptr,std::map<AnfNodePtr,size_t> * const_map_ptr,irpb::GraphProto * graph_proto)491 void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
492                                 std::map<AnfNodePtr, size_t> *apply_map_ptr,
493                                 std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
494   if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
495       graph_proto == nullptr) {
496     return;
497   }
498 
499   auto apply_idx = apply_map_ptr->size() + 1;
500   (*apply_map_ptr)[node] = apply_idx;
501 
502   auto &inputs = node->inputs();
503   if (inputs.size() < 1) {
504     MS_LOG(INTERNAL_EXCEPTION) << "Inputs of CNode is empty";
505   }
506   AnfNodePtr op = inputs[0];
507   irpb::NodeProto *node_proto = graph_proto->add_node();
508 
509   // CNode/ConstGraph/Const/Parameter
510   if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
511     MS_LOG(DEBUG) << "Operator must be a primitive";
512   } else {
513     GetOpNodeTypeAndAttrs(func_graph, node, node_proto);
514     node_proto->set_name(std::to_string(apply_idx));
515     node_proto->set_scope(node->scope()->name());
516     node_proto->set_full_name(GetKernelNodeName(node));
517 
518     // process OP inputs
519     for (size_t i = 1; i < inputs.size(); ++i) {
520       irpb::InputProto *input_proto = node_proto->add_input();
521       input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE);
522       std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
523       input_proto->set_name(id);
524     }
525 
526     // set node output type
527     SetNodeOutputType(node, node_proto->mutable_output_type());
528 
529     if (IsValueNode<Primitive>(op)) {
530       PrimitivePtr primitive = GetValueNode<PrimitivePtr>(op);
531       if (!primitive->instance_name().empty()) {
532         node_proto->set_instance_name(primitive->instance_name());
533       }
534     }
535   }
536 }
537 
ExportFuncGraphOutput(const FuncGraphPtr & func_graph,const CNodePtr & ret_node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr,irpb::GraphProto * graph_proto)538 void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
539                                           const std::map<AnfNodePtr, size_t> &apply_map,
540                                           std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
541   if (ret_node == nullptr || !ret_node->isa<CNode>()) {
542     MS_LOG(INTERNAL_EXCEPTION) << "Graph return node is illegal";
543   }
544   // ret node has two input 1 ret op + 1 value
545   const size_t ret_input_size = 2;
546   if (ret_node->size() != ret_input_size) {
547     return;
548   }
549   AnfNodePtr arg = ret_node->input(1);
550   if (graph_proto == nullptr) {
551     MS_LOG(INTERNAL_EXCEPTION) << "graph_proto is nullptr";
552   }
553   irpb::OutputProto *output_proto = graph_proto->add_outputs();
554   if (output_proto == nullptr) {
555     MS_LOG(INTERNAL_EXCEPTION) << "output_proto is nullptr";
556   }
557   std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
558   output_proto->set_name(id);
559   SetNodeOutputType(arg, output_proto->mutable_type());
560 }
561 
CompareValue(const std::pair<AnfNodePtr,size_t> & x,const std::pair<AnfNodePtr,size_t> & y)562 static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
563   return x.second < y.second;
564 }
565 
ExportValueNodes(const std::map<AnfNodePtr,size_t> & const_map,irpb::GraphProto * graph_proto)566 void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) {
567   std::vector<std::pair<AnfNodePtr, size_t>> nodes;
568   (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
569                        [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
570 
571   sort(nodes.begin(), nodes.end(), CompareValue);
572 
573   for (auto &item : nodes) {
574     if (graph_proto == nullptr) {
575       MS_LOG(INTERNAL_EXCEPTION) << "graph_proto is nullptr";
576     }
577     irpb::NamedValueProto *named_value = graph_proto->add_const_vals();
578     MS_EXCEPTION_IF_NULL(named_value);
579     named_value->set_key(GetConstNodeId(item.second));
580     SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
581   }
582 }
583 
InitModelInfo()584 void ProtoExporter::InitModelInfo() { model_.set_ir_version(static_cast<int64_t>(irpb::IR_VERSION)); }
585 
GetFuncGraphProtoString(const FuncGraphPtr & func_graph)586 std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
587   ProtoExporter exporter;
588   return exporter.GetFuncGraphProtoString(func_graph);
589 }
590 
GetFuncGraphProtoJsonString(const FuncGraphPtr & func_graph)591 std::string GetFuncGraphProtoJsonString(const FuncGraphPtr &func_graph) {
592   ProtoExporter exporter;
593   irpb::GraphProto graph_proto = irpb::GraphProto();
594   exporter.ExportFuncGraph(func_graph, &graph_proto);
595   std::string graph_proto_str;
596   (void)google::protobuf::util::MessageToJsonString(graph_proto, &graph_proto_str);
597   return graph_proto_str;
598 }
599 
600 #ifdef ENABLE_DUMP_IR
DumpIRProto(const FuncGraphPtr & func_graph,const std::string & suffix)601 void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
602   if (func_graph == nullptr) {
603     MS_LOG(ERROR) << "Func graph is nullptr";
604     return;
605   }
606   std::string file_path = GetSaveGraphsPathName("ms_output_" + suffix + ".pb");
607   auto realpath = Common::CreatePrefixPath(file_path);
608   if (!realpath.has_value()) {
609     MS_LOG(ERROR) << "Get real path failed, path=" << file_path;
610     return;
611   }
612 
613   ChangeFileMode(realpath.value(), S_IWUSR);
614   // write to pb file
615   std::ofstream ofs(file_path);
616   if (!ofs.is_open()) {
617     MS_LOG(ERROR) << "Open file '" << file_path << "' failed!" << ErrnoToString(errno);
618     return;
619   }
620   ofs << GetFuncGraphProtoString(func_graph);
621   ofs.close();
622   // set file mode to read only by user
623   ChangeFileMode(file_path, S_IRUSR);
624 }
625 #else
DumpIRProto(const FuncGraphPtr &,const std::string &)626 void DumpIRProto(const FuncGraphPtr &, const std::string &) {
627   static bool already_printed = false;
628   if (already_printed) {
629     return;
630   }
631   already_printed = true;
632   MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled, "
633                   << "please recompile source to enable it. See help of building script.";
634 }
635 #endif
636 }  // namespace mindspore
637