• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
17 #define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
18 
19 #include <functional>
20 
21 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
24 #include "tensorflow/core/common_runtime/optimization_registry.h"
25 
26 namespace tensorflow {
27 
28 // -------------------------------------------------------------------------- //
29 // MLIR passes running on Tensorflow function graphs (Tensorflow V2).
30 // -------------------------------------------------------------------------- //
31 
32 // Disabled - skip execution of the pass.
33 // Enabled - execute the pass, propagate errors to the caller if any.
34 // ShadowEnabled - execute the pass in a shadow mode. The pass should not commit
35 //   any changes to the MLIR module it's processing. Failures are not propagated
36 //   to the caller.
37 // FallbackEnabled - execute the pass and commit all the changes to the MLIR
38 //   module in case of success. Do not commit any changes in case of failures,
39 //   let the rest of the pipeline run.
40 enum class MlirOptimizationPassState {
41   Disabled,
42   Enabled,
43   ShadowEnabled,
44   FallbackEnabled
45 };
46 
47 // An API for registering MLIR ModulePass with the Tensorflow runtime. These
48 // passes are running only for function graphs built by Tensorflow V2 and
49 // instantiated by the process_function_library_runtime (see
50 // FunctionOptimizationPass for details).
51 class MlirOptimizationPass {
52  public:
53   virtual ~MlirOptimizationPass() = default;
54   virtual llvm::StringRef name() const = 0;
55 
56   // Returns an enum value:
57   //   Enabled if the pass is enabled for the given graph with specified config.
58   //   Disabled if the pass is disabled.
59   //   ShadowEnabled if the pass needs to be executed in shadow mode.
60   //
61   // When the pass is ShadowEnabled, the pass is executed for metrics collection
62   // and reporting purposes only, but none of the changes it makes to the MLIR
63   // module will be committed.
64   // `device_set` can be nullptr if the devices information is not
65   // available or no device specific filtering is required.
66   virtual MlirOptimizationPassState GetPassState(
67       const DeviceSet* device_set, const ConfigProto& config_proto,
68       const Graph& graph) const = 0;
69 
70   virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
71                      const Graph& graph) = 0;
72 };
73 
74 class MlirOptimizationPassRegistry {
75  public:
76   struct PassRegistration {
77     int priority;
78     std::unique_ptr<MlirOptimizationPass> pass;
79   };
80 
81   struct PriorityComparator {
operatorPriorityComparator82     bool operator()(const PassRegistration& x,
83                     const PassRegistration& y) const {
84       return x.priority < y.priority;
85     }
86   };
87 
88   using Passes = std::set<PassRegistration, PriorityComparator>;
89 
90   // Returns the global registry of MLIR optimization passes.
91   static MlirOptimizationPassRegistry& Global();
92 
93   // Register optimization `pass` with the given `priority`.
Add(int priority,std::unique_ptr<MlirOptimizationPass> pass)94   void Add(int priority, std::unique_ptr<MlirOptimizationPass> pass) {
95     auto inserted = passes_.insert({priority, std::move(pass)});
96     CHECK(inserted.second)
97         << "Pass priority must be unique. "
98         << "Previously registered pass with the same priority: "
99         << inserted.first->pass->name().str();
100   }
101 
102   // Free the memory allocated for all passes.
ClearPasses()103   void ClearPasses() { passes_.clear(); }
104 
passes()105   const Passes& passes() const { return passes_; }
106 
107  private:
108   Passes passes_;
109 };
110 
111 // Function optimization pass that runs all MLIR passes registered in
112 // MlirOptimizationPassRegistry.
113 class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
114  public:
115   explicit MlirFunctionOptimizationPass(
116       const MlirOptimizationPassRegistry* registry =
117           &MlirOptimizationPassRegistry::Global())
registry_(registry)118       : registry_(registry) {}
119 
120   // Executes all of the underlying registered MlirOptimizationPasses.
121   //
122   // The MlirFunctionOptimizationPass will be executed in fully shadow mode if
123   // all of the underlying registered MlirOptimizationPasses are ShadowEnabled.
124   // In this case, no changes should be done to the original TF graph and no
125   // failures propagated back to the user. Failures during the conversion
126   // of TF graph to MLIR module and back will be treated as a soft
127   // failures, e.g., relevant stats will be recorded and no error returned
128   // back to the caller.
129   //
130   // In case some of the passes are shadow enabled while others are enabled,
131   // failures in the enabled passes will be treated as real errors and
132   // propagated back to the caller. Failure during the shadow pass execution
133   // is a soft failure.
134   Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
135              std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
136              std::vector<std::string>* control_ret_node_names,
137              bool* control_rets_updated) override;
138 
139  private:
140   const MlirOptimizationPassRegistry* registry_;
141 };
142 
143 // -------------------------------------------------------------------------- //
144 // MLIR passes running on Tensorflow V1 graphs.
145 // -------------------------------------------------------------------------- //
146 
147 // An API for registering MLIR ModulePass with the Tensorflow runtime. These
148 // passes are running only for V1 graphs (legacy graphs) executed via Session
149 // runtime. Graph importer updates legacy graph behavior to V2 constructs (e.g.
150 // it raises control flow from Switch/Merge nodes to functional control flow
151 // with If/While operations).
152 class MlirV1CompatOptimizationPass {
153  public:
154   virtual ~MlirV1CompatOptimizationPass() = default;
155   virtual llvm::StringRef name() const = 0;
156 
157   // Returns true if the pass is enabled for the given graph with specified
158   // config. `device_set` can be nullptr if the devices information is not
159   // available or no device specific filtering is required.
160   virtual bool IsEnabled(const DeviceSet* device_set,
161                          const ConfigProto& config_proto,
162                          const Graph& graph) const = 0;
163 
164   virtual Status Run(const GraphOptimizationPassOptions& options,
165                      mlir::ModuleOp module) = 0;
166 };
167 
168 class MlirV1CompatOptimizationPassRegistry {
169  public:
170   struct PassRegistration {
171     int priority;
172     std::unique_ptr<MlirV1CompatOptimizationPass> pass;
173   };
174 
175   struct PriorityComparator {
operatorPriorityComparator176     bool operator()(const PassRegistration& x,
177                     const PassRegistration& y) const {
178       return x.priority < y.priority;
179     }
180   };
181 
182   using Passes = std::set<PassRegistration, PriorityComparator>;
183 
184   // Returns the global registry of MLIR optimization passes.
185   static MlirV1CompatOptimizationPassRegistry& Global();
186 
Add(int priority,std::unique_ptr<MlirV1CompatOptimizationPass> pass)187   void Add(int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) {
188     auto inserted = passes_.insert({priority, std::move(pass)});
189     CHECK(inserted.second)
190         << "Pass priority must be unique. "
191         << "Previously registered pass with the same priority: "
192         << inserted.first->pass->name().str();
193   }
194 
passes()195   const Passes& passes() const { return passes_; }
196 
197  private:
198   Passes passes_;
199 };
200 
201 class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass {
202  public:
203   explicit MlirV1CompatGraphOptimizationPass(
204       const MlirV1CompatOptimizationPassRegistry* registry =
205           &MlirV1CompatOptimizationPassRegistry::Global())
registry_(registry)206       : registry_(registry) {}
207 
208   Status Run(const GraphOptimizationPassOptions& options) override;
209 
210  private:
211   const MlirV1CompatOptimizationPassRegistry* registry_;
212 };
213 
214 // -------------------------------------------------------------------------- //
215 // Helper classes for static registration of MLIR (V1 Compat) passes in the
216 // corresponding registry.
217 // -------------------------------------------------------------------------- //
218 
219 namespace mlir_pass_registration {
220 
221 class MlirOptimizationPassRegistration {
222  public:
MlirOptimizationPassRegistration(int priority,std::unique_ptr<MlirOptimizationPass> pass)223   explicit MlirOptimizationPassRegistration(
224       int priority, std::unique_ptr<MlirOptimizationPass> pass) {
225     MlirOptimizationPassRegistry::Global().Add(priority, std::move(pass));
226   }
227 };
228 
229 class MlirV1CompatOptimizationPassRegistration {
230  public:
MlirV1CompatOptimizationPassRegistration(int priority,std::unique_ptr<MlirV1CompatOptimizationPass> pass)231   explicit MlirV1CompatOptimizationPassRegistration(
232       int priority, std::unique_ptr<MlirV1CompatOptimizationPass> pass) {
233     MlirV1CompatOptimizationPassRegistry::Global().Add(priority,
234                                                        std::move(pass));
235   }
236 };
237 
238 }  // namespace mlir_pass_registration
239 
240 }  // namespace tensorflow
241 
242 #endif  // TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
243