• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/label.h"
18 
19 #include <vector>
20 #include "utils/info.h"
21 #include "utils/compile_config.h"
22 
23 namespace {
24 using mindspore::DebugInfoPtr;
25 using mindspore::NodeDebugInfo;
26 using mindspore::TraceInfoPtr;
27 using mindspore::trace::TraceLabelType;
28 
WithUniqueIdPtr()29 bool *WithUniqueIdPtr() {
30   static bool with_unique_id = mindspore::common::GetCompileConfig("TRACE_LABEL_WITH_UNIQUE_ID") == "1";
31   return &with_unique_id;
32 }
33 
GetCurrentTraceLabelType()34 TraceLabelType GetCurrentTraceLabelType() {
35   if (*WithUniqueIdPtr()) {
36     return TraceLabelType::kWithUniqueId;
37   }
38   return TraceLabelType::kShortSymbol;
39 }
40 
CombineUniqueID(const DebugInfoPtr & debug_info)41 std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
42   auto root_info = debug_info;
43   std::string label = "";
44   while (root_info != nullptr) {
45     if (!root_info->name().empty()) {
46       label = label + root_info->name();
47     } else {
48       // The symbol 'U' is for identification of number
49       label = label + "U" + std::to_string(root_info->unique_id());
50     }
51 
52     if (root_info->trace_info() != nullptr) {
53       label = label + "_" + root_info->trace_info()->full_name() + "_";
54       root_info = root_info->trace_info()->debug_info();
55     } else {
56       root_info = nullptr;
57     }
58   }
59   return label;
60 }
61 
62 // Get trace with unique id chain
LabelStringUnique(const DebugInfoPtr & debug_info)63 std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
64 
65 struct NameWithTrace {
66   std::string root_name;
67   std::vector<std::string> trace_labels;
68 };
69 
GetTraceName(const TraceInfoPtr & trace_info,TraceLabelType trace_label)70 static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) {
71   switch (trace_label) {
72     case TraceLabelType::kShortSymbol:
73       return trace_info->symbol();
74     case TraceLabelType::kFullName:
75       return "_" + trace_info->full_name() + "_";
76     default:
77       return "";
78   }
79 }
80 
CollectTraceInfos(const DebugInfoPtr & debug_info,TraceLabelType trace_label)81 NameWithTrace CollectTraceInfos(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
82   NameWithTrace name_and_traces;
83   // Find debug info after Resolve/ExpandJ/GenMetaFuncGraph/GenerateVarArg/GenerateKwArg, it is a new node.
84   MS_EXCEPTION_IF_NULL(debug_info);
85   const auto &shadow_debug_infos_map = debug_info->shadow_debug_infos_map();
86   auto root_info = debug_info;
87   while (root_info != nullptr) {
88     if (root_info->trace_info() == nullptr) {
89       break;
90     }
91     if (root_info->trace_info()->isa<mindspore::TraceParse>() ||
92         root_info->trace_info()->isa<mindspore::TraceResolve>() ||
93         root_info->trace_info()->isa<mindspore::TraceExpandJ>() ||
94         root_info->trace_info()->isa<mindspore::TraceGenMetaFuncGraph>() ||
95         root_info->trace_info()->isa<mindspore::TraceGenerateVarArg>() ||
96         root_info->trace_info()->isa<mindspore::TraceGenerateKwArg>()) {
97       break;
98     }
99     const auto trace_name = GetTraceName(root_info->trace_info(), trace_label);
100     if (!trace_name.empty()) {
101       (void)name_and_traces.trace_labels.emplace_back(trace_name);
102     }
103     // Insert shadow debug info.
104     auto iter = shadow_debug_infos_map.find(root_info);
105     if (iter != shadow_debug_infos_map.end()) {
106       DebugInfoPtr shadowed_debug_info = iter->first;
107       DebugInfoPtr shadow_debug_info = iter->second;
108       MS_LOG(DEBUG) << "Insert debug info, root_info: " << root_info << "/" << root_info->name() << "/"
109                     << root_info->debug_name() << ", shadow_debug_info: " << shadow_debug_info << "/"
110                     << shadow_debug_info->name() << "/" << shadow_debug_info->debug_name()
111                     << ", shadowed_debug_info: " << shadowed_debug_info << "/" << shadowed_debug_info->name() << "/"
112                     << shadowed_debug_info->debug_name();
113       const auto shadow_trace_name = GetTraceName(shadow_debug_info->trace_info(), trace_label);
114       if (!shadow_trace_name.empty()) {
115         (void)name_and_traces.trace_labels.emplace_back(shadow_trace_name);
116       }
117     }
118     root_info = root_info->trace_info()->debug_info();
119   }
120 
121   if (!root_info->name().empty()) {
122     name_and_traces.root_name = root_info->name();
123     return name_and_traces;
124   }
125   // If it's node debug info and no trace label, use current node debug info.
126   auto node_root_info = std::dynamic_pointer_cast<NodeDebugInfo>(root_info);
127   if (node_root_info != nullptr && name_and_traces.trace_labels.empty()) {
128     root_info = debug_info;
129     // Use shadow debug info's name.
130     if (!shadow_debug_infos_map.empty()) {
131       name_and_traces.root_name = root_info->debug_name();
132       for (const auto &shadow_pair : shadow_debug_infos_map) {
133         DebugInfoPtr shadow_debug_info = shadow_pair.second;
134         if (!shadow_debug_info->name().empty()) {
135           name_and_traces.root_name += '$';
136           name_and_traces.root_name += shadow_debug_info->name();
137         }
138       }
139       return name_and_traces;
140     }
141   }
142   name_and_traces.root_name = root_info->debug_name();
143   return name_and_traces;
144 }
145 
CombineTraceInfos(const std::string & root_name,const std::vector<std::string> & trace_labels)146 std::string CombineTraceInfos(const std::string &root_name, const std::vector<std::string> &trace_labels) {
147   std::stringstream ss_labels;
148   for (size_t i = 0; i < trace_labels.size(); ++i) {
149     size_t start = i;
150     auto &start_label = trace_labels[start];
151     if (start_label.empty()) {
152       continue;
153     }
154     // Combine the same continuous symbols. For example, AAA --> 3A
155     while (i + 1 < trace_labels.size() && trace_labels[i + 1] == start_label) {
156       ++i;
157     }
158     if (start == i) {
159       ss_labels << start_label;
160     } else {
161       ss_labels << std::to_string(i - start + 1) << start_label;
162     }
163   }
164   ss_labels << root_name;
165   return ss_labels.str();
166 }
167 
168 // Get the label name of the node debug info
LabelString(const DebugInfoPtr & debug_info,TraceLabelType trace_label)169 std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
170   NameWithTrace name_and_traces = CollectTraceInfos(debug_info, trace_label);
171   return CombineTraceInfos(name_and_traces.root_name, name_and_traces.trace_labels);
172 }
173 }  // namespace
174 
175 namespace mindspore {
176 namespace trace {
SetWithUniqueId(bool enabled)177 void SetWithUniqueId(bool enabled) { *WithUniqueIdPtr() = enabled; }
178 
GetGlobalTraceLabelType()179 TraceLabelType GetGlobalTraceLabelType() {
180   static const TraceLabelType global_trace_type =
181     (mindspore::common::GetCompileConfig("TRACE_LABEL_WITH_UNIQUE_ID") == "1") ? TraceLabelType::kWithUniqueId
182                                                                                : TraceLabelType::kShortSymbol;
183   return global_trace_type;
184 }
185 
Label(const DebugInfoPtr & debug_info,TraceLabelType trace_label)186 std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
187   if ((GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) ||
188       (GetCurrentTraceLabelType() == TraceLabelType::kWithUniqueId)) {
189     return LabelStringUnique(debug_info);
190   }
191   return LabelString(debug_info, trace_label);
192 }
193 }  // namespace trace
194 }  // namespace mindspore
195