1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ 18 19 #include "absl/container/flat_hash_set.h" 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/hlo_module_group.h" 22 #include "tensorflow/compiler/xla/status_macros.h" 23 #include "tensorflow/compiler/xla/statusor.h" 24 #include "tensorflow/compiler/xla/types.h" 25 #include "tensorflow/core/platform/macros.h" 26 27 namespace xla { 28 29 // Base class for HLO passes. These are used with the HloPassPipeline to 30 // organize a sequence of passes. An HLO pass should not extend this class 31 // directly; it should extend HloModulePass or HloModuleGroupPass. 32 class HloPassInterface { 33 public: 34 // Struct that holds states of pass runs across multiple iterations. 35 struct RunState { 36 // The current iteration number. 37 int iteration = 0; 38 // Set of all changed computations from all pass runs using this state. 39 absl::flat_hash_set<HloComputation*> changed; 40 // Set of changed computation from previous iteration. 41 absl::flat_hash_set<HloComputation*> changed_last_iteration; 42 // Set of changed computation from current iteration. 43 absl::flat_hash_set<HloComputation*> changed_this_iteration; 44 45 RunState() = default; RunStateRunState46 explicit RunState(HloModule* module) 47 : changed_last_iteration(module->computations().begin(), 48 module->computations().end()) {} 49 50 // Transition to the next iteration. 51 // 52 // Depending on the pass implmentation, one iteration includes all the work 53 // done between two IncrementIteration calls, there can be arbitrary number 54 // of passes that ran arbitrary times with this state. IncrementIterationRunState55 void IncrementIteration() { 56 using std::swap; 57 changed.insert(changed_this_iteration.begin(), 58 changed_this_iteration.end()); 59 swap(changed_last_iteration, changed_this_iteration); 60 changed_this_iteration.clear(); 61 ++iteration; 62 } 63 }; 64 virtual ~HloPassInterface() = default; 65 virtual absl::string_view name() const = 0; 66 67 // Run the pass on the given HLO module. Returns whether it modified the 68 // module. 69 virtual StatusOr<bool> Run(HloModule* module) = 0; 70 71 // Run the pass on computation on changed computations from last iteration in 72 // given HLO module, with caller provided RunState which holds the state 73 // information across multiple iterations. 74 // 75 // NOTE: This is a temporary default implementation that conservatively treats 76 // all computations as changed. Eventually all passes should override this 77 // method instead of Run() and Run() will call into this method instead. RunOnChangedComputations(HloModule * module,RunState * run_state)78 virtual Status RunOnChangedComputations(HloModule* module, 79 RunState* run_state) { 80 TF_ASSIGN_OR_RETURN(bool changed, Run(module)); 81 if (changed) { 82 auto computations = module->computations(); 83 run_state->changed_this_iteration.insert(computations.begin(), 84 computations.end()); 85 } 86 return Status::OK(); 87 } 88 89 // Run the pass on the given HLO module group. Returns whether it modified the 90 // module group. Ideally, the module group variant would be named "Run" as 91 // well, but C++ does not handle overloaded virtual methods well. 92 virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0; 93 IsPassPipeline()94 virtual bool IsPassPipeline() { return false; } 95 }; 96 97 // Base class for passes which are module-scoped. 98 class HloModulePass : public HloPassInterface { 99 public: 100 // Runs the pass on a module group by iterating through each module in the 101 // group. RunOnModuleGroup(HloModuleGroup * module_group)102 StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override { 103 bool changed = false; 104 for (HloModule* module : module_group->modules()) { 105 TF_ASSIGN_OR_RETURN(bool module_changed, Run(module)); 106 changed |= module_changed; 107 } 108 return changed; 109 }; 110 111 // Update the layout of a Shape to one that is supported by a given backend. 112 // One can call this function after modifying the Shape in case that modifying 113 // the Shape requires changes to the layout for the given Backend. 114 // 115 // TODO(b/129084868): Make this Backend dependent instead of requiring 116 // deriving from the pass and overriding this function. UpdateLayout(Shape * shape)117 virtual void UpdateLayout(Shape* shape) {} 118 }; 119 120 // Base class for passes which are module-group scoped. These passes cannot run 121 // on an HLO module. 122 class HloModuleGroupPass : public HloPassInterface { 123 public: Run(HloModule * module)124 StatusOr<bool> Run(HloModule* module) override { 125 return InternalError("Module group pass cannot be run on a module"); 126 } 127 }; 128 129 } // namespace xla 130 131 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ 132