1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "debug/debugger/proto_exporter.h"
17
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <algorithm>
25
26 #include "debug/anf_ir_utils.h"
27 #include "debug/common.h"
28 #include "debug/debugger/debugger.h"
29 #include "debug/data_dump/dump_json_parser.h"
30 #include "proto/debug_graph.pb.h"
31 #include "ir/graph_utils.h"
32 #include "utils/symbolic.h"
33 #include "utils/trace_base.h"
34
35 namespace mindspore {
36
37 using TypeInfoToProtoTypeMap = std::vector<std::pair<const char *, debugger::DataType>>;
38
39 void SetOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto);
40
CheckIfValidType(const TypePtr & type,debugger::TypeProto * const type_proto)41 void CheckIfValidType(const TypePtr &type, debugger::TypeProto *const type_proto) {
42 if (!(type->isa<Number>() || type->isa<TensorType>() || type->isa<Tuple>() || type->isa<TypeType>() ||
43 type->isa<List>() || type->isa<TypeAnything>() || type->isa<RefKeyType>() || type->isa<RefType>() ||
44 type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<SymbolicKeyType>() ||
45 type->isa<UMonadType>() || type->isa<IOMonadType>())) {
46 MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
47 }
48 if (type->isa<Number>()) {
49 type_proto->set_data_type(GetDebuggerNumberDataType(type));
50 }
51 }
52
SetTensorTypeProto(const TypePtr & type,const BaseShapePtr & shape,debugger::TypeProto * type_proto)53 void SetTensorTypeProto(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) {
54 TypePtr elem_type = dyn_cast<TensorType>(type)->element();
55 type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
56 if (shape != nullptr && shape->isa<abstract::Shape>()) {
57 abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
58 for (const auto &elem : shape_info->shape()) {
59 type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
60 }
61 }
62 }
63
SetTupleTypeProto(const TypePtr & type,debugger::TypeProto * type_proto)64 void SetTupleTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) {
65 TuplePtr tuple_type = dyn_cast<Tuple>(type);
66 for (const auto &elem_type : tuple_type->elements()) {
67 SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
68 }
69 }
70
SetListTypeProto(const TypePtr & type,debugger::TypeProto * type_proto)71 void SetListTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) {
72 ListPtr list_type = dyn_cast<List>(type);
73 for (const auto &elem_type : list_type->elements()) {
74 SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
75 }
76 }
77
78 static TypeInfoToProtoTypeMap type_info_to_proto_type = {
79 {typeid(TensorType).name(), debugger::DT_TENSOR}, {typeid(Tuple).name(), debugger::DT_TUPLE},
80 {typeid(TypeType).name(), debugger::DT_TYPE}, {typeid(List).name(), debugger::DT_LIST},
81 {typeid(TypeAnything).name(), debugger::DT_ANYTHING}, {typeid(RefKeyType).name(), debugger::DT_REFKEY},
82 {typeid(RefType).name(), debugger::DT_REF}, {typeid(Function).name(), debugger::DT_GRAPH},
83 {typeid(TypeNone).name(), debugger::DT_NONE}, {typeid(String).name(), debugger::DT_STRING},
84 {typeid(UMonadType).name(), debugger::DT_UMONAD}, {typeid(IOMonadType).name(), debugger::DT_IOMONAD}};
85
SetOutputType(const TypePtr & type,const BaseShapePtr & shape,debugger::TypeProto * type_proto)86 void SetOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) {
87 if (type_proto == nullptr) {
88 return;
89 }
90 if (type == nullptr) {
91 type_proto->set_data_type(debugger::DT_UNDEFINED);
92 return;
93 }
94 CheckIfValidType(type, type_proto);
95 for (auto &it : type_info_to_proto_type) {
96 if (type->IsFromTypeId(Base::GetTypeId(it.first))) {
97 type_proto->set_data_type(it.second);
98 break;
99 }
100 }
101 if (type->isa<TensorType>()) {
102 SetTensorTypeProto(type, shape, type_proto);
103 return;
104 }
105 if (type->isa<Tuple>()) {
106 SetTupleTypeProto(type, type_proto);
107 return;
108 }
109 if (type->isa<List>()) {
110 SetListTypeProto(type, type_proto);
111 }
112 }
113
SetNodeOutputType(const AnfNodePtr & node,debugger::TypeProto * type_proto)114 void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto) {
115 if (node == nullptr || type_proto == nullptr) {
116 return;
117 }
118 SetOutputType(node->Type(), node->Shape(), type_proto);
119 }
120
SetValueToProto(const ValuePtr & val,debugger::ValueProto * value_proto)121 void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) {
122 if (val == nullptr || value_proto == nullptr) {
123 return;
124 }
125
126 if (val->isa<StringImm>()) {
127 const StringImmPtr &value = dyn_cast<StringImm>(val);
128 value_proto->set_dtype(debugger::DT_STRING);
129 value_proto->set_str_val(value->value());
130 } else if (val->isa<Scalar>()) {
131 SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
132 } else if (val->isa<Bool>()) {
133 value_proto->set_dtype(debugger::DT_TYPE);
134 value_proto->mutable_type_val()->set_data_type(debugger::DT_BOOL);
135 } else if (val->isa<Int>()) {
136 value_proto->set_dtype(debugger::DT_TYPE);
137 value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_INT);
138 } else if (val->isa<Float>()) {
139 value_proto->set_dtype(debugger::DT_TYPE);
140 value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_FLOAT);
141 } else if (val->isa<ValueSequeue>()) {
142 SetSequenceToProto(dyn_cast<ValueSequeue>(val), value_proto);
143 } else if (val->isa<None>()) {
144 value_proto->set_dtype(debugger::DT_NONE);
145 value_proto->set_str_val("None");
146 } else if (val->isa<SymbolicKeyInstance>()) {
147 SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
148 ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
149 value_proto->set_dtype(debugger::DT_SYM_INST);
150 value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
151 } else if (val->isa<ValueDictionary>()) {
152 SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
153 } else if (val->isa<tensor::Tensor>()) {
154 tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
155 value_proto->set_dtype(debugger::DT_TENSOR);
156 debugger::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
157 tensor_proto->set_data_type(GetDebuggerNumberDataType(tensor_ptr->Dtype()));
158 for (auto &elem : tensor_ptr->shape()) {
159 tensor_proto->add_dims(elem);
160 }
161 tensor_proto->set_tensor_content(tensor_ptr->data_c(), tensor_ptr->data().nbytes());
162 } else if (val->isa<TensorType>()) {
163 value_proto->set_dtype(debugger::DT_TYPE);
164
165 debugger::TypeProto *type_proto = value_proto->mutable_type_val();
166 type_proto->set_data_type(debugger::DT_TENSOR);
167 TypePtr elem_type = dyn_cast<TensorType>(val)->element();
168 type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
169 } else {
170 MS_LOG(INFO) << "Unsupported type " << val->type_name();
171 }
172 }
173
SetScalarToProto(const ScalarPtr & val,debugger::ValueProto * value_proto)174 void DebuggerProtoExporter::SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto) {
175 if (val == nullptr || value_proto == nullptr) {
176 return;
177 }
178
179 if (val->isa<BoolImm>()) {
180 const BoolImmPtr &value = dyn_cast<BoolImm>(val);
181 value_proto->set_dtype(debugger::DT_BOOL);
182 value_proto->set_bool_val(value->value());
183 } else if (val->isa<Int8Imm>()) {
184 const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
185 value_proto->set_dtype(debugger::DT_INT8);
186 value_proto->set_int_val(value->value());
187 } else if (val->isa<Int16Imm>()) {
188 const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
189 value_proto->set_dtype(debugger::DT_INT16);
190 value_proto->set_int_val(value->value());
191 } else if (val->isa<Int32Imm>()) {
192 const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
193 value_proto->set_dtype(debugger::DT_INT32);
194 value_proto->set_int_val(value->value());
195 } else if (val->isa<Int64Imm>()) {
196 const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
197 value_proto->set_dtype(debugger::DT_INT64);
198 value_proto->set_int_val(value->value());
199 } else if (val->isa<UInt8Imm>()) {
200 const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
201 value_proto->set_dtype(debugger::DT_UINT8);
202 value_proto->set_uint_val(value->value());
203 } else if (val->isa<UInt16Imm>()) {
204 const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
205 value_proto->set_dtype(debugger::DT_UINT16);
206 value_proto->set_uint_val(value->value());
207 } else if (val->isa<UInt32Imm>()) {
208 const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
209 value_proto->set_dtype(debugger::DT_UINT32);
210 value_proto->set_uint_val(value->value());
211 } else if (val->isa<UInt64Imm>()) {
212 const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
213 value_proto->set_dtype(debugger::DT_UINT64);
214 value_proto->set_uint_val(value->value());
215 } else if (val->isa<FP32Imm>()) {
216 const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
217 value_proto->set_dtype(debugger::DT_FLOAT32);
218 value_proto->set_float_val(value->value());
219 } else if (val->isa<FP64Imm>()) {
220 const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
221 value_proto->set_dtype(debugger::DT_FLOAT64);
222 value_proto->set_double_val(value->value());
223 } else {
224 MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString();
225 }
226 }
227
SetSequenceToProto(const ValueSequeuePtr & val,debugger::ValueProto * value_proto)228 void DebuggerProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto) {
229 if (val == nullptr || value_proto == nullptr) {
230 return;
231 }
232
233 if (val->isa<ValueTuple>()) {
234 const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
235 value_proto->set_dtype(debugger::DT_TUPLE);
236 for (const auto &item : value->value()) {
237 SetValueToProto(item, value_proto->add_values());
238 }
239 } else if (val->isa<ValueList>()) {
240 const ValueListPtr &value = dyn_cast<ValueList>(val);
241 value_proto->set_dtype(debugger::DT_LIST);
242 for (const auto &item : value->value()) {
243 SetValueToProto(item, value_proto->add_values());
244 }
245 }
246 }
247
SetDictionaryToProto(const ValueDictionaryPtr & val,debugger::ValueProto * value_proto)248 void DebuggerProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto) {
249 if (val == nullptr || value_proto == nullptr) {
250 return;
251 }
252
253 value_proto->set_dtype(debugger::DT_DICT);
254 for (const auto &item : val->value()) {
255 debugger::NamedValueProto *named_val = value_proto->add_dict_val();
256 named_val->set_key(item.first);
257 SetValueToProto(item.second, named_val->mutable_value());
258 }
259 }
260
GetOpNodeTypeAndAttrs(const FuncGraphPtr &,const AnfNodePtr & node,debugger::NodeProto * node_proto)261 void DebuggerProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node,
262 debugger::NodeProto *node_proto) {
263 if (node == nullptr || node_proto == nullptr) {
264 return;
265 }
266
267 if (node->isa<CNode>() || node->isa<Parameter>() || IsValueNode<FuncGraph>(node)) {
268 MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString();
269 }
270
271 if (!IsValueNode<Primitive>(node)) {
272 MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
273 }
274
275 const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
276 node_proto->set_op_type(prim->name());
277 for (const auto &attr : prim->attrs()) {
278 debugger::AttributeProto *attr_proto = node_proto->add_attribute();
279 attr_proto->set_name(attr.first);
280 SetValueToProto(attr.second, attr_proto->mutable_value());
281 }
282 node_proto->set_scope(node->scope()->name());
283 }
284
GetOpNodeInputId(const FuncGraphPtr &,const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr)285 std::string DebuggerProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
286 const std::map<AnfNodePtr, size_t> &apply_map,
287 std::map<AnfNodePtr, size_t> *const_map_ptr) {
288 if (node == nullptr || const_map_ptr == nullptr) {
289 return "";
290 }
291
292 if (node->isa<CNode>()) {
293 auto iter = apply_map.find(node);
294 if (iter == apply_map.end()) {
295 MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
296 }
297 return std::to_string(iter->second);
298 }
299
300 if (node->isa<Parameter>()) {
301 return node->ToString();
302 }
303
304 if (node->isa<ValueNode>()) {
305 auto iter = const_map_ptr->find(node);
306 if (iter == const_map_ptr->end()) {
307 // Start index number from 1
308 auto const_idx = const_map_ptr->size() + 1;
309 (*const_map_ptr)[node] = const_idx;
310 }
311 return GetConstNodeId((*const_map_ptr)[node]);
312 }
313
314 MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
315 }
316
GetFuncGraphProtoString(const FuncGraphPtr & func_graph,LocDebugDumpMode dump_location)317 std::string DebuggerProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph,
318 LocDebugDumpMode dump_location) {
319 if (func_graph == nullptr) {
320 return "";
321 }
322
323 InitModelInfo();
324 debugger::GraphProto *graph_proto = model_.mutable_graph();
325 ExportFuncGraph(func_graph, graph_proto, dump_location);
326 return model_.SerializeAsString();
327 }
328
GetFuncGraphProto(const FuncGraphPtr & func_graph)329 debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr &func_graph) {
330 if (func_graph == nullptr) {
331 return ModelProto();
332 }
333
334 InitModelInfo();
335 debugger::GraphProto *graph_proto = model_.mutable_graph();
336 ExportFuncGraph(func_graph, graph_proto);
337 return model_;
338 }
339
ExportFuncGraph(const FuncGraphPtr & func_graph,debugger::GraphProto * const graph_proto,LocDebugDumpMode dump_location)340 void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
341 LocDebugDumpMode dump_location) {
342 if (func_graph == nullptr || graph_proto == nullptr) {
343 return;
344 }
345
346 // map for store ValueNodes of this graph
347 std::map<AnfNodePtr, size_t> const_map;
348
349 // set graph name
350 graph_proto->set_name(func_graph->ToString());
351
352 MS_LOG(INFO) << "graph names: " << func_graph->ToString();
353
354 // cast FuncGraph to KernelGraph to access root_graph_id()
355 uint32_t root_graph_id = static_cast<session::KernelGraph *>(func_graph.get())->root_graph_id();
356
357 MS_LOG(INFO) << "root graph id: " << root_graph_id;
358
359 // set root graph id
360 graph_proto->set_root_name(std::to_string(root_graph_id));
361
362 ExportParameters(func_graph, graph_proto);
363
364 ExportCNodes(func_graph, graph_proto, &const_map, dump_location);
365
366 ExportValueNodes(const_map, graph_proto);
367 }
368
ExportParameters(const FuncGraphPtr & func_graph,debugger::GraphProto * graph_proto)369 void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) {
370 if (func_graph == nullptr || graph_proto == nullptr) {
371 return;
372 }
373
374 // cast FuncGraph to KernelGraph to access inputs()
375 std::vector<AnfNodePtr> parameters = static_cast<session::KernelGraph *>(func_graph.get())->inputs();
376
377 for (auto ¶m : parameters) {
378 debugger::ParameterProto *param_proto = graph_proto->add_parameters();
379 param_proto->set_name(param->ToString());
380
381 SetNodeOutputType(param, param_proto->mutable_type());
382
383 const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
384 if (param_ptr == nullptr) {
385 MS_LOG(INFO) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
386 }
387 }
388 }
389
ExportCNodes(const FuncGraphPtr & func_graph,debugger::GraphProto * const graph_proto,std::map<AnfNodePtr,size_t> * const_map_ptr,LocDebugDumpMode dump_location)390 void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
391 std::map<AnfNodePtr, size_t> *const_map_ptr, LocDebugDumpMode dump_location) {
392 if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
393 return;
394 }
395 // topo sort nodes
396 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
397 std::map<AnfNodePtr, size_t> apply_map;
398 for (const AnfNodePtr &node : nodes) {
399 MS_EXCEPTION_IF_NULL(node);
400 if (!node->isa<CNode>()) {
401 continue;
402 }
403 auto cnode = node->cast<CNodePtr>();
404 if (cnode != func_graph->get_return()) {
405 ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto, dump_location);
406 } else {
407 ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
408 }
409 }
410 }
411
ExportCNode(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,size_t> * apply_map_ptr,std::map<AnfNodePtr,size_t> * const_map_ptr,debugger::GraphProto * const graph_proto,LocDebugDumpMode dump_location)412 void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
413 std::map<AnfNodePtr, size_t> *apply_map_ptr,
414 std::map<AnfNodePtr, size_t> *const_map_ptr,
415 debugger::GraphProto *const graph_proto, LocDebugDumpMode dump_location) {
416 if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
417 graph_proto == nullptr) {
418 return;
419 }
420
421 auto apply_idx = apply_map_ptr->size() + 1;
422 (*apply_map_ptr)[node] = apply_idx;
423
424 auto &inputs = node->inputs();
425 if (inputs.size() < 1) {
426 MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
427 }
428 AnfNodePtr op = inputs[0];
429 debugger::NodeProto *node_proto = graph_proto->add_node();
430
431 // CNode/ConstGraph/Const/Parameter
432 if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
433 MS_LOG(WARNING) << "Operator must be a primitive";
434 } else {
435 GetOpNodeTypeAndAttrs(func_graph, op, node_proto);
436 node_proto->set_name(std::to_string(apply_idx));
437 node_proto->set_scope(node->scope()->name());
438
439 // add full_name for debugger
440 std::string full_name = GetKernelNodeName(node);
441 node_proto->set_full_name(full_name);
442 MS_LOG(INFO) << "full_name: " << full_name;
443 if (dump_location == kDebugWholeStack) {
444 std::ostringstream buffer;
445 auto traces = mindspore::trace::GetSourceLineList(node);
446 for (auto &trace : traces) {
447 buffer << " # " << trace;
448 }
449 node_proto->set_source_address(buffer.str());
450 }
451 // process OP inputs
452 for (size_t i = 1; i < inputs.size(); ++i) {
453 debugger::InputProto *input_proto = node_proto->add_input();
454 input_proto->set_type(debugger::InputProto_EdgeType_DATA_EDGE);
455 std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
456 input_proto->set_name(id);
457 }
458
459 // set node output type
460 SetNodeOutputType(node, node_proto->mutable_output_type());
461 }
462 }
463
ExportFuncGraphOutput(const FuncGraphPtr & func_graph,const CNodePtr & ret_node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr,debugger::GraphProto * graph_proto)464 void DebuggerProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
465 const std::map<AnfNodePtr, size_t> &apply_map,
466 std::map<AnfNodePtr, size_t> *const_map_ptr,
467 debugger::GraphProto *graph_proto) {
468 if (ret_node == nullptr || !ret_node->isa<CNode>()) {
469 MS_LOG(EXCEPTION) << "Graph return node is illegal";
470 }
471 AnfNodePtr arg = ret_node->input(1);
472 if (graph_proto == nullptr) {
473 MS_LOG(EXCEPTION) << "graph_proto is nullptr";
474 }
475 debugger::OutputProto *output_proto = graph_proto->add_outputs();
476 if (output_proto == nullptr) {
477 MS_LOG(EXCEPTION) << "output_proto is nullptr";
478 }
479 std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
480 output_proto->set_name(id);
481 SetNodeOutputType(arg, output_proto->mutable_type());
482 }
483
CompareValue(const std::pair<AnfNodePtr,size_t> & x,const std::pair<AnfNodePtr,size_t> & y)484 static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
485 return x.second < y.second;
486 }
487
ExportValueNodes(const std::map<AnfNodePtr,size_t> & const_map,debugger::GraphProto * graph_proto)488 void DebuggerProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map,
489 debugger::GraphProto *graph_proto) {
490 std::vector<std::pair<AnfNodePtr, size_t>> nodes;
491 (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
492 [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
493
494 sort(nodes.begin(), nodes.end(), CompareValue);
495
496 for (auto &item : nodes) {
497 if (graph_proto == nullptr) {
498 MS_LOG(EXCEPTION) << "graph_proto is nullptr";
499 }
500 debugger::NamedValueProto *named_value = graph_proto->add_const_vals();
501 MS_EXCEPTION_IF_NULL(named_value);
502 named_value->set_key(GetConstNodeId(item.second));
503 SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
504 }
505 }
506
InitModelInfo()507 void DebuggerProtoExporter::InitModelInfo() { model_.set_ir_version(debugger::IR_VERSION); }
508
GetDebuggerFuncGraphProto(const FuncGraphPtr & func_graph)509 debugger::ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph) {
510 DebuggerProtoExporter exporter;
511 return exporter.GetFuncGraphProto(func_graph);
512 }
513
GetDebuggerNumberDataType(const TypePtr & type)514 debugger::DataType GetDebuggerNumberDataType(const TypePtr &type) {
515 switch (type->type_id()) {
516 case kNumberTypeBool:
517 return debugger::DT_BOOL;
518 case kNumberTypeInt8:
519 return debugger::DT_INT8;
520 case kNumberTypeInt16:
521 return debugger::DT_INT16;
522 case kNumberTypeInt32:
523 return debugger::DT_INT32;
524 case kNumberTypeInt64:
525 return debugger::DT_INT64;
526 case kNumberTypeUInt8:
527 return debugger::DT_UINT8;
528 case kNumberTypeUInt16:
529 return debugger::DT_UINT16;
530 case kNumberTypeUInt32:
531 return debugger::DT_UINT32;
532 case kNumberTypeUInt64:
533 return debugger::DT_UINT64;
534 case kNumberTypeFloat16:
535 return debugger::DT_FLOAT16;
536 case kNumberTypeFloat32:
537 return debugger::DT_FLOAT32;
538 case kNumberTypeFloat64:
539 return debugger::DT_FLOAT64;
540 case kNumberTypeInt:
541 return debugger::DT_BASE_INT;
542 case kNumberTypeUInt:
543 return debugger::DT_BASE_UINT;
544 case kNumberTypeFloat:
545 return debugger::DT_BASE_FLOAT;
546 default:
547 MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
548 }
549 }
550
551 #ifdef ENABLE_DUMP_IR
DumpIRProtoWithSrcInfo(const FuncGraphPtr & func_graph,const std::string & suffix,const std::string & target_dir,LocDebugDumpMode dump_location)552 void DumpIRProtoWithSrcInfo(const FuncGraphPtr &func_graph, const std::string &suffix, const std::string &target_dir,
553 LocDebugDumpMode dump_location) {
554 DebuggerProtoExporter exporter;
555 std::string graph_proto = exporter.GetFuncGraphProtoString(func_graph, dump_location);
556 if (func_graph == nullptr) {
557 MS_LOG(ERROR) << "Func graph is nullptr";
558 return;
559 }
560 std::string file_path = target_dir + "/" + "ms_output_" + suffix + ".pb";
561 auto realpath = Common::CreatePrefixPath(file_path);
562 if (!realpath.has_value()) {
563 MS_LOG(ERROR) << "Get real path failed, path=" << file_path;
564 return;
565 }
566 ChangeFileMode(realpath.value(), S_IWUSR);
567
568 // write to pb file
569 std::ofstream ofs(realpath.value());
570 if (!ofs.is_open()) {
571 MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
572 return;
573 }
574 ofs << graph_proto;
575 ofs.close();
576 // set file mode to read only by user
577 ChangeFileMode(file_path, S_IRUSR);
578 }
579 #else
DumpIRProtoWithSrcInfo(const FuncGraphPtr &,const std::string &,const std::string &,LocDebugDumpMode)580 void DumpIRProtoWithSrcInfo(const FuncGraphPtr &, const std::string &, const std::string &, LocDebugDumpMode) {
581 static bool already_printed = false;
582 if (already_printed) {
583 return;
584 }
585 already_printed = true;
586 MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled,"
587 << "because ENABLE_DEBUGGER option is off"
588 << "please recompile source to enable it. See help of building script.";
589 }
590 #endif
591 } // namespace mindspore
592