1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h"
26 #include "tensorflow/core/profiler/utils/group_events.h"
27 #include "tensorflow/core/profiler/utils/html_utils.h"
28 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
29 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
30 #include "tensorflow/core/profiler/utils/timespan.h"
31 #include "tensorflow/core/profiler/utils/xplane_schema.h"
32 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
33 
34 namespace tensorflow {
35 namespace profiler {
36 
37 // 50 us from https://www.tensorflow.org/guide/data_performance_analysis
38 const int64_t kSlowCallThresholdPs = 50 * 1000000;
39 
40 namespace {
41 
42 // Returns true if the given iterator event is for a root iterator.
IsRootIteratorEvent(const XEventVisitor & iterator_event)43 bool IsRootIteratorEvent(const XEventVisitor& iterator_event) {
44   std::vector<absl::string_view> split_result =
45       absl::StrSplit(iterator_event.Name(), "::");
46   // The root iterator's name contains only its own name (no parent
47   // information).
48   return split_result.size() == 2;
49 }
50 
51 // Returns true if the given iterator event name is for an async iterator.
IsAsyncIterator(absl::string_view iterator_event_name)52 bool IsAsyncIterator(absl::string_view iterator_event_name) {
53   static auto* kAsyncIterators = new absl::flat_hash_set<absl::string_view>(
54       {"Prefetch", "ParallelInterleave", "ParallelMap", "ParseExample",
55        "MapAndBatch", "DataService", "LegacyParallelInterleave",
56        "ParallelBatch"});
57   return kAsyncIterators->contains(iterator_event_name);
58 }
59 
SetIteratorMetadata(int64_t id,const XEventVisitor & event,IteratorMetadata * metadata)60 void SetIteratorMetadata(int64_t id, const XEventVisitor& event,
61                          IteratorMetadata* metadata) {
62   metadata->set_id(id);
63   auto parent_id_stat = event.GetStat(StatType::kParentId);
64   if (parent_id_stat.has_value()) {
65     metadata->set_parent_id(parent_id_stat->IntValue());
66   }
67   metadata->set_name(IteratorName(event.Name()));
68   metadata->set_long_name(event.Name().data(), event.Name().size());
69   metadata->set_is_async(IsAsyncIterator(metadata->name()));
70   // TODO(b/161831651): Set params.
71 }
72 
73 // Returns the parent iterator's id if it is a root of a device input
74 // pipeline.
FindDeviceInputPipeline(const XEventVisitor & event)75 absl::optional<int64_t> FindDeviceInputPipeline(const XEventVisitor& event) {
76   if (event.Type() == HostEventType::kDeviceInputPipelineSecondIterator) {
77     auto parent_id_stat = event.GetStat(StatType::kParentId);
78     if (parent_id_stat.has_value()) return parent_id_stat->IntValue();
79   }
80   return absl::nullopt;
81 }
82 
83 // Processes EventForest to do the following:
84 // (1) set iterator metadata
85 // (2) find root iterator events
86 // (3) find device input pipeline ids
ProcessEventForest(const EventForest & event_forest,absl::flat_hash_set<int64_t> * device_input_pipeline_ids,absl::flat_hash_map<int64_t,std::vector<const EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)87 void ProcessEventForest(
88     const EventForest& event_forest,
89     absl::flat_hash_set<int64_t>* device_input_pipeline_ids,
90     absl::flat_hash_map<int64_t, std::vector<const EventNode*>>*
91         root_iterator_event_map,
92     TfDataStats* tf_data_stats) {
93   const EventNodeMap& event_node_map = event_forest.GetEventNodeMap();
94   auto* iterator_event_list =
95       gtl::FindOrNull(event_node_map, HostEventType::kIterator);
96   if (!iterator_event_list) return;
97   for (const EventNode& iterator_event : *iterator_event_list) {
98     const XEventVisitor& iterator_event_visitor =
99         iterator_event.GetEventVisitor();
100     auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
101     if (!iterator_id_stat.has_value()) continue;
102     int64_t iterator_id = iterator_id_stat->IntValue();
103     auto result = tf_data_stats->mutable_iterator_metadata()->insert(
104         {iterator_id, IteratorMetadata()});
105     IteratorMetadata& metadata = result.first->second;
106     if (result.second) {
107       // First time processing this iterator.
108       SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
109     }
110     if (IsRootIteratorEvent(iterator_event_visitor)) {
111       // Record root iterator events.
112       (*root_iterator_event_map)[iterator_id].push_back(&iterator_event);
113     }
114   }
115   auto* device_input_pipeline_second_iterator_events = gtl::FindOrNull(
116       event_node_map, HostEventType::kDeviceInputPipelineSecondIterator);
117   if (!device_input_pipeline_second_iterator_events) return;
118   for (const EventNode& iterator_event :
119        *device_input_pipeline_second_iterator_events) {
120     const XEventVisitor& iterator_event_visitor =
121         iterator_event.GetEventVisitor();
122     auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId);
123     if (!iterator_id_stat.has_value()) continue;
124     int64_t iterator_id = iterator_id_stat->IntValue();
125     auto result = tf_data_stats->mutable_iterator_metadata()->insert(
126         {iterator_id, IteratorMetadata()});
127     IteratorMetadata& metadata = result.first->second;
128     if (result.second) {
129       // First time processing this iterator.
130       SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata);
131       // Find and record device input pipeline ids.
132       absl::optional<int64_t> device_input_pipeline_id =
133           FindDeviceInputPipeline(iterator_event_visitor);
134       if (device_input_pipeline_id.has_value()) {
135         device_input_pipeline_ids->insert(*device_input_pipeline_id);
136       }
137     }
138   }
139 }
140 
SetInputPipelineMetadata(int64_t id,int64_t name_id,bool is_device_input_pipeline,InputPipelineMetadata * metadata)141 void SetInputPipelineMetadata(int64_t id, int64_t name_id,
142                               bool is_device_input_pipeline,
143                               InputPipelineMetadata* metadata) {
144   constexpr absl::string_view kHostInputPipelinePrefix = "Host:";
145   constexpr absl::string_view kDeviceInputPipelinePrefix = "Device:";
146   metadata->set_id(id);
147   if (is_device_input_pipeline) {
148     metadata->set_type(InputPipelineMetadata::DEVICE);
149     metadata->set_name(absl::StrCat(kDeviceInputPipelinePrefix, name_id));
150   } else {
151     metadata->set_type(InputPipelineMetadata::HOST);
152     metadata->set_name(absl::StrCat(kHostInputPipelinePrefix, name_id));
153   }
154 }
155 
ProcessIteratorEvent(const EventNode & iterator_event,InputPipelineStat * input_pipeline_stat,bool is_blocking)156 void ProcessIteratorEvent(const EventNode& iterator_event,
157                           InputPipelineStat* input_pipeline_stat,
158                           bool is_blocking) {
159   const XEventVisitor& visitor = iterator_event.GetEventVisitor();
160   auto iterator_id_stat = visitor.GetStat(StatType::kStepId);
161   if (!iterator_id_stat.has_value()) return;
162   int64_t iterator_id = iterator_id_stat->IntValue();
163   auto result = input_pipeline_stat->mutable_iterator_stats()->insert(
164       {iterator_id, IteratorStat()});
165   IteratorStat& iterator_stat = result.first->second;
166   if (result.second) {
167     iterator_stat.set_id(iterator_id);
168     iterator_stat.set_start_time_ps(visitor.TimestampPs());
169   }
170   iterator_stat.set_duration_ps(iterator_stat.duration_ps() +
171                                 visitor.DurationPs());
172   int64_t self_time_ps = visitor.DurationPs();
173   Timespan self_time_span = visitor.GetTimespan();
174   for (EventNode* child : iterator_event.GetChildren()) {
175     const XEventVisitor& child_visitor = child->GetEventVisitor();
176     if (ParseTfOpFullname(child_visitor.Name()).category == Category::kTfData) {
177       int64_t overlap_duration_ps =
178           self_time_span.OverlappedDurationPs(child_visitor.GetTimespan());
179       ProcessIteratorEvent(*child, input_pipeline_stat,
180                            is_blocking && overlap_duration_ps);
181       // Note: Assume no overlap between child events.
182       self_time_ps -= overlap_duration_ps;
183     }
184   }
185   iterator_stat.set_self_time_ps(iterator_stat.self_time_ps() + self_time_ps);
186   iterator_stat.set_is_blocking(iterator_stat.is_blocking() || is_blocking);
187   iterator_stat.set_num_calls(iterator_stat.num_calls() + 1);
188 }
189 
SetBottleneckIteratorId(InputPipelineStat * input_pipeline_stat)190 void SetBottleneckIteratorId(InputPipelineStat* input_pipeline_stat) {
191   int64_t bottleneck_iterator_id = 0;
192   int64_t max_self_time = 0;
193   for (const auto& pair : input_pipeline_stat->iterator_stats()) {
194     const auto& id = pair.first;
195     const auto& iterator_stat = pair.second;
196     if (iterator_stat.is_blocking() &&
197         iterator_stat.self_time_ps() > max_self_time) {
198       bottleneck_iterator_id = id;
199       max_self_time = iterator_stat.self_time_ps();
200     }
201   }
202   input_pipeline_stat->set_bottleneck_iterator_id(bottleneck_iterator_id);
203   input_pipeline_stat->set_bottleneck_iterator_latency_ps(max_self_time);
204 }
205 
ProcessInputPipelines(const absl::flat_hash_set<int64_t> & device_input_pipeline_ids,absl::flat_hash_map<int64_t,std::vector<const EventNode * >> * root_iterator_event_map,TfDataStats * tf_data_stats)206 void ProcessInputPipelines(
207     const absl::flat_hash_set<int64_t>& device_input_pipeline_ids,
208     absl::flat_hash_map<int64_t, std::vector<const EventNode*>>*
209         root_iterator_event_map,
210     TfDataStats* tf_data_stats) {
211   auto* input_pipelines = tf_data_stats->mutable_input_pipelines();
212   int64_t num_host_input_pipelines = 0;
213   int64_t num_device_input_pipelines = 0;
214   for (auto& id_and_events : *root_iterator_event_map) {
215     auto& root_iterator_id = id_and_events.first;
216     auto& root_iterator_events = id_and_events.second;
217     absl::c_sort(root_iterator_events,
218                  [](const EventNode* lhs, const EventNode* rhs) {
219                    return lhs->GetEventVisitor().DurationPs() >
220                           rhs->GetEventVisitor().DurationPs();
221                  });
222     auto result =
223         input_pipelines->insert({root_iterator_id, InputPipelineStats()});
224     InputPipelineStats& input_pipeline_stats = result.first->second;
225     InputPipelineMetadata* metadata = input_pipeline_stats.mutable_metadata();
226     if (result.second) {
227       bool is_device_input_pipeline =
228           device_input_pipeline_ids.contains(root_iterator_id);
229       int64_t name_id = is_device_input_pipeline ? num_device_input_pipelines++
230                                                  : num_host_input_pipelines++;
231       SetInputPipelineMetadata(root_iterator_id, name_id,
232                                is_device_input_pipeline, metadata);
233     }
234     int64_t sum_latency_ps = 0;
235     int64_t min_latency_ps = INT64_MAX;
236     int64_t max_latency_ps = 0;
237     int64_t num_slow_calls = 0;
238     for (const EventNode* root_iterator_event : root_iterator_events) {
239       InputPipelineStat* stat = input_pipeline_stats.add_stats();
240       ProcessIteratorEvent(*root_iterator_event, stat,
241                            /*is_blocking*/ true);
242       SetBottleneckIteratorId(stat);
243       int64_t latency_ps = root_iterator_event->GetEventVisitor().DurationPs();
244       sum_latency_ps += latency_ps;
245       min_latency_ps = std::min(min_latency_ps, latency_ps);
246       max_latency_ps = std::max(max_latency_ps, latency_ps);
247       if (latency_ps > kSlowCallThresholdPs) num_slow_calls++;
248     }
249     input_pipeline_stats.set_avg_latency_ps(sum_latency_ps /
250                                             root_iterator_events.size());
251     input_pipeline_stats.set_min_latency_ps(min_latency_ps);
252     input_pipeline_stats.set_max_latency_ps(max_latency_ps);
253     input_pipeline_stats.set_num_slow_calls(num_slow_calls);
254   }
255 }
256 
SetBottleneckAnalysis(CombinedTfDataStats * combined_tf_data_stats)257 void SetBottleneckAnalysis(CombinedTfDataStats* combined_tf_data_stats) {
258   struct InputPipeline {
259     InputPipeline(absl::string_view host_name,
260                   absl::string_view input_pipeline_name, int64_t max_latency_ps,
261                   absl::string_view iterator_name,
262                   absl::string_view iterator_long_name,
263                   int64_t iterator_latency_ps)
264         : host_name(host_name),
265           input_pipeline_name(input_pipeline_name),
266           max_latency_ps(max_latency_ps),
267           iterator_name(iterator_name),
268           iterator_long_name(iterator_long_name),
269           iterator_latency_ps(iterator_latency_ps) {}
270     absl::string_view host_name;
271     absl::string_view input_pipeline_name;
272     int64_t max_latency_ps;
273     absl::string_view iterator_name;
274     absl::string_view iterator_long_name;
275     int64_t iterator_latency_ps;
276 
277     bool operator<(const InputPipeline& rhs) const {
278       return max_latency_ps > rhs.max_latency_ps;
279     }
280   };
281   std::vector<InputPipeline> slow_input_pipelines;
282   for (const auto& host_name_and_tf_data_stats :
283        combined_tf_data_stats->tf_data_stats()) {
284     absl::string_view host_name = host_name_and_tf_data_stats.first;
285     const TfDataStats& tf_data_stats = host_name_and_tf_data_stats.second;
286     for (const auto& id_and_stats : tf_data_stats.input_pipelines()) {
287       const InputPipelineStats& input_pipeline_stats = id_and_stats.second;
288       if (input_pipeline_stats.metadata().type() ==
289           InputPipelineMetadata::DEVICE) {
290         // Ignore device input pipelines.
291         continue;
292       }
293       // Choose the slowest execution trace of the input pipeline.
294       // `input_pipeline_stats.stats` is already sorted so choose the first one.
295       const InputPipelineStat& input_pipeline_stat =
296           input_pipeline_stats.stats(0);
297       const IteratorMetadata& metadata = tf_data_stats.iterator_metadata().at(
298           input_pipeline_stat.bottleneck_iterator_id());
299       slow_input_pipelines.emplace_back(
300           host_name, input_pipeline_stats.metadata().name(),
301           input_pipeline_stats.max_latency_ps(), metadata.name(),
302           metadata.long_name(),
303           input_pipeline_stat.bottleneck_iterator_latency_ps());
304     }
305   }
306   std::sort(slow_input_pipelines.begin(), slow_input_pipelines.end());
307   for (const auto& input_pipeline : slow_input_pipelines) {
308     TfDataBottleneckAnalysis* bottleneck_analysis =
309         combined_tf_data_stats->add_bottleneck_analysis();
310     bottleneck_analysis->set_host(input_pipeline.host_name.data(),
311                                   input_pipeline.host_name.size());
312     bottleneck_analysis->set_input_pipeline(
313         input_pipeline.input_pipeline_name.data(),
314         input_pipeline.input_pipeline_name.size());
315     bottleneck_analysis->set_max_latency_ps(input_pipeline.max_latency_ps);
316     bottleneck_analysis->set_iterator_name(input_pipeline.iterator_name.data(),
317                                            input_pipeline.iterator_name.size());
318     bottleneck_analysis->set_iterator_long_name(
319         input_pipeline.iterator_long_name.data(),
320         input_pipeline.iterator_long_name.size());
321     bottleneck_analysis->set_iterator_latency_ps(
322         input_pipeline.iterator_latency_ps);
323   }
324 }
325 
GetSuggestion(BottleneckType type)326 std::string GetSuggestion(BottleneckType type) {
327   constexpr absl::string_view kPlaybookLink =
328       "https://www.tensorflow.org/guide/data_performance_analysis";
329   constexpr absl::string_view kPlaybookSourceDatasetLink =
330       "https://www.tensorflow.org/guide/"
331       "data_performance_analysis#source_datasets";
332   constexpr absl::string_view kPlaybookCpuUtilizationLink =
333       "https://www.tensorflow.org/guide/"
334       "data_performance_analysis#3_are_you_reaching_high_cpu_utilization";
335   constexpr absl::string_view kPlaybookTransformationLink =
336       "https://www.tensorflow.org/guide/"
337       "data_performance_analysis#transformation_datasets";
338   constexpr absl::string_view kTfGuideParallelDataExtractionLink =
339       "https://www.tensorflow.org/guide/"
340       "data_performance#parallelizing_data_extraction";
341   constexpr absl::string_view kTfGuideParallelTransformationLink =
342       "https://www.tensorflow.org/guide/"
343       "data_performance#parallelizing_data_transformation";
344   constexpr absl::string_view kTfGuideCacheLink =
345       "https://www.tensorflow.org/guide/data_performance#caching";
346   constexpr absl::string_view kTfDataServiceLink =
347       "https://www.tensorflow.org/api_docs/python/tf/data/experimental/"
348       "service?version=nightly";
349   switch (type) {
350     case BottleneckType::kSlowSource:
351       return absl::StrFormat(
352           "1. Check the locality of a host and input data. Ideally, they "
353           "should be in the same cell (or very close, like the same "
354           "region).<br/>"
355           "2. Parallelize reading from this dataset source. See %s and %s for "
356           "more details.<br/>",
357           AnchorElement(kPlaybookSourceDatasetLink, "here"),
358           AnchorElement(kTfGuideParallelDataExtractionLink, "here"));
359     case BottleneckType::kSlowDataService:
360       return absl::StrFormat(
361           "1. Fetching data from tf.data service took a while. Profile the "
362           "tf.data service worker to analyze the issue further.<br/>"
363           "2. See %s for more details on tf.data service.<br/>"
364           "3. See %s for other suggestions.",
365           AnchorElement(kTfDataServiceLink, "this"),
366           AnchorElement(kPlaybookLink, "this"));
367     case BottleneckType::kSlowRemoteSource:
368       return absl::StrFormat(
369           "1. The remote data source is slow. Profile its host to analyze the "
370           "issue further.<br/>"
371           "2. See %s for other suggestions.",
372           AnchorElement(kPlaybookLink, "this"));
373     case BottleneckType::kSlowTransformationWithParallelVersion:
374       return absl::StrFormat(
375           "1. Parallelize this transformation by setting "
376           "<code>num_parallel_calls=tf.data.experimental.AUTOTUNE</code>. See "
377           "%s for more details.<br/>"
378           "2. Consider adding <code>cache</code> after this transformation if "
379           "your data fits into memory and it is appropriate (e.g., there is no "
380           "randomness in upstream transformations like <code>shuffle</code>). "
381           "See %s for more details.<br/>"
382           "3. Find more resources %s.",
383           AnchorElement(kTfGuideParallelTransformationLink, "this"),
384           AnchorElement(kTfGuideCacheLink, "this"),
385           AnchorElement(kPlaybookTransformationLink, "here"));
386     case BottleneckType::kSlowTransformationWithoutParallelVersion:
387       return absl::StrFormat(
388           "1. This transformation is inherently sequential. Add outer "
389           "parallelism by running multiple copies of the input pipeline over "
390           "sharded inputs and combining the results. See %s for more "
391           "details.<br/>"
392           "2. Consider adding <code>cache</code> after this transformation if "
393           "your data fits into memory and it is appropriate (e.g., there is no "
394           "randomness in upstream transformations like <code>shuffle</code>). "
395           "See %s for more details.<br/>"
396           "3. Find more resources %s.",
397           AnchorElement(kPlaybookTransformationLink, "this"),
398           AnchorElement(kTfGuideCacheLink, "this"),
399           AnchorElement(kPlaybookCpuUtilizationLink, "here"));
400     default:
401       return absl::StrFormat("See %s for suggestions.",
402                              AnchorElement(kPlaybookLink, "this"));
403   }
404 }
405 
SetSuggestion(CombinedTfDataStats * combined_tf_data_stats)406 void SetSuggestion(CombinedTfDataStats* combined_tf_data_stats) {
407   for (TfDataBottleneckAnalysis& bottleneck_analysis :
408        *combined_tf_data_stats->mutable_bottleneck_analysis()) {
409     bottleneck_analysis.set_suggestion(
410         GetSuggestion(GetBottleneckType(bottleneck_analysis.iterator_name())));
411   }
412 }
413 
SetSummary(CombinedTfDataStats * combined_tf_data_stats)414 void SetSummary(CombinedTfDataStats* combined_tf_data_stats) {
415   int64_t max_latency_ps = 0;
416   if (combined_tf_data_stats->bottleneck_analysis_size()) {
417     max_latency_ps =
418         combined_tf_data_stats->bottleneck_analysis().at(0).max_latency_ps();
419   }
420   if (max_latency_ps > kSlowCallThresholdPs) {
421     combined_tf_data_stats->set_is_input_bound(true);
422     combined_tf_data_stats->set_summary(
423         "Your profile has a tf.data input pipeline slower than 50 us. For each "
424         "slow input pipeline, below shows a bottleneck in the input pipeline "
425         "and a suggestion on how to fix it.");
426   } else if (max_latency_ps > 0) {
427     combined_tf_data_stats->set_is_input_bound(false);
428     combined_tf_data_stats->set_summary(
429         "Your profile does not have any tf.data input pipeline slower than 50 "
430         "us. Your job could be still input bound if this profile didn't "
431         "capture all workers.");
432   } else {
433     combined_tf_data_stats->set_is_input_bound(false);
434     combined_tf_data_stats->set_summary(
435         "No tf.data activity captured in your profile. If your job uses "
436         "tf.data, try to capture a longer profile.");
437   }
438 }
439 
440 }  // namespace
441 
GetBottleneckType(absl::string_view bottleneck_iterator_name)442 BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name) {
443   static auto* kBottleneckTypeMap = new absl::flat_hash_map<absl::string_view,
444                                                             BottleneckType>(
445       {// Read from storage.
446        {"TFRecord", BottleneckType::kSlowSource},
447        {"SSTable", BottleneckType::kSlowSource},
448        {"RecordIO", BottleneckType::kSlowSource},
449        {"Spanner", BottleneckType::kSlowSource},
450        {"TFColumn", BottleneckType::kSlowSource},
451        {"SleepwalkRemoteDataset", BottleneckType::kSlowSource},
452        {"TextLine", BottleneckType::kSlowSource},
453        {"StitchedTimelineDataset", BottleneckType::kSlowSource},
454        {"DateKeyDataset", BottleneckType::kSlowSource},
455        {"CapacitorProto", BottleneckType::kSlowSource},
456        {"LMDB", BottleneckType::kSlowSource},
457        {"ExternalDataset", BottleneckType::kSlowSource},
458        {"PearModel", BottleneckType::kSlowSource},
459        {"FixedLengthRecordV2", BottleneckType::kSlowSource},
460        // Read from local memory.
461        {"FromTensor", BottleneckType::kSlowSource},
462        {"TensorSlice", BottleneckType::kSlowSource},
463        {"Generator", BottleneckType::kSlowSource},
464        {"SyntheticDatasetOp", BottleneckType::kSlowSource},
465        // tf.data service.
466        {"DataService", BottleneckType::kSlowDataService},
467        // Read from remote memory.
468        {"GuzzlerDataGuzzlerRemoteDataset", BottleneckType::kSlowRemoteSource},
469        {"ReverbDataset", BottleneckType::kSlowRemoteSource},
470        {"DatasetSampleGame", BottleneckType::kSlowRemoteSource},
471        {"Courier", BottleneckType::kSlowRemoteSource},
472        {"ReverbEpisodeDataset", BottleneckType::kSlowRemoteSource},
473        // Transformations with parallel version.
474        {"Map", BottleneckType::kSlowTransformationWithParallelVersion},
475        {"Interleave", BottleneckType::kSlowTransformationWithParallelVersion},
476        // Transformations without parallel version.
477        {"Filter", BottleneckType::kSlowTransformationWithoutParallelVersion},
478        {"Batch", BottleneckType::kSlowTransformationWithoutParallelVersion},
479        {"Unbatch", BottleneckType::kSlowTransformationWithoutParallelVersion}});
480   if (auto type =
481           gtl::FindOrNull(*kBottleneckTypeMap, bottleneck_iterator_name)) {
482     return *type;
483   }
484   return BottleneckType::kOther;
485 }
486 
Add(absl::string_view host_name,XPlane * host_plane)487 void CombinedTfDataStatsBuilder::Add(absl::string_view host_name,
488                                      XPlane* host_plane) {
489   TfDataStats& tf_data_stats =
490       (*combined_tf_data_stats_
491             ->mutable_tf_data_stats())[std::string(host_name)];
492   EventForest event_forest;
493   event_forest.AddPlanes(CreateTfXPlaneVisitor, {host_plane});
494   event_forest.ConnectEvents();
495   event_forest.ConnectTfDataEvents();
496   absl::flat_hash_set<int64_t> device_input_pipeline_ids;
497   absl::flat_hash_map<int64_t, std::vector<const EventNode*>>
498       root_iterator_event_map;
499   ProcessEventForest(event_forest, &device_input_pipeline_ids,
500                      &root_iterator_event_map, &tf_data_stats);
501   ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map,
502                         &tf_data_stats);
503 }
504 
Finalize()505 void CombinedTfDataStatsBuilder::Finalize() {
506   SetBottleneckAnalysis(combined_tf_data_stats_);
507   if (generate_suggestion_) SetSuggestion(combined_tf_data_stats_);
508   SetSummary(combined_tf_data_stats_);
509 }
510 
511 }  // namespace profiler
512 }  // namespace tensorflow
513