• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
18 
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/types.h"
26 
27 namespace xla {
28 
29 // Base class for HLO passes. These are used with the HloPassPipeline to
30 // organize a sequence of passes. An HLO pass should not extend this class
31 // directly; it should extend HloModulePass or HloModuleGroupPass.
32 class HloPassInterface {
33  public:
34   // Struct that holds states of pass runs across multiple iterations.
35   struct RunState {
36     // The current iteration number.
37     int iteration = 0;
38     // Set of all changed computations from all pass runs using this state.
39     absl::flat_hash_set<HloComputation*> changed;
40     // Set of changed computation from previous iteration.
41     absl::flat_hash_set<HloComputation*> changed_last_iteration;
42     // Set of changed computation from current iteration.
43     absl::flat_hash_set<HloComputation*> changed_this_iteration;
44 
45     RunState() = default;
RunStateRunState46     explicit RunState(HloModule* module)
47         : changed_last_iteration(module->computations().begin(),
48                                  module->computations().end()) {}
49 
50     // Transition to the next iteration.
51     //
52     // Depending on the pass implmentation, one iteration includes all the work
53     // done between two IncrementIteration calls, there can be arbitrary number
54     // of passes that ran arbitrary times with this state.
IncrementIterationRunState55     void IncrementIteration() {
56       using std::swap;
57       changed.insert(changed_this_iteration.begin(),
58                      changed_this_iteration.end());
59       swap(changed_last_iteration, changed_this_iteration);
60       changed_this_iteration.clear();
61       ++iteration;
62     }
63   };
64   virtual ~HloPassInterface() = default;
65   virtual absl::string_view name() const = 0;
66 
67   // Run the pass on the given HLO module with specified execution_threads.
68   // Empty execution_threads list means all execution_threads are included.
69   // Returns whether it modified the module. Note that due to C++ inheritance
70   // hides overloaded function, Run(HloModule* module) is not a member function
71   // of a subclass unless it's explicitly brought to the subclass besides
72   // implementing the virtual version, for instance,
73   //
74   //   class MyNewPass : public HloModulePass {
75   //    public:
76   //      MyNewPass();
77   //      absl::string_view name() const override { return "my-new-pass"; }
78   //
79   //      using HloPassInterface::Run;
80   //      StatusOr<bool> Run(
81   //        HloModule* module,
82   //        const absl::flat_hash_set<absl::string_view>& execution_threads)
83   //        override;
84   //   };
85   //
Run(HloModule * module)86   StatusOr<bool> Run(HloModule* module) {
87     return Run(module, /*execution_threads=*/{});
88   }
89   virtual StatusOr<bool> Run(
90       HloModule* module,
91       const absl::flat_hash_set<absl::string_view>& execution_threads) = 0;
92 
93   // Run the pass on computation on changed computations from last iteration in
94   // given HLO module for specified execution_threads, with caller provided
95   // RunState which holds the state information across multiple iterations.
96   //
97   // NOTE: This is a temporary default implementation that conservatively treats
98   // all computations as changed. Eventually all passes should override this
99   // method instead of Run() and Run() will call into this method instead.
RunOnChangedComputations(HloModule * module,RunState * run_state,const absl::flat_hash_set<absl::string_view> & execution_threads)100   virtual Status RunOnChangedComputations(
101       HloModule* module, RunState* run_state,
102       const absl::flat_hash_set<absl::string_view>& execution_threads) {
103     TF_ASSIGN_OR_RETURN(bool changed, Run(module, execution_threads));
104     if (changed) {
105       auto computations = module->computations(execution_threads);
106       run_state->changed_this_iteration.insert(computations.begin(),
107                                                computations.end());
108     }
109     return OkStatus();
110   }
111 
112   // Run the pass on the given HLO module group for specified
113   // `execution_threads`. Empty `execution_threads` list means all execution
114   // threads are included. Returns whether it modified the module group.
115   // Ideally, the module group variant would be named "Run" as well, but C++
116   // does not handle overloaded virtual methods well.
117   //
118   // Note that due to C++ inheritance hides overloaded function,
119   // RunOnModuleGroup(HloModuleGroup* module_group) is not a member function of
120   // a subclass unless it's explicitly brought to the subclass besides
121   // implementing the virtual version, for instance,
122   //
123   //   class MyNewPass : public HloModuleGroupPass {
124   //    public:
125   //      MyNewPass();
126   //      absl::string_view name() const override { return "my-new-pass"; }
127   //
128   //      using HloPassInterface::RunOnModuleGroup;
129   //      StatusOr<bool> RunOnModuleGroup(
130   //        HloModuleGroup* module_group,
131   //        const absl::flat_hash_set<absl::string_view>& execution_threads)
132   //        override;
133   //   };
134   //
RunOnModuleGroup(HloModuleGroup * module_group)135   StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) {
136     return RunOnModuleGroup(module_group, /*execution_threads=*/{});
137   }
138   virtual StatusOr<bool> RunOnModuleGroup(
139       HloModuleGroup* module_group,
140       const absl::flat_hash_set<absl::string_view>& execution_threads) = 0;
141 
IsPassPipeline()142   virtual bool IsPassPipeline() { return false; }
143 };
144 
145 // Base class for passes which are module-scoped.
146 class HloModulePass : public HloPassInterface {
147  public:
148   // Runs the pass on a module group by iterating through each module in the
149   // group.
RunOnModuleGroup(HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)150   StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group,
151                                   const absl::flat_hash_set<absl::string_view>&
152                                       execution_threads) override {
153     bool changed = false;
154     for (HloModule* module : module_group->modules()) {
155       TF_ASSIGN_OR_RETURN(bool module_changed, Run(module, execution_threads));
156       changed |= module_changed;
157     }
158     return changed;
159   };
160 
161   // Update the layout of a Shape to one that is supported by a given backend.
162   // One can call this function after modifying the Shape in case that modifying
163   // the Shape requires changes to the layout for the given Backend.
164   //
165   // TODO(b/129084868): Make this Backend dependent instead of requiring
166   // deriving from the pass and overriding this function.
UpdateLayout(Shape * shape)167   virtual void UpdateLayout(Shape* shape) {}
168 };
169 
170 // Base class for passes which are module-group scoped. These passes cannot run
171 // on an HLO module.
172 class HloModuleGroupPass : public HloPassInterface {
173  public:
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)174   StatusOr<bool> Run(HloModule* module,
175                      const absl::flat_hash_set<absl::string_view>&
176                          execution_threads) override {
177     return InternalError("Module group pass cannot be run on a module");
178   }
179 };
180 
181 }  // namespace xla
182 
183 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
184