1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/utils/derived_timeline.h"
16
17 #include <algorithm>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/path.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
32 #include "tensorflow/core/profiler/utils/group_events.h"
33 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
34 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
35 #include "tensorflow/core/profiler/utils/time_utils.h"
36 #include "tensorflow/core/profiler/utils/timespan.h"
37 #include "tensorflow/core/profiler/utils/trace_utils.h"
38 #include "tensorflow/core/profiler/utils/xplane_builder.h"
39 #include "tensorflow/core/profiler/utils/xplane_schema.h"
40 #include "tensorflow/core/profiler/utils/xplane_utils.h"
41 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
42
43 namespace tensorflow {
44 namespace profiler {
45 namespace {
46
47 const absl::string_view kAnnotationDelimiter = "::";
48
CreateXEvent(const XEventMetadata & metadata,int64_t offset_ps,int64_t duration_ps,int64_t group_id_stat_metadata_id,absl::optional<int64> group_id)49 XEvent CreateXEvent(const XEventMetadata& metadata, int64_t offset_ps,
50 int64_t duration_ps, int64_t group_id_stat_metadata_id,
51 absl::optional<int64> group_id) {
52 XEvent event;
53 event.set_metadata_id(metadata.id());
54 // TODO(b/150498419): Normalize with the line start time.
55 event.set_offset_ps(offset_ps);
56 event.set_duration_ps(duration_ps);
57 if (group_id) {
58 XStat* stat = event.add_stats();
59 stat->set_metadata_id(group_id_stat_metadata_id);
60 stat->set_int64_value(*group_id);
61 }
62 return event;
63 }
64
GroupIdOrInvalid(absl::optional<int64> group_id)65 int64 GroupIdOrInvalid(absl::optional<int64> group_id) {
66 if (group_id)
67 return *group_id;
68 else
69 return DerivedXLineBuilder::kInvalidGroupId;
70 }
71
72 } // namespace
73
ProcessTfOpEvent(absl::string_view tf_op_full_name,absl::string_view low_level_event_name,int64_t offset_ps,int64_t duration_ps,absl::optional<int64> group_id,XPlaneBuilder * plane_builder,DerivedXLineBuilder * tf_name_scope_line_builder,DerivedXLineBuilder * tf_op_line_builder)74 void ProcessTfOpEvent(absl::string_view tf_op_full_name,
75 absl::string_view low_level_event_name, int64_t offset_ps,
76 int64_t duration_ps, absl::optional<int64> group_id,
77 XPlaneBuilder* plane_builder,
78 DerivedXLineBuilder* tf_name_scope_line_builder,
79 DerivedXLineBuilder* tf_op_line_builder) {
80 int64_t group_id_stat_metadata_id =
81 plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))
82 ->id();
83 TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
84 Category category = tf_op.category;
85 int64_t group_id_or_invalid = GroupIdOrInvalid(group_id);
86 if (category == Category::kTensorFlow || category == Category::kJax) {
87 std::vector<XEvent> name_scope_event_per_level;
88 for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
89 name_scope_event_per_level.push_back(CreateXEvent(
90 *plane_builder->GetOrCreateEventMetadata(tf_name_scope), offset_ps,
91 duration_ps, group_id_stat_metadata_id, group_id));
92 }
93 tf_name_scope_line_builder->ExpandOrAddEvents(
94 name_scope_event_per_level, group_id_or_invalid, low_level_event_name);
95 }
96 XEventMetadata* tf_op_event_metadata =
97 plane_builder->GetOrCreateEventMetadata(tf_op_full_name);
98 // Set the display name to op_type so that the events of the same op_type have
99 // the same color in the trace viewer.
100 tf_op_event_metadata->set_display_name(TfOpEventName(tf_op));
101 tf_op_line_builder->ExpandOrAddEvent(
102 CreateXEvent(*tf_op_event_metadata, offset_ps, duration_ps,
103 group_id_stat_metadata_id, group_id),
104 group_id_or_invalid, low_level_event_name);
105 }
106
DerivedXLineBuilder(XPlaneBuilder * plane,int64_t line_id,absl::string_view name,int64_t timestamp_ns,std::vector<DerivedXLineBuilder * > dependent_lines)107 DerivedXLineBuilder::DerivedXLineBuilder(
108 XPlaneBuilder* plane, int64_t line_id, absl::string_view name,
109 int64_t timestamp_ns, std::vector<DerivedXLineBuilder*> dependent_lines)
110 : line_(plane->GetOrCreateLine(line_id)) {
111 line_.SetName(name);
112 line_.SetTimestampNs(timestamp_ns);
113 dependent_lines_ = std::move(dependent_lines);
114 }
115
ExpandOrAddLevelEvent(const XEvent & event,int64_t group_id,absl::string_view low_level_event_name,int level)116 void DerivedXLineBuilder::ExpandOrAddLevelEvent(
117 const XEvent& event, int64_t group_id,
118 absl::string_view low_level_event_name, int level) {
119 int64_t offset_ps = event.offset_ps();
120 int64_t duration_ps = event.duration_ps();
121 auto& last_event = last_event_by_level_[level];
122 // If last_event is not nullptr, its offset must be less than or equal to
123 // the given event's offset.
124 DCHECK(!last_event || last_event->OffsetPs() <= offset_ps);
125 auto& last_eventinfo = last_eventinfo_by_level_[level];
126 bool merge_last_event = false;
127 if (last_event && last_event->MetadataId() == event.metadata_id()) {
128 // If last_event is not nullptr and metadata is same, merge the given
129 // event into last_event.
130 DCHECK(last_eventinfo); // last_eventinfo must be valid as well.
131 // Merges event with last_event if (1) they have the same group_id
132 // and (2) low_level_event_name hasn't been seen before. If
133 // low_level_event has been seen before, event and last_event are actually
134 // different invocations of the same Op, and so they shouldn't be merged.
135 merge_last_event =
136 (group_id == last_eventinfo->group_id) &&
137 !last_eventinfo->low_level_event_names.contains(low_level_event_name);
138 }
139 if (merge_last_event) {
140 // Merge event with last_event.
141 last_event->SetDurationPs((offset_ps + duration_ps) -
142 last_event->OffsetPs());
143 if (!low_level_event_name.empty()) {
144 // One more low_level_event_name associated with last_event.
145 last_eventinfo->low_level_event_names.insert(
146 std::string(low_level_event_name));
147 }
148 } else {
149 // Otherwise, reset the last events lower than or equal to the given level.
150 ResetLastEvents(level);
151 // And create a new event for the given level.
152 last_event = line_.AddEvent(event);
153 // Also create a new XEventInfo for this level.
154 last_eventinfo = XEventInfo(group_id, low_level_event_name);
155 }
156 }
157
ResetLastEvents(int level)158 void DerivedXLineBuilder::ResetLastEvents(int level) {
159 for (int i = level, end = last_event_by_level_.size(); i < end; ++i) {
160 last_event_by_level_[i] = absl::nullopt;
161 last_eventinfo_by_level_[i] = absl::nullopt;
162 }
163 if (level == 0) ResetDependentLines();
164 }
165
DeriveEventsFromAnnotations(const SymbolResolver & symbol_resolver,const GroupMetadataMap & group_metadata_map,XPlane * device_trace,bool step_info_only)166 void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver,
167 const GroupMetadataMap& group_metadata_map,
168 XPlane* device_trace, bool step_info_only) {
169 // Merge and sort events by Timespan as they come from different lines.
170 std::vector<XEventVisitor> events;
171 uint64 start_timestamp_ns = 0;
172 XPlaneVisitor device_plane = CreateTfXPlaneVisitor(device_trace);
173 device_plane.ForEachLine([&](const XLineVisitor& line) {
174 if (IsDerivedThreadId(line.Id())) return; // Skip overhead line.
175 start_timestamp_ns = line.TimestampNs();
176 line.ForEachEvent(
177 [&](const XEventVisitor& event) { events.push_back(event); });
178 });
179 absl::c_sort(events);
180
181 XPlaneBuilder plane(device_trace);
182 DerivedXLineBuilder tf_ops(&plane, kThreadIdTfOp, kTensorFlowOpLineName,
183 start_timestamp_ns, {});
184 DerivedXLineBuilder tf_name_scope(&plane, kThreadIdTfNameScope,
185 kTensorFlowNameScopeLineName,
186 start_timestamp_ns, {&tf_ops});
187 DerivedXLineBuilder hlo_ops(&plane, kThreadIdHloOp, kXlaOpLineName,
188 start_timestamp_ns, {});
189 DerivedXLineBuilder hlo_modules(&plane, kThreadIdHloModule,
190 kXlaModuleLineName, start_timestamp_ns,
191 {&tf_name_scope, &hlo_ops});
192 DerivedXLineBuilder steps(&plane, kThreadIdStepInfo, kStepLineName,
193 start_timestamp_ns, {&hlo_modules});
194 DerivedXLineBuilder source(&plane, kThreadIdSource, kSourceLineName,
195 start_timestamp_ns, {});
196
197 int64_t group_id_stat_metadata_id =
198 plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))->id();
199 int64_t step_name_stat_metadata_id =
200 plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName))->id();
201
202 // Process events in order by start time.
203 for (const XEventVisitor& event : events) {
204 int64_t offset_ps = event.OffsetPs();
205 int64_t duration_ps = event.DurationPs();
206 absl::string_view tf_op_full_name;
207 absl::string_view hlo_module_name;
208 std::vector<absl::string_view> hlo_op_names;
209 absl::optional<int64> group_id;
210 bool is_kernel = false;
211 event.ForEachStat([&](const XStatVisitor& stat) {
212 if (stat.Type() == StatType::kGroupId) {
213 group_id = stat.IntValue();
214 } else if (stat.Type() == StatType::kLevel0 || // old way to carry tf_op
215 stat.Type() == StatType::kTfOp) {
216 tf_op_full_name = stat.StrOrRefValue();
217 } else if (stat.Type() == StatType::kHloOp) {
218 hlo_op_names =
219 absl::StrSplit(stat.StrOrRefValue(), kAnnotationDelimiter);
220 } else if (stat.Type() == StatType::kHloModule) {
221 hlo_module_name = stat.StrOrRefValue();
222 } else if (stat.Type() == StatType::kKernelDetails) {
223 is_kernel = true;
224 }
225 });
226 int64_t group_id_or_invalid = GroupIdOrInvalid(group_id);
227 if (group_id) {
228 XEvent step_event = CreateXEvent(
229 *plane.GetOrCreateEventMetadata(absl::StrCat(*group_id)), offset_ps,
230 duration_ps, group_id_stat_metadata_id, group_id);
231 if (auto group_metadata =
232 gtl::FindOrNull(group_metadata_map, *group_id)) {
233 XStat* stat = step_event.add_stats();
234 stat->set_metadata_id(step_name_stat_metadata_id);
235 stat->set_str_value(group_metadata->name);
236 }
237 steps.ExpandOrAddEvent(step_event, group_id_or_invalid);
238 }
239
240 if (step_info_only) continue;
241
242 // For HLO/TF op lines, only use kernel events (i.e. excluding memcpy or
243 // allocation events).
244 if (!is_kernel) continue;
245
246 if (!hlo_module_name.empty()) {
247 hlo_modules.ExpandOrAddEvent(CreateXEvent(
248 *plane.GetOrCreateEventMetadata(hlo_module_name), offset_ps,
249 duration_ps, group_id_stat_metadata_id, group_id));
250 }
251
252 if (!hlo_op_names.empty()) { // GPU kernel compiled by XLA
253 DCHECK(!hlo_module_name.empty());
254 std::vector<XEvent> hlo_op_event_per_level;
255 for (absl::string_view hlo_op_name : hlo_op_names) {
256 DCHECK(!hlo_op_name.empty());
257 hlo_op_event_per_level.push_back(CreateXEvent(
258 *plane.GetOrCreateEventMetadata(hlo_op_name), offset_ps,
259 duration_ps, group_id_stat_metadata_id, group_id));
260 }
261 hlo_ops.ExpandOrAddEvents(hlo_op_event_per_level, group_id_or_invalid);
262 auto symbol = symbol_resolver(hlo_module_name, hlo_op_names.back());
263 if (!symbol.tf_op_name.empty()) {
264 ProcessTfOpEvent(symbol.tf_op_name,
265 /*low_level_event_name=*/event.Name(), offset_ps,
266 duration_ps, group_id, &plane, &tf_name_scope,
267 &tf_ops);
268 }
269 if (!symbol.source_info.empty()) {
270 source.ExpandOrAddEvent(CreateXEvent(
271 *plane.GetOrCreateEventMetadata(symbol.source_info), offset_ps,
272 duration_ps, group_id_stat_metadata_id, group_id));
273 }
274 } else if (!tf_op_full_name.empty()) { // GPU kernel not compiled by XLA
275 ProcessTfOpEvent(tf_op_full_name,
276 /*low_level_event_name=*/event.Name(), offset_ps,
277 duration_ps, group_id, &plane, &tf_name_scope, &tf_ops);
278 }
279 }
280 RemoveEmptyLines(device_trace);
281 }
282
DeriveEventsFromHostTrace(const XPlane * host_trace,const GroupMetadataMap & group_metadata_map,std::vector<XPlane * > device_traces)283 void DeriveEventsFromHostTrace(const XPlane* host_trace,
284 const GroupMetadataMap& group_metadata_map,
285 std::vector<XPlane*> device_traces) {
286 struct GroupLaunchInfo { // "Group" normally means step.
287 Timespan timespan;
288 int32 num_launches = 0;
289 uint64 max_launch_time_ps = 0ULL;
290 uint64 total_launch_time_ps = 0ULL;
291 };
292 typedef absl::flat_hash_map<int64 /*group_id*/, GroupLaunchInfo>
293 DeviceLaunchInfo;
294
295 int num_devices = device_traces.size();
296 std::vector<DeviceLaunchInfo> per_device_launch_info(num_devices);
297
298 XPlaneVisitor host_plane = CreateTfXPlaneVisitor(host_trace);
299 host_plane.ForEachLine([&](const XLineVisitor& line) {
300 if (IsDerivedThreadId(line.Id())) return;
301 line.ForEachEvent([&](const XEventVisitor& event) {
302 absl::optional<int64> group_id;
303 absl::optional<int64> device_id;
304 absl::optional<int64> correlation_id;
305 // Filter out API calls for cuEventRecord/cuEventQuery/cuCtxSynchronize
306 // etc for now. TODO: find a better way to filter out only the memcpy and
307 // kernel launch events.
308 if (absl::StartsWith(event.Name(), "cu")) return;
309 event.ForEachStat([&](const XStatVisitor& stat) {
310 if (stat.Type() == StatType::kGroupId) {
311 group_id = stat.IntValue();
312 } else if (stat.Type() == StatType::kDeviceId) {
313 device_id = stat.IntOrUintValue();
314 } else if (stat.Type() == StatType::kCorrelationId) {
315 correlation_id = stat.IntValue();
316 }
317 });
318 if (group_id && device_id && correlation_id && *device_id >= 0 &&
319 *device_id < num_devices) {
320 // This is a launch event on a known device.
321 GroupLaunchInfo& group_launch_info =
322 per_device_launch_info[*device_id][*group_id];
323 Timespan& group_span = group_launch_info.timespan;
324 Timespan event_span = event.GetTimespan();
325 if (group_launch_info.num_launches) { // Existing group.
326 group_span.ExpandToInclude(event_span);
327 } else {
328 group_span = event_span;
329 }
330 ++group_launch_info.num_launches;
331 group_launch_info.max_launch_time_ps = std::max(
332 group_launch_info.max_launch_time_ps, event_span.duration_ps());
333 group_launch_info.total_launch_time_ps += event_span.duration_ps();
334 }
335 });
336 });
337
338 uint64 host_plane_start = GetStartTimestampNs(*host_trace);
339 for (int i = 0; i < num_devices; ++i) {
340 if (per_device_launch_info[i].empty()) continue;
341 uint64 device_plane_start = GetStartTimestampNs(*device_traces[i]);
342 XPlaneBuilder device_plane(device_traces[i]);
343 XLineBuilder launch_line =
344 device_plane.GetOrCreateLine(kThreadIdKernelLaunch);
345 launch_line.SetName(kKernelLaunchLineName);
346 launch_line.SetTimestampNs(std::min(device_plane_start, host_plane_start));
347 for (const auto& kv : per_device_launch_info[i]) {
348 int64_t group_id = kv.first;
349 const GroupLaunchInfo& group_info = kv.second;
350 if (auto group_metadata = gtl::FindOrNull(group_metadata_map, group_id)) {
351 XEventBuilder device_event =
352 launch_line.AddEvent(*device_plane.GetOrCreateEventMetadata(
353 absl::StrCat("Launch Stats for ", group_metadata->name)));
354 device_event.SetTimestampNs(
355 host_plane_start + PicosToNanos(group_info.timespan.begin_ps()));
356 device_event.SetDurationPs(group_info.timespan.duration_ps());
357 device_event.AddStatValue(*device_plane.GetOrCreateStatMetadata(
358 GetStatTypeStr(StatType::kGroupId)),
359 group_id);
360 device_event.AddStatValue(
361 *device_plane.GetOrCreateStatMetadata("num_launches"),
362 group_info.num_launches);
363 device_event.AddStatValue(
364 *device_plane.GetOrCreateStatMetadata("max_launch_time_us"),
365 PicosToMicros(group_info.max_launch_time_ps));
366 device_event.AddStatValue(
367 *device_plane.GetOrCreateStatMetadata("avg_launch_time_us"),
368 PicosToMicros(group_info.total_launch_time_ps /
369 group_info.num_launches));
370 }
371 }
372 }
373 }
374
GenerateDerivedTimeLines(const GroupMetadataMap & group_metadata_map,XSpace * space,bool step_info_only)375 void GenerateDerivedTimeLines(const GroupMetadataMap& group_metadata_map,
376 XSpace* space, bool step_info_only) {
377 // TODO(profiler): Once we capture HLO protos for xla/gpu, we should use that
378 // to look up tensorflow op name from hlo_module/hlo_op.
379 auto dummy_symbol_resolver = [](absl::string_view hlo_module,
380 absl::string_view hlo_op) {
381 return tensorflow::profiler::Symbol();
382 };
383 std::vector<XPlane*> device_traces =
384 FindMutablePlanesWithPrefix(space, kGpuPlanePrefix);
385 for (XPlane* plane : device_traces) {
386 DeriveEventsFromAnnotations(dummy_symbol_resolver, group_metadata_map,
387 plane, step_info_only);
388 }
389 }
390
391 } // namespace profiler
392 } // namespace tensorflow
393