1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "debug/debugger/proto_exporter.h"
17
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <utility>
22 #include <algorithm>
23
24 #include "utils/hash_set.h"
25 #include "include/common/debug/anf_dump_utils.h"
26 #include "include/backend/debug/data_dump/dump_utils.h"
27 #include "include/common/debug/common.h"
28 #include "include/backend/debug/debugger/debugger.h"
29 #include "ir/graph_utils.h"
30 #include "utils/symbolic.h"
31 #include "utils/trace_base.h"
32 #include "include/backend/debug/data_dump/e2e_dump.h"
33 #include "mindspore/core/utils/file_utils.h"
34 #include "utils/anf_utils.h"
35 #include "debug/debugger/debugger_utils.h"
36
37 namespace mindspore {
38
39 using TypeInfoToProtoTypeMap = std::vector<std::pair<uint32_t, debugger::DataType>>;
40
41 void SetOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto);
42
CheckIfValidType(const TypePtr & type,debugger::TypeProto * const type_proto)43 void CheckIfValidType(const TypePtr &type, debugger::TypeProto *const type_proto) {
44 if (!(type->isa<Number>() || type->isa<TensorType>() || type->isa<Tuple>() || type->isa<TypeType>() ||
45 type->isa<List>() || type->isa<TypeAny>() || type->isa<RefKeyType>() || type->isa<RefType>() ||
46 type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<SymbolicKeyType>() ||
47 type->isa<MapTensorType>() || type->isa<UMonadType>() || type->isa<IOMonadType>())) {
48 MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
49 }
50 if (type->isa<Number>()) {
51 type_proto->set_data_type(GetDebuggerNumberDataType(type));
52 }
53 }
54
SetTensorTypeProto(const TypePtr & type,const BaseShapePtr & shape,debugger::TypeProto * type_proto)55 void SetTensorTypeProto(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) {
56 TypePtr elem_type = dyn_cast<TensorType>(type)->element();
57 type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
58 if (shape != nullptr && shape->isa<abstract::Shape>()) {
59 abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
60 for (const auto &elem : shape_info->shape()) {
61 type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
62 }
63 }
64 }
65
SetTupleTypeProto(const TypePtr & type,debugger::TypeProto * type_proto)66 void SetTupleTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) {
67 TuplePtr tuple_type = dyn_cast<Tuple>(type);
68 for (const auto &elem_type : tuple_type->elements()) {
69 SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
70 }
71 }
72
SetListTypeProto(const TypePtr & type,debugger::TypeProto * type_proto)73 void SetListTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) {
74 ListPtr list_type = dyn_cast<List>(type);
75 for (const auto &elem_type : list_type->elements()) {
76 SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
77 }
78 }
79
80 static TypeInfoToProtoTypeMap type_info_to_proto_type = {
81 {TensorType::kTypeId, debugger::DT_TENSOR}, {Tuple::kTypeId, debugger::DT_TUPLE},
82 {TypeType::kTypeId, debugger::DT_TYPE}, {List::kTypeId, debugger::DT_LIST},
83 {TypeAny::kTypeId, debugger::DT_ANY}, {RefKeyType::kTypeId, debugger::DT_REFKEY},
84 {RefType::kTypeId, debugger::DT_REF}, {Function::kTypeId, debugger::DT_GRAPH},
85 {TypeNone::kTypeId, debugger::DT_NONE}, {String::kTypeId, debugger::DT_STRING},
86 {UMonadType::kTypeId, debugger::DT_UMONAD}, {IOMonadType::kTypeId, debugger::DT_IOMONAD}};
87
SetOutputType(const TypePtr & type,const BaseShapePtr & shape,debugger::TypeProto * type_proto)88 void SetOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) {
89 if (type_proto == nullptr) {
90 return;
91 }
92 if (type == nullptr) {
93 type_proto->set_data_type(debugger::DT_UNDEFINED);
94 return;
95 }
96 CheckIfValidType(type, type_proto);
97 for (auto &it : type_info_to_proto_type) {
98 if (type->IsFromTypeId(it.first)) {
99 type_proto->set_data_type(it.second);
100 break;
101 }
102 }
103 if (type->isa<TensorType>()) {
104 SetTensorTypeProto(type, shape, type_proto);
105 return;
106 }
107 if (type->isa<Tuple>()) {
108 SetTupleTypeProto(type, type_proto);
109 return;
110 }
111 if (type->isa<List>()) {
112 SetListTypeProto(type, type_proto);
113 }
114 }
115
SetNodeOutputType(const AnfNodePtr & node,debugger::TypeProto * type_proto) const116 void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto) const {
117 if (node == nullptr || type_proto == nullptr) {
118 return;
119 }
120 SetOutputType(node->Type(), node->Shape(), type_proto);
121 }
122
SetValueToProto(const ValuePtr & val,debugger::ValueProto * value_proto) const123 void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) const {
124 if (val == nullptr || value_proto == nullptr) {
125 return;
126 }
127
128 if (val->isa<StringImm>()) {
129 const StringImmPtr &value = dyn_cast<StringImm>(val);
130 value_proto->set_dtype(debugger::DT_STRING);
131 value_proto->set_str_val(value->value());
132 } else if (val->isa<Scalar>()) {
133 SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
134 } else if (val->isa<Bool>()) {
135 value_proto->set_dtype(debugger::DT_TYPE);
136 value_proto->mutable_type_val()->set_data_type(debugger::DT_BOOL);
137 } else if (val->isa<Int>()) {
138 value_proto->set_dtype(debugger::DT_TYPE);
139 value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_INT);
140 } else if (val->isa<Float>()) {
141 value_proto->set_dtype(debugger::DT_TYPE);
142 value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_FLOAT);
143 } else if (val->isa<ValueSequence>()) {
144 SetSequenceToProto(dyn_cast<ValueSequence>(val), value_proto);
145 } else if (val->isa<None>()) {
146 value_proto->set_dtype(debugger::DT_NONE);
147 value_proto->set_str_val("None");
148 } else if (val->isa<SymbolicKeyInstance>()) {
149 SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
150 ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
151 value_proto->set_dtype(debugger::DT_SYM_INST);
152 value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
153 } else if (val->isa<ValueDictionary>()) {
154 SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
155 } else if (val->isa<tensor::Tensor>()) {
156 tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
157 value_proto->set_dtype(debugger::DT_TENSOR);
158 debugger::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
159 tensor_proto->set_data_type(GetDebuggerNumberDataType(tensor_ptr->Dtype()));
160 for (auto &elem : tensor_ptr->shape()) {
161 tensor_proto->add_dims(elem);
162 }
163 tensor_proto->set_tensor_content(tensor_ptr->data_c(), tensor_ptr->data().nbytes());
164 } else if (val->isa<TensorType>()) {
165 value_proto->set_dtype(debugger::DT_TYPE);
166
167 debugger::TypeProto *type_proto = value_proto->mutable_type_val();
168 type_proto->set_data_type(debugger::DT_TENSOR);
169 TypePtr elem_type = dyn_cast<TensorType>(val)->element();
170 type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
171 } else {
172 MS_LOG(INFO) << "Unsupported type " << val->type_name();
173 }
174 }
175
SetScalarToProto(const ScalarPtr & val,debugger::ValueProto * value_proto) const176 void DebuggerProtoExporter::SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto) const {
177 if (val == nullptr || value_proto == nullptr) {
178 return;
179 }
180
181 if (val->isa<BoolImm>()) {
182 const BoolImmPtr &value = dyn_cast<BoolImm>(val);
183 value_proto->set_dtype(debugger::DT_BOOL);
184 value_proto->set_bool_val(value->value());
185 } else if (val->isa<Int8Imm>()) {
186 const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
187 value_proto->set_dtype(debugger::DT_INT8);
188 value_proto->set_int_val(value->value());
189 } else if (val->isa<Int16Imm>()) {
190 const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
191 value_proto->set_dtype(debugger::DT_INT16);
192 value_proto->set_int_val(value->value());
193 } else if (val->isa<Int32Imm>()) {
194 const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
195 value_proto->set_dtype(debugger::DT_INT32);
196 value_proto->set_int_val(value->value());
197 } else if (val->isa<Int64Imm>()) {
198 const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
199 value_proto->set_dtype(debugger::DT_INT64);
200 value_proto->set_int_val(value->value());
201 } else if (val->isa<UInt8Imm>()) {
202 const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
203 value_proto->set_dtype(debugger::DT_UINT8);
204 value_proto->set_uint_val(value->value());
205 } else if (val->isa<UInt16Imm>()) {
206 const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
207 value_proto->set_dtype(debugger::DT_UINT16);
208 value_proto->set_uint_val(value->value());
209 } else if (val->isa<UInt32Imm>()) {
210 const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
211 value_proto->set_dtype(debugger::DT_UINT32);
212 value_proto->set_uint_val(value->value());
213 } else if (val->isa<UInt64Imm>()) {
214 const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
215 value_proto->set_dtype(debugger::DT_UINT64);
216 value_proto->set_uint_val(value->value());
217 } else if (val->isa<FP32Imm>()) {
218 const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
219 value_proto->set_dtype(debugger::DT_FLOAT32);
220 value_proto->set_float_val(value->value());
221 } else if (val->isa<FP64Imm>()) {
222 const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
223 value_proto->set_dtype(debugger::DT_FLOAT64);
224 value_proto->set_double_val(value->value());
225 } else {
226 MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString();
227 }
228 }
229
SetSequenceToProto(const ValueSequencePtr & val,debugger::ValueProto * value_proto) const230 void DebuggerProtoExporter::SetSequenceToProto(const ValueSequencePtr &val, debugger::ValueProto *value_proto) const {
231 if (val == nullptr || value_proto == nullptr) {
232 return;
233 }
234
235 if (val->isa<ValueTuple>()) {
236 const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
237 value_proto->set_dtype(debugger::DT_TUPLE);
238 for (const auto &item : value->value()) {
239 SetValueToProto(item, value_proto->add_values());
240 }
241 } else if (val->isa<ValueList>()) {
242 const ValueListPtr &value = dyn_cast<ValueList>(val);
243 value_proto->set_dtype(debugger::DT_LIST);
244 for (const auto &item : value->value()) {
245 SetValueToProto(item, value_proto->add_values());
246 }
247 }
248 }
249
SetDictionaryToProto(const ValueDictionaryPtr & val,debugger::ValueProto * value_proto) const250 void DebuggerProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val,
251 debugger::ValueProto *value_proto) const {
252 if (val == nullptr || value_proto == nullptr) {
253 return;
254 }
255
256 value_proto->set_dtype(debugger::DT_DICT);
257 for (const auto &item : val->value()) {
258 debugger::NamedValueProto *named_val = value_proto->add_dict_val();
259 if (!item.first->isa<StringImm>()) {
260 MS_LOG(EXCEPTION) << "The key of NamedValueProto should be string type, but got " << item.first->ToString();
261 }
262 named_val->set_key(GetValue<std::string>(item.first));
263 SetValueToProto(item.second, named_val->mutable_value());
264 }
265 }
266
GetOpNodeTypeAndAttrs(const FuncGraphPtr &,const AnfNodePtr & node,debugger::NodeProto * node_proto) const267 void DebuggerProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node,
268 debugger::NodeProto *node_proto) const {
269 if (node == nullptr || node_proto == nullptr) {
270 return;
271 }
272
273 if (node->isa<CNode>() || node->isa<Parameter>() || IsValueNode<FuncGraph>(node)) {
274 MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString();
275 }
276
277 if (!IsValueNode<Primitive>(node)) {
278 MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
279 }
280
281 const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
282 node_proto->set_op_type(prim->name());
283 for (const auto &attr : prim->attrs()) {
284 debugger::AttributeProto *attr_proto = node_proto->add_attribute();
285 attr_proto->set_name(attr.first);
286 SetValueToProto(attr.second, attr_proto->mutable_value());
287 }
288 node_proto->set_scope(node->scope()->name());
289 }
290
GetOpNodeInputId(const FuncGraphPtr &,const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr) const291 std::string DebuggerProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
292 const std::map<AnfNodePtr, size_t> &apply_map,
293 std::map<AnfNodePtr, size_t> *const_map_ptr) const {
294 if (node == nullptr || const_map_ptr == nullptr) {
295 return "";
296 }
297
298 if (node->isa<CNode>()) {
299 auto iter = apply_map.find(node);
300 if (iter == apply_map.end()) {
301 MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
302 }
303 return std::to_string(iter->second);
304 }
305
306 if (AnfUtils::IsCustomActorNode(node)) {
307 return AnfUtils::GetCustomActorName(node);
308 }
309
310 if (node->isa<Parameter>()) {
311 return node->ToString();
312 }
313
314 if (node->isa<ValueNode>()) {
315 std::map<AnfNodePtr, size_t>::const_iterator iter = const_map_ptr->find(node);
316 if (iter == const_map_ptr->end()) {
317 // Start index number from 1
318 const auto const_idx = const_map_ptr->size() + 1;
319 (*const_map_ptr)[node] = const_idx;
320 }
321 return GetConstNodeId((*const_map_ptr)[node]);
322 }
323
324 MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
325 }
326
GetFuncGraphProtoString(const FuncGraphPtr & func_graph,LocDebugDumpMode dump_location)327 std::string DebuggerProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph,
328 LocDebugDumpMode dump_location) {
329 if (func_graph == nullptr) {
330 return "";
331 }
332
333 InitModelInfo();
334 debugger::GraphProto *graph_proto = model_.mutable_graph();
335 ExportFuncGraph(func_graph, graph_proto, dump_location);
336 return model_.SerializeAsString();
337 }
338
GetFuncGraphProto(const FuncGraphPtr & func_graph)339 debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr &func_graph) {
340 if (func_graph == nullptr) {
341 return debugger::ModelProto();
342 }
343
344 InitModelInfo();
345 debugger::GraphProto *graph_proto = model_.mutable_graph();
346 ExportFuncGraph(func_graph, graph_proto);
347 return model_;
348 }
349
ExportFuncGraph(const FuncGraphPtr & func_graph,debugger::GraphProto * const graph_proto,LocDebugDumpMode dump_location)350 void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
351 LocDebugDumpMode dump_location) {
352 if (func_graph == nullptr || graph_proto == nullptr) {
353 return;
354 }
355
356 // map for store ValueNodes of this graph
357 std::map<AnfNodePtr, size_t> const_map;
358
359 // set graph name
360 graph_proto->set_name(func_graph->ToString());
361
362 MS_LOG(INFO) << "graph names: " << func_graph->ToString();
363
364 // cast FuncGraph to KernelGraph to access root_graph_id()
365 auto kernel_graph = static_cast<session::KernelGraph *>(func_graph.get());
366 uint32_t root_graph_id = kernel_graph->root_graph_id();
367 uint32_t graph_id = kernel_graph->graph_id();
368 MS_LOG(INFO) << "root graph id: " << root_graph_id;
369
370 // set root graph id
371 if (kernel_graph->is_graph_run_mode()) {
372 graph_proto->set_root_name(std::to_string(root_graph_id));
373 } else {
374 graph_proto->set_root_name(std::to_string(graph_id));
375 }
376
377 ExportParameters(func_graph, graph_proto);
378
379 ExportCNodes(func_graph, graph_proto, &const_map, dump_location);
380
381 ExportValueNodes(const_map, graph_proto);
382 }
383
ExportParameters(const FuncGraphPtr & func_graph,debugger::GraphProto * graph_proto) const384 void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) const {
385 if (func_graph == nullptr || graph_proto == nullptr) {
386 return;
387 }
388
389 // cast FuncGraph to KernelGraph to access inputs()
390 std::vector<AnfNodePtr> parameters = static_cast<session::KernelGraph *>(func_graph.get())->inputs();
391
392 for (auto ¶m : parameters) {
393 debugger::ParameterProto *param_proto = graph_proto->add_parameters();
394 param_proto->set_name(param->ToString());
395
396 SetNodeOutputType(param, param_proto->mutable_type());
397
398 const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
399 if (param_ptr == nullptr) {
400 MS_LOG(INFO) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
401 }
402 }
403 }
404
ExportCNodes(const FuncGraphPtr & func_graph,debugger::GraphProto * const graph_proto,std::map<AnfNodePtr,size_t> * const_map_ptr,LocDebugDumpMode dump_location)405 void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
406 std::map<AnfNodePtr, size_t> *const_map_ptr, LocDebugDumpMode dump_location) {
407 if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
408 return;
409 }
410 // topo sort nodes
411 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
412 std::map<AnfNodePtr, size_t> apply_map;
413 for (const AnfNodePtr &node : nodes) {
414 MS_EXCEPTION_IF_NULL(node);
415 if (!node->isa<CNode>()) {
416 continue;
417 }
418 auto cnode = node->cast<CNodePtr>();
419 if (cnode != func_graph->get_return()) {
420 ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto, dump_location);
421 } else {
422 ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
423 }
424 }
425 }
426
ExportCNode(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,size_t> * apply_map_ptr,std::map<AnfNodePtr,size_t> * const_map_ptr,debugger::GraphProto * const graph_proto,LocDebugDumpMode dump_location)427 void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
428 std::map<AnfNodePtr, size_t> *apply_map_ptr,
429 std::map<AnfNodePtr, size_t> *const_map_ptr,
430 debugger::GraphProto *const graph_proto, LocDebugDumpMode dump_location) {
431 if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
432 graph_proto == nullptr) {
433 return;
434 }
435
436 auto apply_idx = apply_map_ptr->size() + 1;
437 (*apply_map_ptr)[node] = apply_idx;
438
439 auto &inputs = node->inputs();
440 if (inputs.size() < 1) {
441 MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
442 }
443 AnfNodePtr op = inputs[0];
444 debugger::NodeProto *node_proto = graph_proto->add_node();
445
446 // CNode/ConstGraph/Const/Parameter
447 if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
448 MS_LOG(WARNING) << "Operator must be a primitive";
449 } else {
450 GetOpNodeTypeAndAttrs(func_graph, op, node_proto);
451 node_proto->set_name(std::to_string(apply_idx));
452 node_proto->set_scope(node->scope()->name());
453
454 // add full_name for debugger
455 std::string full_name = GetKernelNodeName(node);
456 node_proto->set_full_name(full_name);
457 MS_LOG(INFO) << "full_name: " << full_name;
458 if (dump_location == kDebugWholeStack) {
459 std::ostringstream buffer;
460 auto traces = mindspore::trace::GetSourceLineList(node);
461 for (auto &trace : traces) {
462 buffer << " # " << trace;
463 }
464 node_proto->set_source_address(buffer.str());
465 }
466 // process OP inputs
467 for (size_t i = 1; i < inputs.size(); ++i) {
468 debugger::InputProto *input_proto = node_proto->add_input();
469 input_proto->set_type(debugger::InputProto_EdgeType_DATA_EDGE);
470 std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
471 input_proto->set_name(id);
472 }
473
474 // set node output type
475 SetNodeOutputType(node, node_proto->mutable_output_type());
476 }
477 }
478
ExportFuncGraphOutput(const FuncGraphPtr & func_graph,const CNodePtr & ret_node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr,debugger::GraphProto * graph_proto) const479 void DebuggerProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
480 const std::map<AnfNodePtr, size_t> &apply_map,
481 std::map<AnfNodePtr, size_t> *const_map_ptr,
482 debugger::GraphProto *graph_proto) const {
483 if (ret_node == nullptr || !ret_node->isa<CNode>()) {
484 MS_LOG(EXCEPTION) << "Graph return node is illegal";
485 }
486 AnfNodePtr arg = ret_node->input(1);
487 if (graph_proto == nullptr) {
488 MS_LOG(EXCEPTION) << "graph_proto is nullptr";
489 }
490 debugger::OutputProto *output_proto = graph_proto->add_outputs();
491 if (output_proto == nullptr) {
492 MS_LOG(EXCEPTION) << "output_proto is nullptr";
493 }
494 std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
495 output_proto->set_name(id);
496 SetNodeOutputType(arg, output_proto->mutable_type());
497 }
498
CompareValue(const std::pair<AnfNodePtr,size_t> & x,const std::pair<AnfNodePtr,size_t> & y)499 static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
500 return x.second < y.second;
501 }
502
ExportValueNodes(const std::map<AnfNodePtr,size_t> & const_map,debugger::GraphProto * graph_proto) const503 void DebuggerProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map,
504 debugger::GraphProto *graph_proto) const {
505 std::vector<std::pair<AnfNodePtr, size_t>> nodes;
506 (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
507 [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
508
509 sort(nodes.begin(), nodes.end(), CompareValue);
510
511 for (auto &item : nodes) {
512 if (graph_proto == nullptr) {
513 MS_LOG(EXCEPTION) << "graph_proto is nullptr";
514 }
515 debugger::NamedValueProto *named_value = graph_proto->add_const_vals();
516 MS_EXCEPTION_IF_NULL(named_value);
517 named_value->set_key(GetConstNodeId(item.second));
518
519 // cst full name: Default--data-x
520 std::string node_name = GetKernelNodeName(item.first);
521 GetFileKernelName(NOT_NULL(&node_name));
522 named_value->set_full_name(node_name);
523 if (GetValueNode(item.first)->isa<tensor::Tensor>()) {
524 continue;
525 }
526 SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
527 }
528 }
529
InitModelInfo()530 void DebuggerProtoExporter::InitModelInfo() { model_.set_ir_version(static_cast<int64_t>(debugger::IR_VERSION)); }
531
GetDebuggerFuncGraphProto(const FuncGraphPtr & func_graph)532 debugger::ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph) {
533 DebuggerProtoExporter exporter;
534 return exporter.GetFuncGraphProto(func_graph);
535 }
536
GetDebuggerNumberDataType(const TypePtr & type)537 debugger::DataType GetDebuggerNumberDataType(const TypePtr &type) {
538 switch (type->type_id()) {
539 case kNumberTypeBool:
540 return debugger::DT_BOOL;
541 case kNumberTypeInt8:
542 return debugger::DT_INT8;
543 case kNumberTypeInt16:
544 return debugger::DT_INT16;
545 case kNumberTypeInt32:
546 return debugger::DT_INT32;
547 case kNumberTypeInt64:
548 return debugger::DT_INT64;
549 case kNumberTypeUInt8:
550 return debugger::DT_UINT8;
551 case kNumberTypeUInt16:
552 return debugger::DT_UINT16;
553 case kNumberTypeUInt32:
554 return debugger::DT_UINT32;
555 case kNumberTypeUInt64:
556 return debugger::DT_UINT64;
557 case kNumberTypeFloat16:
558 return debugger::DT_FLOAT16;
559 case kNumberTypeFloat32:
560 return debugger::DT_FLOAT32;
561 case kNumberTypeFloat64:
562 return debugger::DT_FLOAT64;
563 case kNumberTypeBFloat16:
564 return debugger::DT_BFLOAT16;
565 case kNumberTypeInt:
566 return debugger::DT_BASE_INT;
567 case kNumberTypeUInt:
568 return debugger::DT_BASE_UINT;
569 case kNumberTypeFloat:
570 return debugger::DT_BASE_FLOAT;
571 default:
572 MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
573 }
574 }
575
576 #ifdef ENABLE_DUMP_IR
DumpIRProtoWithSrcInfo(const FuncGraphPtr & func_graph,const std::string & suffix,const std::string & target_dir,LocDebugDumpMode dump_location)577 void DumpIRProtoWithSrcInfo(const FuncGraphPtr &func_graph, const std::string &suffix, const std::string &target_dir,
578 LocDebugDumpMode dump_location) {
579 DebuggerProtoExporter exporter;
580 std::string graph_proto = exporter.GetFuncGraphProtoString(func_graph, dump_location);
581 if (func_graph == nullptr) {
582 MS_LOG(ERROR) << "Func graph is nullptr";
583 return;
584 }
585 std::string file_path = target_dir + "/" + "ms_output_" + suffix + ".pb";
586 auto realpath = Common::CreatePrefixPath(file_path);
587 if (!realpath.has_value()) {
588 MS_LOG(ERROR) << "Get real path failed, path=" << file_path;
589 return;
590 }
591 ChangeFileMode(realpath.value(), S_IWUSR);
592
593 // write to pb file
594 std::ofstream ofs(realpath.value());
595 if (!ofs.is_open()) {
596 MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
597 return;
598 }
599 ofs << graph_proto;
600 ofs.close();
601 // set file mode to read only by user
602 ChangeFileMode(file_path, S_IRUSR);
603 }
604
DumpConstantInfo(const KernelGraphPtr & graph,const std::string & target_dir)605 void DumpConstantInfo(const KernelGraphPtr &graph, const std::string &target_dir) {
606 // Dump constant to npy file
607 MS_LOG(INFO) << "Start e2e dump Const values";
608 auto debugger = Debugger::GetInstance();
609 MS_EXCEPTION_IF_NULL(debugger);
610 E2eDump::DumpConstantData(graph.get(), target_dir, debugger.get());
611 }
612 #else
DumpIRProtoWithSrcInfo(const FuncGraphPtr &,const std::string &,const std::string &,LocDebugDumpMode)613 void DumpIRProtoWithSrcInfo(const FuncGraphPtr &, const std::string &, const std::string &, LocDebugDumpMode) {
614 static bool already_printed = false;
615 if (already_printed) {
616 return;
617 }
618 already_printed = true;
619 MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled,"
620 << "because ENABLE_DEBUGGER option is off"
621 << "please recompile source to enable it. See help of building script.";
622 }
DumpConstantInfo(const KernelGraphPtr &,const std::string &)623 void DumpConstantInfo(const KernelGraphPtr &, const std::string &) {
624 static bool already_printed = false;
625 if (already_printed) {
626 return;
627 }
628 already_printed = true;
629 MS_LOG(WARNING) << "The functionality of dumping function graph constant is disabled, "
630 << "please recompile source to enable it. See help of building script.";
631 }
632 #endif
633 } // namespace mindspore
634