1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17
18 #include <functional>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
26 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31
32 namespace xla {
33
34 namespace {
35
RecordPassStartMetadata(HloModule & module,const std::string & pass_name,const std::string & pipeline_name)36 void RecordPassStartMetadata(HloModule& module, const std::string& pass_name,
37 const std::string& pipeline_name) {
38 module.metadata()->RecordPassStart();
39 // An HloPassMetadata was just created so Status should always be OK.
40 TF_CHECK_OK(module.metadata()->set_current_pass_name(pass_name));
41 TF_CHECK_OK(module.metadata()->set_current_pass_pipeline_name(pipeline_name));
42 }
43
RecordPassStartMetadata(HloModuleGroup & module_group,const std::string & pass_name,const std::string & pipeline_name)44 void RecordPassStartMetadata(HloModuleGroup& module_group,
45 const std::string& pass_name,
46 const std::string& pipeline_name) {
47 for (HloModule* module : module_group.modules()) {
48 RecordPassStartMetadata(*module, pass_name, pipeline_name);
49 }
50 }
51
AttemptRecordPassEndMetadata(HloModule & module,const std::string & pass_name,bool module_changed)52 Status AttemptRecordPassEndMetadata(HloModule& module,
53 const std::string& pass_name,
54 bool module_changed) {
55 // Module id is set here instead of RecordPassStartMetadata because it may
56 // change in the middle of the pass, and we want the final id.
57 TF_RETURN_IF_ERROR(
58 module.metadata()->set_current_pass_module_id(module.unique_id()));
59 TF_RETURN_IF_ERROR(
60 module.metadata()->set_current_pass_module_changed(module_changed));
61 TF_RETURN_IF_ERROR(module.metadata()->RecordPassEnd());
62 return Status::OK();
63 }
64
RecordPassEndMetadata(HloModule & module,const std::string & pass_name,bool module_changed)65 void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
66 bool module_changed) {
67 Status status =
68 AttemptRecordPassEndMetadata(module, pass_name, module_changed);
69 if (!status.ok()) {
70 LOG(FATAL) << status;
71 }
72 }
73
AttemptRecordPassEndMetadata(HloModuleGroup & module_group,const std::string & pass_name,bool module_changed)74 Status AttemptRecordPassEndMetadata(HloModuleGroup& module_group,
75 const std::string& pass_name,
76 bool module_changed) {
77 for (HloModule* module : module_group.modules()) {
78 for (HloModule* other_module : module_group.modules()) {
79 TF_RETURN_IF_ERROR(
80 module->metadata()->add_current_pass_module_group_module_id(
81 other_module->unique_id()));
82 }
83 TF_RETURN_IF_ERROR(
84 AttemptRecordPassEndMetadata(*module, pass_name, module_changed));
85 }
86 return Status::OK();
87 }
88
RecordPassEndMetadata(HloModuleGroup & module_group,const std::string & pass_name,bool module_changed)89 void RecordPassEndMetadata(HloModuleGroup& module_group,
90 const std::string& pass_name, bool module_changed) {
91 Status status =
92 AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
93 if (!status.ok()) {
94 LOG(FATAL) << status;
95 }
96 }
97
SetInstructionMetadata(HloModule & module)98 void SetInstructionMetadata(HloModule& module) {
99 StatusOr<int64> pass_id = module.metadata()->current_pass_id();
100 if (!pass_id.ok()) {
101 LOG(FATAL) << pass_id.status();
102 }
103 for (xla::HloComputation* computation : module.computations()) {
104 for (xla::HloInstruction* instruction : computation->instructions()) {
105 if (instruction->metadata().creation_pass_id() == 0) {
106 instruction->set_creation_pass_id(*pass_id);
107 }
108 if (instruction->metadata().logical_creation_pass_id() == 0) {
109 instruction->set_logical_creation_pass_id(*pass_id);
110 }
111 }
112 }
113 }
114
SetInstructionMetadata(HloModuleGroup & module_group)115 void SetInstructionMetadata(HloModuleGroup& module_group) {
116 for (HloModule* module : module_group.modules()) {
117 SetInstructionMetadata(*module);
118 }
119 }
120
121 } // namespace
122
123 template <typename HloT>
RunInvariantCheckers(HloT * hlo,absl::string_view after_pass_name)124 Status HloPassPipeline::RunInvariantCheckers(
125 HloT* hlo, absl::string_view after_pass_name) {
126 for (auto& invariant_checker : invariant_checkers_) {
127 VLOG(1) << " Invariant checker " << invariant_checker->name();
128 StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
129 VLOG(1) << " Invariant checker done " << invariant_checker->name();
130 if (!changed_status.ok()) {
131 VLOG(2) << "Failed invariant check:";
132 XLA_VLOG_LINES(2, hlo->ToString());
133 return Status(changed_status.status().code(),
134 absl::StrCat(changed_status.status().error_message(),
135 "\n\nFailed after ", after_pass_name));
136 }
137 TF_RET_CHECK(!changed_status.ValueOrDie())
138 << "invariant checkers must not change the graph";
139 }
140 return Status::OK();
141 }
142
143 template <typename HloT>
RunPassesInternal(HloT * hlo,absl::Span<HloPassInterface * const> passes)144 StatusOr<bool> HloPassPipeline::RunPassesInternal(
145 HloT* hlo, absl::Span<HloPassInterface* const> passes) {
146 static constexpr absl::string_view kPipelineStart = "pipeline-start";
147 static constexpr absl::string_view kPipelineEnd = "pipeline-end";
148 std::string pipeline_name = std::string(name());
149
150 TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
151
152 RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
153 SetInstructionMetadata(*hlo);
154 MaybeDumpHloAndSaveFilenames(*hlo,
155 /*after_pass_name=*/kPipelineStart,
156 /*before_pass_name=*/passes.empty()
157 ? kPipelineEnd
158 : passes.front()->name());
159 RecordPassEndMetadata(*hlo, std::string(kPipelineStart),
160 /*module_changed=*/false);
161
162 bool changed = false;
163 for (int i = 0; i < passes.size(); i++) {
164 HloPassInterface* pass = passes[i];
165 XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name()));
166 std::string pass_name = std::string(pass->name());
167 VLOG(1) << " HLO pass " << pass_name;
168 VLOG(2) << " Module hash " << hlo->Hash();
169 if (!pass->IsPassPipeline()) {
170 compilation_stats_->StartPass(pass_name);
171 }
172 RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
173 TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
174 SetInstructionMetadata(*hlo);
175 MaybeDumpHloAndSaveFilenames(*hlo,
176 /*after_pass_name=*/pass_name,
177 /*before_pass_name=*/i + 1 >= passes.size()
178 ? kPipelineEnd
179 : passes[i + 1]->name());
180 RecordPassEndMetadata(*hlo, pass_name, pass_changed);
181 changed |= pass_changed;
182 if (pass_changed) {
183 VLOG(3) << " Pass caused changes " << pass->name();
184 }
185 TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
186 if (!pass->IsPassPipeline()) {
187 compilation_stats_->EndPass(pass_name);
188 }
189 }
190 return changed;
191 }
192
GetEnabledPasses(const DebugOptions & debug_options)193 std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
194 const DebugOptions& debug_options) {
195 if (debug_options.xla_disable_all_hlo_passes()) {
196 VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes.";
197 return {};
198 }
199
200 absl::flat_hash_set<string> disabled_pass_names(
201 debug_options.xla_disable_hlo_passes().begin(),
202 debug_options.xla_disable_hlo_passes().end());
203
204 absl::flat_hash_set<string> enabled_pass_names(
205 debug_options.xla_enable_hlo_passes_only().begin(),
206 debug_options.xla_enable_hlo_passes_only().end());
207
208 if (!disabled_pass_names.empty()) {
209 VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
210 << absl::StrJoin(disabled_pass_names, ", ");
211 }
212
213 if (!enabled_pass_names.empty()) {
214 VLOG(1) << "Passes enabled by --xla_enable_hlo_passes_only: "
215 << absl::StrJoin(enabled_pass_names, ", ");
216 }
217
218 CHECK(disabled_pass_names.empty() || enabled_pass_names.empty());
219
220 std::vector<HloPassInterface*> enabled_passes;
221 if (!enabled_pass_names.empty()) {
222 for (auto& pass : passes_) {
223 if (enabled_pass_names.contains(pass->name())) {
224 enabled_passes.push_back(pass.get());
225 }
226 }
227 } else {
228 for (auto& pass : passes_) {
229 if (!disabled_pass_names.contains(pass->name())) {
230 enabled_passes.push_back(pass.get());
231 }
232 }
233 }
234 return enabled_passes;
235 }
236
MaybeDumpHloAndSaveFilenames(HloModule & module,absl::string_view after_pass_name,absl::string_view before_pass_name)237 void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
238 HloModule& module, absl::string_view after_pass_name,
239 absl::string_view before_pass_name) {
240 for (const std::string& filename : DumpHloModuleBetweenPassesIfEnabled(
241 name(), before_pass_name, after_pass_name, module)) {
242 Status status = module.metadata()->add_current_pass_dump_filename(filename);
243 if (!status.ok()) {
244 LOG(FATAL) << status;
245 }
246 }
247 }
248
MaybeDumpHloAndSaveFilenames(HloModuleGroup & module_group,absl::string_view after_pass_name,absl::string_view before_pass_name)249 void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
250 HloModuleGroup& module_group, absl::string_view after_pass_name,
251 absl::string_view before_pass_name) {
252 for (HloModule* module : module_group.modules()) {
253 MaybeDumpHloAndSaveFilenames(*module, after_pass_name, before_pass_name);
254 }
255 }
256
Run(HloModule * module)257 StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
258 run_called_ = true;
259
260 VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
261 << name();
262
263 return RunPassesInternal(module,
264 GetEnabledPasses(module->config().debug_options()));
265 }
266
RunOnModuleGroup(HloModuleGroup * module_group)267 StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
268 run_called_ = true;
269
270 VLOG(1) << "Running HLO pass pipeline on module group "
271 << module_group->name() << ": " << name();
272
273 if (module_group->modules().empty()) {
274 VLOG(1) << "Module group is empty. Nothing to do.";
275 return false;
276 }
277
278 return RunPassesInternal(
279 module_group,
280 GetEnabledPasses(module_group->module(0).config().debug_options()));
281 }
282
283 } // namespace xla
284