1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass takes operations in TensorFlowLite dialect and
17 // optimizes them to resulting operations in TensorFlowLite dialect.
18
19 #include <algorithm>
20 #include <climits>
21 #include <cstdint>
22 #include <functional>
23 #include <iterator>
24 #include <map>
25 #include <numeric>
26 #include <utility>
27
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallSet.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/ADT/StringSwitch.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
41 #include "mlir/IR/Attributes.h" // from @llvm-project
42 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
43 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
44 #include "mlir/IR/MLIRContext.h" // from @llvm-project
45 #include "mlir/IR/Matchers.h" // from @llvm-project
46 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
47 #include "mlir/IR/Value.h" // from @llvm-project
48 #include "mlir/Pass/Pass.h" // from @llvm-project
49 #include "mlir/Support/LLVM.h" // from @llvm-project
50 #include "mlir/Support/LogicalResult.h" // from @llvm-project
51 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
52 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
53 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
54 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
56 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
57 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
58 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
59
60 namespace mlir {
61 namespace TFL {
62
63 //===----------------------------------------------------------------------===//
64 // The actual Optimize Pass.
65 namespace {
66 constexpr char kRelu[] = "RELU";
67 constexpr char kRelu6[] = "RELU6";
68 constexpr char kRelu1[] = "RELU_N1_TO_1";
69
L2NormalizeReduceAxis(Value sq_op,DenseElementsAttr axis)70 bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
71 if (axis.getNumElements() == 0) {
72 return false;
73 }
74 if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
75 *axis.getValues<int>().begin() ||
76 *axis.getValues<int>().begin() == -1) {
77 return true;
78 }
79 if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
80 return false;
81 }
82 auto shape = sq_op.getType().cast<ShapedType>();
83 SmallVector<int, 4> elems{axis.getValues<int>().begin(),
84 axis.getValues<int>().end()};
85 for (int i = 0; i < shape.getRank(); ++i) {
86 if (i != elems[i]) return false;
87 }
88 return true;
89 }
90
91 using ::llvm::cast;
92
93 // Optimize TFLite operations in functions.
94 class OptimizePass : public PassWrapper<OptimizePass, FunctionPass> {
95 public:
96 OptimizePass() = default;
OptimizePass(const OptimizePass &)97 OptimizePass(const OptimizePass &) {}
OptimizePass(bool enable_canonicalization)98 explicit OptimizePass(bool enable_canonicalization) {
99 enable_canonicalization_ = enable_canonicalization;
100 }
101
getArgument() const102 StringRef getArgument() const final {
103 // This is the argument used to refer to the pass in
104 // the textual format (on the commandline for example).
105 return "tfl-optimize";
106 }
getDescription() const107 StringRef getDescription() const final {
108 // This is a brief description of the pass.
109 return "Optimize within the TensorFlow Lite dialect";
110 }
111
112 void runOnFunction() override;
113
114 private:
115 Option<bool> enable_canonicalization_{
116 *this, "enable-canonicalization",
117 llvm::cl::desc("Enable canonicalization during optimization pass."),
118 llvm::cl::init(false)};
119 };
120
121 // Returns whether the given type `a` is broadcast-compatible with `b`.
IsBroadcastableElementsAttrAndType(Type a,Type b)122 bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
123 return OpTrait::util::getBroadcastedType(a, b) != Type();
124 }
125
126 // Returns whether the resultant type of any broadcastable operation with
127 // operands `a` and `b` matches `expected_output`. Returns false if `a` is not
128 // broadcast-compatible with `b`.
OperandsBroadcastToOutputType(Type a,Type b,Type expected_output)129 bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) {
130 Type output_element_type =
131 expected_output.cast<ShapedType>().getElementType();
132 Type broadcasted_type =
133 OpTrait::util::getBroadcastedType(a, b, output_element_type);
134 return broadcasted_type != Type() && broadcasted_type == expected_output;
135 }
136
137 // Returns whether if `type1` dimensions are the same as the ending dimensions
138 // of `type2`. This is more restricted than broadcastable.
IsTailOfShape(Type type1,Type type2)139 bool IsTailOfShape(Type type1, Type type2) {
140 auto tail_type = type1.dyn_cast<ShapedType>();
141 auto full_type = type2.dyn_cast<ShapedType>();
142 if (!tail_type || !full_type || !tail_type.hasRank() ||
143 !full_type.hasRank() || tail_type.getRank() > full_type.getRank())
144 return false;
145 auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
146 auto i2 = full_type.getShape().rbegin();
147 return std::equal(i1, e1, i2);
148 }
149
CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,const ArrayRef<int64_t> elements_shape,bool is_depthwise)150 bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
151 const ArrayRef<int64_t> elements_shape,
152 bool is_depthwise) {
153 // Make sure the val tensor has shape where all dimensions are 1 except
154 // last one.
155 // Also, val tensor must be of rank 1 or 4 or 0 (scalar).
156 const auto elements_rank = elements_shape.size();
157 for (int i = 0; i < static_cast<int>(elements_shape.size()) - 1; ++i) {
158 if (elements_shape[i] != 1) return false;
159 }
160 if (elements_rank != 1 && elements_rank != 0 && elements_rank != 4) {
161 return false;
162 }
163 auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back();
164 // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we
165 // can let binary op to broadcast elements.
166 if (elements_depth == 1) {
167 return true;
168 }
169
170 // In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise Conv.
171 // For conv:
172 // Check if last dimension in filter equals the first dimension
173 // For depthwise conv:
174 // Check if the first in filter dimension equals the first dimension.
175 if (filter_shape.empty() ||
176 (is_depthwise ? filter_shape.back() != elements_depth
177 : filter_shape[0] != elements_depth))
178 return false;
179 return true;
180 }
181
CanFuseConvOrDepthwiseConv(Value filter,Attribute val,bool is_depthwise)182 bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
183 bool is_depthwise) {
184 const auto elements = val.dyn_cast<DenseElementsAttr>();
185 if (!elements) {
186 return false;
187 }
188 const auto elements_shape = elements.getType().getShape();
189 const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
190 return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
191 is_depthwise);
192 }
193
CanFuseConvOrDepthwiseConv(Attribute filter,Attribute val,bool is_depthwise)194 bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
195 bool is_depthwise) {
196 if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
197 if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
198 return CanFuseConvOrDepthwiseConvShapes(
199 filter_elements.getType().getShape(), elements.getType().getShape(),
200 is_depthwise);
201 }
202 }
203 return false;
204 }
205
206 // Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value
207 // of `indices` are from 0 to n-1, the output tensor are identical to the
208 // `params`.
CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,DenseIntElementsAttr indices)209 bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
210 DenseIntElementsAttr indices) {
211 auto params_type = params.getType().dyn_cast<RankedTensorType>();
212 auto indices_type = indices.getType().dyn_cast<RankedTensorType>();
213 // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D
214 // `indices` means it gets the first row of `params`. As long as indices
215 // iterate the first row of `params`, the output is identical to input.
216 if (!params_type || !indices_type || indices_type.getRank() != 2 ||
217 indices_type.getDimSize(0) != params_type.getDimSize(0) ||
218 indices_type.getDimSize(1) != 1)
219 return false;
220
221 // Checks the value in `indices` is from 0 to n-1.
222 int cur_value = 0;
223 for (const auto &v : indices.getValues<APInt>()) {
224 if (v.getSExtValue() != cur_value) return false;
225 ++cur_value;
226 }
227
228 return true;
229 }
230
231 // Returns true if we can eliminate the SliceOp. When the values of `begin` are
232 // all 0s and `size[i]` is equal to either -1 or `input.shape[i]`
233 // for each dim i, the output tensor is identical to `input`.
CanOptimizeIdentitySliceOp(Value input,Attribute begin,Attribute size)234 bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) {
235 // Checks if `begin` and `size` are i32 or i64.
236 auto begin_attr = begin.dyn_cast<DenseIntElementsAttr>();
237 auto size_attr = size.dyn_cast<DenseIntElementsAttr>();
238 if (!begin_attr || !size_attr) {
239 return false;
240 }
241
242 auto begin_elem_ty = begin_attr.getType().getElementType();
243 if (!begin_elem_ty.isInteger(32) && !begin_elem_ty.isInteger(64)) {
244 return false;
245 }
246 auto size_elem_ty = size_attr.getType().getElementType();
247 if (!size_elem_ty.isInteger(32) && !size_elem_ty.isInteger(64)) {
248 return false;
249 }
250
251 // Checks if `input` is ranked and its rank is equal to number of elements in
252 // `begin` and `size`.
253 auto input_ty = input.getType().cast<ShapedType>();
254 if (!input_ty.hasRank()) {
255 return false;
256 }
257
258 int64_t rank = input_ty.getRank();
259 if (rank != begin_attr.getNumElements() ||
260 rank != size_attr.getNumElements()) {
261 return false;
262 }
263
264 // Checks if `begin` is all 0s, and `size[i]` is equal to either -1 or
265 // `input.shape[i]`.
266 for (uint64_t i = 0; i < rank; ++i) {
267 if (begin_attr.getValue<APInt>({i}).getSExtValue() != 0) return false;
268 int64_t si = size_attr.getValue<APInt>({i}).getSExtValue();
269 if (si != -1 && si != input_ty.getDimSize(i)) return false;
270 }
271
272 return true;
273 }
274
275 // Expand Attribute 'a' to 4D with all 1s except 1 dimension.
276 // Which dimension depends on 'is_depthwise' is true or false.
ExpandTo4DForConvImpl(Attribute a,bool is_depthwise)277 ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
278 auto elements = a.dyn_cast<DenseElementsAttr>();
279 auto shape = elements.getType().getShape();
280 if (!shape.empty()) {
281 // Checks that elements are essentially 1d.
282 assert(elements.getNumElements() == shape.back());
283 }
284 std::vector<int64_t> shape_data = {1, 1, 1, 1};
285 const int vector_length = elements.getNumElements();
286 if (is_depthwise)
287 shape_data[3] = vector_length;
288 else
289 shape_data[0] = vector_length;
290 auto new_shape =
291 RankedTensorType::get(shape_data, elements.getType().getElementType());
292 return elements.reshape(new_shape);
293 }
294
ExpandTo4DForConv(Attribute a)295 ElementsAttr ExpandTo4DForConv(Attribute a) {
296 return ExpandTo4DForConvImpl(a, false);
297 }
298
ExpandTo4DForDepthwiseConv(Attribute a)299 ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
300 return ExpandTo4DForConvImpl(a, true);
301 }
302
RescaleQtype(Type input,Attribute factor)303 TypeAttr RescaleQtype(Type input, Attribute factor) {
304 return quant::RescaleQuantizedType(input, factor);
305 }
306
307 // Returns shape of a ranked tensor.
308 // Precondition: output_val's is ranked tensor.
GetShape(Value output_val)309 DenseElementsAttr GetShape(Value output_val) {
310 auto output_type = output_val.getType().cast<RankedTensorType>();
311 auto shape_vector = output_type.getShape();
312 std::vector<int32_t> shape;
313 shape.reserve(shape_vector.size());
314 for (auto shape_object : shape_vector) {
315 shape.push_back(shape_object);
316 }
317 return mlir::DenseElementsAttr::get(
318 RankedTensorType::get(
319 {static_cast<int>(shape.size())},
320 mlir::IntegerType::get(output_val.getContext(), 32)),
321 llvm::makeArrayRef(shape));
322 }
323
GetShapeStrippedType(TypeAttr type_attr)324 static Type GetShapeStrippedType(TypeAttr type_attr) {
325 auto type = type_attr.getValue();
326 auto shaped_type = type.dyn_cast<ShapedType>();
327 if (shaped_type) {
328 return shaped_type.getElementType();
329 } else {
330 return type;
331 }
332 }
333
334 // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in
335 // the specified `shape` and `false` otherwise.
ShapeMatchesReduceWithKeepAxes(Value input,const mlir::Attribute & axes,const mlir::Attribute & shape)336 static bool ShapeMatchesReduceWithKeepAxes(Value input,
337 const mlir::Attribute &axes,
338 const mlir::Attribute &shape) {
339 RankedTensorType type = input.getType().dyn_cast_or_null<RankedTensorType>();
340 if (!type) return false;
341
342 DenseIntElementsAttr axes_attr =
343 axes.dyn_cast_or_null<DenseIntElementsAttr>();
344 DenseIntElementsAttr shape_attr =
345 shape.dyn_cast_or_null<DenseIntElementsAttr>();
346 if (!axes_attr || !shape_attr) return false;
347
348 if (shape_attr.getNumElements() != type.getRank()) return false;
349
350 llvm::SmallSet<uint64_t, 4> axes_set;
351 for (auto a : axes_attr.getIntValues()) {
352 axes_set.insert(a.getZExtValue());
353 }
354
355 auto type_shape = type.getShape();
356 for (uint64_t i = 0; i < type.getRank(); ++i) {
357 if (axes_set.contains(i)) {
358 if (shape_attr.getValue<APInt>({i}) != 1) return false;
359 } else {
360 if (shape_attr.getValue<APInt>({i}) != type_shape[i]) return false;
361 }
362 }
363 return true;
364 }
365
FloatValueEquals(const Attribute & attr,double value)366 static bool FloatValueEquals(const Attribute &attr, double value) {
367 auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
368 if (!fp_attr) return false;
369
370 if (fp_attr.isSplat()) {
371 return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
372 }
373 return llvm::all_of(fp_attr.getFloatValues(), [value](const APFloat &f) {
374 return f.isExactlyValue(value);
375 });
376 }
377
378 // Returns true if the value's element type is F32.
IsF32Value(Value value)379 bool IsF32Value(Value value) {
380 return value.getType().cast<ShapedType>().getElementType().isF32();
381 }
382
383 // Returns the number of elements in attr if it is a DenseElementsAttr, 1
384 // otherwise, as an unranked int32 Attribute.
GetNumElementsOrOne(Attribute attr)385 Attribute GetNumElementsOrOne(Attribute attr) {
386 const auto dense_attr = attr.dyn_cast_or_null<DenseElementsAttr>();
387 int32_t num_elements = dense_attr ? dense_attr.getNumElements() : 1;
388
389 OpBuilder builder(attr.getContext());
390
391 return DenseIntElementsAttr::get(
392 RankedTensorType::get({}, builder.getI32Type()),
393 {llvm::APInt(32, num_elements, true)});
394 }
395
396 // Returns true if attr is a DenseIntElementsAttr with the last element equal 1.
IsLastElementEqualsOne(Attribute attr)397 bool IsLastElementEqualsOne(Attribute attr) {
398 const auto ints = attr.dyn_cast_or_null<DenseIntElementsAttr>();
399 if (!ints) return false;
400 if (ints.empty()) return false;
401 const auto last_element_index = ints.getNumElements() - 1;
402 const auto iterator = ints.getIntValues().begin();
403 const APInt last_element = iterator[last_element_index];
404 return last_element == 1;
405 }
406
407 // Returns true if attr is a DenseIntElementsAttr of int32 or int64 values or an
408 // incrementing sequence from 0 to N-1.
409 //
410 // If such a value is used in an Equal operator, it can be replaced with OneHot.
IsOneHotIndexAttribute(Attribute attr)411 bool IsOneHotIndexAttribute(Attribute attr) {
412 const auto dense_attr = attr.dyn_cast_or_null<DenseIntElementsAttr>();
413 if (!dense_attr) {
414 return false;
415 }
416 auto index_type = dense_attr.getType();
417 const auto index_elem_bits = index_type.getElementTypeBitWidth();
418 if (index_elem_bits != 32 && index_elem_bits != 64) {
419 return false;
420 }
421 if (index_type.getRank() != 1) {
422 return false;
423 }
424 const auto elems = dense_attr.getIntValues().begin();
425 for (int i = 0; i < dense_attr.getNumElements(); ++i) {
426 if (i != elems[i]) {
427 return false;
428 }
429 }
430 return true;
431 }
432
433 // Converts an Attribute with a single value of float or integral type to an
434 // Attribute holding a single value of float type. If attr has no elements, the
435 // result is 0.0f.
ConvertSingleElementAttrToFloatAttr(Attribute attr)436 Attribute ConvertSingleElementAttrToFloatAttr(Attribute attr) {
437 const auto dense_fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
438 if (dense_fp_attr) {
439 // Already float => return
440 return dense_fp_attr;
441 }
442
443 OpBuilder builder(attr.getContext());
444
445 const auto dense_int_attr = attr.dyn_cast<DenseIntElementsAttr>();
446 const auto int_values = dense_int_attr.getIntValues();
447 float float_val = 0.0f;
448 if (!int_values.empty()) {
449 const APInt apint_val = *int_values.begin();
450 if (dense_int_attr.getType().getElementType().isSignedInteger()) {
451 // Get the sign-extended value (=>int64) if the type is signed.
452 float_val = apint_val.getSExtValue();
453 } else {
454 // Get the zero-extended value (=>uint64) if unsigned or signless.
455 float_val = apint_val.getZExtValue();
456 }
457 }
458 return DenseFPElementsAttr::get(
459 RankedTensorType::get({}, builder.getF32Type()),
460 {llvm::APFloat(float_val)});
461 }
462
463 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
464
465 // Fuse Add with proceeding FullyConnected.
466 // TODO(b/136285429): Move to tablegen when variadic is supported
467 struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
468 using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
469
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseFullyConnectedAndAdd470 LogicalResult matchAndRewrite(TFL::AddOp add_op,
471 PatternRewriter &rewriter) const override {
472 // Match Add.
473 DenseElementsAttr added_value;
474 Value constant_val = add_op.rhs();
475 if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
476
477 // Match Fully Connected.
478 auto fc_op =
479 dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
480 if (!fc_op) return failure();
481
482 // Check if the constant RHS is either 0D (scalar), or a 1D with
483 // `{num_channels}` shape.
484 auto constant_val_type = constant_val.getType().cast<TensorType>();
485
486 // In TFLite FullyConnect definition, bias must be a 1D tensor where
487 // the number of elements is equal to the number of channels.
488 // If it's not 1D or 0D (which can be broadcasted to 1D), reject the
489 // matching.
490 bool is_scalar_rhs = false;
491 if (constant_val_type.getRank() == 0) {
492 is_scalar_rhs = true;
493 } else if (constant_val_type.getRank() != 1) {
494 return failure();
495 }
496
497 Value filter = fc_op.filter();
498 Value bias = fc_op.bias();
499 ElementsAttr bias_value;
500 const bool is_none_bias = bias.getType().isa<NoneType>();
501 if (fc_op.fused_activation_function() != "NONE") return failure();
502
503 if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
504 return failure();
505
506 // Rewrite
507 if (is_none_bias) {
508 if (is_scalar_rhs) {
509 // If the `constant_val` is scalar, we must the shape of filter
510 // to properly broadcast the scalar to `{num_channels}` shape.
511
512 // Get the number of channels if possible.
513 auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
514 // Filter must be a `2D` tensor with `{num_channels, num_features}`
515 // shape. The following check is rejecting unknown rank (-1).
516 if (filter_type == nullptr || filter_type.getRank() != 2) {
517 return failure();
518 }
519 int num_channels = filter_type.getShape()[0];
520
521 // Create a zero tensor with shape {num_channels}, and the type need to
522 // be the same as constant_val.
523 // This is a way to gracefully handle scalar tensor. The Add will always
524 // be constant-folded away regardless if `constant_val` is a scalar or
525 // not.
526 RankedTensorType type = RankedTensorType::get(
527 {num_channels}, constant_val_type.getElementType());
528 auto attr = rewriter.getZeroAttr(type);
529 bias = rewriter.create<ConstantOp>(add_op.getLoc(), type, attr);
530 auto none_af = rewriter.getStringAttr("NONE");
531 bias =
532 rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
533 .output();
534 } else {
535 // If there no pre-existing bias and the `constant_val` is 1D, simply
536 // use `constant_val` as bias.
537 bias = constant_val;
538 }
539 } else {
540 auto none_af = rewriter.getStringAttr("NONE");
541 bias =
542 rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
543 .output();
544 }
545
546 auto fc = rewriter.create<TFL::FullyConnectedOp>(
547 FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), add_op.getLoc()}),
548 add_op.getType(),
549 /*input=*/fc_op.input(),
550 /*filter=*/filter,
551 /*bias=*/bias,
552 /*fused_activation_function=*/
553 rewriter.getStringAttr(add_op.fused_activation_function()),
554 /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
555 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
556 rewriter.replaceOp(add_op, fc.output());
557
558 return success();
559 }
560 };
561
562 // Replace ..
563 // FC(Add(lhs, rhs), filter, bias)
564 // .. with ..
565 // FC(lhs, filter, FC(rhs, filter, bias))
566 // .. if rhs, filter, and bias are all constants.
567 // The second FC will be constant folded to a single vector.
568 // TODO(b/136285429): Move to tablegen when variadic is supported
569 struct FuseAddAndFullyConnected
570 : public OpRewritePattern<TFL::FullyConnectedOp> {
571 using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
572
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseAddAndFullyConnected573 LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
574 PatternRewriter &rewriter) const override {
575 // This only works with default format.
576 if (fc_op.weights_format() != "DEFAULT") return failure();
577
578 // Match Add.
579 auto add_op = dyn_cast_or_null<TFL::AddOp>(fc_op.input().getDefiningOp());
580 if (!add_op) return failure();
581 if (add_op.fused_activation_function() != "NONE") return failure();
582
583 // Don't match adds where the added constant is not 1D.
584 {
585 auto addend_shape = add_op.rhs().getType().cast<ShapedType>();
586 if (!addend_shape.hasStaticShape()) return failure();
587 if (addend_shape.getShape().size() != 1) return failure();
588 }
589
590 // Calculate new bias. Generate a new FC; it will be constant folded.
591 auto old_bias = fc_op.bias();
592 if (!old_bias || old_bias.getType().isa<NoneType>()) {
593 // TODO(b/180752069): Figure out new bias' type when old bias is empty.
594 return failure();
595 }
596
597 // The FC relies on constant folding, which is implemented on F32. Checks
598 // types to be F32.
599 {
600 if (!IsF32Value(add_op.rhs()) || !IsF32Value(fc_op.filter()) ||
601 !IsF32Value(old_bias))
602 return failure();
603 }
604
605 auto new_bias = rewriter.create<TFL::FullyConnectedOp>(
606 fc_op.getLoc(), old_bias.getType(),
607 /*input=*/add_op.rhs(),
608 /*filter=*/fc_op.filter(),
609 /*bias=*/old_bias,
610 /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
611 /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
612 /*keep_num_dims=*/rewriter.getBoolAttr(true));
613
614 // Create the updated FC.
615 auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
616 FusedLoc::get(add_op.getContext(), {add_op.getLoc(), fc_op.getLoc()}),
617 fc_op.output().getTypes(),
618 /*input=*/add_op.lhs(),
619 /*filter=*/fc_op.filter(),
620 /*bias=*/*new_bias.output().begin(),
621 /*fused_activation_function=*/
622 rewriter.getStringAttr(fc_op.fused_activation_function()),
623 /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
624 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
625 rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
626
627 return success();
628 }
629 };
630
631 // Replace ..
632 // FC(Mul(lhs, rhs), filter, bias)
633 // .. with ..
634 // FC(lhs, Mul(filter, rhs), bias)
635 // .. if rhs, filter, and bias are all constants.
636 // The generated Mul will be constant folded to a single matrix.
637 struct FuseMulAndFullyConnected
638 : public OpRewritePattern<TFL::FullyConnectedOp> {
639 using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
640
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseMulAndFullyConnected641 LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
642 PatternRewriter &rewriter) const override {
643 // This only works with default format.
644 if (fc_op.weights_format() != "DEFAULT") return failure();
645
646 // Match Mul.
647 auto mul_op = dyn_cast_or_null<TFL::MulOp>(fc_op.input().getDefiningOp());
648 if (!mul_op) return failure();
649 if (mul_op.fused_activation_function() != "NONE") return failure();
650
651 // Don't match muls where the multiplier constant is not 1D.
652 {
653 auto multiplier_shape = mul_op.rhs().getType().cast<ShapedType>();
654 if (!multiplier_shape.hasStaticShape()) return failure();
655 if (multiplier_shape.getShape().size() != 1) return failure();
656 }
657
658 // We rely on constant folding, implemented only for F32. Check types.
659 if (!IsF32Value(mul_op.rhs()) || !IsF32Value(fc_op.filter())) {
660 return failure();
661 }
662
663 auto location =
664 FusedLoc::get(mul_op.getContext(), {mul_op.getLoc(), fc_op.getLoc()});
665
666 auto new_filter = rewriter.create<TFL::MulOp>(
667 location,
668 /*lhs=*/fc_op.filter(),
669 /*rhs=*/mul_op.rhs(),
670 /*fused_activation_function=*/rewriter.getStringAttr("NONE"));
671 // Create the updated FC.
672 auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
673 location, fc_op.output().getTypes(),
674 /*input=*/mul_op.lhs(),
675 /*filter=*/new_filter,
676 /*bias=*/fc_op.bias(),
677 /*fused_activation_function=*/
678 rewriter.getStringAttr(fc_op.fused_activation_function()),
679 /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
680 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
681 rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
682
683 return success();
684 }
685 };
686
687 // TODO(b/136285429): Move to tablegen when variadic is supported.
688 template <typename ReluXOp, char const *Act>
689 struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
690 using OpRewritePattern<ReluXOp>::OpRewritePattern;
691
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseFullyConnectedAndReluX692 LogicalResult matchAndRewrite(ReluXOp relu_op,
693 PatternRewriter &rewriter) const override {
694 Operation *input = relu_op.getOperand().getDefiningOp();
695 if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
696 auto fully_connected_op = cast<FullyConnectedOp>(input);
697 if (fully_connected_op.fused_activation_function() != "NONE")
698 return failure();
699
700 auto new_activation_func = rewriter.getStringAttr(Act);
701 auto new_weights_format =
702 rewriter.getStringAttr(fully_connected_op.weights_format());
703 auto new_keep_num_dims =
704 rewriter.getBoolAttr(fully_connected_op.keep_num_dims());
705 auto fc = rewriter.create<FullyConnectedOp>(
706 FusedLoc::get(relu_op.getContext(),
707 {fully_connected_op.getLoc(), relu_op.getLoc()}),
708 relu_op.getType(), fully_connected_op.input(),
709 fully_connected_op.filter(), fully_connected_op.bias(),
710 new_activation_func, new_weights_format, new_keep_num_dims);
711 rewriter.replaceOp(relu_op, fc.output());
712
713 return success();
714 }
715 };
716
717 // Fuse Mul with proceeding FullyConnected.
718 // TODO(b/136285429): Move to tablegen when variadic is supported
719 struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
720 using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
721
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseFullyConnectedAndMul722 LogicalResult matchAndRewrite(TFL::MulOp mul_op,
723 PatternRewriter &rewriter) const override {
724 // If we are broadcasting on the lhs then don't fold the multiply as it
725 // would increase the amount of compute done by the fully connected op.
726 if (mul_op.lhs().getType() != mul_op.getType()) return failure();
727
728 // Mul.
729 DenseElementsAttr cst;
730 Value constant_val = mul_op.rhs();
731 if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
732
733 // Fully Connected.
734 auto fc_op =
735 dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
736 if (!fc_op) return failure();
737 Value filter = fc_op.filter();
738 Value bias = fc_op.bias();
739 ElementsAttr cst_tmp;
740 if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
741 if (!bias.getType().isa<NoneType>() &&
742 !matchPattern(bias, m_Constant(&cst_tmp)))
743 return failure();
744 if (fc_op.fused_activation_function() != "NONE") return failure();
745
746 // Only fuse multiplier if all dimensions other than the depth dimension
747 // are equal to 1 since otherwise
748 // `matmul(x, filter) * cst != matmul(x, filter * cst)`
749 // even if `filter` and `cst` are be broadcastable.
750 auto shape = cst.getType().getShape();
751 if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
752
753 int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
754 // Expand and transpose the multiplier since weights are using the
755 // OHWI data format in TFLite.
756 int64_t normalized_shape[2] = {element_size, 1};
757 auto new_cst = cst.reshape(RankedTensorType::get(
758 normalized_shape, cst.getType().getElementType()));
759 Type new_type = new_cst.getType();
760 if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
761 return failure();
762 }
763
764 auto new_op =
765 rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
766 Value new_const_val = new_op.getResult();
767
768 // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
769 // TF::MulOp is used to fold the constant.
770 // TODO(b/139192933): switch to the TFL constant folding
771 auto new_filter =
772 rewriter.create<TF::MulOp>(mul_op.getLoc(), filter, new_const_val).z();
773 // If bias isn't None, it needs to be multiplied as well.
774 if (!bias.getType().isa<NoneType>()) {
775 bias =
776 rewriter.create<TF::MulOp>(mul_op.getLoc(), bias, constant_val).z();
777 }
778
779 auto fc = rewriter.create<TFL::FullyConnectedOp>(
780 FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), mul_op.getLoc()}),
781 mul_op.getType(),
782 /*input=*/fc_op.input(),
783 /*filter=*/new_filter,
784 /*bias=*/bias,
785 /*fused_activation_function=*/
786 rewriter.getStringAttr(mul_op.fused_activation_function()),
787 /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
788 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
789 rewriter.replaceOp(mul_op, fc.output());
790
791 return success();
792 }
793 };
794
795 // Fuse Mul with proceeding Affine ops. This is an C++ implementation of the
796 // following table gen implementation, which doesn't derived the result type of
797 // the TFL_DequantizeOp.
798 // def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input,
799 // (TFL_DequantizeOp (TFL_QuantizeOp
800 // (ConstantOp F32ElementsAttr:$filter), $qtype)),
801 // (ConstantOp F32ElementsAttr:$bias),
802 // $h_factor, $w_factor, TFL_AF_None,
803 // $padding, $stride_h, $stride_w),
804 // (ConstantOp F32ElementsAttr:$value), $act_fn),
805 // (TFL_Conv2DOp $input,
806 // (TFL_DequantizeOp (TFL_QuantizeOp
807 // (TFL_MulOp (ConstantOp $filter),
808 // (ConstantOp (ExpandTo4DForConv $value)),
809 // TFL_AF_None),
810 // (RescaleQtype $qtype, $value))),
811 // (TFL_MulOp (ConstantOp $bias), (ConstantOp $value),
812 // TFL_AF_None),
813 // $h_factor, $w_factor, $act_fn,
814 // $padding, $stride_h, $stride_w),
815 // [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
816 // (HasOneUse $conv_output),
817 // (IsPerAxisQuantization $qtype), // per-axis quantization
818 // ]>;
819 template <typename AffineOpType>
820 struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
821 using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
822
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseAffinOpAndMulWithQDQs823 LogicalResult matchAndRewrite(TFL::MulOp mul_op,
824 PatternRewriter &rewriter) const override {
825 // Mul. Required 1-D rhs for batch normalization.
826 DenseElementsAttr gamma_cst;
827 Value gamma = mul_op.rhs();
828 if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
829 if (gamma_cst.getType().getRank() != 1) return failure();
830
831 // Affine op
832 Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
833 auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
834 if (!fc_op) return failure();
835 Value filter = fc_op.filter();
836 Value bias = fc_op.bias();
837
838 // QDQs
839 auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
840 if (!dq_op) return failure();
841 auto q_op =
842 dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
843 if (!q_op) return failure();
844 filter = q_op.input();
845
846 // weight constant
847 ElementsAttr cst_tmp;
848 if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
849 if (!bias.getType().isa<NoneType>() &&
850 !matchPattern(bias, m_Constant(&cst_tmp)))
851 return failure();
852 if (fc_op.fused_activation_function() != "NONE") return failure();
853
854 // Broadcast the constant operand of Mul if it isn't compatible to the
855 // filter input. We only support broadcasting the operand along the depth
856 // dimension, when the operand's depth is 1.
857 rewriter.setInsertionPoint(q_op);
858 Location loc = fc_op.getLoc();
859 Value broadcasted_gamma;
860 if (isa<TFL::Conv2DOp>(mul_op_lhs)) {
861 auto mul_rhs = ExpandTo4DForConv(gamma_cst);
862 broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
863 } else if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
864 auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
865 broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
866 } else {
867 return failure();
868 }
869
870 // Make sure that the fused bias will be a 1D tensor.
871 if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
872 auto gamma_shape = gamma.getType().cast<ShapedType>();
873 if (!gamma_shape.hasRank() || gamma_shape.getRank() != 1) {
874 return failure();
875 }
876 }
877
878 // Rewrite filter constant. Since the folder of TFL::MulOp couldn't
879 // broadcast the operands, TF::MulOp is used to fold the constant.
880 auto new_filter =
881 rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
882 // Update the scale in the quantize op.
883 auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
884 if (!new_qtype) return failure();
885 rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
886 new_filter, new_qtype);
887
888 // If bias isn't None, it needs to be multiplied as well.
889 if (!bias.getType().isa<NoneType>()) {
890 rewriter.setInsertionPoint(fc_op);
891 auto new_bias = rewriter.create<TF::MulOp>(loc, bias, gamma);
892 fc_op.getOperation()->replaceUsesOfWith(bias, new_bias);
893 }
894
895 // Remove the tailing mul op.
896 mul_op.replaceAllUsesWith(fc_op.getResult());
897 return success();
898 }
899 };
900
901 using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs<TFL::Conv2DOp>;
902 using FuseDepthwiseConv2DAndMulWithQDQs =
903 FuseAffinOpAndMulWithQDQs<TFL::DepthwiseConv2DOp>;
904
905 // Fuse Binary Op with following Affine operation.
906 template <typename AffineOpType>
907 struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
908 using OpRewritePattern<AffineOpType>::OpRewritePattern;
909
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseBinaryOpToFollowingAffineOp910 LogicalResult matchAndRewrite(AffineOpType fc_op,
911 PatternRewriter &rewriter) const override {
912 // Binary op.
913 Operation *binary_op = fc_op.input().getDefiningOp();
914 if (!binary_op || binary_op->getNumOperands() != 2) return failure();
915 // We only handle the cases the RHS is a scalar.
916 // TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
917 // the constant operands are on the RHS, we need to consider LHS constant
918 // operand if necessary.
919 DenseFPElementsAttr cst;
920 if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
921 return failure();
922 if (cst.getNumElements() != 1) return failure();
923 APFloat cst_value = *cst.float_value_begin();
924
925 // Affine op.
926 Value filter = fc_op.filter();
927 Value bias = fc_op.bias();
928 DenseFPElementsAttr filter_cst, bias_cst;
929 if (!matchPattern(filter, m_Constant(&filter_cst))) {
930 // The filter maybe quantized, then we should set it to the real constant.
931 auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
932 if (!dq) return failure();
933 auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
934 if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
935 return failure();
936 }
937 filter = q.input();
938 }
939 if (!bias.getType().isa<NoneType>() &&
940 !matchPattern(bias, m_Constant(&bias_cst)))
941 return failure();
942 auto binary_op_activation_func =
943 binary_op->template getAttrOfType<StringAttr>(
944 "fused_activation_function");
945 if (!binary_op_activation_func ||
946 binary_op_activation_func.getValue() != "NONE")
947 return failure();
948 ShapedType filter_type = filter_cst.getType();
949
950 if (llvm::isa<AddOp, SubOp>(binary_op)) {
951 auto padding = fc_op->template getAttrOfType<StringAttr>("padding");
952 if (padding && padding.getValue() != "VALID") return failure();
953
954 // The fusion of add/sub is actually applying the following
955 // transformation:
956 // w * (x + c) + b => w * x + (w * c + b)
957 // so we have to update the bias.
958 if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
959
960 auto bias_and_slice =
961 GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
962 int64_t bias_size = bias_and_slice.first;
963 int64_t slice_size = bias_and_slice.second;
964 ShapedType new_bias_type =
965 RankedTensorType::get({bias_size}, filter_type.getElementType());
966
967 // The new bias should be a 1-D tensor with length equals to the bias
968 // dimension of the weight.
969 SmallVector<APFloat, 4> new_bias_values;
970 if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
971 new_bias_values.resize(bias_size,
972 APFloat::getZero(cst_value.getSemantics()));
973 } else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
974 new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
975 } else if (bias_cst.getNumElements() == bias_size) { // 1-d bias, copy it
976 new_bias_values.insert(new_bias_values.begin(),
977 bias_cst.float_value_begin(),
978 bias_cst.float_value_end());
979 } else {
980 return failure();
981 }
982
983 int64_t flatten_index = 0;
984 for (auto fp_it = filter_cst.float_value_begin(),
985 fp_end = filter_cst.float_value_end();
986 fp_it != fp_end; ++fp_it) {
987 int bias_index = (flatten_index++ / slice_size) % bias_size;
988
989 new_bias_values[bias_index] =
990 new_bias_values[bias_index] + *fp_it * cst_value;
991 }
992 auto new_bias = DenseFPElementsAttr::get(new_bias_type, new_bias_values);
993 auto new_bias_op =
994 rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
995 fc_op.setOperand(0, binary_op->getOperand(0));
996 fc_op.setOperand(2, new_bias_op);
997 } else if (llvm::isa<MulOp, DivOp>(binary_op)) {
998 // The fusion of mul/div is actually applying the following
999 // transformation:
1000 // w * (x ' c) + b => (w ' c) x + b
1001 // so we have to update the weight.
1002 bool is_mul = llvm::isa<MulOp>(binary_op);
1003 auto new_filter =
1004 filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
1005 return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
1006 });
1007 // We recreate the constant op in case it is shared by the other ops. This
1008 // might increase the model size.
1009 auto new_filter_op = rewriter.create<ConstOp>(
1010 fc_op.getLoc(), filter.getType(), new_filter);
1011 fc_op.setOperand(0, binary_op->getOperand(0));
1012 if (fc_op.filter() != filter) {
1013 // This filter goes through quantize and dequantize ops. Then we just
1014 // need to update the weight to the quantize op.
1015 filter.replaceAllUsesWith(new_filter_op);
1016 } else {
1017 // This filter doesn't go through quantize and dequantize ops, Then
1018 // we update the weight of the affine op directly.
1019 fc_op.setOperand(1, new_filter_op);
1020 }
1021 } else {
1022 return failure();
1023 }
1024 return success();
1025 }
1026
1027 private:
1028 // Returns the dimension length of the channel dimension and also the slide
1029 // size by each position in the channel dimension accordingly. tfl.conv2d and
1030 // tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d
1031 // has tailing channel dimension. This function is to provide a utility to
1032 // create the above information from the op property.
GetBiasDimAndSliceSizemlir::TFL::__anon15858f2e0111::FuseBinaryOpToFollowingAffineOp1033 static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
1034 ArrayRef<int64_t> filter_shape, AffineOpType op) {
1035 // Channel dimension index is specified as op property
1036 auto channel_index_iter = filter_shape.begin();
1037 std::advance(channel_index_iter, op.GetChannelDimIndex());
1038 // The slide size is the size of the data in higher dimensions.
1039 int64_t slice_size =
1040 std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
1041 std::multiplies<int64_t>());
1042 return {*channel_index_iter, slice_size};
1043 }
1044 };
1045
1046 // If the operand to a broadcastable op is a splat constant, try to replace it
1047 // with a 0-d constant, e.g. before this optimization,
1048 // %cst = constant dense<1.0> : tensor<16x16x4xf32>
1049 // %0 = "tfl.conv_2d"...
1050 // %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
1051 // After this optimization:
1052 // %cst = constant dense<1.0> : tensor<f32>
1053 // %0 = "tfl.conv_2d"...
1054 // %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
1055 // This pattern can enable more fusing opportunities when the binary op is
1056 // following conv ops.
1057 template <typename BinaryOpType>
1058 struct ScalarizeSplatConstantForBroadcastableOps
1059 : public OpRewritePattern<BinaryOpType> {
1060 using OpRewritePattern<BinaryOpType>::OpRewritePattern;
1061
matchAndRewritemlir::TFL::__anon15858f2e0111::ScalarizeSplatConstantForBroadcastableOps1062 LogicalResult matchAndRewrite(BinaryOpType binary_op,
1063 PatternRewriter &rewriter) const override {
1064 DenseElementsAttr splat_elements_attr;
1065 if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
1066 return failure();
1067 }
1068
1069 constexpr int kSplatOperandIndex = 1;
1070 auto result_type =
1071 binary_op.getResult().getType().template cast<ShapedType>();
1072 mlir::Value non_splat_operand =
1073 binary_op.getOperand(1 - kSplatOperandIndex);
1074 auto non_splat_operand_type =
1075 non_splat_operand.getType().cast<ShapedType>();
1076 // If the other operand's shape does not equal to the result shape, then we
1077 // cannot scalarize the splat constant because the result shape relies on
1078 // the splat constant op's shape for broadcasting.
1079 if (!non_splat_operand_type.hasStaticShape() ||
1080 non_splat_operand_type.getShape() != result_type.getShape() ||
1081 non_splat_operand_type.getRank() > 4) {
1082 return failure();
1083 }
1084
1085 // If non-splat operand is not fusable affine ops, then no need to apply
1086 // this transformation.
1087 if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
1088 return failure();
1089 }
1090
1091 // Creates a new scalar constant op using the splat value.
1092 mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
1093 auto scalar_elements_attr = DenseElementsAttr::get(
1094 RankedTensorType::get({},
1095 splat_elements_attr.getType().getElementType()),
1096 splat_elements_attr.getSplatValue());
1097
1098 auto scalar_constant_op = rewriter.create<ConstantOp>(
1099 splat_operand.getLoc(), scalar_elements_attr.getType(),
1100 scalar_elements_attr);
1101
1102 binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
1103 return success();
1104 }
1105
1106 private:
1107 // Returns true if this value is a splat constant op which can be scalarized.
1108 // Also returns the elements attr if this value is indeed a splat constant.
IsScalarizableSplatConstantmlir::TFL::__anon15858f2e0111::ScalarizeSplatConstantForBroadcastableOps1109 bool IsScalarizableSplatConstant(mlir::Value value,
1110 DenseElementsAttr *elements_attr) const {
1111 if (!matchPattern(value, m_Constant(elements_attr))) {
1112 return false;
1113 }
1114 auto element_type = value.getType().cast<ShapedType>().getElementType();
1115 // Ignore per-axis quantized constants because after converting to scalar,
1116 // we will lose per-axis qantization parameter.
1117 if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
1118 return false;
1119 }
1120 if (IsScalar(value)) {
1121 return false;
1122 }
1123 return elements_attr->isSplat();
1124 }
1125
1126 // If this type is a scalar shaped type.
IsScalarmlir::TFL::__anon15858f2e0111::ScalarizeSplatConstantForBroadcastableOps1127 bool IsScalar(mlir::Value value) const {
1128 auto type = value.getType().dyn_cast<ShapedType>();
1129 if (!type) {
1130 return false;
1131 }
1132 if (!type.hasStaticShape()) {
1133 return false;
1134 }
1135 return type.getNumElements() == 1;
1136 }
1137
1138 // Returns true if we can fuse an affine op with consuming binary op.
CanFuseAffineOpmlir::TFL::__anon15858f2e0111::ScalarizeSplatConstantForBroadcastableOps1139 bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
1140 if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
1141 TFL::FullyConnectedOp>(affine_op)) {
1142 return false;
1143 }
1144 DenseElementsAttr value;
1145 // Check that bias are constants if not none.
1146 Value bias = affine_op->getOperand(2);
1147 if (!bias.getType().isa<NoneType>() &&
1148 !matchPattern(bias, m_Constant(&value))) {
1149 return false;
1150 }
1151 // If the binary op is mul/div, also check that filter is constant.
1152 if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
1153 !matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
1154 return false;
1155 }
1156
1157 // We can only fuse F32/BF16.
1158 auto is_fusable_type = [](Type t) {
1159 Type element_type = t;
1160 if (auto shaped_type = t.dyn_cast<ShapedType>()) {
1161 element_type = shaped_type.getElementType();
1162 }
1163 return element_type.isBF16() || element_type.isF32();
1164 };
1165 for (Type t : binary_op->getOperandTypes()) {
1166 if (!is_fusable_type(t)) {
1167 return false;
1168 }
1169 }
1170
1171 return true;
1172 }
1173 };
1174
1175 using ScalarizeSplatConstantForSub =
1176 ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
1177 using ScalarizeSplatConstantForAdd =
1178 ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
1179 using ScalarizeSplatConstantForMul =
1180 ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
1181 using ScalarizeSplatConstantForDiv =
1182 ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
1183
1184 struct ConvertTrivialTransposeOpToReshapeOp
1185 : public OpRewritePattern<TFL::TransposeOp> {
1186 using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
1187
matchAndRewritemlir::TFL::__anon15858f2e0111::ConvertTrivialTransposeOpToReshapeOp1188 LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
1189 PatternRewriter &rewriter) const override {
1190 auto input_type = transpose_op.input().getType().cast<ShapedType>();
1191 auto output_type = transpose_op.output().getType().cast<ShapedType>();
1192 // It's possible to know if the transformation is safe only if the input
1193 // & output shapes are fully known and permutation is a constant.
1194 if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
1195 return failure();
1196 Value perm = transpose_op.perm();
1197 DenseElementsAttr perm_values_attr;
1198 if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
1199
1200 auto input_shape = input_type.getShape();
1201 SmallVector<int64_t, 8> perm_values;
1202 for (const auto &dim : perm_values_attr.getIntValues())
1203 perm_values.push_back(dim.getSExtValue());
1204
1205 // This should never happen unless the input graph is malformed.
1206 if (input_shape.size() != perm_values.size()) {
1207 transpose_op.emitError(
1208 "TransposeOP has inconsistent input and perm values.");
1209 }
1210
1211 SmallVector<int, 8> old_major_index_ordering;
1212 SmallVector<int, 8> new_major_index_ordering;
1213 for (int i = 0, end = input_shape.size(); i < end; i++) {
1214 if (input_shape[i] != 1) {
1215 old_major_index_ordering.push_back(i);
1216 }
1217
1218 if (input_shape[perm_values[i]] != 1) {
1219 new_major_index_ordering.push_back(perm_values[i]);
1220 }
1221 }
1222 if (old_major_index_ordering != new_major_index_ordering) {
1223 return failure();
1224 }
1225
1226 // Rewrite.
1227 Location loc = transpose_op.getLoc();
1228
1229 SmallVector<int32_t, 8> output_shape_values;
1230 for (auto dim : output_type.getShape()) {
1231 output_shape_values.push_back(dim);
1232 }
1233 auto type = mlir::RankedTensorType::get(output_shape_values.size(),
1234 rewriter.getIntegerType(32));
1235 auto new_shape_attr =
1236 mlir::DenseIntElementsAttr::get(type, output_shape_values);
1237 auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr);
1238
1239 rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
1240 transpose_op, transpose_op.output().getType(), transpose_op.input(),
1241 new_shape);
1242
1243 return success();
1244 }
1245 };
1246
1247 // Remove Reshape before FullyConnected when `keep_num_dims=false` and Reshape
1248 // does not alter the last dimension as FullyConnected will collapse all other
1249 // dimensions into a single dimension. For example,
1250 //
1251 // %shape = constant dense<[1, 128, 64]> : tensor<3xi32>
1252 // %reshape = tfl.reshape(%input, %shape) // %input: tensor<128x64xf32>
1253 // %fc = tfl.fully_connected(%reshape, %filter, %bias)
1254 // {keep_num_dims = false, weights_format = "DEFAULT"}
1255 //
1256 // can be canonicalized to
1257 //
1258 // %fc = tfl.fully_connected(%input, %filter, %bias)
1259 // {keep_num_dims = false, weights_format = "DEFAULT"}
1260 struct RemoveReshapeBeforeFullyConnected
1261 : public OpRewritePattern<TFL::FullyConnectedOp> {
1262 using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
1263
matchAndRewritemlir::TFL::__anon15858f2e0111::RemoveReshapeBeforeFullyConnected1264 LogicalResult matchAndRewrite(TFL::FullyConnectedOp fully_connected_op,
1265 PatternRewriter &) const override {
1266 auto input = fully_connected_op.input();
1267 auto input_ty = input.getType().dyn_cast<ShapedType>();
1268 auto output_ty = fully_connected_op.output()[0]
1269 .getType()
1270 .template dyn_cast<ShapedType>();
1271 if (!input_ty.hasStaticShape() ||
1272 fully_connected_op.weights_format() != "DEFAULT" ||
1273 fully_connected_op.keep_num_dims() || !output_ty.hasStaticShape() ||
1274 output_ty.getRank() != 2) {
1275 return failure();
1276 }
1277
1278 auto reshape_op = input.getDefiningOp<TFL::ReshapeOp>();
1279 if (!reshape_op) return failure();
1280
1281 // Check if the last dimension does not change after reshape.
1282 auto reshape_input = reshape_op.input();
1283 auto reshape_input_ty = reshape_input.getType().dyn_cast<ShapedType>();
1284 if (!reshape_input_ty.hasStaticShape() || input_ty.getRank() == 0 ||
1285 reshape_input_ty.getRank() == 0 ||
1286 input_ty.getDimSize(input_ty.getRank() - 1) !=
1287 reshape_input_ty.getDimSize(reshape_input_ty.getRank() - 1)) {
1288 return failure();
1289 }
1290
1291 // Connect the input to the one of reshape.
1292 fully_connected_op.setOperand(0, reshape_input);
1293 return success();
1294 }
1295 };
1296
1297 // Remove Reshape after FullyConnected when `keep_num_dims=false`, the Reshaoe
1298 // does not alter the last dimension and it restores the batch dimensions
1299 // collapsed by the FullyConnected op due to `keep_num_dims=false`. For example,
1300 //
1301 // // %input: tensor<4x16x32xf32>
1302 // %fc = tfl.fully_connected(%input, %filter, %bias)
1303 // {keep_num_dims = false, weights_format = "DEFAULT"}
1304 // %shape = constant dense<[4, 16, 32]> : tensor<3xi32>
1305 // %rs = tfl.reshape(%fc, %shape)
1306 //
1307 // can be canonicalized to
1308 //
1309 // %fc = tfl.fully_connected(%input, %filter, %bias)
1310 // {keep_num_dims = true, weights_format = "DEFAULT"}
1311 struct RemoveReshapeAfterFullyConnected
1312 : public OpRewritePattern<TFL::ReshapeOp> {
1313 using OpRewritePattern::OpRewritePattern;
1314
matchAndRewritemlir::TFL::__anon15858f2e0111::RemoveReshapeAfterFullyConnected1315 LogicalResult matchAndRewrite(TFL::ReshapeOp reshape_op,
1316 PatternRewriter &rewriter) const override {
1317 auto fully_connected_op = llvm::dyn_cast_or_null<TFL::FullyConnectedOp>(
1318 reshape_op.input().getDefiningOp());
1319 if (!fully_connected_op || fully_connected_op.getNumResults() != 1 ||
1320 fully_connected_op.weights_format() != "DEFAULT" ||
1321 fully_connected_op.keep_num_dims())
1322 return failure();
1323 if (!reshape_op.input().hasOneUse()) return failure();
1324
1325 auto input_shape = fully_connected_op.input().getType().cast<ShapedType>();
1326 auto output_shape = fully_connected_op.getType(0).cast<ShapedType>();
1327 auto reshape_shape = reshape_op.getType().cast<ShapedType>();
1328 if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() ||
1329 !reshape_shape.hasStaticShape())
1330 return failure();
1331
1332 // Check that the reshape doesn't modify the last dimension and it restores
1333 // the input (batch) dimension with the exception of the feature (last)
1334 // dimension.
1335 if (output_shape.getShape().empty() || reshape_shape.getShape().empty() ||
1336 output_shape.getShape().back() != reshape_shape.getShape().back() ||
1337 input_shape.getShape().drop_back() !=
1338 reshape_shape.getShape().drop_back())
1339 return failure();
1340
1341 llvm::SmallVector<Type, 1> output_type{reshape_op.getType()};
1342 rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
1343 reshape_op, output_type, fully_connected_op.input(),
1344 fully_connected_op.filter(), fully_connected_op.bias(),
1345 fully_connected_op.fused_activation_function(),
1346 fully_connected_op.weights_format(), /*keep_num_dims=*/true);
1347 return success();
1348 }
1349 };
1350
1351 // Fuses Unpack with proceeding Concatenation to Reshape if output type has
1352 // static shape and activation function is none. For example:
1353 //
1354 // // %input: tensor<1x3x2xf32>
1355 // %unpack:3 = "tfl.unpack"(%input) {axis = 1 : i32, num = 3 : i32}
1356 // %res = "tfl.concatenation"(%unpack#0, %unpack#1, %unpack#2)
1357 // {axis = -1 : i32, fused_activation_function = "NONE"}
1358 //
1359 // can be optimized to
1360 //
1361 // %cst = constant dense<[1, 6]> : tensor<2xi32>
1362 // %res = "tfl.reshape"(%input, %cst)
1363 struct FuseUnpackAndConcatToReshape
1364 : public OpRewritePattern<TFL::ConcatenationOp> {
1365 using OpRewritePattern::OpRewritePattern;
1366
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseUnpackAndConcatToReshape1367 LogicalResult matchAndRewrite(TFL::ConcatenationOp concat_op,
1368 PatternRewriter &rewriter) const override {
1369 if (concat_op.fused_activation_function() != "NONE") {
1370 return failure();
1371 }
1372
1373 // Checks all operands come from the same unpack op.
1374 auto first_operand = concat_op.values().front();
1375 auto unpack_op =
1376 dyn_cast_or_null<TFL::UnpackOp>(first_operand.getDefiningOp());
1377 if (!unpack_op || unpack_op.getNumResults() != concat_op.getNumOperands()) {
1378 return failure();
1379 }
1380 for (auto &index_and_value : llvm::enumerate(concat_op.values())) {
1381 if (index_and_value.value() !=
1382 unpack_op.getResult(index_and_value.index())) {
1383 return failure();
1384 }
1385 }
1386
1387 auto output_type = concat_op.getType().cast<ShapedType>();
1388 if (!output_type.hasStaticShape()) {
1389 return failure();
1390 }
1391
1392 auto new_shape_array = output_type.getShape();
1393 // This is to workaround the unnecessary cast i64 -> i32.
1394 SmallVector<int32_t, 4> new_shape_array_i32;
1395 for (auto size : new_shape_array) {
1396 new_shape_array_i32.push_back(static_cast<int32_t>(size));
1397 }
1398 auto new_shape = rewriter.create<TFL::ConstOp>(
1399 concat_op.getLoc(),
1400 DenseIntElementsAttr::get(
1401 RankedTensorType::get(new_shape_array_i32.size(),
1402 rewriter.getIntegerType(32)),
1403 new_shape_array_i32));
1404
1405 rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(concat_op, output_type,
1406 unpack_op.input(), new_shape);
1407 return success();
1408 }
1409 };
1410
1411 using FuseBinaryOpToFollowingFullyConnected =
1412 FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
1413 using FuseBinaryOpToFollowingDepthwiseConv2D =
1414 FuseBinaryOpToFollowingAffineOp<DepthwiseConv2DOp>;
1415 using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp<Conv2DOp>;
1416
1417 // Adds canonicalization patterns to the list of patterns.
AddCanonicalizationPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1418 void AddCanonicalizationPatterns(MLIRContext *context,
1419 OwningRewritePatternList *patterns) {
1420 for (auto *op : context->getRegisteredOperations())
1421 op->getCanonicalizationPatterns(*patterns, context);
1422 }
1423
runOnFunction()1424 void OptimizePass::runOnFunction() {
1425 OwningRewritePatternList patterns(&getContext());
1426 auto *ctx = &getContext();
1427 auto func = getFunction();
1428
1429 // Merge reshapes into fully connected ops before we start moving them past
1430 // binary ops.
1431 OwningRewritePatternList phase_0_patterns(&getContext());
1432 phase_0_patterns.insert<RemoveReshapeAfterFullyConnected,
1433 RemoveReshapeBeforeFullyConnected>(ctx);
1434 (void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns));
1435
1436 // Potentially the binary ops might be fused together, like hard_swish, thus
1437 // we explore these potentially first and then fuse the binary ops with the
1438 // following ops in a second pattern match.
1439 TFL::populateWithGenerated(patterns);
1440 patterns.insert<FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
1441 FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
1442 FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
1443 FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
1444 FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>>(ctx);
1445 if (enable_canonicalization_) AddCanonicalizationPatterns(ctx, &patterns);
1446 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1447
1448 // Fuse the binary ops with the following ops.
1449 OwningRewritePatternList phase_2_patterns(&getContext());
1450 TFL::populateWithGenerated(phase_2_patterns);
1451 phase_2_patterns.insert<
1452 ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
1453 ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
1454 FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
1455 FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
1456 FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
1457 FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
1458 FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
1459 FuseBinaryOpToFollowingConv2D, FuseBinaryOpToFollowingDepthwiseConv2D,
1460 FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
1461 FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp,
1462 RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected,
1463 FuseUnpackAndConcatToReshape>(ctx);
1464 if (enable_canonicalization_)
1465 AddCanonicalizationPatterns(ctx, &phase_2_patterns);
1466 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1467 }
1468 } // namespace
1469
1470 // Creates an instance of the TensorFlow Lite dialect Optimize pass.
CreateOptimizePass(bool enable_canonicalization)1471 std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass(
1472 bool enable_canonicalization) {
1473 return std::make_unique<OptimizePass>(enable_canonicalization);
1474 }
1475
1476 static PassRegistration<OptimizePass> pass;
1477
1478 } // namespace TFL
1479 } // namespace mlir
1480