• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
17 
18 #include "llvm/ADT/Optional.h"
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/Pass/Pass.h"  // from @llvm-project
22 #include "mlir/Pass/PassManager.h"  // from @llvm-project
23 #include "mlir/Transforms/Passes.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
25 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
26 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
32 
33 namespace mlir {
34 /// Create a pass to convert from the TFExecutor to the TF control dialect.
35 std::unique_ptr<OperationPass<FuncOp>>
36 CreateTFExecutorToControlDialectConversion();
37 }  // namespace mlir
38 
39 namespace tensorflow {
40 namespace {
41 // Data layout supported by TFLite.
42 const char kTFLiteDataLayout[] = "NHWC";
43 }  // namespace
44 
AddQuantizationPasses(const mlir::TFL::QuantizationSpecs & quant_specs,mlir::OpPassManager * pass_manager)45 void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
46                            mlir::OpPassManager* pass_manager) {
47   pass_manager->addNestedPass<mlir::FuncOp>(
48       mlir::TFL::CreatePrepareQuantizePass(quant_specs));
49   if (quant_specs.default_ranges.first.hasValue() ||
50       quant_specs.default_ranges.second.hasValue()) {
51     pass_manager->addNestedPass<mlir::FuncOp>(
52         mlir::TFL::CreateDefaultQuantParamsPass(
53             quant_specs.default_ranges.first.getValueOr(0.0),
54             quant_specs.default_ranges.second.getValueOr(0.0),
55             quant_specs.IsSignedInferenceType()));
56   }
57   pass_manager->addNestedPass<mlir::FuncOp>(
58       mlir::TFL::CreateQuantizePass(quant_specs.verify_numeric));
59   bool emit_quant_adaptor_ops =
60       quant_specs.inference_type != quant_specs.inference_input_type;
61   pass_manager->addNestedPass<mlir::FuncOp>(
62       mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
63 }
64 
AddTFToTFLConversionPasses(const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager,llvm::Optional<tensorflow::Session * > session)65 void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
66                                 mlir::OpPassManager* pass_manager,
67                                 llvm::Optional<tensorflow::Session*> session) {
68   mlir::TF::StandardPipelineOptions standard_pipeline_options;
69   standard_pipeline_options.enable_inliner = false;
70   standard_pipeline_options.form_clusters = pass_config.form_clusters;
71   mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
72   pass_manager->addNestedPass<mlir::FuncOp>(
73       mlir::TF::CreateDeviceIndexSelectorPass());
74 
75   // Add canonicalize pass to remove no-op session initializer pass.
76   pass_manager->addPass(mlir::createCanonicalizerPass());
77 
78   if (pass_config.shape_inference) {
79     pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
80   }
81 
82   if (session.hasValue()) {
83     // Add a pass that converts reference variables to resource variables.
84     pass_manager->addPass(
85         mlir::TF::
86             CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
87 
88     // Add a pass that promotes resource variable to the function arguments.
89     pass_manager->addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
90 
91     // Add a pass that creates global tensors and converts the function
92     // arguments to the tf_saved_model.bound_input arguments.
93     pass_manager->addPass(
94         mlir::tf_saved_model::CreateLiftVariablesPass(session.getValue()));
95   }
96 
97   // Keep this pass after the shape inference pass, which couldn't do shape
98   // inference for non-tf ops.
99   if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
100     pass_manager->addPass(
101         mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
102             pass_config.quant_specs.serialized_quant_stats));
103   }
104 
105   pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
106 
107   // The conversion pipeline has to follow the following orders:
108   // 1) Saved model related optimization like decompose resource ops
109   // 2) Convert composite functions like lstm/rnns, along with proper function
110   // inlining & dce.
111   // 3) Lower static tensor list pass.
112 
113   // This decomposes resource ops like ResourceGather into read-variable op
114   // followed by gather. This is used when the saved model import path is used
115   // during which resources dont get frozen in the python layer.
116   pass_manager->addNestedPass<mlir::FuncOp>(
117       mlir::TFDevice::CreateDecomposeResourceOpsPass());
118 
119   // Note:
120   // We need to fuse composite ops before LowerStaticTensorList pass.
121   // The tensorflow list is not supported right now by that pass.
122   // Enable fusing composite ops that can be lowered to built-in TFLite ops.
123   if (pass_config.emit_builtin_tflite_ops) {
124     pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
125   }
126 
127   pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
128 
129   pass_manager->addPass(mlir::createInlinerPass());
130   pass_manager->addPass(mlir::createSymbolDCEPass());
131 
132   if (pass_config.lower_tensor_list_ops) {
133     // TODO(haoliang): Add this pass by default.
134     pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
135   }
136 
137   // This pass does resource analysis of saved model global tensors and marks
138   // those deemed read-only as immutable.
139   pass_manager->addPass(
140       mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
141 
142   if (pass_config.shape_inference) {
143     // Add a shape inference pass to optimize away the unnecessary casts.
144     pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
145   }
146 
147   // Legalize while early to allow further constant folding.
148   // TODO(jpienaar): This may not actually matter as we do canonicalization
149   // after the legalize below, for now it needs to be below the above passes
150   // that work on TF dialect and before inliner so that the function calls in
151   // body and cond are inlined for optimization.
152   if (pass_config.legalize_tf_while) {
153     pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass());
154   }
155 
156   // Add function inlining pass. Both TF and TFLite dialects are opted into
157   // function inliner interface.
158   pass_manager->addPass(mlir::createInlinerPass());
159 
160   // TODO(jpienaar): Revise post dialect constants.
161   pass_manager->addNestedPass<mlir::FuncOp>(
162       mlir::TF::CreateDecodeConstantPass());
163   // Canonicalization includes const folding, which is utilized here to optimize
164   // away ops that can't get constant folded after PrepareTF pass. For example,
165   // tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
166   pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
167   pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
168   // This pass does dead code elimination based on symbol visibility.
169   pass_manager->addPass(mlir::createSymbolDCEPass());
170 
171   if (!pass_config.disable_variable_freezing) {
172     // This pass 'freezes' immutable global tensors and inlines them as tf
173     // constant ops.
174     pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass(
175         /*allow_mutable_tensors=*/pass_config.enable_tflite_variables));
176   }
177 
178   // The below passes only make sense if Builtin TFLite ops are enabled
179   // for emission.
180   if (pass_config.emit_builtin_tflite_ops) {
181     // Run shape inference after variables are converted to constants.
182     if (pass_config.shape_inference) {
183       pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
184     }
185     // Force layout supported by TFLite, this will transpose the data
186     // to match 'kTFLiteDataLayout'
187     mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
188     layout_optimization_options.force_data_format = kTFLiteDataLayout;
189     layout_optimization_options.skip_fold_transpose_in_ops = true;
190     mlir::TF::CreateLayoutOptimizationPipeline(
191         pass_manager->nest<mlir::FuncOp>(), layout_optimization_options);
192     // Prepare for TFLite dialect, rerun canonicalization, and then legalize to
193     // the TFLite dialect.
194     pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreatePrepareTFPass(
195         pass_config.unfold_batch_matmul,
196         /*allow_bf16_and_f16_type_legalization=*/!pass_config
197             .runtime_verification));
198     pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
199     if (pass_config.shape_inference) {
200       // Add a shape inference pass to optimize away the unnecessary casts.
201       // This also fixes the unranked shapes due to TF ops constant folding.
202       // TODO(fengliuai): remove this pass if TableGen patterns have a better
203       // to control the shapes for the intermediate results.
204       pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
205     }
206 
207     // Inline function calls that left in the graph after folding functional
208     // control flow ops (IfOp, CaseOp).
209     pass_manager->addPass(mlir::createInlinerPass());
210 
211     // This pass removes the asset file dependencies in hash table use cases.
212     pass_manager->addNestedPass<mlir::FuncOp>(
213         mlir::TF::CreateInitTextFileToImportPass());
214 
215     pass_manager->addNestedPass<mlir::FuncOp>(
216         mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
217     if (pass_config.enable_tflite_variables) {
218       pass_manager->addPass(mlir::TFL::CreateInitializeVariablesPass());
219       pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
220       pass_manager->addPass(mlir::TFL::CreateRemoveArgsAndGlobalTensors());
221     }
222     pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateOptimizePass());
223     // This pass operates on TensorFlow ops but is triggered after legalization
224     // so that it can target constants introduced once TensorFlow Identity ops
225     // are removed during legalization.
226     pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
227     pass_manager->addNestedPass<mlir::FuncOp>(
228         mlir::TFL::CreateRaiseCustomOpsPass());
229     pass_manager->addPass(mlir::createSymbolDCEPass());
230     pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
231     pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
232 
233     // Run quantization after all the floating point model conversion is
234     // completed.
235     if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
236       AddQuantizationPasses(pass_config.quant_specs, pass_manager);
237     }
238 
239     // This pass should be always at the end of the model
240     // conversion (even after quantization). Some TFL ops like unidirectional
241     // sequence lstm will have stateful operands and some optimization passes
242     // will merge those operands if they have identical values & types. However,
243     // it's not desired by TFL. This pass serves as a "fix" pass to split the
244     // merged inputs until we have 1st class variable support or reuse
245     // tf.variable to model this.
246     pass_manager->addNestedPass<mlir::FuncOp>(
247         mlir::TFL::CreateSplitMergedOperandsPass());
248 
249     // Add CallOnceOp when there is a session initializer function in tf saved
250     // model dialect.
251     pass_manager->addPass(
252         mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
253   }
254 }
255 
256 }  // namespace tensorflow
257 
258 namespace mlir {
259 namespace TFL {
260 
261 struct StandardPipelineOptions
262     : public PassPipelineOptions<StandardPipelineOptions> {
263   // TODO(b/150915052): All the tf_tfl_translate_cl flags should
264   // move inside this.
265 };
266 
267 // NOLINTNEXTLINE
268 // This creates the standard pass pipeline for TF->TFLite. This
269 // represents a std configuration for TFLite, for use with APIs like
270 // tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline
271 // This does not yet include quantization passes.
CreateTFLStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)272 void CreateTFLStandardPipeline(OpPassManager& pm,
273                                const StandardPipelineOptions& options) {
274   OpPassManager& func_pm = pm.nest<FuncOp>();
275 
276   // tf_executor dialect passes - Cleaning up the IR.
277   mlir::TF::StandardPipelineOptions standard_pipeline_options;
278   mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
279 
280   // This is needed for control flow support with TF TensorList.
281   pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());
282 
283   // Saved model pass to mark global tensors immutable.
284   pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
285   // Op fusion pass.
286   pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
287 
288   pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeTFWhilePass());
289 
290   pm.addPass(mlir::createInlinerPass());
291 
292   // Canonicalize, CSE etc.
293   pm.addPass(mlir::TF::CreateDecodeConstantPass());
294   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
295   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
296   // DCE for private symbols.
297   pm.addPass(mlir::createSymbolDCEPass());
298 
299   // freeze global tensors.
300   pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
301 
302   // TFLite dialect passes.
303   pm.addPass(mlir::TFL::CreatePrepareTFPass(
304       /*unfold_batch_matmul=*/true, /*allow_bf16_type_legalization=*/false));
305   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
306   pm.addPass(
307       mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
308   pm.addPass(mlir::TFL::CreateOptimizePass());
309   pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
310   pm.addPass(mlir::createSymbolDCEPass());
311 
312   // Canonicalize, CSE etc.
313   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
314   pm.addNestedPass<mlir::tf_saved_model::SessionInitializerOp>(
315       mlir::createCanonicalizerPass());
316   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
317 
318   // Pass for stateful operands like LSTM.
319   pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass());
320 
321   pm.addPass(mlir::TFL::CreateWhileOutlinePass());
322 
323   pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateRuntimeVerifyPass());
324 }
325 
326 // Registers a pass pipeline for the standard TFL passes.
327 static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
328     "tfl-standard-pipeline",
329     "Run the standard passes involved in transforming/optimizing the TF "
330     "program to TFLite after "
331     "importing into MLIR.",
332     CreateTFLStandardPipeline);
333 
334 }  // namespace TFL
335 }  // namespace mlir
336