• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17 
18 #include <functional>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
26 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31 
32 namespace xla {
33 
34 template <typename HloT>
RunInvariantCheckers(HloT * hlo,absl::string_view after_pass_name)35 Status HloPassPipeline::RunInvariantCheckers(
36     HloT* hlo, absl::string_view after_pass_name) {
37   for (auto& invariant_checker : invariant_checkers_) {
38     VLOG(1) << "    Invariant checker " << invariant_checker->name();
39     StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
40     VLOG(1) << "    Invariant checker done " << invariant_checker->name();
41     if (!changed_status.ok()) {
42       VLOG(2) << "Failed invariant check:";
43       XLA_VLOG_LINES(2, hlo->ToString());
44       return Status(changed_status.status().code(),
45                     absl::StrCat(changed_status.status().error_message(),
46                                  "\n\nFailed after ", after_pass_name));
47     }
48     TF_RET_CHECK(!changed_status.ValueOrDie())
49         << "invariant checkers must not change the graph";
50   }
51   return Status::OK();
52 }
53 
54 template <typename HloT>
RunPassesInternal(HloT * hlo,absl::Span<HloPassInterface * const> passes)55 StatusOr<bool> HloPassPipeline::RunPassesInternal(
56     HloT* hlo, absl::Span<HloPassInterface* const> passes) {
57   string last_pass_name = "pipeline-start";
58   TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
59   bool changed = false;
60   for (HloPassInterface* pass : passes) {
61     VLOG(1) << "  HLO pass " << pass->name();
62     MaybeDumpHlo(*hlo,
63                  /*after_pass_name=*/last_pass_name,
64                  /*before_pass_name=*/pass->name());
65     TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
66     changed |= pass_changed;
67     TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
68     last_pass_name = string(pass->name());
69   }
70   MaybeDumpHlo(*hlo,
71                /*after_pass_name=*/last_pass_name,
72                /*before_pass_name=*/"pipeline-end");
73   return changed;
74 }
75 
GetEnabledPasses(const DebugOptions & debug_options)76 std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
77     const DebugOptions& debug_options) {
78   auto repeated_field = debug_options.xla_disable_hlo_passes();
79   absl::flat_hash_set<string> disabled_pass_names(repeated_field.begin(),
80                                                   repeated_field.end());
81   if (debug_options.xla_disable_all_hlo_passes()) {
82     VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes.";
83     return {};
84   }
85 
86   if (!disabled_pass_names.empty()) {
87     VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
88             << absl::StrJoin(disabled_pass_names, ", ");
89   }
90 
91   std::vector<HloPassInterface*> enabled_passes;
92   for (auto& pass : passes_) {
93     if (!disabled_pass_names.contains(pass->name())) {
94       enabled_passes.push_back(pass.get());
95     }
96   }
97   return enabled_passes;
98 }
99 
MaybeDumpHlo(const HloModule & module,absl::string_view after_pass_name,absl::string_view before_pass_name)100 void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
101                                    absl::string_view after_pass_name,
102                                    absl::string_view before_pass_name) {
103   DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name,
104                                       module);
105 }
106 
MaybeDumpHlo(const HloModuleGroup & module_group,absl::string_view after_pass_name,absl::string_view before_pass_name)107 void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
108                                    absl::string_view after_pass_name,
109                                    absl::string_view before_pass_name) {
110   for (const HloModule* module : module_group.modules()) {
111     MaybeDumpHlo(*module, after_pass_name, before_pass_name);
112   }
113 }
114 
Run(HloModule * module)115 StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
116   run_called_ = true;
117 
118   VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
119           << name();
120 
121   return RunPassesInternal(module,
122                            GetEnabledPasses(module->config().debug_options()));
123 }
124 
RunOnModuleGroup(HloModuleGroup * module_group)125 StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
126   run_called_ = true;
127 
128   VLOG(1) << "Running HLO pass pipeline on module group "
129           << module_group->name() << ": " << name();
130 
131   if (module_group->modules().empty()) {
132     VLOG(1) << "Module group is empty. Nothing to do.";
133     return false;
134   }
135 
136   return RunPassesInternal(
137       module_group,
138       GetEnabledPasses(module_group->module(0).config().debug_options()));
139 }
140 
141 }  // namespace xla
142