1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/utils/group_events.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <iterator>
21 #include <map>
22 #include <memory>
23 #include <queue>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/strings/match.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/profiler/lib/connected_traceme.h"
36 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
37 #include "tensorflow/core/profiler/utils/xplane_builder.h"
38 #include "tensorflow/core/profiler/utils/xplane_schema.h"
39 #include "tensorflow/core/profiler/utils/xplane_utils.h"
40
41 namespace tensorflow {
42 namespace profiler {
43 namespace {
44
45 // Creates stat metadata for the stats which may be added by grouping.
CreateStatMetadata(XPlane * plane)46 void CreateStatMetadata(XPlane* plane) {
47 XPlaneBuilder builder(plane);
48 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
49 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
50 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
51 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSelectedGroupIds));
52 }
53
54 // Returns event type if it is a KernelLaunch or KernelExecute event.
GetKernelEventType(bool is_host_plane,const EventNode & event)55 absl::optional<int64> GetKernelEventType(bool is_host_plane,
56 const EventNode& event) {
57 if (event.GetEventVisitor().GetStat(StatType::kCorrelationId).has_value()) {
58 return is_host_plane ? HostEventType::kKernelLaunch
59 : HostEventType::kKernelExecute;
60 }
61 return absl::nullopt;
62 }
63
GetEventType(bool is_host_plane,const EventNode & event)64 int64 GetEventType(bool is_host_plane, const EventNode& event) {
65 if (absl::optional<int64> event_type = event.GetEventVisitor().Type()) {
66 return *event_type;
67 } else if (absl::optional<int64> kernel_event_type =
68 GetKernelEventType(is_host_plane, event)) {
69 // KernelLaunch and KernelExecute event types are not supported by
70 // XPlaneVisitor and should be checked separately.
71 // TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
72 // KernelExecute event types.
73 return *kernel_event_type;
74 } else {
75 return HostEventType::kUnknownHostEventType;
76 }
77 }
78
SetContextGroup(EventNode * event,ContextGroupMap * context_groups)79 void SetContextGroup(EventNode* event, ContextGroupMap* context_groups) {
80 auto producer = event->GetProducerContext();
81 if (producer.has_value()) {
82 ((*context_groups)[producer->type][producer->id])
83 .producers.push_back(event);
84 }
85 auto consumer = event->GetConsumerContext();
86 if (consumer.has_value()) {
87 ((*context_groups)[consumer->type][consumer->id])
88 .consumers.push_back(event);
89 }
90 }
91
ConnectContextGroups(const ContextGroupMap & context_groups)92 void ConnectContextGroups(const ContextGroupMap& context_groups) {
93 for (auto& type_id_group : context_groups) {
94 for (auto& id_group : type_id_group.second) {
95 const ContextGroup& group = id_group.second;
96 for (EventNode* parent : group.producers) {
97 for (EventNode* child : group.consumers) {
98 parent->AddChild(child);
99 }
100 }
101 }
102 }
103 }
104
CreateVirtualEvent(const XStat & step_id_stat,const XStat & iter_num_stat)105 std::unique_ptr<XEvent> CreateVirtualEvent(const XStat& step_id_stat,
106 const XStat& iter_num_stat) {
107 auto virtual_event = absl::make_unique<XEvent>();
108 *virtual_event->add_stats() = step_id_stat;
109 *virtual_event->add_stats() = iter_num_stat;
110 return virtual_event;
111 }
112
HasFunctionRun(EventNode * event_node)113 bool HasFunctionRun(EventNode* event_node) {
114 for (EventNode* child : event_node->GetChildren()) {
115 if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
116 return true;
117 }
118 }
119 return false;
120 }
121
IsImplicitRootEvent(const XEventVisitor & event)122 bool IsImplicitRootEvent(const XEventVisitor& event) {
123 static const auto* const kImplicitRootEvents = new absl::flat_hash_set<int64>{
124 HostEventType::kFunctionRun, HostEventType::kSessionRun,
125 HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
126 return event.Type().has_value() &&
127 kImplicitRootEvents->contains(*event.Type());
128 }
129
ProcessRootEvent(int64_t group_id,EventNode * root_event,GroupMetadataMap * group_metadata_map)130 void ProcessRootEvent(int64_t group_id, EventNode* root_event,
131 GroupMetadataMap* group_metadata_map) {
132 root_event->PropagateGroupId(group_id, group_metadata_map);
133 std::string group_name = root_event->GetGroupName();
134 // TODO(jihochoi): change event name instead.
135 if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
136 // Add the `step_name` stat for the user-defined root events only. When an
137 // XEvent is converted to a trace event, the trace event name is set to the
138 // `step_name` stat's value if present.
139 root_event->AddStepName(group_name);
140 }
141 (*group_metadata_map)[group_id].name = std::move(group_name);
142 }
143
144 struct ContextTypeAndId {
145 int type;
146 uint64 id;
147 };
148
GetLegacyProducerContext(const XEventVisitor & event)149 absl::optional<ContextTypeAndId> GetLegacyProducerContext(
150 const XEventVisitor& event) {
151 absl::optional<ContextTypeAndId> type_and_id;
152 absl::optional<int64> event_type = event.Type();
153 if (event_type.has_value()) {
154 switch (*event_type) {
155 case HostEventType::kTraceContext:
156 case HostEventType::kFunctionRun:
157 case HostEventType::kSessionRun:
158 case HostEventType::kRunGraph: {
159 absl::optional<XStatVisitor> stat = event.GetStat(StatType::kStepId);
160 if (stat.has_value()) {
161 type_and_id = {static_cast<int>(ContextType::kTfExecutor),
162 static_cast<uint64>(stat->IntValue())};
163 }
164 break;
165 }
166 case HostEventType::kCallOp:
167 case HostEventType::kNumericalGradientOpEvalRight:
168 case HostEventType::kNumericalGradientOpEvalLeft:
169 case HostEventType::kSymbolicGradientOp:
170 case HostEventType::kRemoteCallOp:
171 case HostEventType::kIfOp:
172 case HostEventType::kCaseOp:
173 case HostEventType::kPartitionedCallOp: {
174 // TODO(b/154510598): Fix handling of the loop ops.
175 // case HostEventType::kWhileOpEvalCond:
176 // case HostEventType::kWhileOpStartBody:
177 // case HostEventType::kForOp:
178 // case HostEventType::kParallelForOp:
179 // case HostEventType::kForeverOp:
180 absl::optional<XStatVisitor> stat =
181 event.GetStat(StatType::kFunctionStepId);
182 if (stat.has_value()) {
183 type_and_id = {static_cast<int>(ContextType::kTfExecutor),
184 static_cast<uint64>(stat->IntValue())};
185 }
186 break;
187 }
188 default:
189 break;
190 }
191 }
192 return type_and_id;
193 }
194
GetLegacyConsumerContext(const XEventVisitor & event)195 absl::optional<ContextTypeAndId> GetLegacyConsumerContext(
196 const XEventVisitor& event) {
197 absl::optional<ContextTypeAndId> type_and_id;
198 absl::optional<int64> event_type = event.Type();
199 if (event_type.has_value()) {
200 switch (*event_type) {
201 case HostEventType::kExecutorStateProcess:
202 case HostEventType::kExecutorDoneCallback:
203 case HostEventType::kRunGraphDone: {
204 absl::optional<XStatVisitor> stat = event.GetStat(StatType::kStepId);
205 if (stat.has_value()) {
206 type_and_id = {static_cast<int>(ContextType::kTfExecutor),
207 static_cast<uint64>(stat->IntValue())};
208 }
209 break;
210 }
211 default:
212 break;
213 }
214 }
215 return type_and_id;
216 }
217
IsLegacyRootEvent(const XEventVisitor & event)218 bool IsLegacyRootEvent(const XEventVisitor& event) {
219 static const auto* const kRootEvents = new absl::flat_hash_set<int64>{
220 HostEventType::kTraceContext, HostEventType::kFunctionRun,
221 HostEventType::kSessionRun, HostEventType::kRunGraph};
222 return event.Type().has_value() && kRootEvents->contains(*event.Type());
223 }
224
225 using Comparator = std::function<bool(const EventNode*)>;
226
FindParentWithComparator(const Comparator & comparator,const EventNode * node,bool include_self)227 const EventNode* FindParentWithComparator(const Comparator& comparator,
228 const EventNode* node,
229 bool include_self) {
230 std::queue<const EventNode*> nodes;
231 absl::flat_hash_set<const EventNode*> seen = {node};
232 if (include_self) {
233 nodes.push(node);
234 } else {
235 for (const EventNode* parent : node->GetParents()) {
236 nodes.push(parent);
237 seen.insert(parent);
238 }
239 }
240 while (!nodes.empty()) {
241 const EventNode* node = nodes.front();
242 nodes.pop();
243 if (comparator(node)) return node;
244 for (const EventNode* parent : node->GetParents()) {
245 if (seen.contains(parent)) continue;
246 nodes.push(parent);
247 seen.insert(parent);
248 }
249 }
250 return nullptr;
251 }
252
253 // Returns true if it has JAX-related events.
HasJaxEvent(const EventNodeMap & event_node_map)254 bool HasJaxEvent(const EventNodeMap& event_node_map) {
255 return event_node_map.contains(HostEventType::kExecuteOnLocalDevices);
256 }
257
IsIteratorEventType(absl::optional<int64> event_type)258 bool IsIteratorEventType(absl::optional<int64> event_type) {
259 return event_type == HostEventType::kIterator ||
260 event_type == HostEventType::kDeviceInputPipelineSecondIterator;
261 }
262
263 } // namespace
264
265 // Returns true if TF's loop ops exist in the given XSpace's metadata.
CheckLoopOp(const XSpace & space)266 bool CheckLoopOp(const XSpace& space) {
267 for (const XPlane& plane : space.planes()) {
268 for (const auto& event_metadata : plane.event_metadata()) {
269 absl::optional<int64> event_type =
270 FindHostEventType(event_metadata.second.name());
271 if (!event_type.has_value()) continue;
272 switch (*event_type) {
273 case HostEventType::kWhileOpEvalCond:
274 case HostEventType::kWhileOpStartBody:
275 case HostEventType::kForOp:
276 case HostEventType::kParallelForOp:
277 case HostEventType::kForeverOp:
278 return true;
279 default:
280 break;
281 }
282 }
283 }
284 return false;
285 }
286
EventNode(const XPlaneVisitor * plane,XLine * raw_line,XEvent * raw_event)287 EventNode::EventNode(const XPlaneVisitor* plane, XLine* raw_line,
288 XEvent* raw_event)
289 : plane_(plane),
290 visitor_(plane, raw_line, raw_event),
291 raw_line_(raw_line),
292 raw_event_(raw_event) {
293 absl::optional<int> producer_type;
294 absl::optional<uint64> producer_id;
295 absl::optional<int> consumer_type;
296 absl::optional<uint64> consumer_id;
297
298 visitor_.ForEachStat([&](const XStatVisitor& stat) {
299 if (!stat.Type().has_value()) return;
300 switch (*stat.Type()) {
301 case StatType::kProducerType:
302 producer_type = stat.IntValue();
303 break;
304 case StatType::kProducerId:
305 producer_id = stat.IntOrUintValue();
306 break;
307 case StatType::kConsumerType:
308 consumer_type = stat.IntValue();
309 break;
310 case StatType::kConsumerId:
311 consumer_id = stat.IntOrUintValue();
312 break;
313 case StatType::kIsRoot:
314 root_level_ = stat.IntValue();
315 break;
316 case StatType::kIsAsync:
317 is_async_ = stat.IntValue();
318 break;
319 default:
320 break;
321 }
322 });
323
324 // Support legacy traces.
325 if (!producer_type.has_value() || !producer_id.has_value()) {
326 if (auto producer_context = GetLegacyProducerContext(visitor_)) {
327 producer_type = producer_context->type;
328 producer_id = producer_context->id;
329 }
330 }
331 if (!consumer_type.has_value() || !consumer_id.has_value()) {
332 if (auto consumer_context = GetLegacyConsumerContext(visitor_)) {
333 consumer_type = consumer_context->type;
334 consumer_id = consumer_context->id;
335 }
336 }
337 root_level_ = root_level_ ? root_level_ : IsLegacyRootEvent(visitor_);
338
339 if (producer_type.has_value() && producer_id.has_value()) {
340 producer_context_ = {*producer_type, *producer_id};
341 }
342 if (consumer_type.has_value() && consumer_id.has_value()) {
343 consumer_context_ = {*consumer_type, *consumer_id};
344 }
345 }
346
EventNode(const EventNode & event_node)347 EventNode::EventNode(const EventNode& event_node)
348 : EventNode(event_node.plane_, event_node.raw_line_,
349 event_node.raw_event_) {}
350
GetContextStat(int64_t stat_type) const351 absl::optional<XStatVisitor> EventNode::GetContextStat(
352 int64_t stat_type) const {
353 std::queue<const EventNode*> nodes;
354 absl::flat_hash_set<const EventNode*> seen = {this};
355 nodes.push(this);
356 while (!nodes.empty()) {
357 const EventNode* node = nodes.front();
358 nodes.pop();
359 if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
360 return stat;
361 }
362 for (const EventNode* parent : node->GetParents()) {
363 if (seen.contains(parent)) continue;
364 nodes.push(parent);
365 seen.insert(parent);
366 }
367 }
368 return absl::nullopt;
369 }
370
GetGroupName() const371 std::string EventNode::GetGroupName() const {
372 std::string name;
373 if (absl::optional<XStatVisitor> stat =
374 GetContextStat(StatType::kGraphType)) {
375 absl::StrAppend(&name, stat->StrOrRefValue(), " ");
376 } else if (!(IsImplicitRootEvent(visitor_))) {
377 absl::StrAppend(&name, GetEventVisitor().Name(), " ");
378 }
379 int64_t step_num = group_id_.value_or(0);
380 if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
381 step_num = stat->IntValue();
382 } else if (absl::optional<XStatVisitor> stat =
383 GetContextStat(StatType::kStepNum)) {
384 step_num = stat->IntValue();
385 }
386 absl::StrAppend(&name, step_num);
387 return name;
388 }
389
FindOrAddStatByType(int64_t stat_type)390 XStat* EventNode::FindOrAddStatByType(int64_t stat_type) {
391 const XStatMetadata* stat_metadata = plane_->GetStatMetadataByType(stat_type);
392 DCHECK(stat_metadata != nullptr);
393 return FindOrAddMutableStat(*stat_metadata, raw_event_);
394 }
395
SetGroupId(int64_t group_id)396 void EventNode::SetGroupId(int64_t group_id) {
397 group_id_ = group_id;
398 FindOrAddStatByType(StatType::kGroupId)->set_int64_value(group_id);
399 }
400
PropagateGroupId(int64_t group_id,GroupMetadataMap * group_metadata_map)401 void EventNode::PropagateGroupId(int64_t group_id,
402 GroupMetadataMap* group_metadata_map) {
403 std::queue<EventNode*> nodes;
404 absl::flat_hash_set<EventNode*> seen = {this};
405 nodes.push(this);
406 while (!nodes.empty()) {
407 EventNode* node = nodes.front();
408 nodes.pop();
409 absl::optional<int64> node_group_id = node->GetGroupId();
410 if (node_group_id.has_value()) {
411 if (*node_group_id != group_id) {
412 (*group_metadata_map)[group_id].children.insert(*node_group_id);
413 (*group_metadata_map)[*node_group_id].parents.insert(group_id);
414 }
415 } else {
416 node->SetGroupId(group_id);
417 for (EventNode* child : node->GetChildren()) {
418 if (seen.contains(child)) continue;
419 nodes.push(child);
420 seen.insert(child);
421 }
422 }
423 }
424 }
425
AddStepName(absl::string_view step_name)426 void EventNode::AddStepName(absl::string_view step_name) {
427 FindOrAddStatByType(StatType::kStepName)
428 ->set_str_value(step_name.data(), step_name.size());
429 }
430
AddSelectedGroupIds(const GroupMetadataMap & group_metadata_map)431 void EventNode::AddSelectedGroupIds(
432 const GroupMetadataMap& group_metadata_map) {
433 const auto& group_metadata = group_metadata_map.at(*group_id_);
434 std::vector<int64> group_ids;
435 group_ids.reserve(1 + group_metadata.parents.size() +
436 group_metadata.children.size());
437 group_ids.push_back(*group_id_);
438 group_ids.insert(group_ids.end(), group_metadata.parents.begin(),
439 group_metadata.parents.end());
440 group_ids.insert(group_ids.end(), group_metadata.children.begin(),
441 group_metadata.children.end());
442 FindOrAddStatByType(StatType::kSelectedGroupIds)
443 ->set_str_value(
444 absl::StrCat("?selected_group_ids=", absl::StrJoin(group_ids, ",")));
445 }
446
SetIsEager(bool is_eager)447 void EventNode::SetIsEager(bool is_eager) {
448 FindOrAddStatByType(StatType::kIsEager)->set_int64_value(is_eager ? 1 : 0);
449 }
450
IsEager()451 bool EventNode::IsEager() {
452 // It is eagerly executed if its trace context includes the EagerKernelExecute
453 // event (which may execute an op eagerly or through the TF executor) but not
454 // the TF executor event.
455 return FindParent(HostEventType::kExecutorStateProcess) == nullptr &&
456 FindParent(HostEventType::kEagerKernelExecute) != nullptr;
457 }
458
FindParent(int64_t event_type) const459 const EventNode* EventNode::FindParent(int64_t event_type) const {
460 return FindParentWithComparator(
461 [event_type](const EventNode* node) {
462 return node->GetEventVisitor().Type() == event_type;
463 },
464 this, /*include_self=*/true);
465 }
466
ConnectIntraThread(XPlane * plane,XPlaneVisitor * visitor,ContextGroupMap * context_groups)467 void EventForest::ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor,
468 ContextGroupMap* context_groups) {
469 // TODO(b/149095099): avoid string comparison.
470 bool is_host_plane = (visitor->Name() == kHostThreadsPlaneName);
471 for (auto& line : *plane->mutable_lines()) {
472 std::vector<EventNode*> parent_nodes;
473 for (auto& event : *line.mutable_events()) {
474 auto cur_node = absl::make_unique<EventNode>(visitor, &line, &event);
475 // Update `context_groups` for `ConnectInterThread`.
476 SetContextGroup(cur_node.get(), context_groups);
477 // Async events are ignored when processing the nesting relationship.
478 if (cur_node->IsAsync()) continue;
479 while (!parent_nodes.empty()) {
480 EventNode* parent_node = parent_nodes.back();
481 if (parent_node->GetEventVisitor().GetTimespan().Includes(
482 cur_node->GetEventVisitor().GetTimespan())) {
483 parent_node->AddChild(cur_node.get());
484 break;
485 } else {
486 parent_nodes.pop_back();
487 }
488 }
489 parent_nodes.push_back(cur_node.get());
490 // event_node_map_ keeps cur_node alive.
491 event_node_map_[GetEventType(is_host_plane, *cur_node)].push_back(
492 std::move(cur_node));
493 }
494 }
495 }
496
ConnectInterThread(const std::vector<InterThreadConnectInfo> & connect_info_list)497 void EventForest::ConnectInterThread(
498 const std::vector<InterThreadConnectInfo>& connect_info_list) {
499 for (const auto& connect_info : connect_info_list) {
500 absl::flat_hash_map<std::vector<uint64>, EventNode*> connect_map;
501 const std::vector<int64>& parent_stat_types =
502 connect_info.parent_stat_types;
503 const std::vector<int64>* child_stat_types = &connect_info.child_stat_types;
504 if (child_stat_types->empty()) {
505 child_stat_types = &parent_stat_types;
506 }
507 if (auto parent_event_node_list =
508 gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) {
509 for (const auto& parent_event_node : *parent_event_node_list) {
510 std::vector<uint64> stats;
511 for (auto stat_type : parent_stat_types) {
512 absl::optional<XStatVisitor> stat =
513 parent_event_node->GetContextStat(stat_type);
514 if (!stat) break;
515 stats.push_back(stat->IntOrUintValue());
516 }
517 if (stats.size() == parent_stat_types.size()) {
518 connect_map[stats] = parent_event_node.get();
519 }
520 }
521 }
522 if (auto child_event_node_list =
523 gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) {
524 for (const auto& child_event_node : *child_event_node_list) {
525 std::vector<uint64> stats;
526 for (auto stat_type : *child_stat_types) {
527 absl::optional<XStatVisitor> stat =
528 child_event_node->GetContextStat(stat_type);
529 if (!stat) break;
530 stats.push_back(stat->IntOrUintValue());
531 }
532 if (stats.size() == child_stat_types->size()) {
533 if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
534 parent_event_node->AddChild(child_event_node.get());
535 }
536 }
537 }
538 }
539 }
540 }
541
542 // Returns whether a root event needs grouping.
RootNeedsGrouping(const EventNode * root)543 bool RootNeedsGrouping(const EventNode* root) {
544 // No grouping is needed if it is already grouped.
545 if (root->GetGroupId().has_value()) return false;
546 // If there is a parent node with the same root level, skip grouping at <root>
547 // and later apply grouping at the parent node.
548 // If there is a parent node with a different root level, apply grouping at
549 // <root>, and later apply grouping at the parent node. Root events with
550 // different levels are grouped separately.
551 const EventNode* root_parent = FindParentWithComparator(
552 [root](const EventNode* parent) {
553 return parent->RootLevel() == root->RootLevel();
554 },
555 root,
556 /*include_self=*/false);
557 return root_parent == nullptr;
558 }
559
560 // Sorts root events based on root level and timestamp.
SortRootEventList(EventList * event_list)561 void SortRootEventList(EventList* event_list) {
562 absl::c_sort(*event_list, [](const EventNode* e1, const EventNode* e2) {
563 // If two root events have the same root level, the root event with an
564 // earlier timestamp will be processed first. Otherwise, the event with a
565 // larger root level will be processed first.
566 return e1->RootLevel() == e2->RootLevel()
567 ? *e1 < *e2
568 : e1->RootLevel() > e2->RootLevel();
569 });
570 }
571
CreateEventGroups()572 void EventForest::CreateEventGroups() {
573 // Create a group for each TF loop iteration in non-JAX profiles.
574 int64_t group_id = 0;
575 if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
576 for (EventNode* root_event : tf_loop_root_events_) {
577 ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
578 }
579 return;
580 }
581
582 // Iterate over all events and collect all root events.
583 EventList root_events;
584 for (const auto& typed_events : event_node_map_) {
585 for (const auto& event : typed_events.second) {
586 if (!event->RootLevel()) continue;
587 absl::optional<XStatVisitor> step_id_stat =
588 event->GetEventVisitor().GetStat(StatType::kStepId);
589 // If this is a root event that associated with tf.data, skip.
590 if (step_id_stat && tf_data_step_ids_.contains(step_id_stat->IntValue()))
591 continue;
592 root_events.push_back(event.get());
593 }
594 }
595
596 SortRootEventList(&root_events);
597
598 for (EventNode* root_event : root_events) {
599 if (RootNeedsGrouping(root_event) &&
600 // Ignores legacy TF root events for JAX profiles.
601 (!HasJaxEvent(event_node_map_) ||
602 !IsLegacyRootEvent(root_event->GetEventVisitor()))) {
603 ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
604 }
605 }
606 }
607
MarkEagerlyExecutedGpuKernels()608 void EventForest::MarkEagerlyExecutedGpuKernels() {
609 auto kernel_execute_event_node_list =
610 gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
611 if (!kernel_execute_event_node_list) return;
612 for (auto& kernel_execute_event_node : *kernel_execute_event_node_list) {
613 kernel_execute_event_node->SetIsEager(kernel_execute_event_node->IsEager());
614 }
615 }
616
MarkEagerlyExecutedCpuTfOps()617 void EventForest::MarkEagerlyExecutedCpuTfOps() {
618 auto tf_op_run_event_node_list =
619 gtl::FindOrNull(event_node_map_, HostEventType::kTfOpRun);
620 if (!tf_op_run_event_node_list) return;
621 for (auto& tf_op_run_event_node : *tf_op_run_event_node_list) {
622 tf_op_run_event_node->SetIsEager(tf_op_run_event_node->IsEager());
623 }
624 }
625
ProcessTfDataSteps()626 void EventForest::ProcessTfDataSteps() {
627 const int64 tf_data_event_types[] = {
628 HostEventType::kTfDataCapturedFunctionRun,
629 HostEventType::kTfDataCapturedFunctionRunAsync,
630 HostEventType::kTfDataCapturedFunctionRunInstantiated,
631 HostEventType::kTfDataCapturedFunctionRunWithBorrowedArgs};
632 for (const int64_t tf_data_event_type : tf_data_event_types) {
633 auto tf_data_events = gtl::FindOrNull(event_node_map_, tf_data_event_type);
634 if (!tf_data_events) continue;
635 for (const auto& tf_data_event : *tf_data_events) {
636 absl::optional<XStatVisitor> step_id_stat =
637 tf_data_event->GetEventVisitor().GetStat(StatType::kStepId);
638 if (!step_id_stat) continue;
639 tf_data_step_ids_.insert(step_id_stat->IntValue());
640 }
641 }
642 }
643
ProcessTensorFlowLoop()644 void EventForest::ProcessTensorFlowLoop() {
645 struct TensorFlowLoopIteration {
646 EventNode* first_event = nullptr;
647 std::vector<EventNode*> events;
648 };
649 using TensorFlowLoop =
650 absl::flat_hash_map<int64 /*iter_num*/, TensorFlowLoopIteration>;
651 absl::flat_hash_map<int64 /*step_id*/, TensorFlowLoop> tf_loops;
652
653 // Sort the TF executor events by TF function/session (step_id) and iter_num.
654 auto executor_event_list =
655 gtl::FindOrNull(event_node_map_, HostEventType::kExecutorStateProcess);
656 if (!executor_event_list) return;
657 for (auto& executor_event : *executor_event_list) {
658 absl::optional<XStatVisitor> step_id_stat =
659 executor_event->GetEventVisitor().GetStat(StatType::kStepId);
660 absl::optional<XStatVisitor> iter_num_stat =
661 executor_event->GetEventVisitor().GetStat(StatType::kIterNum);
662 if (!step_id_stat || !iter_num_stat) continue;
663 int64_t step_id = step_id_stat->IntValue();
664 // Skip tf.data events.
665 if (tf_data_step_ids_.contains(step_id)) continue;
666 TensorFlowLoop& tf_loop = tf_loops[step_id];
667 TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
668 if (!iteration.first_event || *executor_event < *iteration.first_event) {
669 iteration.first_event = executor_event.get();
670 }
671 iteration.events.push_back(executor_event.get());
672 }
673
674 std::vector<const TensorFlowLoopIteration*> iters;
675 for (const auto& step_id_and_tf_loop : tf_loops) {
676 const TensorFlowLoop& tf_loop = step_id_and_tf_loop.second;
677 // Filter out TF function/session without loops.
678 if (tf_loop.size() == 1 && tf_loop.contains(0)) continue;
679 for (const auto& iter_num_and_iter : tf_loop) {
680 iters.push_back(&iter_num_and_iter.second);
681 }
682 }
683
684 // Sort iterations based on timestamp of the first event in the iteration.
685 absl::c_sort(iters, [](const auto& iter1, const auto& iter2) {
686 return *iter1->first_event < *iter2->first_event;
687 });
688
689 // Register the first event of each iteration as a root event. Also, add the
690 // other events of the iteration as child to the root event.
691 for (const TensorFlowLoopIteration* iter : iters) {
692 EventNode* root_event = iter->first_event;
693 tf_loop_root_events_.push_back(root_event);
694 for (EventNode* event : iter->events) {
695 if (event == root_event) continue;
696 root_event->AddChild(event);
697 }
698 }
699 }
700
ProcessWorker()701 void EventForest::ProcessWorker() {
702 auto eager_kernel_execute_event_list =
703 gtl::FindOrNull(event_node_map_, HostEventType::kEagerKernelExecute);
704 if (!eager_kernel_execute_event_list) return;
705 // The last EagerKernelExecute with a FunctionRun child.
706 EventNode* root_event = nullptr;
707 for (auto& eager_kernel_execute_event : *eager_kernel_execute_event_list) {
708 if (HasFunctionRun(eager_kernel_execute_event.get())) {
709 // A function op becomes a new root.
710 root_event = eager_kernel_execute_event.get();
711 root_event->SetRootLevel(1);
712 } else if (root_event) {
713 // Add non-function eager ops as child.
714 root_event->AddChild(eager_kernel_execute_event.get());
715 }
716 }
717 }
718
ProcessModelIds()719 void EventForest::ProcessModelIds() {
720 const int64 model_id_event_type_list[] = {HostEventType::kSessionRun,
721 HostEventType::kTfrtModelRun};
722 for (const int64_t event_type : model_id_event_type_list) {
723 auto event_list = gtl::FindOrNull(event_node_map_, event_type);
724 if (!event_list) continue;
725 for (const auto& event : *event_list) {
726 auto group_id = event->GetGroupId();
727 if (!group_id.has_value()) continue;
728 absl::optional<XStatVisitor> model_id =
729 event->GetEventVisitor().GetStat(StatType::kModelId);
730 if (!model_id.has_value()) continue;
731 group_metadata_map_[*group_id].model_id = model_id->ToString();
732 }
733 }
734 }
735
AddPlane(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XPlane * plane)736 void EventForest::AddPlane(
737 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
738 XPlane* plane) {
739 CreateStatMetadata(plane);
740 planes_.push_back({plane, visitor_factory(plane)});
741 }
742
AddSpace(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XSpace * space)743 void EventForest::AddSpace(
744 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
745 XSpace* space) {
746 for (XPlane& plane : *space->mutable_planes()) {
747 AddPlane(visitor_factory, &plane);
748 }
749 }
750
AddPlanes(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,const std::vector<XPlane * > & planes)751 void EventForest::AddPlanes(
752 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
753 const std::vector<XPlane*>& planes) {
754 for (XPlane* plane : planes) {
755 AddPlane(visitor_factory, plane);
756 }
757 }
758
ConnectEvents(const std::vector<InterThreadConnectInfo> & connect_info_list)759 void EventForest::ConnectEvents(
760 const std::vector<InterThreadConnectInfo>& connect_info_list) {
761 ContextGroupMap context_groups;
762 for (auto& plane_visitor : planes_) {
763 ConnectIntraThread(plane_visitor.first, &plane_visitor.second,
764 &context_groups);
765 }
766 ConnectInterThread(connect_info_list);
767 ConnectContextGroups(context_groups);
768 }
769
ConnectTfDataEvents()770 void EventForest::ConnectTfDataEvents() {
771 absl::flat_hash_map<std::pair<int64 /*iterator_id*/, int64 /*element_id*/>,
772 std::vector<EventNode*>>
773 produce_iterator_map;
774 uint64 num_producers = 0;
775 for (HostEventType event_type :
776 {HostEventType::kPrefetchProduce,
777 HostEventType::kParallelInterleaveProduce,
778 HostEventType::kParallelMapProduce, HostEventType::kMapAndBatchProduce,
779 HostEventType::kParseExampleProduce,
780 HostEventType::kParallelBatchProduce}) {
781 auto produce_event_list = gtl::FindOrNull(event_node_map_, event_type);
782 if (!produce_event_list) continue;
783 VLOG(1) << produce_event_list->size() << " "
784 << GetHostEventTypeStr(event_type) << " events found.";
785 for (auto& produce_event : *produce_event_list) {
786 absl::optional<XStatVisitor> element_id =
787 produce_event->GetEventVisitor().GetStat(StatType::kElementId);
788 if (!element_id.has_value()) continue;
789 for (EventNode* produce_iterator : produce_event->GetChildren()) {
790 if (IsIteratorEventType(produce_iterator->GetEventVisitor().Type())) {
791 absl::optional<XStatVisitor> iterator_id =
792 produce_iterator->GetEventVisitor().GetStat(StatType::kParentId);
793 if (!iterator_id.has_value()) break;
794 produce_iterator_map[{iterator_id->IntValue(),
795 element_id->IntValue()}]
796 .push_back(produce_iterator);
797 ++num_producers;
798 break;
799 }
800 }
801 }
802 }
803 VLOG(1) << num_producers << " producer iterators found.";
804 uint64 num_matched = 0;
805 for (HostEventType event_type :
806 {HostEventType::kPrefetchConsume,
807 HostEventType::kParallelInterleaveConsume,
808 HostEventType::kParallelMapConsume, HostEventType::kMapAndBatchConsume,
809 HostEventType::kParseExampleConsume,
810 HostEventType::kParallelBatchConsume}) {
811 auto consume_event_list = gtl::FindOrNull(event_node_map_, event_type);
812 if (!consume_event_list) continue;
813 VLOG(1) << consume_event_list->size() << " "
814 << GetHostEventTypeStr(event_type) << " events found.";
815 for (auto& consume_event : *consume_event_list) {
816 absl::optional<XStatVisitor> element_id =
817 consume_event->GetEventVisitor().GetStat(StatType::kElementId);
818 if (!element_id.has_value()) continue;
819 if (consume_event->GetParents().empty()) continue;
820 // consume_event is nested by consumer_iterator and does not have other
821 // parents.
822 EventNode* consume_iterator = consume_event->GetParents().at(0);
823 if (!consume_iterator ||
824 !IsIteratorEventType(consume_iterator->GetEventVisitor().Type())) {
825 continue;
826 }
827 absl::optional<XStatVisitor> iterator_id =
828 consume_iterator->GetEventVisitor().GetStat(StatType::kStepId);
829 if (!iterator_id.has_value()) continue;
830 if (auto produce_iterators = gtl::FindOrNull(
831 produce_iterator_map, std::make_pair(iterator_id->IntValue(),
832 element_id->IntValue()))) {
833 for (EventNode* produce_iterator : *produce_iterators) {
834 consume_iterator->AddChild(produce_iterator);
835 ++num_matched;
836 }
837 }
838 }
839 }
840 VLOG(1) << num_matched << " consumer iterators matched.";
841 }
842
GroupEvents()843 void EventForest::GroupEvents() {
844 ProcessTfDataSteps();
845 ProcessTensorFlowLoop();
846 ProcessWorker();
847 CreateEventGroups();
848 MarkEagerlyExecutedGpuKernels();
849 MarkEagerlyExecutedCpuTfOps();
850 ProcessModelIds();
851 }
852
CreateInterThreadConnectInfoList()853 std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
854 std::vector<InterThreadConnectInfo> connect_info_list = {
855 {HostEventType::kExecutorStateProcess,
856 HostEventType::kIteratorGetNextOp,
857 {StatType::kStepId, StatType::kIterNum}},
858 {HostEventType::kExecutorStateProcess,
859 HostEventType::kIteratorGetNextAsOptionalOp,
860 {StatType::kStepId, StatType::kIterNum}},
861 {HostEventType::kKernelLaunch,
862 HostEventType::kKernelExecute,
863 {StatType::kCorrelationId}}};
864 return connect_info_list;
865 }
866
GroupTfEvents(XSpace * space,EventForest * event_forest)867 void GroupTfEvents(XSpace* space, EventForest* event_forest) {
868 if (CheckLoopOp(*space)) {
869 // TODO(b/154510598): Support TF's loop ops.
870 return;
871 }
872 std::vector<InterThreadConnectInfo> connect_info_list =
873 CreateInterThreadConnectInfoList();
874 event_forest->AddSpace(CreateTfXPlaneVisitor, space);
875 event_forest->ConnectEvents(connect_info_list);
876 event_forest->GroupEvents();
877 }
878
GroupTfEvents(XSpace * space)879 void GroupTfEvents(XSpace* space) {
880 EventForest event_forest;
881 GroupTfEvents(space, &event_forest);
882 }
883
884 } // namespace profiler
885 } // namespace tensorflow
886