• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/group_events.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <iterator>
21 #include <map>
22 #include <memory>
23 #include <queue>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/strings/match.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/profiler/lib/connected_traceme.h"
35 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
36 #include "tensorflow/core/profiler/utils/xplane_builder.h"
37 #include "tensorflow/core/profiler/utils/xplane_schema.h"
38 #include "tensorflow/core/profiler/utils/xplane_utils.h"
39 
40 namespace tensorflow {
41 namespace profiler {
42 namespace {
43 
44 // Creates stat metadata for the stats which may be added by grouping.
CreateStatMetadata(XPlane * plane)45 void CreateStatMetadata(XPlane* plane) {
46   XPlaneBuilder builder(plane);
47   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
48   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
49   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
50   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSelectedGroupIds));
51 }
52 
53 // Returns event type if it is a KernelLaunch or KernelExecute event.
GetKernelEventType(bool is_host_plane,const EventNode & event)54 absl::optional<int64> GetKernelEventType(bool is_host_plane,
55                                          const EventNode& event) {
56   if (event.GetEventVisitor().GetStat(StatType::kCorrelationId).has_value()) {
57     return is_host_plane ? HostEventType::kKernelLaunch
58                          : HostEventType::kKernelExecute;
59   }
60   return absl::nullopt;
61 }
62 
GetEventType(bool is_host_plane,const EventNode & event)63 int64 GetEventType(bool is_host_plane, const EventNode& event) {
64   if (absl::optional<int64> event_type = event.GetEventVisitor().Type()) {
65     return *event_type;
66   } else if (absl::optional<int64> kernel_event_type =
67                  GetKernelEventType(is_host_plane, event)) {
68     // KernelLaunch and KernelExecute event types are not supported by
69     // XPlaneVisitor and should be checked separately.
70     // TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
71     // KernelExecute event types.
72     return *kernel_event_type;
73   } else {
74     absl::string_view name = event.GetEventVisitor().Name();
75     // Legacy event names appended with arguments.
76     if (absl::StartsWith(name, "BatchingSessionRun")) {
77       return HostEventType::kBatchingSessionRun;
78     } else if (absl::StartsWith(name, "ProcessBatch")) {
79       return HostEventType::kProcessBatch;
80     }
81     return HostEventType::kUnknownHostEventType;
82   }
83 }
84 
SetContextGroup(EventNode * event,ContextGroupMap * context_groups)85 void SetContextGroup(EventNode* event, ContextGroupMap* context_groups) {
86   auto producer = event->GetProducerContext();
87   if (producer.has_value()) {
88     ((*context_groups)[producer->type][producer->id])
89         .producers.push_back(event);
90   }
91   auto consumer = event->GetConsumerContext();
92   if (consumer.has_value()) {
93     ((*context_groups)[consumer->type][consumer->id])
94         .consumers.push_back(event);
95   }
96 }
97 
ConnectContextGroups(const ContextGroupMap & context_groups)98 void ConnectContextGroups(const ContextGroupMap& context_groups) {
99   for (auto& type_id_group : context_groups) {
100     for (auto& id_group : type_id_group.second) {
101       const ContextGroup& group = id_group.second;
102       for (EventNode* parent : group.producers) {
103         for (EventNode* child : group.consumers) {
104           parent->AddChild(child);
105         }
106       }
107     }
108   }
109 }
110 
CreateVirtualEvent(const XStat & step_id_stat,const XStat & iter_num_stat)111 std::unique_ptr<XEvent> CreateVirtualEvent(const XStat& step_id_stat,
112                                            const XStat& iter_num_stat) {
113   auto virtual_event = absl::make_unique<XEvent>();
114   *virtual_event->add_stats() = step_id_stat;
115   *virtual_event->add_stats() = iter_num_stat;
116   return virtual_event;
117 }
118 
HasFunctionRun(EventNode * event_node)119 bool HasFunctionRun(EventNode* event_node) {
120   for (EventNode* child : event_node->GetChildren()) {
121     if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
122       return true;
123     }
124   }
125   return false;
126 }
127 
IsImplicitRootEvent(const XEventVisitor & event)128 bool IsImplicitRootEvent(const XEventVisitor& event) {
129   static const auto* const kImplicitRootEvents = new absl::flat_hash_set<int64>{
130       HostEventType::kFunctionRun, HostEventType::kSessionRun,
131       HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
132   return event.Type().has_value() &&
133          kImplicitRootEvents->contains(*event.Type());
134 }
135 
ProcessRootEvent(int64 group_id,bool set_step_name,EventNode * root_event,GroupMetadataMap * group_metadata_map)136 void ProcessRootEvent(int64 group_id, bool set_step_name, EventNode* root_event,
137                       GroupMetadataMap* group_metadata_map) {
138   root_event->PropagateGroupId(group_id, group_metadata_map);
139   if (!set_step_name) {
140     // Step names are not necessary for inference profiles but add group_id to
141     // group_metadata_map to count the number of groups.
142     group_metadata_map->emplace(group_id, GroupMetadata());
143     return;
144   }
145   std::string group_name = root_event->GetGroupName();
146   // TODO(jihochoi): change event name instead.
147   if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
148     // Add the `step_name` stat for the user-defined root events only. When an
149     // XEvent is converted to a trace event, the trace event name is set to the
150     // `step_name` stat's value if present.
151     root_event->AddStepName(group_name);
152   }
153   (*group_metadata_map)[group_id].name = std::move(group_name);
154 }
155 
IsTfDataEvent(const EventNode & event_node)156 bool IsTfDataEvent(const EventNode& event_node) {
157   return event_node.FindParent(HostEventType::kTfDataCapturedFunctionRun) ||
158          event_node.FindParent(
159              HostEventType::kTfDataCapturedFunctionRunAsync) ||
160          event_node.FindParent(
161              HostEventType::kTfDataCapturedFunctionRunInstantiated) ||
162          event_node.FindParent(
163              HostEventType::kTfDataCapturedFunctionRunWithBorrowedArgs);
164 }
165 
166 struct ContextTypeAndId {
167   int type;
168   uint64 id;
169 };
170 
GetLegacyProducerContext(const XEventVisitor & event)171 absl::optional<ContextTypeAndId> GetLegacyProducerContext(
172     const XEventVisitor& event) {
173   absl::optional<ContextTypeAndId> type_and_id;
174   absl::optional<int64> event_type = event.Type();
175   if (event_type.has_value()) {
176     switch (*event_type) {
177       case HostEventType::kTraceContext:
178       case HostEventType::kFunctionRun:
179       case HostEventType::kSessionRun:
180       case HostEventType::kRunGraph: {
181         absl::optional<XStatVisitor> stat = event.GetStat(StatType::kStepId);
182         if (stat.has_value()) {
183           type_and_id = {static_cast<int>(ContextType::kTfExecutor),
184                          static_cast<uint64>(stat->IntValue())};
185         }
186         break;
187       }
188       case HostEventType::kCallOp:
189       case HostEventType::kNumericalGradientOpEvalRight:
190       case HostEventType::kNumericalGradientOpEvalLeft:
191       case HostEventType::kSymbolicGradientOp:
192       case HostEventType::kRemoteCallOp:
193       case HostEventType::kIfOp:
194       case HostEventType::kCaseOp:
195       case HostEventType::kPartitionedCallOp: {
196         // TODO(b/154510598): Fix handling of the loop ops.
197         // case HostEventType::kWhileOpEvalCond:
198         // case HostEventType::kWhileOpStartBody:
199         // case HostEventType::kForOp:
200         // case HostEventType::kParallelForOp:
201         // case HostEventType::kForeverOp:
202         absl::optional<XStatVisitor> stat =
203             event.GetStat(StatType::kFunctionStepId);
204         if (stat.has_value()) {
205           type_and_id = {static_cast<int>(ContextType::kTfExecutor),
206                          static_cast<uint64>(stat->IntValue())};
207         }
208         break;
209       }
210       default:
211         break;
212     }
213   }
214   return type_and_id;
215 }
216 
GetLegacyConsumerContext(const XEventVisitor & event)217 absl::optional<ContextTypeAndId> GetLegacyConsumerContext(
218     const XEventVisitor& event) {
219   absl::optional<ContextTypeAndId> type_and_id;
220   absl::optional<int64> event_type = event.Type();
221   if (event_type.has_value()) {
222     switch (*event_type) {
223       case HostEventType::kExecutorStateProcess:
224       case HostEventType::kExecutorDoneCallback:
225       case HostEventType::kRunGraphDone: {
226         absl::optional<XStatVisitor> stat = event.GetStat(StatType::kStepId);
227         if (stat.has_value()) {
228           type_and_id = {static_cast<int>(ContextType::kTfExecutor),
229                          static_cast<uint64>(stat->IntValue())};
230         }
231         break;
232       }
233       default:
234         break;
235     }
236   }
237   return type_and_id;
238 }
239 
IsLegacyRootEvent(const XEventVisitor & event)240 bool IsLegacyRootEvent(const XEventVisitor& event) {
241   static const auto* const kRootEvents = new absl::flat_hash_set<int64>{
242       HostEventType::kTraceContext, HostEventType::kFunctionRun,
243       HostEventType::kSessionRun, HostEventType::kRunGraph};
244   return event.Type().has_value() && kRootEvents->contains(*event.Type());
245 }
246 
247 using Comparator = std::function<bool(const EventNode*)>;
248 
FindParentWithComparator(const Comparator & comparator,const EventNode * node,bool include_self)249 const EventNode* FindParentWithComparator(const Comparator& comparator,
250                                           const EventNode* node,
251                                           bool include_self) {
252   std::queue<const EventNode*> nodes;
253   absl::flat_hash_set<const EventNode*> seen = {node};
254   if (include_self) {
255     nodes.push(node);
256   } else {
257     for (const EventNode* parent : node->GetParents()) {
258       nodes.push(parent);
259       seen.insert(parent);
260     }
261   }
262   while (!nodes.empty()) {
263     const EventNode* node = nodes.front();
264     nodes.pop();
265     if (comparator(node)) return node;
266     for (const EventNode* parent : node->GetParents()) {
267       if (seen.contains(parent)) continue;
268       nodes.push(parent);
269       seen.insert(parent);
270     }
271   }
272   return nullptr;
273 }
274 
275 // Returns true if none of its ancestors is a root event.
IsTopRoot(const EventNode * event)276 bool IsTopRoot(const EventNode* event) {
277   // If it is already grouped, it is not a top root.
278   if (event->GetGroupId().has_value()) return false;
279   const EventNode* root_parent = FindParentWithComparator(
280       [](const EventNode* node) { return node->IsRoot(); }, event,
281       /*include_self=*/false);
282   return root_parent == nullptr;
283 }
284 
SortEventList(EventList * event_list)285 void SortEventList(EventList* event_list) {
286   absl::c_sort(*event_list, [](const EventNode* e1, const EventNode* e2) {
287     return e1->GetEventVisitor().TimestampPs() <
288            e2->GetEventVisitor().TimestampPs();
289   });
290 }
291 
292 // Returns true if it has JAX-related events.
HasJaxEvent(const EventNodeMap & event_node_map)293 bool HasJaxEvent(const EventNodeMap& event_node_map) {
294   return event_node_map.contains(HostEventType::kExecuteOnLocalDevices);
295 }
296 
IsIteratorEventType(absl::optional<int64> event_type)297 bool IsIteratorEventType(absl::optional<int64> event_type) {
298   return event_type == HostEventType::kIterator ||
299          event_type == HostEventType::kDeviceInputPipelineSecondIterator;
300 }
301 
302 }  // namespace
303 
304 // Returns true if TF's loop ops exist in the given XSpace's metadata.
CheckLoopOp(const XSpace & space)305 bool CheckLoopOp(const XSpace& space) {
306   for (const XPlane& plane : space.planes()) {
307     for (const auto& event_metadata : plane.event_metadata()) {
308       absl::optional<int64> event_type =
309           FindHostEventType(event_metadata.second.name());
310       if (!event_type.has_value()) continue;
311       switch (*event_type) {
312         case HostEventType::kWhileOpEvalCond:
313         case HostEventType::kWhileOpStartBody:
314         case HostEventType::kForOp:
315         case HostEventType::kParallelForOp:
316         case HostEventType::kForeverOp:
317           return true;
318         default:
319           break;
320       }
321     }
322   }
323   return false;
324 }
325 
EventNode(const XPlaneVisitor * plane,XLine * raw_line,XEvent * raw_event)326 EventNode::EventNode(const XPlaneVisitor* plane, XLine* raw_line,
327                      XEvent* raw_event)
328     : plane_(plane),
329       visitor_(plane, raw_line, raw_event),
330       raw_line_(raw_line),
331       raw_event_(raw_event) {
332   absl::optional<int> producer_type;
333   absl::optional<uint64> producer_id;
334   absl::optional<int> consumer_type;
335   absl::optional<uint64> consumer_id;
336 
337   visitor_.ForEachStat([&](const XStatVisitor& stat) {
338     if (!stat.Type().has_value()) return;
339     switch (*stat.Type()) {
340       case StatType::kProducerType:
341         producer_type = stat.IntValue();
342         break;
343       case StatType::kProducerId:
344         producer_id = stat.IntOrUintValue();
345         break;
346       case StatType::kConsumerType:
347         consumer_type = stat.IntValue();
348         break;
349       case StatType::kConsumerId:
350         consumer_id = stat.IntOrUintValue();
351         break;
352       case StatType::kIsRoot:
353         is_root_ = stat.IntValue();
354         break;
355       case StatType::kIsAsync:
356         is_async_ = stat.IntValue();
357         break;
358       default:
359         break;
360     }
361   });
362 
363   // Support legacy traces.
364   if (!producer_type.has_value() || !producer_id.has_value()) {
365     if (auto producer_context = GetLegacyProducerContext(visitor_)) {
366       producer_type = producer_context->type;
367       producer_id = producer_context->id;
368     }
369   }
370   if (!consumer_type.has_value() || !consumer_id.has_value()) {
371     if (auto consumer_context = GetLegacyConsumerContext(visitor_)) {
372       consumer_type = consumer_context->type;
373       consumer_id = consumer_context->id;
374     }
375   }
376   is_root_ = is_root_ || IsLegacyRootEvent(visitor_);
377 
378   if (producer_type.has_value() && producer_id.has_value()) {
379     producer_context_ = {*producer_type, *producer_id};
380   }
381   if (consumer_type.has_value() && consumer_id.has_value()) {
382     consumer_context_ = {*consumer_type, *consumer_id};
383   }
384 }
385 
EventNode(const EventNode & event_node)386 EventNode::EventNode(const EventNode& event_node)
387     : EventNode(event_node.plane_, event_node.raw_line_,
388                 event_node.raw_event_) {}
389 
GetContextStat(int64 stat_type) const390 absl::optional<XStatVisitor> EventNode::GetContextStat(int64 stat_type) const {
391   std::queue<const EventNode*> nodes;
392   absl::flat_hash_set<const EventNode*> seen = {this};
393   nodes.push(this);
394   while (!nodes.empty()) {
395     const EventNode* node = nodes.front();
396     nodes.pop();
397     if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
398       return stat;
399     }
400     for (const EventNode* parent : node->GetParents()) {
401       if (seen.contains(parent)) continue;
402       nodes.push(parent);
403       seen.insert(parent);
404     }
405   }
406   return absl::nullopt;
407 }
408 
GetGroupName() const409 std::string EventNode::GetGroupName() const {
410   std::string name;
411   if (absl::optional<XStatVisitor> stat =
412           GetContextStat(StatType::kGraphType)) {
413     absl::StrAppend(&name, stat->StrOrRefValue(), " ");
414   } else if (!(IsImplicitRootEvent(visitor_))) {
415     absl::StrAppend(&name, GetEventVisitor().Name(), " ");
416   }
417   int64 step_num = group_id_.value_or(0);
418   if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
419     step_num = stat->IntValue();
420   } else if (absl::optional<XStatVisitor> stat =
421                  GetContextStat(StatType::kStepNum)) {
422     step_num = stat->IntValue();
423   }
424   absl::StrAppend(&name, step_num);
425   return name;
426 }
427 
FindOrAddStatByType(int64 stat_type)428 XStat* EventNode::FindOrAddStatByType(int64 stat_type) {
429   const XStatMetadata* stat_metadata = plane_->GetStatMetadataByType(stat_type);
430   DCHECK(stat_metadata != nullptr);
431   return FindOrAddMutableStat(*stat_metadata, raw_event_);
432 }
433 
SetGroupId(int64 group_id)434 void EventNode::SetGroupId(int64 group_id) {
435   group_id_ = group_id;
436   FindOrAddStatByType(StatType::kGroupId)->set_int64_value(group_id);
437 }
438 
PropagateGroupId(int64 group_id,GroupMetadataMap * group_metadata_map)439 void EventNode::PropagateGroupId(int64 group_id,
440                                  GroupMetadataMap* group_metadata_map) {
441   std::queue<EventNode*> nodes;
442   absl::flat_hash_set<EventNode*> seen = {this};
443   nodes.push(this);
444   while (!nodes.empty()) {
445     EventNode* node = nodes.front();
446     nodes.pop();
447     absl::optional<int64> node_group_id = node->GetGroupId();
448     if (node_group_id.has_value()) {
449       if (*node_group_id != group_id) {
450         (*group_metadata_map)[group_id].children.insert(*node_group_id);
451         (*group_metadata_map)[*node_group_id].parents.insert(group_id);
452       }
453     } else {
454       node->SetGroupId(group_id);
455       for (EventNode* child : node->GetChildren()) {
456         if (seen.contains(child)) continue;
457         nodes.push(child);
458         seen.insert(child);
459       }
460     }
461   }
462 }
463 
AddStepName(absl::string_view step_name)464 void EventNode::AddStepName(absl::string_view step_name) {
465   FindOrAddStatByType(StatType::kStepName)
466       ->set_str_value(step_name.data(), step_name.size());
467 }
468 
AddSelectedGroupIds(const GroupMetadataMap & group_metadata_map)469 void EventNode::AddSelectedGroupIds(
470     const GroupMetadataMap& group_metadata_map) {
471   const auto& group_metadata = group_metadata_map.at(*group_id_);
472   std::vector<int64> group_ids;
473   group_ids.reserve(1 + group_metadata.parents.size() +
474                     group_metadata.children.size());
475   group_ids.push_back(*group_id_);
476   group_ids.insert(group_ids.end(), group_metadata.parents.begin(),
477                    group_metadata.parents.end());
478   group_ids.insert(group_ids.end(), group_metadata.children.begin(),
479                    group_metadata.children.end());
480   FindOrAddStatByType(StatType::kSelectedGroupIds)
481       ->set_str_value(
482           absl::StrCat("?selected_group_ids=", absl::StrJoin(group_ids, ",")));
483 }
484 
SetIsEager(bool is_eager)485 void EventNode::SetIsEager(bool is_eager) {
486   FindOrAddStatByType(StatType::kIsEager)->set_int64_value(is_eager ? 1 : 0);
487 }
488 
IsEager()489 bool EventNode::IsEager() {
490   // It is eagerly executed if its trace context includes the EagerKernelExecute
491   // event (which may execute an op eagerly or through the TF executor) but not
492   // the TF executor event.
493   return FindParent(HostEventType::kExecutorStateProcess) == nullptr &&
494          FindParent(HostEventType::kEagerKernelExecute) != nullptr;
495 }
496 
FindParent(int64 event_type) const497 const EventNode* EventNode::FindParent(int64 event_type) const {
498   return FindParentWithComparator(
499       [event_type](const EventNode* node) {
500         return node->GetEventVisitor().Type() == event_type;
501       },
502       this, /*include_self=*/true);
503 }
504 
StartsBefore(const EventNode & other) const505 bool EventNode::StartsBefore(const EventNode& other) const {
506   return GetEventVisitor().TimestampPs() <=
507          other.GetEventVisitor().TimestampPs();
508 }
509 
ConnectIntraThread(XPlane * plane,XPlaneVisitor * visitor,ContextGroupMap * context_groups)510 void EventForest::ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor,
511                                      ContextGroupMap* context_groups) {
512   // TODO(b/149095099): avoid string comparison.
513   bool is_host_plane = (visitor->Name() == kHostThreadsPlaneName);
514   for (auto& line : *plane->mutable_lines()) {
515     std::vector<EventNode*> parent_nodes;
516     for (auto& event : *line.mutable_events()) {
517       auto cur_node = absl::make_unique<EventNode>(visitor, &line, &event);
518       // Update `context_groups` for `ConnectInterThread`.
519       SetContextGroup(cur_node.get(), context_groups);
520       // Update `root_events_` for `CreateEventGroup`.
521       if (cur_node->IsRoot()) root_events_.push_back(cur_node.get());
522       // Async events are ignored when processing the nesting relationship.
523       if (cur_node->IsAsync()) continue;
524       while (!parent_nodes.empty()) {
525         EventNode* parent_node = parent_nodes.back();
526         if (parent_node->GetEventVisitor().GetTimespan().Includes(
527                 cur_node->GetEventVisitor().GetTimespan())) {
528           parent_node->AddChild(cur_node.get());
529           break;
530         } else {
531           parent_nodes.pop_back();
532         }
533       }
534       parent_nodes.push_back(cur_node.get());
535       // event_node_map_ keeps cur_node alive.
536       event_node_map_[GetEventType(is_host_plane, *cur_node)].push_back(
537           std::move(cur_node));
538     }
539   }
540 }
541 
ConnectInterThread(const std::vector<InterThreadConnectInfo> & connect_info_list)542 void EventForest::ConnectInterThread(
543     const std::vector<InterThreadConnectInfo>& connect_info_list) {
544   for (const auto& connect_info : connect_info_list) {
545     absl::flat_hash_map<std::vector<uint64>, EventNode*> connect_map;
546     const std::vector<int64>& parent_stat_types =
547         connect_info.parent_stat_types;
548     const std::vector<int64>* child_stat_types = &connect_info.child_stat_types;
549     if (child_stat_types->empty()) {
550       child_stat_types = &parent_stat_types;
551     }
552     if (auto parent_event_node_list =
553             gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) {
554       for (const auto& parent_event_node : *parent_event_node_list) {
555         std::vector<uint64> stats;
556         for (auto stat_type : parent_stat_types) {
557           absl::optional<XStatVisitor> stat =
558               parent_event_node->GetContextStat(stat_type);
559           if (!stat) break;
560           stats.push_back(stat->IntOrUintValue());
561         }
562         if (stats.size() == parent_stat_types.size()) {
563           connect_map[stats] = parent_event_node.get();
564         }
565       }
566     }
567     if (auto child_event_node_list =
568             gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) {
569       for (const auto& child_event_node : *child_event_node_list) {
570         std::vector<uint64> stats;
571         for (auto stat_type : *child_stat_types) {
572           absl::optional<XStatVisitor> stat =
573               child_event_node->GetContextStat(stat_type);
574           if (!stat) break;
575           stats.push_back(stat->IntOrUintValue());
576         }
577         if (stats.size() == child_stat_types->size()) {
578           if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
579             parent_event_node->AddChild(child_event_node.get());
580           }
581         }
582       }
583     }
584   }
585 }
586 
ProcessUserDefinedRootEvents(const std::vector<int64> & user_defined_root_event_types)587 void EventForest::ProcessUserDefinedRootEvents(
588     const std::vector<int64 /*EventType*/>& user_defined_root_event_types) {
589   for (int64 user_defined_root_event_type : user_defined_root_event_types) {
590     if (auto root_events =
591             gtl::FindOrNull(event_node_map_, user_defined_root_event_type)) {
592       for (const auto& root_event : *root_events) {
593         root_event->SetIsRoot(true);
594         root_events_.push_back(root_event.get());
595       }
596     }
597   }
598 }
599 
CreateEventGroups()600 void EventForest::CreateEventGroups() {
601   // Handle inference batching profiles.
602   if (event_node_map_.contains(HostEventType::kProcessBatch)) {
603     // Assign group_id per batch.
604     for (const auto& process_batch_node :
605          event_node_map_[HostEventType::kProcessBatch]) {
606       ProcessRootEvent(next_group_id_++, /*set_step_name=*/false,
607                        process_batch_node.get(), &group_metadata_map_);
608     }
609     HostEventType request_event_type =
610         event_node_map_.contains(HostEventType::kBatchingSessionRun)
611             ? HostEventType::kBatchingSessionRun
612             : HostEventType::kSessionRun;
613     if (auto request_events =
614             gtl::FindOrNull(event_node_map_, request_event_type)) {
615       // Assign group_id per request.
616       for (const auto& request_event : *request_events) {
617         ProcessRootEvent(next_group_id_++, /*set_step_name=*/false,
618                          request_event.get(), &group_metadata_map_);
619         // Also, set a helper stat for selected_group_ids.
620         request_event->AddSelectedGroupIds(group_metadata_map_);
621       }
622     }
623     // Set a helper stat for selected_group_ids per batch.
624     for (const auto& process_batch_node :
625          event_node_map_[HostEventType::kProcessBatch]) {
626       process_batch_node->AddSelectedGroupIds(group_metadata_map_);
627     }
628     return;
629   }
630   // Create a group for each TF loop iteration in non-JAX profiles.
631   if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
632     for (EventNode* root_event : tf_loop_root_events_) {
633       ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event,
634                        &group_metadata_map_);
635     }
636     return;
637   }
638   SortEventList(&root_events_);
639   // Create a group for each top root event while ignoring TF's legacy root
640   // events for JAX profiles.
641   for (EventNode* root_event : root_events_) {
642     if (IsTopRoot(root_event) &&
643         (!HasJaxEvent(event_node_map_) ||
644          !IsLegacyRootEvent(root_event->GetEventVisitor()))) {
645       ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event,
646                        &group_metadata_map_);
647     }
648   }
649 }
650 
MarkEagerlyExecutedGpuKernels()651 void EventForest::MarkEagerlyExecutedGpuKernels() {
652   auto kernel_execute_event_node_list =
653       gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
654   if (!kernel_execute_event_node_list) return;
655   for (auto& kernel_execute_event_node : *kernel_execute_event_node_list) {
656     kernel_execute_event_node->SetIsEager(kernel_execute_event_node->IsEager());
657   }
658 }
659 
MarkEagerlyExecutedCpuTfOps()660 void EventForest::MarkEagerlyExecutedCpuTfOps() {
661   auto tf_op_run_event_node_list =
662       gtl::FindOrNull(event_node_map_, HostEventType::kTfOpRun);
663   if (!tf_op_run_event_node_list) return;
664   for (auto& tf_op_run_event_node : *tf_op_run_event_node_list) {
665     tf_op_run_event_node->SetIsEager(tf_op_run_event_node->IsEager());
666   }
667 }
668 
ProcessTensorFlowLoop()669 void EventForest::ProcessTensorFlowLoop() {
670   struct TensorFlowLoopIteration {
671     EventNode* first_event = nullptr;
672     std::vector<EventNode*> events;
673   };
674   using TensorFlowLoop = std::map<int64 /*iter_num*/, TensorFlowLoopIteration>;
675   absl::flat_hash_map<int64 /*step_id*/, TensorFlowLoop> tf_loops;
676 
677   // Sort the TF executor events by TF function/session (step_id) and iter_num.
678   auto executor_event_list =
679       gtl::FindOrNull(event_node_map_, HostEventType::kExecutorStateProcess);
680   if (!executor_event_list) return;
681   for (auto& executor_event : *executor_event_list) {
682     if (IsTfDataEvent(*executor_event)) continue;
683     absl::optional<XStatVisitor> step_id_stat =
684         executor_event->GetContextStat(StatType::kStepId);
685     absl::optional<XStatVisitor> iter_num_stat =
686         executor_event->GetContextStat(StatType::kIterNum);
687     if (!step_id_stat || !iter_num_stat) continue;
688     int64 step_id = step_id_stat->IntValue();
689     TensorFlowLoop& tf_loop = tf_loops[step_id];
690     TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
691     if (!iteration.first_event ||
692         executor_event->StartsBefore(*iteration.first_event)) {
693       iteration.first_event = executor_event.get();
694     }
695     iteration.events.push_back(executor_event.get());
696   }
697 
698   // Sort the TF loops by start time.
699   std::map<int64 /*start_time*/, int64 /*step_id*/> sorted_tf_loops;
700   for (const auto& step_id_and_tf_loop : tf_loops) {
701     auto& iterations = step_id_and_tf_loop.second;
702     // Filter out TF function/session without loops.
703     if (iterations.size() == 1 && iterations.count(0)) continue;
704     int64 start_time = iterations.cbegin()
705                            ->second.first_event->GetEventVisitor()
706                            .TimestampPs();
707     DCHECK_EQ(sorted_tf_loops.count(start_time), 0);
708     sorted_tf_loops[start_time] = step_id_and_tf_loop.first;
709   }
710 
711   // Register the first event of each iteration as a root event. Also, add the
712   // other events of the iteration as child to the root event.
713   bool next_group_id_updated = false;
714   for (const auto& start_time_and_step_id : sorted_tf_loops) {
715     TensorFlowLoop& tf_loop = tf_loops[start_time_and_step_id.second];
716     for (auto& iter_num_and_iteration : tf_loop) {
717       if (!next_group_id_updated) {
718         // Set next_group_id_ to the first iter_num of the first TF loop. This
719         // is necessary later when selecting the intersection of the steps from
720         // multiple hosts.
721         next_group_id_ = iter_num_and_iteration.first;
722         next_group_id_updated = true;
723       }
724       TensorFlowLoopIteration& iteration = iter_num_and_iteration.second;
725       EventNode* root_event = iteration.first_event;
726       tf_loop_root_events_.push_back(root_event);
727       for (EventNode* event : iteration.events) {
728         if (event == root_event) continue;
729         root_event->AddChild(event);
730       }
731     }
732   }
733 }
734 
ProcessWorker()735 void EventForest::ProcessWorker() {
736   auto eager_kernel_execute_event_list =
737       gtl::FindOrNull(event_node_map_, HostEventType::kEagerKernelExecute);
738   if (!eager_kernel_execute_event_list) return;
739   // The last EagerKernelExecute with a FunctionRun child.
740   EventNode* root_event = nullptr;
741   for (auto& eager_kernel_execute_event : *eager_kernel_execute_event_list) {
742     if (HasFunctionRun(eager_kernel_execute_event.get())) {
743       // A function op becomes a new root.
744       root_event = eager_kernel_execute_event.get();
745       root_event->SetIsRoot(true);
746       root_events_.push_back(root_event);
747     } else if (root_event) {
748       // Add non-function eager ops as child.
749       root_event->AddChild(eager_kernel_execute_event.get());
750     }
751   }
752 }
753 
ProcessModelIds()754 void EventForest::ProcessModelIds() {
755   auto session_run_event_list =
756       gtl::FindOrNull(event_node_map_, HostEventType::kSessionRun);
757   if (!session_run_event_list) return;
758   for (const auto& session_run_event : *session_run_event_list) {
759     auto group_id = session_run_event->GetGroupId();
760     if (!group_id.has_value()) continue;
761     absl::optional<XStatVisitor> model_id =
762         session_run_event->GetEventVisitor().GetStat(StatType::kModelId);
763     if (!model_id.has_value()) continue;
764     group_metadata_map_[*group_id].model_id = model_id->ToString();
765   }
766 }
767 
AddPlane(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XPlane * plane)768 void EventForest::AddPlane(
769     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
770     XPlane* plane) {
771   CreateStatMetadata(plane);
772   planes_.push_back({plane, visitor_factory(plane)});
773 }
774 
AddSpace(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XSpace * space)775 void EventForest::AddSpace(
776     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
777     XSpace* space) {
778   for (XPlane& plane : *space->mutable_planes()) {
779     AddPlane(visitor_factory, &plane);
780   }
781 }
782 
AddPlanes(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,const std::vector<XPlane * > & planes)783 void EventForest::AddPlanes(
784     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
785     const std::vector<XPlane*>& planes) {
786   for (XPlane* plane : planes) {
787     AddPlane(visitor_factory, plane);
788   }
789 }
790 
ConnectEvents(const std::vector<InterThreadConnectInfo> & connect_info_list)791 void EventForest::ConnectEvents(
792     const std::vector<InterThreadConnectInfo>& connect_info_list) {
793   ContextGroupMap context_groups;
794   for (auto& plane_visitor : planes_) {
795     ConnectIntraThread(plane_visitor.first, &plane_visitor.second,
796                        &context_groups);
797   }
798   ConnectInterThread(connect_info_list);
799   ConnectContextGroups(context_groups);
800 }
801 
ConnectTfDataEvents()802 void EventForest::ConnectTfDataEvents() {
803   absl::flat_hash_map<std::pair<int64 /*iterator_id*/, int64 /*element_id*/>,
804                       std::vector<EventNode*>>
805       produce_iterator_map;
806   uint64 num_producers = 0;
807   for (HostEventType event_type :
808        {HostEventType::kPrefetchProduce,
809         HostEventType::kParallelInterleaveProduce,
810         HostEventType::kParallelMapProduce, HostEventType::kMapAndBatchProduce,
811         HostEventType::kParseExampleProduce}) {
812     auto produce_event_list = gtl::FindOrNull(event_node_map_, event_type);
813     if (!produce_event_list) continue;
814     VLOG(1) << produce_event_list->size() << " "
815             << GetHostEventTypeStr(event_type) << " events found.";
816     for (auto& produce_event : *produce_event_list) {
817       absl::optional<XStatVisitor> element_id =
818           produce_event->GetEventVisitor().GetStat(StatType::kElementId);
819       if (!element_id.has_value()) continue;
820       for (EventNode* produce_iterator : produce_event->GetChildren()) {
821         if (IsIteratorEventType(produce_iterator->GetEventVisitor().Type())) {
822           absl::optional<XStatVisitor> iterator_id =
823               produce_iterator->GetEventVisitor().GetStat(StatType::kParentId);
824           if (!iterator_id.has_value()) break;
825           produce_iterator_map[{iterator_id->IntValue(),
826                                 element_id->IntValue()}]
827               .push_back(produce_iterator);
828           ++num_producers;
829           break;
830         }
831       }
832     }
833   }
834   VLOG(1) << num_producers << " producer iterators found.";
835   uint64 num_matched = 0;
836   for (HostEventType event_type :
837        {HostEventType::kPrefetchConsume,
838         HostEventType::kParallelInterleaveConsume,
839         HostEventType::kParallelMapConsume, HostEventType::kMapAndBatchConsume,
840         HostEventType::kParseExampleConsume}) {
841     auto consume_event_list = gtl::FindOrNull(event_node_map_, event_type);
842     if (!consume_event_list) continue;
843     VLOG(1) << consume_event_list->size() << " "
844             << GetHostEventTypeStr(event_type) << " events found.";
845     for (auto& consume_event : *consume_event_list) {
846       absl::optional<XStatVisitor> element_id =
847           consume_event->GetEventVisitor().GetStat(StatType::kElementId);
848       if (!element_id.has_value()) continue;
849       if (consume_event->GetParents().empty()) continue;
850       // consume_event is nested by consumer_iterator and does not have other
851       // parents.
852       EventNode* consume_iterator = consume_event->GetParents().at(0);
853       if (!consume_iterator ||
854           !IsIteratorEventType(consume_iterator->GetEventVisitor().Type())) {
855         continue;
856       }
857       absl::optional<XStatVisitor> iterator_id =
858           consume_iterator->GetEventVisitor().GetStat(StatType::kStepId);
859       if (!iterator_id.has_value()) continue;
860       if (auto produce_iterators = gtl::FindOrNull(
861               produce_iterator_map, std::make_pair(iterator_id->IntValue(),
862                                                    element_id->IntValue()))) {
863         for (EventNode* produce_iterator : *produce_iterators) {
864           consume_iterator->AddChild(produce_iterator);
865           ++num_matched;
866         }
867       }
868     }
869   }
870   VLOG(1) << num_matched << " consumer iterators matched.";
871 }
872 
GroupEvents(const std::vector<int64> & user_defined_root_event_types)873 void EventForest::GroupEvents(
874     const std::vector<int64>& user_defined_root_event_types) {
875   ProcessTensorFlowLoop();
876   ProcessWorker();
877   ProcessUserDefinedRootEvents(user_defined_root_event_types);
878   CreateEventGroups();
879   MarkEagerlyExecutedGpuKernels();
880   MarkEagerlyExecutedCpuTfOps();
881   ProcessModelIds();
882 }
883 
CreateInterThreadConnectInfoList()884 std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
885   std::vector<InterThreadConnectInfo> connect_info_list = {
886       {HostEventType::kExecutorStateProcess,
887        HostEventType::kIteratorGetNextOp,
888        {StatType::kStepId, StatType::kIterNum}},
889       {HostEventType::kExecutorStateProcess,
890        HostEventType::kIteratorGetNextAsOptionalOp,
891        {StatType::kStepId, StatType::kIterNum}},
892       {HostEventType::kKernelLaunch,
893        HostEventType::kKernelExecute,
894        {StatType::kCorrelationId}}};
895   return connect_info_list;
896 }
897 
GroupTfEvents(XSpace * space,EventForest * event_forest)898 void GroupTfEvents(XSpace* space, EventForest* event_forest) {
899   if (CheckLoopOp(*space)) {
900     // TODO(b/154510598): Support TF's loop ops.
901     return;
902   }
903   std::vector<InterThreadConnectInfo> connect_info_list =
904       CreateInterThreadConnectInfoList();
905   event_forest->AddSpace(CreateTfXPlaneVisitor, space);
906   event_forest->ConnectEvents(connect_info_list);
907   event_forest->GroupEvents();
908 }
909 
GroupTfEvents(XSpace * space)910 void GroupTfEvents(XSpace* space) {
911   EventForest event_forest;
912   GroupTfEvents(space, &event_forest);
913 }
914 
915 }  // namespace profiler
916 }  // namespace tensorflow
917