1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass takes operations in TensorFlowLite dialect and
17 // optimizes them to resulting operations in TensorFlowLite dialect.
18
19 #include <algorithm>
20 #include <climits>
21 #include <cstdint>
22 #include <functional>
23 #include <iterator>
24 #include <map>
25 #include <numeric>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/None.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/SmallSet.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringRef.h"
35 #include "llvm/ADT/StringSwitch.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
39 #include "mlir/IR/Attributes.h" // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
41 #include "mlir/IR/Matchers.h" // from @llvm-project
42 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
43 #include "mlir/IR/Value.h" // from @llvm-project
44 #include "mlir/Pass/Pass.h" // from @llvm-project
45 #include "mlir/Support/LLVM.h" // from @llvm-project
46 #include "mlir/Support/LogicalResult.h" // from @llvm-project
47 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
48 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
49 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
50 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
52 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
53 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
54 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
55
56 namespace mlir {
57 namespace TFL {
58
59 //===----------------------------------------------------------------------===//
60 // The actual Optimize Pass.
61 namespace {
62 constexpr char kRelu[] = "RELU";
63 constexpr char kRelu6[] = "RELU6";
64 constexpr char kRelu1[] = "RELU_N1_TO_1";
65
L2NormalizeReduceAxis(Value sq_op,DenseElementsAttr axis)66 bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
67 if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
68 *axis.getValues<int>().begin() ||
69 *axis.getValues<int>().begin() == -1) {
70 return true;
71 }
72 if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
73 return false;
74 }
75 auto shape = sq_op.getType().cast<ShapedType>();
76 SmallVector<int, 4> elems{axis.getValues<int>().begin(),
77 axis.getValues<int>().end()};
78 for (int i = 0; i < shape.getRank(); ++i) {
79 if (i != elems[i]) return false;
80 }
81 return true;
82 }
83
84 using ::llvm::cast;
85
86 // Optimize TFLite operations in functions.
87 struct Optimize : public PassWrapper<Optimize, FunctionPass> {
88 void runOnFunction() override;
89 };
90
91 // Returns whether the given type `a` is broadcast-compatible with `b`.
IsBroadcastableElementsAttrAndType(Type a,Type b)92 bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
93 return OpTrait::util::getBroadcastedType(a, b) != Type();
94 }
95
96 // Returns whether the resultant type of any broadcastable operation with
97 // operands `a` and `b` matches `expected_output`. Returns false if `a` is not
98 // broadcast-compatible with `b`.
OperandsBroadcastToOutputType(Type a,Type b,Type expected_output)99 bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) {
100 Type output_element_type =
101 expected_output.cast<ShapedType>().getElementType();
102 Type broadcasted_type =
103 OpTrait::util::getBroadcastedType(a, b, output_element_type);
104 return broadcasted_type != Type() && broadcasted_type == expected_output;
105 }
106
107 // Returns whether if `type1` dimensions are the same as the ending dimensions
108 // of `type2`. This is more restricted than broadcastable.
IsTailOfShape(Type type1,Type type2)109 bool IsTailOfShape(Type type1, Type type2) {
110 auto tail_type = type1.dyn_cast<ShapedType>();
111 auto full_type = type2.dyn_cast<ShapedType>();
112 if (!tail_type || !full_type || !tail_type.hasRank() ||
113 !full_type.hasRank() || tail_type.getRank() > full_type.getRank())
114 return false;
115 auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
116 auto i2 = full_type.getShape().rbegin();
117 return std::equal(i1, e1, i2);
118 }
119
CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,const ArrayRef<int64_t> elements_shape,bool is_depthwise)120 bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
121 const ArrayRef<int64_t> elements_shape,
122 bool is_depthwise) {
123 // Make sure the val tensor has shape where all dimensions are 1 except
124 // last one.
125 // Also, val tensor must be of rank 1 or 4 or 0 (scalar).
126 const auto elements_rank = elements_shape.size();
127 for (int i = 0; i < static_cast<int>(elements_shape.size()) - 1; ++i) {
128 if (elements_shape[i] != 1) return false;
129 }
130 if (elements_rank != 1 && elements_rank != 0 && elements_rank != 4) {
131 return false;
132 }
133 auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back();
134 // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we
135 // can let binary op to broadcast elements.
136 if (elements_depth == 1) {
137 return true;
138 }
139
140 // In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise Conv.
141 // For conv:
142 // Check if last dimension in filter equals the first dimension
143 // For depthwise conv:
144 // Check if the first in filter dimension equals the first dimension.
145 if (filter_shape.empty() ||
146 (is_depthwise ? filter_shape.back() != elements_depth
147 : filter_shape[0] != elements_depth))
148 return false;
149 return true;
150 }
151
CanFuseConvOrDepthwiseConv(Value filter,Attribute val,bool is_depthwise)152 bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
153 bool is_depthwise) {
154 const auto elements = val.dyn_cast<DenseElementsAttr>();
155 if (!elements) {
156 return false;
157 }
158 const auto elements_shape = elements.getType().getShape();
159 const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
160 return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
161 is_depthwise);
162 }
163
CanFuseConvOrDepthwiseConv(Attribute filter,Attribute val,bool is_depthwise)164 bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
165 bool is_depthwise) {
166 if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
167 if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
168 return CanFuseConvOrDepthwiseConvShapes(
169 filter_elements.getType().getShape(), elements.getType().getShape(),
170 is_depthwise);
171 }
172 }
173 return false;
174 }
175
176 // Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value
177 // of `indices` are from 0 to n-1, the output tensor are identical to the
178 // `params`.
CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,DenseIntElementsAttr indices)179 bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
180 DenseIntElementsAttr indices) {
181 auto params_type = params.getType().dyn_cast<RankedTensorType>();
182 auto indices_type = indices.getType().dyn_cast<RankedTensorType>();
183 // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D
184 // `indices` means it gets the first row of `params`. As long as indices
185 // iterate the first row of `params`, the output is identical to input.
186 if (!params_type || !indices_type || indices_type.getRank() != 2 ||
187 indices_type.getDimSize(0) != params_type.getDimSize(0) ||
188 indices_type.getDimSize(1) != 1)
189 return false;
190
191 // Checks the value in `indices` is from 0 to n-1.
192 int cur_value = 0;
193 for (const auto &v : indices.getValues<APInt>()) {
194 if (v.getSExtValue() != cur_value) return false;
195 ++cur_value;
196 }
197
198 return true;
199 }
200
201 // Expand Attribute 'a' to 4D with all 1s except 1 dimension.
202 // Which dimension depends on 'is_depthwise' is true or false.
ExpandTo4DForConvImpl(Attribute a,bool is_depthwise)203 ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
204 auto elements = a.dyn_cast<DenseElementsAttr>();
205 auto shape = elements.getType().getShape();
206 if (!shape.empty()) {
207 // Checks that elements are essentially 1d.
208 assert(elements.getNumElements() == shape.back());
209 }
210 std::vector<int64_t> shape_data = {1, 1, 1, 1};
211 const int vector_length = elements.getNumElements();
212 if (is_depthwise)
213 shape_data[3] = vector_length;
214 else
215 shape_data[0] = vector_length;
216 auto new_shape =
217 RankedTensorType::get(shape_data, elements.getType().getElementType());
218 return elements.reshape(new_shape);
219 }
220
ExpandTo4DForConv(Attribute a)221 ElementsAttr ExpandTo4DForConv(Attribute a) {
222 return ExpandTo4DForConvImpl(a, false);
223 }
224
ExpandTo4DForDepthwiseConv(Attribute a)225 ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
226 return ExpandTo4DForConvImpl(a, true);
227 }
228
RescaleQtype(Type input,Attribute factor)229 TypeAttr RescaleQtype(Type input, Attribute factor) {
230 return quant::RescaleQuantizedType(input, factor);
231 }
232
233 // Returns shape of a ranked tensor.
234 // Precondition: output_val's is ranked tensor.
GetShape(Value output_val)235 DenseElementsAttr GetShape(Value output_val) {
236 auto output_type = output_val.getType().cast<RankedTensorType>();
237 auto shape_vector = output_type.getShape();
238 std::vector<int32_t> shape;
239 shape.reserve(shape_vector.size());
240 for (auto shape_object : shape_vector) {
241 shape.push_back(shape_object);
242 }
243 return mlir::DenseElementsAttr::get(
244 RankedTensorType::get(
245 {static_cast<int>(shape.size())},
246 mlir::IntegerType::get(output_val.getContext(), 32)),
247 llvm::makeArrayRef(shape));
248 }
249
GetShapeStrippedType(TypeAttr type_attr)250 static Type GetShapeStrippedType(TypeAttr type_attr) {
251 auto type = type_attr.getValue();
252 auto shaped_type = type.dyn_cast<ShapedType>();
253 if (shaped_type) {
254 return shaped_type.getElementType();
255 } else {
256 return type;
257 }
258 }
259
260 // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in
261 // the specified `shape` and `false` otherwise.
ShapeMatchesReduceWithKeepAxes(Value input,const mlir::Attribute & axes,const mlir::Attribute & shape)262 static bool ShapeMatchesReduceWithKeepAxes(Value input,
263 const mlir::Attribute &axes,
264 const mlir::Attribute &shape) {
265 RankedTensorType type = input.getType().dyn_cast_or_null<RankedTensorType>();
266 if (!type) return false;
267
268 DenseIntElementsAttr axes_attr =
269 axes.dyn_cast_or_null<DenseIntElementsAttr>();
270 DenseIntElementsAttr shape_attr =
271 shape.dyn_cast_or_null<DenseIntElementsAttr>();
272 if (!axes_attr || !shape_attr) return false;
273
274 if (shape_attr.getNumElements() != type.getRank()) return false;
275
276 llvm::SmallSet<uint64_t, 4> axes_set;
277 for (auto a : axes_attr.getIntValues()) {
278 axes_set.insert(a.getZExtValue());
279 }
280
281 auto type_shape = type.getShape();
282 for (uint64_t i = 0; i < type.getRank(); ++i) {
283 if (axes_set.contains(i)) {
284 if (shape_attr.getValue<APInt>({i}) != 1) return false;
285 } else {
286 if (shape_attr.getValue<APInt>({i}) != type_shape[i]) return false;
287 }
288 }
289 return true;
290 }
291
FloatValueEquals(const Attribute & attr,double value)292 static bool FloatValueEquals(const Attribute &attr, double value) {
293 auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
294 if (!fp_attr) return false;
295
296 if (fp_attr.isSplat()) {
297 return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
298 }
299 return llvm::all_of(fp_attr.getFloatValues(), [value](const APFloat &f) {
300 return f.isExactlyValue(value);
301 });
302 }
303
304 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
305
306 // Fuse Add with proceeding FullyConnected.
307 // TODO(b/136285429): Move to tablegen when variadic is supported
308 struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
309 using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
310
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseFullyConnectedAndAdd311 LogicalResult matchAndRewrite(TFL::AddOp add_op,
312 PatternRewriter &rewriter) const override {
313 // Match Add.
314 DenseElementsAttr added_value;
315 Value constant_val = add_op.rhs();
316 if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
317
318 // Match Fully Connected.
319 auto fc_op =
320 dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
321 if (!fc_op) return failure();
322
323 // Check if the constant RHS is either 0D (scalar), or a 1D with
324 // `{num_channels}` shape.
325 auto constant_val_type = constant_val.getType().cast<TensorType>();
326
327 // In TFLite FullyConnect definition, bias must be a 1D tensor where
328 // the number of elements is equal to the number of channels.
329 // If it's not 1D or 0D (which can be broadcasted to 1D), reject the
330 // matching.
331 bool is_scalar_rhs = false;
332 if (constant_val_type.getRank() == 0) {
333 is_scalar_rhs = true;
334 } else if (constant_val_type.getRank() != 1) {
335 return failure();
336 }
337
338 Value filter = fc_op.filter();
339 Value bias = fc_op.bias();
340 ElementsAttr bias_value;
341 const bool is_none_bias = bias.getType().isa<NoneType>();
342 if (fc_op.fused_activation_function() != "NONE") return failure();
343
344 if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
345 return failure();
346
347 // Rewrite
348 if (is_none_bias) {
349 if (is_scalar_rhs) {
350 // If the `constant_val` is scalar, we must the shape of filter
351 // to properly broadcast the scalar to `{num_channels}` shape.
352
353 // Get the number of channels if possible.
354 auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
355 // Filter must be a `2D` tensor with `{num_channels, num_features}`
356 // shape. The following check is rejecting unknown rank (-1).
357 if (filter_type == nullptr || filter_type.getRank() != 2) {
358 return failure();
359 }
360 int num_channels = filter_type.getShape()[0];
361
362 // Create a zero tensor with shape {num_channels}, and the type need to
363 // be the same as constant_val.
364 // This is a way to gracefully handle scalar tensor. The Add will always
365 // be constant-folded away regardless if `constant_val` is a scalar or
366 // not.
367 RankedTensorType type = RankedTensorType::get(
368 {num_channels}, constant_val_type.getElementType());
369 auto attr = rewriter.getZeroAttr(type);
370 bias = rewriter.create<ConstantOp>(add_op.getLoc(), type, attr);
371 auto none_af = rewriter.getStringAttr("NONE");
372 bias =
373 rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
374 .output();
375 } else {
376 // If there no pre-existing bias and the `constant_val` is 1D, simply
377 // use `constant_val` as bias.
378 bias = constant_val;
379 }
380 } else {
381 auto none_af = rewriter.getStringAttr("NONE");
382 bias =
383 rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
384 .output();
385 }
386
387 auto fc = rewriter.create<TFL::FullyConnectedOp>(
388 FusedLoc::get({fc_op.getLoc(), add_op.getLoc()}, fc_op.getContext()),
389 add_op.getType(),
390 /*input=*/fc_op.input(),
391 /*filter=*/filter,
392 /*bias=*/bias,
393 /*fused_activation_function=*/
394 rewriter.getStringAttr(add_op.fused_activation_function()),
395 /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
396 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
397 rewriter.replaceOp(add_op, fc.output());
398
399 return success();
400 }
401 };
402
403 // TODO(b/136285429): Move to tablegen when variadic is supported.
404 template <typename ReluXOp, char const *Act>
405 struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
406 using OpRewritePattern<ReluXOp>::OpRewritePattern;
407
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseFullyConnectedAndReluX408 LogicalResult matchAndRewrite(ReluXOp relu_op,
409 PatternRewriter &rewriter) const override {
410 Operation *input = relu_op.getOperand().getDefiningOp();
411 if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
412 auto fully_connected_op = cast<FullyConnectedOp>(input);
413 if (fully_connected_op.fused_activation_function() != "NONE")
414 return failure();
415
416 auto new_activation_func = rewriter.getStringAttr(Act);
417 auto new_weights_format =
418 rewriter.getStringAttr(fully_connected_op.weights_format());
419 auto new_keep_num_dims =
420 rewriter.getBoolAttr(fully_connected_op.keep_num_dims());
421 auto fc = rewriter.create<FullyConnectedOp>(
422 FusedLoc::get({fully_connected_op.getLoc(), relu_op.getLoc()},
423 relu_op.getContext()),
424 relu_op.getType(), fully_connected_op.input(),
425 fully_connected_op.filter(), fully_connected_op.bias(),
426 new_activation_func, new_weights_format, new_keep_num_dims);
427 rewriter.replaceOp(relu_op, fc.output());
428
429 return success();
430 }
431 };
432
433 // Fuse Mul with proceeding FullyConnected.
434 // TODO(b/136285429): Move to tablegen when variadic is supported
435 struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
436 using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
437
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseFullyConnectedAndMul438 LogicalResult matchAndRewrite(TFL::MulOp mul_op,
439 PatternRewriter &rewriter) const override {
440 // If we are broadcasting on the lhs then don't fold the multiply as it
441 // would increase the amount of compute done by the fully connected op.
442 if (mul_op.lhs().getType() != mul_op.getType()) return failure();
443
444 // Mul.
445 DenseElementsAttr cst;
446 Value constant_val = mul_op.rhs();
447 if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
448
449 // Fully Connected.
450 auto fc_op =
451 dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
452 if (!fc_op) return failure();
453 Value filter = fc_op.filter();
454 Value bias = fc_op.bias();
455 ElementsAttr cst_tmp;
456 if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
457 if (!bias.getType().isa<NoneType>() &&
458 !matchPattern(bias, m_Constant(&cst_tmp)))
459 return failure();
460 if (fc_op.fused_activation_function() != "NONE") return failure();
461
462 // Only fuse multiplier if all dimensions other than the depth dimension
463 // are equal to 1 since otherwise
464 // `matmul(x, filter) * cst != matmul(x, filter * cst)`
465 // even if `filter` and `cst` are be broadcastable.
466 auto shape = cst.getType().getShape();
467 if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
468
469 int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
470 // Expand and transpose the multiplier since weights are using the
471 // OHWI data format in TFLite.
472 int64_t normalized_shape[2] = {element_size, 1};
473 auto new_cst = cst.reshape(RankedTensorType::get(
474 normalized_shape, cst.getType().getElementType()));
475 Type new_type = new_cst.getType();
476 if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
477 return failure();
478 }
479
480 auto new_op =
481 rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
482 Value new_const_val = new_op.getResult();
483
484 // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
485 // TF::MulOp is used to fold the constant.
486 // TODO(b/139192933): switch to the TFL constant folding
487 auto new_filter =
488 rewriter.create<TF::MulOp>(mul_op.getLoc(), filter, new_const_val).z();
489 // If bias isn't None, it needs to be multiplied as well.
490 if (!bias.getType().isa<NoneType>()) {
491 bias =
492 rewriter.create<TF::MulOp>(mul_op.getLoc(), bias, constant_val).z();
493 }
494
495 auto fc = rewriter.create<TFL::FullyConnectedOp>(
496 FusedLoc::get({fc_op.getLoc(), mul_op.getLoc()}, fc_op.getContext()),
497 mul_op.getType(),
498 /*input=*/fc_op.input(),
499 /*filter=*/new_filter,
500 /*bias=*/bias,
501 /*fused_activation_function=*/
502 rewriter.getStringAttr(mul_op.fused_activation_function()),
503 /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
504 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
505 rewriter.replaceOp(mul_op, fc.output());
506
507 return success();
508 }
509 };
510
511 // Fuse Mul with proceeding Affine ops. This is an C++ implementation of the
512 // following table gen implementation, which doesn't derived the result type of
513 // the TFL_DequantizeOp.
514 // def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input,
515 // (TFL_DequantizeOp (TFL_QuantizeOp
516 // (ConstantOp F32ElementsAttr:$filter), $qtype)),
517 // (ConstantOp F32ElementsAttr:$bias),
518 // $h_factor, $w_factor, TFL_AF_None,
519 // $padding, $stride_h, $stride_w),
520 // (ConstantOp F32ElementsAttr:$value), $act_fn),
521 // (TFL_Conv2DOp $input,
522 // (TFL_DequantizeOp (TFL_QuantizeOp
523 // (TFL_MulOp (ConstantOp $filter),
524 // (ConstantOp (ExpandTo4DForConv $value)),
525 // TFL_AF_None),
526 // (RescaleQtype $qtype, $value))),
527 // (TFL_MulOp (ConstantOp $bias), (ConstantOp $value),
528 // TFL_AF_None),
529 // $h_factor, $w_factor, $act_fn,
530 // $padding, $stride_h, $stride_w),
531 // [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
532 // (HasOneUse $conv_output),
533 // (IsPerAxisQuantization $qtype), // per-axis quantization
534 // ]>;
535 template <typename AffineOpType>
536 struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
537 using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
538
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseAffinOpAndMulWithQDQs539 LogicalResult matchAndRewrite(TFL::MulOp mul_op,
540 PatternRewriter &rewriter) const override {
541 // Mul. Required 1-D rhs for batch normalization.
542 DenseElementsAttr gamma_cst;
543 Value gamma = mul_op.rhs();
544 if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
545 if (gamma_cst.getType().getRank() != 1) return failure();
546
547 // Affine op
548 Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
549 auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
550 if (!fc_op) return failure();
551 Value filter = fc_op.filter();
552 Value bias = fc_op.bias();
553
554 // QDQs
555 auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
556 if (!dq_op) return failure();
557 auto q_op =
558 dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
559 if (!q_op) return failure();
560 filter = q_op.input();
561
562 // weight constant
563 ElementsAttr cst_tmp;
564 if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
565 if (!bias.getType().isa<NoneType>() &&
566 !matchPattern(bias, m_Constant(&cst_tmp)))
567 return failure();
568 if (fc_op.fused_activation_function() != "NONE") return failure();
569
570 // Broadcast the constant operand of Mul if it isn't compatible to the
571 // filter input. We only support broadcasting the operand along the depth
572 // dimension, when the operand's depth is 1.
573 rewriter.setInsertionPoint(q_op);
574 Location loc = fc_op.getLoc();
575 Value broadcasted_gamma;
576 if (isa<TFL::Conv2DOp>(mul_op_lhs)) {
577 auto mul_rhs = ExpandTo4DForConv(gamma_cst);
578 broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
579 } else if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
580 auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
581 broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
582 } else {
583 return failure();
584 }
585
586 // Rewrite filter constant. Since the folder of TFL::MulOp couldn't
587 // broadcast the operands, TF::MulOp is used to fold the constant.
588 auto new_filter =
589 rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
590 // Update the scale in the quantize op.
591 auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
592 if (!new_qtype) return failure();
593 rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
594 new_filter, new_qtype);
595
596 // If bias isn't None, it needs to be multiplied as well.
597 if (!bias.getType().isa<NoneType>()) {
598 rewriter.setInsertionPoint(fc_op);
599 auto new_bias = rewriter.create<TF::MulOp>(loc, bias, gamma);
600 fc_op.getOperation()->replaceUsesOfWith(bias, new_bias);
601 }
602
603 // Remove the tailing mul op.
604 mul_op.replaceAllUsesWith(fc_op.getResult());
605 return success();
606 }
607 };
608
609 using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs<TFL::Conv2DOp>;
610 using FuseDepthwiseConv2DAndMulWithQDQs =
611 FuseAffinOpAndMulWithQDQs<TFL::DepthwiseConv2DOp>;
612
613 // Fuse Binary Op with following Affine operation.
614 template <typename AffineOpType>
615 struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
616 using OpRewritePattern<AffineOpType>::OpRewritePattern;
617
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseBinaryOpToFollowingAffineOp618 LogicalResult matchAndRewrite(AffineOpType fc_op,
619 PatternRewriter &rewriter) const override {
620 // Binary op.
621 Operation *binary_op = fc_op.input().getDefiningOp();
622 if (!binary_op || binary_op->getNumOperands() != 2) return failure();
623 // We only handle the cases the RHS is a scalar.
624 // TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
625 // the constant operands are on the RHS, we need to consider LHS constant
626 // operand if necessary.
627 DenseFPElementsAttr cst;
628 if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
629 return failure();
630 if (cst.getNumElements() != 1) return failure();
631 APFloat cst_value = *cst.float_value_begin();
632
633 // Affine op.
634 Value filter = fc_op.filter();
635 Value bias = fc_op.bias();
636 DenseFPElementsAttr filter_cst, bias_cst;
637 if (!matchPattern(filter, m_Constant(&filter_cst))) {
638 // The filter maybe quantized, then we should set it to the real constant.
639 auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
640 if (!dq) return failure();
641 auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
642 if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
643 return failure();
644 }
645 filter = q.input();
646 }
647 if (!bias.getType().isa<NoneType>() &&
648 !matchPattern(bias, m_Constant(&bias_cst)))
649 return failure();
650 ShapedType filter_type = filter_cst.getType();
651
652 if (llvm::isa<AddOp, SubOp>(binary_op)) {
653 auto padding = fc_op->template getAttrOfType<StringAttr>("padding");
654 if (padding && padding.getValue() != "VALID") return failure();
655
656 // The fusion of add/sub is actually applying the following
657 // transformation:
658 // w * (x + c) + b => w * x + (w * c + b)
659 // so we have to update the bias.
660 if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
661
662 auto bias_and_slice =
663 GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
664 int64_t bias_size = bias_and_slice.first;
665 int64_t slice_size = bias_and_slice.second;
666 ShapedType new_bias_type =
667 RankedTensorType::get({bias_size}, filter_type.getElementType());
668
669 // The new bias should be a 1-D tensor with length equals to the bias
670 // dimension of the weight.
671 SmallVector<APFloat, 4> new_bias_values;
672 if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
673 new_bias_values.resize(bias_size, APFloat(0.0));
674 } else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
675 new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
676 } else if (bias_cst.getNumElements() == bias_size) { // 1-d bias, copy it
677 new_bias_values.insert(new_bias_values.begin(),
678 bias_cst.float_value_begin(),
679 bias_cst.float_value_end());
680 } else {
681 return failure();
682 }
683
684 int64_t flatten_index = 0;
685 for (auto fp_it = filter_cst.float_value_begin(),
686 fp_end = filter_cst.float_value_end();
687 fp_it != fp_end; ++fp_it) {
688 int bias_index = (flatten_index++ / slice_size) % bias_size;
689
690 new_bias_values[bias_index] =
691 new_bias_values[bias_index] + *fp_it * cst_value;
692 }
693 auto new_bias = DenseFPElementsAttr::get(new_bias_type, new_bias_values);
694 auto new_bias_op =
695 rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
696 fc_op.setOperand(0, binary_op->getOperand(0));
697 fc_op.setOperand(2, new_bias_op);
698 } else if (llvm::isa<MulOp, DivOp>(binary_op)) {
699 // The fusion of mul/div is actually applying the following
700 // transformation:
701 // w * (x ' c) + b => (w ' c) x + b
702 // so we have to update the weight.
703 bool is_mul = llvm::isa<MulOp>(binary_op);
704 auto new_filter =
705 filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
706 return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
707 });
708 // We recreate the constant op in case it is shared by the other ops. This
709 // might increase the model size.
710 auto new_filter_op = rewriter.create<ConstOp>(
711 fc_op.getLoc(), filter.getType(), new_filter);
712 fc_op.setOperand(0, binary_op->getOperand(0));
713 if (fc_op.filter() != filter) {
714 // This filter goes through quantize and dequantize ops. Then we just
715 // need to update the weight to the quantize op.
716 filter.replaceAllUsesWith(new_filter_op);
717 } else {
718 // This filter doesn't go through quantize and dequantize ops, Then
719 // we update the weight of the affine op directly.
720 fc_op.setOperand(1, new_filter_op);
721 }
722 } else {
723 return failure();
724 }
725 return success();
726 }
727
728 private:
729 // Returns the dimension length of the channel dimension and also the slide
730 // size by each position in the channel dimension accordingly. tfl.conv2d and
731 // tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d
732 // has tailing channel dimension. This function is to provide a utility to
733 // create the above information from the op property.
GetBiasDimAndSliceSizemlir::TFL::__anonf2f2cbe90111::FuseBinaryOpToFollowingAffineOp734 static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
735 ArrayRef<int64_t> filter_shape, AffineOpType op) {
736 // Channel dimension index is specified as op property
737 auto channel_index_iter = filter_shape.begin();
738 std::advance(channel_index_iter, op.GetChannelDimIndex());
739 // The slide size is the size of the data in higher dimensions.
740 int64_t slice_size =
741 std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
742 std::multiplies<int64_t>());
743 return {*channel_index_iter, slice_size};
744 }
745 };
746
747 // If the operand to a broadcastable op is a splat constant, try to replace it
748 // with a 0-d constant, e.g. before this optimization,
749 // %cst = constant dense<1.0> : tensor<16x16x4xf32>
750 // %0 = "tfl.conv_2d"...
751 // %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
752 // After this optimization:
753 // %cst = constant dense<1.0> : tensor<f32>
754 // %0 = "tfl.conv_2d"...
755 // %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
756 // This pattern can enable more fusing opportunities when the binary op is
757 // following conv ops.
758 template <typename BinaryOpType>
759 struct ScalarizeSplatConstantForBroadcastableOps
760 : public OpRewritePattern<BinaryOpType> {
761 using OpRewritePattern<BinaryOpType>::OpRewritePattern;
762
matchAndRewritemlir::TFL::__anonf2f2cbe90111::ScalarizeSplatConstantForBroadcastableOps763 LogicalResult matchAndRewrite(BinaryOpType binary_op,
764 PatternRewriter &rewriter) const override {
765 DenseElementsAttr splat_elements_attr;
766 if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
767 return failure();
768 }
769
770 constexpr int kSplatOperandIndex = 1;
771 auto result_type =
772 binary_op.getResult().getType().template cast<ShapedType>();
773 mlir::Value non_splat_operand =
774 binary_op.getOperand(1 - kSplatOperandIndex);
775 auto non_splat_operand_type =
776 non_splat_operand.getType().cast<ShapedType>();
777 // If the other operand's shape does not equal to the result shape, then we
778 // cannot scalarize the splat constant because the result shape relies on
779 // the splat constant op's shape for broadcasting.
780 if (!non_splat_operand_type.hasStaticShape() ||
781 non_splat_operand_type.getShape() != result_type.getShape() ||
782 non_splat_operand_type.getRank() > 4) {
783 return failure();
784 }
785
786 // If non-splat operand is not fusable affine ops, then no need to apply
787 // this transformation.
788 if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
789 return failure();
790 }
791
792 // Creates a new scalar constant op using the splat value.
793 mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
794 auto scalar_elements_attr = DenseElementsAttr::get(
795 RankedTensorType::get({},
796 splat_elements_attr.getType().getElementType()),
797 splat_elements_attr.getSplatValue());
798
799 auto scalar_constant_op = rewriter.create<ConstantOp>(
800 splat_operand.getLoc(), scalar_elements_attr.getType(),
801 scalar_elements_attr);
802
803 binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
804 return success();
805 }
806
807 private:
808 // Returns true if this value is a splat constant op which can be scalarized.
809 // Also returns the elements attr if this value is indeed a splat constant.
IsScalarizableSplatConstantmlir::TFL::__anonf2f2cbe90111::ScalarizeSplatConstantForBroadcastableOps810 bool IsScalarizableSplatConstant(mlir::Value value,
811 DenseElementsAttr *elements_attr) const {
812 if (!matchPattern(value, m_Constant(elements_attr))) {
813 return false;
814 }
815 auto element_type = value.getType().cast<ShapedType>().getElementType();
816 // Ignore per-axis quantized constants because after converting to scalar,
817 // we will lose per-axis qantization parameter.
818 if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
819 return false;
820 }
821 if (IsScalar(value)) {
822 return false;
823 }
824 return elements_attr->isSplat();
825 }
826
827 // If this type is a scalar shaped type.
IsScalarmlir::TFL::__anonf2f2cbe90111::ScalarizeSplatConstantForBroadcastableOps828 bool IsScalar(mlir::Value value) const {
829 auto type = value.getType().dyn_cast<ShapedType>();
830 if (!type) {
831 return false;
832 }
833 if (!type.hasStaticShape()) {
834 return false;
835 }
836 return type.getNumElements() == 1;
837 }
838
839 // Returns true if we can fuse an affine op with consuming binary op.
CanFuseAffineOpmlir::TFL::__anonf2f2cbe90111::ScalarizeSplatConstantForBroadcastableOps840 bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
841 if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
842 TFL::FullyConnectedOp>(affine_op)) {
843 return false;
844 }
845 DenseElementsAttr value;
846 // Check that bias are constants if not none.
847 Value bias = affine_op->getOperand(2);
848 if (!bias.getType().isa<NoneType>() &&
849 !matchPattern(bias, m_Constant(&value))) {
850 return false;
851 }
852 // If the binary op is mul/div, also check that filter is constant.
853 if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
854 !matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
855 return false;
856 }
857
858 // We can only fuse F32/BF16.
859 auto is_fusable_type = [](Type t) {
860 Type element_type = t;
861 if (auto shaped_type = t.dyn_cast<ShapedType>()) {
862 element_type = shaped_type.getElementType();
863 }
864 return element_type.isBF16() || element_type.isF32();
865 };
866 for (Type t : binary_op->getOperandTypes()) {
867 if (!is_fusable_type(t)) {
868 return false;
869 }
870 }
871
872 return true;
873 }
874 };
875
876 using ScalarizeSplatConstantForSub =
877 ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
878 using ScalarizeSplatConstantForAdd =
879 ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
880 using ScalarizeSplatConstantForMul =
881 ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
882 using ScalarizeSplatConstantForDiv =
883 ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
884
885 struct ConvertTrivialTransposeOpToReshapeOp
886 : public OpRewritePattern<TFL::TransposeOp> {
887 using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
888
matchAndRewritemlir::TFL::__anonf2f2cbe90111::ConvertTrivialTransposeOpToReshapeOp889 LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
890 PatternRewriter &rewriter) const override {
891 auto input_type = transpose_op.input().getType().cast<ShapedType>();
892 auto output_type = transpose_op.output().getType().cast<ShapedType>();
893 // It's possible to know if the transformation is safe only if the input
894 // & output shapes are fully known and permutation is a constant.
895 if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
896 return failure();
897 Value perm = transpose_op.perm();
898 DenseElementsAttr perm_values_attr;
899 if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
900
901 auto input_shape = input_type.getShape();
902 SmallVector<int64_t, 8> perm_values;
903 for (const auto &dim : perm_values_attr.getIntValues())
904 perm_values.push_back(dim.getSExtValue());
905
906 // This should never happen unless the input graph is malformed.
907 if (input_shape.size() != perm_values.size()) {
908 transpose_op.emitError(
909 "TransposeOP has inconsistent input and perm values.");
910 }
911
912 SmallVector<int, 8> old_major_index_ordering;
913 SmallVector<int, 8> new_major_index_ordering;
914 for (int i = 0, end = input_shape.size(); i < end; i++) {
915 if (input_shape[i] != 1) {
916 old_major_index_ordering.push_back(i);
917 }
918
919 if (input_shape[perm_values[i]] != 1) {
920 new_major_index_ordering.push_back(perm_values[i]);
921 }
922 }
923 if (old_major_index_ordering != new_major_index_ordering) {
924 return failure();
925 }
926
927 // Rewrite.
928 Location loc = transpose_op.getLoc();
929
930 SmallVector<int32_t, 8> output_shape_values;
931 for (auto dim : output_type.getShape()) {
932 output_shape_values.push_back(dim);
933 }
934 auto type = mlir::RankedTensorType::get(output_shape_values.size(),
935 rewriter.getIntegerType(32));
936 auto new_shape_attr =
937 mlir::DenseIntElementsAttr::get(type, output_shape_values);
938 auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr);
939
940 rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
941 transpose_op, transpose_op.output().getType(), transpose_op.input(),
942 new_shape);
943
944 return success();
945 }
946 };
947
948 using FuseBinaryOpToFollowingFullyConnected =
949 FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
950 using FuseBinaryOpToFollowingDepthwiseConv2D =
951 FuseBinaryOpToFollowingAffineOp<DepthwiseConv2DOp>;
952 using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp<Conv2DOp>;
953
runOnFunction()954 void Optimize::runOnFunction() {
955 OwningRewritePatternList patterns;
956 auto *ctx = &getContext();
957 auto func = getFunction();
958
959 // Potentially the binary ops might be fused together, like hard_swish, thus
960 // we explore these potentially first and then fuse the binary ops with the
961 // following ops in a second pattern match.
962 TFL::populateWithGenerated(ctx, patterns);
963 patterns.insert<FuseFullyConnectedAndAdd,
964 FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
965 FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
966 FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
967 FuseFullyConnectedAndMul>(ctx);
968 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
969
970 // Fuse the binary ops with the following ops.
971 OwningRewritePatternList phase_2_patterns;
972 TFL::populateWithGenerated(ctx, phase_2_patterns);
973 phase_2_patterns.insert<
974 ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
975 ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
976 FuseFullyConnectedAndAdd, FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
977 FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
978 FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
979 FuseFullyConnectedAndMul, FuseBinaryOpToFollowingConv2D,
980 FuseBinaryOpToFollowingDepthwiseConv2D,
981 FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
982 FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
983 ctx);
984 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
985 }
986
987 } // namespace
988
989 // Creates an instance of the TensorFlow Lite dialect Optimize pass.
CreateOptimizePass()990 std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass() {
991 return std::make_unique<Optimize>();
992 }
993
994 static PassRegistration<Optimize> pass(
995 "tfl-optimize", "Optimize within the TensorFlow Lite dialect");
996
997 } // namespace TFL
998 } // namespace mlir
999