• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "debug/debugger/proto_exporter.h"
17 
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <utility>
22 #include <algorithm>
23 
24 #include "utils/hash_set.h"
25 #include "include/common/debug/anf_dump_utils.h"
26 #include "include/backend/debug/data_dump/dump_utils.h"
27 #include "include/common/debug/common.h"
28 #include "include/backend/debug/debugger/debugger.h"
29 #include "ir/graph_utils.h"
30 #include "utils/symbolic.h"
31 #include "utils/trace_base.h"
32 #include "include/backend/debug/data_dump/e2e_dump.h"
33 #include "mindspore/core/utils/file_utils.h"
34 #include "utils/anf_utils.h"
35 #include "debug/debugger/debugger_utils.h"
36 
37 namespace mindspore {
38 
39 using TypeInfoToProtoTypeMap = std::vector<std::pair<uint32_t, debugger::DataType>>;
40 
41 void SetOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto);
42 
CheckIfValidType(const TypePtr & type,debugger::TypeProto * const type_proto)43 void CheckIfValidType(const TypePtr &type, debugger::TypeProto *const type_proto) {
44   if (!(type->isa<Number>() || type->isa<TensorType>() || type->isa<Tuple>() || type->isa<TypeType>() ||
45         type->isa<List>() || type->isa<TypeAny>() || type->isa<RefKeyType>() || type->isa<RefType>() ||
46         type->isa<Function>() || type->isa<TypeNone>() || type->isa<String>() || type->isa<SymbolicKeyType>() ||
47         type->isa<MapTensorType>() || type->isa<UMonadType>() || type->isa<IOMonadType>())) {
48     MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
49   }
50   if (type->isa<Number>()) {
51     type_proto->set_data_type(GetDebuggerNumberDataType(type));
52   }
53 }
54 
SetTensorTypeProto(const TypePtr & type,const BaseShapePtr & shape,debugger::TypeProto * type_proto)55 void SetTensorTypeProto(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) {
56   TypePtr elem_type = dyn_cast<TensorType>(type)->element();
57   type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
58   if (shape != nullptr && shape->isa<abstract::Shape>()) {
59     abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
60     for (const auto &elem : shape_info->shape()) {
61       type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
62     }
63   }
64 }
65 
SetTupleTypeProto(const TypePtr & type,debugger::TypeProto * type_proto)66 void SetTupleTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) {
67   TuplePtr tuple_type = dyn_cast<Tuple>(type);
68   for (const auto &elem_type : tuple_type->elements()) {
69     SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
70   }
71 }
72 
SetListTypeProto(const TypePtr & type,debugger::TypeProto * type_proto)73 void SetListTypeProto(const TypePtr &type, debugger::TypeProto *type_proto) {
74   ListPtr list_type = dyn_cast<List>(type);
75   for (const auto &elem_type : list_type->elements()) {
76     SetOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
77   }
78 }
79 
80 static TypeInfoToProtoTypeMap type_info_to_proto_type = {
81   {TensorType::kTypeId, debugger::DT_TENSOR}, {Tuple::kTypeId, debugger::DT_TUPLE},
82   {TypeType::kTypeId, debugger::DT_TYPE},     {List::kTypeId, debugger::DT_LIST},
83   {TypeAny::kTypeId, debugger::DT_ANY},       {RefKeyType::kTypeId, debugger::DT_REFKEY},
84   {RefType::kTypeId, debugger::DT_REF},       {Function::kTypeId, debugger::DT_GRAPH},
85   {TypeNone::kTypeId, debugger::DT_NONE},     {String::kTypeId, debugger::DT_STRING},
86   {UMonadType::kTypeId, debugger::DT_UMONAD}, {IOMonadType::kTypeId, debugger::DT_IOMONAD}};
87 
SetOutputType(const TypePtr & type,const BaseShapePtr & shape,debugger::TypeProto * type_proto)88 void SetOutputType(const TypePtr &type, const BaseShapePtr &shape, debugger::TypeProto *type_proto) {
89   if (type_proto == nullptr) {
90     return;
91   }
92   if (type == nullptr) {
93     type_proto->set_data_type(debugger::DT_UNDEFINED);
94     return;
95   }
96   CheckIfValidType(type, type_proto);
97   for (auto &it : type_info_to_proto_type) {
98     if (type->IsFromTypeId(it.first)) {
99       type_proto->set_data_type(it.second);
100       break;
101     }
102   }
103   if (type->isa<TensorType>()) {
104     SetTensorTypeProto(type, shape, type_proto);
105     return;
106   }
107   if (type->isa<Tuple>()) {
108     SetTupleTypeProto(type, type_proto);
109     return;
110   }
111   if (type->isa<List>()) {
112     SetListTypeProto(type, type_proto);
113   }
114 }
115 
SetNodeOutputType(const AnfNodePtr & node,debugger::TypeProto * type_proto) const116 void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto) const {
117   if (node == nullptr || type_proto == nullptr) {
118     return;
119   }
120   SetOutputType(node->Type(), node->Shape(), type_proto);
121 }
122 
SetValueToProto(const ValuePtr & val,debugger::ValueProto * value_proto) const123 void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) const {
124   if (val == nullptr || value_proto == nullptr) {
125     return;
126   }
127 
128   if (val->isa<StringImm>()) {
129     const StringImmPtr &value = dyn_cast<StringImm>(val);
130     value_proto->set_dtype(debugger::DT_STRING);
131     value_proto->set_str_val(value->value());
132   } else if (val->isa<Scalar>()) {
133     SetScalarToProto(dyn_cast<Scalar>(val), value_proto);
134   } else if (val->isa<Bool>()) {
135     value_proto->set_dtype(debugger::DT_TYPE);
136     value_proto->mutable_type_val()->set_data_type(debugger::DT_BOOL);
137   } else if (val->isa<Int>()) {
138     value_proto->set_dtype(debugger::DT_TYPE);
139     value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_INT);
140   } else if (val->isa<Float>()) {
141     value_proto->set_dtype(debugger::DT_TYPE);
142     value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_FLOAT);
143   } else if (val->isa<ValueSequence>()) {
144     SetSequenceToProto(dyn_cast<ValueSequence>(val), value_proto);
145   } else if (val->isa<None>()) {
146     value_proto->set_dtype(debugger::DT_NONE);
147     value_proto->set_str_val("None");
148   } else if (val->isa<SymbolicKeyInstance>()) {
149     SymbolicKeyInstancePtr sym_inst = dyn_cast<SymbolicKeyInstance>(val);
150     ParameterPtr sym_node = dyn_cast<Parameter>(sym_inst->node());
151     value_proto->set_dtype(debugger::DT_SYM_INST);
152     value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString());
153   } else if (val->isa<ValueDictionary>()) {
154     SetDictionaryToProto(dyn_cast<ValueDictionary>(val), value_proto);
155   } else if (val->isa<tensor::Tensor>()) {
156     tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
157     value_proto->set_dtype(debugger::DT_TENSOR);
158     debugger::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
159     tensor_proto->set_data_type(GetDebuggerNumberDataType(tensor_ptr->Dtype()));
160     for (auto &elem : tensor_ptr->shape()) {
161       tensor_proto->add_dims(elem);
162     }
163     tensor_proto->set_tensor_content(tensor_ptr->data_c(), tensor_ptr->data().nbytes());
164   } else if (val->isa<TensorType>()) {
165     value_proto->set_dtype(debugger::DT_TYPE);
166 
167     debugger::TypeProto *type_proto = value_proto->mutable_type_val();
168     type_proto->set_data_type(debugger::DT_TENSOR);
169     TypePtr elem_type = dyn_cast<TensorType>(val)->element();
170     type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type));
171   } else {
172     MS_LOG(INFO) << "Unsupported type " << val->type_name();
173   }
174 }
175 
SetScalarToProto(const ScalarPtr & val,debugger::ValueProto * value_proto) const176 void DebuggerProtoExporter::SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto) const {
177   if (val == nullptr || value_proto == nullptr) {
178     return;
179   }
180 
181   if (val->isa<BoolImm>()) {
182     const BoolImmPtr &value = dyn_cast<BoolImm>(val);
183     value_proto->set_dtype(debugger::DT_BOOL);
184     value_proto->set_bool_val(value->value());
185   } else if (val->isa<Int8Imm>()) {
186     const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
187     value_proto->set_dtype(debugger::DT_INT8);
188     value_proto->set_int_val(value->value());
189   } else if (val->isa<Int16Imm>()) {
190     const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
191     value_proto->set_dtype(debugger::DT_INT16);
192     value_proto->set_int_val(value->value());
193   } else if (val->isa<Int32Imm>()) {
194     const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
195     value_proto->set_dtype(debugger::DT_INT32);
196     value_proto->set_int_val(value->value());
197   } else if (val->isa<Int64Imm>()) {
198     const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
199     value_proto->set_dtype(debugger::DT_INT64);
200     value_proto->set_int_val(value->value());
201   } else if (val->isa<UInt8Imm>()) {
202     const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
203     value_proto->set_dtype(debugger::DT_UINT8);
204     value_proto->set_uint_val(value->value());
205   } else if (val->isa<UInt16Imm>()) {
206     const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
207     value_proto->set_dtype(debugger::DT_UINT16);
208     value_proto->set_uint_val(value->value());
209   } else if (val->isa<UInt32Imm>()) {
210     const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
211     value_proto->set_dtype(debugger::DT_UINT32);
212     value_proto->set_uint_val(value->value());
213   } else if (val->isa<UInt64Imm>()) {
214     const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
215     value_proto->set_dtype(debugger::DT_UINT64);
216     value_proto->set_uint_val(value->value());
217   } else if (val->isa<FP32Imm>()) {
218     const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
219     value_proto->set_dtype(debugger::DT_FLOAT32);
220     value_proto->set_float_val(value->value());
221   } else if (val->isa<FP64Imm>()) {
222     const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
223     value_proto->set_dtype(debugger::DT_FLOAT64);
224     value_proto->set_double_val(value->value());
225   } else {
226     MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString();
227   }
228 }
229 
SetSequenceToProto(const ValueSequencePtr & val,debugger::ValueProto * value_proto) const230 void DebuggerProtoExporter::SetSequenceToProto(const ValueSequencePtr &val, debugger::ValueProto *value_proto) const {
231   if (val == nullptr || value_proto == nullptr) {
232     return;
233   }
234 
235   if (val->isa<ValueTuple>()) {
236     const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
237     value_proto->set_dtype(debugger::DT_TUPLE);
238     for (const auto &item : value->value()) {
239       SetValueToProto(item, value_proto->add_values());
240     }
241   } else if (val->isa<ValueList>()) {
242     const ValueListPtr &value = dyn_cast<ValueList>(val);
243     value_proto->set_dtype(debugger::DT_LIST);
244     for (const auto &item : value->value()) {
245       SetValueToProto(item, value_proto->add_values());
246     }
247   }
248 }
249 
SetDictionaryToProto(const ValueDictionaryPtr & val,debugger::ValueProto * value_proto) const250 void DebuggerProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val,
251                                                  debugger::ValueProto *value_proto) const {
252   if (val == nullptr || value_proto == nullptr) {
253     return;
254   }
255 
256   value_proto->set_dtype(debugger::DT_DICT);
257   for (const auto &item : val->value()) {
258     debugger::NamedValueProto *named_val = value_proto->add_dict_val();
259     if (!item.first->isa<StringImm>()) {
260       MS_LOG(EXCEPTION) << "The key of NamedValueProto should be string type, but got " << item.first->ToString();
261     }
262     named_val->set_key(GetValue<std::string>(item.first));
263     SetValueToProto(item.second, named_val->mutable_value());
264   }
265 }
266 
GetOpNodeTypeAndAttrs(const FuncGraphPtr &,const AnfNodePtr & node,debugger::NodeProto * node_proto) const267 void DebuggerProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node,
268                                                   debugger::NodeProto *node_proto) const {
269   if (node == nullptr || node_proto == nullptr) {
270     return;
271   }
272 
273   if (node->isa<CNode>() || node->isa<Parameter>() || IsValueNode<FuncGraph>(node)) {
274     MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString();
275   }
276 
277   if (!IsValueNode<Primitive>(node)) {
278     MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
279   }
280 
281   const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
282   node_proto->set_op_type(prim->name());
283   for (const auto &attr : prim->attrs()) {
284     debugger::AttributeProto *attr_proto = node_proto->add_attribute();
285     attr_proto->set_name(attr.first);
286     SetValueToProto(attr.second, attr_proto->mutable_value());
287   }
288   node_proto->set_scope(node->scope()->name());
289 }
290 
GetOpNodeInputId(const FuncGraphPtr &,const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr) const291 std::string DebuggerProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
292                                                     const std::map<AnfNodePtr, size_t> &apply_map,
293                                                     std::map<AnfNodePtr, size_t> *const_map_ptr) const {
294   if (node == nullptr || const_map_ptr == nullptr) {
295     return "";
296   }
297 
298   if (node->isa<CNode>()) {
299     auto iter = apply_map.find(node);
300     if (iter == apply_map.end()) {
301       MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map";
302     }
303     return std::to_string(iter->second);
304   }
305 
306   if (AnfUtils::IsCustomActorNode(node)) {
307     return AnfUtils::GetCustomActorName(node);
308   }
309 
310   if (node->isa<Parameter>()) {
311     return node->ToString();
312   }
313 
314   if (node->isa<ValueNode>()) {
315     std::map<AnfNodePtr, size_t>::const_iterator iter = const_map_ptr->find(node);
316     if (iter == const_map_ptr->end()) {
317       // Start index number from 1
318       const auto const_idx = const_map_ptr->size() + 1;
319       (*const_map_ptr)[node] = const_idx;
320     }
321     return GetConstNodeId((*const_map_ptr)[node]);
322   }
323 
324   MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
325 }
326 
GetFuncGraphProtoString(const FuncGraphPtr & func_graph,LocDebugDumpMode dump_location)327 std::string DebuggerProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph,
328                                                            LocDebugDumpMode dump_location) {
329   if (func_graph == nullptr) {
330     return "";
331   }
332 
333   InitModelInfo();
334   debugger::GraphProto *graph_proto = model_.mutable_graph();
335   ExportFuncGraph(func_graph, graph_proto, dump_location);
336   return model_.SerializeAsString();
337 }
338 
GetFuncGraphProto(const FuncGraphPtr & func_graph)339 debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr &func_graph) {
340   if (func_graph == nullptr) {
341     return debugger::ModelProto();
342   }
343 
344   InitModelInfo();
345   debugger::GraphProto *graph_proto = model_.mutable_graph();
346   ExportFuncGraph(func_graph, graph_proto);
347   return model_;
348 }
349 
ExportFuncGraph(const FuncGraphPtr & func_graph,debugger::GraphProto * const graph_proto,LocDebugDumpMode dump_location)350 void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
351                                             LocDebugDumpMode dump_location) {
352   if (func_graph == nullptr || graph_proto == nullptr) {
353     return;
354   }
355 
356   // map for store ValueNodes of this graph
357   std::map<AnfNodePtr, size_t> const_map;
358 
359   // set graph name
360   graph_proto->set_name(func_graph->ToString());
361 
362   MS_LOG(INFO) << "graph names: " << func_graph->ToString();
363 
364   // cast FuncGraph to KernelGraph to access root_graph_id()
365   auto kernel_graph = static_cast<session::KernelGraph *>(func_graph.get());
366   uint32_t root_graph_id = kernel_graph->root_graph_id();
367   uint32_t graph_id = kernel_graph->graph_id();
368   MS_LOG(INFO) << "root graph id: " << root_graph_id;
369 
370   // set root graph id
371   if (kernel_graph->is_graph_run_mode()) {
372     graph_proto->set_root_name(std::to_string(root_graph_id));
373   } else {
374     graph_proto->set_root_name(std::to_string(graph_id));
375   }
376 
377   ExportParameters(func_graph, graph_proto);
378 
379   ExportCNodes(func_graph, graph_proto, &const_map, dump_location);
380 
381   ExportValueNodes(const_map, graph_proto);
382 }
383 
ExportParameters(const FuncGraphPtr & func_graph,debugger::GraphProto * graph_proto) const384 void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) const {
385   if (func_graph == nullptr || graph_proto == nullptr) {
386     return;
387   }
388 
389   // cast FuncGraph to KernelGraph to access inputs()
390   std::vector<AnfNodePtr> parameters = static_cast<session::KernelGraph *>(func_graph.get())->inputs();
391 
392   for (auto &param : parameters) {
393     debugger::ParameterProto *param_proto = graph_proto->add_parameters();
394     param_proto->set_name(param->ToString());
395 
396     SetNodeOutputType(param, param_proto->mutable_type());
397 
398     const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
399     if (param_ptr == nullptr) {
400       MS_LOG(INFO) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
401     }
402   }
403 }
404 
ExportCNodes(const FuncGraphPtr & func_graph,debugger::GraphProto * const graph_proto,std::map<AnfNodePtr,size_t> * const_map_ptr,LocDebugDumpMode dump_location)405 void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *const graph_proto,
406                                          std::map<AnfNodePtr, size_t> *const_map_ptr, LocDebugDumpMode dump_location) {
407   if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
408     return;
409   }
410   // topo sort nodes
411   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
412   std::map<AnfNodePtr, size_t> apply_map;
413   for (const AnfNodePtr &node : nodes) {
414     MS_EXCEPTION_IF_NULL(node);
415     if (!node->isa<CNode>()) {
416       continue;
417     }
418     auto cnode = node->cast<CNodePtr>();
419     if (cnode != func_graph->get_return()) {
420       ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto, dump_location);
421     } else {
422       ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto);
423     }
424   }
425 }
426 
ExportCNode(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,size_t> * apply_map_ptr,std::map<AnfNodePtr,size_t> * const_map_ptr,debugger::GraphProto * const graph_proto,LocDebugDumpMode dump_location)427 void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
428                                         std::map<AnfNodePtr, size_t> *apply_map_ptr,
429                                         std::map<AnfNodePtr, size_t> *const_map_ptr,
430                                         debugger::GraphProto *const graph_proto, LocDebugDumpMode dump_location) {
431   if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
432       graph_proto == nullptr) {
433     return;
434   }
435 
436   auto apply_idx = apply_map_ptr->size() + 1;
437   (*apply_map_ptr)[node] = apply_idx;
438 
439   auto &inputs = node->inputs();
440   if (inputs.size() < 1) {
441     MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
442   }
443   AnfNodePtr op = inputs[0];
444   debugger::NodeProto *node_proto = graph_proto->add_node();
445 
446   // CNode/ConstGraph/Const/Parameter
447   if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
448     MS_LOG(WARNING) << "Operator must be a primitive";
449   } else {
450     GetOpNodeTypeAndAttrs(func_graph, op, node_proto);
451     node_proto->set_name(std::to_string(apply_idx));
452     node_proto->set_scope(node->scope()->name());
453 
454     // add full_name for debugger
455     std::string full_name = GetKernelNodeName(node);
456     node_proto->set_full_name(full_name);
457     MS_LOG(INFO) << "full_name: " << full_name;
458     if (dump_location == kDebugWholeStack) {
459       std::ostringstream buffer;
460       auto traces = mindspore::trace::GetSourceLineList(node);
461       for (auto &trace : traces) {
462         buffer << "      # " << trace;
463       }
464       node_proto->set_source_address(buffer.str());
465     }
466     // process OP inputs
467     for (size_t i = 1; i < inputs.size(); ++i) {
468       debugger::InputProto *input_proto = node_proto->add_input();
469       input_proto->set_type(debugger::InputProto_EdgeType_DATA_EDGE);
470       std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
471       input_proto->set_name(id);
472     }
473 
474     // set node output type
475     SetNodeOutputType(node, node_proto->mutable_output_type());
476   }
477 }
478 
ExportFuncGraphOutput(const FuncGraphPtr & func_graph,const CNodePtr & ret_node,const std::map<AnfNodePtr,size_t> & apply_map,std::map<AnfNodePtr,size_t> * const_map_ptr,debugger::GraphProto * graph_proto) const479 void DebuggerProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
480                                                   const std::map<AnfNodePtr, size_t> &apply_map,
481                                                   std::map<AnfNodePtr, size_t> *const_map_ptr,
482                                                   debugger::GraphProto *graph_proto) const {
483   if (ret_node == nullptr || !ret_node->isa<CNode>()) {
484     MS_LOG(EXCEPTION) << "Graph return node is illegal";
485   }
486   AnfNodePtr arg = ret_node->input(1);
487   if (graph_proto == nullptr) {
488     MS_LOG(EXCEPTION) << "graph_proto is nullptr";
489   }
490   debugger::OutputProto *output_proto = graph_proto->add_outputs();
491   if (output_proto == nullptr) {
492     MS_LOG(EXCEPTION) << "output_proto is nullptr";
493   }
494   std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr);
495   output_proto->set_name(id);
496   SetNodeOutputType(arg, output_proto->mutable_type());
497 }
498 
CompareValue(const std::pair<AnfNodePtr,size_t> & x,const std::pair<AnfNodePtr,size_t> & y)499 static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
500   return x.second < y.second;
501 }
502 
ExportValueNodes(const std::map<AnfNodePtr,size_t> & const_map,debugger::GraphProto * graph_proto) const503 void DebuggerProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map,
504                                              debugger::GraphProto *graph_proto) const {
505   std::vector<std::pair<AnfNodePtr, size_t>> nodes;
506   (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
507                        [](const std::pair<AnfNodePtr, size_t> &item) { return item; });
508 
509   sort(nodes.begin(), nodes.end(), CompareValue);
510 
511   for (auto &item : nodes) {
512     if (graph_proto == nullptr) {
513       MS_LOG(EXCEPTION) << "graph_proto is nullptr";
514     }
515     debugger::NamedValueProto *named_value = graph_proto->add_const_vals();
516     MS_EXCEPTION_IF_NULL(named_value);
517     named_value->set_key(GetConstNodeId(item.second));
518 
519     // cst full name: Default--data-x
520     std::string node_name = GetKernelNodeName(item.first);
521     GetFileKernelName(NOT_NULL(&node_name));
522     named_value->set_full_name(node_name);
523     if (GetValueNode(item.first)->isa<tensor::Tensor>()) {
524       continue;
525     }
526     SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
527   }
528 }
529 
InitModelInfo()530 void DebuggerProtoExporter::InitModelInfo() { model_.set_ir_version(static_cast<int64_t>(debugger::IR_VERSION)); }
531 
GetDebuggerFuncGraphProto(const FuncGraphPtr & func_graph)532 debugger::ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph) {
533   DebuggerProtoExporter exporter;
534   return exporter.GetFuncGraphProto(func_graph);
535 }
536 
GetDebuggerNumberDataType(const TypePtr & type)537 debugger::DataType GetDebuggerNumberDataType(const TypePtr &type) {
538   switch (type->type_id()) {
539     case kNumberTypeBool:
540       return debugger::DT_BOOL;
541     case kNumberTypeInt8:
542       return debugger::DT_INT8;
543     case kNumberTypeInt16:
544       return debugger::DT_INT16;
545     case kNumberTypeInt32:
546       return debugger::DT_INT32;
547     case kNumberTypeInt64:
548       return debugger::DT_INT64;
549     case kNumberTypeUInt8:
550       return debugger::DT_UINT8;
551     case kNumberTypeUInt16:
552       return debugger::DT_UINT16;
553     case kNumberTypeUInt32:
554       return debugger::DT_UINT32;
555     case kNumberTypeUInt64:
556       return debugger::DT_UINT64;
557     case kNumberTypeFloat16:
558       return debugger::DT_FLOAT16;
559     case kNumberTypeFloat32:
560       return debugger::DT_FLOAT32;
561     case kNumberTypeFloat64:
562       return debugger::DT_FLOAT64;
563     case kNumberTypeBFloat16:
564       return debugger::DT_BFLOAT16;
565     case kNumberTypeInt:
566       return debugger::DT_BASE_INT;
567     case kNumberTypeUInt:
568       return debugger::DT_BASE_UINT;
569     case kNumberTypeFloat:
570       return debugger::DT_BASE_FLOAT;
571     default:
572       MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
573   }
574 }
575 
576 #ifdef ENABLE_DUMP_IR
DumpIRProtoWithSrcInfo(const FuncGraphPtr & func_graph,const std::string & suffix,const std::string & target_dir,LocDebugDumpMode dump_location)577 void DumpIRProtoWithSrcInfo(const FuncGraphPtr &func_graph, const std::string &suffix, const std::string &target_dir,
578                             LocDebugDumpMode dump_location) {
579   DebuggerProtoExporter exporter;
580   std::string graph_proto = exporter.GetFuncGraphProtoString(func_graph, dump_location);
581   if (func_graph == nullptr) {
582     MS_LOG(ERROR) << "Func graph is nullptr";
583     return;
584   }
585   std::string file_path = target_dir + "/" + "ms_output_" + suffix + ".pb";
586   auto realpath = Common::CreatePrefixPath(file_path);
587   if (!realpath.has_value()) {
588     MS_LOG(ERROR) << "Get real path failed, path=" << file_path;
589     return;
590   }
591   ChangeFileMode(realpath.value(), S_IWUSR);
592 
593   // write to pb file
594   std::ofstream ofs(realpath.value());
595   if (!ofs.is_open()) {
596     MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
597     return;
598   }
599   ofs << graph_proto;
600   ofs.close();
601   // set file mode to read only by user
602   ChangeFileMode(file_path, S_IRUSR);
603 }
604 
DumpConstantInfo(const KernelGraphPtr & graph,const std::string & target_dir)605 void DumpConstantInfo(const KernelGraphPtr &graph, const std::string &target_dir) {
606   // Dump constant to npy file
607   MS_LOG(INFO) << "Start e2e dump Const values";
608   auto debugger = Debugger::GetInstance();
609   MS_EXCEPTION_IF_NULL(debugger);
610   E2eDump::DumpConstantData(graph.get(), target_dir, debugger.get());
611 }
612 #else
DumpIRProtoWithSrcInfo(const FuncGraphPtr &,const std::string &,const std::string &,LocDebugDumpMode)613 void DumpIRProtoWithSrcInfo(const FuncGraphPtr &, const std::string &, const std::string &, LocDebugDumpMode) {
614   static bool already_printed = false;
615   if (already_printed) {
616     return;
617   }
618   already_printed = true;
619   MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled,"
620                   << "because ENABLE_DEBUGGER option is off"
621                   << "please recompile source to enable it. See help of building script.";
622 }
DumpConstantInfo(const KernelGraphPtr &,const std::string &)623 void DumpConstantInfo(const KernelGraphPtr &, const std::string &) {
624   static bool already_printed = false;
625   if (already_printed) {
626     return;
627   }
628   already_printed = true;
629   MS_LOG(WARNING) << "The functionality of dumping function graph constant is disabled, "
630                   << "please recompile source to enable it. See help of building script.";
631 }
632 #endif
633 }  // namespace mindspore
634