• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/utils/derived_timeline.h"
16 
17 #include <algorithm>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/path.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
32 #include "tensorflow/core/profiler/utils/group_events.h"
33 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
34 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
35 #include "tensorflow/core/profiler/utils/time_utils.h"
36 #include "tensorflow/core/profiler/utils/timespan.h"
37 #include "tensorflow/core/profiler/utils/trace_utils.h"
38 #include "tensorflow/core/profiler/utils/xplane_builder.h"
39 #include "tensorflow/core/profiler/utils/xplane_schema.h"
40 #include "tensorflow/core/profiler/utils/xplane_utils.h"
41 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
42 
43 namespace tensorflow {
44 namespace profiler {
45 namespace {
46 
47 const absl::string_view kAnnotationDelimiter = "::";
48 
CreateXEvent(const XEventMetadata & metadata,int64_t offset_ps,int64_t duration_ps,int64_t group_id_stat_metadata_id,absl::optional<int64> group_id)49 XEvent CreateXEvent(const XEventMetadata& metadata, int64_t offset_ps,
50                     int64_t duration_ps, int64_t group_id_stat_metadata_id,
51                     absl::optional<int64> group_id) {
52   XEvent event;
53   event.set_metadata_id(metadata.id());
54   // TODO(b/150498419): Normalize with the line start time.
55   event.set_offset_ps(offset_ps);
56   event.set_duration_ps(duration_ps);
57   if (group_id) {
58     XStat* stat = event.add_stats();
59     stat->set_metadata_id(group_id_stat_metadata_id);
60     stat->set_int64_value(*group_id);
61   }
62   return event;
63 }
64 
GroupIdOrInvalid(absl::optional<int64> group_id)65 int64 GroupIdOrInvalid(absl::optional<int64> group_id) {
66   if (group_id)
67     return *group_id;
68   else
69     return DerivedXLineBuilder::kInvalidGroupId;
70 }
71 
72 }  // namespace
73 
ProcessTfOpEvent(absl::string_view tf_op_full_name,absl::string_view low_level_event_name,int64_t offset_ps,int64_t duration_ps,absl::optional<int64> group_id,XPlaneBuilder * plane_builder,DerivedXLineBuilder * tf_name_scope_line_builder,DerivedXLineBuilder * tf_op_line_builder)74 void ProcessTfOpEvent(absl::string_view tf_op_full_name,
75                       absl::string_view low_level_event_name, int64_t offset_ps,
76                       int64_t duration_ps, absl::optional<int64> group_id,
77                       XPlaneBuilder* plane_builder,
78                       DerivedXLineBuilder* tf_name_scope_line_builder,
79                       DerivedXLineBuilder* tf_op_line_builder) {
80   int64_t group_id_stat_metadata_id =
81       plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))
82           ->id();
83   TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
84   Category category = tf_op.category;
85   int64_t group_id_or_invalid = GroupIdOrInvalid(group_id);
86   if (category == Category::kTensorFlow || category == Category::kJax) {
87     std::vector<XEvent> name_scope_event_per_level;
88     for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
89       name_scope_event_per_level.push_back(CreateXEvent(
90           *plane_builder->GetOrCreateEventMetadata(tf_name_scope), offset_ps,
91           duration_ps, group_id_stat_metadata_id, group_id));
92     }
93     tf_name_scope_line_builder->ExpandOrAddEvents(
94         name_scope_event_per_level, group_id_or_invalid, low_level_event_name);
95   }
96   XEventMetadata* tf_op_event_metadata =
97       plane_builder->GetOrCreateEventMetadata(tf_op_full_name);
98   // Set the display name to op_type so that the events of the same op_type have
99   // the same color in the trace viewer.
100   tf_op_event_metadata->set_display_name(TfOpEventName(tf_op));
101   tf_op_line_builder->ExpandOrAddEvent(
102       CreateXEvent(*tf_op_event_metadata, offset_ps, duration_ps,
103                    group_id_stat_metadata_id, group_id),
104       group_id_or_invalid, low_level_event_name);
105 }
106 
DerivedXLineBuilder(XPlaneBuilder * plane,int64_t line_id,absl::string_view name,int64_t timestamp_ns,std::vector<DerivedXLineBuilder * > dependent_lines)107 DerivedXLineBuilder::DerivedXLineBuilder(
108     XPlaneBuilder* plane, int64_t line_id, absl::string_view name,
109     int64_t timestamp_ns, std::vector<DerivedXLineBuilder*> dependent_lines)
110     : line_(plane->GetOrCreateLine(line_id)) {
111   line_.SetName(name);
112   line_.SetTimestampNs(timestamp_ns);
113   dependent_lines_ = std::move(dependent_lines);
114 }
115 
ExpandOrAddLevelEvent(const XEvent & event,int64_t group_id,absl::string_view low_level_event_name,int level)116 void DerivedXLineBuilder::ExpandOrAddLevelEvent(
117     const XEvent& event, int64_t group_id,
118     absl::string_view low_level_event_name, int level) {
119   int64_t offset_ps = event.offset_ps();
120   int64_t duration_ps = event.duration_ps();
121   auto& last_event = last_event_by_level_[level];
122   // If last_event is not nullptr, its offset must be less than or equal to
123   // the given event's offset.
124   DCHECK(!last_event || last_event->OffsetPs() <= offset_ps);
125   auto& last_eventinfo = last_eventinfo_by_level_[level];
126   bool merge_last_event = false;
127   if (last_event && last_event->MetadataId() == event.metadata_id()) {
128     // If last_event is not nullptr and metadata is same, merge the given
129     // event into last_event.
130     DCHECK(last_eventinfo);  // last_eventinfo must be valid as well.
131     // Merges event with last_event if (1) they have the same group_id
132     // and (2) low_level_event_name hasn't been seen before. If
133     // low_level_event has been seen before, event and last_event are actually
134     // different invocations of the same Op, and so they shouldn't be merged.
135     merge_last_event =
136         (group_id == last_eventinfo->group_id) &&
137         !last_eventinfo->low_level_event_names.contains(low_level_event_name);
138   }
139   if (merge_last_event) {
140     // Merge event with last_event.
141     last_event->SetDurationPs((offset_ps + duration_ps) -
142                               last_event->OffsetPs());
143     if (!low_level_event_name.empty()) {
144       // One more low_level_event_name associated with last_event.
145       last_eventinfo->low_level_event_names.insert(
146           std::string(low_level_event_name));
147     }
148   } else {
149     // Otherwise, reset the last events lower than or equal to the given level.
150     ResetLastEvents(level);
151     // And create a new event for the given level.
152     last_event = line_.AddEvent(event);
153     // Also create a new XEventInfo for this level.
154     last_eventinfo = XEventInfo(group_id, low_level_event_name);
155   }
156 }
157 
ResetLastEvents(int level)158 void DerivedXLineBuilder::ResetLastEvents(int level) {
159   for (int i = level, end = last_event_by_level_.size(); i < end; ++i) {
160     last_event_by_level_[i] = absl::nullopt;
161     last_eventinfo_by_level_[i] = absl::nullopt;
162   }
163   if (level == 0) ResetDependentLines();
164 }
165 
DeriveEventsFromAnnotations(const SymbolResolver & symbol_resolver,const GroupMetadataMap & group_metadata_map,XPlane * device_trace,bool step_info_only)166 void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver,
167                                  const GroupMetadataMap& group_metadata_map,
168                                  XPlane* device_trace, bool step_info_only) {
169   // Merge and sort events by Timespan as they come from different lines.
170   std::vector<XEventVisitor> events;
171   uint64 start_timestamp_ns = 0;
172   XPlaneVisitor device_plane = CreateTfXPlaneVisitor(device_trace);
173   device_plane.ForEachLine([&](const XLineVisitor& line) {
174     if (IsDerivedThreadId(line.Id())) return;  // Skip overhead line.
175     start_timestamp_ns = line.TimestampNs();
176     line.ForEachEvent(
177         [&](const XEventVisitor& event) { events.push_back(event); });
178   });
179   absl::c_sort(events);
180 
181   XPlaneBuilder plane(device_trace);
182   DerivedXLineBuilder tf_ops(&plane, kThreadIdTfOp, kTensorFlowOpLineName,
183                              start_timestamp_ns, {});
184   DerivedXLineBuilder tf_name_scope(&plane, kThreadIdTfNameScope,
185                                     kTensorFlowNameScopeLineName,
186                                     start_timestamp_ns, {&tf_ops});
187   DerivedXLineBuilder hlo_ops(&plane, kThreadIdHloOp, kXlaOpLineName,
188                               start_timestamp_ns, {});
189   DerivedXLineBuilder hlo_modules(&plane, kThreadIdHloModule,
190                                   kXlaModuleLineName, start_timestamp_ns,
191                                   {&tf_name_scope, &hlo_ops});
192   DerivedXLineBuilder steps(&plane, kThreadIdStepInfo, kStepLineName,
193                             start_timestamp_ns, {&hlo_modules});
194   DerivedXLineBuilder source(&plane, kThreadIdSource, kSourceLineName,
195                              start_timestamp_ns, {});
196 
197   int64_t group_id_stat_metadata_id =
198       plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))->id();
199   int64_t step_name_stat_metadata_id =
200       plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName))->id();
201 
202   // Process events in order by start time.
203   for (const XEventVisitor& event : events) {
204     int64_t offset_ps = event.OffsetPs();
205     int64_t duration_ps = event.DurationPs();
206     absl::string_view tf_op_full_name;
207     absl::string_view hlo_module_name;
208     std::vector<absl::string_view> hlo_op_names;
209     absl::optional<int64> group_id;
210     bool is_kernel = false;
211     event.ForEachStat([&](const XStatVisitor& stat) {
212       if (stat.Type() == StatType::kGroupId) {
213         group_id = stat.IntValue();
214       } else if (stat.Type() == StatType::kLevel0 ||  // old way to carry tf_op
215                  stat.Type() == StatType::kTfOp) {
216         tf_op_full_name = stat.StrOrRefValue();
217       } else if (stat.Type() == StatType::kHloOp) {
218         hlo_op_names =
219             absl::StrSplit(stat.StrOrRefValue(), kAnnotationDelimiter);
220       } else if (stat.Type() == StatType::kHloModule) {
221         hlo_module_name = stat.StrOrRefValue();
222       } else if (stat.Type() == StatType::kKernelDetails) {
223         is_kernel = true;
224       }
225     });
226     int64_t group_id_or_invalid = GroupIdOrInvalid(group_id);
227     if (group_id) {
228       XEvent step_event = CreateXEvent(
229           *plane.GetOrCreateEventMetadata(absl::StrCat(*group_id)), offset_ps,
230           duration_ps, group_id_stat_metadata_id, group_id);
231       if (auto group_metadata =
232               gtl::FindOrNull(group_metadata_map, *group_id)) {
233         XStat* stat = step_event.add_stats();
234         stat->set_metadata_id(step_name_stat_metadata_id);
235         stat->set_str_value(group_metadata->name);
236       }
237       steps.ExpandOrAddEvent(step_event, group_id_or_invalid);
238     }
239 
240     if (step_info_only) continue;
241 
242     // For HLO/TF op lines, only use kernel events (i.e. excluding memcpy or
243     // allocation events).
244     if (!is_kernel) continue;
245 
246     if (!hlo_module_name.empty()) {
247       hlo_modules.ExpandOrAddEvent(CreateXEvent(
248           *plane.GetOrCreateEventMetadata(hlo_module_name), offset_ps,
249           duration_ps, group_id_stat_metadata_id, group_id));
250     }
251 
252     if (!hlo_op_names.empty()) {  // GPU kernel compiled by XLA
253       DCHECK(!hlo_module_name.empty());
254       std::vector<XEvent> hlo_op_event_per_level;
255       for (absl::string_view hlo_op_name : hlo_op_names) {
256         DCHECK(!hlo_op_name.empty());
257         hlo_op_event_per_level.push_back(CreateXEvent(
258             *plane.GetOrCreateEventMetadata(hlo_op_name), offset_ps,
259             duration_ps, group_id_stat_metadata_id, group_id));
260       }
261       hlo_ops.ExpandOrAddEvents(hlo_op_event_per_level, group_id_or_invalid);
262       auto symbol = symbol_resolver(hlo_module_name, hlo_op_names.back());
263       if (!symbol.tf_op_name.empty()) {
264         ProcessTfOpEvent(symbol.tf_op_name,
265                          /*low_level_event_name=*/event.Name(), offset_ps,
266                          duration_ps, group_id, &plane, &tf_name_scope,
267                          &tf_ops);
268       }
269       if (!symbol.source_info.empty()) {
270         source.ExpandOrAddEvent(CreateXEvent(
271             *plane.GetOrCreateEventMetadata(symbol.source_info), offset_ps,
272             duration_ps, group_id_stat_metadata_id, group_id));
273       }
274     } else if (!tf_op_full_name.empty()) {  // GPU kernel not compiled by XLA
275       ProcessTfOpEvent(tf_op_full_name,
276                        /*low_level_event_name=*/event.Name(), offset_ps,
277                        duration_ps, group_id, &plane, &tf_name_scope, &tf_ops);
278     }
279   }
280   RemoveEmptyLines(device_trace);
281 }
282 
DeriveEventsFromHostTrace(const XPlane * host_trace,const GroupMetadataMap & group_metadata_map,std::vector<XPlane * > device_traces)283 void DeriveEventsFromHostTrace(const XPlane* host_trace,
284                                const GroupMetadataMap& group_metadata_map,
285                                std::vector<XPlane*> device_traces) {
286   struct GroupLaunchInfo {  // "Group" normally means step.
287     Timespan timespan;
288     int32 num_launches = 0;
289     uint64 max_launch_time_ps = 0ULL;
290     uint64 total_launch_time_ps = 0ULL;
291   };
292   typedef absl::flat_hash_map<int64 /*group_id*/, GroupLaunchInfo>
293       DeviceLaunchInfo;
294 
295   int num_devices = device_traces.size();
296   std::vector<DeviceLaunchInfo> per_device_launch_info(num_devices);
297 
298   XPlaneVisitor host_plane = CreateTfXPlaneVisitor(host_trace);
299   host_plane.ForEachLine([&](const XLineVisitor& line) {
300     if (IsDerivedThreadId(line.Id())) return;
301     line.ForEachEvent([&](const XEventVisitor& event) {
302       absl::optional<int64> group_id;
303       absl::optional<int64> device_id;
304       absl::optional<int64> correlation_id;
305       // Filter out API calls for cuEventRecord/cuEventQuery/cuCtxSynchronize
306       // etc for now. TODO: find a better way to filter out only the memcpy and
307       // kernel launch events.
308       if (absl::StartsWith(event.Name(), "cu")) return;
309       event.ForEachStat([&](const XStatVisitor& stat) {
310         if (stat.Type() == StatType::kGroupId) {
311           group_id = stat.IntValue();
312         } else if (stat.Type() == StatType::kDeviceId) {
313           device_id = stat.IntOrUintValue();
314         } else if (stat.Type() == StatType::kCorrelationId) {
315           correlation_id = stat.IntValue();
316         }
317       });
318       if (group_id && device_id && correlation_id && *device_id >= 0 &&
319           *device_id < num_devices) {
320         // This is a launch event on a known device.
321         GroupLaunchInfo& group_launch_info =
322             per_device_launch_info[*device_id][*group_id];
323         Timespan& group_span = group_launch_info.timespan;
324         Timespan event_span = event.GetTimespan();
325         if (group_launch_info.num_launches) {  // Existing group.
326           group_span.ExpandToInclude(event_span);
327         } else {
328           group_span = event_span;
329         }
330         ++group_launch_info.num_launches;
331         group_launch_info.max_launch_time_ps = std::max(
332             group_launch_info.max_launch_time_ps, event_span.duration_ps());
333         group_launch_info.total_launch_time_ps += event_span.duration_ps();
334       }
335     });
336   });
337 
338   uint64 host_plane_start = GetStartTimestampNs(*host_trace);
339   for (int i = 0; i < num_devices; ++i) {
340     if (per_device_launch_info[i].empty()) continue;
341     uint64 device_plane_start = GetStartTimestampNs(*device_traces[i]);
342     XPlaneBuilder device_plane(device_traces[i]);
343     XLineBuilder launch_line =
344         device_plane.GetOrCreateLine(kThreadIdKernelLaunch);
345     launch_line.SetName(kKernelLaunchLineName);
346     launch_line.SetTimestampNs(std::min(device_plane_start, host_plane_start));
347     for (const auto& kv : per_device_launch_info[i]) {
348       int64_t group_id = kv.first;
349       const GroupLaunchInfo& group_info = kv.second;
350       if (auto group_metadata = gtl::FindOrNull(group_metadata_map, group_id)) {
351         XEventBuilder device_event =
352             launch_line.AddEvent(*device_plane.GetOrCreateEventMetadata(
353                 absl::StrCat("Launch Stats for ", group_metadata->name)));
354         device_event.SetTimestampNs(
355             host_plane_start + PicosToNanos(group_info.timespan.begin_ps()));
356         device_event.SetDurationPs(group_info.timespan.duration_ps());
357         device_event.AddStatValue(*device_plane.GetOrCreateStatMetadata(
358                                       GetStatTypeStr(StatType::kGroupId)),
359                                   group_id);
360         device_event.AddStatValue(
361             *device_plane.GetOrCreateStatMetadata("num_launches"),
362             group_info.num_launches);
363         device_event.AddStatValue(
364             *device_plane.GetOrCreateStatMetadata("max_launch_time_us"),
365             PicosToMicros(group_info.max_launch_time_ps));
366         device_event.AddStatValue(
367             *device_plane.GetOrCreateStatMetadata("avg_launch_time_us"),
368             PicosToMicros(group_info.total_launch_time_ps /
369                           group_info.num_launches));
370       }
371     }
372   }
373 }
374 
GenerateDerivedTimeLines(const GroupMetadataMap & group_metadata_map,XSpace * space,bool step_info_only)375 void GenerateDerivedTimeLines(const GroupMetadataMap& group_metadata_map,
376                               XSpace* space, bool step_info_only) {
377   // TODO(profiler): Once we capture HLO protos for xla/gpu, we should use that
378   // to look up tensorflow op name from hlo_module/hlo_op.
379   auto dummy_symbol_resolver = [](absl::string_view hlo_module,
380                                   absl::string_view hlo_op) {
381     return tensorflow::profiler::Symbol();
382   };
383   std::vector<XPlane*> device_traces =
384       FindMutablePlanesWithPrefix(space, kGpuPlanePrefix);
385   for (XPlane* plane : device_traces) {
386     DeriveEventsFromAnnotations(dummy_symbol_resolver, group_metadata_map,
387                                 plane, step_info_only);
388   }
389 }
390 
391 }  // namespace profiler
392 }  // namespace tensorflow
393