• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting operations in TensorFlow dialect into operations that can be
18 // legalized to TensorFlow Lite dialect with simple replacements.  The newly
19 // created operations are in the TensorFlow dialect if the operation can be
20 // represented using a TensorFlow op.  Otherwise, TensorFlow Lite dialect op is
21 // used.  For example, Conv2D in TFLite which uses OHWI data format for filters
22 // is not supported in TensorFlow because TensorFlow requires filters in the
23 // HWIO data format.
24 //
25 // Motivation to prepare for the TFLite legalization before the actual
26 // legalization is to exploit constant folding opportunities in any newly
27 // created ops by leveraging constant folding support for the TensorFlow ops.
28 // This way TFLite can be used as a serialization format only and does not
29 // require access to the TFLite runtime for optimizations as required by the
30 // TFLite team.
31 
32 #include <climits>
33 #include <cstdint>
34 
35 #include "absl/memory/memory.h"
36 #include "absl/numeric/bits.h"
37 #include "llvm/ADT/ArrayRef.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/StringSwitch.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Support/Debug.h"
42 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
43 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
44 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
45 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
46 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
47 #include "mlir/IR/Attributes.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
49 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
50 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
51 #include "mlir/IR/Operation.h"  // from @llvm-project
52 #include "mlir/Pass/Pass.h"  // from @llvm-project
53 #include "mlir/Support/LLVM.h"  // from @llvm-project
54 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
55 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
56 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
57 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
58 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
59 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
60 #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
61 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
62 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
63 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
64 #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h"
65 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
67 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
68 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
69 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
70 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
71 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
72 
73 #define DEBUG_TYPE "tf-tfl-legalization"
74 
75 namespace mlir {
76 namespace TFL {
77 //===----------------------------------------------------------------------===//
78 // The actual PrepareTF Pass.
79 //
80 // TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this
81 // pass.
82 namespace {
83 
84 // Prepare TF operations in functions for subsequent legalization.
85 class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
86  public:
87   PrepareTFPass() = default;
PrepareTFPass(const PrepareTFPass &)88   PrepareTFPass(const PrepareTFPass &) {}
PrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_and_f16_type_legalization)89   explicit PrepareTFPass(bool unfold_batch_matmul,
90                          bool allow_bf16_and_f16_type_legalization) {
91     unfold_batch_matmul_ = unfold_batch_matmul;
92     allow_bf16_and_f16_type_legalization_ =
93         allow_bf16_and_f16_type_legalization;
94   }
95 
getArgument() const96   StringRef getArgument() const final {
97     // This is the argument used to refer to the pass in
98     // the textual format (on the commandline for example).
99     return "tfl-prepare-tf";
100   }
getDescription() const101   StringRef getDescription() const final {
102     // This is a brief description of the pass.
103     return "Prepare TF for legalization to TensorFlow Lite dialect";
104   }
105 
106   void runOnFunction() override;
107 
getDependentDialects(DialectRegistry & registry) const108   void getDependentDialects(DialectRegistry &registry) const override {
109     registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
110                     TFL::TensorFlowLiteDialect>();
111   }
112 
113  private:
114   Option<bool> unfold_batch_matmul_{
115       *this, "tfl-unfold-batch-matmul",
116       llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
117       llvm::cl::init(true)};
118 
119   Option<bool> allow_bf16_and_f16_type_legalization_{
120       *this, "tfl-allow-bf16-and-f16-type-legalization",
121       llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)};
122 };
123 
124 // Transient state for preserving data from match to rewrite
125 struct ConvertTFConvOpMatchState {
126   IntegerAttr dilation_height_factor;
127   IntegerAttr dilation_width_factor;
128   StringAttr padding;
129   IntegerAttr stride_height;
130   IntegerAttr stride_width;
131 };
132 
133 // Templated class for declaring a converter from some TensorFlow convolution
134 // op into its counterpart in TensorFlow Lite.
135 //
136 // The `ConcreteType` deriving from this template must provide the following
137 // method for constructing TensorFlow Lite op:
138 //
139 //   TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
140 //                         PatternRewriter &rewriter, Location loc,
141 //                         Type result_type, Value input,
142 //                         Value filter, Value bias) const;
143 //
144 // And also the following method for getting the dimension for bias tensor:
145 //
146 //  int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
147 template <typename ConcreteType, typename TFConvOpType>
148 class ConvertTFConvOp : public RewritePattern {
149  public:
ConvertTFConvOp(MLIRContext * context,bool allow_bf16_and_f16_type_legalization)150   ConvertTFConvOp(MLIRContext *context,
151                   bool allow_bf16_and_f16_type_legalization)
152       : RewritePattern(TFConvOpType::getOperationName(), 1, context),
153         intAttrOne(Builder(context).getI32IntegerAttr(1)),
154         allow_bf16_and_f16_type_legalization_(
155             allow_bf16_and_f16_type_legalization) {}
156 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const157   LogicalResult matchAndRewrite(Operation *op,
158                                 PatternRewriter &rewriter) const override {
159     // Assumes TensorFlow convolution op is already verified to be
160     // in valid form.
161 
162     // Match a TFConvOpType under the following conditions:
163     // * The 'T' attribute must exist and be of value DT_FLOAT.
164     // * The 'data_format' attribute must exist and be of value "NHWC".
165     // * The 'strides' attribute must exist and is of the form [1, X, Y, 1].
166     // * The 'dilations' attribute is optional, but it must be of the form
167     //   [1, X, Y, 1] if exists.
168 
169     TFConvOpType tf_op = cast<TFConvOpType>(op);
170 
171     if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
172         !(allow_bf16_and_f16_type_legalization_ &&
173           TFTypeIsBFloat16OrHalfTensor(tf_op.input())))
174       return failure();
175 
176     if (!TFDataFormatIsNHWC(op)) return failure();
177 
178     IntegerAttr height, width;
179     if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
180 
181     ConvertTFConvOpMatchState state;
182     state.stride_height = height;
183     state.stride_width = width;
184 
185     if (TFIntListIs1XY1(op, "dilations", &height, &width)) {
186       state.dilation_height_factor = height;
187       state.dilation_width_factor = width;
188     } else {
189       // If the 'dilations' attribute is missing, we use the default value (1)
190       // for both dilation height and width factor.
191       state.dilation_height_factor = intAttrOne;
192       state.dilation_width_factor = intAttrOne;
193     }
194 
195     if (!TFPaddingIsSameOrValid(op, &state.padding)) return failure();
196 
197     // Additionally, we require the filter operand to be of 4-D tensor type so
198     // that we can extract info from the shape (e.g., for constructing bias
199     // tensor, for setting depth_multiplier attribute, etc.).
200     auto filter = tf_op.filter();
201     auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
202     if (!filter_type || filter_type.getRank() != 4 ||
203         !filter_type.hasStaticShape())
204       return failure();
205 
206     // TensorFlow convolution op only has two inputs, while the TFLite one has
207     // three, with the bias vector marked as optional. However, TOCO has a
208     // dedicated pass, EnsureBiasVectors, to create default bias vectors for all
209     // those missing. So we model TFLite convolution op as requiring three
210     // inputs to achieve the legalization task of EnsureBiasVector. this
211     // requires the filter tensor to have static shape.
212 
213     // TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>)
214 
215     // Get a splat zero tensor with the expected dimension for the bias tensor
216     auto elem_type = filter_type.getElementType();
217     auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
218         filter_type.getShape());
219     auto bias_type = RankedTensorType::get({bias_dim}, elem_type);
220     auto bias_attr = rewriter.getZeroAttr(bias_type);
221     auto bias =
222         rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
223 
224     auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
225         &state, rewriter, op->getLoc(), tf_op.getType(), tf_op.input(), filter,
226         bias);
227 
228     rewriter.replaceOp(op, conv_op.getResult());
229     return success();
230   }
231 
232   const IntegerAttr intAttrOne;
233 
234  private:
235   bool allow_bf16_and_f16_type_legalization_;
236 };
237 
238 class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
239  public:
240   using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
241 
ConvertTFConv2D(MLIRContext * context,bool allow_bf16_type_legalization)242   ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization)
243       : BaseType(context, allow_bf16_type_legalization) {}
244 
getBiasDim(ArrayRef<int64_t> filterShape) const245   int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
246     return filterShape.back();
247   }
248 
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const249   TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
250                             PatternRewriter &rewriter, Location loc,
251                             Type result_type, Value input, Value filter,
252                             Value bias) const {
253     filter = legalizeFilter(rewriter, loc, filter);
254     return rewriter.create<TFL::Conv2DOp>(
255         loc, result_type, input, filter, bias,
256         /*dilation_h_factor=*/state->dilation_height_factor,
257         /*dilation_w_factor=*/state->dilation_width_factor,
258         /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
259         /*padding=*/state->padding,
260         /*stride_h=*/state->stride_height,
261         /*stride_w=*/state->stride_width);
262   }
263 
264  private:
265   // Legalize the given filter by converting it from TensorFlow filter data
266   // format HWIO to TFLite Conv2D op filter data format OHWI and return Value
267   // for the converted filter.  Requires that filter is verified by the match
268   // method that it is a 4-D RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const269   Value legalizeFilter(PatternRewriter &rewriter, Location loc,
270                        Value filter) const {
271     // Create a constant op for HWIO to OHWI transpose permutation.
272     SmallVector<int, 4> perm = {3, 0, 1, 2};
273     auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
274                                            rewriter.getIntegerType(32));
275     auto perm_attr =
276         DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
277     auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
278 
279     // Create tensor type for the transpose result.
280     auto filter_type = filter.getType().cast<RankedTensorType>();
281     auto result_shape =
282         llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) {
283           return filter_type.getDimSize(dim);
284         }));
285     auto elem_type = filter_type.getElementType();
286     auto result_type = RankedTensorType::get(result_shape, elem_type);
287 
288     return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op);
289   }
290 };
291 
292 class ConvertTFDepthwiseConv2dNative
293     : public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
294                              TF::DepthwiseConv2dNativeOp> {
295  public:
296   using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
297                                    TF::DepthwiseConv2dNativeOp>;
298 
ConvertTFDepthwiseConv2dNative(MLIRContext * context,bool allow_bf16_type_legalization)299   ConvertTFDepthwiseConv2dNative(MLIRContext *context,
300                                  bool allow_bf16_type_legalization)
301       : BaseType(context, allow_bf16_type_legalization) {}
302 
getBiasDim(ArrayRef<int64_t> filterShape) const303   int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
304     return filterShape[2] * filterShape[3];
305   }
306 
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const307   TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
308                                      PatternRewriter &rewriter, Location loc,
309                                      Type result_type, Value input,
310                                      Value filter, Value bias) const {
311     // Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
312     // 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
313     // have a corresponding 'depth_multiplier' attribute; the multiplier is the
314     // fourth dimension in the 4-D filter tensor. We query the multiplier from
315     // tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
316     auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
317 
318     filter = legalizeFilter(rewriter, loc, filter);
319     return rewriter.create<TFL::DepthwiseConv2DOp>(
320         loc, result_type, input, filter, bias,
321         /*dilation_h_factor=*/state->dilation_height_factor,
322         /*dilation_w_factor=*/state->dilation_width_factor,
323         /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
324         /*padding=*/state->padding,
325         /*stride_h=*/state->stride_height,
326         /*stride_w=*/state->stride_width,
327         /*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier));
328   }
329 
330  private:
331   /// Legalize the given filter by converting it from TensorFlow filter data
332   /// format to TFLite DepthwiseConv2D op filter data format and return Value
333   /// for the converted filter.  TensorFlow filter data format is
334   /// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite
335   /// filter data format is [1, filter_height, filter_width, out_channels].
336   /// Requires that filter is verified by the match method that it is a 4-D
337   /// RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const338   Value legalizeFilter(PatternRewriter &rewriter, Location loc,
339                        Value filter) const {
340     auto filter_type = filter.getType().cast<RankedTensorType>();
341     auto filterShape = filter_type.getShape();
342     SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
343                                             filterShape[2] * filterShape[3]};
344     auto elem_type = filter_type.getElementType();
345     auto result_type = RankedTensorType::get(result_shape, elem_type);
346     // TensorFlow Lite `Reshape` op only support int32 shape tensor currently.
347     auto shape_type = RankedTensorType::get({4}, rewriter.getIntegerType(32));
348     SmallVector<Attribute, 4> result_shape_data(4);
349     for (int i = 0; i < 4; ++i) {
350       result_shape_data[i] =
351           rewriter.getI32IntegerAttr(static_cast<int32_t>(result_shape[i]));
352     }
353     auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
354     auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
355 
356     return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
357   }
358 };
359 
360 // StridedSlice can have complicated attributes like begin_axis_mask,
361 // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
362 // masks will complicate the strided_slice computation logic, we can simplify
363 // the logic by inserting a reshape op to pad the inputs so strided_slice can
364 // be easier to handle.
365 //
366 // So the graph may looks like below:
367 //   original_input -> strided_slice -> output
368 //      (transforms)
369 //   original_input -> reshape -> strided_slice -> output
370 //
371 // And the new shape is computed based on the masks.
372 //
373 // An example for new_axis_mask. say the new_axis_mask is 9 which represents
374 // [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so
375 // if original shape is [2, 3], now we reshape that into [1, 2, 3, 1].
376 struct ConvertTFStridedSlice : public RewritePattern {
ConvertTFStridedSlicemlir::TFL::__anon092ff4650111::ConvertTFStridedSlice377   explicit ConvertTFStridedSlice(MLIRContext *context)
378       : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
379 
RewriteNewAxisMaskmlir::TFL::__anon092ff4650111::ConvertTFStridedSlice380   LogicalResult RewriteNewAxisMask(Operation *op,
381                                    PatternRewriter &rewriter) const {
382     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
383     uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
384 
385     if (strided_slice_op.ellipsis_mask() != 0) {
386       // Ellipsis mask should have been lowered-away prior to invoking this
387       // function.
388       op->emitError() << "encountered a logical error";
389       return failure();
390     }
391 
392     // Insert a new reshape op.
393     Value original_input = strided_slice_op.input();
394     RankedTensorType original_input_type =
395         original_input.getType().dyn_cast<RankedTensorType>();
396     if (!original_input_type) {
397       return failure();
398     }
399 
400     const ArrayRef<int64_t> &original_input_shape =
401         original_input_type.getShape();
402     SmallVector<int64_t, 4> revised_shape;
403     int index = 0;
404     const int original_input_rank = original_input_shape.size();
405     while (index < original_input_rank || new_axis_mask) {
406       if (new_axis_mask & 1) {
407         revised_shape.emplace_back(1);
408       } else {
409         revised_shape.emplace_back(original_input_shape[index++]);
410       }
411       new_axis_mask >>= 1;
412     }
413 
414     if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
415 
416     const int dim_size = revised_shape.size();
417     Location loc = strided_slice_op.getLoc();
418     auto shape_type =
419         RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
420     SmallVector<Attribute, 4> result_shape_data(dim_size);
421     for (int i = 0; i < dim_size; ++i) {
422       result_shape_data[i] =
423           rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
424     }
425 
426     auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
427     auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
428     auto revised_output_type = RankedTensorType::get(
429         revised_shape, original_input_type.getElementType());
430     TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
431         loc, revised_output_type, original_input, shape);
432 
433     // Replace the original strided_slice.
434     uint64_t revised_begin_mask = strided_slice_op.begin_mask();
435     uint64_t revised_end_mask = strided_slice_op.end_mask();
436     // Since we expand the dims, we need to apply them to the begin_mask &
437     // end_mask.
438     revised_begin_mask |= strided_slice_op.new_axis_mask();
439     revised_end_mask |= strided_slice_op.new_axis_mask();
440 
441     // Enforce operator precedence.
442     uint64_t revised_shrink_axis_mask =
443         strided_slice_op.shrink_axis_mask() & ~strided_slice_op.new_axis_mask();
444 
445     auto attribute_type = rewriter.getIntegerType(64);
446     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
447         op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
448         strided_slice_op.end(), strided_slice_op.strides(),
449         rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
450         rewriter.getIntegerAttr(attribute_type, revised_end_mask),
451         rewriter.getIntegerAttr(attribute_type,
452                                 strided_slice_op.ellipsis_mask()),
453         rewriter.getI64IntegerAttr(0),
454         rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
455     return success();
456   }
457 
RewriteEllipsisMaskmlir::TFL::__anon092ff4650111::ConvertTFStridedSlice458   LogicalResult RewriteEllipsisMask(Operation *op,
459                                     PatternRewriter &rewriter) const {
460     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
461 
462     uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
463     uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
464     uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
465 
466     // Enforce operator precedence.
467     shrink_axis_mask &= ~ellipsis_mask;
468     new_axis_mask &= ~ellipsis_mask;
469 
470     DenseIntElementsAttr begin_dense_elem_attr;
471     Value begin = strided_slice_op.begin();
472     auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
473     if (!begin_ranked_attr_type ||
474         !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
475       return failure();
476     }
477 
478     DenseIntElementsAttr end_dense_elem_attr;
479     Value end = strided_slice_op.end();
480     auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
481     if (!end_ranked_attr_type ||
482         !matchPattern(end, m_Constant(&end_dense_elem_attr))) {
483       return failure();
484     }
485 
486     DenseIntElementsAttr stride_dense_elem_attr;
487     Value stride = strided_slice_op.strides();
488     auto stride_ranked_attr_type =
489         stride.getType().dyn_cast<RankedTensorType>();
490     if (!stride_ranked_attr_type ||
491         !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
492       return failure();
493     }
494 
495     Value input = strided_slice_op.input();
496     RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
497     if (!input_type) {
498       return failure();
499     }
500     const ArrayRef<int64_t> input_shape = input_type.getShape();
501 
502     const int input_size = input_shape.size();
503 
504     RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
505     const ArrayRef<int64_t> begin_shape = begin_type.getShape();
506     const int begin_dim = begin_shape.size();
507 
508     if (begin_dim != 1) return failure();
509 
510     // The ellipsis fill might exceed the current output shape because we are
511     // also taking account of any to-be-inserted new axes.
512     const int ellipsis_filled_dim_size =
513         input_size - begin_shape[0] + 1 + absl::popcount(new_axis_mask);
514 
515     int64_t begin_mask = strided_slice_op.begin_mask();
516     int64_t end_mask = strided_slice_op.end_mask();
517     int64_t revised_begin_mask = 0;
518     int64_t revised_end_mask = 0;
519     int64_t revised_shrink_axis_mask = 0;
520     int64_t revised_new_axis_mask = 0;
521 
522     SmallVector<int32_t, 4> padded_begin;
523     SmallVector<int32_t, 4> padded_end;
524     SmallVector<int32_t, 4> padded_stride;
525 
526     // Before the ellipsis.
527     int index = 0;
528     int new_index = 0;
529     while (((ellipsis_mask >> index) & 1) == 0) {
530       padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
531       padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
532       padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
533       if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
534       if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
535       if ((shrink_axis_mask >> index) & 1)
536         revised_shrink_axis_mask |= (1 << new_index);
537 
538       if ((new_axis_mask >> index) & 1)
539         revised_new_axis_mask |= (1 << new_index);
540 
541       ++index;
542       ++new_index;
543     }
544 
545     // Ellipsis.
546     for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
547       revised_begin_mask |= (1 << new_index);
548       revised_end_mask |= (1 << new_index);
549 
550       // Mimic the begin/end/strides mask behavior.
551       padded_begin.push_back(0);
552       padded_end.push_back(0);
553       padded_stride.push_back(1);
554     }
555 
556     // Account for ellipsis mask.
557     ++index;
558 
559     // After the ellipsis.
560     for (; index < begin_shape[0];) {
561       padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
562       padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
563       padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
564 
565       if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
566       if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
567       if ((shrink_axis_mask >> index) & 1)
568         revised_shrink_axis_mask |= (1 << new_index);
569       if ((new_axis_mask >> index) & 1)
570         revised_new_axis_mask |= (1 << new_index);
571 
572       ++index;
573       ++new_index;
574     }
575 
576     auto attribute_type = rewriter.getIntegerType(64);
577 
578     int full_dim_count = padded_begin.size();
579     auto type =
580         RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
581 
582     auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
583     auto begin_op = rewriter.create<ConstantOp>(op->getLoc(), type, begin_attr);
584     auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
585     auto end_op = rewriter.create<ConstantOp>(op->getLoc(), type, end_attr);
586     auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
587     auto stride_op =
588         rewriter.create<ConstantOp>(op->getLoc(), type, stride_attr);
589 
590     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
591         op, strided_slice_op.getType(), input, begin_op.getResult(),
592         end_op.getResult(), stride_op.getResult(),
593         rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
594         rewriter.getIntegerAttr(attribute_type, revised_end_mask),
595         /*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
596         rewriter.getIntegerAttr(attribute_type, revised_new_axis_mask),
597         rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
598 
599     return success();
600   }
601 
PadStridedSliceAttributeArraymlir::TFL::__anon092ff4650111::ConvertTFStridedSlice602   void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
603                                      SmallVectorImpl<int32_t> &val,
604                                      SmallVectorImpl<int32_t> &padded_val,
605                                      ArrayRef<int32_t> padding_val,
606                                      int *mask) const {
607     for (const auto &idx : dense_elem_attr.getIntValues()) {
608       val.push_back(idx.getSExtValue());
609       padded_val.push_back(idx.getSExtValue());
610     }
611     int attr_dim_count = val.size();
612     int full_dim_count = padding_val.size();
613     for (int i = attr_dim_count; i < full_dim_count; ++i) {
614       padded_val.push_back(padding_val[i]);
615       if (mask) *mask |= 1 << i;
616     }
617   }
618 
matchAndRewritemlir::TFL::__anon092ff4650111::ConvertTFStridedSlice619   LogicalResult matchAndRewrite(Operation *op,
620                                 PatternRewriter &rewriter) const override {
621     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
622 
623     // Handle ellipsis mask.
624     if (strided_slice_op.ellipsis_mask() != 0) {
625       return RewriteEllipsisMask(strided_slice_op, rewriter);
626     }
627 
628     // Handle new axis mask.
629     if (strided_slice_op.new_axis_mask() != 0) {
630       return RewriteNewAxisMask(strided_slice_op, rewriter);
631     }
632 
633     auto ranked_input_type =
634         strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
635     if (!ranked_input_type) {
636       return failure();
637     }
638 
639     auto begin_attr = strided_slice_op.begin();
640     auto end_attr = strided_slice_op.end();
641     auto strides_attr = strided_slice_op.strides();
642 
643     auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
644     auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
645     auto strides_attr_type =
646         strides_attr.getType().dyn_cast<RankedTensorType>();
647 
648     DenseIntElementsAttr begin_elem_attr;
649     DenseIntElementsAttr end_elem_attr;
650     DenseIntElementsAttr strides_elem_attr;
651 
652     if (!begin_attr_type ||
653         !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
654       return failure();
655     }
656     if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
657       return failure();
658     }
659     if (!strides_attr_type ||
660         !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
661       return failure();
662     }
663 
664     SmallVector<int32_t, 4> begin, end, strides;
665     SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
666 
667     int num_input_dims = ranked_input_type.getRank();
668     SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
669     auto input_shape = ranked_input_type.getShape();
670     SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
671     SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
672 
673     int begin_mask = strided_slice_op.begin_mask();
674     int end_mask = strided_slice_op.end_mask();
675 
676     PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
677                                   padding_begin, &begin_mask);
678     PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
679                                   &end_mask);
680     PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
681                                   padding_strides, nullptr);
682 
683     if (begin == padded_begin && end == padded_end &&
684         strides == padded_strides &&
685         begin_mask == strided_slice_op.begin_mask() &&
686         end_mask == strided_slice_op.end_mask()) {
687       return failure();
688     }
689 
690     auto begin_end_type =
691         RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
692     auto new_begin_attr = rewriter.create<ConstantOp>(
693         op->getLoc(), begin_end_type,
694         DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
695     auto new_end_attr = rewriter.create<ConstantOp>(
696         op->getLoc(), begin_end_type,
697         DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
698     auto strides_type =
699         RankedTensorType::get({static_cast<long>(padded_strides.size())},
700                               rewriter.getIntegerType(32));
701     auto new_strides_attr = rewriter.create<ConstantOp>(
702         op->getLoc(), strides_type,
703         DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
704 
705     auto attribute_type = rewriter.getIntegerType(64);
706     rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
707         op, strided_slice_op.output().getType(), strided_slice_op.input(),
708         new_begin_attr, new_end_attr, new_strides_attr,
709         rewriter.getIntegerAttr(attribute_type, begin_mask),
710         rewriter.getIntegerAttr(attribute_type, end_mask),
711         rewriter.getIntegerAttr(attribute_type,
712                                 strided_slice_op.ellipsis_mask()),
713         rewriter.getIntegerAttr(attribute_type,
714                                 strided_slice_op.new_axis_mask()),
715         rewriter.getIntegerAttr(attribute_type,
716                                 strided_slice_op.shrink_axis_mask()));
717 
718     return success();
719   }
720 };
721 
722 struct ConvertTFBroadcastTo : public RewritePattern {
ConvertTFBroadcastTomlir::TFL::__anon092ff4650111::ConvertTFBroadcastTo723   explicit ConvertTFBroadcastTo(MLIRContext *context)
724       : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
725 
matchAndRewritemlir::TFL::__anon092ff4650111::ConvertTFBroadcastTo726   LogicalResult matchAndRewrite(Operation *op,
727                                 PatternRewriter &rewriter) const override {
728     auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
729     auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
730     auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
731     auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
732     Type element_type = input_type.getElementType();
733 
734     // Allow lowering when low dimension inputs are given and its type is F32 or
735     // I32.
736     if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
737           (shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
738            shape_type.getDimSize(0) <= 4)))
739       return failure();
740 
741     if (!(element_type.isa<BFloat16Type, Float32Type>() ||
742           element_type.isInteger(32) || element_type.isInteger(16)))
743       return failure();
744 
745     auto status_or_const_op =
746         CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
747     if (!status_or_const_op.ok()) {
748       return failure();
749     }
750 
751     auto tf_fill_op = rewriter.create<TF::FillOp>(
752         op->getLoc(), output_type, tf_broadcast_to_op.shape(),
753         status_or_const_op.ValueOrDie());
754 
755     auto mul_op = rewriter.create<TF::MulOp>(
756         op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
757     rewriter.replaceOp(op, mul_op.getResult());
758     return success();
759   }
760 };
761 
762 // The below pattern is equivalent to the DRR rule below
763 // The checks are dependent on generated values, so we can't add
764 // the checks on intermediate values, ideally we should find equivalent
765 // checks that guarantees the resultant ops are valid.
766 // The extra conditions are the broadcasting conditions.
767 //
768 // The pattern lower FusedBatchNormV3 to arithmetic ops.
769 // Specifically, performs the following calculation:
770 //
771 //   (x - mean) * scale / sqrt(variance + epsilon) + offset
772 //
773 // Let multiplier = scale / sqrt(variance + epsilon),
774 // to compute
775 //   (x - mean) * scale / sqrt(variance + epsilon) + offset,
776 // is then to compute
777 //   (x * multiplier) + (offset - mean * multiplier).
778 //
779 // def : Pattern<
780 //     (TF_FusedBatchNormV3Op:$root
781 //         $x, $scale, $offset, $mean, $variance,
782 //         F32Attr:$epsilon, $exponential_avg_factor,
783 //         $data_format, FalseBoolAttr:$is_training),
784 //     [(TF_AddOp
785 //         (TF_MulOp
786 //             $x,
787 //             (TF_MulOp:$multiplier
788 //                 $scale,
789 //                 (TF_RsqrtOp
790 //                     (TF_AddOp $variance,
791 //                               (TF_ConstOp $epsilon))))),
792 //         (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
793 //    // We already guaranteed that the last five results have no use so it does
794 //    // not matter what value we provide here for replacement.
795 //      /*batch_mean=*/(replaceWithValue $x),
796 //      /*batch_variance=*/(replaceWithValue $x),
797 //      /*reserve_space_1=*/(replaceWithValue $x),
798 //      /*reserve_space_2=*/(replaceWithValue $x),
799 //      /*reserve_space_3=*/(replaceWithValue $x)],
800 //     [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
801 //      (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
802 //      (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
803 //
804 // When is_training is set to true, the given variance and mean are not used.
805 // In above calculation, they are replaced by new values. These new mean and
806 // variance are calculated as following:
807 // new_mean = mean(x, axis=[0, 1, 2])
808 // new_variance = mean(squared_difference(x, new_mean), axis=[0, 1, 2])
809 //
810 // The DDR rule for the is_training equals true case is as following:
811 // def : Pattern<
812 //     (TF_FusedBatchNormV3Op:$root
813 //         $x, $scale, $offset, $mean, $variance,
814 //         F32Attr:$epsilon, $exponential_avg_factor,
815 //         $data_format, FalseBoolAttr:$is_training),
816 //     [(TF_AddOp
817 //         (TF_MulOp
818 //             $x,
819 //             (TF_MulOp:$multiplier
820 //                 $scale,
821 //                 (TF_RsqrtOp
822 //                     (TF_AddOp
823 //                         (TF_MeanOp
824 //                             (TF_SquaredDifferenceOp $x, $new_mean),
825 //                             (TF_ConstOp [0,1,2])),
826 //                         (TF_ConstOp $epsilon))))),
827 //         (TF_SubOp
828 //             $offset,
829 //             (TF_MulOp
830 //                 (TF_MeanOp $x, (TF_ConstOp [0,1,2])),
831 //                 $multiplier))),
832 //    // We already guaranteed that the last five results have no use so it does
833 //    // not matter what value we provide here for replacement.
834 //      /*batch_mean=*/(replaceWithValue $x),
835 //      /*batch_variance=*/(replaceWithValue $x),
836 //      /*reserve_space_1=*/(replaceWithValue $x),
837 //      /*reserve_space_2=*/(replaceWithValue $x),
838 //      /*reserve_space_3=*/(replaceWithValue $x)],
839 //     [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
840 //      (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
841 //      (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
842 
843 struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
FusedBatchNormV3Patmlir::TFL::__anon092ff4650111::FusedBatchNormV3Pat844   explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
845       : ::mlir::RewritePattern(
846             "tf.FusedBatchNormV3", 1, context,
847             {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}) {}
848 
matchAndRewritemlir::TFL::__anon092ff4650111::FusedBatchNormV3Pat849   ::mlir::LogicalResult matchAndRewrite(
850       ::mlir::Operation *fused_batch_norm,
851       ::mlir::PatternRewriter &rewriter) const override {
852     // Variables for capturing values and attributes used for creating ops
853     Operation::operand_range mean(fused_batch_norm->getOperands());
854     ::mlir::FloatAttr exponential_avg_factor;
855     ::mlir::TF::FusedBatchNormV3Op root;
856     Operation::operand_range offset(fused_batch_norm->getOperands());
857     Operation::operand_range x(fused_batch_norm->getOperands());
858     Operation::operand_range scale(fused_batch_norm->getOperands());
859     Operation::operand_range variance(fused_batch_norm->getOperands());
860     ::mlir::FloatAttr epsilon;
861     ::mlir::BoolAttr is_training;
862 
863     // Match
864     auto fused_batch_norm_op =
865         dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm);
866     root = fused_batch_norm_op;
867     x = fused_batch_norm_op.getODSOperands(0);
868     scale = fused_batch_norm_op.getODSOperands(1);
869     offset = fused_batch_norm_op.getODSOperands(2);
870     mean = fused_batch_norm_op.getODSOperands(3);
871     variance = fused_batch_norm_op.getODSOperands(4);
872 
873     ::mlir::Value mean_value = (*mean.begin());
874     ::mlir::Value variance_value = (*variance.begin());
875 
876     if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
877 
878     {
879       epsilon =
880           fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>("epsilon");
881       if (!epsilon)
882         epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f);
883 
884       if (!(((epsilon.isa<::mlir::FloatAttr>())) &&
885             ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) {
886         return rewriter.notifyMatchFailure(
887             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
888               diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to "
889                       "satisfy constraint: 32-bit float attribute";
890             });
891       }
892     }
893     {
894       exponential_avg_factor =
895           fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>(
896               "exponential_avg_factor");
897       if (!exponential_avg_factor)
898         exponential_avg_factor =
899             rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
900     }
901     if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
902         !TFDataFormatIsNDHWC(fused_batch_norm_op))
903       return failure();
904 
905     if (!(((*root.getODSResults(1).begin()).use_empty()))) {
906       return rewriter.notifyMatchFailure(
907           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
908             diag << "entities '' failed to satisfy constraint: has no use";
909           });
910     }
911 
912     if (!(((*root.getODSResults(2).begin()).use_empty()))) {
913       return rewriter.notifyMatchFailure(
914           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
915             diag << "entities '' failed to satisfy constraint: has no use";
916           });
917     }
918 
919     if (!(((*root.getODSResults(3).begin()).use_empty()))) {
920       return rewriter.notifyMatchFailure(
921           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
922             diag << "entities '' failed to satisfy constraint: has no use";
923           });
924     }
925 
926     if (!(((*root.getODSResults(4).begin()).use_empty()))) {
927       return rewriter.notifyMatchFailure(
928           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
929             diag << "entities '' failed to satisfy constraint: has no use";
930           });
931     }
932 
933     if (!(((*root.getODSResults(5).begin()).use_empty()))) {
934       return rewriter.notifyMatchFailure(
935           fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
936             diag << "entities '' failed to satisfy constraint: has no use";
937           });
938     }
939 
940     is_training =
941         fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
942     auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
943 
944     // We need to make sure input and output shapes are compatible.
945     int64_t last_dim = -1;
946     {
947       auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
948         auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
949         if (!v_type) return true;
950         int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
951         if (v_last_dim == -1) return true;
952         if (last_dim != -1 && v_last_dim != last_dim) return false;
953         last_dim = v_last_dim;
954         return true;
955       };
956 
957       if (!is_last_dim_compatible(*x.begin(), last_dim) ||
958           !is_last_dim_compatible(*scale.begin(), last_dim) ||
959           !is_last_dim_compatible(*offset.begin(), last_dim)) {
960         return rewriter.notifyMatchFailure(
961             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
962               diag << "Shapes of scale and offset should be 1D and "
963                       "compatible with x";
964             });
965       }
966 
967       if (!is_training.getValue()) {
968         if (!is_last_dim_compatible(mean_value, last_dim) ||
969             !is_last_dim_compatible(variance_value, last_dim)) {
970           return rewriter.notifyMatchFailure(
971               fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
972                 diag << "Shapes of mean and variance should be 1D and "
973                         "compatible with x";
974               });
975         }
976       }
977 
978       // Check if output shape and input shape are compatible.
979       auto x_type = (*x.begin()).getType();
980       auto y_type = (*root.getODSResults(0).begin()).getType();
981       if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
982         return rewriter.notifyMatchFailure(
983             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
984               diag << "Shapes of x and the first output should be compatible";
985             });
986       }
987     }
988 
989     // For training, mean and variance is calculated from input values.
990     if (is_training.getValue()) {
991       auto input_type = fused_batch_norm_op.x()
992                             .getType()
993                             .dyn_cast_or_null<RankedTensorType>();
994       if (!input_type || input_type.getRank() != 4) {
995         return rewriter.notifyMatchFailure(
996             fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
997               diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
998                       "True is only supported with input of rank 4";
999             });
1000       }
1001 
1002       ::mlir::TF::ConstOp reduce_dim_op;
1003       {
1004         auto reduce_dim_type =
1005             ::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(32));
1006         ::mlir::SmallVector<int32_t, 3> reduce_dim_values = {0, 1, 2};
1007         reduce_dim_op = rewriter.create<TF::ConstOp>(
1008             odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
1009                                                       reduce_dim_values));
1010       }
1011 
1012       auto new_mean_type =
1013           ::mlir::RankedTensorType::get({last_dim}, rewriter.getF32Type());
1014       ::mlir::TF::MeanOp mean_op_1;
1015       {
1016         ::mlir::Value x_value = (*x.begin());
1017         mean_op_1 = rewriter.create<TF::MeanOp>(
1018             odsLoc, new_mean_type, x_value, reduce_dim_op,
1019             /*keep_dims=*/rewriter.getBoolAttr(false));
1020       }
1021 
1022       ::mlir::TF::SquaredDifferenceOp square_diff_op;
1023       {
1024         ::mlir::Value tblgen_value_0 = (*x.begin());
1025         ::mlir::Value tblgen_value_1 = (*mean_op_1.getODSResults(0).begin());
1026         // If x has shape of [b, h, w, c], the result of mean_op_1 will have
1027         // shape of [c]. Therefore, their shapes are always compatible.
1028         square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
1029             odsLoc, tblgen_value_0, tblgen_value_1);
1030       }
1031 
1032       ::mlir::TF::MeanOp mean_op_2;
1033       {
1034         ::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
1035         mean_op_2 = rewriter.create<TF::MeanOp>(
1036             odsLoc, new_mean_type, input_value, reduce_dim_op,
1037             /*keep_dims=*/rewriter.getBoolAttr(false));
1038       }
1039 
1040       mean_value = (*mean_op_1.getODSResults(0).begin());
1041       variance_value = (*mean_op_2.getODSResults(0).begin());
1042     }  // End is_training equals true if.
1043 
1044     ::llvm::SmallVector<::mlir::Value, 4> replace_values;
1045     ::mlir::TF::ConstOp epsilon_const_op;
1046     {
1047       epsilon_const_op =
1048           rewriter.create<::mlir::TF::ConstOp>(odsLoc,
1049                                                /*value=*/epsilon);
1050     }
1051     ::mlir::TF::AddOp add_op_1;
1052     {
1053       ::mlir::Value epsilon_value =
1054           (*epsilon_const_op.getODSResults(0).begin());
1055       // Multiplying with a constant, no need to check broadcastibility.
1056       add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
1057                                                     /*x=*/variance_value,
1058                                                     /*y=*/epsilon_value);
1059     }
1060     ::mlir::TF::RsqrtOp rsqrt_op;
1061     {
1062       ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1063       ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1064       tblgen_values.push_back((*add_op_1.getODSResults(0).begin()));
1065       rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values,
1066                                                       tblgen_attrs);
1067     }
1068     ::mlir::TF::MulOp multiplier;
1069     {
1070       ::mlir::Value tblgen_value_0 = (*scale.begin());
1071       ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
1072       multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1073                                                       /*x=*/tblgen_value_0,
1074                                                       /*y=*/tblgen_value_1);
1075     }
1076     ::mlir::TF::MulOp mul_op_1;
1077     {
1078       ::mlir::Value tblgen_value_0 = (*x.begin());
1079       ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
1080       mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1081                                                     /*x=*/tblgen_value_0,
1082                                                     /*y=*/tblgen_value_1);
1083     }
1084     ::mlir::TF::MulOp mul_op_2;
1085     {
1086       ::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
1087       mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1088                                                     /*x=*/mean_value,
1089                                                     /*y=*/multiplier_value);
1090     }
1091     ::mlir::TF::SubOp sub_op;
1092     {
1093       ::mlir::Value tblgen_value_0 = (*offset.begin());
1094       ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin());
1095       sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
1096                                                   /*x=*/tblgen_value_0,
1097                                                   /*y=*/tblgen_value_1);
1098     }
1099     ::mlir::TF::AddOp add_op_2;
1100     {
1101       ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1102       ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1103       tblgen_values.push_back((*mul_op_1.getODSResults(0).begin()));
1104       tblgen_values.push_back((*sub_op.getODSResults(0).begin()));
1105       ::mlir::SmallVector<::mlir::Type, 4> tblgen_types;
1106       for (auto v : fused_batch_norm_op.getODSResults(0)) {
1107         tblgen_types.push_back(v.getType());
1108       }
1109       add_op_2 = rewriter.create<::mlir::TF::AddOp>(
1110           odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
1111     }
1112     for (auto v :
1113          ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
1114       replace_values.push_back(v);
1115     }
1116     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1117       replace_values.push_back(v);
1118     }
1119     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1120       replace_values.push_back(v);
1121     }
1122     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1123       replace_values.push_back(v);
1124     }
1125     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1126       replace_values.push_back(v);
1127     }
1128     for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1129       replace_values.push_back(v);
1130     }
1131     rewriter.replaceOp(fused_batch_norm, replace_values);
1132     return success();
1133   };
1134 };
1135 
1136 #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
1137 
1138 // Returns success if all the operations in the `op`'s regions including `op`
1139 // itself are legal in a TFLite pipeline.
ValidateOp(Operation * op)1140 LogicalResult ValidateOp(Operation *op) {
1141   bool has_illegal_ops = false;
1142   op->walk([&](Operation *op) {
1143     if (isa<TF::VariableV2Op>(op)) {
1144       has_illegal_ops = true;
1145       op->emitOpError() << "is illegal in a TFLite pipeline";
1146     }
1147   });
1148 
1149   return failure(has_illegal_ops);
1150 }
1151 
1152 // Converts a set of TF2XLA ops into pure TF ops for future legalizations as
1153 // TF2XLA ops aren't supported by later stages.
ConvertTf2XlaOps(FuncOp func,MLIRContext * context)1154 LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
1155   ConversionTarget target(*context);
1156   target.addLegalDialect<StandardOpsDialect>();
1157   target.addLegalDialect<TF::TensorFlowDialect>();
1158   target.addLegalOp<ModuleOp>();
1159   target.addLegalOp<FuncOp>();
1160   target.addIllegalOp<TF::XlaConvOp>();
1161   target.addIllegalOp<TF::XlaGatherOp>();
1162 
1163   OwningRewritePatternList patterns(context);
1164   mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns, context);
1165   mhlo::PopulateLegalizeTfPatterns(context, &patterns);
1166   TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
1167   mhlo::GatherOp::getCanonicalizationPatterns(patterns, context);
1168 
1169   return applyPartialConversion(func, target, std::move(patterns));
1170 }
1171 
1172 // Convert rfft to rfft2d.
1173 // The transformation pattern looks like below:
1174 //
1175 //    input     fft_len
1176 //     \      /
1177 //     rfft
1178 //
1179 //     ||
1180 //     \/
1181 //
1182 //   input       fft_len
1183 //    \            /
1184 //   expand_dim    concat with [1] at the front
1185 //      \         /
1186 //     rfft_2d
1187 //       |
1188 //     squeeze
1189 struct ConvertRfftToRfft2d : public RewritePattern {
ConvertRfftToRfft2dmlir::TFL::__anon092ff4650111::ConvertRfftToRfft2d1190   explicit ConvertRfftToRfft2d(MLIRContext *context)
1191       : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {}
1192 
matchAndRewritemlir::TFL::__anon092ff4650111::ConvertRfftToRfft2d1193   LogicalResult matchAndRewrite(Operation *op,
1194                                 PatternRewriter &rewriter) const override {
1195     auto rfft_op = dyn_cast<TF::RFFTOp>(op);
1196 
1197     auto input = rfft_op.input();
1198     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
1199     if (!input_type) return failure();
1200     auto fft_len = rfft_op.fft_length();
1201     auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>();
1202     if (!fft_len_type) return failure();
1203 
1204     auto output_type =
1205         rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
1206     if (!output_type) return failure();
1207 
1208     // Expanded inputs.
1209     // Insert at -2 location.
1210     auto one_ele_type =
1211         mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32));
1212     auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1213                                                   one_ele_type, -2);
1214 
1215     SmallVector<int64_t, 4> expanded_input_shape;
1216     SmallVector<int64_t, 4> expanded_output_shape;
1217     int expanded_rank = input_type.getRank() + 1;
1218     int r = 0;
1219     for (int i = 0; i < expanded_rank; ++i) {
1220       if (i == expanded_rank - 2) {
1221         expanded_input_shape.push_back(1);
1222         expanded_output_shape.push_back(1);
1223       } else {
1224         expanded_input_shape.push_back(input_type.getDimSize(r));
1225         expanded_output_shape.push_back(output_type.getDimSize(r));
1226         r++;
1227       }
1228     }
1229 
1230     auto expaned_input_type = mlir::RankedTensorType::get(
1231         expanded_input_shape, input_type.getElementType());
1232     TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>(
1233         rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult());
1234 
1235     // Expanded fft_len.
1236     auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1});
1237 
1238     auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr);
1239 
1240     auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1241                                              one_ele_type, 0);
1242 
1243     auto expanded_fft_len_type =
1244         mlir::RankedTensorType::get({2}, fft_len_type.getElementType());
1245 
1246     TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>(
1247         rfft_op.getLoc(), expanded_fft_len_type,
1248         SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult());
1249 
1250     // Insert the rfft_2d.
1251     auto rfft2d_out_type = mlir::RankedTensorType::get(
1252         expanded_output_shape, output_type.getElementType());
1253     TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>(
1254         rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(),
1255         expanded_fft_len.getResult());
1256 
1257     // Insert the squeeze op.
1258     auto squeeze_dim = rewriter.getI64ArrayAttr({-2});
1259     TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>(
1260         rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim);
1261 
1262     rewriter.replaceOp(op, squeeze.getResult());
1263 
1264     return success();
1265   }
1266 };
1267 
runOnFunction()1268 void PrepareTFPass::runOnFunction() {
1269   MLIRContext *ctx = &getContext();
1270   OwningRewritePatternList patterns(ctx);
1271   OwningRewritePatternList phase_2_patterns(ctx);
1272   auto func = getFunction();
1273 
1274   // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since
1275   // PrepareTFPass is the very first TFLite pass in the pipeline.
1276   // TODO(jingpu): It might be better to split this check into its own pass
1277   // to make things more modular.
1278   if (failed(ValidateOp(func))) {
1279     func.emitError() << "tfl-prepare-tf pass failed.";
1280     signalPassFailure();
1281     return;
1282   }
1283 
1284   if (failed(ConvertTf2XlaOps(func, ctx))) {
1285     signalPassFailure();
1286     return;
1287   }
1288 
1289   // This pattern will try to identify and optimize for dilated convolution.
1290   // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
1291   // replaced with a single Conv op with dilation parameter.
1292   patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
1293                   ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
1294 
1295   TFL::populateWithGenerated(patterns);
1296   // TODO(karimnosseir): Split to separate pass probably after
1297   // deciding on long term plan for this optimization.
1298   // This will allow optimizing any TF_Mul->TF_Conv in the graph
1299   // and any expanded from FusedBatchNorm. We need to do this
1300   // before converting TF_Conv to TFL_Conv
1301   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1302 
1303   // Remove the wrapper of the tf.FakeQuant* ops and also insert the
1304   // tfl.quantize and tfl.dequantize to preserve the quantization parameters.
1305   // This is done after the first round of optimization to make sure all the
1306   // min/max operands of the tf.FakeQuant* are constants to be matched. The
1307   // following round of optimization will folding the unwrapped
1308   // tf.FakeQuant* ops with the weight constants.
1309   if (failed(ConvertFakeQuantOps(func, ctx))) {
1310     signalPassFailure();
1311     return;
1312   }
1313 
1314   // Load the generated pattern again, so new quantization pass-through
1315   // will be applied.
1316   TFL::populateWithGenerated(phase_2_patterns);
1317   if (unfold_batch_matmul_) {
1318     phase_2_patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
1319                             TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>,
1320                             TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV3Op>>(
1321         ctx);
1322   }
1323   phase_2_patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo,
1324                           ConvertTFStridedSlice, ConvertRfftToRfft2d>(ctx);
1325   phase_2_patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
1326       ctx, allow_bf16_and_f16_type_legalization_);
1327 
1328   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1329 }
1330 
1331 }  // namespace
1332 
1333 // Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
CreatePrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_type_legalization)1334 std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
1335     bool unfold_batch_matmul, bool allow_bf16_type_legalization) {
1336   return std::make_unique<PrepareTFPass>(unfold_batch_matmul,
1337                                          allow_bf16_type_legalization);
1338 }
1339 
1340 static PassRegistration<PrepareTFPass> pass;
1341 
1342 }  // namespace TFL
1343 }  // namespace mlir
1344