1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
17
18 #include "llvm/ADT/Optional.h"
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
21 #include "mlir/Pass/Pass.h" // from @llvm-project
22 #include "mlir/Pass/PassManager.h" // from @llvm-project
23 #include "mlir/Transforms/Passes.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
25 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
26 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
32
33 namespace mlir {
34 /// Create a pass to convert from the TFExecutor to the TF control dialect.
35 std::unique_ptr<OperationPass<FuncOp>>
36 CreateTFExecutorToControlDialectConversion();
37 } // namespace mlir
38
39 namespace tensorflow {
40 namespace {
41 // Data layout supported by TFLite.
42 const char kTFLiteDataLayout[] = "NHWC";
43 } // namespace
44
AddQuantizationPasses(const mlir::TFL::QuantizationSpecs & quant_specs,mlir::OpPassManager * pass_manager)45 void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
46 mlir::OpPassManager* pass_manager) {
47 pass_manager->addNestedPass<mlir::FuncOp>(
48 mlir::TFL::CreatePrepareQuantizePass(quant_specs));
49 if (quant_specs.default_ranges.first.hasValue() ||
50 quant_specs.default_ranges.second.hasValue()) {
51 pass_manager->addNestedPass<mlir::FuncOp>(
52 mlir::TFL::CreateDefaultQuantParamsPass(
53 quant_specs.default_ranges.first.getValueOr(0.0),
54 quant_specs.default_ranges.second.getValueOr(0.0),
55 quant_specs.IsSignedInferenceType()));
56 }
57 pass_manager->addNestedPass<mlir::FuncOp>(
58 mlir::TFL::CreateQuantizePass(quant_specs.verify_numeric));
59 bool emit_quant_adaptor_ops =
60 quant_specs.inference_type != quant_specs.inference_input_type;
61 pass_manager->addNestedPass<mlir::FuncOp>(
62 mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
63 }
64
AddTFToTFLConversionPasses(const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager,llvm::Optional<tensorflow::Session * > session)65 void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
66 mlir::OpPassManager* pass_manager,
67 llvm::Optional<tensorflow::Session*> session) {
68 mlir::TF::StandardPipelineOptions standard_pipeline_options;
69 standard_pipeline_options.enable_inliner = false;
70 standard_pipeline_options.form_clusters = pass_config.form_clusters;
71 mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
72 pass_manager->addNestedPass<mlir::FuncOp>(
73 mlir::TF::CreateDeviceIndexSelectorPass());
74
75 // Add canonicalize pass to remove no-op session initializer pass.
76 pass_manager->addPass(mlir::createCanonicalizerPass());
77
78 if (pass_config.shape_inference) {
79 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
80 }
81
82 if (session.hasValue()) {
83 // Add a pass that converts reference variables to resource variables.
84 pass_manager->addPass(
85 mlir::TF::
86 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
87
88 // Add a pass that promotes resource variable to the function arguments.
89 pass_manager->addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
90
91 // Add a pass that creates global tensors and converts the function
92 // arguments to the tf_saved_model.bound_input arguments.
93 pass_manager->addPass(
94 mlir::tf_saved_model::CreateLiftVariablesPass(session.getValue()));
95 }
96
97 // Keep this pass after the shape inference pass, which couldn't do shape
98 // inference for non-tf ops.
99 if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
100 pass_manager->addPass(
101 mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
102 pass_config.quant_specs.serialized_quant_stats));
103 }
104
105 pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
106
107 // The conversion pipeline has to follow the following orders:
108 // 1) Saved model related optimization like decompose resource ops
109 // 2) Convert composite functions like lstm/rnns, along with proper function
110 // inlining & dce.
111 // 3) Lower static tensor list pass.
112
113 // This decomposes resource ops like ResourceGather into read-variable op
114 // followed by gather. This is used when the saved model import path is used
115 // during which resources dont get frozen in the python layer.
116 pass_manager->addNestedPass<mlir::FuncOp>(
117 mlir::TFDevice::CreateDecomposeResourceOpsPass());
118
119 // Note:
120 // We need to fuse composite ops before LowerStaticTensorList pass.
121 // The tensorflow list is not supported right now by that pass.
122 // Enable fusing composite ops that can be lowered to built-in TFLite ops.
123 if (pass_config.emit_builtin_tflite_ops) {
124 pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
125 }
126
127 pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
128
129 pass_manager->addPass(mlir::createInlinerPass());
130 pass_manager->addPass(mlir::createSymbolDCEPass());
131
132 if (pass_config.lower_tensor_list_ops) {
133 // TODO(haoliang): Add this pass by default.
134 pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
135 }
136
137 // This pass does resource analysis of saved model global tensors and marks
138 // those deemed read-only as immutable.
139 pass_manager->addPass(
140 mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
141
142 if (pass_config.shape_inference) {
143 // Add a shape inference pass to optimize away the unnecessary casts.
144 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
145 }
146
147 // Legalize while early to allow further constant folding.
148 // TODO(jpienaar): This may not actually matter as we do canonicalization
149 // after the legalize below, for now it needs to be below the above passes
150 // that work on TF dialect and before inliner so that the function calls in
151 // body and cond are inlined for optimization.
152 if (pass_config.legalize_tf_while) {
153 pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass());
154 }
155
156 // Add function inlining pass. Both TF and TFLite dialects are opted into
157 // function inliner interface.
158 pass_manager->addPass(mlir::createInlinerPass());
159
160 // TODO(jpienaar): Revise post dialect constants.
161 pass_manager->addNestedPass<mlir::FuncOp>(
162 mlir::TF::CreateDecodeConstantPass());
163 // Canonicalization includes const folding, which is utilized here to optimize
164 // away ops that can't get constant folded after PrepareTF pass. For example,
165 // tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
166 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
167 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
168 // This pass does dead code elimination based on symbol visibility.
169 pass_manager->addPass(mlir::createSymbolDCEPass());
170
171 if (!pass_config.disable_variable_freezing) {
172 // This pass 'freezes' immutable global tensors and inlines them as tf
173 // constant ops.
174 pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass(
175 /*allow_mutable_tensors=*/pass_config.enable_tflite_variables));
176 }
177
178 // The below passes only make sense if Builtin TFLite ops are enabled
179 // for emission.
180 if (pass_config.emit_builtin_tflite_ops) {
181 // Run shape inference after variables are converted to constants.
182 if (pass_config.shape_inference) {
183 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
184 }
185 // Force layout supported by TFLite, this will transpose the data
186 // to match 'kTFLiteDataLayout'
187 mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
188 layout_optimization_options.force_data_format = kTFLiteDataLayout;
189 layout_optimization_options.skip_fold_transpose_in_ops = true;
190 mlir::TF::CreateLayoutOptimizationPipeline(
191 pass_manager->nest<mlir::FuncOp>(), layout_optimization_options);
192 // Prepare for TFLite dialect, rerun canonicalization, and then legalize to
193 // the TFLite dialect.
194 pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreatePrepareTFPass(
195 pass_config.unfold_batch_matmul,
196 /*allow_bf16_and_f16_type_legalization=*/!pass_config
197 .runtime_verification));
198 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
199 if (pass_config.shape_inference) {
200 // Add a shape inference pass to optimize away the unnecessary casts.
201 // This also fixes the unranked shapes due to TF ops constant folding.
202 // TODO(fengliuai): remove this pass if TableGen patterns have a better
203 // to control the shapes for the intermediate results.
204 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
205 }
206
207 // Inline function calls that left in the graph after folding functional
208 // control flow ops (IfOp, CaseOp).
209 pass_manager->addPass(mlir::createInlinerPass());
210
211 // This pass removes the asset file dependencies in hash table use cases.
212 pass_manager->addNestedPass<mlir::FuncOp>(
213 mlir::TF::CreateInitTextFileToImportPass());
214
215 pass_manager->addNestedPass<mlir::FuncOp>(
216 mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
217 if (pass_config.enable_tflite_variables) {
218 pass_manager->addPass(mlir::TFL::CreateInitializeVariablesPass());
219 pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
220 pass_manager->addPass(mlir::TFL::CreateRemoveArgsAndGlobalTensors());
221 }
222 pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateOptimizePass());
223 // This pass operates on TensorFlow ops but is triggered after legalization
224 // so that it can target constants introduced once TensorFlow Identity ops
225 // are removed during legalization.
226 pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
227 pass_manager->addNestedPass<mlir::FuncOp>(
228 mlir::TFL::CreateRaiseCustomOpsPass());
229 pass_manager->addPass(mlir::createSymbolDCEPass());
230 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
231 pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
232
233 // Run quantization after all the floating point model conversion is
234 // completed.
235 if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
236 AddQuantizationPasses(pass_config.quant_specs, pass_manager);
237 }
238
239 // This pass should be always at the end of the model
240 // conversion (even after quantization). Some TFL ops like unidirectional
241 // sequence lstm will have stateful operands and some optimization passes
242 // will merge those operands if they have identical values & types. However,
243 // it's not desired by TFL. This pass serves as a "fix" pass to split the
244 // merged inputs until we have 1st class variable support or reuse
245 // tf.variable to model this.
246 pass_manager->addNestedPass<mlir::FuncOp>(
247 mlir::TFL::CreateSplitMergedOperandsPass());
248
249 // Add CallOnceOp when there is a session initializer function in tf saved
250 // model dialect.
251 pass_manager->addPass(
252 mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
253 }
254 }
255
256 } // namespace tensorflow
257
258 namespace mlir {
259 namespace TFL {
260
261 struct StandardPipelineOptions
262 : public PassPipelineOptions<StandardPipelineOptions> {
263 // TODO(b/150915052): All the tf_tfl_translate_cl flags should
264 // move inside this.
265 };
266
267 // NOLINTNEXTLINE
268 // This creates the standard pass pipeline for TF->TFLite. This
269 // represents a std configuration for TFLite, for use with APIs like
270 // tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline
271 // This does not yet include quantization passes.
CreateTFLStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)272 void CreateTFLStandardPipeline(OpPassManager& pm,
273 const StandardPipelineOptions& options) {
274 OpPassManager& func_pm = pm.nest<FuncOp>();
275
276 // tf_executor dialect passes - Cleaning up the IR.
277 mlir::TF::StandardPipelineOptions standard_pipeline_options;
278 mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
279
280 // This is needed for control flow support with TF TensorList.
281 pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());
282
283 // Saved model pass to mark global tensors immutable.
284 pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
285 // Op fusion pass.
286 pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
287
288 pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeTFWhilePass());
289
290 pm.addPass(mlir::createInlinerPass());
291
292 // Canonicalize, CSE etc.
293 pm.addPass(mlir::TF::CreateDecodeConstantPass());
294 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
295 pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
296 // DCE for private symbols.
297 pm.addPass(mlir::createSymbolDCEPass());
298
299 // freeze global tensors.
300 pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
301
302 // TFLite dialect passes.
303 pm.addPass(mlir::TFL::CreatePrepareTFPass(
304 /*unfold_batch_matmul=*/true, /*allow_bf16_type_legalization=*/false));
305 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
306 pm.addPass(
307 mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
308 pm.addPass(mlir::TFL::CreateOptimizePass());
309 pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
310 pm.addPass(mlir::createSymbolDCEPass());
311
312 // Canonicalize, CSE etc.
313 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
314 pm.addNestedPass<mlir::tf_saved_model::SessionInitializerOp>(
315 mlir::createCanonicalizerPass());
316 pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
317
318 // Pass for stateful operands like LSTM.
319 pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass());
320
321 pm.addPass(mlir::TFL::CreateWhileOutlinePass());
322
323 pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateRuntimeVerifyPass());
324 }
325
326 // Registers a pass pipeline for the standard TFL passes.
327 static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
328 "tfl-standard-pipeline",
329 "Run the standard passes involved in transforming/optimizing the TF "
330 "program to TFLite after "
331 "importing into MLIR.",
332 CreateTFLStandardPipeline);
333
334 } // namespace TFL
335 } // namespace mlir
336