• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "tensorflow/compiler/xla/service/hlo_execution_profile_data.pb.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 
31 namespace xla {
HloProfileIndexMap(const HloModule & module,absl::Span<const std::string> extra_metrics)32 HloProfileIndexMap::HloProfileIndexMap(
33     const HloModule& module, absl::Span<const std::string> extra_metrics) {
34   size_t current_profile_index = 0;
35   for (xla::HloComputation* computation : module.MakeComputationPostOrder()) {
36     InsertOrDie(&computation_to_profile_idx_, computation,
37                 current_profile_index++);
38     for (const HloInstruction* instruction : computation->instructions()) {
39       // For simplicity we track all instructions here, but we could skip
40       // non-executing instructions like constants and parameters.
41       InsertOrDie(&instruction_to_profile_idx_, instruction,
42                   current_profile_index++);
43     }
44   }
45   for (const std::string& key : extra_metrics) {
46     InsertOrDie(&extra_metric_to_profile_idx_, key, current_profile_index++);
47   }
48 }
49 
CreateHloProfilePrinterData(const HloProfileIndexMap & hlo_profile_index_map,const HloCostAnalysis & cost_analysis,const std::string & entry_computation_name)50 std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
51     const HloProfileIndexMap& hlo_profile_index_map,
52     const HloCostAnalysis& cost_analysis,
53     const std::string& entry_computation_name) {
54   using HloComputationInfo = HloProfilePrinterData::HloComputationInfo;
55   using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo;
56 
57   size_t profile_counters_size = hlo_profile_index_map.total_count();
58 
59   std::unique_ptr<HloProfilePrinterData> profile_printer_data =
60       std::make_unique<HloProfilePrinterData>();
61   profile_printer_data->set_profile_counters_size(profile_counters_size);
62   profile_printer_data->mutable_computation_infos()->Reserve(
63       hlo_profile_index_map.computation_count());
64 
65   const auto& computation_to_profile_idx_map =
66       hlo_profile_index_map.computation_to_profile_idx();
67 
68   // computation_to_profile_idx_map's order is not deterministic so create a
69   // deterministic computation_and_profile_idx_list so that we end up with a
70   // deterministic HloProfilePrinterData protobuf.
71 
72   std::vector<std::pair<const HloComputation*, int64_t>>
73       computation_and_profile_idx_list(computation_to_profile_idx_map.begin(),
74                                        computation_to_profile_idx_map.end());
75 
76   // The profile indices were computed deterministically in
77   // HloProfileIndexMap::HloProfileIndexMap.
78   absl::c_sort(computation_and_profile_idx_list,
79                [](const std::pair<const HloComputation*, int64_t>& left,
80                   const std::pair<const HloComputation*, int64_t>& right) {
81                  return left.second < right.second;
82                });
83 
84   for (const auto& pair : computation_and_profile_idx_list) {
85     CHECK_LT(pair.second, profile_counters_size);
86     const HloComputation* computation = pair.first;
87     HloComputationInfo* computation_info =
88         profile_printer_data->add_computation_infos();
89 
90     computation_info->set_name(computation->name());
91     computation_info->set_profile_index(pair.second);
92     computation_info->mutable_instruction_infos()->Reserve(
93         computation->instruction_count());
94 
95     for (const HloInstruction* hlo : computation->instructions()) {
96       HloInstructionInfo* instruction_info =
97           computation_info->add_instruction_infos();
98       instruction_info->set_long_name(hlo->ToString());
99       instruction_info->set_short_name(hlo->ToString(
100           HloPrintOptions().set_compact_operands(true).set_print_operand_names(
101               false)));
102       instruction_info->set_category(hlo->ToCategory());
103       instruction_info->set_flop_count(cost_analysis.flop_count(*hlo));
104       instruction_info->set_transcendental_count(
105           cost_analysis.transcendental_count(*hlo));
106       instruction_info->set_bytes_accessed(cost_analysis.bytes_accessed(*hlo));
107       instruction_info->set_optimal_seconds(
108           cost_analysis.optimal_seconds(*hlo));
109       instruction_info->set_profile_index(
110           hlo_profile_index_map.GetProfileIndexFor(*hlo));
111     }
112   }
113 
114   // Add extra metrics if any.
115   for (const auto& pair : hlo_profile_index_map.extra_metric_to_profile_idx()) {
116     profile_printer_data->mutable_extra_metrics()->insert(
117         {pair.first, pair.second});
118   }
119 
120   profile_printer_data->set_entry_computation(entry_computation_name);
121 
122   return profile_printer_data;
123 }
124 
HloExecutionProfile(const HloProfilePrinterData * hlo_profile_printer_data,const HloProfileIndexMap * hlo_profile_index_map)125 HloExecutionProfile::HloExecutionProfile(
126     const HloProfilePrinterData* hlo_profile_printer_data,
127     const HloProfileIndexMap* hlo_profile_index_map)
128     : hlo_profile_printer_data_(*hlo_profile_printer_data),
129       hlo_profile_index_map_(*hlo_profile_index_map),
130       profile_counters_(
131           /*count=*/hlo_profile_index_map_.total_count(),
132           /*value=*/0) {}
133 
SetCyclesTakenBy(const HloInstruction * hlo,uint64_t cycles_taken)134 void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo,
135                                            uint64_t cycles_taken) {
136   SetCyclesTakenBy(hlo_profile_index_map_.GetProfileIndexFor(*hlo),
137                    cycles_taken);
138 }
139 
SetCyclesTakenBy(size_t index,uint64_t cycles_taken)140 void HloExecutionProfile::SetCyclesTakenBy(size_t index,
141                                            uint64_t cycles_taken) {
142   profile_counters_[index] = cycles_taken;
143 }
144 
GetCyclesTakenBy(const HloInstruction & hlo) const145 uint64_t HloExecutionProfile::GetCyclesTakenBy(
146     const HloInstruction& hlo) const {
147   return GetCyclesTakenBy(hlo_profile_index_map_.GetProfileIndexFor(hlo));
148 }
149 
GetCyclesTakenBy(size_t index) const150 uint64_t HloExecutionProfile::GetCyclesTakenBy(size_t index) const {
151   return profile_counters_[index];
152 }
153 
ToProto() const154 HloExecutionProfileData HloExecutionProfile::ToProto() const {
155   HloExecutionProfileData hlo_execution_profile_data;
156   for (const auto& counter : profile_counters_) {
157     hlo_execution_profile_data.add_profile_counters(counter);
158   }
159   *(hlo_execution_profile_data.mutable_printer_data()) =
160       hlo_profile_printer_data_;
161   return hlo_execution_profile_data;
162 }
163 
164 }  // namespace xla
165