1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h"
26 #include "tensorflow/core/profiler/utils/group_events.h"
27 #include "tensorflow/core/profiler/utils/html_utils.h"
28 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
29 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
30 #include "tensorflow/core/profiler/utils/timespan.h"
31 #include "tensorflow/core/profiler/utils/xplane_schema.h"
32 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
33
34 namespace tensorflow {
35 namespace profiler {
36
37 // 50 us from https://www.tensorflow.org/guide/data_performance_analysis
38 const int64_t kSlowCallThresholdPs = 50 * 1000000;
39
40 namespace {
41
42 // Returns true if the given iterator event is for a root iterator.
IsRootIteratorEvent(const XEventVisitor & iterator_event)43 bool IsRootIteratorEvent(const XEventVisitor& iterator_event) {
44 std::vector<absl::string_view> split_result =
45 absl::StrSplit(iterator_event.Name(), "::");
46 // The root iterator's name contains only its own name (no parent
47 // information).
48 return split_result.size() == 2;
49 }
50
51 // Returns true if the given iterator event name is for an async iterator.
IsAsyncIterator(absl::string_view iterator_event_name)52 bool IsAsyncIterator(absl::string_view iterator_event_name) {
53 static auto* kAsyncIterators = new absl::flat_hash_set<absl::string_view>(
54 {"Prefetch", "ParallelInterleave", "ParallelMap", "ParseExample",
55 "MapAndBatch", "DataService", "LegacyParallelInterleave",
56 "ParallelBatch"});
57 return kAsyncIterators->contains(iterator_event_name);
58 }
59
SetIteratorMetadata(int64_t id,const XEventVisitor & event,IteratorMetadata * metadata)60 void SetIteratorMetadata(int64_t id, const XEventVisitor& event,
61 IteratorMetadata* metadata) {
62 metadata->set_id(id);
63 auto parent_id_stat = event.GetStat(StatType::kParentId);
64 if (parent_id_stat.has_value()) {
65 metadata->set_parent_id(parent_id_stat->IntValue());
66 }
67 metadata->set_name(IteratorName(event.Name()));
68 metadata->set_long_name(event.Name().data(), event.Name().size());
69 metadata->set_is_async(IsAsyncIterator(metadata->name()));
70 // TODO(b/161831651): Set params.
71 }
72
73 // Returns the parent iterator's id if it is a root of a device input
74 // pipeline.
FindDeviceInputPipeline(const XEventVisitor & event)75 absl::optional<int64> FindDeviceInputPipeline(const XEventVisitor& event) {
76 if (event.Type() == HostEventType::kDeviceInputPipelineSecondIterator) {
77 auto parent_id_stat = event.GetStat(StatType::kParentId);
78 if (parent_id_stat.has_value()) return parent_id_stat->IntValue();
79 }
80 return absl::nullopt;
81 }
82
83 // Processes EventForest to do the following:
84 // (1) set iterator metadata
85 // (2) find root iterator events
86 // (3) find device input pipeline ids
ProcessEventForest(const EventForest & event_forest,absl::flat_hash_set<int64> * device_input_pipeline_ids,absl::flat_hash_map<int64,std::vector<EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)87 void ProcessEventForest(const EventForest& event_forest,
88 absl::flat_hash_set<int64>* device_input_pipeline_ids,
89 absl::flat_hash_map<int64, std::vector<EventNode*>>*
90 root_iterator_event_map,
91 TfDataStats* tf_data_stats) {
92 const EventNodeMap& event_node_map = event_forest.GetEventNodeMap();
93 auto iterator_event_list =
94 gtl::FindOrNull(event_node_map, HostEventType::kIterator);
95 if (!iterator_event_list) return;
96 for (const auto& iterator_event : *iterator_event_list) {
97 const XEventVisitor& iterator_event_visitor =
98 iterator_event->GetEventVisitor();
99 auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
100 if (!iterator_id_stat.has_value()) continue;
101 int64_t iterator_id = iterator_id_stat->IntValue();
102 auto result = tf_data_stats->mutable_iterator_metadata()->insert(
103 {iterator_id, IteratorMetadata()});
104 IteratorMetadata& metadata = result.first->second;
105 if (result.second) {
106 // First time processing this iterator.
107 SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
108 }
109 if (IsRootIteratorEvent(iterator_event_visitor)) {
110 // Record root iterator events.
111 (*root_iterator_event_map)[iterator_id].push_back(iterator_event.get());
112 }
113 }
114 auto device_input_pipeline_second_iterator_events = gtl::FindOrNull(
115 event_node_map, HostEventType::kDeviceInputPipelineSecondIterator);
116 if (!device_input_pipeline_second_iterator_events) return;
117 for (const auto& iterator_event :
118 *device_input_pipeline_second_iterator_events) {
119 const XEventVisitor& iterator_event_visitor =
120 iterator_event->GetEventVisitor();
121 auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
122 if (!iterator_id_stat.has_value()) continue;
123 int64_t iterator_id = iterator_id_stat->IntValue();
124 auto result = tf_data_stats->mutable_iterator_metadata()->insert(
125 {iterator_id, IteratorMetadata()});
126 IteratorMetadata& metadata = result.first->second;
127 if (result.second) {
128 // First time processing this iterator.
129 SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
130 // Find and record device input pipeline ids.
131 absl::optional<int64> device_input_pipeline_id =
132 FindDeviceInputPipeline(iterator_event_visitor);
133 if (device_input_pipeline_id.has_value()) {
134 device_input_pipeline_ids->insert(*device_input_pipeline_id);
135 }
136 }
137 }
138 }
139
SetInputPipelineMetadata(int64_t id,int64_t name_id,bool is_device_input_pipeline,InputPipelineMetadata * metadata)140 void SetInputPipelineMetadata(int64_t id, int64_t name_id,
141 bool is_device_input_pipeline,
142 InputPipelineMetadata* metadata) {
143 constexpr absl::string_view kHostInputPipelinePrefix = "Host:";
144 constexpr absl::string_view kDeviceInputPipelinePrefix = "Device:";
145 metadata->set_id(id);
146 if (is_device_input_pipeline) {
147 metadata->set_type(InputPipelineMetadata::DEVICE);
148 metadata->set_name(absl::StrCat(kDeviceInputPipelinePrefix, name_id));
149 } else {
150 metadata->set_type(InputPipelineMetadata::HOST);
151 metadata->set_name(absl::StrCat(kHostInputPipelinePrefix, name_id));
152 }
153 }
154
ProcessIteratorEvent(const EventNode & iterator_event,InputPipelineStat * input_pipeline_stat,bool is_blocking)155 void ProcessIteratorEvent(const EventNode& iterator_event,
156 InputPipelineStat* input_pipeline_stat,
157 bool is_blocking) {
158 const XEventVisitor& visitor = iterator_event.GetEventVisitor();
159 auto iterator_id_stat = visitor.GetStat(StatType::kStepId);
160 if (!iterator_id_stat.has_value()) return;
161 int64_t iterator_id = iterator_id_stat->IntValue();
162 auto result = input_pipeline_stat->mutable_iterator_stats()->insert(
163 {iterator_id, IteratorStat()});
164 IteratorStat& iterator_stat = result.first->second;
165 if (result.second) {
166 iterator_stat.set_id(iterator_id);
167 iterator_stat.set_start_time_ps(visitor.TimestampPs());
168 }
169 iterator_stat.set_duration_ps(iterator_stat.duration_ps() +
170 visitor.DurationPs());
171 int64_t self_time_ps = visitor.DurationPs();
172 Timespan self_time_span = visitor.GetTimespan();
173 for (EventNode* child : iterator_event.GetChildren()) {
174 const XEventVisitor& child_visitor = child->GetEventVisitor();
175 if (ParseTfOpFullname(child_visitor.Name()).category == Category::kTfData) {
176 int64_t overlap_duration_ps =
177 self_time_span.OverlappedDurationPs(child_visitor.GetTimespan());
178 ProcessIteratorEvent(*child, input_pipeline_stat,
179 is_blocking && overlap_duration_ps);
180 // Note: Assume no overlap between child events.
181 self_time_ps -= overlap_duration_ps;
182 }
183 }
184 iterator_stat.set_self_time_ps(iterator_stat.self_time_ps() + self_time_ps);
185 iterator_stat.set_is_blocking(iterator_stat.is_blocking() || is_blocking);
186 iterator_stat.set_num_calls(iterator_stat.num_calls() + 1);
187 }
188
SetBottleneckIteratorId(InputPipelineStat * input_pipeline_stat)189 void SetBottleneckIteratorId(InputPipelineStat* input_pipeline_stat) {
190 int64_t bottleneck_iterator_id = 0;
191 int64_t max_self_time = 0;
192 for (const auto& pair : input_pipeline_stat->iterator_stats()) {
193 const auto& id = pair.first;
194 const auto& iterator_stat = pair.second;
195 if (iterator_stat.is_blocking() &&
196 iterator_stat.self_time_ps() > max_self_time) {
197 bottleneck_iterator_id = id;
198 max_self_time = iterator_stat.self_time_ps();
199 }
200 }
201 input_pipeline_stat->set_bottleneck_iterator_id(bottleneck_iterator_id);
202 input_pipeline_stat->set_bottleneck_iterator_latency_ps(max_self_time);
203 }
204
ProcessInputPipelines(const absl::flat_hash_set<int64> & device_input_pipeline_ids,absl::flat_hash_map<int64,std::vector<EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)205 void ProcessInputPipelines(
206 const absl::flat_hash_set<int64>& device_input_pipeline_ids,
207 absl::flat_hash_map<int64, std::vector<EventNode*>>*
208 root_iterator_event_map,
209 TfDataStats* tf_data_stats) {
210 auto* input_pipelines = tf_data_stats->mutable_input_pipelines();
211 int64_t num_host_input_pipelines = 0;
212 int64_t num_device_input_pipelines = 0;
213 for (auto& id_and_events : *root_iterator_event_map) {
214 auto& root_iterator_id = id_and_events.first;
215 auto& root_iterator_events = id_and_events.second;
216 absl::c_sort(root_iterator_events,
217 [](const EventNode* lhs, const EventNode* rhs) {
218 return lhs->GetEventVisitor().DurationPs() >
219 rhs->GetEventVisitor().DurationPs();
220 });
221 auto result =
222 input_pipelines->insert({root_iterator_id, InputPipelineStats()});
223 InputPipelineStats& input_pipeline_stats = result.first->second;
224 InputPipelineMetadata* metadata = input_pipeline_stats.mutable_metadata();
225 if (result.second) {
226 bool is_device_input_pipeline =
227 device_input_pipeline_ids.contains(root_iterator_id);
228 int64_t name_id = is_device_input_pipeline ? num_device_input_pipelines++
229 : num_host_input_pipelines++;
230 SetInputPipelineMetadata(root_iterator_id, name_id,
231 is_device_input_pipeline, metadata);
232 }
233 int64_t sum_latency_ps = 0;
234 int64_t min_latency_ps = INT64_MAX;
235 int64_t max_latency_ps = 0;
236 int64_t num_slow_calls = 0;
237 for (const EventNode* root_iterator_event : root_iterator_events) {
238 InputPipelineStat* stat = input_pipeline_stats.add_stats();
239 ProcessIteratorEvent(*root_iterator_event, stat,
240 /*is_blocking*/ true);
241 SetBottleneckIteratorId(stat);
242 int64_t latency_ps = root_iterator_event->GetEventVisitor().DurationPs();
243 sum_latency_ps += latency_ps;
244 min_latency_ps = std::min(min_latency_ps, latency_ps);
245 max_latency_ps = std::max(max_latency_ps, latency_ps);
246 if (latency_ps > kSlowCallThresholdPs) num_slow_calls++;
247 }
248 input_pipeline_stats.set_avg_latency_ps(sum_latency_ps /
249 root_iterator_events.size());
250 input_pipeline_stats.set_min_latency_ps(min_latency_ps);
251 input_pipeline_stats.set_max_latency_ps(max_latency_ps);
252 input_pipeline_stats.set_num_slow_calls(num_slow_calls);
253 }
254 }
255
SetBottleneckAnalysis(CombinedTfDataStats * combined_tf_data_stats)256 void SetBottleneckAnalysis(CombinedTfDataStats* combined_tf_data_stats) {
257 struct InputPipeline {
258 InputPipeline(absl::string_view host_name,
259 absl::string_view input_pipeline_name, int64_t max_latency_ps,
260 absl::string_view iterator_name,
261 absl::string_view iterator_long_name,
262 int64_t iterator_latency_ps)
263 : host_name(host_name),
264 input_pipeline_name(input_pipeline_name),
265 max_latency_ps(max_latency_ps),
266 iterator_name(iterator_name),
267 iterator_long_name(iterator_long_name),
268 iterator_latency_ps(iterator_latency_ps) {}
269 absl::string_view host_name;
270 absl::string_view input_pipeline_name;
271 int64 max_latency_ps;
272 absl::string_view iterator_name;
273 absl::string_view iterator_long_name;
274 int64 iterator_latency_ps;
275
276 bool operator<(const InputPipeline& rhs) const {
277 return max_latency_ps > rhs.max_latency_ps;
278 }
279 };
280 std::vector<InputPipeline> slow_input_pipelines;
281 for (const auto& host_name_and_tf_data_stats :
282 combined_tf_data_stats->tf_data_stats()) {
283 absl::string_view host_name = host_name_and_tf_data_stats.first;
284 const TfDataStats& tf_data_stats = host_name_and_tf_data_stats.second;
285 for (const auto& id_and_stats : tf_data_stats.input_pipelines()) {
286 const InputPipelineStats& input_pipeline_stats = id_and_stats.second;
287 if (input_pipeline_stats.metadata().type() ==
288 InputPipelineMetadata::DEVICE) {
289 // Ignore device input pipelines.
290 continue;
291 }
292 // Choose the slowest execution trace of the input pipeline.
293 // `input_pipeline_stats.stats` is already sorted so choose the first one.
294 const InputPipelineStat& input_pipeline_stat =
295 input_pipeline_stats.stats(0);
296 const IteratorMetadata& metadata = tf_data_stats.iterator_metadata().at(
297 input_pipeline_stat.bottleneck_iterator_id());
298 slow_input_pipelines.emplace_back(
299 host_name, input_pipeline_stats.metadata().name(),
300 input_pipeline_stats.max_latency_ps(), metadata.name(),
301 metadata.long_name(),
302 input_pipeline_stat.bottleneck_iterator_latency_ps());
303 }
304 }
305 std::sort(slow_input_pipelines.begin(), slow_input_pipelines.end());
306 for (const auto& input_pipeline : slow_input_pipelines) {
307 TfDataBottleneckAnalysis* bottleneck_analysis =
308 combined_tf_data_stats->add_bottleneck_analysis();
309 bottleneck_analysis->set_host(input_pipeline.host_name.data(),
310 input_pipeline.host_name.size());
311 bottleneck_analysis->set_input_pipeline(
312 input_pipeline.input_pipeline_name.data(),
313 input_pipeline.input_pipeline_name.size());
314 bottleneck_analysis->set_max_latency_ps(input_pipeline.max_latency_ps);
315 bottleneck_analysis->set_iterator_name(input_pipeline.iterator_name.data(),
316 input_pipeline.iterator_name.size());
317 bottleneck_analysis->set_iterator_long_name(
318 input_pipeline.iterator_long_name.data(),
319 input_pipeline.iterator_long_name.size());
320 bottleneck_analysis->set_iterator_latency_ps(
321 input_pipeline.iterator_latency_ps);
322 }
323 }
324
GetSuggestion(BottleneckType type)325 std::string GetSuggestion(BottleneckType type) {
326 constexpr absl::string_view kPlaybookLink =
327 "https://www.tensorflow.org/guide/data_performance_analysis";
328 constexpr absl::string_view kPlaybookSourceDatasetLink =
329 "https://www.tensorflow.org/guide/"
330 "data_performance_analysis#source_datasets";
331 constexpr absl::string_view kPlaybookCpuUtilizationLink =
332 "https://www.tensorflow.org/guide/"
333 "data_performance_analysis#3_are_you_reaching_high_cpu_utilization";
334 constexpr absl::string_view kPlaybookTransformationLink =
335 "https://www.tensorflow.org/guide/"
336 "data_performance_analysis#transformation_datasets";
337 constexpr absl::string_view kTfGuideParallelDataExtractionLink =
338 "https://www.tensorflow.org/guide/"
339 "data_performance#parallelizing_data_extraction";
340 constexpr absl::string_view kTfGuideParallelTransformationLink =
341 "https://www.tensorflow.org/guide/"
342 "data_performance#parallelizing_data_transformation";
343 constexpr absl::string_view kTfGuideCacheLink =
344 "https://www.tensorflow.org/guide/data_performance#caching";
345 constexpr absl::string_view kTfDataServiceLink =
346 "https://www.tensorflow.org/api_docs/python/tf/data/experimental/"
347 "service?version=nightly";
348 switch (type) {
349 case BottleneckType::kSlowSource:
350 return absl::StrFormat(
351 "1. Check the locality of a host and input data. Ideally, they "
352 "should be in the same cell (or very close, like the same "
353 "region).<br/>"
354 "2. Parallelize reading from this dataset source. See %s and %s for "
355 "more details.<br/>",
356 AnchorElement(kPlaybookSourceDatasetLink, "here"),
357 AnchorElement(kTfGuideParallelDataExtractionLink, "here"));
358 case BottleneckType::kSlowDataService:
359 return absl::StrFormat(
360 "1. Fetching data from tf.data service took a while. Profile the "
361 "tf.data service worker to analyze the issue further.<br/>"
362 "2. See %s for more details on tf.data service.<br/>"
363 "3. See %s for other suggestions.",
364 AnchorElement(kTfDataServiceLink, "this"),
365 AnchorElement(kPlaybookLink, "this"));
366 case BottleneckType::kSlowRemoteSource:
367 return absl::StrFormat(
368 "1. The remote data source is slow. Profile its host to analyze the "
369 "issue further.<br/>"
370 "2. See %s for other suggestions.",
371 AnchorElement(kPlaybookLink, "this"));
372 case BottleneckType::kSlowTransformationWithParallelVersion:
373 return absl::StrFormat(
374 "1. Parallelize this transformation by setting "
375 "<code>num_parallel_calls=tf.data.experimental.AUTOTUNE</code>. See "
376 "%s for more details.<br/>"
377 "2. Consider adding <code>cache</code> after this transformation if "
378 "your data fits into memory and it is appropriate (e.g., there is no "
379 "randomness in upstream transformations like <code>shuffle</code>). "
380 "See %s for more details.<br/>"
381 "3. Find more resources %s.",
382 AnchorElement(kTfGuideParallelTransformationLink, "this"),
383 AnchorElement(kTfGuideCacheLink, "this"),
384 AnchorElement(kPlaybookTransformationLink, "here"));
385 case BottleneckType::kSlowTransformationWithoutParallelVersion:
386 return absl::StrFormat(
387 "1. This transformation is inherently sequential. Add outer "
388 "parallelism by running multiple copies of the input pipeline over "
389 "sharded inputs and combining the results. See %s for more "
390 "details.<br/>"
391 "2. Consider adding <code>cache</code> after this transformation if "
392 "your data fits into memory and it is appropriate (e.g., there is no "
393 "randomness in upstream transformations like <code>shuffle</code>). "
394 "See %s for more details.<br/>"
395 "3. Find more resources %s.",
396 AnchorElement(kPlaybookTransformationLink, "this"),
397 AnchorElement(kTfGuideCacheLink, "this"),
398 AnchorElement(kPlaybookCpuUtilizationLink, "here"));
399 default:
400 return absl::StrFormat("See %s for suggestions.",
401 AnchorElement(kPlaybookLink, "this"));
402 }
403 }
404
SetSuggestion(CombinedTfDataStats * combined_tf_data_stats)405 void SetSuggestion(CombinedTfDataStats* combined_tf_data_stats) {
406 for (TfDataBottleneckAnalysis& bottleneck_analysis :
407 *combined_tf_data_stats->mutable_bottleneck_analysis()) {
408 bottleneck_analysis.set_suggestion(
409 GetSuggestion(GetBottleneckType(bottleneck_analysis.iterator_name())));
410 }
411 }
412
SetSummary(CombinedTfDataStats * combined_tf_data_stats)413 void SetSummary(CombinedTfDataStats* combined_tf_data_stats) {
414 int64_t max_latency_ps = 0;
415 if (combined_tf_data_stats->bottleneck_analysis_size()) {
416 max_latency_ps =
417 combined_tf_data_stats->bottleneck_analysis().at(0).max_latency_ps();
418 }
419 if (max_latency_ps > kSlowCallThresholdPs) {
420 combined_tf_data_stats->set_is_input_bound(true);
421 combined_tf_data_stats->set_summary(
422 "Your profile has a tf.data input pipeline slower than 50 us. For each "
423 "slow input pipeline, below shows a bottleneck in the input pipeline "
424 "and a suggestion on how to fix it.");
425 } else if (max_latency_ps > 0) {
426 combined_tf_data_stats->set_is_input_bound(false);
427 combined_tf_data_stats->set_summary(
428 "Your profile does not have any tf.data input pipeline slower than 50 "
429 "us. Your job could be still input bound if this profile didn't "
430 "capture all workers.");
431 } else {
432 combined_tf_data_stats->set_is_input_bound(false);
433 combined_tf_data_stats->set_summary(
434 "No tf.data activitiy captured in your profile. If your job uses "
435 "tf.data, try to capture a longer profile.");
436 }
437 }
438
439 } // namespace
440
GetBottleneckType(absl::string_view bottleneck_iterator_name)441 BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name) {
442 static auto* kBottleneckTypeMap = new absl::flat_hash_map<absl::string_view,
443 BottleneckType>(
444 {// Read from storage.
445 {"TFRecord", BottleneckType::kSlowSource},
446 {"SSTable", BottleneckType::kSlowSource},
447 {"RecordIO", BottleneckType::kSlowSource},
448 {"Spanner", BottleneckType::kSlowSource},
449 {"TFColumn", BottleneckType::kSlowSource},
450 {"SleepwalkRemoteDataset", BottleneckType::kSlowSource},
451 {"TextLine", BottleneckType::kSlowSource},
452 {"StitchedTimelineDataset", BottleneckType::kSlowSource},
453 {"DateKeyDataset", BottleneckType::kSlowSource},
454 {"CapacitorProto", BottleneckType::kSlowSource},
455 {"LMDB", BottleneckType::kSlowSource},
456 {"ExternalDataset", BottleneckType::kSlowSource},
457 {"PearModel", BottleneckType::kSlowSource},
458 {"FixedLengthRecordV2", BottleneckType::kSlowSource},
459 // Read from local memory.
460 {"FromTensor", BottleneckType::kSlowSource},
461 {"TensorSlice", BottleneckType::kSlowSource},
462 {"Generator", BottleneckType::kSlowSource},
463 {"SyntheticDatasetOp", BottleneckType::kSlowSource},
464 // tf.data service.
465 {"DataService", BottleneckType::kSlowDataService},
466 // Read from remote memory.
467 {"GuzzlerDataGuzzlerRemoteDataset", BottleneckType::kSlowRemoteSource},
468 {"ReverbDataset", BottleneckType::kSlowRemoteSource},
469 {"DatasetSampleGame", BottleneckType::kSlowRemoteSource},
470 {"Courier", BottleneckType::kSlowRemoteSource},
471 {"ReverbEpisodeDataset", BottleneckType::kSlowRemoteSource},
472 // Transformations with parallel version.
473 {"Map", BottleneckType::kSlowTransformationWithParallelVersion},
474 {"Interleave", BottleneckType::kSlowTransformationWithParallelVersion},
475 // Transformations without parallel version.
476 {"Filter", BottleneckType::kSlowTransformationWithoutParallelVersion},
477 {"Batch", BottleneckType::kSlowTransformationWithoutParallelVersion},
478 {"Unbatch", BottleneckType::kSlowTransformationWithoutParallelVersion}});
479 if (auto type =
480 gtl::FindOrNull(*kBottleneckTypeMap, bottleneck_iterator_name)) {
481 return *type;
482 }
483 return BottleneckType::kOther;
484 }
485
Add(absl::string_view host_name,XPlane * host_plane)486 void CombinedTfDataStatsBuilder::Add(absl::string_view host_name,
487 XPlane* host_plane) {
488 TfDataStats& tf_data_stats =
489 (*combined_tf_data_stats_
490 ->mutable_tf_data_stats())[std::string(host_name)];
491 EventForest event_forest;
492 event_forest.AddPlanes(CreateTfXPlaneVisitor, {host_plane});
493 event_forest.ConnectEvents();
494 event_forest.ConnectTfDataEvents();
495 absl::flat_hash_set<int64> device_input_pipeline_ids;
496 absl::flat_hash_map<int64, std::vector<EventNode*>> root_iterator_event_map;
497 ProcessEventForest(event_forest, &device_input_pipeline_ids,
498 &root_iterator_event_map, &tf_data_stats);
499 ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map,
500 &tf_data_stats);
501 }
502
Finalize()503 void CombinedTfDataStatsBuilder::Finalize() {
504 SetBottleneckAnalysis(combined_tf_data_stats_);
505 if (generate_suggestion_) SetSuggestion(combined_tf_data_stats_);
506 SetSummary(combined_tf_data_stats_);
507 }
508
509 } // namespace profiler
510 } // namespace tensorflow
511