1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/utils/group_events.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <iterator>
21 #include <map>
22 #include <memory>
23 #include <queue>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/strings/match.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/profiler/lib/connected_traceme.h"
35 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
36 #include "tensorflow/core/profiler/utils/xplane_builder.h"
37 #include "tensorflow/core/profiler/utils/xplane_schema.h"
38 #include "tensorflow/core/profiler/utils/xplane_utils.h"
39
40 namespace tensorflow {
41 namespace profiler {
42 namespace {
43
44 // Creates stat metadata for the stats which may be added by grouping.
CreateStatMetadata(XPlane * plane)45 void CreateStatMetadata(XPlane* plane) {
46 XPlaneBuilder builder(plane);
47 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
48 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
49 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
50 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kSelectedGroupIds));
51 }
52
53 // Returns event type if it is a KernelLaunch or KernelExecute event.
GetKernelEventType(bool is_host_plane,const EventNode & event)54 absl::optional<int64> GetKernelEventType(bool is_host_plane,
55 const EventNode& event) {
56 if (event.GetEventVisitor().GetStat(StatType::kCorrelationId).has_value()) {
57 return is_host_plane ? HostEventType::kKernelLaunch
58 : HostEventType::kKernelExecute;
59 }
60 return absl::nullopt;
61 }
62
GetEventType(bool is_host_plane,const EventNode & event)63 int64 GetEventType(bool is_host_plane, const EventNode& event) {
64 if (absl::optional<int64> event_type = event.GetEventVisitor().Type()) {
65 return *event_type;
66 } else if (absl::optional<int64> kernel_event_type =
67 GetKernelEventType(is_host_plane, event)) {
68 // KernelLaunch and KernelExecute event types are not supported by
69 // XPlaneVisitor and should be checked separately.
70 // TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
71 // KernelExecute event types.
72 return *kernel_event_type;
73 } else {
74 absl::string_view name = event.GetEventVisitor().Name();
75 // Legacy event names appended with arguments.
76 if (absl::StartsWith(name, "BatchingSessionRun")) {
77 return HostEventType::kBatchingSessionRun;
78 } else if (absl::StartsWith(name, "ProcessBatch")) {
79 return HostEventType::kProcessBatch;
80 }
81 return HostEventType::kUnknownHostEventType;
82 }
83 }
84
SetContextGroup(EventNode * event,ContextGroupMap * context_groups)85 void SetContextGroup(EventNode* event, ContextGroupMap* context_groups) {
86 auto producer = event->GetProducerContext();
87 if (producer.has_value()) {
88 ((*context_groups)[producer->type][producer->id])
89 .producers.push_back(event);
90 }
91 auto consumer = event->GetConsumerContext();
92 if (consumer.has_value()) {
93 ((*context_groups)[consumer->type][consumer->id])
94 .consumers.push_back(event);
95 }
96 }
97
ConnectContextGroups(const ContextGroupMap & context_groups)98 void ConnectContextGroups(const ContextGroupMap& context_groups) {
99 for (auto& type_id_group : context_groups) {
100 for (auto& id_group : type_id_group.second) {
101 const ContextGroup& group = id_group.second;
102 for (EventNode* parent : group.producers) {
103 for (EventNode* child : group.consumers) {
104 parent->AddChild(child);
105 }
106 }
107 }
108 }
109 }
110
CreateVirtualEvent(const XStat & step_id_stat,const XStat & iter_num_stat)111 std::unique_ptr<XEvent> CreateVirtualEvent(const XStat& step_id_stat,
112 const XStat& iter_num_stat) {
113 auto virtual_event = absl::make_unique<XEvent>();
114 *virtual_event->add_stats() = step_id_stat;
115 *virtual_event->add_stats() = iter_num_stat;
116 return virtual_event;
117 }
118
HasFunctionRun(EventNode * event_node)119 bool HasFunctionRun(EventNode* event_node) {
120 for (EventNode* child : event_node->GetChildren()) {
121 if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
122 return true;
123 }
124 }
125 return false;
126 }
127
IsImplicitRootEvent(const XEventVisitor & event)128 bool IsImplicitRootEvent(const XEventVisitor& event) {
129 static const auto* const kImplicitRootEvents = new absl::flat_hash_set<int64>{
130 HostEventType::kFunctionRun, HostEventType::kSessionRun,
131 HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
132 return event.Type().has_value() &&
133 kImplicitRootEvents->contains(*event.Type());
134 }
135
ProcessRootEvent(int64 group_id,bool set_step_name,EventNode * root_event,GroupMetadataMap * group_metadata_map)136 void ProcessRootEvent(int64 group_id, bool set_step_name, EventNode* root_event,
137 GroupMetadataMap* group_metadata_map) {
138 root_event->PropagateGroupId(group_id, group_metadata_map);
139 if (!set_step_name) {
140 // Step names are not necessary for inference profiles but add group_id to
141 // group_metadata_map to count the number of groups.
142 group_metadata_map->emplace(group_id, GroupMetadata());
143 return;
144 }
145 std::string group_name = root_event->GetGroupName();
146 // TODO(jihochoi): change event name instead.
147 if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
148 // Add the `step_name` stat for the user-defined root events only. When an
149 // XEvent is converted to a trace event, the trace event name is set to the
150 // `step_name` stat's value if present.
151 root_event->AddStepName(group_name);
152 }
153 (*group_metadata_map)[group_id].name = std::move(group_name);
154 }
155
IsTfDataEvent(const EventNode & event_node)156 bool IsTfDataEvent(const EventNode& event_node) {
157 return event_node.FindParent(HostEventType::kTfDataCapturedFunctionRun) ||
158 event_node.FindParent(
159 HostEventType::kTfDataCapturedFunctionRunAsync) ||
160 event_node.FindParent(
161 HostEventType::kTfDataCapturedFunctionRunInstantiated) ||
162 event_node.FindParent(
163 HostEventType::kTfDataCapturedFunctionRunWithBorrowedArgs);
164 }
165
166 struct ContextTypeAndId {
167 int type;
168 uint64 id;
169 };
170
GetLegacyProducerContext(const XEventVisitor & event)171 absl::optional<ContextTypeAndId> GetLegacyProducerContext(
172 const XEventVisitor& event) {
173 absl::optional<ContextTypeAndId> type_and_id;
174 absl::optional<int64> event_type = event.Type();
175 if (event_type.has_value()) {
176 switch (*event_type) {
177 case HostEventType::kTraceContext:
178 case HostEventType::kFunctionRun:
179 case HostEventType::kSessionRun:
180 case HostEventType::kRunGraph: {
181 absl::optional<XStatVisitor> stat = event.GetStat(StatType::kStepId);
182 if (stat.has_value()) {
183 type_and_id = {static_cast<int>(ContextType::kTfExecutor),
184 static_cast<uint64>(stat->IntValue())};
185 }
186 break;
187 }
188 case HostEventType::kCallOp:
189 case HostEventType::kNumericalGradientOpEvalRight:
190 case HostEventType::kNumericalGradientOpEvalLeft:
191 case HostEventType::kSymbolicGradientOp:
192 case HostEventType::kRemoteCallOp:
193 case HostEventType::kIfOp:
194 case HostEventType::kCaseOp:
195 case HostEventType::kPartitionedCallOp: {
196 // TODO(b/154510598): Fix handling of the loop ops.
197 // case HostEventType::kWhileOpEvalCond:
198 // case HostEventType::kWhileOpStartBody:
199 // case HostEventType::kForOp:
200 // case HostEventType::kParallelForOp:
201 // case HostEventType::kForeverOp:
202 absl::optional<XStatVisitor> stat =
203 event.GetStat(StatType::kFunctionStepId);
204 if (stat.has_value()) {
205 type_and_id = {static_cast<int>(ContextType::kTfExecutor),
206 static_cast<uint64>(stat->IntValue())};
207 }
208 break;
209 }
210 default:
211 break;
212 }
213 }
214 return type_and_id;
215 }
216
GetLegacyConsumerContext(const XEventVisitor & event)217 absl::optional<ContextTypeAndId> GetLegacyConsumerContext(
218 const XEventVisitor& event) {
219 absl::optional<ContextTypeAndId> type_and_id;
220 absl::optional<int64> event_type = event.Type();
221 if (event_type.has_value()) {
222 switch (*event_type) {
223 case HostEventType::kExecutorStateProcess:
224 case HostEventType::kExecutorDoneCallback:
225 case HostEventType::kRunGraphDone: {
226 absl::optional<XStatVisitor> stat = event.GetStat(StatType::kStepId);
227 if (stat.has_value()) {
228 type_and_id = {static_cast<int>(ContextType::kTfExecutor),
229 static_cast<uint64>(stat->IntValue())};
230 }
231 break;
232 }
233 default:
234 break;
235 }
236 }
237 return type_and_id;
238 }
239
IsLegacyRootEvent(const XEventVisitor & event)240 bool IsLegacyRootEvent(const XEventVisitor& event) {
241 static const auto* const kRootEvents = new absl::flat_hash_set<int64>{
242 HostEventType::kTraceContext, HostEventType::kFunctionRun,
243 HostEventType::kSessionRun, HostEventType::kRunGraph};
244 return event.Type().has_value() && kRootEvents->contains(*event.Type());
245 }
246
247 using Comparator = std::function<bool(const EventNode*)>;
248
FindParentWithComparator(const Comparator & comparator,const EventNode * node,bool include_self)249 const EventNode* FindParentWithComparator(const Comparator& comparator,
250 const EventNode* node,
251 bool include_self) {
252 std::queue<const EventNode*> nodes;
253 absl::flat_hash_set<const EventNode*> seen = {node};
254 if (include_self) {
255 nodes.push(node);
256 } else {
257 for (const EventNode* parent : node->GetParents()) {
258 nodes.push(parent);
259 seen.insert(parent);
260 }
261 }
262 while (!nodes.empty()) {
263 const EventNode* node = nodes.front();
264 nodes.pop();
265 if (comparator(node)) return node;
266 for (const EventNode* parent : node->GetParents()) {
267 if (seen.contains(parent)) continue;
268 nodes.push(parent);
269 seen.insert(parent);
270 }
271 }
272 return nullptr;
273 }
274
275 // Returns true if none of its ancestors is a root event.
IsTopRoot(const EventNode * event)276 bool IsTopRoot(const EventNode* event) {
277 // If it is already grouped, it is not a top root.
278 if (event->GetGroupId().has_value()) return false;
279 const EventNode* root_parent = FindParentWithComparator(
280 [](const EventNode* node) { return node->IsRoot(); }, event,
281 /*include_self=*/false);
282 return root_parent == nullptr;
283 }
284
SortEventList(EventList * event_list)285 void SortEventList(EventList* event_list) {
286 absl::c_sort(*event_list, [](const EventNode* e1, const EventNode* e2) {
287 return e1->GetEventVisitor().TimestampPs() <
288 e2->GetEventVisitor().TimestampPs();
289 });
290 }
291
292 // Returns true if it has JAX-related events.
HasJaxEvent(const EventNodeMap & event_node_map)293 bool HasJaxEvent(const EventNodeMap& event_node_map) {
294 return event_node_map.contains(HostEventType::kExecuteOnLocalDevices);
295 }
296
IsIteratorEventType(absl::optional<int64> event_type)297 bool IsIteratorEventType(absl::optional<int64> event_type) {
298 return event_type == HostEventType::kIterator ||
299 event_type == HostEventType::kDeviceInputPipelineSecondIterator;
300 }
301
302 } // namespace
303
304 // Returns true if TF's loop ops exist in the given XSpace's metadata.
CheckLoopOp(const XSpace & space)305 bool CheckLoopOp(const XSpace& space) {
306 for (const XPlane& plane : space.planes()) {
307 for (const auto& event_metadata : plane.event_metadata()) {
308 absl::optional<int64> event_type =
309 FindHostEventType(event_metadata.second.name());
310 if (!event_type.has_value()) continue;
311 switch (*event_type) {
312 case HostEventType::kWhileOpEvalCond:
313 case HostEventType::kWhileOpStartBody:
314 case HostEventType::kForOp:
315 case HostEventType::kParallelForOp:
316 case HostEventType::kForeverOp:
317 return true;
318 default:
319 break;
320 }
321 }
322 }
323 return false;
324 }
325
EventNode(const XPlaneVisitor * plane,XLine * raw_line,XEvent * raw_event)326 EventNode::EventNode(const XPlaneVisitor* plane, XLine* raw_line,
327 XEvent* raw_event)
328 : plane_(plane),
329 visitor_(plane, raw_line, raw_event),
330 raw_line_(raw_line),
331 raw_event_(raw_event) {
332 absl::optional<int> producer_type;
333 absl::optional<uint64> producer_id;
334 absl::optional<int> consumer_type;
335 absl::optional<uint64> consumer_id;
336
337 visitor_.ForEachStat([&](const XStatVisitor& stat) {
338 if (!stat.Type().has_value()) return;
339 switch (*stat.Type()) {
340 case StatType::kProducerType:
341 producer_type = stat.IntValue();
342 break;
343 case StatType::kProducerId:
344 producer_id = stat.IntOrUintValue();
345 break;
346 case StatType::kConsumerType:
347 consumer_type = stat.IntValue();
348 break;
349 case StatType::kConsumerId:
350 consumer_id = stat.IntOrUintValue();
351 break;
352 case StatType::kIsRoot:
353 is_root_ = stat.IntValue();
354 break;
355 case StatType::kIsAsync:
356 is_async_ = stat.IntValue();
357 break;
358 default:
359 break;
360 }
361 });
362
363 // Support legacy traces.
364 if (!producer_type.has_value() || !producer_id.has_value()) {
365 if (auto producer_context = GetLegacyProducerContext(visitor_)) {
366 producer_type = producer_context->type;
367 producer_id = producer_context->id;
368 }
369 }
370 if (!consumer_type.has_value() || !consumer_id.has_value()) {
371 if (auto consumer_context = GetLegacyConsumerContext(visitor_)) {
372 consumer_type = consumer_context->type;
373 consumer_id = consumer_context->id;
374 }
375 }
376 is_root_ = is_root_ || IsLegacyRootEvent(visitor_);
377
378 if (producer_type.has_value() && producer_id.has_value()) {
379 producer_context_ = {*producer_type, *producer_id};
380 }
381 if (consumer_type.has_value() && consumer_id.has_value()) {
382 consumer_context_ = {*consumer_type, *consumer_id};
383 }
384 }
385
EventNode(const EventNode & event_node)386 EventNode::EventNode(const EventNode& event_node)
387 : EventNode(event_node.plane_, event_node.raw_line_,
388 event_node.raw_event_) {}
389
GetContextStat(int64 stat_type) const390 absl::optional<XStatVisitor> EventNode::GetContextStat(int64 stat_type) const {
391 std::queue<const EventNode*> nodes;
392 absl::flat_hash_set<const EventNode*> seen = {this};
393 nodes.push(this);
394 while (!nodes.empty()) {
395 const EventNode* node = nodes.front();
396 nodes.pop();
397 if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
398 return stat;
399 }
400 for (const EventNode* parent : node->GetParents()) {
401 if (seen.contains(parent)) continue;
402 nodes.push(parent);
403 seen.insert(parent);
404 }
405 }
406 return absl::nullopt;
407 }
408
GetGroupName() const409 std::string EventNode::GetGroupName() const {
410 std::string name;
411 if (absl::optional<XStatVisitor> stat =
412 GetContextStat(StatType::kGraphType)) {
413 absl::StrAppend(&name, stat->StrOrRefValue(), " ");
414 } else if (!(IsImplicitRootEvent(visitor_))) {
415 absl::StrAppend(&name, GetEventVisitor().Name(), " ");
416 }
417 int64 step_num = group_id_.value_or(0);
418 if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
419 step_num = stat->IntValue();
420 } else if (absl::optional<XStatVisitor> stat =
421 GetContextStat(StatType::kStepNum)) {
422 step_num = stat->IntValue();
423 }
424 absl::StrAppend(&name, step_num);
425 return name;
426 }
427
FindOrAddStatByType(int64 stat_type)428 XStat* EventNode::FindOrAddStatByType(int64 stat_type) {
429 const XStatMetadata* stat_metadata = plane_->GetStatMetadataByType(stat_type);
430 DCHECK(stat_metadata != nullptr);
431 return FindOrAddMutableStat(*stat_metadata, raw_event_);
432 }
433
SetGroupId(int64 group_id)434 void EventNode::SetGroupId(int64 group_id) {
435 group_id_ = group_id;
436 FindOrAddStatByType(StatType::kGroupId)->set_int64_value(group_id);
437 }
438
PropagateGroupId(int64 group_id,GroupMetadataMap * group_metadata_map)439 void EventNode::PropagateGroupId(int64 group_id,
440 GroupMetadataMap* group_metadata_map) {
441 std::queue<EventNode*> nodes;
442 absl::flat_hash_set<EventNode*> seen = {this};
443 nodes.push(this);
444 while (!nodes.empty()) {
445 EventNode* node = nodes.front();
446 nodes.pop();
447 absl::optional<int64> node_group_id = node->GetGroupId();
448 if (node_group_id.has_value()) {
449 if (*node_group_id != group_id) {
450 (*group_metadata_map)[group_id].children.insert(*node_group_id);
451 (*group_metadata_map)[*node_group_id].parents.insert(group_id);
452 }
453 } else {
454 node->SetGroupId(group_id);
455 for (EventNode* child : node->GetChildren()) {
456 if (seen.contains(child)) continue;
457 nodes.push(child);
458 seen.insert(child);
459 }
460 }
461 }
462 }
463
AddStepName(absl::string_view step_name)464 void EventNode::AddStepName(absl::string_view step_name) {
465 FindOrAddStatByType(StatType::kStepName)
466 ->set_str_value(step_name.data(), step_name.size());
467 }
468
AddSelectedGroupIds(const GroupMetadataMap & group_metadata_map)469 void EventNode::AddSelectedGroupIds(
470 const GroupMetadataMap& group_metadata_map) {
471 const auto& group_metadata = group_metadata_map.at(*group_id_);
472 std::vector<int64> group_ids;
473 group_ids.reserve(1 + group_metadata.parents.size() +
474 group_metadata.children.size());
475 group_ids.push_back(*group_id_);
476 group_ids.insert(group_ids.end(), group_metadata.parents.begin(),
477 group_metadata.parents.end());
478 group_ids.insert(group_ids.end(), group_metadata.children.begin(),
479 group_metadata.children.end());
480 FindOrAddStatByType(StatType::kSelectedGroupIds)
481 ->set_str_value(
482 absl::StrCat("?selected_group_ids=", absl::StrJoin(group_ids, ",")));
483 }
484
SetIsEager(bool is_eager)485 void EventNode::SetIsEager(bool is_eager) {
486 FindOrAddStatByType(StatType::kIsEager)->set_int64_value(is_eager ? 1 : 0);
487 }
488
IsEager()489 bool EventNode::IsEager() {
490 // It is eagerly executed if its trace context includes the EagerKernelExecute
491 // event (which may execute an op eagerly or through the TF executor) but not
492 // the TF executor event.
493 return FindParent(HostEventType::kExecutorStateProcess) == nullptr &&
494 FindParent(HostEventType::kEagerKernelExecute) != nullptr;
495 }
496
FindParent(int64 event_type) const497 const EventNode* EventNode::FindParent(int64 event_type) const {
498 return FindParentWithComparator(
499 [event_type](const EventNode* node) {
500 return node->GetEventVisitor().Type() == event_type;
501 },
502 this, /*include_self=*/true);
503 }
504
StartsBefore(const EventNode & other) const505 bool EventNode::StartsBefore(const EventNode& other) const {
506 return GetEventVisitor().TimestampPs() <=
507 other.GetEventVisitor().TimestampPs();
508 }
509
ConnectIntraThread(XPlane * plane,XPlaneVisitor * visitor,ContextGroupMap * context_groups)510 void EventForest::ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor,
511 ContextGroupMap* context_groups) {
512 // TODO(b/149095099): avoid string comparison.
513 bool is_host_plane = (visitor->Name() == kHostThreadsPlaneName);
514 for (auto& line : *plane->mutable_lines()) {
515 std::vector<EventNode*> parent_nodes;
516 for (auto& event : *line.mutable_events()) {
517 auto cur_node = absl::make_unique<EventNode>(visitor, &line, &event);
518 // Update `context_groups` for `ConnectInterThread`.
519 SetContextGroup(cur_node.get(), context_groups);
520 // Update `root_events_` for `CreateEventGroup`.
521 if (cur_node->IsRoot()) root_events_.push_back(cur_node.get());
522 // Async events are ignored when processing the nesting relationship.
523 if (cur_node->IsAsync()) continue;
524 while (!parent_nodes.empty()) {
525 EventNode* parent_node = parent_nodes.back();
526 if (parent_node->GetEventVisitor().GetTimespan().Includes(
527 cur_node->GetEventVisitor().GetTimespan())) {
528 parent_node->AddChild(cur_node.get());
529 break;
530 } else {
531 parent_nodes.pop_back();
532 }
533 }
534 parent_nodes.push_back(cur_node.get());
535 // event_node_map_ keeps cur_node alive.
536 event_node_map_[GetEventType(is_host_plane, *cur_node)].push_back(
537 std::move(cur_node));
538 }
539 }
540 }
541
ConnectInterThread(const std::vector<InterThreadConnectInfo> & connect_info_list)542 void EventForest::ConnectInterThread(
543 const std::vector<InterThreadConnectInfo>& connect_info_list) {
544 for (const auto& connect_info : connect_info_list) {
545 absl::flat_hash_map<std::vector<uint64>, EventNode*> connect_map;
546 const std::vector<int64>& parent_stat_types =
547 connect_info.parent_stat_types;
548 const std::vector<int64>* child_stat_types = &connect_info.child_stat_types;
549 if (child_stat_types->empty()) {
550 child_stat_types = &parent_stat_types;
551 }
552 if (auto parent_event_node_list =
553 gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) {
554 for (const auto& parent_event_node : *parent_event_node_list) {
555 std::vector<uint64> stats;
556 for (auto stat_type : parent_stat_types) {
557 absl::optional<XStatVisitor> stat =
558 parent_event_node->GetContextStat(stat_type);
559 if (!stat) break;
560 stats.push_back(stat->IntOrUintValue());
561 }
562 if (stats.size() == parent_stat_types.size()) {
563 connect_map[stats] = parent_event_node.get();
564 }
565 }
566 }
567 if (auto child_event_node_list =
568 gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) {
569 for (const auto& child_event_node : *child_event_node_list) {
570 std::vector<uint64> stats;
571 for (auto stat_type : *child_stat_types) {
572 absl::optional<XStatVisitor> stat =
573 child_event_node->GetContextStat(stat_type);
574 if (!stat) break;
575 stats.push_back(stat->IntOrUintValue());
576 }
577 if (stats.size() == child_stat_types->size()) {
578 if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
579 parent_event_node->AddChild(child_event_node.get());
580 }
581 }
582 }
583 }
584 }
585 }
586
ProcessUserDefinedRootEvents(const std::vector<int64> & user_defined_root_event_types)587 void EventForest::ProcessUserDefinedRootEvents(
588 const std::vector<int64 /*EventType*/>& user_defined_root_event_types) {
589 for (int64 user_defined_root_event_type : user_defined_root_event_types) {
590 if (auto root_events =
591 gtl::FindOrNull(event_node_map_, user_defined_root_event_type)) {
592 for (const auto& root_event : *root_events) {
593 root_event->SetIsRoot(true);
594 root_events_.push_back(root_event.get());
595 }
596 }
597 }
598 }
599
CreateEventGroups()600 void EventForest::CreateEventGroups() {
601 // Handle inference batching profiles.
602 if (event_node_map_.contains(HostEventType::kProcessBatch)) {
603 // Assign group_id per batch.
604 for (const auto& process_batch_node :
605 event_node_map_[HostEventType::kProcessBatch]) {
606 ProcessRootEvent(next_group_id_++, /*set_step_name=*/false,
607 process_batch_node.get(), &group_metadata_map_);
608 }
609 HostEventType request_event_type =
610 event_node_map_.contains(HostEventType::kBatchingSessionRun)
611 ? HostEventType::kBatchingSessionRun
612 : HostEventType::kSessionRun;
613 if (auto request_events =
614 gtl::FindOrNull(event_node_map_, request_event_type)) {
615 // Assign group_id per request.
616 for (const auto& request_event : *request_events) {
617 ProcessRootEvent(next_group_id_++, /*set_step_name=*/false,
618 request_event.get(), &group_metadata_map_);
619 // Also, set a helper stat for selected_group_ids.
620 request_event->AddSelectedGroupIds(group_metadata_map_);
621 }
622 }
623 // Set a helper stat for selected_group_ids per batch.
624 for (const auto& process_batch_node :
625 event_node_map_[HostEventType::kProcessBatch]) {
626 process_batch_node->AddSelectedGroupIds(group_metadata_map_);
627 }
628 return;
629 }
630 // Create a group for each TF loop iteration in non-JAX profiles.
631 if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
632 for (EventNode* root_event : tf_loop_root_events_) {
633 ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event,
634 &group_metadata_map_);
635 }
636 return;
637 }
638 SortEventList(&root_events_);
639 // Create a group for each top root event while ignoring TF's legacy root
640 // events for JAX profiles.
641 for (EventNode* root_event : root_events_) {
642 if (IsTopRoot(root_event) &&
643 (!HasJaxEvent(event_node_map_) ||
644 !IsLegacyRootEvent(root_event->GetEventVisitor()))) {
645 ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event,
646 &group_metadata_map_);
647 }
648 }
649 }
650
MarkEagerlyExecutedGpuKernels()651 void EventForest::MarkEagerlyExecutedGpuKernels() {
652 auto kernel_execute_event_node_list =
653 gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
654 if (!kernel_execute_event_node_list) return;
655 for (auto& kernel_execute_event_node : *kernel_execute_event_node_list) {
656 kernel_execute_event_node->SetIsEager(kernel_execute_event_node->IsEager());
657 }
658 }
659
MarkEagerlyExecutedCpuTfOps()660 void EventForest::MarkEagerlyExecutedCpuTfOps() {
661 auto tf_op_run_event_node_list =
662 gtl::FindOrNull(event_node_map_, HostEventType::kTfOpRun);
663 if (!tf_op_run_event_node_list) return;
664 for (auto& tf_op_run_event_node : *tf_op_run_event_node_list) {
665 tf_op_run_event_node->SetIsEager(tf_op_run_event_node->IsEager());
666 }
667 }
668
ProcessTensorFlowLoop()669 void EventForest::ProcessTensorFlowLoop() {
670 struct TensorFlowLoopIteration {
671 EventNode* first_event = nullptr;
672 std::vector<EventNode*> events;
673 };
674 using TensorFlowLoop = std::map<int64 /*iter_num*/, TensorFlowLoopIteration>;
675 absl::flat_hash_map<int64 /*step_id*/, TensorFlowLoop> tf_loops;
676
677 // Sort the TF executor events by TF function/session (step_id) and iter_num.
678 auto executor_event_list =
679 gtl::FindOrNull(event_node_map_, HostEventType::kExecutorStateProcess);
680 if (!executor_event_list) return;
681 for (auto& executor_event : *executor_event_list) {
682 if (IsTfDataEvent(*executor_event)) continue;
683 absl::optional<XStatVisitor> step_id_stat =
684 executor_event->GetContextStat(StatType::kStepId);
685 absl::optional<XStatVisitor> iter_num_stat =
686 executor_event->GetContextStat(StatType::kIterNum);
687 if (!step_id_stat || !iter_num_stat) continue;
688 int64 step_id = step_id_stat->IntValue();
689 TensorFlowLoop& tf_loop = tf_loops[step_id];
690 TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
691 if (!iteration.first_event ||
692 executor_event->StartsBefore(*iteration.first_event)) {
693 iteration.first_event = executor_event.get();
694 }
695 iteration.events.push_back(executor_event.get());
696 }
697
698 // Sort the TF loops by start time.
699 std::map<int64 /*start_time*/, int64 /*step_id*/> sorted_tf_loops;
700 for (const auto& step_id_and_tf_loop : tf_loops) {
701 auto& iterations = step_id_and_tf_loop.second;
702 // Filter out TF function/session without loops.
703 if (iterations.size() == 1 && iterations.count(0)) continue;
704 int64 start_time = iterations.cbegin()
705 ->second.first_event->GetEventVisitor()
706 .TimestampPs();
707 DCHECK_EQ(sorted_tf_loops.count(start_time), 0);
708 sorted_tf_loops[start_time] = step_id_and_tf_loop.first;
709 }
710
711 // Register the first event of each iteration as a root event. Also, add the
712 // other events of the iteration as child to the root event.
713 bool next_group_id_updated = false;
714 for (const auto& start_time_and_step_id : sorted_tf_loops) {
715 TensorFlowLoop& tf_loop = tf_loops[start_time_and_step_id.second];
716 for (auto& iter_num_and_iteration : tf_loop) {
717 if (!next_group_id_updated) {
718 // Set next_group_id_ to the first iter_num of the first TF loop. This
719 // is necessary later when selecting the intersection of the steps from
720 // multiple hosts.
721 next_group_id_ = iter_num_and_iteration.first;
722 next_group_id_updated = true;
723 }
724 TensorFlowLoopIteration& iteration = iter_num_and_iteration.second;
725 EventNode* root_event = iteration.first_event;
726 tf_loop_root_events_.push_back(root_event);
727 for (EventNode* event : iteration.events) {
728 if (event == root_event) continue;
729 root_event->AddChild(event);
730 }
731 }
732 }
733 }
734
ProcessWorker()735 void EventForest::ProcessWorker() {
736 auto eager_kernel_execute_event_list =
737 gtl::FindOrNull(event_node_map_, HostEventType::kEagerKernelExecute);
738 if (!eager_kernel_execute_event_list) return;
739 // The last EagerKernelExecute with a FunctionRun child.
740 EventNode* root_event = nullptr;
741 for (auto& eager_kernel_execute_event : *eager_kernel_execute_event_list) {
742 if (HasFunctionRun(eager_kernel_execute_event.get())) {
743 // A function op becomes a new root.
744 root_event = eager_kernel_execute_event.get();
745 root_event->SetIsRoot(true);
746 root_events_.push_back(root_event);
747 } else if (root_event) {
748 // Add non-function eager ops as child.
749 root_event->AddChild(eager_kernel_execute_event.get());
750 }
751 }
752 }
753
ProcessModelIds()754 void EventForest::ProcessModelIds() {
755 auto session_run_event_list =
756 gtl::FindOrNull(event_node_map_, HostEventType::kSessionRun);
757 if (!session_run_event_list) return;
758 for (const auto& session_run_event : *session_run_event_list) {
759 auto group_id = session_run_event->GetGroupId();
760 if (!group_id.has_value()) continue;
761 absl::optional<XStatVisitor> model_id =
762 session_run_event->GetEventVisitor().GetStat(StatType::kModelId);
763 if (!model_id.has_value()) continue;
764 group_metadata_map_[*group_id].model_id = model_id->ToString();
765 }
766 }
767
AddPlane(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XPlane * plane)768 void EventForest::AddPlane(
769 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
770 XPlane* plane) {
771 CreateStatMetadata(plane);
772 planes_.push_back({plane, visitor_factory(plane)});
773 }
774
AddSpace(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XSpace * space)775 void EventForest::AddSpace(
776 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
777 XSpace* space) {
778 for (XPlane& plane : *space->mutable_planes()) {
779 AddPlane(visitor_factory, &plane);
780 }
781 }
782
AddPlanes(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,const std::vector<XPlane * > & planes)783 void EventForest::AddPlanes(
784 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
785 const std::vector<XPlane*>& planes) {
786 for (XPlane* plane : planes) {
787 AddPlane(visitor_factory, plane);
788 }
789 }
790
ConnectEvents(const std::vector<InterThreadConnectInfo> & connect_info_list)791 void EventForest::ConnectEvents(
792 const std::vector<InterThreadConnectInfo>& connect_info_list) {
793 ContextGroupMap context_groups;
794 for (auto& plane_visitor : planes_) {
795 ConnectIntraThread(plane_visitor.first, &plane_visitor.second,
796 &context_groups);
797 }
798 ConnectInterThread(connect_info_list);
799 ConnectContextGroups(context_groups);
800 }
801
ConnectTfDataEvents()802 void EventForest::ConnectTfDataEvents() {
803 absl::flat_hash_map<std::pair<int64 /*iterator_id*/, int64 /*element_id*/>,
804 std::vector<EventNode*>>
805 produce_iterator_map;
806 uint64 num_producers = 0;
807 for (HostEventType event_type :
808 {HostEventType::kPrefetchProduce,
809 HostEventType::kParallelInterleaveProduce,
810 HostEventType::kParallelMapProduce, HostEventType::kMapAndBatchProduce,
811 HostEventType::kParseExampleProduce}) {
812 auto produce_event_list = gtl::FindOrNull(event_node_map_, event_type);
813 if (!produce_event_list) continue;
814 VLOG(1) << produce_event_list->size() << " "
815 << GetHostEventTypeStr(event_type) << " events found.";
816 for (auto& produce_event : *produce_event_list) {
817 absl::optional<XStatVisitor> element_id =
818 produce_event->GetEventVisitor().GetStat(StatType::kElementId);
819 if (!element_id.has_value()) continue;
820 for (EventNode* produce_iterator : produce_event->GetChildren()) {
821 if (IsIteratorEventType(produce_iterator->GetEventVisitor().Type())) {
822 absl::optional<XStatVisitor> iterator_id =
823 produce_iterator->GetEventVisitor().GetStat(StatType::kParentId);
824 if (!iterator_id.has_value()) break;
825 produce_iterator_map[{iterator_id->IntValue(),
826 element_id->IntValue()}]
827 .push_back(produce_iterator);
828 ++num_producers;
829 break;
830 }
831 }
832 }
833 }
834 VLOG(1) << num_producers << " producer iterators found.";
835 uint64 num_matched = 0;
836 for (HostEventType event_type :
837 {HostEventType::kPrefetchConsume,
838 HostEventType::kParallelInterleaveConsume,
839 HostEventType::kParallelMapConsume, HostEventType::kMapAndBatchConsume,
840 HostEventType::kParseExampleConsume}) {
841 auto consume_event_list = gtl::FindOrNull(event_node_map_, event_type);
842 if (!consume_event_list) continue;
843 VLOG(1) << consume_event_list->size() << " "
844 << GetHostEventTypeStr(event_type) << " events found.";
845 for (auto& consume_event : *consume_event_list) {
846 absl::optional<XStatVisitor> element_id =
847 consume_event->GetEventVisitor().GetStat(StatType::kElementId);
848 if (!element_id.has_value()) continue;
849 if (consume_event->GetParents().empty()) continue;
850 // consume_event is nested by consumer_iterator and does not have other
851 // parents.
852 EventNode* consume_iterator = consume_event->GetParents().at(0);
853 if (!consume_iterator ||
854 !IsIteratorEventType(consume_iterator->GetEventVisitor().Type())) {
855 continue;
856 }
857 absl::optional<XStatVisitor> iterator_id =
858 consume_iterator->GetEventVisitor().GetStat(StatType::kStepId);
859 if (!iterator_id.has_value()) continue;
860 if (auto produce_iterators = gtl::FindOrNull(
861 produce_iterator_map, std::make_pair(iterator_id->IntValue(),
862 element_id->IntValue()))) {
863 for (EventNode* produce_iterator : *produce_iterators) {
864 consume_iterator->AddChild(produce_iterator);
865 ++num_matched;
866 }
867 }
868 }
869 }
870 VLOG(1) << num_matched << " consumer iterators matched.";
871 }
872
GroupEvents(const std::vector<int64> & user_defined_root_event_types)873 void EventForest::GroupEvents(
874 const std::vector<int64>& user_defined_root_event_types) {
875 ProcessTensorFlowLoop();
876 ProcessWorker();
877 ProcessUserDefinedRootEvents(user_defined_root_event_types);
878 CreateEventGroups();
879 MarkEagerlyExecutedGpuKernels();
880 MarkEagerlyExecutedCpuTfOps();
881 ProcessModelIds();
882 }
883
CreateInterThreadConnectInfoList()884 std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
885 std::vector<InterThreadConnectInfo> connect_info_list = {
886 {HostEventType::kExecutorStateProcess,
887 HostEventType::kIteratorGetNextOp,
888 {StatType::kStepId, StatType::kIterNum}},
889 {HostEventType::kExecutorStateProcess,
890 HostEventType::kIteratorGetNextAsOptionalOp,
891 {StatType::kStepId, StatType::kIterNum}},
892 {HostEventType::kKernelLaunch,
893 HostEventType::kKernelExecute,
894 {StatType::kCorrelationId}}};
895 return connect_info_list;
896 }
897
GroupTfEvents(XSpace * space,EventForest * event_forest)898 void GroupTfEvents(XSpace* space, EventForest* event_forest) {
899 if (CheckLoopOp(*space)) {
900 // TODO(b/154510598): Support TF's loop ops.
901 return;
902 }
903 std::vector<InterThreadConnectInfo> connect_info_list =
904 CreateInterThreadConnectInfoList();
905 event_forest->AddSpace(CreateTfXPlaneVisitor, space);
906 event_forest->ConnectEvents(connect_info_list);
907 event_forest->GroupEvents();
908 }
909
GroupTfEvents(XSpace * space)910 void GroupTfEvents(XSpace* space) {
911 EventForest event_forest;
912 GroupTfEvents(space, &event_forest);
913 }
914
915 } // namespace profiler
916 } // namespace tensorflow
917