• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/group_events.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <iterator>
21 #include <map>
22 #include <memory>
23 #include <queue>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/strings/match.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/profiler/lib/connected_traceme.h"
36 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
37 #include "tensorflow/core/profiler/utils/xplane_builder.h"
38 #include "tensorflow/core/profiler/utils/xplane_schema.h"
39 #include "tensorflow/core/profiler/utils/xplane_utils.h"
40 
41 namespace tensorflow {
42 namespace profiler {
43 namespace {
44 
45 // Creates stat metadata for the stats which may be added by grouping.
CreateStatMetadata(XPlane * plane)46 void CreateStatMetadata(XPlane* plane) {
47   XPlaneBuilder builder(plane);
48   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
49   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
50   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
51   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSelectedGroupIds));
52 }
53 
54 // Returns event type if it is a KernelLaunch or KernelExecute event.
GetKernelEventType(bool is_host_plane,const EventNode & event)55 absl::optional<int64> GetKernelEventType(bool is_host_plane,
56                                          const EventNode& event) {
57   if (event.GetEventVisitor().GetStat(StatType::kCorrelationId).has_value()) {
58     return is_host_plane ? HostEventType::kKernelLaunch
59                          : HostEventType::kKernelExecute;
60   }
61   return absl::nullopt;
62 }
63 
GetEventType(bool is_host_plane,const EventNode & event)64 int64 GetEventType(bool is_host_plane, const EventNode& event) {
65   if (absl::optional<int64> event_type = event.GetEventVisitor().Type()) {
66     return *event_type;
67   } else if (absl::optional<int64> kernel_event_type =
68                  GetKernelEventType(is_host_plane, event)) {
69     // KernelLaunch and KernelExecute event types are not supported by
70     // XPlaneVisitor and should be checked separately.
71     // TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
72     // KernelExecute event types.
73     return *kernel_event_type;
74   } else {
75     return HostEventType::kUnknownHostEventType;
76   }
77 }
78 
SetContextGroup(EventNode * event,ContextGroupMap * context_groups)79 void SetContextGroup(EventNode* event, ContextGroupMap* context_groups) {
80   auto producer = event->GetProducerContext();
81   if (producer.has_value()) {
82     ((*context_groups)[producer->type][producer->id])
83         .producers.push_back(event);
84   }
85   auto consumer = event->GetConsumerContext();
86   if (consumer.has_value()) {
87     ((*context_groups)[consumer->type][consumer->id])
88         .consumers.push_back(event);
89   }
90 }
91 
ConnectContextGroups(const ContextGroupMap & context_groups)92 void ConnectContextGroups(const ContextGroupMap& context_groups) {
93   for (auto& type_id_group : context_groups) {
94     for (auto& id_group : type_id_group.second) {
95       const ContextGroup& group = id_group.second;
96       for (EventNode* parent : group.producers) {
97         for (EventNode* child : group.consumers) {
98           parent->AddChild(child);
99         }
100       }
101     }
102   }
103 }
104 
CreateVirtualEvent(const XStat & step_id_stat,const XStat & iter_num_stat)105 std::unique_ptr<XEvent> CreateVirtualEvent(const XStat& step_id_stat,
106                                            const XStat& iter_num_stat) {
107   auto virtual_event = absl::make_unique<XEvent>();
108   *virtual_event->add_stats() = step_id_stat;
109   *virtual_event->add_stats() = iter_num_stat;
110   return virtual_event;
111 }
112 
HasFunctionRun(EventNode * event_node)113 bool HasFunctionRun(EventNode* event_node) {
114   for (EventNode* child : event_node->GetChildren()) {
115     if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
116       return true;
117     }
118   }
119   return false;
120 }
121 
IsImplicitRootEvent(const XEventVisitor & event)122 bool IsImplicitRootEvent(const XEventVisitor& event) {
123   static const auto* const kImplicitRootEvents = new absl::flat_hash_set<int64>{
124       HostEventType::kFunctionRun, HostEventType::kSessionRun,
125       HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
126   return event.Type().has_value() &&
127          kImplicitRootEvents->contains(*event.Type());
128 }
129 
ProcessRootEvent(int64_t group_id,EventNode * root_event,GroupMetadataMap * group_metadata_map)130 void ProcessRootEvent(int64_t group_id, EventNode* root_event,
131                       GroupMetadataMap* group_metadata_map) {
132   root_event->PropagateGroupId(group_id, group_metadata_map);
133   std::string group_name = root_event->GetGroupName();
134   // TODO(jihochoi): change event name instead.
135   if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
136     // Add the `step_name` stat for the user-defined root events only. When an
137     // XEvent is converted to a trace event, the trace event name is set to the
138     // `step_name` stat's value if present.
139     root_event->AddStepName(group_name);
140   }
141   (*group_metadata_map)[group_id].name = std::move(group_name);
142 }
143 
144 struct ContextTypeAndId {
145   int type;
146   uint64 id;
147 };
148 
GetLegacyProducerContext(const XEventVisitor & event)149 absl::optional<ContextTypeAndId> GetLegacyProducerContext(
150     const XEventVisitor& event) {
151   absl::optional<ContextTypeAndId> type_and_id;
152   absl::optional<int64> event_type = event.Type();
153   if (event_type.has_value()) {
154     switch (*event_type) {
155       case HostEventType::kTraceContext:
156       case HostEventType::kFunctionRun:
157       case HostEventType::kSessionRun:
158       case HostEventType::kRunGraph: {
159         absl::optional<XStatVisitor> stat = event.GetStat(StatType::kStepId);
160         if (stat.has_value()) {
161           type_and_id = {static_cast<int>(ContextType::kTfExecutor),
162                          static_cast<uint64>(stat->IntValue())};
163         }
164         break;
165       }
166       case HostEventType::kCallOp:
167       case HostEventType::kNumericalGradientOpEvalRight:
168       case HostEventType::kNumericalGradientOpEvalLeft:
169       case HostEventType::kSymbolicGradientOp:
170       case HostEventType::kRemoteCallOp:
171       case HostEventType::kIfOp:
172       case HostEventType::kCaseOp:
173       case HostEventType::kPartitionedCallOp: {
174         // TODO(b/154510598): Fix handling of the loop ops.
175         // case HostEventType::kWhileOpEvalCond:
176         // case HostEventType::kWhileOpStartBody:
177         // case HostEventType::kForOp:
178         // case HostEventType::kParallelForOp:
179         // case HostEventType::kForeverOp:
180         absl::optional<XStatVisitor> stat =
181             event.GetStat(StatType::kFunctionStepId);
182         if (stat.has_value()) {
183           type_and_id = {static_cast<int>(ContextType::kTfExecutor),
184                          static_cast<uint64>(stat->IntValue())};
185         }
186         break;
187       }
188       default:
189         break;
190     }
191   }
192   return type_and_id;
193 }
194 
GetLegacyConsumerContext(const XEventVisitor & event)195 absl::optional<ContextTypeAndId> GetLegacyConsumerContext(
196     const XEventVisitor& event) {
197   absl::optional<ContextTypeAndId> type_and_id;
198   absl::optional<int64> event_type = event.Type();
199   if (event_type.has_value()) {
200     switch (*event_type) {
201       case HostEventType::kExecutorStateProcess:
202       case HostEventType::kExecutorDoneCallback:
203       case HostEventType::kRunGraphDone: {
204         absl::optional<XStatVisitor> stat = event.GetStat(StatType::kStepId);
205         if (stat.has_value()) {
206           type_and_id = {static_cast<int>(ContextType::kTfExecutor),
207                          static_cast<uint64>(stat->IntValue())};
208         }
209         break;
210       }
211       default:
212         break;
213     }
214   }
215   return type_and_id;
216 }
217 
IsLegacyRootEvent(const XEventVisitor & event)218 bool IsLegacyRootEvent(const XEventVisitor& event) {
219   static const auto* const kRootEvents = new absl::flat_hash_set<int64>{
220       HostEventType::kTraceContext, HostEventType::kFunctionRun,
221       HostEventType::kSessionRun, HostEventType::kRunGraph};
222   return event.Type().has_value() && kRootEvents->contains(*event.Type());
223 }
224 
225 using Comparator = std::function<bool(const EventNode*)>;
226 
FindParentWithComparator(const Comparator & comparator,const EventNode * node,bool include_self)227 const EventNode* FindParentWithComparator(const Comparator& comparator,
228                                           const EventNode* node,
229                                           bool include_self) {
230   std::queue<const EventNode*> nodes;
231   absl::flat_hash_set<const EventNode*> seen = {node};
232   if (include_self) {
233     nodes.push(node);
234   } else {
235     for (const EventNode* parent : node->GetParents()) {
236       nodes.push(parent);
237       seen.insert(parent);
238     }
239   }
240   while (!nodes.empty()) {
241     const EventNode* node = nodes.front();
242     nodes.pop();
243     if (comparator(node)) return node;
244     for (const EventNode* parent : node->GetParents()) {
245       if (seen.contains(parent)) continue;
246       nodes.push(parent);
247       seen.insert(parent);
248     }
249   }
250   return nullptr;
251 }
252 
253 // Returns true if it has JAX-related events.
HasJaxEvent(const EventNodeMap & event_node_map)254 bool HasJaxEvent(const EventNodeMap& event_node_map) {
255   return event_node_map.contains(HostEventType::kExecuteOnLocalDevices);
256 }
257 
IsIteratorEventType(absl::optional<int64> event_type)258 bool IsIteratorEventType(absl::optional<int64> event_type) {
259   return event_type == HostEventType::kIterator ||
260          event_type == HostEventType::kDeviceInputPipelineSecondIterator;
261 }
262 
263 }  // namespace
264 
265 // Returns true if TF's loop ops exist in the given XSpace's metadata.
CheckLoopOp(const XSpace & space)266 bool CheckLoopOp(const XSpace& space) {
267   for (const XPlane& plane : space.planes()) {
268     for (const auto& event_metadata : plane.event_metadata()) {
269       absl::optional<int64> event_type =
270           FindHostEventType(event_metadata.second.name());
271       if (!event_type.has_value()) continue;
272       switch (*event_type) {
273         case HostEventType::kWhileOpEvalCond:
274         case HostEventType::kWhileOpStartBody:
275         case HostEventType::kForOp:
276         case HostEventType::kParallelForOp:
277         case HostEventType::kForeverOp:
278           return true;
279         default:
280           break;
281       }
282     }
283   }
284   return false;
285 }
286 
EventNode(const XPlaneVisitor * plane,XLine * raw_line,XEvent * raw_event)287 EventNode::EventNode(const XPlaneVisitor* plane, XLine* raw_line,
288                      XEvent* raw_event)
289     : plane_(plane),
290       visitor_(plane, raw_line, raw_event),
291       raw_line_(raw_line),
292       raw_event_(raw_event) {
293   absl::optional<int> producer_type;
294   absl::optional<uint64> producer_id;
295   absl::optional<int> consumer_type;
296   absl::optional<uint64> consumer_id;
297 
298   visitor_.ForEachStat([&](const XStatVisitor& stat) {
299     if (!stat.Type().has_value()) return;
300     switch (*stat.Type()) {
301       case StatType::kProducerType:
302         producer_type = stat.IntValue();
303         break;
304       case StatType::kProducerId:
305         producer_id = stat.IntOrUintValue();
306         break;
307       case StatType::kConsumerType:
308         consumer_type = stat.IntValue();
309         break;
310       case StatType::kConsumerId:
311         consumer_id = stat.IntOrUintValue();
312         break;
313       case StatType::kIsRoot:
314         root_level_ = stat.IntValue();
315         break;
316       case StatType::kIsAsync:
317         is_async_ = stat.IntValue();
318         break;
319       default:
320         break;
321     }
322   });
323 
324   // Support legacy traces.
325   if (!producer_type.has_value() || !producer_id.has_value()) {
326     if (auto producer_context = GetLegacyProducerContext(visitor_)) {
327       producer_type = producer_context->type;
328       producer_id = producer_context->id;
329     }
330   }
331   if (!consumer_type.has_value() || !consumer_id.has_value()) {
332     if (auto consumer_context = GetLegacyConsumerContext(visitor_)) {
333       consumer_type = consumer_context->type;
334       consumer_id = consumer_context->id;
335     }
336   }
337   root_level_ = root_level_ ? root_level_ : IsLegacyRootEvent(visitor_);
338 
339   if (producer_type.has_value() && producer_id.has_value()) {
340     producer_context_ = {*producer_type, *producer_id};
341   }
342   if (consumer_type.has_value() && consumer_id.has_value()) {
343     consumer_context_ = {*consumer_type, *consumer_id};
344   }
345 }
346 
EventNode(const EventNode & event_node)347 EventNode::EventNode(const EventNode& event_node)
348     : EventNode(event_node.plane_, event_node.raw_line_,
349                 event_node.raw_event_) {}
350 
GetContextStat(int64_t stat_type) const351 absl::optional<XStatVisitor> EventNode::GetContextStat(
352     int64_t stat_type) const {
353   std::queue<const EventNode*> nodes;
354   absl::flat_hash_set<const EventNode*> seen = {this};
355   nodes.push(this);
356   while (!nodes.empty()) {
357     const EventNode* node = nodes.front();
358     nodes.pop();
359     if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
360       return stat;
361     }
362     for (const EventNode* parent : node->GetParents()) {
363       if (seen.contains(parent)) continue;
364       nodes.push(parent);
365       seen.insert(parent);
366     }
367   }
368   return absl::nullopt;
369 }
370 
GetGroupName() const371 std::string EventNode::GetGroupName() const {
372   std::string name;
373   if (absl::optional<XStatVisitor> stat =
374           GetContextStat(StatType::kGraphType)) {
375     absl::StrAppend(&name, stat->StrOrRefValue(), " ");
376   } else if (!(IsImplicitRootEvent(visitor_))) {
377     absl::StrAppend(&name, GetEventVisitor().Name(), " ");
378   }
379   int64_t step_num = group_id_.value_or(0);
380   if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
381     step_num = stat->IntValue();
382   } else if (absl::optional<XStatVisitor> stat =
383                  GetContextStat(StatType::kStepNum)) {
384     step_num = stat->IntValue();
385   }
386   absl::StrAppend(&name, step_num);
387   return name;
388 }
389 
FindOrAddStatByType(int64_t stat_type)390 XStat* EventNode::FindOrAddStatByType(int64_t stat_type) {
391   const XStatMetadata* stat_metadata = plane_->GetStatMetadataByType(stat_type);
392   DCHECK(stat_metadata != nullptr);
393   return FindOrAddMutableStat(*stat_metadata, raw_event_);
394 }
395 
SetGroupId(int64_t group_id)396 void EventNode::SetGroupId(int64_t group_id) {
397   group_id_ = group_id;
398   FindOrAddStatByType(StatType::kGroupId)->set_int64_value(group_id);
399 }
400 
PropagateGroupId(int64_t group_id,GroupMetadataMap * group_metadata_map)401 void EventNode::PropagateGroupId(int64_t group_id,
402                                  GroupMetadataMap* group_metadata_map) {
403   std::queue<EventNode*> nodes;
404   absl::flat_hash_set<EventNode*> seen = {this};
405   nodes.push(this);
406   while (!nodes.empty()) {
407     EventNode* node = nodes.front();
408     nodes.pop();
409     absl::optional<int64> node_group_id = node->GetGroupId();
410     if (node_group_id.has_value()) {
411       if (*node_group_id != group_id) {
412         (*group_metadata_map)[group_id].children.insert(*node_group_id);
413         (*group_metadata_map)[*node_group_id].parents.insert(group_id);
414       }
415     } else {
416       node->SetGroupId(group_id);
417       for (EventNode* child : node->GetChildren()) {
418         if (seen.contains(child)) continue;
419         nodes.push(child);
420         seen.insert(child);
421       }
422     }
423   }
424 }
425 
AddStepName(absl::string_view step_name)426 void EventNode::AddStepName(absl::string_view step_name) {
427   FindOrAddStatByType(StatType::kStepName)
428       ->set_str_value(step_name.data(), step_name.size());
429 }
430 
AddSelectedGroupIds(const GroupMetadataMap & group_metadata_map)431 void EventNode::AddSelectedGroupIds(
432     const GroupMetadataMap& group_metadata_map) {
433   const auto& group_metadata = group_metadata_map.at(*group_id_);
434   std::vector<int64> group_ids;
435   group_ids.reserve(1 + group_metadata.parents.size() +
436                     group_metadata.children.size());
437   group_ids.push_back(*group_id_);
438   group_ids.insert(group_ids.end(), group_metadata.parents.begin(),
439                    group_metadata.parents.end());
440   group_ids.insert(group_ids.end(), group_metadata.children.begin(),
441                    group_metadata.children.end());
442   FindOrAddStatByType(StatType::kSelectedGroupIds)
443       ->set_str_value(
444           absl::StrCat("?selected_group_ids=", absl::StrJoin(group_ids, ",")));
445 }
446 
SetIsEager(bool is_eager)447 void EventNode::SetIsEager(bool is_eager) {
448   FindOrAddStatByType(StatType::kIsEager)->set_int64_value(is_eager ? 1 : 0);
449 }
450 
IsEager()451 bool EventNode::IsEager() {
452   // It is eagerly executed if its trace context includes the EagerKernelExecute
453   // event (which may execute an op eagerly or through the TF executor) but not
454   // the TF executor event.
455   return FindParent(HostEventType::kExecutorStateProcess) == nullptr &&
456          FindParent(HostEventType::kEagerKernelExecute) != nullptr;
457 }
458 
FindParent(int64_t event_type) const459 const EventNode* EventNode::FindParent(int64_t event_type) const {
460   return FindParentWithComparator(
461       [event_type](const EventNode* node) {
462         return node->GetEventVisitor().Type() == event_type;
463       },
464       this, /*include_self=*/true);
465 }
466 
ConnectIntraThread(XPlane * plane,XPlaneVisitor * visitor,ContextGroupMap * context_groups)467 void EventForest::ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor,
468                                      ContextGroupMap* context_groups) {
469   // TODO(b/149095099): avoid string comparison.
470   bool is_host_plane = (visitor->Name() == kHostThreadsPlaneName);
471   for (auto& line : *plane->mutable_lines()) {
472     std::vector<EventNode*> parent_nodes;
473     for (auto& event : *line.mutable_events()) {
474       auto cur_node = absl::make_unique<EventNode>(visitor, &line, &event);
475       // Update `context_groups` for `ConnectInterThread`.
476       SetContextGroup(cur_node.get(), context_groups);
477       // Async events are ignored when processing the nesting relationship.
478       if (cur_node->IsAsync()) continue;
479       while (!parent_nodes.empty()) {
480         EventNode* parent_node = parent_nodes.back();
481         if (parent_node->GetEventVisitor().GetTimespan().Includes(
482                 cur_node->GetEventVisitor().GetTimespan())) {
483           parent_node->AddChild(cur_node.get());
484           break;
485         } else {
486           parent_nodes.pop_back();
487         }
488       }
489       parent_nodes.push_back(cur_node.get());
490       // event_node_map_ keeps cur_node alive.
491       event_node_map_[GetEventType(is_host_plane, *cur_node)].push_back(
492           std::move(cur_node));
493     }
494   }
495 }
496 
ConnectInterThread(const std::vector<InterThreadConnectInfo> & connect_info_list)497 void EventForest::ConnectInterThread(
498     const std::vector<InterThreadConnectInfo>& connect_info_list) {
499   for (const auto& connect_info : connect_info_list) {
500     absl::flat_hash_map<std::vector<uint64>, EventNode*> connect_map;
501     const std::vector<int64>& parent_stat_types =
502         connect_info.parent_stat_types;
503     const std::vector<int64>* child_stat_types = &connect_info.child_stat_types;
504     if (child_stat_types->empty()) {
505       child_stat_types = &parent_stat_types;
506     }
507     if (auto parent_event_node_list =
508             gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) {
509       for (const auto& parent_event_node : *parent_event_node_list) {
510         std::vector<uint64> stats;
511         for (auto stat_type : parent_stat_types) {
512           absl::optional<XStatVisitor> stat =
513               parent_event_node->GetContextStat(stat_type);
514           if (!stat) break;
515           stats.push_back(stat->IntOrUintValue());
516         }
517         if (stats.size() == parent_stat_types.size()) {
518           connect_map[stats] = parent_event_node.get();
519         }
520       }
521     }
522     if (auto child_event_node_list =
523             gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) {
524       for (const auto& child_event_node : *child_event_node_list) {
525         std::vector<uint64> stats;
526         for (auto stat_type : *child_stat_types) {
527           absl::optional<XStatVisitor> stat =
528               child_event_node->GetContextStat(stat_type);
529           if (!stat) break;
530           stats.push_back(stat->IntOrUintValue());
531         }
532         if (stats.size() == child_stat_types->size()) {
533           if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
534             parent_event_node->AddChild(child_event_node.get());
535           }
536         }
537       }
538     }
539   }
540 }
541 
542 // Returns whether a root event needs grouping.
RootNeedsGrouping(const EventNode * root)543 bool RootNeedsGrouping(const EventNode* root) {
544   // No grouping is needed if it is already grouped.
545   if (root->GetGroupId().has_value()) return false;
546   // If there is a parent node with the same root level, skip grouping at <root>
547   // and later apply grouping at the parent node.
548   // If there is a parent node with a different root level, apply grouping at
549   // <root>, and later apply grouping at the parent node. Root events with
550   // different levels are grouped separately.
551   const EventNode* root_parent = FindParentWithComparator(
552       [root](const EventNode* parent) {
553         return parent->RootLevel() == root->RootLevel();
554       },
555       root,
556       /*include_self=*/false);
557   return root_parent == nullptr;
558 }
559 
560 // Sorts root events based on root level and timestamp.
SortRootEventList(EventList * event_list)561 void SortRootEventList(EventList* event_list) {
562   absl::c_sort(*event_list, [](const EventNode* e1, const EventNode* e2) {
563     // If two root events have the same root level, the root event with an
564     // earlier timestamp will be processed first. Otherwise, the event with a
565     // larger root level will be processed first.
566     return e1->RootLevel() == e2->RootLevel()
567                ? *e1 < *e2
568                : e1->RootLevel() > e2->RootLevel();
569   });
570 }
571 
CreateEventGroups()572 void EventForest::CreateEventGroups() {
573   // Create a group for each TF loop iteration in non-JAX profiles.
574   int64_t group_id = 0;
575   if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
576     for (EventNode* root_event : tf_loop_root_events_) {
577       ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
578     }
579     return;
580   }
581 
582   // Iterate over all events and collect all root events.
583   EventList root_events;
584   for (const auto& typed_events : event_node_map_) {
585     for (const auto& event : typed_events.second) {
586       if (!event->RootLevel()) continue;
587       absl::optional<XStatVisitor> step_id_stat =
588           event->GetEventVisitor().GetStat(StatType::kStepId);
589       // If this is a root event that associated with tf.data, skip.
590       if (step_id_stat && tf_data_step_ids_.contains(step_id_stat->IntValue()))
591         continue;
592       root_events.push_back(event.get());
593     }
594   }
595 
596   SortRootEventList(&root_events);
597 
598   for (EventNode* root_event : root_events) {
599     if (RootNeedsGrouping(root_event) &&
600         // Ignores legacy TF root events for JAX profiles.
601         (!HasJaxEvent(event_node_map_) ||
602          !IsLegacyRootEvent(root_event->GetEventVisitor()))) {
603       ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
604     }
605   }
606 }
607 
MarkEagerlyExecutedGpuKernels()608 void EventForest::MarkEagerlyExecutedGpuKernels() {
609   auto kernel_execute_event_node_list =
610       gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
611   if (!kernel_execute_event_node_list) return;
612   for (auto& kernel_execute_event_node : *kernel_execute_event_node_list) {
613     kernel_execute_event_node->SetIsEager(kernel_execute_event_node->IsEager());
614   }
615 }
616 
MarkEagerlyExecutedCpuTfOps()617 void EventForest::MarkEagerlyExecutedCpuTfOps() {
618   auto tf_op_run_event_node_list =
619       gtl::FindOrNull(event_node_map_, HostEventType::kTfOpRun);
620   if (!tf_op_run_event_node_list) return;
621   for (auto& tf_op_run_event_node : *tf_op_run_event_node_list) {
622     tf_op_run_event_node->SetIsEager(tf_op_run_event_node->IsEager());
623   }
624 }
625 
ProcessTfDataSteps()626 void EventForest::ProcessTfDataSteps() {
627   const int64 tf_data_event_types[] = {
628       HostEventType::kTfDataCapturedFunctionRun,
629       HostEventType::kTfDataCapturedFunctionRunAsync,
630       HostEventType::kTfDataCapturedFunctionRunInstantiated,
631       HostEventType::kTfDataCapturedFunctionRunWithBorrowedArgs};
632   for (const int64_t tf_data_event_type : tf_data_event_types) {
633     auto tf_data_events = gtl::FindOrNull(event_node_map_, tf_data_event_type);
634     if (!tf_data_events) continue;
635     for (const auto& tf_data_event : *tf_data_events) {
636       absl::optional<XStatVisitor> step_id_stat =
637           tf_data_event->GetEventVisitor().GetStat(StatType::kStepId);
638       if (!step_id_stat) continue;
639       tf_data_step_ids_.insert(step_id_stat->IntValue());
640     }
641   }
642 }
643 
ProcessTensorFlowLoop()644 void EventForest::ProcessTensorFlowLoop() {
645   struct TensorFlowLoopIteration {
646     EventNode* first_event = nullptr;
647     std::vector<EventNode*> events;
648   };
649   using TensorFlowLoop =
650       absl::flat_hash_map<int64 /*iter_num*/, TensorFlowLoopIteration>;
651   absl::flat_hash_map<int64 /*step_id*/, TensorFlowLoop> tf_loops;
652 
653   // Sort the TF executor events by TF function/session (step_id) and iter_num.
654   auto executor_event_list =
655       gtl::FindOrNull(event_node_map_, HostEventType::kExecutorStateProcess);
656   if (!executor_event_list) return;
657   for (auto& executor_event : *executor_event_list) {
658     absl::optional<XStatVisitor> step_id_stat =
659         executor_event->GetEventVisitor().GetStat(StatType::kStepId);
660     absl::optional<XStatVisitor> iter_num_stat =
661         executor_event->GetEventVisitor().GetStat(StatType::kIterNum);
662     if (!step_id_stat || !iter_num_stat) continue;
663     int64_t step_id = step_id_stat->IntValue();
664     // Skip tf.data events.
665     if (tf_data_step_ids_.contains(step_id)) continue;
666     TensorFlowLoop& tf_loop = tf_loops[step_id];
667     TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
668     if (!iteration.first_event || *executor_event < *iteration.first_event) {
669       iteration.first_event = executor_event.get();
670     }
671     iteration.events.push_back(executor_event.get());
672   }
673 
674   std::vector<const TensorFlowLoopIteration*> iters;
675   for (const auto& step_id_and_tf_loop : tf_loops) {
676     const TensorFlowLoop& tf_loop = step_id_and_tf_loop.second;
677     // Filter out TF function/session without loops.
678     if (tf_loop.size() == 1 && tf_loop.contains(0)) continue;
679     for (const auto& iter_num_and_iter : tf_loop) {
680       iters.push_back(&iter_num_and_iter.second);
681     }
682   }
683 
684   // Sort iterations based on timestamp of the first event in the iteration.
685   absl::c_sort(iters, [](const auto& iter1, const auto& iter2) {
686     return *iter1->first_event < *iter2->first_event;
687   });
688 
689   // Register the first event of each iteration as a root event. Also, add the
690   // other events of the iteration as child to the root event.
691   for (const TensorFlowLoopIteration* iter : iters) {
692     EventNode* root_event = iter->first_event;
693     tf_loop_root_events_.push_back(root_event);
694     for (EventNode* event : iter->events) {
695       if (event == root_event) continue;
696       root_event->AddChild(event);
697     }
698   }
699 }
700 
ProcessWorker()701 void EventForest::ProcessWorker() {
702   auto eager_kernel_execute_event_list =
703       gtl::FindOrNull(event_node_map_, HostEventType::kEagerKernelExecute);
704   if (!eager_kernel_execute_event_list) return;
705   // The last EagerKernelExecute with a FunctionRun child.
706   EventNode* root_event = nullptr;
707   for (auto& eager_kernel_execute_event : *eager_kernel_execute_event_list) {
708     if (HasFunctionRun(eager_kernel_execute_event.get())) {
709       // A function op becomes a new root.
710       root_event = eager_kernel_execute_event.get();
711       root_event->SetRootLevel(1);
712     } else if (root_event) {
713       // Add non-function eager ops as child.
714       root_event->AddChild(eager_kernel_execute_event.get());
715     }
716   }
717 }
718 
ProcessModelIds()719 void EventForest::ProcessModelIds() {
720   const int64 model_id_event_type_list[] = {HostEventType::kSessionRun,
721                                             HostEventType::kTfrtModelRun};
722   for (const int64_t event_type : model_id_event_type_list) {
723     auto event_list = gtl::FindOrNull(event_node_map_, event_type);
724     if (!event_list) continue;
725     for (const auto& event : *event_list) {
726       auto group_id = event->GetGroupId();
727       if (!group_id.has_value()) continue;
728       absl::optional<XStatVisitor> model_id =
729           event->GetEventVisitor().GetStat(StatType::kModelId);
730       if (!model_id.has_value()) continue;
731       group_metadata_map_[*group_id].model_id = model_id->ToString();
732     }
733   }
734 }
735 
AddPlane(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XPlane * plane)736 void EventForest::AddPlane(
737     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
738     XPlane* plane) {
739   CreateStatMetadata(plane);
740   planes_.push_back({plane, visitor_factory(plane)});
741 }
742 
AddSpace(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XSpace * space)743 void EventForest::AddSpace(
744     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
745     XSpace* space) {
746   for (XPlane& plane : *space->mutable_planes()) {
747     AddPlane(visitor_factory, &plane);
748   }
749 }
750 
AddPlanes(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,const std::vector<XPlane * > & planes)751 void EventForest::AddPlanes(
752     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
753     const std::vector<XPlane*>& planes) {
754   for (XPlane* plane : planes) {
755     AddPlane(visitor_factory, plane);
756   }
757 }
758 
ConnectEvents(const std::vector<InterThreadConnectInfo> & connect_info_list)759 void EventForest::ConnectEvents(
760     const std::vector<InterThreadConnectInfo>& connect_info_list) {
761   ContextGroupMap context_groups;
762   for (auto& plane_visitor : planes_) {
763     ConnectIntraThread(plane_visitor.first, &plane_visitor.second,
764                        &context_groups);
765   }
766   ConnectInterThread(connect_info_list);
767   ConnectContextGroups(context_groups);
768 }
769 
ConnectTfDataEvents()770 void EventForest::ConnectTfDataEvents() {
771   absl::flat_hash_map<std::pair<int64 /*iterator_id*/, int64 /*element_id*/>,
772                       std::vector<EventNode*>>
773       produce_iterator_map;
774   uint64 num_producers = 0;
775   for (HostEventType event_type :
776        {HostEventType::kPrefetchProduce,
777         HostEventType::kParallelInterleaveProduce,
778         HostEventType::kParallelMapProduce, HostEventType::kMapAndBatchProduce,
779         HostEventType::kParseExampleProduce,
780         HostEventType::kParallelBatchProduce}) {
781     auto produce_event_list = gtl::FindOrNull(event_node_map_, event_type);
782     if (!produce_event_list) continue;
783     VLOG(1) << produce_event_list->size() << " "
784             << GetHostEventTypeStr(event_type) << " events found.";
785     for (auto& produce_event : *produce_event_list) {
786       absl::optional<XStatVisitor> element_id =
787           produce_event->GetEventVisitor().GetStat(StatType::kElementId);
788       if (!element_id.has_value()) continue;
789       for (EventNode* produce_iterator : produce_event->GetChildren()) {
790         if (IsIteratorEventType(produce_iterator->GetEventVisitor().Type())) {
791           absl::optional<XStatVisitor> iterator_id =
792               produce_iterator->GetEventVisitor().GetStat(StatType::kParentId);
793           if (!iterator_id.has_value()) break;
794           produce_iterator_map[{iterator_id->IntValue(),
795                                 element_id->IntValue()}]
796               .push_back(produce_iterator);
797           ++num_producers;
798           break;
799         }
800       }
801     }
802   }
803   VLOG(1) << num_producers << " producer iterators found.";
804   uint64 num_matched = 0;
805   for (HostEventType event_type :
806        {HostEventType::kPrefetchConsume,
807         HostEventType::kParallelInterleaveConsume,
808         HostEventType::kParallelMapConsume, HostEventType::kMapAndBatchConsume,
809         HostEventType::kParseExampleConsume,
810         HostEventType::kParallelBatchConsume}) {
811     auto consume_event_list = gtl::FindOrNull(event_node_map_, event_type);
812     if (!consume_event_list) continue;
813     VLOG(1) << consume_event_list->size() << " "
814             << GetHostEventTypeStr(event_type) << " events found.";
815     for (auto& consume_event : *consume_event_list) {
816       absl::optional<XStatVisitor> element_id =
817           consume_event->GetEventVisitor().GetStat(StatType::kElementId);
818       if (!element_id.has_value()) continue;
819       if (consume_event->GetParents().empty()) continue;
820       // consume_event is nested by consumer_iterator and does not have other
821       // parents.
822       EventNode* consume_iterator = consume_event->GetParents().at(0);
823       if (!consume_iterator ||
824           !IsIteratorEventType(consume_iterator->GetEventVisitor().Type())) {
825         continue;
826       }
827       absl::optional<XStatVisitor> iterator_id =
828           consume_iterator->GetEventVisitor().GetStat(StatType::kStepId);
829       if (!iterator_id.has_value()) continue;
830       if (auto produce_iterators = gtl::FindOrNull(
831               produce_iterator_map, std::make_pair(iterator_id->IntValue(),
832                                                    element_id->IntValue()))) {
833         for (EventNode* produce_iterator : *produce_iterators) {
834           consume_iterator->AddChild(produce_iterator);
835           ++num_matched;
836         }
837       }
838     }
839   }
840   VLOG(1) << num_matched << " consumer iterators matched.";
841 }
842 
GroupEvents()843 void EventForest::GroupEvents() {
844   ProcessTfDataSteps();
845   ProcessTensorFlowLoop();
846   ProcessWorker();
847   CreateEventGroups();
848   MarkEagerlyExecutedGpuKernels();
849   MarkEagerlyExecutedCpuTfOps();
850   ProcessModelIds();
851 }
852 
CreateInterThreadConnectInfoList()853 std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
854   std::vector<InterThreadConnectInfo> connect_info_list = {
855       {HostEventType::kExecutorStateProcess,
856        HostEventType::kIteratorGetNextOp,
857        {StatType::kStepId, StatType::kIterNum}},
858       {HostEventType::kExecutorStateProcess,
859        HostEventType::kIteratorGetNextAsOptionalOp,
860        {StatType::kStepId, StatType::kIterNum}},
861       {HostEventType::kKernelLaunch,
862        HostEventType::kKernelExecute,
863        {StatType::kCorrelationId}}};
864   return connect_info_list;
865 }
866 
GroupTfEvents(XSpace * space,EventForest * event_forest)867 void GroupTfEvents(XSpace* space, EventForest* event_forest) {
868   if (CheckLoopOp(*space)) {
869     // TODO(b/154510598): Support TF's loop ops.
870     return;
871   }
872   std::vector<InterThreadConnectInfo> connect_info_list =
873       CreateInterThreadConnectInfoList();
874   event_forest->AddSpace(CreateTfXPlaneVisitor, space);
875   event_forest->ConnectEvents(connect_info_list);
876   event_forest->GroupEvents();
877 }
878 
GroupTfEvents(XSpace * space)879 void GroupTfEvents(XSpace* space) {
880   EventForest event_forest;
881   GroupTfEvents(space, &event_forest);
882 }
883 
884 }  // namespace profiler
885 }  // namespace tensorflow
886