1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ 18 19 #include "absl/container/flat_hash_set.h" 20 #include "absl/strings/string_view.h" 21 #include "tensorflow/compiler/xla/service/hlo_module.h" 22 #include "tensorflow/compiler/xla/service/hlo_module_group.h" 23 #include "tensorflow/compiler/xla/status_macros.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 #include "tensorflow/compiler/xla/types.h" 26 27 namespace xla { 28 29 // Base class for HLO passes. These are used with the HloPassPipeline to 30 // organize a sequence of passes. An HLO pass should not extend this class 31 // directly; it should extend HloModulePass or HloModuleGroupPass. 32 class HloPassInterface { 33 public: 34 // Struct that holds states of pass runs across multiple iterations. 35 struct RunState { 36 // The current iteration number. 37 int iteration = 0; 38 // Set of all changed computations from all pass runs using this state. 39 absl::flat_hash_set<HloComputation*> changed; 40 // Set of changed computation from previous iteration. 41 absl::flat_hash_set<HloComputation*> changed_last_iteration; 42 // Set of changed computation from current iteration. 43 absl::flat_hash_set<HloComputation*> changed_this_iteration; 44 45 RunState() = default; RunStateRunState46 explicit RunState(HloModule* module) 47 : changed_last_iteration(module->computations().begin(), 48 module->computations().end()) {} 49 50 // Transition to the next iteration. 51 // 52 // Depending on the pass implmentation, one iteration includes all the work 53 // done between two IncrementIteration calls, there can be arbitrary number 54 // of passes that ran arbitrary times with this state. IncrementIterationRunState55 void IncrementIteration() { 56 using std::swap; 57 changed.insert(changed_this_iteration.begin(), 58 changed_this_iteration.end()); 59 swap(changed_last_iteration, changed_this_iteration); 60 changed_this_iteration.clear(); 61 ++iteration; 62 } 63 }; 64 virtual ~HloPassInterface() = default; 65 virtual absl::string_view name() const = 0; 66 67 // Run the pass on the given HLO module with specified execution_threads. 68 // Empty execution_threads list means all execution_threads are included. 69 // Returns whether it modified the module. Note that due to C++ inheritance 70 // hides overloaded function, Run(HloModule* module) is not a member function 71 // of a subclass unless it's explicitly brought to the subclass besides 72 // implementing the virtual version, for instance, 73 // 74 // class MyNewPass : public HloModulePass { 75 // public: 76 // MyNewPass(); 77 // absl::string_view name() const override { return "my-new-pass"; } 78 // 79 // using HloPassInterface::Run; 80 // StatusOr<bool> Run( 81 // HloModule* module, 82 // const absl::flat_hash_set<absl::string_view>& execution_threads) 83 // override; 84 // }; 85 // Run(HloModule * module)86 StatusOr<bool> Run(HloModule* module) { 87 return Run(module, /*execution_threads=*/{}); 88 } 89 virtual StatusOr<bool> Run( 90 HloModule* module, 91 const absl::flat_hash_set<absl::string_view>& execution_threads) = 0; 92 93 // Run the pass on computation on changed computations from last iteration in 94 // given HLO module for specified execution_threads, with caller provided 95 // RunState which holds the state information across multiple iterations. 96 // 97 // NOTE: This is a temporary default implementation that conservatively treats 98 // all computations as changed. Eventually all passes should override this 99 // method instead of Run() and Run() will call into this method instead. RunOnChangedComputations(HloModule * module,RunState * run_state,const absl::flat_hash_set<absl::string_view> & execution_threads)100 virtual Status RunOnChangedComputations( 101 HloModule* module, RunState* run_state, 102 const absl::flat_hash_set<absl::string_view>& execution_threads) { 103 TF_ASSIGN_OR_RETURN(bool changed, Run(module, execution_threads)); 104 if (changed) { 105 auto computations = module->computations(execution_threads); 106 run_state->changed_this_iteration.insert(computations.begin(), 107 computations.end()); 108 } 109 return OkStatus(); 110 } 111 112 // Run the pass on the given HLO module group for specified 113 // `execution_threads`. Empty `execution_threads` list means all execution 114 // threads are included. Returns whether it modified the module group. 115 // Ideally, the module group variant would be named "Run" as well, but C++ 116 // does not handle overloaded virtual methods well. 117 // 118 // Note that due to C++ inheritance hides overloaded function, 119 // RunOnModuleGroup(HloModuleGroup* module_group) is not a member function of 120 // a subclass unless it's explicitly brought to the subclass besides 121 // implementing the virtual version, for instance, 122 // 123 // class MyNewPass : public HloModuleGroupPass { 124 // public: 125 // MyNewPass(); 126 // absl::string_view name() const override { return "my-new-pass"; } 127 // 128 // using HloPassInterface::RunOnModuleGroup; 129 // StatusOr<bool> RunOnModuleGroup( 130 // HloModuleGroup* module_group, 131 // const absl::flat_hash_set<absl::string_view>& execution_threads) 132 // override; 133 // }; 134 // RunOnModuleGroup(HloModuleGroup * module_group)135 StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) { 136 return RunOnModuleGroup(module_group, /*execution_threads=*/{}); 137 } 138 virtual StatusOr<bool> RunOnModuleGroup( 139 HloModuleGroup* module_group, 140 const absl::flat_hash_set<absl::string_view>& execution_threads) = 0; 141 IsPassPipeline()142 virtual bool IsPassPipeline() { return false; } 143 }; 144 145 // Base class for passes which are module-scoped. 146 class HloModulePass : public HloPassInterface { 147 public: 148 // Runs the pass on a module group by iterating through each module in the 149 // group. RunOnModuleGroup(HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)150 StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group, 151 const absl::flat_hash_set<absl::string_view>& 152 execution_threads) override { 153 bool changed = false; 154 for (HloModule* module : module_group->modules()) { 155 TF_ASSIGN_OR_RETURN(bool module_changed, Run(module, execution_threads)); 156 changed |= module_changed; 157 } 158 return changed; 159 }; 160 161 // Update the layout of a Shape to one that is supported by a given backend. 162 // One can call this function after modifying the Shape in case that modifying 163 // the Shape requires changes to the layout for the given Backend. 164 // 165 // TODO(b/129084868): Make this Backend dependent instead of requiring 166 // deriving from the pass and overriding this function. UpdateLayout(Shape * shape)167 virtual void UpdateLayout(Shape* shape) {} 168 }; 169 170 // Base class for passes which are module-group scoped. These passes cannot run 171 // on an HLO module. 172 class HloModuleGroupPass : public HloPassInterface { 173 public: Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)174 StatusOr<bool> Run(HloModule* module, 175 const absl::flat_hash_set<absl::string_view>& 176 execution_threads) override { 177 return InternalError("Module group pass cannot be run on a module"); 178 } 179 }; 180 181 } // namespace xla 182 183 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ 184