1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
17
18 #include <string>
19
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Pass/PassManager.h" // from @llvm-project
26 #include "mlir/Transforms/Passes.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
36
37 namespace mlir {
38 /// Create a pass to convert from the TFExecutor to the TF control dialect.
39 std::unique_ptr<OperationPass<FuncOp>>
40 CreateTFExecutorToControlDialectConversion();
41 } // namespace mlir
42
43 namespace tensorflow {
44 namespace {
45 // Data layout supported by TFLite.
46 const char kTFLiteDataLayout[] = "NHWC";
47 } // namespace
48
AddQuantizationPasses(const mlir::TFL::QuantizationSpecs & quant_specs,mlir::OpPassManager * pass_manager)49 void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
50 mlir::OpPassManager* pass_manager) {
51 pass_manager->addNestedPass<mlir::FuncOp>(
52 mlir::TFL::CreatePrepareQuantizePass(quant_specs));
53 if (quant_specs.default_ranges.first.hasValue() ||
54 quant_specs.default_ranges.second.hasValue()) {
55 pass_manager->addNestedPass<mlir::FuncOp>(
56 mlir::TFL::CreateDefaultQuantParamsPass(
57 quant_specs.default_ranges.first.getValueOr(0.0),
58 quant_specs.default_ranges.second.getValueOr(0.0),
59 quant_specs.IsSignedInferenceType()));
60 }
61 pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass(
62 quant_specs.verify_numeric, quant_specs.whole_model_verify));
63 bool emit_quant_adaptor_ops =
64 quant_specs.inference_type != quant_specs.inference_input_type;
65 pass_manager->addNestedPass<mlir::FuncOp>(
66 mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
67 }
68
AddTFToTFLConversionPasses(const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager,llvm::Optional<tensorflow::Session * > session)69 void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
70 const toco::TocoFlags& toco_flags,
71 const mlir::TFL::PassConfig& pass_config,
72 mlir::OpPassManager* pass_manager,
73 llvm::Optional<tensorflow::Session*> session) {
74 // This pass wraps all the tf.FakeQuant ops in a custom op so they are not
75 // folded before being converted to tfl.quantize and tfl.dequantize ops.
76 auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps();
77 pass_manager->addNestedPass<mlir::FuncOp>(
78 mlir::TFL::CreateRaiseCustomOpsPass(wrapped_ops));
79
80 mlir::TF::StandardPipelineOptions standard_pipeline_options;
81 standard_pipeline_options.enable_inliner = false;
82 standard_pipeline_options.form_clusters = pass_config.form_clusters;
83 mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
84 pass_manager->addNestedPass<mlir::FuncOp>(
85 mlir::TF::CreateDeviceIndexSelectorPass());
86
87 // Add canonicalize pass to remove no-op session initializer pass.
88 pass_manager->addPass(mlir::createCanonicalizerPass());
89
90 if (pass_config.guarantee_all_funcs_one_use) {
91 pass_manager->addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
92 }
93 if (pass_config.shape_inference) {
94 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
95 }
96
97 // TODO(b/149099381): Remove after handling WhileRegion in favor of later
98 // instance.
99 if (session.hasValue()) {
100 pass_manager->addPass(
101 mlir::tf_saved_model::CreateFreezeVariablesPass(session.getValue()));
102 }
103
104 // Keep this pass after the shape inference pass, which couldn't do shape
105 // inference for non-tf ops.
106 if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
107 pass_manager->addPass(
108 mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
109 pass_config.quant_specs.serialized_quant_stats));
110 }
111
112 pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
113
114 // The conversion pipeline has to follow the following orders:
115 // 1) Saved model related optimization like decompose resource ops
116 // 2) Convert composite functions like lstm/rnns, along with proper function
117 // inlining & dce.
118 // 3) Lower static tensor list pass.
119
120 // This decomposes resource ops like ResourceGather into read-variable op
121 // followed by gather. This is used when the saved model import path is used
122 // during which resources dont get frozen in the python layer.
123 pass_manager->addNestedPass<mlir::FuncOp>(
124 mlir::TFDevice::CreateDecomposeResourceOpsPass());
125
126 // Try freezing again read only vars post resource decomposition.
127 if (session.hasValue()) {
128 pass_manager->addPass(
129 mlir::tf_saved_model::CreateFreezeVariablesPass(session.getValue()));
130 }
131
132 // Note:
133 // We need to fuse composite ops before LowerStaticTensorList pass.
134 // The tensorflow list is not supported right now by that pass.
135 // Enable fusing composite ops that can be lowered to built-in TFLite ops.
136 if (pass_config.emit_builtin_tflite_ops) {
137 pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
138 }
139
140 pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
141
142 pass_manager->addPass(mlir::createInlinerPass());
143 pass_manager->addPass(mlir::createSymbolDCEPass());
144
145 if (pass_config.lower_tensor_list_ops) {
146 // TODO(haoliang): Add this pass by default.
147 pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass(
148 /*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() ||
149 toco_flags.enable_select_tf_ops()));
150 }
151
152 // This pass does resource analysis of saved model global tensors and marks
153 // those deemed read-only as immutable.
154 pass_manager->addPass(
155 mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
156
157 if (pass_config.shape_inference) {
158 // Add a shape inference pass to optimize away the unnecessary casts.
159 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
160 }
161
162 // Legalize while early to allow further constant folding.
163 // TODO(jpienaar): This may not actually matter as we do canonicalization
164 // after the legalize below, for now it needs to be below the above passes
165 // that work on TF dialect and before inliner so that the function calls in
166 // body and cond are inlined for optimization.
167 if (pass_config.legalize_tf_while) {
168 pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass());
169 }
170
171 // Add function inlining pass. Both TF and TFLite dialects are opted into
172 // function inliner interface.
173 pass_manager->addPass(mlir::createInlinerPass());
174
175 // Canonicalization includes const folding, which is utilized here to optimize
176 // away ops that can't get constant folded after PrepareTF pass. For example,
177 // tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
178 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
179 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
180 // This pass does dead code elimination based on symbol visibility.
181 pass_manager->addPass(mlir::createSymbolDCEPass());
182
183 if (!pass_config.disable_variable_freezing) {
184 // This pass 'freezes' immutable global tensors and inlines them as tf
185 // constant ops.
186 pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass(
187 /*allow_mutable_tensors=*/pass_config.enable_tflite_variables));
188 }
189
190 if (!model_flags.saved_model_dir().empty()) {
191 // This pass 'freezes' tf saved model asset ops and inlines as string values
192 // in a format of the tf constant op.
193 pass_manager->addPass(mlir::tf_saved_model::CreateFreezeAssetsPass(
194 model_flags.saved_model_dir()));
195 }
196
197 // The below passes only make sense if Builtin TFLite ops are enabled
198 // for emission.
199 if (pass_config.emit_builtin_tflite_ops) {
200 // Run shape inference after variables are converted to constants.
201 if (pass_config.shape_inference) {
202 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
203 }
204 // Force layout supported by TFLite, this will transpose the data
205 // to match 'kTFLiteDataLayout'
206 mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
207 layout_optimization_options.force_data_format = kTFLiteDataLayout;
208 layout_optimization_options.skip_fold_transpose_in_ops = true;
209 mlir::TF::CreateLayoutOptimizationPipeline(
210 pass_manager->nest<mlir::FuncOp>(), layout_optimization_options);
211 // Prepare for TFLite dialect, rerun canonicalization, and then legalize to
212 // the TFLite dialect.
213 pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreatePrepareTFPass(
214 pass_config.unfold_batch_matmul,
215 /*allow_bf16_and_f16_type_legalization=*/!pass_config
216 .runtime_verification));
217 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
218 if (pass_config.shape_inference) {
219 // Add a shape inference pass to optimize away the unnecessary casts.
220 // This also fixes the unranked shapes due to TF ops constant folding.
221 // TODO(fengliuai): remove this pass if TableGen patterns have a better
222 // to control the shapes for the intermediate results.
223 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
224 }
225
226 // Inline function calls that left in the graph after folding functional
227 // control flow ops (IfOp, CaseOp).
228 pass_manager->addPass(mlir::createInlinerPass());
229
230 // This pass removes the asset file dependencies in hash table use cases.
231 pass_manager->addNestedPass<mlir::FuncOp>(
232 mlir::TF::CreateInitTextFileToImportPass(
233 model_flags.saved_model_dir()));
234
235 pass_manager->addNestedPass<mlir::FuncOp>(
236 mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
237 if (pass_config.enable_tflite_variables) {
238 pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass());
239 pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
240 }
241 pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass());
242 pass_manager->addNestedPass<mlir::FuncOp>(
243 mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
244 // This pass operates on TensorFlow ops but is triggered after legalization
245 // so that it can target constants introduced once TensorFlow Identity ops
246 // are removed during legalization.
247 pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
248 std::vector<std::string> empty_wrapped_ops({});
249 pass_manager->addNestedPass<mlir::FuncOp>(
250 mlir::TFL::CreateRaiseCustomOpsPass(empty_wrapped_ops));
251 pass_manager->addPass(mlir::createSymbolDCEPass());
252 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
253 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
254
255 // Run quantization after all the floating point model conversion is
256 // completed.
257 if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
258 AddQuantizationPasses(pass_config.quant_specs, pass_manager);
259 }
260
261 pass_manager->addPass(mlir::createCanonicalizerPass());
262
263 // This pass should be always at the end of the model
264 // conversion (even after quantization). Some TFL ops like unidirectional
265 // sequence lstm will have stateful operands and some optimization passes
266 // will merge those operands if they have identical values & types. However,
267 // it's not desired by TFL. This pass serves as a "fix" pass to split the
268 // merged inputs until we have 1st class variable support or reuse
269 // tf.variable to model this.
270 pass_manager->addNestedPass<mlir::FuncOp>(
271 mlir::TFL::CreateSplitMergedOperandsPass());
272
273 // Add CallOnceOp when there is a session initializer function in tf saved
274 // model dialect.
275 pass_manager->addPass(
276 mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
277 }
278 if (pass_config.unfold_large_splat_constant) {
279 pass_manager->addPass(mlir::TFL::CreateUnfoldLargeSplatConstantPass());
280 }
281 }
282
AddTFToTFLConversionPasses(const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager,llvm::Optional<tensorflow::Session * > session)283 void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
284 mlir::OpPassManager* pass_manager,
285 llvm::Optional<tensorflow::Session*> session) {
286 const toco::ModelFlags model_flags;
287 const toco::TocoFlags toco_flags;
288 AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config, pass_manager,
289 session);
290 }
291
292 } // namespace tensorflow
293
294 namespace mlir {
295 namespace TFL {
296
297 struct StandardPipelineOptions
298 : public PassPipelineOptions<StandardPipelineOptions> {
299 // TODO(b/150915052): All the tf_tfl_translate_cl flags should
300 // move inside this.
301 };
302
303 // NOLINTNEXTLINE
304 // This creates the standard pass pipeline for TF->TFLite. This
305 // represents a std configuration for TFLite, for use with APIs like
306 // tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline
307 // This does not yet include quantization passes.
CreateTFLStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)308 void CreateTFLStandardPipeline(OpPassManager& pm,
309 const StandardPipelineOptions& options) {
310 OpPassManager& func_pm = pm.nest<FuncOp>();
311
312 // tf_executor dialect passes - Cleaning up the IR.
313 mlir::TF::StandardPipelineOptions standard_pipeline_options;
314 mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
315
316 // This is needed for control flow support with TF TensorList.
317 pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass(
318 /*allow_tensorlist_pass_through=*/false));
319
320 // Saved model pass to mark global tensors immutable.
321 pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
322 // Op fusion pass.
323 pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
324
325 pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeTFWhilePass());
326
327 pm.addPass(mlir::createInlinerPass());
328
329 // Canonicalize, CSE etc.
330 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
331 pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
332 // DCE for private symbols.
333 pm.addPass(mlir::createSymbolDCEPass());
334
335 // freeze global tensors.
336 pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
337
338 // TFLite dialect passes.
339 pm.addPass(mlir::TFL::CreatePrepareTFPass(
340 /*unfold_batch_matmul=*/true,
341 /*allow_bf16_and_f16_type_legalization=*/false));
342 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
343 pm.addPass(
344 mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
345 pm.addPass(mlir::TFL::CreateLegalizeHashTablesPass());
346 pm.addPass(mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
347 pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
348 pm.addPass(mlir::createSymbolDCEPass());
349
350 // Canonicalize, CSE etc.
351 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
352 pm.addNestedPass<mlir::tf_saved_model::SessionInitializerOp>(
353 mlir::createCanonicalizerPass());
354 pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
355
356 // Pass for stateful operands like LSTM.
357 pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass());
358
359 pm.addPass(mlir::TFL::CreateWhileOutlinePass());
360
361 pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateRuntimeVerifyPass());
362 }
363
364 // Registers a pass pipeline for the standard TFL passes.
365 static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
366 "tfl-standard-pipeline",
367 "Run the standard passes involved in transforming/optimizing the TF "
368 "program to TFLite after "
369 "importing into MLIR.",
370 CreateTFLStandardPipeline);
371
372 } // namespace TFL
373 } // namespace mlir
374