• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting operations in TensorFlow dialect into operations that can be
18 // legalized to TensorFlow Lite dialect with simple replacements.  The newly
19 // created operations are in the TensorFlow dialect if the operation can be
20 // represented using a TensorFlow op.  Otherwise, TensorFlow Lite dialect op is
21 // used.  For example, Conv2D in TFLite which uses OHWI data format for filters
22 // is not supported in TensorFlow because TensorFlow requires filters in the
23 // HWIO data format.
24 //
25 // Motivation to prepare for the TFLite legalization before the actual
26 // legalization is to exploit constant folding opportunities in any newly
27 // created ops by leveraging constant folding support for the TensorFlow ops.
28 // This way TFLite can be used as a serialization format only and does not
29 // require access to the TFLite runtime for optimizations as required by the
30 // TFLite team.
31 
32 #include <climits>
33 #include <cstdint>
34 
35 #include "absl/memory/memory.h"
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/StringSwitch.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/Debug.h"
41 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
42 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
43 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
44 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
45 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
46 #include "mlir/IR/Attributes.h"  // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
49 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
50 #include "mlir/Pass/Pass.h"  // from @llvm-project
51 #include "mlir/Support/LLVM.h"  // from @llvm-project
52 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
53 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
54 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
55 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
56 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
57 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
58 #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
59 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
60 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
61 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
62 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
63 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
64 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
65 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
66 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
67 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
68 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
69 
70 #define DEBUG_TYPE "tf-tfl-legalization"
71 
72 namespace mlir {
73 namespace TFL {
74 //===----------------------------------------------------------------------===//
75 // The actual PrepareTF Pass.
76 //
77 // TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this
78 // pass.
79 namespace {
80 
81 // Prepare TF operations in functions for subsequent legalization.
82 class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
83  public:
84   PrepareTFPass() = default;
PrepareTFPass(const PrepareTFPass &)85   PrepareTFPass(const PrepareTFPass &) {}
PrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_and_f16_type_legalization)86   explicit PrepareTFPass(bool unfold_batch_matmul,
87                          bool allow_bf16_and_f16_type_legalization) {
88     unfold_batch_matmul_ = unfold_batch_matmul;
89     allow_bf16_and_f16_type_legalization_ =
90         allow_bf16_and_f16_type_legalization;
91   }
92   void runOnFunction() override;
93 
getDependentDialects(DialectRegistry & registry) const94   void getDependentDialects(DialectRegistry &registry) const override {
95     registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
96                     TFL::TensorFlowLiteDialect>();
97   }
98 
99  private:
100   Option<bool> unfold_batch_matmul_{
101       *this, "tfl-unfold-batch-matmul",
102       llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
103       llvm::cl::init(true)};
104 
105   Option<bool> allow_bf16_and_f16_type_legalization_{
106       *this, "tfl-allow-bf16-and-f16-type-legalization",
107       llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)};
108 };
109 
110 template <class TFFakeQuantOp>
111 struct FetchConstantMinMaxInputs {
112   using AttrType = DenseFPElementsAttr;
operator ()mlir::TFL::__anonf6df4be00111::FetchConstantMinMaxInputs113   bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
114                   AttrType &max_value) const {
115     Value min = tf_op.min(), max = tf_op.max();
116 
117     // TODO: incomplete  neither IdentityN ops
118     // nor chains of Identity* (not rare) are handled
119     if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
120       min = id1.input();
121     if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
122       max = id2.input();
123     if (!matchPattern(min, m_Constant(&min_value))) {
124       return false;
125     }
126     if (!matchPattern(max, m_Constant(&max_value))) {
127       return false;
128     }
129     return true;  // Succesfully matched and fetched.
130   }
131 };
132 
133 template <class TFFakeQuantOp>
134 struct FetchMinMaxAttrs {
135   using AttrType = FloatAttr;
operator ()mlir::TFL::__anonf6df4be00111::FetchMinMaxAttrs136   bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
137                   AttrType &max_value) const {
138     min_value = tf_op.minAttr();
139     max_value = tf_op.maxAttr();
140     return true;  // Succesfully matched and fetched.
141   }
142 };
143 
144 // TODO(fengliuai): move this rule to PreparePatterns.td
145 // TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass.
146 // TODO(b/140968741): propagate the sign from the command line. Currently all
147 // the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
148 // actually INT8.
149 // Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
150 // tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op
151 // to be constant folded. Since the constant
152 // folding logic will use a "std.constant" op to replace the
153 // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
154 // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
155 // convert the output type to the next op. Here are the transformations:
156 //
157 // input   min cst       max cst          input   min cst       max cst
158 //  \       |             |                \       |             |
159 //   \  (tf.Identity) (tf.Identity)   =>    \  (tf.Identity) (tf.Identity)
160 //    \     |             |                  \     |             |
161 //       tf.FakeQuantWithMinMaxVars       tf.FakeQuantWithMinMaxVars
162 //                   |                                 |
163 //                                                tf.quantize
164 //                                                     |
165 //                                                tf.dequantize
166 //                                                     |
167 // If the input is a constant, the result pattern will eventually converted to
168 //
169 //            quant-emulated input
170 //                   |
171 //               tf.quantize
172 //                   |
173 //              tf.dequantize
174 //                   |
175 //
176 //
177 // Warns if the (most likely unwanted, currently not quite correctly handled)
178 // case of back-to-back tf.FakeQuant occurs
179 //
180 //             tf.FakeQuant*
181 //                   |
182 //             tf.FakeQuant*
183 //
184 // tf.identity / tf.IdentityN between the tf.FakeQuant* ops
185 // need no special treatment are already eliminated before the rewrites / check
186 // is applied.
187 //
188 
189 template <typename TFFakeQuantOp, bool PerAxis, class FetchMinMax>
190 struct InsertTFLQuantOpsAfterTFFakeQuantOp
191     : public OpRewritePattern<TFFakeQuantOp> {
192   using BaseType =
193       InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis, FetchMinMax>;
194 
InsertTFLQuantOpsAfterTFFakeQuantOpmlir::TFL::__anonf6df4be00111::InsertTFLQuantOpsAfterTFFakeQuantOp195   explicit InsertTFLQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis,
196                                                FetchMinMax>(MLIRContext *ctx)
197       : OpRewritePattern<TFFakeQuantOp>(ctx) {}
198 
199   FetchMinMax fetchMinMax;
200 
201   using FetchAttrType = typename FetchMinMax::AttrType;
matchAndRewritemlir::TFL::__anonf6df4be00111::InsertTFLQuantOpsAfterTFFakeQuantOp202   LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
203                                 PatternRewriter &rewriter) const override {
204     // We don't want to insert quantize/dequantize if the quantize op exists.
205     auto res = tf_op.outputs();
206     if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin())) {
207       return failure();
208     }
209 
210     // Extract the min/max constant values from the operands. We also consider
211     // a special case that there are tf.Identity ops between the min/max
212     // constants and the tf.FakeQuantWithMinMaxVarsOp.
213 
214     FetchAttrType min_value, max_value;
215     if (!fetchMinMax(tf_op, min_value, max_value)) {
216       return failure();
217     }
218 
219     int quant_dim = -1;
220     if (PerAxis) {
221       // This is a special case that the quant_dim is the last dimensions.
222       quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
223     }
224     // Use the min/max from the operands and the num_bits and narrow_range
225     // attribute to create the quantization parameter for the new quantize op.
226     rewriter.setInsertionPointAfter(tf_op.getOperation());
227     IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits());
228     BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
229     Type res_type = tf_op.getType();
230     TypeAttr qtype = quant::GetQuantizedTypeAttr(
231         rewriter, res_type, min_value, max_value, quant_dim, num_bits,
232         narrow_range, /*is_signed=*/false);
233     if (!qtype) {
234       return failure();
235     }
236 
237     // Finally, use the quantization parameter to create the quantize and
238     // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
239     // and its users.
240     Value value = tf_op.outputs();
241     auto quantize = rewriter.create<TFL::QuantizeOp>(
242         tf_op.getLoc(), qtype.getValue(), value, qtype);
243     auto dequantize = rewriter.create<TFL::DequantizeOp>(
244         tf_op.getLoc(), res_type, quantize.output());
245     value.replaceAllUsesWith(dequantize);
246     quantize.getOperation()->replaceUsesOfWith(dequantize, value);
247 
248     return success();
249   }
250 };
251 
252 //
253 // Three instances of the rule to cover the three different types of
254 // TF::FakeQuant operators
255 //
256 using PreparePerTensorFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp<
257     TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false,
258     FetchConstantMinMaxInputs<TF::FakeQuantWithMinMaxVarsOp>>;
259 
260 using PreparePerChannelFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp<
261     TF::FakeQuantWithMinMaxVarsPerChannelOp, /*PerAxis=*/true,
262     FetchConstantMinMaxInputs<TF::FakeQuantWithMinMaxVarsPerChannelOp>>;
263 
264 using PreparePerTensorFakeQuantWithMinMaxArgs =
265     InsertTFLQuantOpsAfterTFFakeQuantOp<
266         TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false,
267         FetchMinMaxAttrs<TF::FakeQuantWithMinMaxArgsOp>>;
268 
269 // Transient state for preserving data from match to rewrite
270 struct ConvertTFConvOpMatchState {
271   IntegerAttr dilation_height_factor;
272   IntegerAttr dilation_width_factor;
273   StringAttr padding;
274   IntegerAttr stride_height;
275   IntegerAttr stride_width;
276 };
277 
278 // Templated class for declaring a converter from some TensorFlow convolution
279 // op into its counterpart in TensorFlow Lite.
280 //
281 // The `ConcreteType` deriving from this template must provide the following
282 // method for constructing TensorFlow Lite op:
283 //
284 //   TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
285 //                         PatternRewriter &rewriter, Location loc,
286 //                         Type result_type, Value input,
287 //                         Value filter, Value bias) const;
288 //
289 // And also the following method for getting the dimension for bias tensor:
290 //
291 //  int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
292 template <typename ConcreteType, typename TFConvOpType>
293 class ConvertTFConvOp : public RewritePattern {
294  public:
ConvertTFConvOp(MLIRContext * context,bool allow_bf16_and_f16_type_legalization)295   ConvertTFConvOp(MLIRContext *context,
296                   bool allow_bf16_and_f16_type_legalization)
297       : RewritePattern(TFConvOpType::getOperationName(), 1, context),
298         intAttrOne(Builder(context).getI32IntegerAttr(1)),
299         allow_bf16_and_f16_type_legalization_(
300             allow_bf16_and_f16_type_legalization) {}
301 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const302   LogicalResult matchAndRewrite(Operation *op,
303                                 PatternRewriter &rewriter) const override {
304     // Assumes TensorFlow convolution op is already verified to be
305     // in valid form.
306 
307     // Match a TFConvOpType under the following conditions:
308     // * The 'T' attribute must exist and be of value DT_FLOAT.
309     // * The 'data_format' attribute must exist and be of value "NHWC".
310     // * The 'strides' attribute must exist and is of the form [1, X, Y, 1].
311     // * The 'dilations' attribute is optional, but it must be of the form
312     //   [1, X, Y, 1] if exists.
313 
314     TFConvOpType tf_op = cast<TFConvOpType>(op);
315 
316     if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
317         !(allow_bf16_and_f16_type_legalization_ &&
318           TFTypeIsBFloat16OrHalfTensor(tf_op.input())))
319       return failure();
320 
321     if (!TFDataFormatIsNHWC(op)) return failure();
322 
323     IntegerAttr height, width;
324     if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
325 
326     ConvertTFConvOpMatchState state;
327     state.stride_height = height;
328     state.stride_width = width;
329 
330     if (TFIntListIs1XY1(op, "dilations", &height, &width)) {
331       state.dilation_height_factor = height;
332       state.dilation_width_factor = width;
333     } else {
334       // If the 'dilations' attribute is missing, we use the default value (1)
335       // for both dilation height and width factor.
336       state.dilation_height_factor = intAttrOne;
337       state.dilation_width_factor = intAttrOne;
338     }
339 
340     if (!TFPaddingIsSameOrValid(op, &state.padding)) return failure();
341 
342     // Additionally, we require the filter operand to be of 4-D tensor type so
343     // that we can extract info from the shape (e.g., for constructing bias
344     // tensor, for setting depth_multiplier attribute, etc.).
345     auto filter = tf_op.filter();
346     auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
347     if (!filter_type || filter_type.getRank() != 4 ||
348         !filter_type.hasStaticShape())
349       return failure();
350 
351     // TensorFlow convolution op only has two inputs, while the TFLite one has
352     // three, with the bias vector marked as optional. However, TOCO has a
353     // dedicated pass, EnsureBiasVectors, to create default bias vectors for all
354     // those missing. So we model TFLite convolution op as requiring three
355     // inputs to achieve the legalization task of EnsureBiasVector. this
356     // requires the filter tensor to have static shape.
357 
358     // TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>)
359 
360     // Get a splat zero tensor with the expected dimension for the bias tensor
361     auto elem_type = filter_type.getElementType();
362     auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
363         filter_type.getShape());
364     auto bias_type = RankedTensorType::get({bias_dim}, elem_type);
365     auto bias_attr = rewriter.getZeroAttr(bias_type);
366     auto bias =
367         rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
368 
369     auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
370         &state, rewriter, op->getLoc(), tf_op.getType(), tf_op.input(), filter,
371         bias);
372 
373     rewriter.replaceOp(op, conv_op.getResult());
374     return success();
375   }
376 
377   const IntegerAttr intAttrOne;
378 
379  private:
380   bool allow_bf16_and_f16_type_legalization_;
381 };
382 
383 class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
384  public:
385   using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
386 
ConvertTFConv2D(MLIRContext * context,bool allow_bf16_type_legalization)387   ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization)
388       : BaseType(context, allow_bf16_type_legalization) {}
389 
getBiasDim(ArrayRef<int64_t> filterShape) const390   int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
391     return filterShape.back();
392   }
393 
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const394   TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
395                             PatternRewriter &rewriter, Location loc,
396                             Type result_type, Value input, Value filter,
397                             Value bias) const {
398     filter = legalizeFilter(rewriter, loc, filter);
399     return rewriter.create<TFL::Conv2DOp>(
400         loc, result_type, input, filter, bias,
401         /*dilation_h_factor=*/state->dilation_height_factor,
402         /*dilation_w_factor=*/state->dilation_width_factor,
403         /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
404         /*padding=*/state->padding,
405         /*stride_h=*/state->stride_height,
406         /*stride_w=*/state->stride_width);
407   }
408 
409  private:
410   // Legalize the given filter by converting it from TensorFlow filter data
411   // format HWIO to TFLite Conv2D op filter data format OHWI and return Value
412   // for the converted filter.  Requires that filter is verified by the match
413   // method that it is a 4-D RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const414   Value legalizeFilter(PatternRewriter &rewriter, Location loc,
415                        Value filter) const {
416     // Create a constant op for HWIO to OHWI transpose permutation.
417     SmallVector<int, 4> perm = {3, 0, 1, 2};
418     auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
419                                            rewriter.getIntegerType(32));
420     auto perm_attr =
421         DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
422     auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
423 
424     // Create tensor type for the transpose result.
425     auto filter_type = filter.getType().cast<RankedTensorType>();
426     auto result_shape =
427         llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) {
428           return filter_type.getDimSize(dim);
429         }));
430     auto elem_type = filter_type.getElementType();
431     auto result_type = RankedTensorType::get(result_shape, elem_type);
432 
433     return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op);
434   }
435 };
436 
437 class ConvertTFDepthwiseConv2dNative
438     : public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
439                              TF::DepthwiseConv2dNativeOp> {
440  public:
441   using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
442                                    TF::DepthwiseConv2dNativeOp>;
443 
ConvertTFDepthwiseConv2dNative(MLIRContext * context,bool allow_bf16_type_legalization)444   ConvertTFDepthwiseConv2dNative(MLIRContext *context,
445                                  bool allow_bf16_type_legalization)
446       : BaseType(context, allow_bf16_type_legalization) {}
447 
getBiasDim(ArrayRef<int64_t> filterShape) const448   int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
449     return filterShape[2] * filterShape[3];
450   }
451 
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const452   TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
453                                      PatternRewriter &rewriter, Location loc,
454                                      Type result_type, Value input,
455                                      Value filter, Value bias) const {
456     // Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
457     // 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
458     // have a corresponding 'depth_multiplier' attribute; the multiplier is the
459     // fourth dimension in the 4-D filter tensor. We query the multiplier from
460     // tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
461     auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
462 
463     filter = legalizeFilter(rewriter, loc, filter);
464     return rewriter.create<TFL::DepthwiseConv2DOp>(
465         loc, result_type, input, filter, bias,
466         /*dilation_h_factor=*/state->dilation_height_factor,
467         /*dilation_w_factor=*/state->dilation_width_factor,
468         /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
469         /*padding=*/state->padding,
470         /*stride_h=*/state->stride_height,
471         /*stride_w=*/state->stride_width,
472         /*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier));
473   }
474 
475  private:
476   /// Legalize the given filter by converting it from TensorFlow filter data
477   /// format to TFLite DepthwiseConv2D op filter data format and return Value
478   /// for the converted filter.  TensorFlow filter data format is
479   /// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite
480   /// filter data format is [1, filter_height, filter_width, out_channels].
481   /// Requires that filter is verified by the match method that it is a 4-D
482   /// RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const483   Value legalizeFilter(PatternRewriter &rewriter, Location loc,
484                        Value filter) const {
485     auto filter_type = filter.getType().cast<RankedTensorType>();
486     auto filterShape = filter_type.getShape();
487     SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
488                                             filterShape[2] * filterShape[3]};
489     auto elem_type = filter_type.getElementType();
490     auto result_type = RankedTensorType::get(result_shape, elem_type);
491     // TensorFlow Lite `Reshape` op only support int32 shape tensor currently.
492     auto shape_type = RankedTensorType::get({4}, rewriter.getIntegerType(32));
493     SmallVector<Attribute, 4> result_shape_data(4);
494     for (int i = 0; i < 4; ++i) {
495       result_shape_data[i] =
496           rewriter.getI32IntegerAttr(static_cast<int32_t>(result_shape[i]));
497     }
498     auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
499     auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
500 
501     return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
502   }
503 };
504 
505 // StridedSlice can have complicated attributes like begin_axis_mask,
506 // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
507 // masks will complicate the strided_slice computation logic, we can simplify
508 // the logic by inserting a reshape op to pad the inputs so strided_slice can
509 // be easier to handle.
510 //
511 // So the graph may looks like below:
512 //   original_input -> strided_slice -> output
513 //      (transforms)
514 //   original_input -> reshape -> strided_slice -> output
515 //
516 // And the new shape is computed based on the masks.
517 //
518 // An example for new_axis_mask. say the new_axis_mask is 9 which represents
519 // [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so
520 // if original shape is [2, 3], now we reshape that into [1, 2, 3, 1].
521 struct ConvertTFStridedSlice : public RewritePattern {
ConvertTFStridedSlicemlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice522   explicit ConvertTFStridedSlice(MLIRContext *context)
523       : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
524 
RewriteNewAxisMaskmlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice525   LogicalResult RewriteNewAxisMask(Operation *op,
526                                    PatternRewriter &rewriter) const {
527     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
528     uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
529 
530     // Insert a new reshape op.
531     Value original_input = strided_slice_op.input();
532     RankedTensorType original_input_type =
533         original_input.getType().dyn_cast<RankedTensorType>();
534     if (!original_input_type) {
535       return failure();
536     }
537 
538     const ArrayRef<int64_t> &original_input_shape =
539         original_input_type.getShape();
540     SmallVector<int64_t, 4> revised_shape;
541     int index = 0;
542     const int original_input_rank = original_input_shape.size();
543     while (index < original_input_rank || new_axis_mask) {
544       if (new_axis_mask & 1) {
545         revised_shape.emplace_back(1);
546       } else {
547         revised_shape.emplace_back(original_input_shape[index++]);
548       }
549       new_axis_mask >>= 1;
550     }
551 
552     if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
553 
554     const int dim_size = revised_shape.size();
555     Location loc = strided_slice_op.getLoc();
556     auto shape_type =
557         RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
558     SmallVector<Attribute, 4> result_shape_data(dim_size);
559     for (int i = 0; i < dim_size; ++i) {
560       result_shape_data[i] =
561           rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
562     }
563 
564     auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
565     auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
566     auto revised_output_type = RankedTensorType::get(
567         revised_shape, original_input_type.getElementType());
568     TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
569         loc, revised_output_type, original_input, shape);
570 
571     // Replace the original strided_slice.
572     uint64_t revised_begin_mask = strided_slice_op.begin_mask();
573     uint64_t revised_end_mask = strided_slice_op.end_mask();
574     // Since we expand the dims, we need to apply them to the begin_mask &
575     // end_mask.
576     revised_begin_mask |= strided_slice_op.new_axis_mask();
577     revised_end_mask |= strided_slice_op.new_axis_mask();
578 
579     auto attribute_type = rewriter.getIntegerType(64);
580     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
581         op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
582         strided_slice_op.end(), strided_slice_op.strides(),
583         rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
584         rewriter.getIntegerAttr(attribute_type, revised_end_mask),
585         rewriter.getIntegerAttr(attribute_type,
586                                 strided_slice_op.ellipsis_mask()),
587         rewriter.getI64IntegerAttr(0),
588         rewriter.getIntegerAttr(attribute_type,
589                                 strided_slice_op.shrink_axis_mask()));
590     return success();
591   }
592 
RewriteEllipsisMaskmlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice593   LogicalResult RewriteEllipsisMask(Operation *op,
594                                     PatternRewriter &rewriter) const {
595     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
596 
597     uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
598     uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
599 
600     // Enforce operator precedence.
601     shrink_axis_mask &= ~ellipsis_mask;
602 
603     DenseIntElementsAttr begin_dense_elem_attr;
604     Value begin = strided_slice_op.begin();
605     auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
606     if (!begin_ranked_attr_type ||
607         !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
608       return failure();
609     }
610 
611     DenseIntElementsAttr end_dense_elem_attr;
612     Value end = strided_slice_op.end();
613     auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
614     if (!end_ranked_attr_type ||
615         !matchPattern(end, m_Constant(&end_dense_elem_attr))) {
616       return failure();
617     }
618 
619     DenseIntElementsAttr stride_dense_elem_attr;
620     Value stride = strided_slice_op.strides();
621     auto stride_ranked_attr_type =
622         stride.getType().dyn_cast<RankedTensorType>();
623     if (!stride_ranked_attr_type ||
624         !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
625       return failure();
626     }
627 
628     Value input = strided_slice_op.input();
629     RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
630     if (!input_type) {
631       return failure();
632     }
633     const ArrayRef<int64_t> input_shape = input_type.getShape();
634 
635     const int input_size = input_shape.size();
636 
637     RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
638     const ArrayRef<int64_t> begin_shape = begin_type.getShape();
639     const int begin_dim = begin_shape.size();
640 
641     if (begin_dim != 1) return failure();
642 
643     const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
644 
645     int64_t begin_mask = strided_slice_op.begin_mask();
646     int64_t end_mask = strided_slice_op.end_mask();
647     int64_t revised_begin_mask = 0;
648     int64_t revised_end_mask = 0;
649     int64_t revised_shrink_axis_mask = 0;
650 
651     SmallVector<int32_t, 4> padded_begin;
652     SmallVector<int32_t, 4> padded_end;
653     SmallVector<int32_t, 4> padded_stride;
654 
655     // Before the ellipsis.
656     int index = 0;
657     int new_index = 0;
658     while (((ellipsis_mask >> index) & 1) == 0) {
659       padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
660       padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
661       padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
662       if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
663       if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
664       if ((shrink_axis_mask >> index) & 1)
665         revised_shrink_axis_mask |= (1 << new_index);
666       ++index;
667       ++new_index;
668     }
669 
670     // Ellipsis.
671     for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
672       revised_begin_mask |= (1 << new_index);
673       revised_end_mask |= (1 << new_index);
674 
675       // Mimic the begin/end/strides mask behavior.
676       padded_begin.push_back(0);
677       padded_end.push_back(0);
678       padded_stride.push_back(1);
679     }
680 
681     // Account for ellipsis mask.
682     ++index;
683 
684     // After the ellipsis.
685     for (; index < begin_shape[0];) {
686       padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
687       padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
688       padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
689 
690       if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
691       if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
692       if ((shrink_axis_mask >> index) & 1)
693         revised_shrink_axis_mask |= (1 << new_index);
694 
695       ++index;
696       ++new_index;
697     }
698 
699     auto attribute_type = rewriter.getIntegerType(64);
700 
701     int full_dim_count = padded_begin.size();
702     auto type =
703         RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
704 
705     auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
706     auto begin_op = rewriter.create<ConstantOp>(op->getLoc(), type, begin_attr);
707     auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
708     auto end_op = rewriter.create<ConstantOp>(op->getLoc(), type, end_attr);
709     auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
710     auto stride_op =
711         rewriter.create<ConstantOp>(op->getLoc(), type, stride_attr);
712 
713     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
714         op, strided_slice_op.getType(), input, begin_op.getResult(),
715         end_op.getResult(), stride_op.getResult(),
716         rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
717         rewriter.getIntegerAttr(attribute_type, revised_end_mask),
718         /*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
719         rewriter.getIntegerAttr(attribute_type,
720                                 strided_slice_op.new_axis_mask()),
721         rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
722     return success();
723   }
724 
PadStridedSliceAttributeArraymlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice725   void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
726                                      SmallVectorImpl<int32_t> &val,
727                                      SmallVectorImpl<int32_t> &padded_val,
728                                      ArrayRef<int32_t> padding_val,
729                                      int *mask) const {
730     for (const auto &idx : dense_elem_attr.getIntValues()) {
731       val.push_back(idx.getSExtValue());
732       padded_val.push_back(idx.getSExtValue());
733     }
734     int attr_dim_count = val.size();
735     int full_dim_count = padding_val.size();
736     for (int i = attr_dim_count; i < full_dim_count; ++i) {
737       padded_val.push_back(padding_val[i]);
738       if (mask) *mask |= 1 << i;
739     }
740   }
741 
matchAndRewritemlir::TFL::__anonf6df4be00111::ConvertTFStridedSlice742   LogicalResult matchAndRewrite(Operation *op,
743                                 PatternRewriter &rewriter) const override {
744     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
745 
746     // Handle new axis mask.
747     if (strided_slice_op.new_axis_mask() != 0) {
748       // We currently don't handle simultaneous shrink_ and new_axis masks.
749       if (!strided_slice_op.shrink_axis_mask()) {
750         return RewriteNewAxisMask(strided_slice_op, rewriter);
751       }
752     }
753 
754     // Handle ellipsis mask.
755     if (strided_slice_op.ellipsis_mask() != 0) {
756       return RewriteEllipsisMask(strided_slice_op, rewriter);
757     }
758 
759     auto ranked_input_type =
760         strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
761     if (!ranked_input_type) {
762       return failure();
763     }
764 
765     auto begin_attr = strided_slice_op.begin();
766     auto end_attr = strided_slice_op.end();
767     auto strides_attr = strided_slice_op.strides();
768 
769     auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
770     auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
771     auto strides_attr_type =
772         strides_attr.getType().dyn_cast<RankedTensorType>();
773 
774     DenseIntElementsAttr begin_elem_attr;
775     DenseIntElementsAttr end_elem_attr;
776     DenseIntElementsAttr strides_elem_attr;
777 
778     if (!begin_attr_type ||
779         !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
780       return failure();
781     }
782     if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
783       return failure();
784     }
785     if (!strides_attr_type ||
786         !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
787       return failure();
788     }
789 
790     SmallVector<int32_t, 4> begin, end, strides;
791     SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
792 
793     int num_input_dims = ranked_input_type.getRank();
794     SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
795     auto input_shape = ranked_input_type.getShape();
796     SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
797     SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
798 
799     int begin_mask = strided_slice_op.begin_mask();
800     int end_mask = strided_slice_op.end_mask();
801 
802     PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
803                                   padding_begin, &begin_mask);
804     PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
805                                   &end_mask);
806     PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
807                                   padding_strides, nullptr);
808 
809     if (begin == padded_begin && end == padded_end &&
810         strides == padded_strides &&
811         begin_mask == strided_slice_op.begin_mask() &&
812         end_mask == strided_slice_op.end_mask()) {
813       return failure();
814     }
815 
816     auto begin_end_type =
817         RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
818     auto new_begin_attr = rewriter.create<ConstantOp>(
819         op->getLoc(), begin_end_type,
820         DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
821     auto new_end_attr = rewriter.create<ConstantOp>(
822         op->getLoc(), begin_end_type,
823         DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
824     auto strides_type =
825         RankedTensorType::get({static_cast<long>(padded_strides.size())},
826                               rewriter.getIntegerType(32));
827     auto new_strides_attr = rewriter.create<ConstantOp>(
828         op->getLoc(), strides_type,
829         DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
830 
831     auto attribute_type = rewriter.getIntegerType(64);
832     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
833         op, strided_slice_op.output().getType(), strided_slice_op.input(),
834         new_begin_attr, new_end_attr, new_strides_attr,
835         rewriter.getIntegerAttr(attribute_type, begin_mask),
836         rewriter.getIntegerAttr(attribute_type, end_mask),
837         rewriter.getIntegerAttr(attribute_type,
838                                 strided_slice_op.ellipsis_mask()),
839         rewriter.getIntegerAttr(attribute_type,
840                                 strided_slice_op.new_axis_mask()),
841         rewriter.getIntegerAttr(attribute_type,
842                                 strided_slice_op.shrink_axis_mask()));
843 
844     return success();
845   }
846 };
847 
848 struct ConvertTFBroadcastTo : public RewritePattern {
ConvertTFBroadcastTomlir::TFL::__anonf6df4be00111::ConvertTFBroadcastTo849   explicit ConvertTFBroadcastTo(MLIRContext *context)
850       : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
851 
matchAndRewritemlir::TFL::__anonf6df4be00111::ConvertTFBroadcastTo852   LogicalResult matchAndRewrite(Operation *op,
853                                 PatternRewriter &rewriter) const override {
854     auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
855     auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
856     auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
857     auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
858     Type element_type = input_type.getElementType();
859 
860     // Allow lowering when low dimension inputs are given and its type is F32 or
861     // I32.
862     if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
863           (shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
864            shape_type.getDimSize(0) <= 4)))
865       return failure();
866 
867     if (!(element_type.isa<BFloat16Type, Float32Type>() ||
868           element_type.isInteger(32) || element_type.isInteger(16)))
869       return failure();
870 
871     auto status_or_const_op =
872         CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
873     if (!status_or_const_op.ok()) {
874       return failure();
875     }
876 
877     auto tf_fill_op = rewriter.create<TF::FillOp>(
878         op->getLoc(), output_type, tf_broadcast_to_op.shape(),
879         status_or_const_op.ValueOrDie());
880 
881     auto mul_op = rewriter.create<TF::MulOp>(
882         op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
883     rewriter.replaceOp(op, mul_op.getResult());
884     return success();
885   }
886 };
887 
888 // The below pattern is equivalent to the DRR rule below
889 // The checks are dependent on generated values, so we can't add
890 // the checks on intermediate values, ideally we should find equivalent
891 // checks that guarantees the resultant ops are valid.
892 // The extra conditions are the broadcasting conditions.
893 //
894 // The pattern lower FusedBatchNormV3 to arithmetic ops.
895 // Specifically, performs the following calculation:
896 //
897 //   (x - mean) * scale / sqrt(variance + epsilon) + offset
898 //
899 // Let multiplier = scale / sqrt(variance + epsilon),
900 // to compute
901 //   (x - mean) * scale / sqrt(variance + epsilon) + offset,
902 // is then to compute
903 //   (x * multiplier) + (offset - mean * multiplier).
904 //
905 // def : Pattern<
906 //     (TF_FusedBatchNormV3Op:$root
907 //         $x, $scale, $offset, $mean, $variance,
908 //         F32Attr:$epsilon, $exponential_avg_factor,
909 //         $data_format, FalseBoolAttr:$is_training),
910 //     [(TF_AddOp
911 //         (TF_MulOp
912 //             $x,
913 //             (TF_MulOp:$multiplier
914 //                 $scale,
915 //                 (TF_RsqrtOp
916 //                     (TF_AddOp $variance,
917 //                               (TF_ConstOp $epsilon))))),
918 //         (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
919 //    // We already guaranteed that the last five results have no use so it does
920 //    // not matter what value we provide here for replacement.
921 //      /*batch_mean=*/(replaceWithValue $x),
922 //      /*batch_variance=*/(replaceWithValue $x),
923 //      /*reserve_space_1=*/(replaceWithValue $x),
924 //      /*reserve_space_2=*/(replaceWithValue $x),
925 //      /*reserve_space_3=*/(replaceWithValue $x)],
926 //     [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
927 //      (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
928 //      (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
929 //
930 // When is_training is set to true, the given variance and mean are not used.
931 // In above calculation, they are replaced by new values. These new mean and
932 // variance are calculated as following:
933 // rest_size = shape(x)[0] * shape(x)[1] * shape(x)[2]
934 // new_mean = sum(x, axis=[0, 1, 2]) / rest_size
935 // new_variance = sum(squared_difference(x, new_mean), axis=[0, 1, 2])
936 //                / rest_size
937 //
938 // The DDR rule for the is_training equals true case is as following:
939 // def : Pattern<
940 //     (TF_FusedBatchNormV3Op:$root
941 //         $x, $scale, $offset, $mean, $variance,
942 //         F32Attr:$epsilon, $exponential_avg_factor,
943 //         $data_format, FalseBoolAttr:$is_training),
944 //     [(TF_AddOp
945 //         (TF_MulOp
946 //             $x,
947 //             (TF_MulOp:$multiplier
948 //                 $scale,
949 //                 (TF_RsqrtOp
950 //                     (TF_AddOp
951 //                         (TF_DivOp:$new_variance
952 //                             (TF_SumOp
953 //                                 (TF_SquaredDifferenceOp $x, $new_mean),
954 //                                 (TF_ConstOp [0,1,2])),
955 //                             $rest_size),
956 //                         (TF_ConstOp $epsilon))))),
957 //         (TF_SubOp
958 //             $offset,
959 //             (TF_MulOp
960 //                 (TF_DivOp:$new_mean
961 //                     (TF_SumOp $x, (TF_ConstOp [0,1,2])),
962 //                     (TF_ProdOp:$rest_size
963 //                         (TF_SliceOp
964 //                             (TF_ShapeOp $x),
965 //                             (TF_ConstOp 0),
966 //                             (TF_ConstOp 3)))),
967 //                 $multiplier))),
968 //    // We already guaranteed that the last five results have no use so it does
969 //    // not matter what value we provide here for replacement.
970 //      /*batch_mean=*/(replaceWithValue $x),
971 //      /*batch_variance=*/(replaceWithValue $x),
972 //      /*reserve_space_1=*/(replaceWithValue $x),
973 //      /*reserve_space_2=*/(replaceWithValue $x),
974 //      /*reserve_space_3=*/(replaceWithValue $x)],
975 //     [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
976 //      (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
977 //      (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
978 
979 struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
FusedBatchNormV3Patmlir::TFL::__anonf6df4be00111::FusedBatchNormV3Pat980   explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
981       : ::mlir::RewritePattern(
982             "tf.FusedBatchNormV3",
983             {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}, 1,
984             context) {}
985 
matchAndRewritemlir::TFL::__anonf6df4be00111::FusedBatchNormV3Pat986   ::mlir::LogicalResult matchAndRewrite(
987       ::mlir::Operation *fused_batch_norm,
988       ::mlir::PatternRewriter &rewriter) const override {
989     // Variables for capturing values and attributes used for creating ops
990     Operation::operand_range mean(fused_batch_norm->getOperands());
991     ::mlir::FloatAttr exponential_avg_factor;
992     ::mlir::TF::FusedBatchNormV3Op root;
993     Operation::operand_range offset(fused_batch_norm->getOperands());
994     Operation::operand_range x(fused_batch_norm->getOperands());
995     Operation::operand_range scale(fused_batch_norm->getOperands());
996     Operation::operand_range variance(fused_batch_norm->getOperands());
997     ::mlir::FloatAttr epsilon;
998     ::mlir::BoolAttr is_training;
999 
1000     // Match
1001     auto fused_batch_norm_op =
1002         dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm);
1003     root = fused_batch_norm_op;
1004     x = fused_batch_norm_op.getODSOperands(0);
1005     scale = fused_batch_norm_op.getODSOperands(1);
1006     offset = fused_batch_norm_op.getODSOperands(2);
1007     mean = fused_batch_norm_op.getODSOperands(3);
1008     variance = fused_batch_norm_op.getODSOperands(4);
1009 
1010     ::mlir::Value mean_value = (*mean.begin());
1011     ::mlir::Value variance_value = (*variance.begin());
1012 
1013     if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
1014 
1015     {
1016       epsilon =
1017           fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>("epsilon");
1018       if (!epsilon)
1019         epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f);
1020 
1021       if (!(((epsilon.isa<::mlir::FloatAttr>())) &&
1022             ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) {
1023         return rewriter.notifyMatchFailure(
1024             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1025               diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to "
1026                       "satisfy constraint: 32-bit float attribute";
1027             });
1028       }
1029     }
1030     {
1031       exponential_avg_factor =
1032           fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>(
1033               "exponential_avg_factor");
1034       if (!exponential_avg_factor)
1035         exponential_avg_factor =
1036             rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
1037     }
1038     if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
1039         !TFDataFormatIsNDHWC(fused_batch_norm_op))
1040       return failure();
1041 
1042     if (!(((*root.getODSResults(1).begin()).use_empty()))) {
1043       return rewriter.notifyMatchFailure(
1044           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1045             diag << "entities '' failed to satisfy constraint: has no use";
1046           });
1047     }
1048 
1049     if (!(((*root.getODSResults(2).begin()).use_empty()))) {
1050       return rewriter.notifyMatchFailure(
1051           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1052             diag << "entities '' failed to satisfy constraint: has no use";
1053           });
1054     }
1055 
1056     if (!(((*root.getODSResults(3).begin()).use_empty()))) {
1057       return rewriter.notifyMatchFailure(
1058           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1059             diag << "entities '' failed to satisfy constraint: has no use";
1060           });
1061     }
1062 
1063     if (!(((*root.getODSResults(4).begin()).use_empty()))) {
1064       return rewriter.notifyMatchFailure(
1065           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1066             diag << "entities '' failed to satisfy constraint: has no use";
1067           });
1068     }
1069 
1070     if (!(((*root.getODSResults(5).begin()).use_empty()))) {
1071       return rewriter.notifyMatchFailure(
1072           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1073             diag << "entities '' failed to satisfy constraint: has no use";
1074           });
1075     }
1076 
1077     is_training =
1078         fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
1079     auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
1080 
1081     // We need to make sure input and output shapes are compatible.
1082     {
1083       int64_t last_dim = -1;
1084       auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
1085         auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
1086         if (!v_type) return true;
1087         int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
1088         if (v_last_dim == -1) return true;
1089         if (last_dim != -1 && v_last_dim != last_dim) return false;
1090         last_dim = v_last_dim;
1091         return true;
1092       };
1093 
1094       if (!is_last_dim_compatible(*x.begin(), last_dim) ||
1095           !is_last_dim_compatible(*scale.begin(), last_dim) ||
1096           !is_last_dim_compatible(*offset.begin(), last_dim)) {
1097         return rewriter.notifyMatchFailure(
1098             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1099               diag << "Shapes of scale and offset should be 1D and "
1100                       "compatible with x";
1101             });
1102       }
1103 
1104       if (!is_training.getValue()) {
1105         if (!is_last_dim_compatible(mean_value, last_dim) ||
1106             !is_last_dim_compatible(variance_value, last_dim)) {
1107           return rewriter.notifyMatchFailure(
1108               fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1109                 diag << "Shapes of mean and variance should be 1D and "
1110                         "compatible with x";
1111               });
1112         }
1113       }
1114 
1115       // Check if output shape and input shape are compatible.
1116       auto x_type = (*x.begin()).getType();
1117       auto y_type = (*root.getODSResults(0).begin()).getType();
1118       if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
1119         return rewriter.notifyMatchFailure(
1120             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1121               diag << "Shapes of x and the first output should be compatible";
1122             });
1123       }
1124     }
1125 
1126     // For training, mean and variance is calculated from input values.
1127     if (is_training.getValue()) {
1128       auto input_type = fused_batch_norm_op.x()
1129                             .getType()
1130                             .dyn_cast_or_null<RankedTensorType>();
1131       if (!input_type || input_type.getRank() != 4 ||
1132           !input_type.hasStaticShape()) {
1133         return rewriter.notifyMatchFailure(
1134             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
1135               diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
1136                       "True is only supported with static input shape";
1137             });
1138       }
1139 
1140       ::mlir::TF::ConstOp reduce_dim_op;
1141       {
1142         auto reduce_dim_type =
1143             ::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(64));
1144         ::mlir::SmallVector<int64_t, 3> reduce_dim_values = {0, 1, 2};
1145         reduce_dim_op = rewriter.create<TF::ConstOp>(
1146             odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
1147                                                       reduce_dim_values));
1148       }
1149 
1150       ::mlir::TF::ConstOp rest_size_inv_op;
1151       {
1152         int64_t rest_size = input_type.getDimSize(0) *
1153                             input_type.getDimSize(1) * input_type.getDimSize(2);
1154         auto rest_size_inv_type =
1155             ::mlir::RankedTensorType::get({1}, rewriter.getF32Type());
1156         auto rest_size_inv_attr = ::mlir::DenseFPElementsAttr::get(
1157             rest_size_inv_type, {1.0f / rest_size});
1158         rest_size_inv_op =
1159             rewriter.create<::mlir::TF::ConstOp>(odsLoc, rest_size_inv_attr);
1160       }
1161 
1162       ::mlir::TF::SumOp sum_op_1;
1163       {
1164         ::mlir::Value x_value = (*x.begin());
1165         sum_op_1 = rewriter.create<TF::SumOp>(
1166             odsLoc, x_value, reduce_dim_op,
1167             /*keep_dims=*/rewriter.getBoolAttr(false));
1168       }
1169 
1170       ::mlir::TF::MulOp mul_op_1;
1171       {
1172         ::mlir::Value tblgen_value_0 = (*sum_op_1.getODSResults(0).begin());
1173         ::mlir::Value tblgen_value_1 =
1174             (*rest_size_inv_op.getODSResults(0).begin());
1175         mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
1176                                                       tblgen_value_1);
1177       }
1178 
1179       ::mlir::TF::SquaredDifferenceOp square_diff_op;
1180       {
1181         ::mlir::Value tblgen_value_0 = (*x.begin());
1182         ::mlir::Value tblgen_value_1 = (*mul_op_1.getODSResults(0).begin());
1183         // If x has shape of [b, h, w, c], the result of mul_op_1 will have
1184         // shape of [c]. Therefore, their shapes are always compatible.
1185         square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
1186             odsLoc, tblgen_value_0, tblgen_value_1);
1187       }
1188 
1189       ::mlir::TF::SumOp sum_op_2;
1190       {
1191         ::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
1192         sum_op_2 = rewriter.create<TF::SumOp>(
1193             odsLoc, input_value, reduce_dim_op,
1194             /*keep_dims=*/rewriter.getBoolAttr(false));
1195       }
1196 
1197       ::mlir::TF::MulOp mul_op_2;
1198       {
1199         ::mlir::Value tblgen_value_0 = (*sum_op_2.getODSResults(0).begin());
1200         ::mlir::Value tblgen_value_1 =
1201             (*rest_size_inv_op.getODSResults(0).begin());
1202         mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
1203                                                       tblgen_value_1);
1204       }
1205 
1206       mean_value = (*mul_op_1.getODSResults(0).begin());
1207       variance_value = (*mul_op_2.getODSResults(0).begin());
1208     }  // End is_training equals true if.
1209 
1210     ::llvm::SmallVector<::mlir::Value, 4> replace_values;
1211     ::mlir::TF::ConstOp epsilon_const_op;
1212     {
1213       epsilon_const_op =
1214           rewriter.create<::mlir::TF::ConstOp>(odsLoc,
1215                                                /*value=*/epsilon);
1216     }
1217     ::mlir::TF::AddOp add_op_1;
1218     {
1219       ::mlir::Value epsilon_value =
1220           (*epsilon_const_op.getODSResults(0).begin());
1221       // Multiplying with a constant, no need to check broadcastibility.
1222       add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
1223                                                     /*x=*/variance_value,
1224                                                     /*y=*/epsilon_value);
1225     }
1226     ::mlir::TF::RsqrtOp rsqrt_op;
1227     {
1228       ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1229       ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1230       tblgen_values.push_back((*add_op_1.getODSResults(0).begin()));
1231       rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values,
1232                                                       tblgen_attrs);
1233     }
1234     ::mlir::TF::MulOp multiplier;
1235     {
1236       ::mlir::Value tblgen_value_0 = (*scale.begin());
1237       ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
1238       multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1239                                                       /*x=*/tblgen_value_0,
1240                                                       /*y=*/tblgen_value_1);
1241     }
1242     ::mlir::TF::MulOp mul_op_1;
1243     {
1244       ::mlir::Value tblgen_value_0 = (*x.begin());
1245       ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
1246       mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1247                                                     /*x=*/tblgen_value_0,
1248                                                     /*y=*/tblgen_value_1);
1249     }
1250     ::mlir::TF::MulOp mul_op_2;
1251     {
1252       ::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
1253       mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1254                                                     /*x=*/mean_value,
1255                                                     /*y=*/multiplier_value);
1256     }
1257     ::mlir::TF::SubOp sub_op;
1258     {
1259       ::mlir::Value tblgen_value_0 = (*offset.begin());
1260       ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin());
1261       sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
1262                                                   /*x=*/tblgen_value_0,
1263                                                   /*y=*/tblgen_value_1);
1264     }
1265     ::mlir::TF::AddOp add_op_2;
1266     {
1267       ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1268       ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1269       tblgen_values.push_back((*mul_op_1.getODSResults(0).begin()));
1270       tblgen_values.push_back((*sub_op.getODSResults(0).begin()));
1271       ::mlir::SmallVector<::mlir::Type, 4> tblgen_types;
1272       for (auto v : fused_batch_norm_op.getODSResults(0)) {
1273         tblgen_types.push_back(v.getType());
1274       }
1275       add_op_2 = rewriter.create<::mlir::TF::AddOp>(
1276           odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
1277     }
1278     for (auto v :
1279          ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
1280       replace_values.push_back(v);
1281     }
1282     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1283       replace_values.push_back(v);
1284     }
1285     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1286       replace_values.push_back(v);
1287     }
1288     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1289       replace_values.push_back(v);
1290     }
1291     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1292       replace_values.push_back(v);
1293     }
1294     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1295       replace_values.push_back(v);
1296     }
1297     rewriter.replaceOp(fused_batch_norm, replace_values);
1298     return success();
1299   };
1300 };
1301 
1302 #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
1303 
1304 // Returns success if all the operations in the `op`'s regions including `op`
1305 // itself are legal in a TFLite pipeline.
ValidateOp(Operation * op)1306 LogicalResult ValidateOp(Operation *op) {
1307   bool has_illegal_ops = false;
1308   op->walk([&](Operation *op) {
1309     if (isa<TF::VariableV2Op>(op)) {
1310       has_illegal_ops = true;
1311       op->emitOpError() << "is illegal in a TFLite pipeline";
1312     }
1313   });
1314 
1315   return failure(has_illegal_ops);
1316 }
1317 
1318 // Converts a set of TF2XLA ops into pure TF ops for future legalizations as
1319 // TF2XLA ops aren't supported by later stages.
ConvertTf2XlaOps(FuncOp func,MLIRContext * context)1320 LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
1321   ConversionTarget target(*context);
1322   target.addLegalDialect<StandardOpsDialect>();
1323   target.addLegalDialect<TF::TensorFlowDialect>();
1324   target.addLegalOp<ModuleOp>();
1325   target.addLegalOp<FuncOp>();
1326   target.addIllegalOp<TF::XlaConvOp>();
1327   target.addIllegalOp<TF::XlaGatherOp>();
1328 
1329   OwningRewritePatternList patterns;
1330   mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns);
1331   mhlo::PopulateLegalizeTfPatterns(context, &patterns);
1332   TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
1333   mhlo::GatherOp::getCanonicalizationPatterns(patterns, context);
1334 
1335   return applyPartialConversion(func, target, std::move(patterns));
1336 }
1337 
1338 // Convert rfft to rfft2d.
1339 // The transformation pattern looks like below:
1340 //
1341 //    input     fft_len
1342 //     \      /
1343 //     rfft
1344 //
1345 //     ||
1346 //     \/
1347 //
1348 //   input       fft_len
1349 //    \            /
1350 //   expand_dim    concat with [1] at the front
1351 //      \         /
1352 //     rfft_2d
1353 //       |
1354 //     squeeze
1355 struct ConvertRfftToRfft2d : public RewritePattern {
ConvertRfftToRfft2dmlir::TFL::__anonf6df4be00111::ConvertRfftToRfft2d1356   explicit ConvertRfftToRfft2d(MLIRContext *context)
1357       : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {}
1358 
matchAndRewritemlir::TFL::__anonf6df4be00111::ConvertRfftToRfft2d1359   LogicalResult matchAndRewrite(Operation *op,
1360                                 PatternRewriter &rewriter) const override {
1361     auto rfft_op = dyn_cast<TF::RFFTOp>(op);
1362 
1363     auto input = rfft_op.input();
1364     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
1365     if (!input_type) return failure();
1366     auto fft_len = rfft_op.fft_length();
1367     auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>();
1368     if (!fft_len_type) return failure();
1369 
1370     auto output_type =
1371         rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
1372     if (!output_type) return failure();
1373 
1374     // Expanded inputs.
1375     // Insert at -2 location.
1376     auto one_ele_type =
1377         mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32));
1378     auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1379                                                   one_ele_type, -2);
1380 
1381     SmallVector<int64_t, 4> expanded_input_shape;
1382     SmallVector<int64_t, 4> expanded_output_shape;
1383     int expanded_rank = input_type.getRank() + 1;
1384     int r = 0;
1385     for (int i = 0; i < expanded_rank; ++i) {
1386       if (i == expanded_rank - 2) {
1387         expanded_input_shape.push_back(1);
1388         expanded_output_shape.push_back(1);
1389       } else {
1390         expanded_input_shape.push_back(input_type.getDimSize(r));
1391         expanded_output_shape.push_back(output_type.getDimSize(r));
1392         r++;
1393       }
1394     }
1395 
1396     auto expaned_input_type = mlir::RankedTensorType::get(
1397         expanded_input_shape, input_type.getElementType());
1398     TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>(
1399         rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult());
1400 
1401     // Expanded fft_len.
1402     auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1});
1403 
1404     auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr);
1405 
1406     auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1407                                              one_ele_type, 0);
1408 
1409     auto expanded_fft_len_type =
1410         mlir::RankedTensorType::get({2}, fft_len_type.getElementType());
1411 
1412     TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>(
1413         rfft_op.getLoc(), expanded_fft_len_type,
1414         SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult());
1415 
1416     // Insert the rfft_2d.
1417     auto rfft2d_out_type = mlir::RankedTensorType::get(
1418         expanded_output_shape, output_type.getElementType());
1419     TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>(
1420         rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(),
1421         expanded_fft_len.getResult());
1422 
1423     // Insert the squeeze op.
1424     auto squeeze_dim = rewriter.getI64ArrayAttr({-2});
1425     TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>(
1426         rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim);
1427 
1428     rewriter.replaceOp(op, squeeze.getResult());
1429 
1430     return success();
1431   }
1432 };
1433 
runOnFunction()1434 void PrepareTFPass::runOnFunction() {
1435   OwningRewritePatternList patterns, phase_2_patterns;
1436   auto func = getFunction();
1437   MLIRContext *ctx = &getContext();
1438 
1439   // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since
1440   // PrepareTFPass is the very first TFLite pass in the pipeline.
1441   // TODO(jingpu): It might be better to split this check into its own pass
1442   // to make things more modular.
1443   if (failed(ValidateOp(func))) {
1444     func.emitError() << "tfl-prepare-tf pass failed.";
1445     signalPassFailure();
1446     return;
1447   }
1448 
1449   if (failed(ConvertTf2XlaOps(func, ctx))) {
1450     signalPassFailure();
1451     return;
1452   }
1453 
1454   // This pattern was intented to uses TFL QDQs to preserve the quantization
1455   // parameters from the TF Quant ops, thus this pattern should run with the
1456   // first `applyPatternsGreedily` method, which would otherwise removes the
1457   // TF FakeQuant ops by the constant folding.
1458   patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant,
1459                   PreparePerTensorFakeQuantWithMinMaxArgs>(ctx);
1460 
1461   // This pattern will try to identify and optimize for dilated convolution.
1462   // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
1463   // replaced with a single Conv op with dilation parameter.
1464   patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
1465                   ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
1466 
1467   TFL::populateWithGenerated(ctx, patterns);
1468   // TODO(karimnosseir): Split to separate pass probably after
1469   // deciding on long term plan for this optimization.
1470   // This will allow optimizing any TF_Mul->TF_Conv in the graph
1471   // and any expanded from FusedBatchNorm. We need to do this
1472   // before converting TF_Conv to TFL_Conv
1473   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1474 
1475   // Load the generated pattern again, so new quantization pass-through
1476   // will be applied.
1477   TFL::populateWithGenerated(ctx, phase_2_patterns);
1478   if (unfold_batch_matmul_) {
1479     phase_2_patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
1480                             TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(
1481         ctx);
1482   }
1483   phase_2_patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo,
1484                           ConvertTFStridedSlice, ConvertRfftToRfft2d>(ctx);
1485   phase_2_patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
1486       ctx, allow_bf16_and_f16_type_legalization_);
1487 
1488   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1489 }
1490 
1491 }  // namespace
1492 
1493 // Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
CreatePrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_type_legalization)1494 std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
1495     bool unfold_batch_matmul, bool allow_bf16_type_legalization) {
1496   return std::make_unique<PrepareTFPass>(unfold_batch_matmul,
1497                                          allow_bf16_type_legalization);
1498 }
1499 
1500 static PassRegistration<PrepareTFPass> pass(
1501     "tfl-prepare-tf", "Prepare TF for legalization to TensorFlow Lite dialect");
1502 
1503 }  // namespace TFL
1504 }  // namespace mlir
1505