1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/label.h"
18 #include <algorithm>
19 #include <sstream>
20 #include <utility>
21
22 #include "utils/info.h"
23 #include "ir/func_graph.h"
24
25 namespace mindspore {
26 namespace label_manage {
27 static TraceLabelType global_trace_type = (common::GetEnv("ENV_TRACE_LABEL_WITH_UNIQUE_ID") == "1")
28 ? TraceLabelType::kWithUniqueId
29 : TraceLabelType::kShortSymbol;
GetGlobalTraceLabelType()30 TraceLabelType GetGlobalTraceLabelType() { return global_trace_type; }
SetGlobalTraceLabelType(TraceLabelType label_type)31 void SetGlobalTraceLabelType(TraceLabelType label_type) { global_trace_type = label_type; }
32
33 struct NameWithTrace {
34 std::string name;
35 std::vector<std::string> trace_labels;
36 };
GetTraceName(const TraceInfoPtr & trace_info,TraceLabelType trace_label)37 static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) {
38 switch (trace_label) {
39 case TraceLabelType::kShortSymbol:
40 return trace_info->symbol();
41 case TraceLabelType::kFullName:
42 return "_" + trace_info->full_name() + "_";
43 default:
44 return "";
45 }
46 }
47
RootName(const DebugInfoPtr & debug_info,TraceLabelType trace_label)48 NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
49 NameWithTrace trace_name;
50 // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node
51 MS_EXCEPTION_IF_NULL(debug_info);
52 auto temp_info = debug_info;
53 while (temp_info != nullptr) {
54 if (temp_info->trace_info() != nullptr) {
55 if (temp_info->trace_info()->isa<TraceResolve>() || temp_info->trace_info()->isa<TraceExpandJ>() ||
56 temp_info->trace_info()->isa<TraceGenMetaFuncGraph>() ||
57 temp_info->trace_info()->isa<TraceGenerateVarArg>() || temp_info->trace_info()->isa<TraceGenerateKwArg>()) {
58 break;
59 }
60 trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label));
61 temp_info = temp_info->trace_info()->debug_info();
62 } else {
63 break;
64 }
65 }
66 if (!temp_info->name().empty()) {
67 trace_name.name = temp_info->name();
68 } else {
69 trace_name.name = temp_info->debug_name();
70 }
71 return trace_name;
72 }
73
CombineTraceTypes(const std::string & root_name,const std::vector<std::string> & trace_labels)74 std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) {
75 std::string tags = "";
76 for (auto &itr : trace_labels) {
77 std::string symbol = itr;
78 tags = tags + symbol;
79 }
80 return tags + root_name;
81 }
82
83 // get the label name of the node debug info
LabelString(const DebugInfoPtr & debug_info,TraceLabelType trace_label)84 std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
85 NameWithTrace trace_name = RootName(debug_info, trace_label);
86 return CombineTraceTypes(trace_name.name, trace_name.trace_labels);
87 }
88
CombineUniqueID(const DebugInfoPtr & debug_info)89 std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
90 auto temp_info = debug_info;
91 std::string label = "";
92 while (temp_info != nullptr) {
93 if (!temp_info->name().empty()) {
94 label = label + temp_info->name();
95 } else {
96 // the symbol 'U' is for identification of number
97 label = label + "U" + std::to_string(temp_info->unique_id());
98 }
99
100 if (temp_info->trace_info() != nullptr) {
101 label = label + "_" + temp_info->trace_info()->full_name() + "_";
102 temp_info = temp_info->trace_info()->debug_info();
103 } else {
104 temp_info = nullptr;
105 }
106 }
107 return label;
108 }
109
110 // get trace with unique id chain
LabelStringUnique(const DebugInfoPtr & debug_info)111 std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
112
Label(const DebugInfoPtr & debug_info,TraceLabelType trace_label)113 std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
114 if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
115 return LabelStringUnique(debug_info);
116 }
117 return LabelString(debug_info, trace_label);
118 }
119 } // namespace label_manage
120 } // namespace mindspore
121