1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17
18 #include <functional>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
26 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31
32 namespace xla {
33
34 template <typename HloT>
RunInvariantCheckers(HloT * hlo,absl::string_view after_pass_name)35 Status HloPassPipeline::RunInvariantCheckers(
36 HloT* hlo, absl::string_view after_pass_name) {
37 for (auto& invariant_checker : invariant_checkers_) {
38 VLOG(1) << " Invariant checker " << invariant_checker->name();
39 StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
40 VLOG(1) << " Invariant checker done " << invariant_checker->name();
41 if (!changed_status.ok()) {
42 VLOG(2) << "Failed invariant check:";
43 XLA_VLOG_LINES(2, hlo->ToString());
44 return Status(changed_status.status().code(),
45 absl::StrCat(changed_status.status().error_message(),
46 "\n\nFailed after ", after_pass_name));
47 }
48 TF_RET_CHECK(!changed_status.ValueOrDie())
49 << "invariant checkers must not change the graph";
50 }
51 return Status::OK();
52 }
53
54 template <typename HloT>
RunPassesInternal(HloT * hlo,absl::Span<HloPassInterface * const> passes)55 StatusOr<bool> HloPassPipeline::RunPassesInternal(
56 HloT* hlo, absl::Span<HloPassInterface* const> passes) {
57 string last_pass_name = "pipeline-start";
58 TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
59 bool changed = false;
60 for (HloPassInterface* pass : passes) {
61 VLOG(1) << " HLO pass " << pass->name();
62 MaybeDumpHlo(*hlo,
63 /*after_pass_name=*/last_pass_name,
64 /*before_pass_name=*/pass->name());
65 TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
66 changed |= pass_changed;
67 TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
68 last_pass_name = string(pass->name());
69 }
70 MaybeDumpHlo(*hlo,
71 /*after_pass_name=*/last_pass_name,
72 /*before_pass_name=*/"pipeline-end");
73 return changed;
74 }
75
GetEnabledPasses(const DebugOptions & debug_options)76 std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
77 const DebugOptions& debug_options) {
78 auto repeated_field = debug_options.xla_disable_hlo_passes();
79 absl::flat_hash_set<string> disabled_pass_names(repeated_field.begin(),
80 repeated_field.end());
81 if (debug_options.xla_disable_all_hlo_passes()) {
82 VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes.";
83 return {};
84 }
85
86 if (!disabled_pass_names.empty()) {
87 VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
88 << absl::StrJoin(disabled_pass_names, ", ");
89 }
90
91 std::vector<HloPassInterface*> enabled_passes;
92 for (auto& pass : passes_) {
93 if (!disabled_pass_names.contains(pass->name())) {
94 enabled_passes.push_back(pass.get());
95 }
96 }
97 return enabled_passes;
98 }
99
MaybeDumpHlo(const HloModule & module,absl::string_view after_pass_name,absl::string_view before_pass_name)100 void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
101 absl::string_view after_pass_name,
102 absl::string_view before_pass_name) {
103 DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name,
104 module);
105 }
106
MaybeDumpHlo(const HloModuleGroup & module_group,absl::string_view after_pass_name,absl::string_view before_pass_name)107 void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
108 absl::string_view after_pass_name,
109 absl::string_view before_pass_name) {
110 for (const HloModule* module : module_group.modules()) {
111 MaybeDumpHlo(*module, after_pass_name, before_pass_name);
112 }
113 }
114
Run(HloModule * module)115 StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
116 run_called_ = true;
117
118 VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
119 << name();
120
121 return RunPassesInternal(module,
122 GetEnabledPasses(module->config().debug_options()));
123 }
124
RunOnModuleGroup(HloModuleGroup * module_group)125 StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
126 run_called_ = true;
127
128 VLOG(1) << "Running HLO pass pipeline on module group "
129 << module_group->name() << ": " << name();
130
131 if (module_group->modules().empty()) {
132 VLOG(1) << "Module group is empty. Nothing to do.";
133 return false;
134 }
135
136 return RunPassesInternal(
137 module_group,
138 GetEnabledPasses(module_group->module(0).config().debug_options()));
139 }
140
141 } // namespace xla
142