1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting operations in TensorFlow dialect into operations that can be
18 // legalized to TensorFlow Lite dialect with simple replacements. The newly
19 // created operations are in the TensorFlow dialect if the operation can be
20 // represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is
21 // used. For example, Conv2D in TFLite which uses OHWI data format for filters
22 // is not supported in TensorFlow because TensorFlow requires filters in the
23 // HWIO data format.
24 //
25 // Motivation to prepare for the TFLite legalization before the actual
26 // legalization is to exploit constant folding opportunities in any newly
27 // created ops by leveraging constant folding support for the TensorFlow ops.
28 // This way TFLite can be used as a serialization format only and does not
29 // require access to the TFLite runtime for optimizations as required by the
30 // TFLite team.
31
32 #include <climits>
33 #include <cstdint>
34
35 #include "absl/memory/memory.h"
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/StringSwitch.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/Debug.h"
41 #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
42 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
43 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
44 #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
45 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
46 #include "mlir/IR/Attributes.h" // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
49 #include "mlir/IR/MLIRContext.h" // from @llvm-project
50 #include "mlir/Pass/Pass.h" // from @llvm-project
51 #include "mlir/Support/LLVM.h" // from @llvm-project
52 #include "mlir/Support/LogicalResult.h" // from @llvm-project
53 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
54 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
55 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
56 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
57 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
58 #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
59 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
60 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
61 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
62 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
63 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
64 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
65 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
66 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
67 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
68 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
69
70 #define DEBUG_TYPE "tf-tfl-legalization"
71
72 namespace mlir {
73 namespace TFL {
74 //===----------------------------------------------------------------------===//
75 // The actual PrepareTF Pass.
76 //
77 // TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this
78 // pass.
79 namespace {
80
81 // Prepare TF operations in functions for subsequent legalization.
82 class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
83 public:
84 PrepareTFPass() = default;
PrepareTFPass(const PrepareTFPass &)85 PrepareTFPass(const PrepareTFPass &) {}
PrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_and_f16_type_legalization)86 explicit PrepareTFPass(bool unfold_batch_matmul,
87 bool allow_bf16_and_f16_type_legalization) {
88 unfold_batch_matmul_ = unfold_batch_matmul;
89 allow_bf16_and_f16_type_legalization_ =
90 allow_bf16_and_f16_type_legalization;
91 }
92 void runOnFunction() override;
93
getDependentDialects(DialectRegistry & registry) const94 void getDependentDialects(DialectRegistry ®istry) const override {
95 registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
96 TFL::TensorFlowLiteDialect>();
97 }
98
99 private:
100 Option<bool> unfold_batch_matmul_{
101 *this, "tfl-unfold-batch-matmul",
102 llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
103 llvm::cl::init(true)};
104
105 Option<bool> allow_bf16_and_f16_type_legalization_{
106 *this, "tfl-allow-bf16-and-f16-type-legalization",
107 llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)};
108 };
109
110 template <class TFFakeQuantOp>
111 struct FetchConstantMinMaxInputs {
112 using AttrType = DenseFPElementsAttr;
operator ()mlir::TFL::__anonf6df4be00111::FetchConstantMinMaxInputs113 bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
114 AttrType &max_value) const {
115 Value min = tf_op.min(), max = tf_op.max();
116
117 // TODO: incomplete neither IdentityN ops
118 // nor chains of Identity* (not rare) are handled
119 if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
120 min = id1.input();
121 if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
122 max = id2.input();
123 if (!matchPattern(min, m_Constant(&min_value))) {
124 return false;
125 }
126 if (!matchPattern(max, m_Constant(&max_value))) {
127 return false;
128 }
129 return true; // Succesfully matched and fetched.
130 }
131 };
132
133 template <class TFFakeQuantOp>
134 struct FetchMinMaxAttrs {
135 using AttrType = FloatAttr;
operator ()mlir::TFL::__anonf6df4be00111::FetchMinMaxAttrs136 bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
137 AttrType &max_value) const {
138 min_value = tf_op.minAttr();
139 max_value = tf_op.maxAttr();
140 return true; // Succesfully matched and fetched.
141 }
142 };
143
144 // TODO(fengliuai): move this rule to PreparePatterns.td
145 // TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass.
146 // TODO(b/140968741): propagate the sign from the command line. Currently all
147 // the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
148 // actually INT8.
149 // Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
150 // tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op
151 // to be constant folded. Since the constant
152 // folding logic will use a "std.constant" op to replace the
153 // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
154 // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
155 // convert the output type to the next op. Here are the transformations:
156 //
157 // input min cst max cst input min cst max cst
158 // \ | | \ | |
159 // \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity)
160 // \ | | \ | |
161 // tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars
162 // | |
163 // tf.quantize
164 // |
165 // tf.dequantize
166 // |
167 // If the input is a constant, the result pattern will eventually converted to
168 //
169 // quant-emulated input
170 // |
171 // tf.quantize
172 // |
173 // tf.dequantize
174 // |
175 //
176 //
177 // Warns if the (most likely unwanted, currently not quite correctly handled)
178 // case of back-to-back tf.FakeQuant occurs
179 //
180 // tf.FakeQuant*
181 // |
182 // tf.FakeQuant*
183 //
184 // tf.identity / tf.IdentityN between the tf.FakeQuant* ops
185 // need no special treatment are already eliminated before the rewrites / check
186 // is applied.
187 //
188
189 template <typename TFFakeQuantOp, bool PerAxis, class FetchMinMax>
190 struct InsertTFLQuantOpsAfterTFFakeQuantOp
191 : public OpRewritePattern<TFFakeQuantOp> {
192 using BaseType =
193 InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis, FetchMinMax>;
194
InsertTFLQuantOpsAfterTFFakeQuantOpmlir::TFL::__anonf6df4be00111::InsertTFLQuantOpsAfterTFFakeQuantOp195 explicit InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis,
196 FetchMinMax>(MLIRContext *ctx)
197 : OpRewritePattern<TFFakeQuantOp>(ctx) {}
198
199 FetchMinMax fetchMinMax;
200
201 using FetchAttrType = typename FetchMinMax::AttrType;
matchAndRewritemlir::TFL::__anonf6df4be00111::InsertTFLQuantOpsAfterTFFakeQuantOp202 LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
203 PatternRewriter &rewriter) const override {
204 // We don't want to insert quantize/dequantize if the quantize op exists.
205 auto res = tf_op.outputs();
206 if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin())) {
207 return failure();
208 }
209
210 // Extract the min/max constant values from the operands. We also consider
211 // a special case that there are tf.Identity ops between the min/max
212 // constants and the tf.FakeQuantWithMinMaxVarsOp.
213
214 FetchAttrType min_value, max_value;
215 if (!fetchMinMax(tf_op, min_value, max_value)) {
216 return failure();
217 }
218
219 int quant_dim = -1;
220 if (PerAxis) {
221 // This is a special case that the quant_dim is the last dimensions.
222 quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
223 }
224 // Use the min/max from the operands and the num_bits and narrow_range
225 // attribute to create the quantization parameter for the new quantize op.
226 rewriter.setInsertionPointAfter(tf_op.getOperation());
227 IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits());
228 BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
229 Type res_type = tf_op.getType();
230 TypeAttr qtype = quant::GetQuantizedTypeAttr(
231 rewriter, res_type, min_value, max_value, quant_dim, num_bits,
232 narrow_range, /*is_signed=*/false);
233 if (!qtype) {
234 return failure();
235 }
236
237 // Finally, use the quantization parameter to create the quantize and
238 // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
239 // and its users.
240 Value value = tf_op.outputs();
241 auto quantize = rewriter.create<TFL::QuantizeOp>(
242 tf_op.getLoc(), qtype.getValue(), value, qtype);
243 auto dequantize = rewriter.create<TFL::DequantizeOp>(
244 tf_op.getLoc(), res_type, quantize.output());
245 value.replaceAllUsesWith(dequantize);
246 quantize.getOperation()->replaceUsesOfWith(dequantize, value);
247
248 return success();
249 }
250 };
251
252 //
253 // Three instances of the rule to cover the three different types of
254 // TF::FakeQuant operators
255 //
256 using PreparePerTensorFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp<
257 TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false,
258 FetchConstantMinMaxInputs<TF::FakeQuantWithMinMaxVarsOp>>;
259
260 using PreparePerChannelFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp<
261 TF::FakeQuantWithMinMaxVarsPerChannelOp, /*PerAxis=*/true,
262 FetchConstantMinMaxInputs<TF::FakeQuantWithMinMaxVarsPerChannelOp>>;
263
264 using PreparePerTensorFakeQuantWithMinMaxArgs =
265 InsertTFLQuantOpsAfterTFFakeQuantOp<
266 TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false,
267 FetchMinMaxAttrs<TF::FakeQuantWithMinMaxArgsOp>>;
268
269 // Transient state for preserving data from match to rewrite
270 struct ConvertTFConvOpMatchState {
271 IntegerAttr dilation_height_factor;
272 IntegerAttr dilation_width_factor;
273 StringAttr padding;
274 IntegerAttr stride_height;
275 IntegerAttr stride_width;
276 };
277
278 // Templated class for declaring a converter from some TensorFlow convolution
279 // op into its counterpart in TensorFlow Lite.
280 //
281 // The `ConcreteType` deriving from this template must provide the following
282 // method for constructing TensorFlow Lite op:
283 //
284 // TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
285 // PatternRewriter &rewriter, Location loc,
286 // Type result_type, Value input,
287 // Value filter, Value bias) const;
288 //
289 // And also the following method for getting the dimension for bias tensor:
290 //
291 // int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
292 template <typename ConcreteType, typename TFConvOpType>
293 class ConvertTFConvOp : public RewritePattern {
294 public:
ConvertTFConvOp(MLIRContext * context,bool allow_bf16_and_f16_type_legalization)295 ConvertTFConvOp(MLIRContext *context,
296 bool allow_bf16_and_f16_type_legalization)
297 : RewritePattern(TFConvOpType::getOperationName(), 1, context),
298 intAttrOne(Builder(context).getI32IntegerAttr(1)),
299 allow_bf16_and_f16_type_legalization_(
300 allow_bf16_and_f16_type_legalization) {}
301
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const302 LogicalResult matchAndRewrite(Operation *op,
303 PatternRewriter &rewriter) const override {
304 // Assumes TensorFlow convolution op is already verified to be
305 // in valid form.
306
307 // Match a TFConvOpType under the following conditions:
308 // * The 'T' attribute must exist and be of value DT_FLOAT.
309 // * The 'data_format' attribute must exist and be of value "NHWC".
310 // * The 'strides' attribute must exist and is of the form [1, X, Y, 1].
311 // * The 'dilations' attribute is optional, but it must be of the form
312 // [1, X, Y, 1] if exists.
313
314 TFConvOpType tf_op = cast<TFConvOpType>(op);
315
316 if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
317 !(allow_bf16_and_f16_type_legalization_ &&
318 TFTypeIsBFloat16OrHalfTensor(tf_op.input())))
319 return failure();
320
321 if (!TFDataFormatIsNHWC(op)) return failure();
322
323 IntegerAttr height, width;
324 if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
325
326 ConvertTFConvOpMatchState state;
327 state.stride_height = height;
328 state.stride_width = width;
329
330 if (TFIntListIs1XY1(op, "dilations", &height, &width)) {
331 state.dilation_height_factor = height;
332 state.dilation_width_factor = width;
333 } else {
334 // If the 'dilations' attribute is missing, we use the default value (1)
335 // for both dilation height and width factor.
336 state.dilation_height_factor = intAttrOne;
337 state.dilation_width_factor = intAttrOne;
338 }
339
340 if (!TFPaddingIsSameOrValid(op, &state.padding)) return failure();
341
342 // Additionally, we require the filter operand to be of 4-D tensor type so
343 // that we can extract info from the shape (e.g., for constructing bias
344 // tensor, for setting depth_multiplier attribute, etc.).
345 auto filter = tf_op.filter();
346 auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
347 if (!filter_type || filter_type.getRank() != 4 ||
348 !filter_type.hasStaticShape())
349 return failure();
350
351 // TensorFlow convolution op only has two inputs, while the TFLite one has
352 // three, with the bias vector marked as optional. However, TOCO has a
353 // dedicated pass, EnsureBiasVectors, to create default bias vectors for all
354 // those missing. So we model TFLite convolution op as requiring three
355 // inputs to achieve the legalization task of EnsureBiasVector. this
356 // requires the filter tensor to have static shape.
357
358 // TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>)
359
360 // Get a splat zero tensor with the expected dimension for the bias tensor
361 auto elem_type = filter_type.getElementType();
362 auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
363 filter_type.getShape());
364 auto bias_type = RankedTensorType::get({bias_dim}, elem_type);
365 auto bias_attr = rewriter.getZeroAttr(bias_type);
366 auto bias =
367 rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
368
369 auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
370 &state, rewriter, op->getLoc(), tf_op.getType(), tf_op.input(), filter,
371 bias);
372
373 rewriter.replaceOp(op, conv_op.getResult());
374 return success();
375 }
376
377 const IntegerAttr intAttrOne;
378
379 private:
380 bool allow_bf16_and_f16_type_legalization_;
381 };
382
383 class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
384 public:
385 using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
386
ConvertTFConv2D(MLIRContext * context,bool allow_bf16_type_legalization)387 ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization)
388 : BaseType(context, allow_bf16_type_legalization) {}
389
getBiasDim(ArrayRef<int64_t> filterShape) const390 int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
391 return filterShape.back();
392 }
393
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const394 TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
395 PatternRewriter &rewriter, Location loc,
396 Type result_type, Value input, Value filter,
397 Value bias) const {
398 filter = legalizeFilter(rewriter, loc, filter);
399 return rewriter.create<TFL::Conv2DOp>(
400 loc, result_type, input, filter, bias,
401 /*dilation_h_factor=*/state->dilation_height_factor,
402 /*dilation_w_factor=*/state->dilation_width_factor,
403 /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
404 /*padding=*/state->padding,
405 /*stride_h=*/state->stride_height,
406 /*stride_w=*/state->stride_width);
407 }
408
409 private:
410 // Legalize the given filter by converting it from TensorFlow filter data
411 // format HWIO to TFLite Conv2D op filter data format OHWI and return Value
412 // for the converted filter. Requires that filter is verified by the match
413 // method that it is a 4-D RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const414 Value legalizeFilter(PatternRewriter &rewriter, Location loc,
415 Value filter) const {
416 // Create a constant op for HWIO to OHWI transpose permutation.
417 SmallVector<int, 4> perm = {3, 0, 1, 2};
418 auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
419 rewriter.getIntegerType(32));
420 auto perm_attr =
421 DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
422 auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
423
424 // Create tensor type for the transpose result.
425 auto filter_type = filter.getType().cast<RankedTensorType>();
426 auto result_shape =
427 llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) {
428 return filter_type.getDimSize(dim);
429 }));
430 auto elem_type = filter_type.getElementType();
431 auto result_type = RankedTensorType::get(result_shape, elem_type);
432
433 return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op);
434 }
435 };
436
437 class ConvertTFDepthwiseConv2dNative
438 : public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
439 TF::DepthwiseConv2dNativeOp> {
440 public:
441 using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
442 TF::DepthwiseConv2dNativeOp>;
443
ConvertTFDepthwiseConv2dNative(MLIRContext * context,bool allow_bf16_type_legalization)444 ConvertTFDepthwiseConv2dNative(MLIRContext *context,
445 bool allow_bf16_type_legalization)
446 : BaseType(context, allow_bf16_type_legalization) {}
447
getBiasDim(ArrayRef<int64_t> filterShape) const448 int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
449 return filterShape[2] * filterShape[3];
450 }
451
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const452 TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
453 PatternRewriter &rewriter, Location loc,
454 Type result_type, Value input,
455 Value filter, Value bias) const {
456 // Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
457 // 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
458 // have a corresponding 'depth_multiplier' attribute; the multiplier is the
459 // fourth dimension in the 4-D filter tensor. We query the multiplier from
460 // tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
461 auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
462
463 filter = legalizeFilter(rewriter, loc, filter);
464 return rewriter.create<TFL::DepthwiseConv2DOp>(
465 loc, result_type, input, filter, bias,
466 /*dilation_h_factor=*/state->dilation_height_factor,
467 /*dilation_w_factor=*/state->dilation_width_factor,
468 /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
469 /*padding=*/state->padding,
470 /*stride_h=*/state->stride_height,
471 /*stride_w=*/state->stride_width,
472 /*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier));
473 }
474
475 private:
476 /// Legalize the given filter by converting it from TensorFlow filter data
477 /// format to TFLite DepthwiseConv2D op filter data format and return Value
478 /// for the converted filter. TensorFlow filter data format is
479 /// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite
480 /// filter data format is [1, filter_height, filter_width, out_channels].
481 /// Requires that filter is verified by the match method that it is a 4-D
482 /// RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const483 Value legalizeFilter(PatternRewriter &rewriter, Location loc,
484 Value filter) const {
485 auto filter_type = filter.getType().cast<RankedTensorType>();
486 auto filterShape = filter_type.getShape();
487 SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
488 filterShape[2] * filterShape[3]};
489 auto elem_type = filter_type.getElementType();
490 auto result_type = RankedTensorType::get(result_shape, elem_type);
491 // TensorFlow Lite `Reshape` op only support int32 shape tensor currently.
492 auto shape_type = RankedTensorType::get({4}, rewriter.getIntegerType(32));
493 SmallVector<Attribute, 4> result_shape_data(4);
494 for (int i = 0; i < 4; ++i) {
495 result_shape_data[i] =
496 rewriter.getI32IntegerAttr(static_cast<int32_t>(result_shape[i]));
497 }
498 auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
499 auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
500
501 return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
502 }
503 };
504
505 // StridedSlice can have complicated attributes like begin_axis_mask,
506 // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
507 // masks will complicate the strided_slice computation logic, we can simplify
508 // the logic by inserting a reshape op to pad the inputs so strided_slice can
509 // be easier to handle.
510 //
511 // So the graph may looks like below:
512 // original_input -> strided_slice -> output
513 // (transforms)
514 // original_input -> reshape -> strided_slice -> output
515 //
516 // And the new shape is computed based on the masks.
517 //
518 // An example for new_axis_mask. say the new_axis_mask is 9 which represents
519 // [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so
520 // if original shape is [2, 3], now we reshape that into [1, 2, 3, 1].
521 struct ConvertTFStridedSlice : public RewritePattern {
ConvertTFStridedSlicemlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice522 explicit ConvertTFStridedSlice(MLIRContext *context)
523 : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
524
RewriteNewAxisMaskmlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice525 LogicalResult RewriteNewAxisMask(Operation *op,
526 PatternRewriter &rewriter) const {
527 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
528 uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
529
530 // Insert a new reshape op.
531 Value original_input = strided_slice_op.input();
532 RankedTensorType original_input_type =
533 original_input.getType().dyn_cast<RankedTensorType>();
534 if (!original_input_type) {
535 return failure();
536 }
537
538 const ArrayRef<int64_t> &original_input_shape =
539 original_input_type.getShape();
540 SmallVector<int64_t, 4> revised_shape;
541 int index = 0;
542 const int original_input_rank = original_input_shape.size();
543 while (index < original_input_rank || new_axis_mask) {
544 if (new_axis_mask & 1) {
545 revised_shape.emplace_back(1);
546 } else {
547 revised_shape.emplace_back(original_input_shape[index++]);
548 }
549 new_axis_mask >>= 1;
550 }
551
552 if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
553
554 const int dim_size = revised_shape.size();
555 Location loc = strided_slice_op.getLoc();
556 auto shape_type =
557 RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
558 SmallVector<Attribute, 4> result_shape_data(dim_size);
559 for (int i = 0; i < dim_size; ++i) {
560 result_shape_data[i] =
561 rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
562 }
563
564 auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
565 auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
566 auto revised_output_type = RankedTensorType::get(
567 revised_shape, original_input_type.getElementType());
568 TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
569 loc, revised_output_type, original_input, shape);
570
571 // Replace the original strided_slice.
572 uint64_t revised_begin_mask = strided_slice_op.begin_mask();
573 uint64_t revised_end_mask = strided_slice_op.end_mask();
574 // Since we expand the dims, we need to apply them to the begin_mask &
575 // end_mask.
576 revised_begin_mask |= strided_slice_op.new_axis_mask();
577 revised_end_mask |= strided_slice_op.new_axis_mask();
578
579 auto attribute_type = rewriter.getIntegerType(64);
580 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
581 op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
582 strided_slice_op.end(), strided_slice_op.strides(),
583 rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
584 rewriter.getIntegerAttr(attribute_type, revised_end_mask),
585 rewriter.getIntegerAttr(attribute_type,
586 strided_slice_op.ellipsis_mask()),
587 rewriter.getI64IntegerAttr(0),
588 rewriter.getIntegerAttr(attribute_type,
589 strided_slice_op.shrink_axis_mask()));
590 return success();
591 }
592
RewriteEllipsisMaskmlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice593 LogicalResult RewriteEllipsisMask(Operation *op,
594 PatternRewriter &rewriter) const {
595 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
596
597 uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
598 uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
599
600 // Enforce operator precedence.
601 shrink_axis_mask &= ~ellipsis_mask;
602
603 DenseIntElementsAttr begin_dense_elem_attr;
604 Value begin = strided_slice_op.begin();
605 auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
606 if (!begin_ranked_attr_type ||
607 !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
608 return failure();
609 }
610
611 DenseIntElementsAttr end_dense_elem_attr;
612 Value end = strided_slice_op.end();
613 auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
614 if (!end_ranked_attr_type ||
615 !matchPattern(end, m_Constant(&end_dense_elem_attr))) {
616 return failure();
617 }
618
619 DenseIntElementsAttr stride_dense_elem_attr;
620 Value stride = strided_slice_op.strides();
621 auto stride_ranked_attr_type =
622 stride.getType().dyn_cast<RankedTensorType>();
623 if (!stride_ranked_attr_type ||
624 !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
625 return failure();
626 }
627
628 Value input = strided_slice_op.input();
629 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
630 if (!input_type) {
631 return failure();
632 }
633 const ArrayRef<int64_t> input_shape = input_type.getShape();
634
635 const int input_size = input_shape.size();
636
637 RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
638 const ArrayRef<int64_t> begin_shape = begin_type.getShape();
639 const int begin_dim = begin_shape.size();
640
641 if (begin_dim != 1) return failure();
642
643 const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
644
645 int64_t begin_mask = strided_slice_op.begin_mask();
646 int64_t end_mask = strided_slice_op.end_mask();
647 int64_t revised_begin_mask = 0;
648 int64_t revised_end_mask = 0;
649 int64_t revised_shrink_axis_mask = 0;
650
651 SmallVector<int32_t, 4> padded_begin;
652 SmallVector<int32_t, 4> padded_end;
653 SmallVector<int32_t, 4> padded_stride;
654
655 // Before the ellipsis.
656 int index = 0;
657 int new_index = 0;
658 while (((ellipsis_mask >> index) & 1) == 0) {
659 padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
660 padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
661 padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
662 if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
663 if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
664 if ((shrink_axis_mask >> index) & 1)
665 revised_shrink_axis_mask |= (1 << new_index);
666 ++index;
667 ++new_index;
668 }
669
670 // Ellipsis.
671 for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
672 revised_begin_mask |= (1 << new_index);
673 revised_end_mask |= (1 << new_index);
674
675 // Mimic the begin/end/strides mask behavior.
676 padded_begin.push_back(0);
677 padded_end.push_back(0);
678 padded_stride.push_back(1);
679 }
680
681 // Account for ellipsis mask.
682 ++index;
683
684 // After the ellipsis.
685 for (; index < begin_shape[0];) {
686 padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
687 padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
688 padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
689
690 if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
691 if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
692 if ((shrink_axis_mask >> index) & 1)
693 revised_shrink_axis_mask |= (1 << new_index);
694
695 ++index;
696 ++new_index;
697 }
698
699 auto attribute_type = rewriter.getIntegerType(64);
700
701 int full_dim_count = padded_begin.size();
702 auto type =
703 RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
704
705 auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
706 auto begin_op = rewriter.create<ConstantOp>(op->getLoc(), type, begin_attr);
707 auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
708 auto end_op = rewriter.create<ConstantOp>(op->getLoc(), type, end_attr);
709 auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
710 auto stride_op =
711 rewriter.create<ConstantOp>(op->getLoc(), type, stride_attr);
712
713 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
714 op, strided_slice_op.getType(), input, begin_op.getResult(),
715 end_op.getResult(), stride_op.getResult(),
716 rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
717 rewriter.getIntegerAttr(attribute_type, revised_end_mask),
718 /*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
719 rewriter.getIntegerAttr(attribute_type,
720 strided_slice_op.new_axis_mask()),
721 rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
722 return success();
723 }
724
PadStridedSliceAttributeArraymlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice725 void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
726 SmallVectorImpl<int32_t> &val,
727 SmallVectorImpl<int32_t> &padded_val,
728 ArrayRef<int32_t> padding_val,
729 int *mask) const {
730 for (const auto &idx : dense_elem_attr.getIntValues()) {
731 val.push_back(idx.getSExtValue());
732 padded_val.push_back(idx.getSExtValue());
733 }
734 int attr_dim_count = val.size();
735 int full_dim_count = padding_val.size();
736 for (int i = attr_dim_count; i < full_dim_count; ++i) {
737 padded_val.push_back(padding_val[i]);
738 if (mask) *mask |= 1 << i;
739 }
740 }
741
matchAndRewritemlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice742 LogicalResult matchAndRewrite(Operation *op,
743 PatternRewriter &rewriter) const override {
744 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
745
746 // Handle new axis mask.
747 if (strided_slice_op.new_axis_mask() != 0) {
748 // We currently don't handle simultaneous shrink_ and new_axis masks.
749 if (!strided_slice_op.shrink_axis_mask()) {
750 return RewriteNewAxisMask(strided_slice_op, rewriter);
751 }
752 }
753
754 // Handle ellipsis mask.
755 if (strided_slice_op.ellipsis_mask() != 0) {
756 return RewriteEllipsisMask(strided_slice_op, rewriter);
757 }
758
759 auto ranked_input_type =
760 strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
761 if (!ranked_input_type) {
762 return failure();
763 }
764
765 auto begin_attr = strided_slice_op.begin();
766 auto end_attr = strided_slice_op.end();
767 auto strides_attr = strided_slice_op.strides();
768
769 auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
770 auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
771 auto strides_attr_type =
772 strides_attr.getType().dyn_cast<RankedTensorType>();
773
774 DenseIntElementsAttr begin_elem_attr;
775 DenseIntElementsAttr end_elem_attr;
776 DenseIntElementsAttr strides_elem_attr;
777
778 if (!begin_attr_type ||
779 !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
780 return failure();
781 }
782 if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
783 return failure();
784 }
785 if (!strides_attr_type ||
786 !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
787 return failure();
788 }
789
790 SmallVector<int32_t, 4> begin, end, strides;
791 SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
792
793 int num_input_dims = ranked_input_type.getRank();
794 SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
795 auto input_shape = ranked_input_type.getShape();
796 SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
797 SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
798
799 int begin_mask = strided_slice_op.begin_mask();
800 int end_mask = strided_slice_op.end_mask();
801
802 PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
803 padding_begin, &begin_mask);
804 PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
805 &end_mask);
806 PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
807 padding_strides, nullptr);
808
809 if (begin == padded_begin && end == padded_end &&
810 strides == padded_strides &&
811 begin_mask == strided_slice_op.begin_mask() &&
812 end_mask == strided_slice_op.end_mask()) {
813 return failure();
814 }
815
816 auto begin_end_type =
817 RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
818 auto new_begin_attr = rewriter.create<ConstantOp>(
819 op->getLoc(), begin_end_type,
820 DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
821 auto new_end_attr = rewriter.create<ConstantOp>(
822 op->getLoc(), begin_end_type,
823 DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
824 auto strides_type =
825 RankedTensorType::get({static_cast<long>(padded_strides.size())},
826 rewriter.getIntegerType(32));
827 auto new_strides_attr = rewriter.create<ConstantOp>(
828 op->getLoc(), strides_type,
829 DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
830
831 auto attribute_type = rewriter.getIntegerType(64);
832 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
833 op, strided_slice_op.output().getType(), strided_slice_op.input(),
834 new_begin_attr, new_end_attr, new_strides_attr,
835 rewriter.getIntegerAttr(attribute_type, begin_mask),
836 rewriter.getIntegerAttr(attribute_type, end_mask),
837 rewriter.getIntegerAttr(attribute_type,
838 strided_slice_op.ellipsis_mask()),
839 rewriter.getIntegerAttr(attribute_type,
840 strided_slice_op.new_axis_mask()),
841 rewriter.getIntegerAttr(attribute_type,
842 strided_slice_op.shrink_axis_mask()));
843
844 return success();
845 }
846 };
847
848 struct ConvertTFBroadcastTo : public RewritePattern {
ConvertTFBroadcastTomlir::TFL::__anonf6df4be00111::ConvertTFBroadcastTo849 explicit ConvertTFBroadcastTo(MLIRContext *context)
850 : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
851
matchAndRewritemlir::TFL::__anonf6df4be00111::ConvertTFBroadcastTo852 LogicalResult matchAndRewrite(Operation *op,
853 PatternRewriter &rewriter) const override {
854 auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
855 auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
856 auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
857 auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
858 Type element_type = input_type.getElementType();
859
860 // Allow lowering when low dimension inputs are given and its type is F32 or
861 // I32.
862 if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
863 (shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
864 shape_type.getDimSize(0) <= 4)))
865 return failure();
866
867 if (!(element_type.isa<BFloat16Type, Float32Type>() ||
868 element_type.isInteger(32) || element_type.isInteger(16)))
869 return failure();
870
871 auto status_or_const_op =
872 CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
873 if (!status_or_const_op.ok()) {
874 return failure();
875 }
876
877 auto tf_fill_op = rewriter.create<TF::FillOp>(
878 op->getLoc(), output_type, tf_broadcast_to_op.shape(),
879 status_or_const_op.ValueOrDie());
880
881 auto mul_op = rewriter.create<TF::MulOp>(
882 op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
883 rewriter.replaceOp(op, mul_op.getResult());
884 return success();
885 }
886 };
887
888 // The below pattern is equivalent to the DRR rule below
889 // The checks are dependent on generated values, so we can't add
890 // the checks on intermediate values, ideally we should find equivalent
891 // checks that guarantees the resultant ops are valid.
892 // The extra conditions are the broadcasting conditions.
893 //
894 // The pattern lower FusedBatchNormV3 to arithmetic ops.
895 // Specifically, performs the following calculation:
896 //
897 // (x - mean) * scale / sqrt(variance + epsilon) + offset
898 //
899 // Let multiplier = scale / sqrt(variance + epsilon),
900 // to compute
901 // (x - mean) * scale / sqrt(variance + epsilon) + offset,
902 // is then to compute
903 // (x * multiplier) + (offset - mean * multiplier).
904 //
905 // def : Pattern<
906 // (TF_FusedBatchNormV3Op:$root
907 // $x, $scale, $offset, $mean, $variance,
908 // F32Attr:$epsilon, $exponential_avg_factor,
909 // $data_format, FalseBoolAttr:$is_training),
910 // [(TF_AddOp
911 // (TF_MulOp
912 // $x,
913 // (TF_MulOp:$multiplier
914 // $scale,
915 // (TF_RsqrtOp
916 // (TF_AddOp $variance,
917 // (TF_ConstOp $epsilon))))),
918 // (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
919 // // We already guaranteed that the last five results have no use so it does
920 // // not matter what value we provide here for replacement.
921 // /*batch_mean=*/(replaceWithValue $x),
922 // /*batch_variance=*/(replaceWithValue $x),
923 // /*reserve_space_1=*/(replaceWithValue $x),
924 // /*reserve_space_2=*/(replaceWithValue $x),
925 // /*reserve_space_3=*/(replaceWithValue $x)],
926 // [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
927 // (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
928 // (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
929 //
930 // When is_training is set to true, the given variance and mean are not used.
931 // In above calculation, they are replaced by new values. These new mean and
932 // variance are calculated as following:
933 // rest_size = shape(x)[0] * shape(x)[1] * shape(x)[2]
934 // new_mean = sum(x, axis=[0, 1, 2]) / rest_size
935 // new_variance = sum(squared_difference(x, new_mean), axis=[0, 1, 2])
936 // / rest_size
937 //
938 // The DDR rule for the is_training equals true case is as following:
939 // def : Pattern<
940 // (TF_FusedBatchNormV3Op:$root
941 // $x, $scale, $offset, $mean, $variance,
942 // F32Attr:$epsilon, $exponential_avg_factor,
943 // $data_format, FalseBoolAttr:$is_training),
944 // [(TF_AddOp
945 // (TF_MulOp
946 // $x,
947 // (TF_MulOp:$multiplier
948 // $scale,
949 // (TF_RsqrtOp
950 // (TF_AddOp
951 // (TF_DivOp:$new_variance
952 // (TF_SumOp
953 // (TF_SquaredDifferenceOp $x, $new_mean),
954 // (TF_ConstOp [0,1,2])),
955 // $rest_size),
956 // (TF_ConstOp $epsilon))))),
957 // (TF_SubOp
958 // $offset,
959 // (TF_MulOp
960 // (TF_DivOp:$new_mean
961 // (TF_SumOp $x, (TF_ConstOp [0,1,2])),
962 // (TF_ProdOp:$rest_size
963 // (TF_SliceOp
964 // (TF_ShapeOp $x),
965 // (TF_ConstOp 0),
966 // (TF_ConstOp 3)))),
967 // $multiplier))),
968 // // We already guaranteed that the last five results have no use so it does
969 // // not matter what value we provide here for replacement.
970 // /*batch_mean=*/(replaceWithValue $x),
971 // /*batch_variance=*/(replaceWithValue $x),
972 // /*reserve_space_1=*/(replaceWithValue $x),
973 // /*reserve_space_2=*/(replaceWithValue $x),
974 // /*reserve_space_3=*/(replaceWithValue $x)],
975 // [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
976 // (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
977 // (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
978
979 struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
FusedBatchNormV3Patmlir::TFL::__anonf6df4be00111::FusedBatchNormV3Pat980 explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
981 : ::mlir::RewritePattern(
982 "tf.FusedBatchNormV3",
983 {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}, 1,
984 context) {}
985
matchAndRewritemlir::TFL::__anonf6df4be00111::FusedBatchNormV3Pat986 ::mlir::LogicalResult matchAndRewrite(
987 ::mlir::Operation *fused_batch_norm,
988 ::mlir::PatternRewriter &rewriter) const override {
989 // Variables for capturing values and attributes used for creating ops
990 Operation::operand_range mean(fused_batch_norm->getOperands());
991 ::mlir::FloatAttr exponential_avg_factor;
992 ::mlir::TF::FusedBatchNormV3Op root;
993 Operation::operand_range offset(fused_batch_norm->getOperands());
994 Operation::operand_range x(fused_batch_norm->getOperands());
995 Operation::operand_range scale(fused_batch_norm->getOperands());
996 Operation::operand_range variance(fused_batch_norm->getOperands());
997 ::mlir::FloatAttr epsilon;
998 ::mlir::BoolAttr is_training;
999
1000 // Match
1001 auto fused_batch_norm_op =
1002 dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm);
1003 root = fused_batch_norm_op;
1004 x = fused_batch_norm_op.getODSOperands(0);
1005 scale = fused_batch_norm_op.getODSOperands(1);
1006 offset = fused_batch_norm_op.getODSOperands(2);
1007 mean = fused_batch_norm_op.getODSOperands(3);
1008 variance = fused_batch_norm_op.getODSOperands(4);
1009
1010 ::mlir::Value mean_value = (*mean.begin());
1011 ::mlir::Value variance_value = (*variance.begin());
1012
1013 if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
1014
1015 {
1016 epsilon =
1017 fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>("epsilon");
1018 if (!epsilon)
1019 epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f);
1020
1021 if (!(((epsilon.isa<::mlir::FloatAttr>())) &&
1022 ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) {
1023 return rewriter.notifyMatchFailure(
1024 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1025 diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to "
1026 "satisfy constraint: 32-bit float attribute";
1027 });
1028 }
1029 }
1030 {
1031 exponential_avg_factor =
1032 fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>(
1033 "exponential_avg_factor");
1034 if (!exponential_avg_factor)
1035 exponential_avg_factor =
1036 rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
1037 }
1038 if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
1039 !TFDataFormatIsNDHWC(fused_batch_norm_op))
1040 return failure();
1041
1042 if (!(((*root.getODSResults(1).begin()).use_empty()))) {
1043 return rewriter.notifyMatchFailure(
1044 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1045 diag << "entities '' failed to satisfy constraint: has no use";
1046 });
1047 }
1048
1049 if (!(((*root.getODSResults(2).begin()).use_empty()))) {
1050 return rewriter.notifyMatchFailure(
1051 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1052 diag << "entities '' failed to satisfy constraint: has no use";
1053 });
1054 }
1055
1056 if (!(((*root.getODSResults(3).begin()).use_empty()))) {
1057 return rewriter.notifyMatchFailure(
1058 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1059 diag << "entities '' failed to satisfy constraint: has no use";
1060 });
1061 }
1062
1063 if (!(((*root.getODSResults(4).begin()).use_empty()))) {
1064 return rewriter.notifyMatchFailure(
1065 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1066 diag << "entities '' failed to satisfy constraint: has no use";
1067 });
1068 }
1069
1070 if (!(((*root.getODSResults(5).begin()).use_empty()))) {
1071 return rewriter.notifyMatchFailure(
1072 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1073 diag << "entities '' failed to satisfy constraint: has no use";
1074 });
1075 }
1076
1077 is_training =
1078 fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
1079 auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
1080
1081 // We need to make sure input and output shapes are compatible.
1082 {
1083 int64_t last_dim = -1;
1084 auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
1085 auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
1086 if (!v_type) return true;
1087 int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
1088 if (v_last_dim == -1) return true;
1089 if (last_dim != -1 && v_last_dim != last_dim) return false;
1090 last_dim = v_last_dim;
1091 return true;
1092 };
1093
1094 if (!is_last_dim_compatible(*x.begin(), last_dim) ||
1095 !is_last_dim_compatible(*scale.begin(), last_dim) ||
1096 !is_last_dim_compatible(*offset.begin(), last_dim)) {
1097 return rewriter.notifyMatchFailure(
1098 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1099 diag << "Shapes of scale and offset should be 1D and "
1100 "compatible with x";
1101 });
1102 }
1103
1104 if (!is_training.getValue()) {
1105 if (!is_last_dim_compatible(mean_value, last_dim) ||
1106 !is_last_dim_compatible(variance_value, last_dim)) {
1107 return rewriter.notifyMatchFailure(
1108 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1109 diag << "Shapes of mean and variance should be 1D and "
1110 "compatible with x";
1111 });
1112 }
1113 }
1114
1115 // Check if output shape and input shape are compatible.
1116 auto x_type = (*x.begin()).getType();
1117 auto y_type = (*root.getODSResults(0).begin()).getType();
1118 if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
1119 return rewriter.notifyMatchFailure(
1120 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1121 diag << "Shapes of x and the first output should be compatible";
1122 });
1123 }
1124 }
1125
1126 // For training, mean and variance is calculated from input values.
1127 if (is_training.getValue()) {
1128 auto input_type = fused_batch_norm_op.x()
1129 .getType()
1130 .dyn_cast_or_null<RankedTensorType>();
1131 if (!input_type || input_type.getRank() != 4 ||
1132 !input_type.hasStaticShape()) {
1133 return rewriter.notifyMatchFailure(
1134 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1135 diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
1136 "True is only supported with static input shape";
1137 });
1138 }
1139
1140 ::mlir::TF::ConstOp reduce_dim_op;
1141 {
1142 auto reduce_dim_type =
1143 ::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(64));
1144 ::mlir::SmallVector<int64_t, 3> reduce_dim_values = {0, 1, 2};
1145 reduce_dim_op = rewriter.create<TF::ConstOp>(
1146 odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
1147 reduce_dim_values));
1148 }
1149
1150 ::mlir::TF::ConstOp rest_size_inv_op;
1151 {
1152 int64_t rest_size = input_type.getDimSize(0) *
1153 input_type.getDimSize(1) * input_type.getDimSize(2);
1154 auto rest_size_inv_type =
1155 ::mlir::RankedTensorType::get({1}, rewriter.getF32Type());
1156 auto rest_size_inv_attr = ::mlir::DenseFPElementsAttr::get(
1157 rest_size_inv_type, {1.0f / rest_size});
1158 rest_size_inv_op =
1159 rewriter.create<::mlir::TF::ConstOp>(odsLoc, rest_size_inv_attr);
1160 }
1161
1162 ::mlir::TF::SumOp sum_op_1;
1163 {
1164 ::mlir::Value x_value = (*x.begin());
1165 sum_op_1 = rewriter.create<TF::SumOp>(
1166 odsLoc, x_value, reduce_dim_op,
1167 /*keep_dims=*/rewriter.getBoolAttr(false));
1168 }
1169
1170 ::mlir::TF::MulOp mul_op_1;
1171 {
1172 ::mlir::Value tblgen_value_0 = (*sum_op_1.getODSResults(0).begin());
1173 ::mlir::Value tblgen_value_1 =
1174 (*rest_size_inv_op.getODSResults(0).begin());
1175 mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
1176 tblgen_value_1);
1177 }
1178
1179 ::mlir::TF::SquaredDifferenceOp square_diff_op;
1180 {
1181 ::mlir::Value tblgen_value_0 = (*x.begin());
1182 ::mlir::Value tblgen_value_1 = (*mul_op_1.getODSResults(0).begin());
1183 // If x has shape of [b, h, w, c], the result of mul_op_1 will have
1184 // shape of [c]. Therefore, their shapes are always compatible.
1185 square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
1186 odsLoc, tblgen_value_0, tblgen_value_1);
1187 }
1188
1189 ::mlir::TF::SumOp sum_op_2;
1190 {
1191 ::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
1192 sum_op_2 = rewriter.create<TF::SumOp>(
1193 odsLoc, input_value, reduce_dim_op,
1194 /*keep_dims=*/rewriter.getBoolAttr(false));
1195 }
1196
1197 ::mlir::TF::MulOp mul_op_2;
1198 {
1199 ::mlir::Value tblgen_value_0 = (*sum_op_2.getODSResults(0).begin());
1200 ::mlir::Value tblgen_value_1 =
1201 (*rest_size_inv_op.getODSResults(0).begin());
1202 mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
1203 tblgen_value_1);
1204 }
1205
1206 mean_value = (*mul_op_1.getODSResults(0).begin());
1207 variance_value = (*mul_op_2.getODSResults(0).begin());
1208 } // End is_training equals true if.
1209
1210 ::llvm::SmallVector<::mlir::Value, 4> replace_values;
1211 ::mlir::TF::ConstOp epsilon_const_op;
1212 {
1213 epsilon_const_op =
1214 rewriter.create<::mlir::TF::ConstOp>(odsLoc,
1215 /*value=*/epsilon);
1216 }
1217 ::mlir::TF::AddOp add_op_1;
1218 {
1219 ::mlir::Value epsilon_value =
1220 (*epsilon_const_op.getODSResults(0).begin());
1221 // Multiplying with a constant, no need to check broadcastibility.
1222 add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
1223 /*x=*/variance_value,
1224 /*y=*/epsilon_value);
1225 }
1226 ::mlir::TF::RsqrtOp rsqrt_op;
1227 {
1228 ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1229 ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1230 tblgen_values.push_back((*add_op_1.getODSResults(0).begin()));
1231 rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values,
1232 tblgen_attrs);
1233 }
1234 ::mlir::TF::MulOp multiplier;
1235 {
1236 ::mlir::Value tblgen_value_0 = (*scale.begin());
1237 ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
1238 multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1239 /*x=*/tblgen_value_0,
1240 /*y=*/tblgen_value_1);
1241 }
1242 ::mlir::TF::MulOp mul_op_1;
1243 {
1244 ::mlir::Value tblgen_value_0 = (*x.begin());
1245 ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
1246 mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1247 /*x=*/tblgen_value_0,
1248 /*y=*/tblgen_value_1);
1249 }
1250 ::mlir::TF::MulOp mul_op_2;
1251 {
1252 ::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
1253 mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1254 /*x=*/mean_value,
1255 /*y=*/multiplier_value);
1256 }
1257 ::mlir::TF::SubOp sub_op;
1258 {
1259 ::mlir::Value tblgen_value_0 = (*offset.begin());
1260 ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin());
1261 sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
1262 /*x=*/tblgen_value_0,
1263 /*y=*/tblgen_value_1);
1264 }
1265 ::mlir::TF::AddOp add_op_2;
1266 {
1267 ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1268 ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1269 tblgen_values.push_back((*mul_op_1.getODSResults(0).begin()));
1270 tblgen_values.push_back((*sub_op.getODSResults(0).begin()));
1271 ::mlir::SmallVector<::mlir::Type, 4> tblgen_types;
1272 for (auto v : fused_batch_norm_op.getODSResults(0)) {
1273 tblgen_types.push_back(v.getType());
1274 }
1275 add_op_2 = rewriter.create<::mlir::TF::AddOp>(
1276 odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
1277 }
1278 for (auto v :
1279 ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
1280 replace_values.push_back(v);
1281 }
1282 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1283 replace_values.push_back(v);
1284 }
1285 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1286 replace_values.push_back(v);
1287 }
1288 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1289 replace_values.push_back(v);
1290 }
1291 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1292 replace_values.push_back(v);
1293 }
1294 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1295 replace_values.push_back(v);
1296 }
1297 rewriter.replaceOp(fused_batch_norm, replace_values);
1298 return success();
1299 };
1300 };
1301
1302 #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
1303
1304 // Returns success if all the operations in the `op`'s regions including `op`
1305 // itself are legal in a TFLite pipeline.
ValidateOp(Operation * op)1306 LogicalResult ValidateOp(Operation *op) {
1307 bool has_illegal_ops = false;
1308 op->walk([&](Operation *op) {
1309 if (isa<TF::VariableV2Op>(op)) {
1310 has_illegal_ops = true;
1311 op->emitOpError() << "is illegal in a TFLite pipeline";
1312 }
1313 });
1314
1315 return failure(has_illegal_ops);
1316 }
1317
1318 // Converts a set of TF2XLA ops into pure TF ops for future legalizations as
1319 // TF2XLA ops aren't supported by later stages.
ConvertTf2XlaOps(FuncOp func,MLIRContext * context)1320 LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
1321 ConversionTarget target(*context);
1322 target.addLegalDialect<StandardOpsDialect>();
1323 target.addLegalDialect<TF::TensorFlowDialect>();
1324 target.addLegalOp<ModuleOp>();
1325 target.addLegalOp<FuncOp>();
1326 target.addIllegalOp<TF::XlaConvOp>();
1327 target.addIllegalOp<TF::XlaGatherOp>();
1328
1329 OwningRewritePatternList patterns;
1330 mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns);
1331 mhlo::PopulateLegalizeTfPatterns(context, &patterns);
1332 TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
1333 mhlo::GatherOp::getCanonicalizationPatterns(patterns, context);
1334
1335 return applyPartialConversion(func, target, std::move(patterns));
1336 }
1337
1338 // Convert rfft to rfft2d.
1339 // The transformation pattern looks like below:
1340 //
1341 // input fft_len
1342 // \ /
1343 // rfft
1344 //
1345 // ||
1346 // \/
1347 //
1348 // input fft_len
1349 // \ /
1350 // expand_dim concat with [1] at the front
1351 // \ /
1352 // rfft_2d
1353 // |
1354 // squeeze
1355 struct ConvertRfftToRfft2d : public RewritePattern {
ConvertRfftToRfft2dmlir::TFL::__anonf6df4be00111::ConvertRfftToRfft2d1356 explicit ConvertRfftToRfft2d(MLIRContext *context)
1357 : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {}
1358
matchAndRewritemlir::TFL::__anonf6df4be00111::ConvertRfftToRfft2d1359 LogicalResult matchAndRewrite(Operation *op,
1360 PatternRewriter &rewriter) const override {
1361 auto rfft_op = dyn_cast<TF::RFFTOp>(op);
1362
1363 auto input = rfft_op.input();
1364 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
1365 if (!input_type) return failure();
1366 auto fft_len = rfft_op.fft_length();
1367 auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>();
1368 if (!fft_len_type) return failure();
1369
1370 auto output_type =
1371 rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
1372 if (!output_type) return failure();
1373
1374 // Expanded inputs.
1375 // Insert at -2 location.
1376 auto one_ele_type =
1377 mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32));
1378 auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1379 one_ele_type, -2);
1380
1381 SmallVector<int64_t, 4> expanded_input_shape;
1382 SmallVector<int64_t, 4> expanded_output_shape;
1383 int expanded_rank = input_type.getRank() + 1;
1384 int r = 0;
1385 for (int i = 0; i < expanded_rank; ++i) {
1386 if (i == expanded_rank - 2) {
1387 expanded_input_shape.push_back(1);
1388 expanded_output_shape.push_back(1);
1389 } else {
1390 expanded_input_shape.push_back(input_type.getDimSize(r));
1391 expanded_output_shape.push_back(output_type.getDimSize(r));
1392 r++;
1393 }
1394 }
1395
1396 auto expaned_input_type = mlir::RankedTensorType::get(
1397 expanded_input_shape, input_type.getElementType());
1398 TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>(
1399 rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult());
1400
1401 // Expanded fft_len.
1402 auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1});
1403
1404 auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr);
1405
1406 auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1407 one_ele_type, 0);
1408
1409 auto expanded_fft_len_type =
1410 mlir::RankedTensorType::get({2}, fft_len_type.getElementType());
1411
1412 TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>(
1413 rfft_op.getLoc(), expanded_fft_len_type,
1414 SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult());
1415
1416 // Insert the rfft_2d.
1417 auto rfft2d_out_type = mlir::RankedTensorType::get(
1418 expanded_output_shape, output_type.getElementType());
1419 TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>(
1420 rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(),
1421 expanded_fft_len.getResult());
1422
1423 // Insert the squeeze op.
1424 auto squeeze_dim = rewriter.getI64ArrayAttr({-2});
1425 TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>(
1426 rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim);
1427
1428 rewriter.replaceOp(op, squeeze.getResult());
1429
1430 return success();
1431 }
1432 };
1433
runOnFunction()1434 void PrepareTFPass::runOnFunction() {
1435 OwningRewritePatternList patterns, phase_2_patterns;
1436 auto func = getFunction();
1437 MLIRContext *ctx = &getContext();
1438
1439 // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since
1440 // PrepareTFPass is the very first TFLite pass in the pipeline.
1441 // TODO(jingpu): It might be better to split this check into its own pass
1442 // to make things more modular.
1443 if (failed(ValidateOp(func))) {
1444 func.emitError() << "tfl-prepare-tf pass failed.";
1445 signalPassFailure();
1446 return;
1447 }
1448
1449 if (failed(ConvertTf2XlaOps(func, ctx))) {
1450 signalPassFailure();
1451 return;
1452 }
1453
1454 // This pattern was intented to uses TFL QDQs to preserve the quantization
1455 // parameters from the TF Quant ops, thus this pattern should run with the
1456 // first `applyPatternsGreedily` method, which would otherwise removes the
1457 // TF FakeQuant ops by the constant folding.
1458 patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant,
1459 PreparePerTensorFakeQuantWithMinMaxArgs>(ctx);
1460
1461 // This pattern will try to identify and optimize for dilated convolution.
1462 // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
1463 // replaced with a single Conv op with dilation parameter.
1464 patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
1465 ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
1466
1467 TFL::populateWithGenerated(ctx, patterns);
1468 // TODO(karimnosseir): Split to separate pass probably after
1469 // deciding on long term plan for this optimization.
1470 // This will allow optimizing any TF_Mul->TF_Conv in the graph
1471 // and any expanded from FusedBatchNorm. We need to do this
1472 // before converting TF_Conv to TFL_Conv
1473 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1474
1475 // Load the generated pattern again, so new quantization pass-through
1476 // will be applied.
1477 TFL::populateWithGenerated(ctx, phase_2_patterns);
1478 if (unfold_batch_matmul_) {
1479 phase_2_patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
1480 TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(
1481 ctx);
1482 }
1483 phase_2_patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo,
1484 ConvertTFStridedSlice, ConvertRfftToRfft2d>(ctx);
1485 phase_2_patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
1486 ctx, allow_bf16_and_f16_type_legalization_);
1487
1488 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1489 }
1490
1491 } // namespace
1492
1493 // Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
CreatePrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_type_legalization)1494 std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
1495 bool unfold_batch_matmul, bool allow_bf16_type_legalization) {
1496 return std::make_unique<PrepareTFPass>(unfold_batch_matmul,
1497 allow_bf16_type_legalization);
1498 }
1499
1500 static PassRegistration<PrepareTFPass> pass(
1501 "tfl-prepare-tf", "Prepare TF for legalization to TensorFlow Lite dialect");
1502
1503 } // namespace TFL
1504 } // namespace mlir
1505