1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/common/debug/dump_proto.h"
17 #include <algorithm>
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 #include "google/protobuf/util/json_util.h"
24 #include "proto/anf_ir.pb.h"
25 #include "proto/mind_ir.pb.h"
26 #include "ir/graph_utils.h"
27 #include "utils/ms_context.h"
28 #include "utils/symbolic.h"
29 #include "include/common/debug/anf_dump_utils.h"
30 #include "utils/anf_utils.h"
31 #include "frontend/parallel/ops_info/ops_utils.h" // todo: use constant string now
32 #include "mindspore/core/utils/file_utils.h"
33
34 namespace mindspore {
35 class ProtoExporter {
36 public:
ProtoExporter()37 ProtoExporter() {}
~ProtoExporter()38 ~ProtoExporter() {}
39
40 std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
41 void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
42
43 private:
44 void InitModelInfo();
45 void GetOpNodeTypeAndAttrs(const FuncGraphPtr & /* func_graph */, const CNodePtr &cnode, irpb::NodeProto *node_proto);
46 std::string GetOpNodeInputId(const FuncGraphPtr & /* func_graph */, const AnfNodePtr &node,
47 const std::map<AnfNodePtr, size_t> &apply_map,
48 std::map<AnfNodePtr, size_t> *const_map_ptr) const;
49 void SetValueToProtoBasicTypes(const ValuePtr &val, irpb::ValueProto *const value_proto) const;
50 void SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto);
51 void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) const;
52 void SetSequenceToProto(const ValueSequencePtr &val, irpb::ValueProto *value_proto);
53 void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto);
54 void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto);
55 void SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto);
56
57 void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
58 void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
59 std::map<AnfNodePtr, size_t> *const_map_ptr);
60 void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
61 std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto);
62 void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
63 const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
64 irpb::GraphProto *graph_proto);
65 void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto);
66
GetConstNodeId(size_t idx)67 static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
68
69 irpb::ModelProto model_;
70 };
71
72 static std::map<TypeId, irpb::DataType> number_data_type_map = {{kNumberTypeBool, irpb::DT_BOOL},
73 {kNumberTypeInt8, irpb::DT_INT8},
74 {kNumberTypeInt16, irpb::DT_INT16},
75 {kNumberTypeInt32, irpb::DT_INT32},
76 {kNumberTypeInt64, irpb::DT_INT64},
77 {kNumberTypeUInt8, irpb::DT_UINT8},
78 {kNumberTypeUInt16, irpb::DT_UINT16},
79 {kNumberTypeUInt32, irpb::DT_UINT32},
80 {kNumberTypeUInt64, irpb::DT_UINT64},
81 {kNumberTypeFloat16, irpb::DT_FLOAT16},
82 {kNumberTypeFloat32, irpb::DT_FLOAT32},
83 {kNumberTypeFloat64, irpb::DT_FLOAT64},
84 {kNumberTypeBFloat16, irpb::DT_BFLOAT16},
85 {kNumberTypeInt, irpb::DT_BASE_INT},
86 {kNumberTypeUInt, irpb::DT_BASE_UINT},
87 {kNumberTypeFloat, irpb::DT_BASE_FLOAT},
88 {kNumberTypeComplex64, irpb::DT_COMPLEX64},
89 {kNumberTypeComplex128, irpb::DT_COMPLEX128},
90 {kObjectTypeString, irpb::DT_STRING},
91 {kObjectTypeTuple, irpb::DT_TUPLE},
92 {kNumberTypeInt4, irpb::DT_INT4}};
93
GetNumberDataType(const TypePtr & type)94 static irpb::DataType GetNumberDataType(const TypePtr &type) {
95 auto iter = number_data_type_map.find(type->type_id());
96 if (iter != number_data_type_map.end()) {
97 return (*iter).second;
98 } else {
99 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected type " << type->type_name();
100 }
101 }
102
IsKindOfTensorType(const TypePtr & type)103 static inline bool IsKindOfTensorType(const TypePtr &type) {
104 MS_EXCEPTION_IF_NULL(type);
105 return type->isa<TensorType>() || type->isa<RowTensorType>() || type->isa<CSRTensorType>() ||
106 type->isa<COOTensorType>() || type->isa<MapTensorType>();
107 }
108
CheckIfValidType(const TypePtr & type)109 void CheckIfValidType(const TypePtr &type) {
110 MS_EXCEPTION_IF_NULL(type);
111 if (type->isa<Problem>()) {
112 MS_LOG(WARNING) << "The type: " << type->type_name();
113 return;
114 }
115 if (!(type->isa<Number>() || IsKindOfTensorType(type) || type->isa<Tuple>() || type->isa<TypeType>() ||
116 type->isa<List>() || type->isa<TypeAny>() || type->isa<RefKeyType>() || type->isa<RefType>() ||
117 type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<UndeterminedType>() ||
118 type->isa<SymbolicKeyType>() || type->isa<MonadType>() || type->isa<Dictionary>()) ||
119 type->isa<Slice>()) {
120 MS_LOG(INTERNAL_EXCEPTION) << "Unknown type: " << type->type_name();
121 }
122 }
123
SetTensorType(const TypePtr & type,const BaseShapePtr & shape,irpb::TypeProto * const type_proto)124 void SetTensorType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *const type_proto) {
125 TypePtr elem_type = dyn_cast<TensorType>(type)->element();
126 type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
127 type_proto->set_data_type(irpb::DT_TENSOR);
128 if (shape != nullptr && shape->isa<abstract::Shape>()) {
129 abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
130 for (const auto &elem : shape_info->shape()) {
131 type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
132 }
133 }
134 }
135
SetNodeOutputType(const TypePtr & type,const BaseShapePtr & shape,irpb::TypeProto * type_proto)136 void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
137 if (type_proto == nullptr) {
138 return;
139 }
140 if (type == nullptr) {
141 type_proto->set_data_type(irpb::DT_UNDEFINED);
142 return;
143 }
144 if (type->isa<External>()) {
145 return;
146 }
147 CheckIfValidType(type);
148 if (type->isa<Number>()) {
149 type_proto->set_data_type(GetNumberDataType(type));
150 } else if (type->isa<TensorType>()) {
151 SetTensorType(type, shape, type_proto);
152 } else if (type->isa<Tuple>()) {
153 TuplePtr tuple_type = dyn_cast<Tuple>(type);
154 type_proto->set_data_type(irpb::DT_TUPLE);
155 for (const auto &elem_type : tuple_type->elements()) {
156 SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
157 }
158 } else if (type->isa<TypeType>()) {
159 type_proto->set_data_type(irpb::DT_TYPE);
160 } else if (type->isa<List>()) {
161 ListPtr list_type = dyn_cast<List>(type);
162 type_proto->set_data_type(irpb::DT_LIST);
163 for (const auto &elem_type : list_type->elements()) {
164 SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
165 }
166 } else if (type->isa<TypeAny>()) {
167 type_proto->set_data_type(irpb::DT_ANY);
168 } else if (type->isa<RefKeyType>()) {
169 type_proto->set_data_type(irpb::DT_REFKEY);
170 } else if (type->isa<RefType>()) {
171 type_proto->set_data_type(irpb::DT_REF);
172 } else if (type->isa<Function>()) {
173 type_proto->set_data_type(irpb::DT_GRAPH);
174 } else if (type->isa<TypeNone>()) {
175 type_proto->set_data_type(irpb::DT_NONE);
176 } else if (type->isa<String>()) {
177 type_proto->set_data_type(irpb::DT_STRING);
178 } else {
179 type_proto->set_data_type(irpb::DT_SLICE);
180 }
181 }
182
SetNodeOutputType(const AnfNodePtr & node,irpb::TypeProto * type_proto)183 void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) {
184 if (node == nullptr || type_proto == nullptr) {
185 return;
186 }
187 SetNodeOutputType(node->Type(), node->Shape(), type_proto);
188 }
189
SetValueToProtoBasicTypes(const ValuePtr & val,irpb::ValueProto * const value_proto) const190 void ProtoExporter::SetValueToProtoBasicTypes(const ValuePtr &val, irpb::ValueProto *const value_proto) const {
191 if (val->isa<StringImm>()) {
192 const StringImmPtr &value = dyn_cast<StringImm>(val);
193 value_proto->set_dtype(irpb::DT_STRING);
194 value_proto->set_str_val(value->value());
195 } else if (val->isa<Scalar>()) {
196 SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
197 } else if (val->isa<Bool>()) {
198 value_proto->set_dtype(irpb::DT_TYPE);
199 value_proto->mutable_type_val()->set_data_type(irpb::DT_BOOL);
200 } else if (val->isa<Int>()) {
201 value_proto->set_dtype(irpb::DT_TYPE);
202 value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_INT);
203 } else if (val->isa<UInt>()) {
204 value_proto->set_dtype(irpb::DT_TYPE);
205 value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_UINT);
206 } else if (val->isa<Float>() || val->isa<BFloat>()) {
207 value_proto->set_dtype(irpb::DT_TYPE);
208 value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_FLOAT);
209 }
210 }
211
SetValueToProto(const ValuePtr & val,irpb::ValueProto * value_proto)212 void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
213 if (val == nullptr || value_proto == nullptr) {
214 return;
215 }
216
217 SetValueToProtoBasicTypes(val, value_proto);
218
219 if (val->isa<ValueSequence>()) {
220 SetSequenceToProto(dyn_cast<ValueSequence>(val), value_proto);
221 } else if (val->isa<None>()) {
222 value_proto->set_dtype(irpb::DT_NONE);
223 value_proto->set_str_val("None");
224 } else if (val->isa<SymbolicKeyInstance>()) {
225 SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
226 ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
227 value_proto->set_dtype(irpb::DT_SYM_INST);
228 value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
229 } else if (val->isa<ValueDictionary>()) {
230 SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
231 } else if (val->isa<tensor::Tensor>()) {
232 tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
233 value_proto->set_dtype(irpb::DT_TENSOR);
234 irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
235 tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype()));
236 for (auto &elem : tensor_ptr->shape()) {
237 tensor_proto->add_dims(elem);
238 }
239 } else if (val->isa<TensorType>()) {
240 value_proto->set_dtype(irpb::DT_TYPE);
241
242 irpb::TypeProto *type_proto = value_proto->mutable_type_val();
243 type_proto->set_data_type(irpb::DT_TENSOR);
244 TypePtr elem_type = dyn_cast<TensorType>(val)->element();
245 type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
246 } else if (val->isa<Monad>() || val->isa<MonadType>()) {
247 value_proto->set_str_val(val->ToString());
248 } else if (val->isa<Complex>()) {
249 value_proto->set_dtype(irpb::DT_TYPE);
250 value_proto->mutable_type_val()->set_data_type(irpb::DT_BASE_COMPLEX);
251 } else {
252 MS_LOG(DEBUG) << "Unsupported type " << val->type_name();
253 }
254 }
255
SetScalarToProto(const ScalarPtr & val,irpb::ValueProto * value_proto) const256 void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) const {
257 if (val == nullptr || value_proto == nullptr) {
258 return;
259 }
260
261 if (val->isa<BoolImm>()) {
262 const BoolImmPtr &value = dyn_cast<BoolImm>(val);
263 value_proto->set_dtype(irpb::DT_BOOL);
264 value_proto->set_bool_val(value->value());
265 } else if (val->isa<Int8Imm>()) {
266 const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
267 value_proto->set_dtype(irpb::DT_INT8);
268 value_proto->set_int_val(value->value());
269 } else if (val->isa<Int16Imm>()) {
270 const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
271 value_proto->set_dtype(irpb::DT_INT16);
272 value_proto->set_int_val(value->value());
273 } else if (val->isa<Int32Imm>()) {
274 const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
275 value_proto->set_dtype(irpb::DT_INT32);
276 value_proto->set_int_val(value->value());
277 } else if (val->isa<Int64Imm>()) {
278 const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
279 value_proto->set_dtype(irpb::DT_INT64);
280 value_proto->set_int_val(value->value());
281 } else if (val->isa<UInt8Imm>()) {
282 const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
283 value_proto->set_dtype(irpb::DT_UINT8);
284 value_proto->set_uint_val(value->value());
285 } else if (val->isa<UInt16Imm>()) {
286 const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
287 value_proto->set_dtype(irpb::DT_UINT16);
288 value_proto->set_uint_val(value->value());
289 } else if (val->isa<UInt32Imm>()) {
290 const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
291 value_proto->set_dtype(irpb::DT_UINT32);
292 value_proto->set_uint_val(value->value());
293 } else if (val->isa<UInt64Imm>()) {
294 const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
295 value_proto->set_dtype(irpb::DT_UINT64);
296 value_proto->set_uint_val(value->value());
297 } else if (val->isa<FP32Imm>()) {
298 const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
299 value_proto->set_dtype(irpb::DT_FLOAT32);
300 value_proto->set_float_val(value->value());
301 } else if (val->isa<FP64Imm>()) {
302 const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
303 value_proto->set_dtype(irpb::DT_FLOAT64);
304 value_proto->set_double_val(value->value());
305 } else {
306 MS_LOG(INTERNAL_EXCEPTION) << "Unknown scalar type " << val->ToString();
307 }
308 }
309
SetSequenceToProto(const ValueSequencePtr & val,irpb::ValueProto * value_proto)310 void ProtoExporter::SetSequenceToProto(const ValueSequencePtr &val, irpb::ValueProto *value_proto) {
311 if (val == nullptr || value_proto == nullptr) {
312 return;
313 }
314
315 if (val->isa<ValueTuple>()) {
316 const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
317 value_proto->set_dtype(irpb::DT_TUPLE);
318 for (const auto &item : value->value()) {
319 SetValueToProto(item, value_proto->add_values());
320 }
321 } else if (val->isa<ValueList>()) {
322 const ValueListPtr &value = dyn_cast<ValueList>(val);
323 value_proto->set_dtype(irpb::DT_LIST);
324 for (const auto &item : value->value()) {
325 SetValueToProto(item, value_proto->add_values());
326 }
327 }
328 }
329
SetDictionaryToProto(const ValueDictionaryPtr & val,irpb::ValueProto * value_proto)330 void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) {
331 if (val == nullptr || value_proto == nullptr) {
332 return;
333 }
334
335 value_proto->set_dtype(irpb::DT_DICT);
336 for (const auto &item : val->value()) {
337 irpb::NamedValueProto *named_val = value_proto->add_dict_val();
338 MS_EXCEPTION_IF_NULL(item.first);
339 if (!item.first->isa<StringImm>()) {
340 MS_LOG(INTERNAL_EXCEPTION) << "The key of NamedValueProto should be string type, but got "
341 << item.first->ToString();
342 }
343 named_val->set_key(GetValue<std::string>(item.first));
344 SetValueToProto(item.second, named_val->mutable_value());
345 }
346 }
347
GetOpNodeTypeAndAttrs(const FuncGraphPtr &,const CNodePtr & cnode,irpb::NodeProto * node_proto)348 void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr & /* func_graph */, const CNodePtr &cnode,
349 irpb::NodeProto *node_proto) {
350 const auto &inputs = cnode->inputs();
351 AnfNodePtr op_node = inputs[0];
352
353 if (op_node == nullptr || node_proto == nullptr) {
354 return;
355 }
356
357 if (op_node->isa<CNode>() || op_node->isa<Parameter>() || IsValueNode<FuncGraph>(op_node)) {
358 MS_LOG(INTERNAL_EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got "
359 << op_node->ToString();
360 }
361
362 if (!IsValueNode<Primitive>(op_node)) {
363 MS_LOG(INTERNAL_EXCEPTION) << "Op node is not primitive: " << op_node->ToString();
364 }
365
366 const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(op_node);
367 node_proto->set_op_type(prim->name());
368 for (const auto &attr : prim->attrs()) {
369 irpb::AttributeProto *attr_proto = node_proto->add_attribute();
370 attr_proto->set_name(attr.first);
371 SetValueToProto(attr.second, attr_proto->mutable_value());
372 }
373
374 // Only CNode save the operator strategy
375 auto strategy_value = AnfDumpHandler::InStrategyValue(cnode);
376 if (strategy_value != nullptr) {
377 irpb::AttributeProto *attr_proto = node_proto->add_attribute();
378 attr_proto->set_name(mindspore::parallel::IN_STRATEGY);
379 SetValueToProto(strategy_value, attr_proto->mutable_value());
380 }
381
382 node_proto->set_scope(op_node->scope()->name());
383 }
384
GetOpNodeInputId(const FuncGraphPtr &,const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr) const385 std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr & /* func_graph */, const AnfNodePtr &node,
386 const std::map<AnfNodePtr, size_t> &apply_map,
387 std::map<AnfNodePtr, size_t> *const_map_ptr) const {
388 if (node == nullptr || const_map_ptr == nullptr) {
389 return "";
390 }
391
392 if (node->isa<CNode>()) {
393 auto iter = apply_map.find(node);
394 if (iter == apply_map.end()) {
395 MS_LOG(INTERNAL_EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
396 }
397 return std::to_string(iter->second);
398 }
399
400 if (node->isa<Parameter>()) {
401 return node->ToString();
402 }
403
404 if (AnfUtils::IsCustomActorNode(node)) {
405 return AnfUtils::GetCustomActorName(node);
406 }
407
408 if (node->isa<ValueNode>()) {
409 auto iter = const_map_ptr->find(node);
410 if (iter == const_map_ptr->end()) {
411 // Start index number from 1
412 auto const_idx = const_map_ptr->size() + 1;
413 (*const_map_ptr)[node] = const_idx;
414 }
415 return GetConstNodeId((*const_map_ptr)[node]);
416 }
417
418 MS_LOG(INTERNAL_EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
419 }
420
GetFuncGraphProtoString(const FuncGraphPtr & func_graph)421 std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
422 if (func_graph == nullptr) {
423 return "";
424 }
425
426 InitModelInfo();
427 irpb::GraphProto *graph_proto = model_.mutable_graph();
428 ExportFuncGraph(func_graph, graph_proto);
429 return model_.SerializeAsString();
430 }
431
ExportFuncGraph(const FuncGraphPtr & func_graph,irpb::GraphProto * graph_proto)432 void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
433 if (func_graph == nullptr || graph_proto == nullptr) {
434 return;
435 }
436
437 // map for store ValueNodes of this graph
438 std::map<AnfNodePtr, size_t> const_map;
439
440 // set graph name
441 graph_proto->set_name(func_graph->ToString());
442
443 ExportParameters(func_graph, graph_proto);
444
445 ExportCNodes(func_graph, graph_proto, &const_map);
446
447 ExportValueNodes(const_map, graph_proto);
448 }
449
ExportParameters(const FuncGraphPtr & func_graph,irpb::GraphProto * graph_proto)450 void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
451 if (func_graph == nullptr || graph_proto == nullptr) {
452 return;
453 }
454
455 std::vector<AnfNodePtr> parameters = func_graph->parameters();
456 for (auto ¶m : parameters) {
457 irpb::ParameterProto *param_proto = graph_proto->add_parameters();
458 param_proto->set_name(param->ToString());
459
460 SetNodeOutputType(param, param_proto->mutable_type());
461
462 const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
463 if (param_ptr == nullptr) {
464 MS_LOG(INTERNAL_EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
465 }
466 }
467 }
468
ExportCNodes(const FuncGraphPtr & func_graph,irpb::GraphProto * graph_proto,std::map<AnfNodePtr,size_t> * const_map_ptr)469 void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
470 std::map<AnfNodePtr, size_t> *const_map_ptr) {
471 if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
472 return;
473 }
474 // topo sort nodes
475 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
476 std::map<AnfNodePtr, size_t> apply_map;
477 for (const AnfNodePtr &node : nodes) {
478 MS_EXCEPTION_IF_NULL(node);
479 if (!node->isa<CNode>()) {
480 continue;
481 }
482 auto cnode = node->cast<CNodePtr>();
483 if (cnode != func_graph->get_return()) {
484 ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto);
485 } else {
486 ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
487 }
488 }
489 }
490
ExportCNode(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,size_t> * apply_map_ptr,std::map<AnfNodePtr,size_t> * const_map_ptr,irpb::GraphProto * graph_proto)491 void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
492 std::map<AnfNodePtr, size_t> *apply_map_ptr,
493 std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
494 if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
495 graph_proto == nullptr) {
496 return;
497 }
498
499 auto apply_idx = apply_map_ptr->size() + 1;
500 (*apply_map_ptr)[node] = apply_idx;
501
502 auto &inputs = node->inputs();
503 if (inputs.size() < 1) {
504 MS_LOG(INTERNAL_EXCEPTION) << "Inputs of CNode is empty";
505 }
506 AnfNodePtr op = inputs[0];
507 irpb::NodeProto *node_proto = graph_proto->add_node();
508
509 // CNode/ConstGraph/Const/Parameter
510 if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
511 MS_LOG(DEBUG) << "Operator must be a primitive";
512 } else {
513 GetOpNodeTypeAndAttrs(func_graph, node, node_proto);
514 node_proto->set_name(std::to_string(apply_idx));
515 node_proto->set_scope(node->scope()->name());
516 node_proto->set_full_name(GetKernelNodeName(node));
517
518 // process OP inputs
519 for (size_t i = 1; i < inputs.size(); ++i) {
520 irpb::InputProto *input_proto = node_proto->add_input();
521 input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE);
522 std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
523 input_proto->set_name(id);
524 }
525
526 // set node output type
527 SetNodeOutputType(node, node_proto->mutable_output_type());
528
529 if (IsValueNode<Primitive>(op)) {
530 PrimitivePtr primitive = GetValueNode<PrimitivePtr>(op);
531 if (!primitive->instance_name().empty()) {
532 node_proto->set_instance_name(primitive->instance_name());
533 }
534 }
535 }
536 }
537
ExportFuncGraphOutput(const FuncGraphPtr & func_graph,const CNodePtr & ret_node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr,irpb::GraphProto * graph_proto)538 void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
539 const std::map<AnfNodePtr, size_t> &apply_map,
540 std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
541 if (ret_node == nullptr || !ret_node->isa<CNode>()) {
542 MS_LOG(INTERNAL_EXCEPTION) << "Graph return node is illegal";
543 }
544 // ret node has two input 1 ret op + 1 value
545 const size_t ret_input_size = 2;
546 if (ret_node->size() != ret_input_size) {
547 return;
548 }
549 AnfNodePtr arg = ret_node->input(1);
550 if (graph_proto == nullptr) {
551 MS_LOG(INTERNAL_EXCEPTION) << "graph_proto is nullptr";
552 }
553 irpb::OutputProto *output_proto = graph_proto->add_outputs();
554 if (output_proto == nullptr) {
555 MS_LOG(INTERNAL_EXCEPTION) << "output_proto is nullptr";
556 }
557 std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
558 output_proto->set_name(id);
559 SetNodeOutputType(arg, output_proto->mutable_type());
560 }
561
CompareValue(const std::pair<AnfNodePtr,size_t> & x,const std::pair<AnfNodePtr,size_t> & y)562 static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
563 return x.second < y.second;
564 }
565
ExportValueNodes(const std::map<AnfNodePtr,size_t> & const_map,irpb::GraphProto * graph_proto)566 void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) {
567 std::vector<std::pair<AnfNodePtr, size_t>> nodes;
568 (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
569 [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
570
571 sort(nodes.begin(), nodes.end(), CompareValue);
572
573 for (auto &item : nodes) {
574 if (graph_proto == nullptr) {
575 MS_LOG(INTERNAL_EXCEPTION) << "graph_proto is nullptr";
576 }
577 irpb::NamedValueProto *named_value = graph_proto->add_const_vals();
578 MS_EXCEPTION_IF_NULL(named_value);
579 named_value->set_key(GetConstNodeId(item.second));
580 SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
581 }
582 }
583
InitModelInfo()584 void ProtoExporter::InitModelInfo() { model_.set_ir_version(static_cast<int64_t>(irpb::IR_VERSION)); }
585
GetFuncGraphProtoString(const FuncGraphPtr & func_graph)586 std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
587 ProtoExporter exporter;
588 return exporter.GetFuncGraphProtoString(func_graph);
589 }
590
GetFuncGraphProtoJsonString(const FuncGraphPtr & func_graph)591 std::string GetFuncGraphProtoJsonString(const FuncGraphPtr &func_graph) {
592 ProtoExporter exporter;
593 irpb::GraphProto graph_proto = irpb::GraphProto();
594 exporter.ExportFuncGraph(func_graph, &graph_proto);
595 std::string graph_proto_str;
596 (void)google::protobuf::util::MessageToJsonString(graph_proto, &graph_proto_str);
597 return graph_proto_str;
598 }
599
600 #ifdef ENABLE_DUMP_IR
DumpIRProto(const FuncGraphPtr & func_graph,const std::string & suffix)601 void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
602 if (func_graph == nullptr) {
603 MS_LOG(ERROR) << "Func graph is nullptr";
604 return;
605 }
606 std::string file_path = GetSaveGraphsPathName("ms_output_" + suffix + ".pb");
607 auto realpath = Common::CreatePrefixPath(file_path);
608 if (!realpath.has_value()) {
609 MS_LOG(ERROR) << "Get real path failed, path=" << file_path;
610 return;
611 }
612
613 ChangeFileMode(realpath.value(), S_IWUSR);
614 // write to pb file
615 std::ofstream ofs(file_path);
616 if (!ofs.is_open()) {
617 MS_LOG(ERROR) << "Open file '" << file_path << "' failed!" << ErrnoToString(errno);
618 return;
619 }
620 ofs << GetFuncGraphProtoString(func_graph);
621 ofs.close();
622 // set file mode to read only by user
623 ChangeFileMode(file_path, S_IRUSR);
624 }
625 #else
DumpIRProto(const FuncGraphPtr &,const std::string &)626 void DumpIRProto(const FuncGraphPtr &, const std::string &) {
627 static bool already_printed = false;
628 if (already_printed) {
629 return;
630 }
631 already_printed = true;
632 MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled, "
633 << "please recompile source to enable it. See help of building script.";
634 }
635 #endif
636 } // namespace mindspore
637