1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ 17 #define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ 18 19 #include <functional> 20 21 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h" 22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 23 #include "tensorflow/core/common_runtime/function_optimization_registry.h" 24 #include "tensorflow/core/common_runtime/optimization_registry.h" 25 26 namespace tensorflow { 27 28 // -------------------------------------------------------------------------- // 29 // MLIR passes running on Tensorflow function graphs (Tensorflow V2). 30 // -------------------------------------------------------------------------- // 31 32 // Disabled - skip execution of the pass. 33 // Enabled - execute the pass, propagate errors to the caller if any. 34 // ShadowEnabled - execute the pass in a shadow mode. The pass should not commit 35 // any changes to the MLIR module it's processing. Failures are not propagated 36 // to the caller. 37 // FallbackEnabled - execute the pass and commit all the changes to the MLIR 38 // module in case of success. Do not commit any changes in case of failures, 39 // let the rest of the pipeline run. 40 enum class MlirOptimizationPassState { 41 Disabled, 42 Enabled, 43 ShadowEnabled, 44 FallbackEnabled 45 }; 46 47 // An API for registering MLIR ModulePass with the Tensorflow runtime. These 48 // passes are running only for function graphs built by Tensorflow V2 and 49 // instantiated by the process_function_library_runtime (see 50 // FunctionOptimizationPass for details). 51 class MlirOptimizationPass { 52 public: 53 virtual ~MlirOptimizationPass() = default; 54 virtual llvm::StringRef name() const = 0; 55 56 // Returns an enum value: 57 // Enabled if the pass is enabled for the given graph with specified config. 58 // Disabled if the pass is disabled. 59 // ShadowEnabled if the pass needs to be executed in shadow mode. 60 // 61 // When the pass is ShadowEnabled, the pass is executed for metrics collection 62 // and reporting purposes only, but none of the changes it makes to the MLIR 63 // module will be committed. 64 // `device_set` can be nullptr if the devices information is not 65 // available or no device specific filtering is required. 66 virtual MlirOptimizationPassState GetPassState( 67 const DeviceSet* device_set, const ConfigProto& config_proto, 68 const Graph& graph) const = 0; 69 70 virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module, 71 const Graph& graph) = 0; 72 }; 73 74 class MlirOptimizationPassRegistry { 75 public: 76 struct PassRegistration { 77 int priority; 78 std::unique_ptr<MlirOptimizationPass> pass; 79 }; 80 81 struct PriorityComparator { operatorPriorityComparator82 bool operator()(const PassRegistration& x, 83 const PassRegistration& y) const { 84 return x.priority < y.priority; 85 } 86 }; 87 88 using Passes = std::set<PassRegistration, PriorityComparator>; 89 90 // Returns the global registry of MLIR optimization passes. 91 static MlirOptimizationPassRegistry& Global(); 92 93 // Register optimization `pass` with the given `priority`. Add(int priority,std::unique_ptr<MlirOptimizationPass> pass)94 void Add(int priority, std::unique_ptr<MlirOptimizationPass> pass) { 95 auto inserted = passes_.insert({priority, std::move(pass)}); 96 CHECK(inserted.second) 97 << "Pass priority must be unique. " 98 << "Previously registered pass with the same priority: " 99 << inserted.first->pass->name().str(); 100 } 101 102 // Free the memory allocated for all passes. ClearPasses()103 void ClearPasses() { passes_.clear(); } 104 passes()105 const Passes& passes() const { return passes_; } 106 107 private: 108 Passes passes_; 109 }; 110 111 // Function optimization pass that runs all MLIR passes registered in 112 // MlirOptimizationPassRegistry. 113 class MlirFunctionOptimizationPass : public FunctionOptimizationPass { 114 public: 115 explicit MlirFunctionOptimizationPass( 116 const MlirOptimizationPassRegistry* registry = 117 &MlirOptimizationPassRegistry::Global()) registry_(registry)118 : registry_(registry) {} 119 120 // Executes all of the underlying registered MlirOptimizationPasses. 121 // 122 // The MlirFunctionOptimizationPass will be executed in fully shadow mode if 123 // all of the underlying registered MlirOptimizationPasses are ShadowEnabled. 124 // In this case, no changes should be done to the original TF graph and no 125 // failures propagated back to the user. Failures during the conversion 126 // of TF graph to MLIR module and back will be treated as a soft 127 // failures, e.g., relevant stats will be recorded and no error returned 128 // back to the caller. 129 // 130 // In case some of the passes are shadow enabled while others are enabled, 131 // failures in the enabled passes will be treated as real errors and 132 // propagated back to the caller. Failure during the shadow pass execution 133 // is a soft failure. 134 Status Run(const DeviceSet& device_set, const ConfigProto& config_proto, 135 std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def, 136 std::vector<std::string>* control_ret_node_names, 137 bool* control_rets_updated) override; 138 139 private: 140 const MlirOptimizationPassRegistry* registry_; 141 }; 142 143 // -------------------------------------------------------------------------- // 144 // MLIR passes running on Tensorflow V1 graphs. 145 // -------------------------------------------------------------------------- // 146 147 // An API for registering MLIR ModulePass with the Tensorflow runtime. These 148 // passes are running only for V1 graphs (legacy graphs) executed via Session 149 // runtime. Graph importer updates legacy graph behavior to V2 constructs (e.g. 150 // it raises control flow from Switch/Merge nodes to functional control flow 151 // with If/While operations). 152 class MlirV1CompatOptimizationPass { 153 public: 154 virtual ~MlirV1CompatOptimizationPass() = default; 155 virtual llvm::StringRef name() const = 0; 156 157 // Returns true if the pass is enabled for the given graph with specified 158 // config. `device_set` can be nullptr if the devices information is not 159 // available or no device specific filtering is required. 160 virtual bool IsEnabled(const DeviceSet* device_set, 161 const ConfigProto& config_proto, 162 const Graph& graph) const = 0; 163 164 virtual Status Run(const GraphOptimizationPassOptions& options, 165 mlir::ModuleOp module) = 0; 166 }; 167 168 class MlirV1CompatOptimizationPassRegistry { 169 public: 170 struct PassRegistration { 171 int priority; 172 std::unique_ptr<MlirV1CompatOptimizationPass> pass; 173 }; 174 175 struct PriorityComparator { operatorPriorityComparator176 bool operator()(const PassRegistration& x, 177 const PassRegistration& y) const { 178 return x.priority < y.priority; 179 } 180 }; 181 182 using Passes = std::set<PassRegistration, PriorityComparator>; 183 184 // Returns the global registry of MLIR optimization passes. 185 static MlirV1CompatOptimizationPassRegistry& Global(); 186 Add(int priority,std::unique_ptr<MlirV1CompatOptimizationPass> pass)187 void Add(int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) { 188 auto inserted = passes_.insert({priority, std::move(pass)}); 189 CHECK(inserted.second) 190 << "Pass priority must be unique. " 191 << "Previously registered pass with the same priority: " 192 << inserted.first->pass->name().str(); 193 } 194 passes()195 const Passes& passes() const { return passes_; } 196 197 private: 198 Passes passes_; 199 }; 200 201 class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass { 202 public: 203 explicit MlirV1CompatGraphOptimizationPass( 204 const MlirV1CompatOptimizationPassRegistry* registry = 205 &MlirV1CompatOptimizationPassRegistry::Global()) registry_(registry)206 : registry_(registry) {} 207 208 Status Run(const GraphOptimizationPassOptions& options) override; 209 210 private: 211 const MlirV1CompatOptimizationPassRegistry* registry_; 212 }; 213 214 // -------------------------------------------------------------------------- // 215 // Helper classes for static registration of MLIR (V1 Compat) passes in the 216 // corresponding registry. 217 // -------------------------------------------------------------------------- // 218 219 namespace mlir_pass_registration { 220 221 class MlirOptimizationPassRegistration { 222 public: MlirOptimizationPassRegistration(int priority,std::unique_ptr<MlirOptimizationPass> pass)223 explicit MlirOptimizationPassRegistration( 224 int priority, std::unique_ptr<MlirOptimizationPass> pass) { 225 MlirOptimizationPassRegistry::Global().Add(priority, std::move(pass)); 226 } 227 }; 228 229 class MlirV1CompatOptimizationPassRegistration { 230 public: MlirV1CompatOptimizationPassRegistration(int priority,std::unique_ptr<MlirV1CompatOptimizationPass> pass)231 explicit MlirV1CompatOptimizationPassRegistration( 232 int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) { 233 MlirV1CompatOptimizationPassRegistry::Global().Add(priority, 234 std::move(pass)); 235 } 236 }; 237 238 } // namespace mlir_pass_registration 239 240 } // namespace tensorflow 241 242 #endif // TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ 243