1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting operations in TensorFlow dialect into operations that can be
18 // legalized to TensorFlow Lite dialect with simple replacements. The newly
19 // created operations are in the TensorFlow dialect if the operation can be
20 // represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is
21 // used. For example, Conv2D in TFLite which uses OHWI data format for filters
22 // is not supported in TensorFlow because TensorFlow requires filters in the
23 // HWIO data format.
24 //
25 // Motivation to prepare for the TFLite legalization before the actual
26 // legalization is to exploit constant folding opportunities in any newly
27 // created ops by leveraging constant folding support for the TensorFlow ops.
28 // This way TFLite can be used as a serialization format only and does not
29 // require access to the TFLite runtime for optimizations as required by the
30 // TFLite team.
31
32 #include <climits>
33 #include <cstdint>
34
35 #include "absl/memory/memory.h"
36 #include "absl/numeric/bits.h"
37 #include "llvm/ADT/ArrayRef.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/StringSwitch.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Support/Debug.h"
42 #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
43 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
44 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
45 #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
46 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
47 #include "mlir/IR/Attributes.h" // from @llvm-project
48 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
49 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
50 #include "mlir/IR/MLIRContext.h" // from @llvm-project
51 #include "mlir/IR/Operation.h" // from @llvm-project
52 #include "mlir/Pass/Pass.h" // from @llvm-project
53 #include "mlir/Support/LLVM.h" // from @llvm-project
54 #include "mlir/Support/LogicalResult.h" // from @llvm-project
55 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
56 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
57 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
58 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
59 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
60 #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
61 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
62 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
63 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
64 #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h"
65 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
67 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
68 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
69 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
70 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
71 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
72
73 #define DEBUG_TYPE "tf-tfl-legalization"
74
75 namespace mlir {
76 namespace TFL {
77 //===----------------------------------------------------------------------===//
78 // The actual PrepareTF Pass.
79 //
80 // TODO(hinsu): Add and use TensorFlow dialect ops for the ops created in this
81 // pass.
82 namespace {
83
84 // Prepare TF operations in functions for subsequent legalization.
85 class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
86 public:
87 PrepareTFPass() = default;
PrepareTFPass(const PrepareTFPass &)88 PrepareTFPass(const PrepareTFPass &) {}
PrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_and_f16_type_legalization)89 explicit PrepareTFPass(bool unfold_batch_matmul,
90 bool allow_bf16_and_f16_type_legalization) {
91 unfold_batch_matmul_ = unfold_batch_matmul;
92 allow_bf16_and_f16_type_legalization_ =
93 allow_bf16_and_f16_type_legalization;
94 }
95
getArgument() const96 StringRef getArgument() const final {
97 // This is the argument used to refer to the pass in
98 // the textual format (on the commandline for example).
99 return "tfl-prepare-tf";
100 }
getDescription() const101 StringRef getDescription() const final {
102 // This is a brief description of the pass.
103 return "Prepare TF for legalization to TensorFlow Lite dialect";
104 }
105
106 void runOnFunction() override;
107
getDependentDialects(DialectRegistry & registry) const108 void getDependentDialects(DialectRegistry ®istry) const override {
109 registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
110 TFL::TensorFlowLiteDialect>();
111 }
112
113 private:
114 Option<bool> unfold_batch_matmul_{
115 *this, "tfl-unfold-batch-matmul",
116 llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
117 llvm::cl::init(true)};
118
119 Option<bool> allow_bf16_and_f16_type_legalization_{
120 *this, "tfl-allow-bf16-and-f16-type-legalization",
121 llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)};
122 };
123
124 // Transient state for preserving data from match to rewrite
125 struct ConvertTFConvOpMatchState {
126 IntegerAttr dilation_height_factor;
127 IntegerAttr dilation_width_factor;
128 StringAttr padding;
129 IntegerAttr stride_height;
130 IntegerAttr stride_width;
131 };
132
133 // Templated class for declaring a converter from some TensorFlow convolution
134 // op into its counterpart in TensorFlow Lite.
135 //
136 // The `ConcreteType` deriving from this template must provide the following
137 // method for constructing TensorFlow Lite op:
138 //
139 // TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
140 // PatternRewriter &rewriter, Location loc,
141 // Type result_type, Value input,
142 // Value filter, Value bias) const;
143 //
144 // And also the following method for getting the dimension for bias tensor:
145 //
146 // int64_t getBiasDim(ArrayRef<int64_t> filterShape) const;
147 template <typename ConcreteType, typename TFConvOpType>
148 class ConvertTFConvOp : public RewritePattern {
149 public:
ConvertTFConvOp(MLIRContext * context,bool allow_bf16_and_f16_type_legalization)150 ConvertTFConvOp(MLIRContext *context,
151 bool allow_bf16_and_f16_type_legalization)
152 : RewritePattern(TFConvOpType::getOperationName(), 1, context),
153 intAttrOne(Builder(context).getI32IntegerAttr(1)),
154 allow_bf16_and_f16_type_legalization_(
155 allow_bf16_and_f16_type_legalization) {}
156
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const157 LogicalResult matchAndRewrite(Operation *op,
158 PatternRewriter &rewriter) const override {
159 // Assumes TensorFlow convolution op is already verified to be
160 // in valid form.
161
162 // Match a TFConvOpType under the following conditions:
163 // * The 'T' attribute must exist and be of value DT_FLOAT.
164 // * The 'data_format' attribute must exist and be of value "NHWC".
165 // * The 'strides' attribute must exist and is of the form [1, X, Y, 1].
166 // * The 'dilations' attribute is optional, but it must be of the form
167 // [1, X, Y, 1] if exists.
168
169 TFConvOpType tf_op = cast<TFConvOpType>(op);
170
171 if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
172 !(allow_bf16_and_f16_type_legalization_ &&
173 TFTypeIsBFloat16OrHalfTensor(tf_op.input())))
174 return failure();
175
176 if (!TFDataFormatIsNHWC(op)) return failure();
177
178 IntegerAttr height, width;
179 if (!TFIntListIs1XY1(op, "strides", &height, &width)) return failure();
180
181 ConvertTFConvOpMatchState state;
182 state.stride_height = height;
183 state.stride_width = width;
184
185 if (TFIntListIs1XY1(op, "dilations", &height, &width)) {
186 state.dilation_height_factor = height;
187 state.dilation_width_factor = width;
188 } else {
189 // If the 'dilations' attribute is missing, we use the default value (1)
190 // for both dilation height and width factor.
191 state.dilation_height_factor = intAttrOne;
192 state.dilation_width_factor = intAttrOne;
193 }
194
195 if (!TFPaddingIsSameOrValid(op, &state.padding)) return failure();
196
197 // Additionally, we require the filter operand to be of 4-D tensor type so
198 // that we can extract info from the shape (e.g., for constructing bias
199 // tensor, for setting depth_multiplier attribute, etc.).
200 auto filter = tf_op.filter();
201 auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
202 if (!filter_type || filter_type.getRank() != 4 ||
203 !filter_type.hasStaticShape())
204 return failure();
205
206 // TensorFlow convolution op only has two inputs, while the TFLite one has
207 // three, with the bias vector marked as optional. However, TOCO has a
208 // dedicated pass, EnsureBiasVectors, to create default bias vectors for all
209 // those missing. So we model TFLite convolution op as requiring three
210 // inputs to achieve the legalization task of EnsureBiasVector. this
211 // requires the filter tensor to have static shape.
212
213 // TODO(antiagainst): also handle the case of tf.Add(tf.[op], <bias>)
214
215 // Get a splat zero tensor with the expected dimension for the bias tensor
216 auto elem_type = filter_type.getElementType();
217 auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
218 filter_type.getShape());
219 auto bias_type = RankedTensorType::get({bias_dim}, elem_type);
220 auto bias_attr = rewriter.getZeroAttr(bias_type);
221 auto bias =
222 rewriter.create<TF::ConstOp>(op->getLoc(), bias_type, bias_attr);
223
224 auto conv_op = static_cast<const ConcreteType *>(this)->createTFLOp(
225 &state, rewriter, op->getLoc(), tf_op.getType(), tf_op.input(), filter,
226 bias);
227
228 rewriter.replaceOp(op, conv_op.getResult());
229 return success();
230 }
231
232 const IntegerAttr intAttrOne;
233
234 private:
235 bool allow_bf16_and_f16_type_legalization_;
236 };
237
238 class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
239 public:
240 using BaseType = ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp>;
241
ConvertTFConv2D(MLIRContext * context,bool allow_bf16_type_legalization)242 ConvertTFConv2D(MLIRContext *context, bool allow_bf16_type_legalization)
243 : BaseType(context, allow_bf16_type_legalization) {}
244
getBiasDim(ArrayRef<int64_t> filterShape) const245 int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
246 return filterShape.back();
247 }
248
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const249 TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
250 PatternRewriter &rewriter, Location loc,
251 Type result_type, Value input, Value filter,
252 Value bias) const {
253 filter = legalizeFilter(rewriter, loc, filter);
254 return rewriter.create<TFL::Conv2DOp>(
255 loc, result_type, input, filter, bias,
256 /*dilation_h_factor=*/state->dilation_height_factor,
257 /*dilation_w_factor=*/state->dilation_width_factor,
258 /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
259 /*padding=*/state->padding,
260 /*stride_h=*/state->stride_height,
261 /*stride_w=*/state->stride_width);
262 }
263
264 private:
265 // Legalize the given filter by converting it from TensorFlow filter data
266 // format HWIO to TFLite Conv2D op filter data format OHWI and return Value
267 // for the converted filter. Requires that filter is verified by the match
268 // method that it is a 4-D RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const269 Value legalizeFilter(PatternRewriter &rewriter, Location loc,
270 Value filter) const {
271 // Create a constant op for HWIO to OHWI transpose permutation.
272 SmallVector<int, 4> perm = {3, 0, 1, 2};
273 auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
274 rewriter.getIntegerType(32));
275 auto perm_attr =
276 DenseElementsAttr::get(perm_type, llvm::makeArrayRef<int>(perm));
277 auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
278
279 // Create tensor type for the transpose result.
280 auto filter_type = filter.getType().cast<RankedTensorType>();
281 auto result_shape =
282 llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) {
283 return filter_type.getDimSize(dim);
284 }));
285 auto elem_type = filter_type.getElementType();
286 auto result_type = RankedTensorType::get(result_shape, elem_type);
287
288 return rewriter.create<TF::TransposeOp>(loc, result_type, filter, perm_op);
289 }
290 };
291
292 class ConvertTFDepthwiseConv2dNative
293 : public ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
294 TF::DepthwiseConv2dNativeOp> {
295 public:
296 using BaseType = ConvertTFConvOp<ConvertTFDepthwiseConv2dNative,
297 TF::DepthwiseConv2dNativeOp>;
298
ConvertTFDepthwiseConv2dNative(MLIRContext * context,bool allow_bf16_type_legalization)299 ConvertTFDepthwiseConv2dNative(MLIRContext *context,
300 bool allow_bf16_type_legalization)
301 : BaseType(context, allow_bf16_type_legalization) {}
302
getBiasDim(ArrayRef<int64_t> filterShape) const303 int64_t getBiasDim(ArrayRef<int64_t> filterShape) const {
304 return filterShape[2] * filterShape[3];
305 }
306
createTFLOp(ConvertTFConvOpMatchState * state,PatternRewriter & rewriter,Location loc,Type result_type,Value input,Value filter,Value bias) const307 TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
308 PatternRewriter &rewriter, Location loc,
309 Type result_type, Value input,
310 Value filter, Value bias) const {
311 // Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
312 // 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
313 // have a corresponding 'depth_multiplier' attribute; the multiplier is the
314 // fourth dimension in the 4-D filter tensor. We query the multiplier from
315 // tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
316 auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
317
318 filter = legalizeFilter(rewriter, loc, filter);
319 return rewriter.create<TFL::DepthwiseConv2DOp>(
320 loc, result_type, input, filter, bias,
321 /*dilation_h_factor=*/state->dilation_height_factor,
322 /*dilation_w_factor=*/state->dilation_width_factor,
323 /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
324 /*padding=*/state->padding,
325 /*stride_h=*/state->stride_height,
326 /*stride_w=*/state->stride_width,
327 /*depth_multiplier=*/rewriter.getI32IntegerAttr(multiplier));
328 }
329
330 private:
331 /// Legalize the given filter by converting it from TensorFlow filter data
332 /// format to TFLite DepthwiseConv2D op filter data format and return Value
333 /// for the converted filter. TensorFlow filter data format is
334 /// [filter_height, filter_width, in_channels, channel_multiplier] and TFLite
335 /// filter data format is [1, filter_height, filter_width, out_channels].
336 /// Requires that filter is verified by the match method that it is a 4-D
337 /// RankedTensorType.
legalizeFilter(PatternRewriter & rewriter,Location loc,Value filter) const338 Value legalizeFilter(PatternRewriter &rewriter, Location loc,
339 Value filter) const {
340 auto filter_type = filter.getType().cast<RankedTensorType>();
341 auto filterShape = filter_type.getShape();
342 SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
343 filterShape[2] * filterShape[3]};
344 auto elem_type = filter_type.getElementType();
345 auto result_type = RankedTensorType::get(result_shape, elem_type);
346 // TensorFlow Lite `Reshape` op only support int32 shape tensor currently.
347 auto shape_type = RankedTensorType::get({4}, rewriter.getIntegerType(32));
348 SmallVector<Attribute, 4> result_shape_data(4);
349 for (int i = 0; i < 4; ++i) {
350 result_shape_data[i] =
351 rewriter.getI32IntegerAttr(static_cast<int32_t>(result_shape[i]));
352 }
353 auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
354 auto shape = rewriter.create<TF::ConstOp>(loc, shape_type, shape_attr);
355
356 return rewriter.create<TF::ReshapeOp>(loc, result_type, filter, shape);
357 }
358 };
359
360 // StridedSlice can have complicated attributes like begin_axis_mask,
361 // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
362 // masks will complicate the strided_slice computation logic, we can simplify
363 // the logic by inserting a reshape op to pad the inputs so strided_slice can
364 // be easier to handle.
365 //
366 // So the graph may looks like below:
367 // original_input -> strided_slice -> output
368 // (transforms)
369 // original_input -> reshape -> strided_slice -> output
370 //
371 // And the new shape is computed based on the masks.
372 //
373 // An example for new_axis_mask. say the new_axis_mask is 9 which represents
374 // [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so
375 // if original shape is [2, 3], now we reshape that into [1, 2, 3, 1].
376 struct ConvertTFStridedSlice : public RewritePattern {
ConvertTFStridedSlicemlir::TFL::__anon092ff4650111::ConvertTFStridedSlice377 explicit ConvertTFStridedSlice(MLIRContext *context)
378 : RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
379
RewriteNewAxisMaskmlir::TFL::__anon092ff4650111::ConvertTFStridedSlice380 LogicalResult RewriteNewAxisMask(Operation *op,
381 PatternRewriter &rewriter) const {
382 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
383 uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
384
385 if (strided_slice_op.ellipsis_mask() != 0) {
386 // Ellipsis mask should have been lowered-away prior to invoking this
387 // function.
388 op->emitError() << "encountered a logical error";
389 return failure();
390 }
391
392 // Insert a new reshape op.
393 Value original_input = strided_slice_op.input();
394 RankedTensorType original_input_type =
395 original_input.getType().dyn_cast<RankedTensorType>();
396 if (!original_input_type) {
397 return failure();
398 }
399
400 const ArrayRef<int64_t> &original_input_shape =
401 original_input_type.getShape();
402 SmallVector<int64_t, 4> revised_shape;
403 int index = 0;
404 const int original_input_rank = original_input_shape.size();
405 while (index < original_input_rank || new_axis_mask) {
406 if (new_axis_mask & 1) {
407 revised_shape.emplace_back(1);
408 } else {
409 revised_shape.emplace_back(original_input_shape[index++]);
410 }
411 new_axis_mask >>= 1;
412 }
413
414 if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
415
416 const int dim_size = revised_shape.size();
417 Location loc = strided_slice_op.getLoc();
418 auto shape_type =
419 RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
420 SmallVector<Attribute, 4> result_shape_data(dim_size);
421 for (int i = 0; i < dim_size; ++i) {
422 result_shape_data[i] =
423 rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
424 }
425
426 auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
427 auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
428 auto revised_output_type = RankedTensorType::get(
429 revised_shape, original_input_type.getElementType());
430 TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
431 loc, revised_output_type, original_input, shape);
432
433 // Replace the original strided_slice.
434 uint64_t revised_begin_mask = strided_slice_op.begin_mask();
435 uint64_t revised_end_mask = strided_slice_op.end_mask();
436 // Since we expand the dims, we need to apply them to the begin_mask &
437 // end_mask.
438 revised_begin_mask |= strided_slice_op.new_axis_mask();
439 revised_end_mask |= strided_slice_op.new_axis_mask();
440
441 // Enforce operator precedence.
442 uint64_t revised_shrink_axis_mask =
443 strided_slice_op.shrink_axis_mask() & ~strided_slice_op.new_axis_mask();
444
445 auto attribute_type = rewriter.getIntegerType(64);
446 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
447 op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
448 strided_slice_op.end(), strided_slice_op.strides(),
449 rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
450 rewriter.getIntegerAttr(attribute_type, revised_end_mask),
451 rewriter.getIntegerAttr(attribute_type,
452 strided_slice_op.ellipsis_mask()),
453 rewriter.getI64IntegerAttr(0),
454 rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
455 return success();
456 }
457
RewriteEllipsisMaskmlir::TFL::__anon092ff4650111::ConvertTFStridedSlice458 LogicalResult RewriteEllipsisMask(Operation *op,
459 PatternRewriter &rewriter) const {
460 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
461
462 uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
463 uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
464 uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
465
466 // Enforce operator precedence.
467 shrink_axis_mask &= ~ellipsis_mask;
468 new_axis_mask &= ~ellipsis_mask;
469
470 DenseIntElementsAttr begin_dense_elem_attr;
471 Value begin = strided_slice_op.begin();
472 auto begin_ranked_attr_type = begin.getType().dyn_cast<RankedTensorType>();
473 if (!begin_ranked_attr_type ||
474 !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) {
475 return failure();
476 }
477
478 DenseIntElementsAttr end_dense_elem_attr;
479 Value end = strided_slice_op.end();
480 auto end_ranked_attr_type = end.getType().dyn_cast<RankedTensorType>();
481 if (!end_ranked_attr_type ||
482 !matchPattern(end, m_Constant(&end_dense_elem_attr))) {
483 return failure();
484 }
485
486 DenseIntElementsAttr stride_dense_elem_attr;
487 Value stride = strided_slice_op.strides();
488 auto stride_ranked_attr_type =
489 stride.getType().dyn_cast<RankedTensorType>();
490 if (!stride_ranked_attr_type ||
491 !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) {
492 return failure();
493 }
494
495 Value input = strided_slice_op.input();
496 RankedTensorType input_type = input.getType().dyn_cast<RankedTensorType>();
497 if (!input_type) {
498 return failure();
499 }
500 const ArrayRef<int64_t> input_shape = input_type.getShape();
501
502 const int input_size = input_shape.size();
503
504 RankedTensorType begin_type = begin.getType().cast<RankedTensorType>();
505 const ArrayRef<int64_t> begin_shape = begin_type.getShape();
506 const int begin_dim = begin_shape.size();
507
508 if (begin_dim != 1) return failure();
509
510 // The ellipsis fill might exceed the current output shape because we are
511 // also taking account of any to-be-inserted new axes.
512 const int ellipsis_filled_dim_size =
513 input_size - begin_shape[0] + 1 + absl::popcount(new_axis_mask);
514
515 int64_t begin_mask = strided_slice_op.begin_mask();
516 int64_t end_mask = strided_slice_op.end_mask();
517 int64_t revised_begin_mask = 0;
518 int64_t revised_end_mask = 0;
519 int64_t revised_shrink_axis_mask = 0;
520 int64_t revised_new_axis_mask = 0;
521
522 SmallVector<int32_t, 4> padded_begin;
523 SmallVector<int32_t, 4> padded_end;
524 SmallVector<int32_t, 4> padded_stride;
525
526 // Before the ellipsis.
527 int index = 0;
528 int new_index = 0;
529 while (((ellipsis_mask >> index) & 1) == 0) {
530 padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
531 padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
532 padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
533 if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
534 if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
535 if ((shrink_axis_mask >> index) & 1)
536 revised_shrink_axis_mask |= (1 << new_index);
537
538 if ((new_axis_mask >> index) & 1)
539 revised_new_axis_mask |= (1 << new_index);
540
541 ++index;
542 ++new_index;
543 }
544
545 // Ellipsis.
546 for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
547 revised_begin_mask |= (1 << new_index);
548 revised_end_mask |= (1 << new_index);
549
550 // Mimic the begin/end/strides mask behavior.
551 padded_begin.push_back(0);
552 padded_end.push_back(0);
553 padded_stride.push_back(1);
554 }
555
556 // Account for ellipsis mask.
557 ++index;
558
559 // After the ellipsis.
560 for (; index < begin_shape[0];) {
561 padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
562 padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
563 padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
564
565 if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
566 if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
567 if ((shrink_axis_mask >> index) & 1)
568 revised_shrink_axis_mask |= (1 << new_index);
569 if ((new_axis_mask >> index) & 1)
570 revised_new_axis_mask |= (1 << new_index);
571
572 ++index;
573 ++new_index;
574 }
575
576 auto attribute_type = rewriter.getIntegerType(64);
577
578 int full_dim_count = padded_begin.size();
579 auto type =
580 RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
581
582 auto begin_attr = DenseElementsAttr::get<int32_t>(type, padded_begin);
583 auto begin_op = rewriter.create<ConstantOp>(op->getLoc(), type, begin_attr);
584 auto end_attr = DenseElementsAttr::get<int32_t>(type, padded_end);
585 auto end_op = rewriter.create<ConstantOp>(op->getLoc(), type, end_attr);
586 auto stride_attr = DenseElementsAttr::get<int32_t>(type, padded_stride);
587 auto stride_op =
588 rewriter.create<ConstantOp>(op->getLoc(), type, stride_attr);
589
590 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
591 op, strided_slice_op.getType(), input, begin_op.getResult(),
592 end_op.getResult(), stride_op.getResult(),
593 rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
594 rewriter.getIntegerAttr(attribute_type, revised_end_mask),
595 /*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
596 rewriter.getIntegerAttr(attribute_type, revised_new_axis_mask),
597 rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
598
599 return success();
600 }
601
PadStridedSliceAttributeArraymlir::TFL::__anon092ff4650111::ConvertTFStridedSlice602 void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
603 SmallVectorImpl<int32_t> &val,
604 SmallVectorImpl<int32_t> &padded_val,
605 ArrayRef<int32_t> padding_val,
606 int *mask) const {
607 for (const auto &idx : dense_elem_attr.getIntValues()) {
608 val.push_back(idx.getSExtValue());
609 padded_val.push_back(idx.getSExtValue());
610 }
611 int attr_dim_count = val.size();
612 int full_dim_count = padding_val.size();
613 for (int i = attr_dim_count; i < full_dim_count; ++i) {
614 padded_val.push_back(padding_val[i]);
615 if (mask) *mask |= 1 << i;
616 }
617 }
618
matchAndRewritemlir::TFL::__anon092ff4650111::ConvertTFStridedSlice619 LogicalResult matchAndRewrite(Operation *op,
620 PatternRewriter &rewriter) const override {
621 TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
622
623 // Handle ellipsis mask.
624 if (strided_slice_op.ellipsis_mask() != 0) {
625 return RewriteEllipsisMask(strided_slice_op, rewriter);
626 }
627
628 // Handle new axis mask.
629 if (strided_slice_op.new_axis_mask() != 0) {
630 return RewriteNewAxisMask(strided_slice_op, rewriter);
631 }
632
633 auto ranked_input_type =
634 strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
635 if (!ranked_input_type) {
636 return failure();
637 }
638
639 auto begin_attr = strided_slice_op.begin();
640 auto end_attr = strided_slice_op.end();
641 auto strides_attr = strided_slice_op.strides();
642
643 auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
644 auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
645 auto strides_attr_type =
646 strides_attr.getType().dyn_cast<RankedTensorType>();
647
648 DenseIntElementsAttr begin_elem_attr;
649 DenseIntElementsAttr end_elem_attr;
650 DenseIntElementsAttr strides_elem_attr;
651
652 if (!begin_attr_type ||
653 !matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
654 return failure();
655 }
656 if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
657 return failure();
658 }
659 if (!strides_attr_type ||
660 !matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
661 return failure();
662 }
663
664 SmallVector<int32_t, 4> begin, end, strides;
665 SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
666
667 int num_input_dims = ranked_input_type.getRank();
668 SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
669 auto input_shape = ranked_input_type.getShape();
670 SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
671 SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
672
673 int begin_mask = strided_slice_op.begin_mask();
674 int end_mask = strided_slice_op.end_mask();
675
676 PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
677 padding_begin, &begin_mask);
678 PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
679 &end_mask);
680 PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
681 padding_strides, nullptr);
682
683 if (begin == padded_begin && end == padded_end &&
684 strides == padded_strides &&
685 begin_mask == strided_slice_op.begin_mask() &&
686 end_mask == strided_slice_op.end_mask()) {
687 return failure();
688 }
689
690 auto begin_end_type =
691 RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
692 auto new_begin_attr = rewriter.create<ConstantOp>(
693 op->getLoc(), begin_end_type,
694 DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
695 auto new_end_attr = rewriter.create<ConstantOp>(
696 op->getLoc(), begin_end_type,
697 DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
698 auto strides_type =
699 RankedTensorType::get({static_cast<long>(padded_strides.size())},
700 rewriter.getIntegerType(32));
701 auto new_strides_attr = rewriter.create<ConstantOp>(
702 op->getLoc(), strides_type,
703 DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
704
705 auto attribute_type = rewriter.getIntegerType(64);
706 rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
707 op, strided_slice_op.output().getType(), strided_slice_op.input(),
708 new_begin_attr, new_end_attr, new_strides_attr,
709 rewriter.getIntegerAttr(attribute_type, begin_mask),
710 rewriter.getIntegerAttr(attribute_type, end_mask),
711 rewriter.getIntegerAttr(attribute_type,
712 strided_slice_op.ellipsis_mask()),
713 rewriter.getIntegerAttr(attribute_type,
714 strided_slice_op.new_axis_mask()),
715 rewriter.getIntegerAttr(attribute_type,
716 strided_slice_op.shrink_axis_mask()));
717
718 return success();
719 }
720 };
721
722 struct ConvertTFBroadcastTo : public RewritePattern {
ConvertTFBroadcastTomlir::TFL::__anon092ff4650111::ConvertTFBroadcastTo723 explicit ConvertTFBroadcastTo(MLIRContext *context)
724 : RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
725
matchAndRewritemlir::TFL::__anon092ff4650111::ConvertTFBroadcastTo726 LogicalResult matchAndRewrite(Operation *op,
727 PatternRewriter &rewriter) const override {
728 auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
729 auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
730 auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
731 auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
732 Type element_type = input_type.getElementType();
733
734 // Allow lowering when low dimension inputs are given and its type is F32 or
735 // I32.
736 if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
737 (shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
738 shape_type.getDimSize(0) <= 4)))
739 return failure();
740
741 if (!(element_type.isa<BFloat16Type, Float32Type>() ||
742 element_type.isInteger(32) || element_type.isInteger(16)))
743 return failure();
744
745 auto status_or_const_op =
746 CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
747 if (!status_or_const_op.ok()) {
748 return failure();
749 }
750
751 auto tf_fill_op = rewriter.create<TF::FillOp>(
752 op->getLoc(), output_type, tf_broadcast_to_op.shape(),
753 status_or_const_op.ValueOrDie());
754
755 auto mul_op = rewriter.create<TF::MulOp>(
756 op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
757 rewriter.replaceOp(op, mul_op.getResult());
758 return success();
759 }
760 };
761
762 // The below pattern is equivalent to the DRR rule below
763 // The checks are dependent on generated values, so we can't add
764 // the checks on intermediate values, ideally we should find equivalent
765 // checks that guarantees the resultant ops are valid.
766 // The extra conditions are the broadcasting conditions.
767 //
768 // The pattern lower FusedBatchNormV3 to arithmetic ops.
769 // Specifically, performs the following calculation:
770 //
771 // (x - mean) * scale / sqrt(variance + epsilon) + offset
772 //
773 // Let multiplier = scale / sqrt(variance + epsilon),
774 // to compute
775 // (x - mean) * scale / sqrt(variance + epsilon) + offset,
776 // is then to compute
777 // (x * multiplier) + (offset - mean * multiplier).
778 //
779 // def : Pattern<
780 // (TF_FusedBatchNormV3Op:$root
781 // $x, $scale, $offset, $mean, $variance,
782 // F32Attr:$epsilon, $exponential_avg_factor,
783 // $data_format, FalseBoolAttr:$is_training),
784 // [(TF_AddOp
785 // (TF_MulOp
786 // $x,
787 // (TF_MulOp:$multiplier
788 // $scale,
789 // (TF_RsqrtOp
790 // (TF_AddOp $variance,
791 // (TF_ConstOp $epsilon))))),
792 // (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
793 // // We already guaranteed that the last five results have no use so it does
794 // // not matter what value we provide here for replacement.
795 // /*batch_mean=*/(replaceWithValue $x),
796 // /*batch_variance=*/(replaceWithValue $x),
797 // /*reserve_space_1=*/(replaceWithValue $x),
798 // /*reserve_space_2=*/(replaceWithValue $x),
799 // /*reserve_space_3=*/(replaceWithValue $x)],
800 // [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
801 // (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
802 // (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
803 //
804 // When is_training is set to true, the given variance and mean are not used.
805 // In above calculation, they are replaced by new values. These new mean and
806 // variance are calculated as following:
807 // new_mean = mean(x, axis=[0, 1, 2])
808 // new_variance = mean(squared_difference(x, new_mean), axis=[0, 1, 2])
809 //
810 // The DDR rule for the is_training equals true case is as following:
811 // def : Pattern<
812 // (TF_FusedBatchNormV3Op:$root
813 // $x, $scale, $offset, $mean, $variance,
814 // F32Attr:$epsilon, $exponential_avg_factor,
815 // $data_format, FalseBoolAttr:$is_training),
816 // [(TF_AddOp
817 // (TF_MulOp
818 // $x,
819 // (TF_MulOp:$multiplier
820 // $scale,
821 // (TF_RsqrtOp
822 // (TF_AddOp
823 // (TF_MeanOp
824 // (TF_SquaredDifferenceOp $x, $new_mean),
825 // (TF_ConstOp [0,1,2])),
826 // (TF_ConstOp $epsilon))))),
827 // (TF_SubOp
828 // $offset,
829 // (TF_MulOp
830 // (TF_MeanOp $x, (TF_ConstOp [0,1,2])),
831 // $multiplier))),
832 // // We already guaranteed that the last five results have no use so it does
833 // // not matter what value we provide here for replacement.
834 // /*batch_mean=*/(replaceWithValue $x),
835 // /*batch_variance=*/(replaceWithValue $x),
836 // /*reserve_space_1=*/(replaceWithValue $x),
837 // /*reserve_space_2=*/(replaceWithValue $x),
838 // /*reserve_space_3=*/(replaceWithValue $x)],
839 // [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
840 // (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
841 // (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
842
843 struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
FusedBatchNormV3Patmlir::TFL::__anon092ff4650111::FusedBatchNormV3Pat844 explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
845 : ::mlir::RewritePattern(
846 "tf.FusedBatchNormV3", 1, context,
847 {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}) {}
848
matchAndRewritemlir::TFL::__anon092ff4650111::FusedBatchNormV3Pat849 ::mlir::LogicalResult matchAndRewrite(
850 ::mlir::Operation *fused_batch_norm,
851 ::mlir::PatternRewriter &rewriter) const override {
852 // Variables for capturing values and attributes used for creating ops
853 Operation::operand_range mean(fused_batch_norm->getOperands());
854 ::mlir::FloatAttr exponential_avg_factor;
855 ::mlir::TF::FusedBatchNormV3Op root;
856 Operation::operand_range offset(fused_batch_norm->getOperands());
857 Operation::operand_range x(fused_batch_norm->getOperands());
858 Operation::operand_range scale(fused_batch_norm->getOperands());
859 Operation::operand_range variance(fused_batch_norm->getOperands());
860 ::mlir::FloatAttr epsilon;
861 ::mlir::BoolAttr is_training;
862
863 // Match
864 auto fused_batch_norm_op =
865 dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm);
866 root = fused_batch_norm_op;
867 x = fused_batch_norm_op.getODSOperands(0);
868 scale = fused_batch_norm_op.getODSOperands(1);
869 offset = fused_batch_norm_op.getODSOperands(2);
870 mean = fused_batch_norm_op.getODSOperands(3);
871 variance = fused_batch_norm_op.getODSOperands(4);
872
873 ::mlir::Value mean_value = (*mean.begin());
874 ::mlir::Value variance_value = (*variance.begin());
875
876 if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
877
878 {
879 epsilon =
880 fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>("epsilon");
881 if (!epsilon)
882 epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f);
883
884 if (!(((epsilon.isa<::mlir::FloatAttr>())) &&
885 ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) {
886 return rewriter.notifyMatchFailure(
887 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
888 diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to "
889 "satisfy constraint: 32-bit float attribute";
890 });
891 }
892 }
893 {
894 exponential_avg_factor =
895 fused_batch_norm_op->getAttrOfType<::mlir::FloatAttr>(
896 "exponential_avg_factor");
897 if (!exponential_avg_factor)
898 exponential_avg_factor =
899 rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
900 }
901 if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
902 !TFDataFormatIsNDHWC(fused_batch_norm_op))
903 return failure();
904
905 if (!(((*root.getODSResults(1).begin()).use_empty()))) {
906 return rewriter.notifyMatchFailure(
907 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
908 diag << "entities '' failed to satisfy constraint: has no use";
909 });
910 }
911
912 if (!(((*root.getODSResults(2).begin()).use_empty()))) {
913 return rewriter.notifyMatchFailure(
914 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
915 diag << "entities '' failed to satisfy constraint: has no use";
916 });
917 }
918
919 if (!(((*root.getODSResults(3).begin()).use_empty()))) {
920 return rewriter.notifyMatchFailure(
921 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
922 diag << "entities '' failed to satisfy constraint: has no use";
923 });
924 }
925
926 if (!(((*root.getODSResults(4).begin()).use_empty()))) {
927 return rewriter.notifyMatchFailure(
928 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
929 diag << "entities '' failed to satisfy constraint: has no use";
930 });
931 }
932
933 if (!(((*root.getODSResults(5).begin()).use_empty()))) {
934 return rewriter.notifyMatchFailure(
935 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
936 diag << "entities '' failed to satisfy constraint: has no use";
937 });
938 }
939
940 is_training =
941 fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
942 auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
943
944 // We need to make sure input and output shapes are compatible.
945 int64_t last_dim = -1;
946 {
947 auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
948 auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
949 if (!v_type) return true;
950 int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
951 if (v_last_dim == -1) return true;
952 if (last_dim != -1 && v_last_dim != last_dim) return false;
953 last_dim = v_last_dim;
954 return true;
955 };
956
957 if (!is_last_dim_compatible(*x.begin(), last_dim) ||
958 !is_last_dim_compatible(*scale.begin(), last_dim) ||
959 !is_last_dim_compatible(*offset.begin(), last_dim)) {
960 return rewriter.notifyMatchFailure(
961 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
962 diag << "Shapes of scale and offset should be 1D and "
963 "compatible with x";
964 });
965 }
966
967 if (!is_training.getValue()) {
968 if (!is_last_dim_compatible(mean_value, last_dim) ||
969 !is_last_dim_compatible(variance_value, last_dim)) {
970 return rewriter.notifyMatchFailure(
971 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
972 diag << "Shapes of mean and variance should be 1D and "
973 "compatible with x";
974 });
975 }
976 }
977
978 // Check if output shape and input shape are compatible.
979 auto x_type = (*x.begin()).getType();
980 auto y_type = (*root.getODSResults(0).begin()).getType();
981 if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
982 return rewriter.notifyMatchFailure(
983 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
984 diag << "Shapes of x and the first output should be compatible";
985 });
986 }
987 }
988
989 // For training, mean and variance is calculated from input values.
990 if (is_training.getValue()) {
991 auto input_type = fused_batch_norm_op.x()
992 .getType()
993 .dyn_cast_or_null<RankedTensorType>();
994 if (!input_type || input_type.getRank() != 4) {
995 return rewriter.notifyMatchFailure(
996 fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
997 diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
998 "True is only supported with input of rank 4";
999 });
1000 }
1001
1002 ::mlir::TF::ConstOp reduce_dim_op;
1003 {
1004 auto reduce_dim_type =
1005 ::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(32));
1006 ::mlir::SmallVector<int32_t, 3> reduce_dim_values = {0, 1, 2};
1007 reduce_dim_op = rewriter.create<TF::ConstOp>(
1008 odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
1009 reduce_dim_values));
1010 }
1011
1012 auto new_mean_type =
1013 ::mlir::RankedTensorType::get({last_dim}, rewriter.getF32Type());
1014 ::mlir::TF::MeanOp mean_op_1;
1015 {
1016 ::mlir::Value x_value = (*x.begin());
1017 mean_op_1 = rewriter.create<TF::MeanOp>(
1018 odsLoc, new_mean_type, x_value, reduce_dim_op,
1019 /*keep_dims=*/rewriter.getBoolAttr(false));
1020 }
1021
1022 ::mlir::TF::SquaredDifferenceOp square_diff_op;
1023 {
1024 ::mlir::Value tblgen_value_0 = (*x.begin());
1025 ::mlir::Value tblgen_value_1 = (*mean_op_1.getODSResults(0).begin());
1026 // If x has shape of [b, h, w, c], the result of mean_op_1 will have
1027 // shape of [c]. Therefore, their shapes are always compatible.
1028 square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
1029 odsLoc, tblgen_value_0, tblgen_value_1);
1030 }
1031
1032 ::mlir::TF::MeanOp mean_op_2;
1033 {
1034 ::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
1035 mean_op_2 = rewriter.create<TF::MeanOp>(
1036 odsLoc, new_mean_type, input_value, reduce_dim_op,
1037 /*keep_dims=*/rewriter.getBoolAttr(false));
1038 }
1039
1040 mean_value = (*mean_op_1.getODSResults(0).begin());
1041 variance_value = (*mean_op_2.getODSResults(0).begin());
1042 } // End is_training equals true if.
1043
1044 ::llvm::SmallVector<::mlir::Value, 4> replace_values;
1045 ::mlir::TF::ConstOp epsilon_const_op;
1046 {
1047 epsilon_const_op =
1048 rewriter.create<::mlir::TF::ConstOp>(odsLoc,
1049 /*value=*/epsilon);
1050 }
1051 ::mlir::TF::AddOp add_op_1;
1052 {
1053 ::mlir::Value epsilon_value =
1054 (*epsilon_const_op.getODSResults(0).begin());
1055 // Multiplying with a constant, no need to check broadcastibility.
1056 add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
1057 /*x=*/variance_value,
1058 /*y=*/epsilon_value);
1059 }
1060 ::mlir::TF::RsqrtOp rsqrt_op;
1061 {
1062 ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1063 ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1064 tblgen_values.push_back((*add_op_1.getODSResults(0).begin()));
1065 rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values,
1066 tblgen_attrs);
1067 }
1068 ::mlir::TF::MulOp multiplier;
1069 {
1070 ::mlir::Value tblgen_value_0 = (*scale.begin());
1071 ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
1072 multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1073 /*x=*/tblgen_value_0,
1074 /*y=*/tblgen_value_1);
1075 }
1076 ::mlir::TF::MulOp mul_op_1;
1077 {
1078 ::mlir::Value tblgen_value_0 = (*x.begin());
1079 ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
1080 mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1081 /*x=*/tblgen_value_0,
1082 /*y=*/tblgen_value_1);
1083 }
1084 ::mlir::TF::MulOp mul_op_2;
1085 {
1086 ::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
1087 mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
1088 /*x=*/mean_value,
1089 /*y=*/multiplier_value);
1090 }
1091 ::mlir::TF::SubOp sub_op;
1092 {
1093 ::mlir::Value tblgen_value_0 = (*offset.begin());
1094 ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin());
1095 sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
1096 /*x=*/tblgen_value_0,
1097 /*y=*/tblgen_value_1);
1098 }
1099 ::mlir::TF::AddOp add_op_2;
1100 {
1101 ::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
1102 ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
1103 tblgen_values.push_back((*mul_op_1.getODSResults(0).begin()));
1104 tblgen_values.push_back((*sub_op.getODSResults(0).begin()));
1105 ::mlir::SmallVector<::mlir::Type, 4> tblgen_types;
1106 for (auto v : fused_batch_norm_op.getODSResults(0)) {
1107 tblgen_types.push_back(v.getType());
1108 }
1109 add_op_2 = rewriter.create<::mlir::TF::AddOp>(
1110 odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
1111 }
1112 for (auto v :
1113 ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
1114 replace_values.push_back(v);
1115 }
1116 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1117 replace_values.push_back(v);
1118 }
1119 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1120 replace_values.push_back(v);
1121 }
1122 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1123 replace_values.push_back(v);
1124 }
1125 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1126 replace_values.push_back(v);
1127 }
1128 for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
1129 replace_values.push_back(v);
1130 }
1131 rewriter.replaceOp(fused_batch_norm, replace_values);
1132 return success();
1133 };
1134 };
1135
1136 #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
1137
1138 // Returns success if all the operations in the `op`'s regions including `op`
1139 // itself are legal in a TFLite pipeline.
ValidateOp(Operation * op)1140 LogicalResult ValidateOp(Operation *op) {
1141 bool has_illegal_ops = false;
1142 op->walk([&](Operation *op) {
1143 if (isa<TF::VariableV2Op>(op)) {
1144 has_illegal_ops = true;
1145 op->emitOpError() << "is illegal in a TFLite pipeline";
1146 }
1147 });
1148
1149 return failure(has_illegal_ops);
1150 }
1151
1152 // Converts a set of TF2XLA ops into pure TF ops for future legalizations as
1153 // TF2XLA ops aren't supported by later stages.
ConvertTf2XlaOps(FuncOp func,MLIRContext * context)1154 LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
1155 ConversionTarget target(*context);
1156 target.addLegalDialect<StandardOpsDialect>();
1157 target.addLegalDialect<TF::TensorFlowDialect>();
1158 target.addLegalOp<ModuleOp>();
1159 target.addLegalOp<FuncOp>();
1160 target.addIllegalOp<TF::XlaConvOp>();
1161 target.addIllegalOp<TF::XlaGatherOp>();
1162
1163 OwningRewritePatternList patterns(context);
1164 mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns, context);
1165 mhlo::PopulateLegalizeTfPatterns(context, &patterns);
1166 TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
1167 mhlo::GatherOp::getCanonicalizationPatterns(patterns, context);
1168
1169 return applyPartialConversion(func, target, std::move(patterns));
1170 }
1171
1172 // Convert rfft to rfft2d.
1173 // The transformation pattern looks like below:
1174 //
1175 // input fft_len
1176 // \ /
1177 // rfft
1178 //
1179 // ||
1180 // \/
1181 //
1182 // input fft_len
1183 // \ /
1184 // expand_dim concat with [1] at the front
1185 // \ /
1186 // rfft_2d
1187 // |
1188 // squeeze
1189 struct ConvertRfftToRfft2d : public RewritePattern {
ConvertRfftToRfft2dmlir::TFL::__anon092ff4650111::ConvertRfftToRfft2d1190 explicit ConvertRfftToRfft2d(MLIRContext *context)
1191 : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {}
1192
matchAndRewritemlir::TFL::__anon092ff4650111::ConvertRfftToRfft2d1193 LogicalResult matchAndRewrite(Operation *op,
1194 PatternRewriter &rewriter) const override {
1195 auto rfft_op = dyn_cast<TF::RFFTOp>(op);
1196
1197 auto input = rfft_op.input();
1198 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
1199 if (!input_type) return failure();
1200 auto fft_len = rfft_op.fft_length();
1201 auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>();
1202 if (!fft_len_type) return failure();
1203
1204 auto output_type =
1205 rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
1206 if (!output_type) return failure();
1207
1208 // Expanded inputs.
1209 // Insert at -2 location.
1210 auto one_ele_type =
1211 mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32));
1212 auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1213 one_ele_type, -2);
1214
1215 SmallVector<int64_t, 4> expanded_input_shape;
1216 SmallVector<int64_t, 4> expanded_output_shape;
1217 int expanded_rank = input_type.getRank() + 1;
1218 int r = 0;
1219 for (int i = 0; i < expanded_rank; ++i) {
1220 if (i == expanded_rank - 2) {
1221 expanded_input_shape.push_back(1);
1222 expanded_output_shape.push_back(1);
1223 } else {
1224 expanded_input_shape.push_back(input_type.getDimSize(r));
1225 expanded_output_shape.push_back(output_type.getDimSize(r));
1226 r++;
1227 }
1228 }
1229
1230 auto expaned_input_type = mlir::RankedTensorType::get(
1231 expanded_input_shape, input_type.getElementType());
1232 TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>(
1233 rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult());
1234
1235 // Expanded fft_len.
1236 auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1});
1237
1238 auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr);
1239
1240 auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
1241 one_ele_type, 0);
1242
1243 auto expanded_fft_len_type =
1244 mlir::RankedTensorType::get({2}, fft_len_type.getElementType());
1245
1246 TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>(
1247 rfft_op.getLoc(), expanded_fft_len_type,
1248 SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult());
1249
1250 // Insert the rfft_2d.
1251 auto rfft2d_out_type = mlir::RankedTensorType::get(
1252 expanded_output_shape, output_type.getElementType());
1253 TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>(
1254 rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(),
1255 expanded_fft_len.getResult());
1256
1257 // Insert the squeeze op.
1258 auto squeeze_dim = rewriter.getI64ArrayAttr({-2});
1259 TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>(
1260 rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim);
1261
1262 rewriter.replaceOp(op, squeeze.getResult());
1263
1264 return success();
1265 }
1266 };
1267
runOnFunction()1268 void PrepareTFPass::runOnFunction() {
1269 MLIRContext *ctx = &getContext();
1270 OwningRewritePatternList patterns(ctx);
1271 OwningRewritePatternList phase_2_patterns(ctx);
1272 auto func = getFunction();
1273
1274 // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since
1275 // PrepareTFPass is the very first TFLite pass in the pipeline.
1276 // TODO(jingpu): It might be better to split this check into its own pass
1277 // to make things more modular.
1278 if (failed(ValidateOp(func))) {
1279 func.emitError() << "tfl-prepare-tf pass failed.";
1280 signalPassFailure();
1281 return;
1282 }
1283
1284 if (failed(ConvertTf2XlaOps(func, ctx))) {
1285 signalPassFailure();
1286 return;
1287 }
1288
1289 // This pattern will try to identify and optimize for dilated convolution.
1290 // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
1291 // replaced with a single Conv op with dilation parameter.
1292 patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
1293 ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
1294
1295 TFL::populateWithGenerated(patterns);
1296 // TODO(karimnosseir): Split to separate pass probably after
1297 // deciding on long term plan for this optimization.
1298 // This will allow optimizing any TF_Mul->TF_Conv in the graph
1299 // and any expanded from FusedBatchNorm. We need to do this
1300 // before converting TF_Conv to TFL_Conv
1301 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1302
1303 // Remove the wrapper of the tf.FakeQuant* ops and also insert the
1304 // tfl.quantize and tfl.dequantize to preserve the quantization parameters.
1305 // This is done after the first round of optimization to make sure all the
1306 // min/max operands of the tf.FakeQuant* are constants to be matched. The
1307 // following round of optimization will folding the unwrapped
1308 // tf.FakeQuant* ops with the weight constants.
1309 if (failed(ConvertFakeQuantOps(func, ctx))) {
1310 signalPassFailure();
1311 return;
1312 }
1313
1314 // Load the generated pattern again, so new quantization pass-through
1315 // will be applied.
1316 TFL::populateWithGenerated(phase_2_patterns);
1317 if (unfold_batch_matmul_) {
1318 phase_2_patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
1319 TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>,
1320 TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV3Op>>(
1321 ctx);
1322 }
1323 phase_2_patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo,
1324 ConvertTFStridedSlice, ConvertRfftToRfft2d>(ctx);
1325 phase_2_patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
1326 ctx, allow_bf16_and_f16_type_legalization_);
1327
1328 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1329 }
1330
1331 } // namespace
1332
1333 // Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
CreatePrepareTFPass(bool unfold_batch_matmul,bool allow_bf16_type_legalization)1334 std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
1335 bool unfold_batch_matmul, bool allow_bf16_type_legalization) {
1336 return std::make_unique<PrepareTFPass>(unfold_batch_matmul,
1337 allow_bf16_type_legalization);
1338 }
1339
1340 static PassRegistration<PrepareTFPass> pass;
1341
1342 } // namespace TFL
1343 } // namespace mlir
1344