• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h"
26 #include "tensorflow/core/profiler/utils/group_events.h"
27 #include "tensorflow/core/profiler/utils/html_utils.h"
28 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
29 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
30 #include "tensorflow/core/profiler/utils/timespan.h"
31 #include "tensorflow/core/profiler/utils/xplane_schema.h"
32 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
33 
34 namespace tensorflow {
35 namespace profiler {
36 
37 // 50 us from https://www.tensorflow.org/guide/data_performance_analysis
38 const int64_t kSlowCallThresholdPs = 50 * 1000000;
39 
40 namespace {
41 
42 // Returns true if the given iterator event is for a root iterator.
IsRootIteratorEvent(const XEventVisitor & iterator_event)43 bool IsRootIteratorEvent(const XEventVisitor& iterator_event) {
44   std::vector<absl::string_view> split_result =
45       absl::StrSplit(iterator_event.Name(), "::");
46   // The root iterator's name contains only its own name (no parent
47   // information).
48   return split_result.size() == 2;
49 }
50 
51 // Returns true if the given iterator event name is for an async iterator.
IsAsyncIterator(absl::string_view iterator_event_name)52 bool IsAsyncIterator(absl::string_view iterator_event_name) {
53   static auto* kAsyncIterators = new absl::flat_hash_set<absl::string_view>(
54       {"Prefetch", "ParallelInterleave", "ParallelMap", "ParseExample",
55        "MapAndBatch", "DataService", "LegacyParallelInterleave",
56        "ParallelBatch"});
57   return kAsyncIterators->contains(iterator_event_name);
58 }
59 
SetIteratorMetadata(int64_t id,const XEventVisitor & event,IteratorMetadata * metadata)60 void SetIteratorMetadata(int64_t id, const XEventVisitor& event,
61                          IteratorMetadata* metadata) {
62   metadata->set_id(id);
63   auto parent_id_stat = event.GetStat(StatType::kParentId);
64   if (parent_id_stat.has_value()) {
65     metadata->set_parent_id(parent_id_stat->IntValue());
66   }
67   metadata->set_name(IteratorName(event.Name()));
68   metadata->set_long_name(event.Name().data(), event.Name().size());
69   metadata->set_is_async(IsAsyncIterator(metadata->name()));
70   // TODO(b/161831651): Set params.
71 }
72 
73 // Returns the parent iterator's id if it is a root of a device input
74 // pipeline.
FindDeviceInputPipeline(const XEventVisitor & event)75 absl::optional<int64> FindDeviceInputPipeline(const XEventVisitor& event) {
76   if (event.Type() == HostEventType::kDeviceInputPipelineSecondIterator) {
77     auto parent_id_stat = event.GetStat(StatType::kParentId);
78     if (parent_id_stat.has_value()) return parent_id_stat->IntValue();
79   }
80   return absl::nullopt;
81 }
82 
83 // Processes EventForest to do the following:
84 // (1) set iterator metadata
85 // (2) find root iterator events
86 // (3) find device input pipeline ids
ProcessEventForest(const EventForest & event_forest,absl::flat_hash_set<int64> * device_input_pipeline_ids,absl::flat_hash_map<int64,std::vector<EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)87 void ProcessEventForest(const EventForest& event_forest,
88                         absl::flat_hash_set<int64>* device_input_pipeline_ids,
89                         absl::flat_hash_map<int64, std::vector<EventNode*>>*
90                             root_iterator_event_map,
91                         TfDataStats* tf_data_stats) {
92   const EventNodeMap& event_node_map = event_forest.GetEventNodeMap();
93   auto iterator_event_list =
94       gtl::FindOrNull(event_node_map, HostEventType::kIterator);
95   if (!iterator_event_list) return;
96   for (const auto& iterator_event : *iterator_event_list) {
97     const XEventVisitor& iterator_event_visitor =
98         iterator_event->GetEventVisitor();
99     auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
100     if (!iterator_id_stat.has_value()) continue;
101     int64_t iterator_id = iterator_id_stat->IntValue();
102     auto result = tf_data_stats->mutable_iterator_metadata()->insert(
103         {iterator_id, IteratorMetadata()});
104     IteratorMetadata& metadata = result.first->second;
105     if (result.second) {
106       // First time processing this iterator.
107       SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
108     }
109     if (IsRootIteratorEvent(iterator_event_visitor)) {
110       // Record root iterator events.
111       (*root_iterator_event_map)[iterator_id].push_back(iterator_event.get());
112     }
113   }
114   auto device_input_pipeline_second_iterator_events = gtl::FindOrNull(
115       event_node_map, HostEventType::kDeviceInputPipelineSecondIterator);
116   if (!device_input_pipeline_second_iterator_events) return;
117   for (const auto& iterator_event :
118        *device_input_pipeline_second_iterator_events) {
119     const XEventVisitor& iterator_event_visitor =
120         iterator_event->GetEventVisitor();
121     auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
122     if (!iterator_id_stat.has_value()) continue;
123     int64_t iterator_id = iterator_id_stat->IntValue();
124     auto result = tf_data_stats->mutable_iterator_metadata()->insert(
125         {iterator_id, IteratorMetadata()});
126     IteratorMetadata& metadata = result.first->second;
127     if (result.second) {
128       // First time processing this iterator.
129       SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
130       // Find and record device input pipeline ids.
131       absl::optional<int64> device_input_pipeline_id =
132           FindDeviceInputPipeline(iterator_event_visitor);
133       if (device_input_pipeline_id.has_value()) {
134         device_input_pipeline_ids->insert(*device_input_pipeline_id);
135       }
136     }
137   }
138 }
139 
SetInputPipelineMetadata(int64_t id,int64_t name_id,bool is_device_input_pipeline,InputPipelineMetadata * metadata)140 void SetInputPipelineMetadata(int64_t id, int64_t name_id,
141                               bool is_device_input_pipeline,
142                               InputPipelineMetadata* metadata) {
143   constexpr absl::string_view kHostInputPipelinePrefix = "Host:";
144   constexpr absl::string_view kDeviceInputPipelinePrefix = "Device:";
145   metadata->set_id(id);
146   if (is_device_input_pipeline) {
147     metadata->set_type(InputPipelineMetadata::DEVICE);
148     metadata->set_name(absl::StrCat(kDeviceInputPipelinePrefix, name_id));
149   } else {
150     metadata->set_type(InputPipelineMetadata::HOST);
151     metadata->set_name(absl::StrCat(kHostInputPipelinePrefix, name_id));
152   }
153 }
154 
ProcessIteratorEvent(const EventNode & iterator_event,InputPipelineStat * input_pipeline_stat,bool is_blocking)155 void ProcessIteratorEvent(const EventNode& iterator_event,
156                           InputPipelineStat* input_pipeline_stat,
157                           bool is_blocking) {
158   const XEventVisitor& visitor = iterator_event.GetEventVisitor();
159   auto iterator_id_stat = visitor.GetStat(StatType::kStepId);
160   if (!iterator_id_stat.has_value()) return;
161   int64_t iterator_id = iterator_id_stat->IntValue();
162   auto result = input_pipeline_stat->mutable_iterator_stats()->insert(
163       {iterator_id, IteratorStat()});
164   IteratorStat& iterator_stat = result.first->second;
165   if (result.second) {
166     iterator_stat.set_id(iterator_id);
167     iterator_stat.set_start_time_ps(visitor.TimestampPs());
168   }
169   iterator_stat.set_duration_ps(iterator_stat.duration_ps() +
170                                 visitor.DurationPs());
171   int64_t self_time_ps = visitor.DurationPs();
172   Timespan self_time_span = visitor.GetTimespan();
173   for (EventNode* child : iterator_event.GetChildren()) {
174     const XEventVisitor& child_visitor = child->GetEventVisitor();
175     if (ParseTfOpFullname(child_visitor.Name()).category == Category::kTfData) {
176       int64_t overlap_duration_ps =
177           self_time_span.OverlappedDurationPs(child_visitor.GetTimespan());
178       ProcessIteratorEvent(*child, input_pipeline_stat,
179                            is_blocking && overlap_duration_ps);
180       // Note: Assume no overlap between child events.
181       self_time_ps -= overlap_duration_ps;
182     }
183   }
184   iterator_stat.set_self_time_ps(iterator_stat.self_time_ps() + self_time_ps);
185   iterator_stat.set_is_blocking(iterator_stat.is_blocking() || is_blocking);
186   iterator_stat.set_num_calls(iterator_stat.num_calls() + 1);
187 }
188 
SetBottleneckIteratorId(InputPipelineStat * input_pipeline_stat)189 void SetBottleneckIteratorId(InputPipelineStat* input_pipeline_stat) {
190   int64_t bottleneck_iterator_id = 0;
191   int64_t max_self_time = 0;
192   for (const auto& pair : input_pipeline_stat->iterator_stats()) {
193     const auto& id = pair.first;
194     const auto& iterator_stat = pair.second;
195     if (iterator_stat.is_blocking() &&
196         iterator_stat.self_time_ps() > max_self_time) {
197       bottleneck_iterator_id = id;
198       max_self_time = iterator_stat.self_time_ps();
199     }
200   }
201   input_pipeline_stat->set_bottleneck_iterator_id(bottleneck_iterator_id);
202   input_pipeline_stat->set_bottleneck_iterator_latency_ps(max_self_time);
203 }
204 
ProcessInputPipelines(const absl::flat_hash_set<int64> & device_input_pipeline_ids,absl::flat_hash_map<int64,std::vector<EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)205 void ProcessInputPipelines(
206     const absl::flat_hash_set<int64>& device_input_pipeline_ids,
207     absl::flat_hash_map<int64, std::vector<EventNode*>>*
208         root_iterator_event_map,
209     TfDataStats* tf_data_stats) {
210   auto* input_pipelines = tf_data_stats->mutable_input_pipelines();
211   int64_t num_host_input_pipelines = 0;
212   int64_t num_device_input_pipelines = 0;
213   for (auto& id_and_events : *root_iterator_event_map) {
214     auto& root_iterator_id = id_and_events.first;
215     auto& root_iterator_events = id_and_events.second;
216     absl::c_sort(root_iterator_events,
217                  [](const EventNode* lhs, const EventNode* rhs) {
218                    return lhs->GetEventVisitor().DurationPs() >
219                           rhs->GetEventVisitor().DurationPs();
220                  });
221     auto result =
222         input_pipelines->insert({root_iterator_id, InputPipelineStats()});
223     InputPipelineStats& input_pipeline_stats = result.first->second;
224     InputPipelineMetadata* metadata = input_pipeline_stats.mutable_metadata();
225     if (result.second) {
226       bool is_device_input_pipeline =
227           device_input_pipeline_ids.contains(root_iterator_id);
228       int64_t name_id = is_device_input_pipeline ? num_device_input_pipelines++
229                                                  : num_host_input_pipelines++;
230       SetInputPipelineMetadata(root_iterator_id, name_id,
231                                is_device_input_pipeline, metadata);
232     }
233     int64_t sum_latency_ps = 0;
234     int64_t min_latency_ps = INT64_MAX;
235     int64_t max_latency_ps = 0;
236     int64_t num_slow_calls = 0;
237     for (const EventNode* root_iterator_event : root_iterator_events) {
238       InputPipelineStat* stat = input_pipeline_stats.add_stats();
239       ProcessIteratorEvent(*root_iterator_event, stat,
240                            /*is_blocking*/ true);
241       SetBottleneckIteratorId(stat);
242       int64_t latency_ps = root_iterator_event->GetEventVisitor().DurationPs();
243       sum_latency_ps += latency_ps;
244       min_latency_ps = std::min(min_latency_ps, latency_ps);
245       max_latency_ps = std::max(max_latency_ps, latency_ps);
246       if (latency_ps > kSlowCallThresholdPs) num_slow_calls++;
247     }
248     input_pipeline_stats.set_avg_latency_ps(sum_latency_ps /
249                                             root_iterator_events.size());
250     input_pipeline_stats.set_min_latency_ps(min_latency_ps);
251     input_pipeline_stats.set_max_latency_ps(max_latency_ps);
252     input_pipeline_stats.set_num_slow_calls(num_slow_calls);
253   }
254 }
255 
SetBottleneckAnalysis(CombinedTfDataStats * combined_tf_data_stats)256 void SetBottleneckAnalysis(CombinedTfDataStats* combined_tf_data_stats) {
257   struct InputPipeline {
258     InputPipeline(absl::string_view host_name,
259                   absl::string_view input_pipeline_name, int64_t max_latency_ps,
260                   absl::string_view iterator_name,
261                   absl::string_view iterator_long_name,
262                   int64_t iterator_latency_ps)
263         : host_name(host_name),
264           input_pipeline_name(input_pipeline_name),
265           max_latency_ps(max_latency_ps),
266           iterator_name(iterator_name),
267           iterator_long_name(iterator_long_name),
268           iterator_latency_ps(iterator_latency_ps) {}
269     absl::string_view host_name;
270     absl::string_view input_pipeline_name;
271     int64 max_latency_ps;
272     absl::string_view iterator_name;
273     absl::string_view iterator_long_name;
274     int64 iterator_latency_ps;
275 
276     bool operator<(const InputPipeline& rhs) const {
277       return max_latency_ps > rhs.max_latency_ps;
278     }
279   };
280   std::vector<InputPipeline> slow_input_pipelines;
281   for (const auto& host_name_and_tf_data_stats :
282        combined_tf_data_stats->tf_data_stats()) {
283     absl::string_view host_name = host_name_and_tf_data_stats.first;
284     const TfDataStats& tf_data_stats = host_name_and_tf_data_stats.second;
285     for (const auto& id_and_stats : tf_data_stats.input_pipelines()) {
286       const InputPipelineStats& input_pipeline_stats = id_and_stats.second;
287       if (input_pipeline_stats.metadata().type() ==
288           InputPipelineMetadata::DEVICE) {
289         // Ignore device input pipelines.
290         continue;
291       }
292       // Choose the slowest execution trace of the input pipeline.
293       // `input_pipeline_stats.stats` is already sorted so choose the first one.
294       const InputPipelineStat& input_pipeline_stat =
295           input_pipeline_stats.stats(0);
296       const IteratorMetadata& metadata = tf_data_stats.iterator_metadata().at(
297           input_pipeline_stat.bottleneck_iterator_id());
298       slow_input_pipelines.emplace_back(
299           host_name, input_pipeline_stats.metadata().name(),
300           input_pipeline_stats.max_latency_ps(), metadata.name(),
301           metadata.long_name(),
302           input_pipeline_stat.bottleneck_iterator_latency_ps());
303     }
304   }
305   std::sort(slow_input_pipelines.begin(), slow_input_pipelines.end());
306   for (const auto& input_pipeline : slow_input_pipelines) {
307     TfDataBottleneckAnalysis* bottleneck_analysis =
308         combined_tf_data_stats->add_bottleneck_analysis();
309     bottleneck_analysis->set_host(input_pipeline.host_name.data(),
310                                   input_pipeline.host_name.size());
311     bottleneck_analysis->set_input_pipeline(
312         input_pipeline.input_pipeline_name.data(),
313         input_pipeline.input_pipeline_name.size());
314     bottleneck_analysis->set_max_latency_ps(input_pipeline.max_latency_ps);
315     bottleneck_analysis->set_iterator_name(input_pipeline.iterator_name.data(),
316                                            input_pipeline.iterator_name.size());
317     bottleneck_analysis->set_iterator_long_name(
318         input_pipeline.iterator_long_name.data(),
319         input_pipeline.iterator_long_name.size());
320     bottleneck_analysis->set_iterator_latency_ps(
321         input_pipeline.iterator_latency_ps);
322   }
323 }
324 
GetSuggestion(BottleneckType type)325 std::string GetSuggestion(BottleneckType type) {
326   constexpr absl::string_view kPlaybookLink =
327       "https://www.tensorflow.org/guide/data_performance_analysis";
328   constexpr absl::string_view kPlaybookSourceDatasetLink =
329       "https://www.tensorflow.org/guide/"
330       "data_performance_analysis#source_datasets";
331   constexpr absl::string_view kPlaybookCpuUtilizationLink =
332       "https://www.tensorflow.org/guide/"
333       "data_performance_analysis#3_are_you_reaching_high_cpu_utilization";
334   constexpr absl::string_view kPlaybookTransformationLink =
335       "https://www.tensorflow.org/guide/"
336       "data_performance_analysis#transformation_datasets";
337   constexpr absl::string_view kTfGuideParallelDataExtractionLink =
338       "https://www.tensorflow.org/guide/"
339       "data_performance#parallelizing_data_extraction";
340   constexpr absl::string_view kTfGuideParallelTransformationLink =
341       "https://www.tensorflow.org/guide/"
342       "data_performance#parallelizing_data_transformation";
343   constexpr absl::string_view kTfGuideCacheLink =
344       "https://www.tensorflow.org/guide/data_performance#caching";
345   constexpr absl::string_view kTfDataServiceLink =
346       "https://www.tensorflow.org/api_docs/python/tf/data/experimental/"
347       "service?version=nightly";
348   switch (type) {
349     case BottleneckType::kSlowSource:
350       return absl::StrFormat(
351           "1. Check the locality of a host and input data. Ideally, they "
352           "should be in the same cell (or very close, like the same "
353           "region).<br/>"
354           "2. Parallelize reading from this dataset source. See %s and %s for "
355           "more details.<br/>",
356           AnchorElement(kPlaybookSourceDatasetLink, "here"),
357           AnchorElement(kTfGuideParallelDataExtractionLink, "here"));
358     case BottleneckType::kSlowDataService:
359       return absl::StrFormat(
360           "1. Fetching data from tf.data service took a while. Profile the "
361           "tf.data service worker to analyze the issue further.<br/>"
362           "2. See %s for more details on tf.data service.<br/>"
363           "3. See %s for other suggestions.",
364           AnchorElement(kTfDataServiceLink, "this"),
365           AnchorElement(kPlaybookLink, "this"));
366     case BottleneckType::kSlowRemoteSource:
367       return absl::StrFormat(
368           "1. The remote data source is slow. Profile its host to analyze the "
369           "issue further.<br/>"
370           "2. See %s for other suggestions.",
371           AnchorElement(kPlaybookLink, "this"));
372     case BottleneckType::kSlowTransformationWithParallelVersion:
373       return absl::StrFormat(
374           "1. Parallelize this transformation by setting "
375           "<code>num_parallel_calls=tf.data.experimental.AUTOTUNE</code>. See "
376           "%s for more details.<br/>"
377           "2. Consider adding <code>cache</code> after this transformation if "
378           "your data fits into memory and it is appropriate (e.g., there is no "
379           "randomness in upstream transformations like <code>shuffle</code>). "
380           "See %s for more details.<br/>"
381           "3. Find more resources %s.",
382           AnchorElement(kTfGuideParallelTransformationLink, "this"),
383           AnchorElement(kTfGuideCacheLink, "this"),
384           AnchorElement(kPlaybookTransformationLink, "here"));
385     case BottleneckType::kSlowTransformationWithoutParallelVersion:
386       return absl::StrFormat(
387           "1. This transformation is inherently sequential. Add outer "
388           "parallelism by running multiple copies of the input pipeline over "
389           "sharded inputs and combining the results. See %s for more "
390           "details.<br/>"
391           "2. Consider adding <code>cache</code> after this transformation if "
392           "your data fits into memory and it is appropriate (e.g., there is no "
393           "randomness in upstream transformations like <code>shuffle</code>). "
394           "See %s for more details.<br/>"
395           "3. Find more resources %s.",
396           AnchorElement(kPlaybookTransformationLink, "this"),
397           AnchorElement(kTfGuideCacheLink, "this"),
398           AnchorElement(kPlaybookCpuUtilizationLink, "here"));
399     default:
400       return absl::StrFormat("See %s for suggestions.",
401                              AnchorElement(kPlaybookLink, "this"));
402   }
403 }
404 
SetSuggestion(CombinedTfDataStats * combined_tf_data_stats)405 void SetSuggestion(CombinedTfDataStats* combined_tf_data_stats) {
406   for (TfDataBottleneckAnalysis& bottleneck_analysis :
407        *combined_tf_data_stats->mutable_bottleneck_analysis()) {
408     bottleneck_analysis.set_suggestion(
409         GetSuggestion(GetBottleneckType(bottleneck_analysis.iterator_name())));
410   }
411 }
412 
SetSummary(CombinedTfDataStats * combined_tf_data_stats)413 void SetSummary(CombinedTfDataStats* combined_tf_data_stats) {
414   int64_t max_latency_ps = 0;
415   if (combined_tf_data_stats->bottleneck_analysis_size()) {
416     max_latency_ps =
417         combined_tf_data_stats->bottleneck_analysis().at(0).max_latency_ps();
418   }
419   if (max_latency_ps > kSlowCallThresholdPs) {
420     combined_tf_data_stats->set_is_input_bound(true);
421     combined_tf_data_stats->set_summary(
422         "Your profile has a tf.data input pipeline slower than 50 us. For each "
423         "slow input pipeline, below shows a bottleneck in the input pipeline "
424         "and a suggestion on how to fix it.");
425   } else if (max_latency_ps > 0) {
426     combined_tf_data_stats->set_is_input_bound(false);
427     combined_tf_data_stats->set_summary(
428         "Your profile does not have any tf.data input pipeline slower than 50 "
429         "us. Your job could be still input bound if this profile didn't "
430         "capture all workers.");
431   } else {
432     combined_tf_data_stats->set_is_input_bound(false);
433     combined_tf_data_stats->set_summary(
434         "No tf.data activitiy captured in your profile. If your job uses "
435         "tf.data, try to capture a longer profile.");
436   }
437 }
438 
439 }  // namespace
440 
GetBottleneckType(absl::string_view bottleneck_iterator_name)441 BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name) {
442   static auto* kBottleneckTypeMap = new absl::flat_hash_map<absl::string_view,
443                                                             BottleneckType>(
444       {// Read from storage.
445        {"TFRecord", BottleneckType::kSlowSource},
446        {"SSTable", BottleneckType::kSlowSource},
447        {"RecordIO", BottleneckType::kSlowSource},
448        {"Spanner", BottleneckType::kSlowSource},
449        {"TFColumn", BottleneckType::kSlowSource},
450        {"SleepwalkRemoteDataset", BottleneckType::kSlowSource},
451        {"TextLine", BottleneckType::kSlowSource},
452        {"StitchedTimelineDataset", BottleneckType::kSlowSource},
453        {"DateKeyDataset", BottleneckType::kSlowSource},
454        {"CapacitorProto", BottleneckType::kSlowSource},
455        {"LMDB", BottleneckType::kSlowSource},
456        {"ExternalDataset", BottleneckType::kSlowSource},
457        {"PearModel", BottleneckType::kSlowSource},
458        {"FixedLengthRecordV2", BottleneckType::kSlowSource},
459        // Read from local memory.
460        {"FromTensor", BottleneckType::kSlowSource},
461        {"TensorSlice", BottleneckType::kSlowSource},
462        {"Generator", BottleneckType::kSlowSource},
463        {"SyntheticDatasetOp", BottleneckType::kSlowSource},
464        // tf.data service.
465        {"DataService", BottleneckType::kSlowDataService},
466        // Read from remote memory.
467        {"GuzzlerDataGuzzlerRemoteDataset", BottleneckType::kSlowRemoteSource},
468        {"ReverbDataset", BottleneckType::kSlowRemoteSource},
469        {"DatasetSampleGame", BottleneckType::kSlowRemoteSource},
470        {"Courier", BottleneckType::kSlowRemoteSource},
471        {"ReverbEpisodeDataset", BottleneckType::kSlowRemoteSource},
472        // Transformations with parallel version.
473        {"Map", BottleneckType::kSlowTransformationWithParallelVersion},
474        {"Interleave", BottleneckType::kSlowTransformationWithParallelVersion},
475        // Transformations without parallel version.
476        {"Filter", BottleneckType::kSlowTransformationWithoutParallelVersion},
477        {"Batch", BottleneckType::kSlowTransformationWithoutParallelVersion},
478        {"Unbatch", BottleneckType::kSlowTransformationWithoutParallelVersion}});
479   if (auto type =
480           gtl::FindOrNull(*kBottleneckTypeMap, bottleneck_iterator_name)) {
481     return *type;
482   }
483   return BottleneckType::kOther;
484 }
485 
Add(absl::string_view host_name,XPlane * host_plane)486 void CombinedTfDataStatsBuilder::Add(absl::string_view host_name,
487                                      XPlane* host_plane) {
488   TfDataStats& tf_data_stats =
489       (*combined_tf_data_stats_
490             ->mutable_tf_data_stats())[std::string(host_name)];
491   EventForest event_forest;
492   event_forest.AddPlanes(CreateTfXPlaneVisitor, {host_plane});
493   event_forest.ConnectEvents();
494   event_forest.ConnectTfDataEvents();
495   absl::flat_hash_set<int64> device_input_pipeline_ids;
496   absl::flat_hash_map<int64, std::vector<EventNode*>> root_iterator_event_map;
497   ProcessEventForest(event_forest, &device_input_pipeline_ids,
498                      &root_iterator_event_map, &tf_data_stats);
499   ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map,
500                         &tf_data_stats);
501 }
502 
Finalize()503 void CombinedTfDataStatsBuilder::Finalize() {
504   SetBottleneckAnalysis(combined_tf_data_stats_);
505   if (generate_suggestion_) SetSuggestion(combined_tf_data_stats_);
506   SetSummary(combined_tf_data_stats_);
507 }
508 
509 }  // namespace profiler
510 }  // namespace tensorflow
511