• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
21 
22 namespace tensorflow {
23 namespace profiler {
24 namespace {
25 
26 using OperationType = OpMetrics::MemoryAccessed::OperationType;
27 
CombinePrecisionStats(const PrecisionStats & src,PrecisionStats * dst)28 void CombinePrecisionStats(const PrecisionStats& src, PrecisionStats* dst) {
29   dst->set_compute_16bit_ps(src.compute_16bit_ps() + dst->compute_16bit_ps());
30   dst->set_compute_32bit_ps(src.compute_32bit_ps() + dst->compute_32bit_ps());
31 }
32 
33 }  // namespace
34 
CopyOpMetricsMetadata(const OpMetrics & src,OpMetrics * dst)35 void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst) {
36   DCHECK(dst != nullptr);
37   DCHECK_EQ(src.hlo_module_id(), dst->hlo_module_id());
38   DCHECK_EQ(src.name(), dst->name());
39   if (dst->long_name().empty()) {
40     dst->set_long_name(src.long_name());
41   }
42   if (dst->category().empty()) {
43     dst->set_category(src.category());
44   }
45   if (dst->provenance().empty()) {
46     dst->set_provenance(src.provenance());
47   }
48   if (dst->deduplicated_name().empty()) {
49     dst->set_deduplicated_name(src.deduplicated_name());
50   }
51   if (!dst->has_layout() && src.has_layout()) {
52     *dst->mutable_layout() = src.layout();
53   }
54   if (!dst->has_children() && src.has_children()) {
55     *dst->mutable_children() = src.children();
56   }
57 }
58 
CombineOpMetrics(const OpMetrics & src,OpMetrics * dst)59 void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst) {
60   DCHECK(dst != nullptr);
61   if (dst->occurrences() == 0) {
62     dst->set_min_time_ps(src.min_time_ps());
63   } else {
64     dst->set_min_time_ps(std::min(src.min_time_ps(), dst->min_time_ps()));
65   }
66   dst->set_is_eager(dst->is_eager() || src.is_eager());
67   dst->set_occurrences(src.occurrences() + dst->occurrences());
68   dst->set_time_ps(src.time_ps() + dst->time_ps());
69   dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps());
70   dst->set_flops(src.flops() + dst->flops());
71   dst->set_bytes_accessed(src.bytes_accessed() + dst->bytes_accessed());
72   CombineMemoryAccessedBreakdown(src.memory_accessed_breakdown(),
73                                  dst->mutable_memory_accessed_breakdown());
74   dst->set_dma_stall_ps(src.dma_stall_ps() + dst->dma_stall_ps());
75 }
76 
CombineMemoryAccessedBreakdown(const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed> & src,protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed> * dst)77 void CombineMemoryAccessedBreakdown(
78     const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>& src,
79     protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>* dst) {
80   if (src.empty()) return;
81   absl::flat_hash_map<std::pair<uint64 /*memory_space*/, OperationType>,
82                       OpMetrics_MemoryAccessed*>
83       dst_memory_accessed_map;
84   for (auto& dst_memory_accessed : *dst) {
85     dst_memory_accessed_map[{dst_memory_accessed.memory_space(),
86                              dst_memory_accessed.operation_type()}] =
87         &dst_memory_accessed;
88   }
89   for (const auto& src_memory_accessed : src) {
90     uint64 memory_space = src_memory_accessed.memory_space();
91     OperationType operation_type = src_memory_accessed.operation_type();
92     auto*& dst_memory_accessed =
93         dst_memory_accessed_map[{memory_space, operation_type}];
94     if (dst_memory_accessed == nullptr) {
95       dst_memory_accessed = dst->Add();
96       dst_memory_accessed->set_memory_space(memory_space);
97       dst_memory_accessed->set_operation_type(operation_type);
98     }
99     dst_memory_accessed->set_bytes_accessed(
100         src_memory_accessed.bytes_accessed() +
101         dst_memory_accessed->bytes_accessed());
102   }
103 }
104 
Combine(const OpMetricsDb & src)105 void OpMetricsDbCombiner::Combine(const OpMetricsDb& src) {
106   OpMetricsDb* dst = db();
107   dst->set_total_host_infeed_enq_duration_ps(
108       src.total_host_infeed_enq_duration_ps() +
109       dst->total_host_infeed_enq_duration_ps());
110   dst->set_total_host_infeed_enq_start_timestamp_ps_diff(
111       src.total_host_infeed_enq_start_timestamp_ps_diff() +
112       dst->total_host_infeed_enq_start_timestamp_ps_diff());
113   dst->set_total_time_ps(src.total_time_ps() + dst->total_time_ps());
114   dst->set_total_op_time_ps(src.total_op_time_ps() + dst->total_op_time_ps());
115   CombinePrecisionStats(src.precision_stats(), dst->mutable_precision_stats());
116 
117   for (const auto& src_metrics : src.metrics_db()) {
118     auto* dst_metrics = LookupOrInsertNewOpMetrics(src_metrics.hlo_module_id(),
119                                                    src_metrics.name());
120     CopyOpMetricsMetadata(src_metrics, dst_metrics);
121     CombineOpMetrics(src_metrics, dst_metrics);
122   }
123 }
124 
125 }  // namespace profiler
126 }  // namespace tensorflow
127