• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
17 
18 #include <string>
19 
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassManager.h"  // from @llvm-project
26 #include "mlir/Transforms/Passes.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
36 
37 namespace mlir {
38 /// Create a pass to convert from the TFExecutor to the TF control dialect.
39 std::unique_ptr<OperationPass<FuncOp>>
40 CreateTFExecutorToControlDialectConversion();
41 }  // namespace mlir
42 
43 namespace tensorflow {
44 namespace {
45 // Data layout supported by TFLite.
46 const char kTFLiteDataLayout[] = "NHWC";
47 }  // namespace
48 
AddQuantizationPasses(const mlir::TFL::QuantizationSpecs & quant_specs,mlir::OpPassManager * pass_manager)49 void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
50                            mlir::OpPassManager* pass_manager) {
51   pass_manager->addNestedPass<mlir::FuncOp>(
52       mlir::TFL::CreatePrepareQuantizePass(quant_specs));
53   if (quant_specs.default_ranges.first.hasValue() ||
54       quant_specs.default_ranges.second.hasValue()) {
55     pass_manager->addNestedPass<mlir::FuncOp>(
56         mlir::TFL::CreateDefaultQuantParamsPass(
57             quant_specs.default_ranges.first.getValueOr(0.0),
58             quant_specs.default_ranges.second.getValueOr(0.0),
59             quant_specs.IsSignedInferenceType()));
60   }
61   pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass(
62       quant_specs.verify_numeric, quant_specs.whole_model_verify));
63   bool emit_quant_adaptor_ops =
64       quant_specs.inference_type != quant_specs.inference_input_type;
65   pass_manager->addNestedPass<mlir::FuncOp>(
66       mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
67 }
68 
AddTFToTFLConversionPasses(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager,llvm::Optional<tensorflow::Session * > session)69 void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
70                                 const toco::TocoFlags& toco_flags,
71                                 const mlir::TFL::PassConfig& pass_config,
72                                 mlir::OpPassManager* pass_manager,
73                                 llvm::Optional<tensorflow::Session*> session) {
74   // This pass wraps all the tf.FakeQuant ops in a custom op so they are not
75   // folded before being converted to tfl.quantize and tfl.dequantize ops.
76   auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps();
77   pass_manager->addNestedPass<mlir::FuncOp>(
78       mlir::TFL::CreateRaiseCustomOpsPass(wrapped_ops));
79 
80   mlir::TF::StandardPipelineOptions standard_pipeline_options;
81   standard_pipeline_options.enable_inliner = false;
82   standard_pipeline_options.form_clusters = pass_config.form_clusters;
83   mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
84   pass_manager->addNestedPass<mlir::FuncOp>(
85       mlir::TF::CreateDeviceIndexSelectorPass());
86 
87   // Add canonicalize pass to remove no-op session initializer pass.
88   pass_manager->addPass(mlir::createCanonicalizerPass());
89 
90   if (pass_config.guarantee_all_funcs_one_use) {
91     pass_manager->addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
92   }
93   if (pass_config.shape_inference) {
94     pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
95   }
96 
97   // TODO(b/149099381): Remove after handling WhileRegion in favor of later
98   // instance.
99   if (session.hasValue()) {
100     pass_manager->addPass(
101         mlir::tf_saved_model::CreateFreezeVariablesPass(session.getValue()));
102   }
103 
104   // Keep this pass after the shape inference pass, which couldn't do shape
105   // inference for non-tf ops.
106   if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
107     pass_manager->addPass(
108         mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
109             pass_config.quant_specs.serialized_quant_stats));
110   }
111 
112   pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
113 
114   // The conversion pipeline has to follow the following orders:
115   // 1) Saved model related optimization like decompose resource ops
116   // 2) Convert composite functions like lstm/rnns, along with proper function
117   // inlining & dce.
118   // 3) Lower static tensor list pass.
119 
120   // This decomposes resource ops like ResourceGather into read-variable op
121   // followed by gather. This is used when the saved model import path is used
122   // during which resources dont get frozen in the python layer.
123   pass_manager->addNestedPass<mlir::FuncOp>(
124       mlir::TFDevice::CreateDecomposeResourceOpsPass());
125 
126   // Try freezing again read only vars post resource decomposition.
127   if (session.hasValue()) {
128     pass_manager->addPass(
129         mlir::tf_saved_model::CreateFreezeVariablesPass(session.getValue()));
130   }
131 
132   // Note:
133   // We need to fuse composite ops before LowerStaticTensorList pass.
134   // The tensorflow list is not supported right now by that pass.
135   // Enable fusing composite ops that can be lowered to built-in TFLite ops.
136   if (pass_config.emit_builtin_tflite_ops) {
137     pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
138   }
139 
140   pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
141 
142   pass_manager->addPass(mlir::createInlinerPass());
143   pass_manager->addPass(mlir::createSymbolDCEPass());
144 
145   if (pass_config.lower_tensor_list_ops) {
146     // TODO(haoliang): Add this pass by default.
147     pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass(
148         /*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() ||
149         toco_flags.enable_select_tf_ops()));
150   }
151 
152   // This pass does resource analysis of saved model global tensors and marks
153   // those deemed read-only as immutable.
154   pass_manager->addPass(
155       mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
156 
157   if (pass_config.shape_inference) {
158     // Add a shape inference pass to optimize away the unnecessary casts.
159     pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
160   }
161 
162   // Legalize while early to allow further constant folding.
163   // TODO(jpienaar): This may not actually matter as we do canonicalization
164   // after the legalize below, for now it needs to be below the above passes
165   // that work on TF dialect and before inliner so that the function calls in
166   // body and cond are inlined for optimization.
167   if (pass_config.legalize_tf_while) {
168     pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass());
169   }
170 
171   // Add function inlining pass. Both TF and TFLite dialects are opted into
172   // function inliner interface.
173   pass_manager->addPass(mlir::createInlinerPass());
174 
175   // Canonicalization includes const folding, which is utilized here to optimize
176   // away ops that can't get constant folded after PrepareTF pass. For example,
177   // tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
178   pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
179   pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
180   // This pass does dead code elimination based on symbol visibility.
181   pass_manager->addPass(mlir::createSymbolDCEPass());
182 
183   if (!pass_config.disable_variable_freezing) {
184     // This pass 'freezes' immutable global tensors and inlines them as tf
185     // constant ops.
186     pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass(
187         /*allow_mutable_tensors=*/pass_config.enable_tflite_variables));
188   }
189 
190   if (!model_flags.saved_model_dir().empty()) {
191     // This pass 'freezes' tf saved model asset ops and inlines as string values
192     // in a format of the tf constant op.
193     pass_manager->addPass(mlir::tf_saved_model::CreateFreezeAssetsPass(
194         model_flags.saved_model_dir()));
195   }
196 
197   // The below passes only make sense if Builtin TFLite ops are enabled
198   // for emission.
199   if (pass_config.emit_builtin_tflite_ops) {
200     // Run shape inference after variables are converted to constants.
201     if (pass_config.shape_inference) {
202       pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
203     }
204     // Force layout supported by TFLite, this will transpose the data
205     // to match 'kTFLiteDataLayout'
206     mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
207     layout_optimization_options.force_data_format = kTFLiteDataLayout;
208     layout_optimization_options.skip_fold_transpose_in_ops = true;
209     mlir::TF::CreateLayoutOptimizationPipeline(
210         pass_manager->nest<mlir::FuncOp>(), layout_optimization_options);
211     // Prepare for TFLite dialect, rerun canonicalization, and then legalize to
212     // the TFLite dialect.
213     pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreatePrepareTFPass(
214         pass_config.unfold_batch_matmul,
215         /*allow_bf16_and_f16_type_legalization=*/!pass_config
216             .runtime_verification));
217     pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
218     if (pass_config.shape_inference) {
219       // Add a shape inference pass to optimize away the unnecessary casts.
220       // This also fixes the unranked shapes due to TF ops constant folding.
221       // TODO(fengliuai): remove this pass if TableGen patterns have a better
222       // to control the shapes for the intermediate results.
223       pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
224     }
225 
226     // Inline function calls that left in the graph after folding functional
227     // control flow ops (IfOp, CaseOp).
228     pass_manager->addPass(mlir::createInlinerPass());
229 
230     // This pass removes the asset file dependencies in hash table use cases.
231     pass_manager->addNestedPass<mlir::FuncOp>(
232         mlir::TF::CreateInitTextFileToImportPass(
233             model_flags.saved_model_dir()));
234 
235     pass_manager->addNestedPass<mlir::FuncOp>(
236         mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
237     if (pass_config.enable_tflite_variables) {
238       pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass());
239       pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
240     }
241     pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass());
242     pass_manager->addNestedPass<mlir::FuncOp>(
243         mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
244     // This pass operates on TensorFlow ops but is triggered after legalization
245     // so that it can target constants introduced once TensorFlow Identity ops
246     // are removed during legalization.
247     pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
248     std::vector<std::string> empty_wrapped_ops({});
249     pass_manager->addNestedPass<mlir::FuncOp>(
250         mlir::TFL::CreateRaiseCustomOpsPass(empty_wrapped_ops));
251     pass_manager->addPass(mlir::createSymbolDCEPass());
252     pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
253     pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
254 
255     // Run quantization after all the floating point model conversion is
256     // completed.
257     if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
258       AddQuantizationPasses(pass_config.quant_specs, pass_manager);
259     }
260 
261     pass_manager->addPass(mlir::createCanonicalizerPass());
262 
263     // This pass should be always at the end of the model
264     // conversion (even after quantization). Some TFL ops like unidirectional
265     // sequence lstm will have stateful operands and some optimization passes
266     // will merge those operands if they have identical values & types. However,
267     // it's not desired by TFL. This pass serves as a "fix" pass to split the
268     // merged inputs until we have 1st class variable support or reuse
269     // tf.variable to model this.
270     pass_manager->addNestedPass<mlir::FuncOp>(
271         mlir::TFL::CreateSplitMergedOperandsPass());
272 
273     // Add CallOnceOp when there is a session initializer function in tf saved
274     // model dialect.
275     pass_manager->addPass(
276         mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
277   }
278   if (pass_config.unfold_large_splat_constant) {
279     pass_manager->addPass(mlir::TFL::CreateUnfoldLargeSplatConstantPass());
280   }
281 }
282 
AddTFToTFLConversionPasses(const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager,llvm::Optional<tensorflow::Session * > session)283 void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
284                                 mlir::OpPassManager* pass_manager,
285                                 llvm::Optional<tensorflow::Session*> session) {
286   const toco::ModelFlags model_flags;
287   const toco::TocoFlags toco_flags;
288   AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config, pass_manager,
289                              session);
290 }
291 
292 }  // namespace tensorflow
293 
294 namespace mlir {
295 namespace TFL {
296 
297 struct StandardPipelineOptions
298     : public PassPipelineOptions<StandardPipelineOptions> {
299   // TODO(b/150915052): All the tf_tfl_translate_cl flags should
300   // move inside this.
301 };
302 
303 // NOLINTNEXTLINE
304 // This creates the standard pass pipeline for TF->TFLite. This
305 // represents a std configuration for TFLite, for use with APIs like
306 // tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline
307 // This does not yet include quantization passes.
CreateTFLStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)308 void CreateTFLStandardPipeline(OpPassManager& pm,
309                                const StandardPipelineOptions& options) {
310   OpPassManager& func_pm = pm.nest<FuncOp>();
311 
312   // tf_executor dialect passes - Cleaning up the IR.
313   mlir::TF::StandardPipelineOptions standard_pipeline_options;
314   mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
315 
316   // This is needed for control flow support with TF TensorList.
317   pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass(
318       /*allow_tensorlist_pass_through=*/false));
319 
320   // Saved model pass to mark global tensors immutable.
321   pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
322   // Op fusion pass.
323   pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
324 
325   pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeTFWhilePass());
326 
327   pm.addPass(mlir::createInlinerPass());
328 
329   // Canonicalize, CSE etc.
330   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
331   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
332   // DCE for private symbols.
333   pm.addPass(mlir::createSymbolDCEPass());
334 
335   // freeze global tensors.
336   pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
337 
338   // TFLite dialect passes.
339   pm.addPass(mlir::TFL::CreatePrepareTFPass(
340       /*unfold_batch_matmul=*/true,
341       /*allow_bf16_and_f16_type_legalization=*/false));
342   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
343   pm.addPass(
344       mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
345   pm.addPass(mlir::TFL::CreateLegalizeHashTablesPass());
346   pm.addPass(mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
347   pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
348   pm.addPass(mlir::createSymbolDCEPass());
349 
350   // Canonicalize, CSE etc.
351   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
352   pm.addNestedPass<mlir::tf_saved_model::SessionInitializerOp>(
353       mlir::createCanonicalizerPass());
354   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
355 
356   // Pass for stateful operands like LSTM.
357   pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass());
358 
359   pm.addPass(mlir::TFL::CreateWhileOutlinePass());
360 
361   pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateRuntimeVerifyPass());
362 }
363 
364 // Registers a pass pipeline for the standard TFL passes.
365 static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
366     "tfl-standard-pipeline",
367     "Run the standard passes involved in transforming/optimizing the TF "
368     "program to TFLite after "
369     "importing into MLIR.",
370     CreateTFLStandardPipeline);
371 
372 }  // namespace TFL
373 }  // namespace mlir
374