1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ 18 19 #include <algorithm> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "absl/memory/memory.h" 25 #include "absl/strings/str_cat.h" 26 #include "tensorflow/compiler/xla/service/compilation_stats.h" 27 #include "tensorflow/compiler/xla/service/hlo_module.h" 28 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 29 #include "tensorflow/compiler/xla/statusor.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/core/platform/macros.h" 32 33 namespace xla { 34 35 class PhaseOrderPipeline; 36 37 // Pipeline of HLO passes. 38 class HloPassPipeline : public HloPassInterface { 39 public: 40 explicit HloPassPipeline(const string& name, 41 CompilationStats* compilation_stats = nullptr) name_(name)42 : name_(name), compilation_stats_(compilation_stats) { 43 if (compilation_stats == nullptr) { 44 empty_compilation_stats_ = CompilationStats::MakeNoopStats(); 45 compilation_stats_ = empty_compilation_stats_.get(); 46 } 47 } name()48 absl::string_view name() const override { return name_; } 49 50 // Add a pass to the pipeline. It should be called with the arguments for the 51 // pass constructor: 52 // 53 // pipeline.AddPass<FooPass>(constructor_arg1, constructor_arg2); 54 // 55 // Returns a reference to the added pass. 56 template <typename T, typename... Args> AddPass(Args &&...args)57 T& AddPass(Args&&... args) { 58 CHECK(!run_called_) << "AddPass cannot be called after Run"; 59 auto pass = new T(std::forward<Args>(args)...); 60 passes_.push_back(std::unique_ptr<T>(pass)); 61 return *pass; 62 } 63 64 // Add an invariant-checking pass to the pipeline. It will be run before and 65 // after each HLO pass. The invariant checking pass must not mutate the graph 66 // (it is required to always return "false" from its Run() method). 67 template <typename T, typename... Args> AddInvariantChecker(Args &&...args)68 T& AddInvariantChecker(Args&&... args) { 69 CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run"; 70 auto pass = new T(std::forward<Args>(args)...); 71 invariant_checkers_.push_back(std::unique_ptr<T>(pass)); 72 return *pass; 73 } 74 75 // Add an invariant-checking pass to the pipeline on debug builds only. 76 template <typename T, typename... Args> AddInvariantCheckerDebug(Args &&...args)77 void AddInvariantCheckerDebug(Args&&... args) { 78 #ifndef NDEBUG 79 AddInvariantChecker<T>(std::forward<Args>(args)...); 80 #endif // NDEBUG 81 } 82 83 StatusOr<bool> Run(HloModule* module) override; 84 StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override; 85 IsPassPipeline()86 bool IsPassPipeline() override { return true; } 87 88 // Return size of passes_. PassesSize()89 int PassesSize() { return passes_.size(); } 90 // Return reference to pass specified by index. GetPass(int index)91 HloPassInterface& GetPass(int index) { return *passes_[index]; } 92 93 private: 94 // Returns the set of passes which are enabled. DebugOptions can selectively 95 // disable passes via --xla_disable_hlo_passes flag. 96 std::vector<HloPassInterface*> GetEnabledPasses( 97 const DebugOptions& debug_options); 98 99 // Maybe dumps the given module or module group depending on flag values 100 // contained in DebugOptions of module config. If it is dumped, saves the 101 // filenames of the dumps into module metadata. 102 void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group, 103 absl::string_view after_pass_name, 104 absl::string_view before_pass_name); 105 void MaybeDumpHloAndSaveFilenames(HloModule& module, 106 absl::string_view after_pass_name, 107 absl::string_view before_pass_name); 108 109 // Runs the invariant checker on the given HLO. HloT can be either HloModule 110 // or HloModuleGroup. 111 template <typename HloT> 112 Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name); 113 114 // Helper which runs the given pass on the given HLO. HloT can be either 115 // HloModule or HloModuleGroup. 116 template <typename HloT> 117 StatusOr<bool> RunPassesInternal(HloT* hlo, 118 const DebugOptions& debug_options); 119 120 // Helpers which run the given passes on the given HLO construct. These 121 // helpers enable templating of the core of the pipeline logic by providing 122 // HloModule and HloModuleGroup specific methods with the same name. RunHelper(HloPassInterface * pass,HloModule * module)123 static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) { 124 TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module)); 125 module->Cleanup(); 126 return changed; 127 } RunHelper(HloPassInterface * pass,HloModuleGroup * module_group)128 static StatusOr<bool> RunHelper(HloPassInterface* pass, 129 HloModuleGroup* module_group) { 130 TF_ASSIGN_OR_RETURN(bool changed, pass->RunOnModuleGroup(module_group)); 131 module_group->Cleanup(); 132 return changed; 133 } 134 135 const string name_; 136 std::vector<std::unique_ptr<HloPassInterface>> passes_; 137 std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_; 138 bool run_called_ = false; 139 140 CompilationStats* compilation_stats_; 141 // Default stats instance for when one is not passed in the constructor. 142 // Use via compilation_stats_, not directly. 143 std::unique_ptr<CompilationStats> empty_compilation_stats_; 144 145 // Allow PhaseOrderPipeline to modify private passes_ member in order to 146 // perform PhaseOrdering. 147 friend class ::xla::PhaseOrderPipeline; 148 }; 149 150 } // namespace xla 151 152 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_ 153