• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass takes operations in TensorFlowLite dialect and
17 // optimizes them to resulting operations in TensorFlowLite dialect.
18 
19 #include <algorithm>
20 #include <climits>
21 #include <cstdint>
22 #include <functional>
23 #include <iterator>
24 #include <map>
25 #include <numeric>
26 #include <utility>
27 
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallSet.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/ADT/StringSwitch.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
41 #include "mlir/IR/Attributes.h"  // from @llvm-project
42 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
43 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
44 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
45 #include "mlir/IR/Matchers.h"  // from @llvm-project
46 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
47 #include "mlir/IR/Value.h"  // from @llvm-project
48 #include "mlir/Pass/Pass.h"  // from @llvm-project
49 #include "mlir/Support/LLVM.h"  // from @llvm-project
50 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
51 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
52 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
53 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
54 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
56 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
57 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
58 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
59 
60 namespace mlir {
61 namespace TFL {
62 
63 //===----------------------------------------------------------------------===//
64 // The actual Optimize Pass.
65 namespace {
66 constexpr char kRelu[] = "RELU";
67 constexpr char kRelu6[] = "RELU6";
68 constexpr char kRelu1[] = "RELU_N1_TO_1";
69 
L2NormalizeReduceAxis(Value sq_op,DenseElementsAttr axis)70 bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
71   if (axis.getNumElements() == 0) {
72     return false;
73   }
74   if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
75           *axis.getValues<int>().begin() ||
76       *axis.getValues<int>().begin() == -1) {
77     return true;
78   }
79   if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
80     return false;
81   }
82   auto shape = sq_op.getType().cast<ShapedType>();
83   SmallVector<int, 4> elems{axis.getValues<int>().begin(),
84                             axis.getValues<int>().end()};
85   for (int i = 0; i < shape.getRank(); ++i) {
86     if (i != elems[i]) return false;
87   }
88   return true;
89 }
90 
91 using ::llvm::cast;
92 
93 // Optimize TFLite operations in functions.
94 class OptimizePass : public PassWrapper<OptimizePass, FunctionPass> {
95  public:
96   OptimizePass() = default;
OptimizePass(const OptimizePass &)97   OptimizePass(const OptimizePass &) {}
OptimizePass(bool enable_canonicalization)98   explicit OptimizePass(bool enable_canonicalization) {
99     enable_canonicalization_ = enable_canonicalization;
100   }
101 
getArgument() const102   StringRef getArgument() const final {
103     // This is the argument used to refer to the pass in
104     // the textual format (on the commandline for example).
105     return "tfl-optimize";
106   }
getDescription() const107   StringRef getDescription() const final {
108     // This is a brief description of the pass.
109     return "Optimize within the TensorFlow Lite dialect";
110   }
111 
112   void runOnFunction() override;
113 
114  private:
115   Option<bool> enable_canonicalization_{
116       *this, "enable-canonicalization",
117       llvm::cl::desc("Enable canonicalization during optimization pass."),
118       llvm::cl::init(false)};
119 };
120 
121 // Returns whether the given type `a` is broadcast-compatible with `b`.
IsBroadcastableElementsAttrAndType(Type a,Type b)122 bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
123   return OpTrait::util::getBroadcastedType(a, b) != Type();
124 }
125 
126 // Returns whether the resultant type of any broadcastable operation with
127 // operands `a` and `b` matches `expected_output`. Returns false if `a` is not
128 // broadcast-compatible with `b`.
OperandsBroadcastToOutputType(Type a,Type b,Type expected_output)129 bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) {
130   Type output_element_type =
131       expected_output.cast<ShapedType>().getElementType();
132   Type broadcasted_type =
133       OpTrait::util::getBroadcastedType(a, b, output_element_type);
134   return broadcasted_type != Type() && broadcasted_type == expected_output;
135 }
136 
137 // Returns whether if `type1` dimensions are the same as the ending dimensions
138 // of `type2`. This is more restricted than broadcastable.
IsTailOfShape(Type type1,Type type2)139 bool IsTailOfShape(Type type1, Type type2) {
140   auto tail_type = type1.dyn_cast<ShapedType>();
141   auto full_type = type2.dyn_cast<ShapedType>();
142   if (!tail_type || !full_type || !tail_type.hasRank() ||
143       !full_type.hasRank() || tail_type.getRank() > full_type.getRank())
144     return false;
145   auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
146   auto i2 = full_type.getShape().rbegin();
147   return std::equal(i1, e1, i2);
148 }
149 
CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,const ArrayRef<int64_t> elements_shape,bool is_depthwise)150 bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
151                                       const ArrayRef<int64_t> elements_shape,
152                                       bool is_depthwise) {
153   // Make sure the val tensor has shape where all dimensions are 1 except
154   // last one.
155   // Also, val tensor must be of rank 1 or 4 or 0 (scalar).
156   const auto elements_rank = elements_shape.size();
157   for (int i = 0; i < static_cast<int>(elements_shape.size()) - 1; ++i) {
158     if (elements_shape[i] != 1) return false;
159   }
160   if (elements_rank != 1 && elements_rank != 0 && elements_rank != 4) {
161     return false;
162   }
163   auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back();
164   // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we
165   // can let binary op to broadcast elements.
166   if (elements_depth == 1) {
167     return true;
168   }
169 
170   // In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise Conv.
171   // For conv:
172   // Check if last dimension in filter equals the first dimension
173   // For depthwise conv:
174   // Check if the first in filter dimension equals the first dimension.
175   if (filter_shape.empty() ||
176       (is_depthwise ? filter_shape.back() != elements_depth
177                     : filter_shape[0] != elements_depth))
178     return false;
179   return true;
180 }
181 
CanFuseConvOrDepthwiseConv(Value filter,Attribute val,bool is_depthwise)182 bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
183                                 bool is_depthwise) {
184   const auto elements = val.dyn_cast<DenseElementsAttr>();
185   if (!elements) {
186     return false;
187   }
188   const auto elements_shape = elements.getType().getShape();
189   const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
190   return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
191                                           is_depthwise);
192 }
193 
CanFuseConvOrDepthwiseConv(Attribute filter,Attribute val,bool is_depthwise)194 bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
195                                 bool is_depthwise) {
196   if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
197     if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
198       return CanFuseConvOrDepthwiseConvShapes(
199           filter_elements.getType().getShape(), elements.getType().getShape(),
200           is_depthwise);
201     }
202   }
203   return false;
204 }
205 
206 // Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value
207 // of `indices` are from 0 to n-1, the output tensor are identical to the
208 // `params`.
CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,DenseIntElementsAttr indices)209 bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
210                                               DenseIntElementsAttr indices) {
211   auto params_type = params.getType().dyn_cast<RankedTensorType>();
212   auto indices_type = indices.getType().dyn_cast<RankedTensorType>();
213   // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D
214   // `indices` means it gets the first row of `params`. As long as indices
215   // iterate the first row of `params`, the output is identical to input.
216   if (!params_type || !indices_type || indices_type.getRank() != 2 ||
217       indices_type.getDimSize(0) != params_type.getDimSize(0) ||
218       indices_type.getDimSize(1) != 1)
219     return false;
220 
221   // Checks the value in `indices` is from 0 to n-1.
222   int cur_value = 0;
223   for (const auto &v : indices.getValues<APInt>()) {
224     if (v.getSExtValue() != cur_value) return false;
225     ++cur_value;
226   }
227 
228   return true;
229 }
230 
231 // Returns true if we can eliminate the SliceOp. When the values of `begin` are
232 // all 0s and `size[i]` is equal to either -1 or `input.shape[i]`
233 // for each dim i, the output tensor is identical to `input`.
CanOptimizeIdentitySliceOp(Value input,Attribute begin,Attribute size)234 bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) {
235   // Checks if `begin` and `size` are i32 or i64.
236   auto begin_attr = begin.dyn_cast<DenseIntElementsAttr>();
237   auto size_attr = size.dyn_cast<DenseIntElementsAttr>();
238   if (!begin_attr || !size_attr) {
239     return false;
240   }
241 
242   auto begin_elem_ty = begin_attr.getType().getElementType();
243   if (!begin_elem_ty.isInteger(32) && !begin_elem_ty.isInteger(64)) {
244     return false;
245   }
246   auto size_elem_ty = size_attr.getType().getElementType();
247   if (!size_elem_ty.isInteger(32) && !size_elem_ty.isInteger(64)) {
248     return false;
249   }
250 
251   // Checks if `input` is ranked and its rank is equal to number of elements in
252   // `begin` and `size`.
253   auto input_ty = input.getType().cast<ShapedType>();
254   if (!input_ty.hasRank()) {
255     return false;
256   }
257 
258   int64_t rank = input_ty.getRank();
259   if (rank != begin_attr.getNumElements() ||
260       rank != size_attr.getNumElements()) {
261     return false;
262   }
263 
264   // Checks if `begin` is all 0s, and `size[i]` is equal to either -1 or
265   // `input.shape[i]`.
266   for (uint64_t i = 0; i < rank; ++i) {
267     if (begin_attr.getValue<APInt>({i}).getSExtValue() != 0) return false;
268     int64_t si = size_attr.getValue<APInt>({i}).getSExtValue();
269     if (si != -1 && si != input_ty.getDimSize(i)) return false;
270   }
271 
272   return true;
273 }
274 
275 // Expand Attribute 'a' to 4D with all 1s except 1 dimension.
276 // Which dimension depends on 'is_depthwise' is true or false.
ExpandTo4DForConvImpl(Attribute a,bool is_depthwise)277 ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
278   auto elements = a.dyn_cast<DenseElementsAttr>();
279   auto shape = elements.getType().getShape();
280   if (!shape.empty()) {
281     // Checks that elements are essentially 1d.
282     assert(elements.getNumElements() == shape.back());
283   }
284   std::vector<int64_t> shape_data = {1, 1, 1, 1};
285   const int vector_length = elements.getNumElements();
286   if (is_depthwise)
287     shape_data[3] = vector_length;
288   else
289     shape_data[0] = vector_length;
290   auto new_shape =
291       RankedTensorType::get(shape_data, elements.getType().getElementType());
292   return elements.reshape(new_shape);
293 }
294 
ExpandTo4DForConv(Attribute a)295 ElementsAttr ExpandTo4DForConv(Attribute a) {
296   return ExpandTo4DForConvImpl(a, false);
297 }
298 
ExpandTo4DForDepthwiseConv(Attribute a)299 ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
300   return ExpandTo4DForConvImpl(a, true);
301 }
302 
RescaleQtype(Type input,Attribute factor)303 TypeAttr RescaleQtype(Type input, Attribute factor) {
304   return quant::RescaleQuantizedType(input, factor);
305 }
306 
307 // Returns shape of a ranked tensor.
308 // Precondition: output_val's is ranked tensor.
GetShape(Value output_val)309 DenseElementsAttr GetShape(Value output_val) {
310   auto output_type = output_val.getType().cast<RankedTensorType>();
311   auto shape_vector = output_type.getShape();
312   std::vector<int32_t> shape;
313   shape.reserve(shape_vector.size());
314   for (auto shape_object : shape_vector) {
315     shape.push_back(shape_object);
316   }
317   return mlir::DenseElementsAttr::get(
318       RankedTensorType::get(
319           {static_cast<int>(shape.size())},
320           mlir::IntegerType::get(output_val.getContext(), 32)),
321       llvm::makeArrayRef(shape));
322 }
323 
GetShapeStrippedType(TypeAttr type_attr)324 static Type GetShapeStrippedType(TypeAttr type_attr) {
325   auto type = type_attr.getValue();
326   auto shaped_type = type.dyn_cast<ShapedType>();
327   if (shaped_type) {
328     return shaped_type.getElementType();
329   } else {
330     return type;
331   }
332 }
333 
334 // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in
335 // the specified `shape` and `false` otherwise.
ShapeMatchesReduceWithKeepAxes(Value input,const mlir::Attribute & axes,const mlir::Attribute & shape)336 static bool ShapeMatchesReduceWithKeepAxes(Value input,
337                                            const mlir::Attribute &axes,
338                                            const mlir::Attribute &shape) {
339   RankedTensorType type = input.getType().dyn_cast_or_null<RankedTensorType>();
340   if (!type) return false;
341 
342   DenseIntElementsAttr axes_attr =
343       axes.dyn_cast_or_null<DenseIntElementsAttr>();
344   DenseIntElementsAttr shape_attr =
345       shape.dyn_cast_or_null<DenseIntElementsAttr>();
346   if (!axes_attr || !shape_attr) return false;
347 
348   if (shape_attr.getNumElements() != type.getRank()) return false;
349 
350   llvm::SmallSet<uint64_t, 4> axes_set;
351   for (auto a : axes_attr.getIntValues()) {
352     axes_set.insert(a.getZExtValue());
353   }
354 
355   auto type_shape = type.getShape();
356   for (uint64_t i = 0; i < type.getRank(); ++i) {
357     if (axes_set.contains(i)) {
358       if (shape_attr.getValue<APInt>({i}) != 1) return false;
359     } else {
360       if (shape_attr.getValue<APInt>({i}) != type_shape[i]) return false;
361     }
362   }
363   return true;
364 }
365 
FloatValueEquals(const Attribute & attr,double value)366 static bool FloatValueEquals(const Attribute &attr, double value) {
367   auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
368   if (!fp_attr) return false;
369 
370   if (fp_attr.isSplat()) {
371     return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
372   }
373   return llvm::all_of(fp_attr.getFloatValues(), [value](const APFloat &f) {
374     return f.isExactlyValue(value);
375   });
376 }
377 
378 // Returns true if the value's element type is F32.
IsF32Value(Value value)379 bool IsF32Value(Value value) {
380   return value.getType().cast<ShapedType>().getElementType().isF32();
381 }
382 
383 // Returns the number of elements in attr if it is a DenseElementsAttr, 1
384 // otherwise, as an unranked int32 Attribute.
GetNumElementsOrOne(Attribute attr)385 Attribute GetNumElementsOrOne(Attribute attr) {
386   const auto dense_attr = attr.dyn_cast_or_null<DenseElementsAttr>();
387   int32_t num_elements = dense_attr ? dense_attr.getNumElements() : 1;
388 
389   OpBuilder builder(attr.getContext());
390 
391   return DenseIntElementsAttr::get(
392       RankedTensorType::get({}, builder.getI32Type()),
393       {llvm::APInt(32, num_elements, true)});
394 }
395 
396 // Returns true if attr is a DenseIntElementsAttr with the last element equal 1.
IsLastElementEqualsOne(Attribute attr)397 bool IsLastElementEqualsOne(Attribute attr) {
398   const auto ints = attr.dyn_cast_or_null<DenseIntElementsAttr>();
399   if (!ints) return false;
400   if (ints.empty()) return false;
401   const auto last_element_index = ints.getNumElements() - 1;
402   const auto iterator = ints.getIntValues().begin();
403   const APInt last_element = iterator[last_element_index];
404   return last_element == 1;
405 }
406 
407 // Returns true if attr is a DenseIntElementsAttr of int32 or int64 values or an
408 // incrementing sequence from 0 to N-1.
409 //
410 // If such a value is used in an Equal operator, it can be replaced with OneHot.
IsOneHotIndexAttribute(Attribute attr)411 bool IsOneHotIndexAttribute(Attribute attr) {
412   const auto dense_attr = attr.dyn_cast_or_null<DenseIntElementsAttr>();
413   if (!dense_attr) {
414     return false;
415   }
416   auto index_type = dense_attr.getType();
417   const auto index_elem_bits = index_type.getElementTypeBitWidth();
418   if (index_elem_bits != 32 && index_elem_bits != 64) {
419     return false;
420   }
421   if (index_type.getRank() != 1) {
422     return false;
423   }
424   const auto elems = dense_attr.getIntValues().begin();
425   for (int i = 0; i < dense_attr.getNumElements(); ++i) {
426     if (i != elems[i]) {
427       return false;
428     }
429   }
430   return true;
431 }
432 
433 // Converts an Attribute with a single value of float or integral type to an
434 // Attribute holding a single value of float type. If attr has no elements, the
435 // result is 0.0f.
ConvertSingleElementAttrToFloatAttr(Attribute attr)436 Attribute ConvertSingleElementAttrToFloatAttr(Attribute attr) {
437   const auto dense_fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
438   if (dense_fp_attr) {
439     // Already float => return
440     return dense_fp_attr;
441   }
442 
443   OpBuilder builder(attr.getContext());
444 
445   const auto dense_int_attr = attr.dyn_cast<DenseIntElementsAttr>();
446   const auto int_values = dense_int_attr.getIntValues();
447   float float_val = 0.0f;
448   if (!int_values.empty()) {
449     const APInt apint_val = *int_values.begin();
450     if (dense_int_attr.getType().getElementType().isSignedInteger()) {
451       // Get the sign-extended value (=>int64) if the type is signed.
452       float_val = apint_val.getSExtValue();
453     } else {
454       // Get the zero-extended value (=>uint64) if unsigned or signless.
455       float_val = apint_val.getZExtValue();
456     }
457   }
458   return DenseFPElementsAttr::get(
459       RankedTensorType::get({}, builder.getF32Type()),
460       {llvm::APFloat(float_val)});
461 }
462 
463 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
464 
465 // Fuse Add with proceeding FullyConnected.
466 // TODO(b/136285429): Move to tablegen when variadic is supported
467 struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
468   using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
469 
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseFullyConnectedAndAdd470   LogicalResult matchAndRewrite(TFL::AddOp add_op,
471                                 PatternRewriter &rewriter) const override {
472     // Match Add.
473     DenseElementsAttr added_value;
474     Value constant_val = add_op.rhs();
475     if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
476 
477     // Match Fully Connected.
478     auto fc_op =
479         dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
480     if (!fc_op) return failure();
481 
482     // Check if the constant RHS is either 0D (scalar), or a 1D with
483     // `{num_channels}` shape.
484     auto constant_val_type = constant_val.getType().cast<TensorType>();
485 
486     // In TFLite FullyConnect definition, bias must be a 1D tensor where
487     // the number of elements is equal to the number of channels.
488     // If it's not 1D or 0D (which can be broadcasted to 1D), reject the
489     // matching.
490     bool is_scalar_rhs = false;
491     if (constant_val_type.getRank() == 0) {
492       is_scalar_rhs = true;
493     } else if (constant_val_type.getRank() != 1) {
494       return failure();
495     }
496 
497     Value filter = fc_op.filter();
498     Value bias = fc_op.bias();
499     ElementsAttr bias_value;
500     const bool is_none_bias = bias.getType().isa<NoneType>();
501     if (fc_op.fused_activation_function() != "NONE") return failure();
502 
503     if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
504       return failure();
505 
506     // Rewrite
507     if (is_none_bias) {
508       if (is_scalar_rhs) {
509         // If the `constant_val` is scalar, we must the shape of filter
510         // to properly broadcast the scalar to `{num_channels}` shape.
511 
512         // Get the number of channels if possible.
513         auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
514         // Filter must be a `2D` tensor with `{num_channels, num_features}`
515         // shape. The following check is rejecting unknown rank (-1).
516         if (filter_type == nullptr || filter_type.getRank() != 2) {
517           return failure();
518         }
519         int num_channels = filter_type.getShape()[0];
520 
521         // Create a zero tensor with shape {num_channels}, and the type need to
522         // be the same as constant_val.
523         // This is a way to gracefully handle scalar tensor. The Add will always
524         // be constant-folded away regardless if `constant_val` is a scalar or
525         // not.
526         RankedTensorType type = RankedTensorType::get(
527             {num_channels}, constant_val_type.getElementType());
528         auto attr = rewriter.getZeroAttr(type);
529         bias = rewriter.create<ConstantOp>(add_op.getLoc(), type, attr);
530         auto none_af = rewriter.getStringAttr("NONE");
531         bias =
532             rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
533                 .output();
534       } else {
535         // If there no pre-existing bias and the `constant_val` is 1D, simply
536         // use `constant_val` as bias.
537         bias = constant_val;
538       }
539     } else {
540       auto none_af = rewriter.getStringAttr("NONE");
541       bias =
542           rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
543               .output();
544     }
545 
546     auto fc = rewriter.create<TFL::FullyConnectedOp>(
547         FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), add_op.getLoc()}),
548         add_op.getType(),
549         /*input=*/fc_op.input(),
550         /*filter=*/filter,
551         /*bias=*/bias,
552         /*fused_activation_function=*/
553         rewriter.getStringAttr(add_op.fused_activation_function()),
554         /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
555         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
556     rewriter.replaceOp(add_op, fc.output());
557 
558     return success();
559   }
560 };
561 
562 // Replace ..
563 // FC(Add(lhs, rhs), filter, bias)
564 // .. with ..
565 // FC(lhs, filter, FC(rhs, filter, bias))
566 // .. if rhs, filter, and bias are all constants.
567 // The second FC will be constant folded to a single vector.
568 // TODO(b/136285429): Move to tablegen when variadic is supported
569 struct FuseAddAndFullyConnected
570     : public OpRewritePattern<TFL::FullyConnectedOp> {
571   using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
572 
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseAddAndFullyConnected573   LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
574                                 PatternRewriter &rewriter) const override {
575     // This only works with default format.
576     if (fc_op.weights_format() != "DEFAULT") return failure();
577 
578     // Match Add.
579     auto add_op = dyn_cast_or_null<TFL::AddOp>(fc_op.input().getDefiningOp());
580     if (!add_op) return failure();
581     if (add_op.fused_activation_function() != "NONE") return failure();
582 
583     // Don't match adds where the added constant is not 1D.
584     {
585       auto addend_shape = add_op.rhs().getType().cast<ShapedType>();
586       if (!addend_shape.hasStaticShape()) return failure();
587       if (addend_shape.getShape().size() != 1) return failure();
588     }
589 
590     // Calculate new bias.  Generate a new FC; it will be constant folded.
591     auto old_bias = fc_op.bias();
592     if (!old_bias || old_bias.getType().isa<NoneType>()) {
593       // TODO(b/180752069): Figure out new bias' type when old bias is empty.
594       return failure();
595     }
596 
597     // The FC relies on constant folding, which is implemented on F32. Checks
598     // types to be F32.
599     {
600       if (!IsF32Value(add_op.rhs()) || !IsF32Value(fc_op.filter()) ||
601           !IsF32Value(old_bias))
602         return failure();
603     }
604 
605     auto new_bias = rewriter.create<TFL::FullyConnectedOp>(
606         fc_op.getLoc(), old_bias.getType(),
607         /*input=*/add_op.rhs(),
608         /*filter=*/fc_op.filter(),
609         /*bias=*/old_bias,
610         /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
611         /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
612         /*keep_num_dims=*/rewriter.getBoolAttr(true));
613 
614     // Create the updated FC.
615     auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
616         FusedLoc::get(add_op.getContext(), {add_op.getLoc(), fc_op.getLoc()}),
617         fc_op.output().getTypes(),
618         /*input=*/add_op.lhs(),
619         /*filter=*/fc_op.filter(),
620         /*bias=*/*new_bias.output().begin(),
621         /*fused_activation_function=*/
622         rewriter.getStringAttr(fc_op.fused_activation_function()),
623         /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
624         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
625     rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
626 
627     return success();
628   }
629 };
630 
631 // Replace ..
632 // FC(Mul(lhs, rhs), filter, bias)
633 // .. with ..
634 // FC(lhs, Mul(filter, rhs), bias)
635 // .. if rhs, filter, and bias are all constants.
636 // The generated Mul will be constant folded to a single matrix.
637 struct FuseMulAndFullyConnected
638     : public OpRewritePattern<TFL::FullyConnectedOp> {
639   using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
640 
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseMulAndFullyConnected641   LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
642                                 PatternRewriter &rewriter) const override {
643     // This only works with default format.
644     if (fc_op.weights_format() != "DEFAULT") return failure();
645 
646     // Match Mul.
647     auto mul_op = dyn_cast_or_null<TFL::MulOp>(fc_op.input().getDefiningOp());
648     if (!mul_op) return failure();
649     if (mul_op.fused_activation_function() != "NONE") return failure();
650 
651     // Don't match muls where the multiplier constant is not 1D.
652     {
653       auto multiplier_shape = mul_op.rhs().getType().cast<ShapedType>();
654       if (!multiplier_shape.hasStaticShape()) return failure();
655       if (multiplier_shape.getShape().size() != 1) return failure();
656     }
657 
658     // We rely on constant folding, implemented only for F32. Check types.
659     if (!IsF32Value(mul_op.rhs()) || !IsF32Value(fc_op.filter())) {
660       return failure();
661     }
662 
663     auto location =
664         FusedLoc::get(mul_op.getContext(), {mul_op.getLoc(), fc_op.getLoc()});
665 
666     auto new_filter = rewriter.create<TFL::MulOp>(
667         location,
668         /*lhs=*/fc_op.filter(),
669         /*rhs=*/mul_op.rhs(),
670         /*fused_activation_function=*/rewriter.getStringAttr("NONE"));
671     // Create the updated FC.
672     auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
673         location, fc_op.output().getTypes(),
674         /*input=*/mul_op.lhs(),
675         /*filter=*/new_filter,
676         /*bias=*/fc_op.bias(),
677         /*fused_activation_function=*/
678         rewriter.getStringAttr(fc_op.fused_activation_function()),
679         /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
680         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
681     rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
682 
683     return success();
684   }
685 };
686 
687 // TODO(b/136285429): Move to tablegen when variadic is supported.
688 template <typename ReluXOp, char const *Act>
689 struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
690   using OpRewritePattern<ReluXOp>::OpRewritePattern;
691 
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseFullyConnectedAndReluX692   LogicalResult matchAndRewrite(ReluXOp relu_op,
693                                 PatternRewriter &rewriter) const override {
694     Operation *input = relu_op.getOperand().getDefiningOp();
695     if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
696     auto fully_connected_op = cast<FullyConnectedOp>(input);
697     if (fully_connected_op.fused_activation_function() != "NONE")
698       return failure();
699 
700     auto new_activation_func = rewriter.getStringAttr(Act);
701     auto new_weights_format =
702         rewriter.getStringAttr(fully_connected_op.weights_format());
703     auto new_keep_num_dims =
704         rewriter.getBoolAttr(fully_connected_op.keep_num_dims());
705     auto fc = rewriter.create<FullyConnectedOp>(
706         FusedLoc::get(relu_op.getContext(),
707                       {fully_connected_op.getLoc(), relu_op.getLoc()}),
708         relu_op.getType(), fully_connected_op.input(),
709         fully_connected_op.filter(), fully_connected_op.bias(),
710         new_activation_func, new_weights_format, new_keep_num_dims);
711     rewriter.replaceOp(relu_op, fc.output());
712 
713     return success();
714   }
715 };
716 
717 // Fuse Mul with proceeding FullyConnected.
718 // TODO(b/136285429): Move to tablegen when variadic is supported
719 struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
720   using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
721 
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseFullyConnectedAndMul722   LogicalResult matchAndRewrite(TFL::MulOp mul_op,
723                                 PatternRewriter &rewriter) const override {
724     // If we are broadcasting on the lhs then don't fold the multiply as it
725     // would increase the amount of compute done by the fully connected op.
726     if (mul_op.lhs().getType() != mul_op.getType()) return failure();
727 
728     // Mul.
729     DenseElementsAttr cst;
730     Value constant_val = mul_op.rhs();
731     if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
732 
733     // Fully Connected.
734     auto fc_op =
735         dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
736     if (!fc_op) return failure();
737     Value filter = fc_op.filter();
738     Value bias = fc_op.bias();
739     ElementsAttr cst_tmp;
740     if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
741     if (!bias.getType().isa<NoneType>() &&
742         !matchPattern(bias, m_Constant(&cst_tmp)))
743       return failure();
744     if (fc_op.fused_activation_function() != "NONE") return failure();
745 
746     // Only fuse multiplier if all dimensions other than the depth dimension
747     // are equal to 1 since otherwise
748     // `matmul(x, filter) * cst != matmul(x, filter * cst)`
749     // even if `filter` and `cst` are be broadcastable.
750     auto shape = cst.getType().getShape();
751     if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
752 
753     int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
754     // Expand and transpose the multiplier since weights are using the
755     // OHWI data format in TFLite.
756     int64_t normalized_shape[2] = {element_size, 1};
757     auto new_cst = cst.reshape(RankedTensorType::get(
758         normalized_shape, cst.getType().getElementType()));
759     Type new_type = new_cst.getType();
760     if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
761       return failure();
762     }
763 
764     auto new_op =
765         rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
766     Value new_const_val = new_op.getResult();
767 
768     // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
769     // TF::MulOp is used to fold the constant.
770     // TODO(b/139192933): switch to the TFL constant folding
771     auto new_filter =
772         rewriter.create<TF::MulOp>(mul_op.getLoc(), filter, new_const_val).z();
773     // If bias isn't None, it needs to be multiplied as well.
774     if (!bias.getType().isa<NoneType>()) {
775       bias =
776           rewriter.create<TF::MulOp>(mul_op.getLoc(), bias, constant_val).z();
777     }
778 
779     auto fc = rewriter.create<TFL::FullyConnectedOp>(
780         FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), mul_op.getLoc()}),
781         mul_op.getType(),
782         /*input=*/fc_op.input(),
783         /*filter=*/new_filter,
784         /*bias=*/bias,
785         /*fused_activation_function=*/
786         rewriter.getStringAttr(mul_op.fused_activation_function()),
787         /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
788         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
789     rewriter.replaceOp(mul_op, fc.output());
790 
791     return success();
792   }
793 };
794 
795 // Fuse Mul with proceeding Affine ops. This is an C++ implementation of the
796 // following table gen implementation, which doesn't derived the result type of
797 // the TFL_DequantizeOp.
798 // def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input,
799 //                          (TFL_DequantizeOp (TFL_QuantizeOp
800 //                              (ConstantOp F32ElementsAttr:$filter), $qtype)),
801 //                          (ConstantOp F32ElementsAttr:$bias),
802 //                          $h_factor, $w_factor, TFL_AF_None,
803 //                          $padding, $stride_h, $stride_w),
804 //                      (ConstantOp F32ElementsAttr:$value), $act_fn),
805 //           (TFL_Conv2DOp $input,
806 //                      (TFL_DequantizeOp (TFL_QuantizeOp
807 //                          (TFL_MulOp (ConstantOp $filter),
808 //                                     (ConstantOp (ExpandTo4DForConv $value)),
809 //                                      TFL_AF_None),
810 //                          (RescaleQtype $qtype, $value))),
811 //                      (TFL_MulOp (ConstantOp $bias), (ConstantOp $value),
812 //                          TFL_AF_None),
813 //                      $h_factor, $w_factor, $act_fn,
814 //                      $padding, $stride_h, $stride_w),
815 //         [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
816 //          (HasOneUse $conv_output),
817 //          (IsPerAxisQuantization $qtype), // per-axis quantization
818 //         ]>;
819 template <typename AffineOpType>
820 struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
821   using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
822 
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseAffinOpAndMulWithQDQs823   LogicalResult matchAndRewrite(TFL::MulOp mul_op,
824                                 PatternRewriter &rewriter) const override {
825     // Mul. Required 1-D rhs for batch normalization.
826     DenseElementsAttr gamma_cst;
827     Value gamma = mul_op.rhs();
828     if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
829     if (gamma_cst.getType().getRank() != 1) return failure();
830 
831     // Affine op
832     Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
833     auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
834     if (!fc_op) return failure();
835     Value filter = fc_op.filter();
836     Value bias = fc_op.bias();
837 
838     // QDQs
839     auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
840     if (!dq_op) return failure();
841     auto q_op =
842         dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
843     if (!q_op) return failure();
844     filter = q_op.input();
845 
846     // weight constant
847     ElementsAttr cst_tmp;
848     if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
849     if (!bias.getType().isa<NoneType>() &&
850         !matchPattern(bias, m_Constant(&cst_tmp)))
851       return failure();
852     if (fc_op.fused_activation_function() != "NONE") return failure();
853 
854     // Broadcast the constant operand of Mul if it isn't compatible to the
855     // filter input. We only support broadcasting the operand along the depth
856     // dimension, when the operand's depth is 1.
857     rewriter.setInsertionPoint(q_op);
858     Location loc = fc_op.getLoc();
859     Value broadcasted_gamma;
860     if (isa<TFL::Conv2DOp>(mul_op_lhs)) {
861       auto mul_rhs = ExpandTo4DForConv(gamma_cst);
862       broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
863     } else if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
864       auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
865       broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
866     } else {
867       return failure();
868     }
869 
870     // Make sure that the fused bias will be a 1D tensor.
871     if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
872       auto gamma_shape = gamma.getType().cast<ShapedType>();
873       if (!gamma_shape.hasRank() || gamma_shape.getRank() != 1) {
874         return failure();
875       }
876     }
877 
878     // Rewrite filter constant. Since the folder of TFL::MulOp couldn't
879     // broadcast the operands, TF::MulOp is used to fold the constant.
880     auto new_filter =
881         rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
882     // Update the scale in the quantize op.
883     auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
884     if (!new_qtype) return failure();
885     rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
886                                                  new_filter, new_qtype);
887 
888     // If bias isn't None, it needs to be multiplied as well.
889     if (!bias.getType().isa<NoneType>()) {
890       rewriter.setInsertionPoint(fc_op);
891       auto new_bias = rewriter.create<TF::MulOp>(loc, bias, gamma);
892       fc_op.getOperation()->replaceUsesOfWith(bias, new_bias);
893     }
894 
895     // Remove the tailing mul op.
896     mul_op.replaceAllUsesWith(fc_op.getResult());
897     return success();
898   }
899 };
900 
901 using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs<TFL::Conv2DOp>;
902 using FuseDepthwiseConv2DAndMulWithQDQs =
903     FuseAffinOpAndMulWithQDQs<TFL::DepthwiseConv2DOp>;
904 
905 // Fuse Binary Op with following Affine operation.
906 template <typename AffineOpType>
907 struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
908   using OpRewritePattern<AffineOpType>::OpRewritePattern;
909 
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseBinaryOpToFollowingAffineOp910   LogicalResult matchAndRewrite(AffineOpType fc_op,
911                                 PatternRewriter &rewriter) const override {
912     // Binary op.
913     Operation *binary_op = fc_op.input().getDefiningOp();
914     if (!binary_op || binary_op->getNumOperands() != 2) return failure();
915     // We only handle the cases the RHS is a scalar.
916     // TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
917     // the constant operands are on the RHS, we need to consider LHS constant
918     // operand if necessary.
919     DenseFPElementsAttr cst;
920     if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
921       return failure();
922     if (cst.getNumElements() != 1) return failure();
923     APFloat cst_value = *cst.float_value_begin();
924 
925     // Affine op.
926     Value filter = fc_op.filter();
927     Value bias = fc_op.bias();
928     DenseFPElementsAttr filter_cst, bias_cst;
929     if (!matchPattern(filter, m_Constant(&filter_cst))) {
930       // The filter maybe quantized, then we should set it to the real constant.
931       auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
932       if (!dq) return failure();
933       auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
934       if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
935         return failure();
936       }
937       filter = q.input();
938     }
939     if (!bias.getType().isa<NoneType>() &&
940         !matchPattern(bias, m_Constant(&bias_cst)))
941       return failure();
942     auto binary_op_activation_func =
943         binary_op->template getAttrOfType<StringAttr>(
944             "fused_activation_function");
945     if (!binary_op_activation_func ||
946         binary_op_activation_func.getValue() != "NONE")
947       return failure();
948     ShapedType filter_type = filter_cst.getType();
949 
950     if (llvm::isa<AddOp, SubOp>(binary_op)) {
951       auto padding = fc_op->template getAttrOfType<StringAttr>("padding");
952       if (padding && padding.getValue() != "VALID") return failure();
953 
954       // The fusion of add/sub is actually applying the following
955       // transformation:
956       // w * (x + c) + b => w * x + (w * c + b)
957       // so we have to update the bias.
958       if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
959 
960       auto bias_and_slice =
961           GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
962       int64_t bias_size = bias_and_slice.first;
963       int64_t slice_size = bias_and_slice.second;
964       ShapedType new_bias_type =
965           RankedTensorType::get({bias_size}, filter_type.getElementType());
966 
967       // The new bias should be a 1-D tensor with length equals to the bias
968       // dimension of the weight.
969       SmallVector<APFloat, 4> new_bias_values;
970       if (bias.getType().isa<NoneType>()) {  // none bias, a list of zeros
971         new_bias_values.resize(bias_size,
972                                APFloat::getZero(cst_value.getSemantics()));
973       } else if (bias_cst.getNumElements() == 1) {  // scalar bias, broadcast it
974         new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
975       } else if (bias_cst.getNumElements() == bias_size) {  // 1-d bias, copy it
976         new_bias_values.insert(new_bias_values.begin(),
977                                bias_cst.float_value_begin(),
978                                bias_cst.float_value_end());
979       } else {
980         return failure();
981       }
982 
983       int64_t flatten_index = 0;
984       for (auto fp_it = filter_cst.float_value_begin(),
985                 fp_end = filter_cst.float_value_end();
986            fp_it != fp_end; ++fp_it) {
987         int bias_index = (flatten_index++ / slice_size) % bias_size;
988 
989         new_bias_values[bias_index] =
990             new_bias_values[bias_index] + *fp_it * cst_value;
991       }
992       auto new_bias = DenseFPElementsAttr::get(new_bias_type, new_bias_values);
993       auto new_bias_op =
994           rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
995       fc_op.setOperand(0, binary_op->getOperand(0));
996       fc_op.setOperand(2, new_bias_op);
997     } else if (llvm::isa<MulOp, DivOp>(binary_op)) {
998       // The fusion of mul/div is actually applying the following
999       // transformation:
1000       // w * (x ' c) + b => (w ' c) x + b
1001       // so we have to update the weight.
1002       bool is_mul = llvm::isa<MulOp>(binary_op);
1003       auto new_filter =
1004           filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
1005             return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
1006           });
1007       // We recreate the constant op in case it is shared by the other ops. This
1008       // might increase the model size.
1009       auto new_filter_op = rewriter.create<ConstOp>(
1010           fc_op.getLoc(), filter.getType(), new_filter);
1011       fc_op.setOperand(0, binary_op->getOperand(0));
1012       if (fc_op.filter() != filter) {
1013         // This filter goes through quantize and dequantize ops. Then we just
1014         // need to update the weight to the quantize op.
1015         filter.replaceAllUsesWith(new_filter_op);
1016       } else {
1017         // This filter doesn't go through quantize and dequantize ops, Then
1018         // we update the weight of the affine op directly.
1019         fc_op.setOperand(1, new_filter_op);
1020       }
1021     } else {
1022       return failure();
1023     }
1024     return success();
1025   }
1026 
1027  private:
1028   // Returns the dimension length of the channel dimension and also the slide
1029   // size by each position in the channel dimension accordingly. tfl.conv2d and
1030   // tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d
1031   // has tailing channel dimension. This function is to provide a utility to
1032   // create the above information from the op property.
GetBiasDimAndSliceSizemlir::TFL::__anon15858f2e0111::FuseBinaryOpToFollowingAffineOp1033   static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
1034       ArrayRef<int64_t> filter_shape, AffineOpType op) {
1035     // Channel dimension index is specified as op property
1036     auto channel_index_iter = filter_shape.begin();
1037     std::advance(channel_index_iter, op.GetChannelDimIndex());
1038     // The slide size is the size of the data in higher dimensions.
1039     int64_t slice_size =
1040         std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
1041                         std::multiplies<int64_t>());
1042     return {*channel_index_iter, slice_size};
1043   }
1044 };
1045 
1046 // If the operand to a broadcastable op is a splat constant, try to replace it
1047 // with a 0-d constant, e.g. before this optimization,
1048 //   %cst = constant dense<1.0> : tensor<16x16x4xf32>
1049 //   %0 = "tfl.conv_2d"...
1050 //   %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
1051 // After this optimization:
1052 //   %cst = constant dense<1.0> : tensor<f32>
1053 //   %0 = "tfl.conv_2d"...
1054 //   %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
1055 // This pattern can enable more fusing opportunities when the binary op is
1056 // following conv ops.
1057 template <typename BinaryOpType>
1058 struct ScalarizeSplatConstantForBroadcastableOps
1059     : public OpRewritePattern<BinaryOpType> {
1060   using OpRewritePattern<BinaryOpType>::OpRewritePattern;
1061 
matchAndRewritemlir::TFL::__anon15858f2e0111::ScalarizeSplatConstantForBroadcastableOps1062   LogicalResult matchAndRewrite(BinaryOpType binary_op,
1063                                 PatternRewriter &rewriter) const override {
1064     DenseElementsAttr splat_elements_attr;
1065     if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
1066       return failure();
1067     }
1068 
1069     constexpr int kSplatOperandIndex = 1;
1070     auto result_type =
1071         binary_op.getResult().getType().template cast<ShapedType>();
1072     mlir::Value non_splat_operand =
1073         binary_op.getOperand(1 - kSplatOperandIndex);
1074     auto non_splat_operand_type =
1075         non_splat_operand.getType().cast<ShapedType>();
1076     // If the other operand's shape does not equal to the result shape, then we
1077     // cannot scalarize the splat constant because the result shape relies on
1078     // the splat constant op's shape for broadcasting.
1079     if (!non_splat_operand_type.hasStaticShape() ||
1080         non_splat_operand_type.getShape() != result_type.getShape() ||
1081         non_splat_operand_type.getRank() > 4) {
1082       return failure();
1083     }
1084 
1085     // If non-splat operand is not fusable affine ops, then no need to apply
1086     // this transformation.
1087     if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
1088       return failure();
1089     }
1090 
1091     // Creates a new scalar constant op using the splat value.
1092     mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
1093     auto scalar_elements_attr = DenseElementsAttr::get(
1094         RankedTensorType::get({},
1095                               splat_elements_attr.getType().getElementType()),
1096         splat_elements_attr.getSplatValue());
1097 
1098     auto scalar_constant_op = rewriter.create<ConstantOp>(
1099         splat_operand.getLoc(), scalar_elements_attr.getType(),
1100         scalar_elements_attr);
1101 
1102     binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
1103     return success();
1104   }
1105 
1106  private:
1107   // Returns true if this value is a splat constant op which can be scalarized.
1108   // Also returns the elements attr if this value is indeed a splat constant.
IsScalarizableSplatConstantmlir::TFL::__anon15858f2e0111::ScalarizeSplatConstantForBroadcastableOps1109   bool IsScalarizableSplatConstant(mlir::Value value,
1110                                    DenseElementsAttr *elements_attr) const {
1111     if (!matchPattern(value, m_Constant(elements_attr))) {
1112       return false;
1113     }
1114     auto element_type = value.getType().cast<ShapedType>().getElementType();
1115     // Ignore per-axis quantized constants because after converting to scalar,
1116     // we will lose per-axis qantization parameter.
1117     if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
1118       return false;
1119     }
1120     if (IsScalar(value)) {
1121       return false;
1122     }
1123     return elements_attr->isSplat();
1124   }
1125 
1126   // If this type is a scalar shaped type.
IsScalarmlir::TFL::__anon15858f2e0111::ScalarizeSplatConstantForBroadcastableOps1127   bool IsScalar(mlir::Value value) const {
1128     auto type = value.getType().dyn_cast<ShapedType>();
1129     if (!type) {
1130       return false;
1131     }
1132     if (!type.hasStaticShape()) {
1133       return false;
1134     }
1135     return type.getNumElements() == 1;
1136   }
1137 
1138   // Returns true if we can fuse an affine op with consuming binary op.
CanFuseAffineOpmlir::TFL::__anon15858f2e0111::ScalarizeSplatConstantForBroadcastableOps1139   bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
1140     if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
1141                          TFL::FullyConnectedOp>(affine_op)) {
1142       return false;
1143     }
1144     DenseElementsAttr value;
1145     // Check that bias are constants if not none.
1146     Value bias = affine_op->getOperand(2);
1147     if (!bias.getType().isa<NoneType>() &&
1148         !matchPattern(bias, m_Constant(&value))) {
1149       return false;
1150     }
1151     // If the binary op is mul/div, also check that filter is constant.
1152     if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
1153         !matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
1154       return false;
1155     }
1156 
1157     // We can only fuse F32/BF16.
1158     auto is_fusable_type = [](Type t) {
1159       Type element_type = t;
1160       if (auto shaped_type = t.dyn_cast<ShapedType>()) {
1161         element_type = shaped_type.getElementType();
1162       }
1163       return element_type.isBF16() || element_type.isF32();
1164     };
1165     for (Type t : binary_op->getOperandTypes()) {
1166       if (!is_fusable_type(t)) {
1167         return false;
1168       }
1169     }
1170 
1171     return true;
1172   }
1173 };
1174 
1175 using ScalarizeSplatConstantForSub =
1176     ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
1177 using ScalarizeSplatConstantForAdd =
1178     ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
1179 using ScalarizeSplatConstantForMul =
1180     ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
1181 using ScalarizeSplatConstantForDiv =
1182     ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
1183 
1184 struct ConvertTrivialTransposeOpToReshapeOp
1185     : public OpRewritePattern<TFL::TransposeOp> {
1186   using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
1187 
matchAndRewritemlir::TFL::__anon15858f2e0111::ConvertTrivialTransposeOpToReshapeOp1188   LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
1189                                 PatternRewriter &rewriter) const override {
1190     auto input_type = transpose_op.input().getType().cast<ShapedType>();
1191     auto output_type = transpose_op.output().getType().cast<ShapedType>();
1192     // It's possible to know if the transformation is safe only if the input
1193     // & output shapes are fully known and permutation is a constant.
1194     if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
1195       return failure();
1196     Value perm = transpose_op.perm();
1197     DenseElementsAttr perm_values_attr;
1198     if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
1199 
1200     auto input_shape = input_type.getShape();
1201     SmallVector<int64_t, 8> perm_values;
1202     for (const auto &dim : perm_values_attr.getIntValues())
1203       perm_values.push_back(dim.getSExtValue());
1204 
1205     // This should never happen unless the input graph is malformed.
1206     if (input_shape.size() != perm_values.size()) {
1207       transpose_op.emitError(
1208           "TransposeOP has inconsistent input and perm values.");
1209     }
1210 
1211     SmallVector<int, 8> old_major_index_ordering;
1212     SmallVector<int, 8> new_major_index_ordering;
1213     for (int i = 0, end = input_shape.size(); i < end; i++) {
1214       if (input_shape[i] != 1) {
1215         old_major_index_ordering.push_back(i);
1216       }
1217 
1218       if (input_shape[perm_values[i]] != 1) {
1219         new_major_index_ordering.push_back(perm_values[i]);
1220       }
1221     }
1222     if (old_major_index_ordering != new_major_index_ordering) {
1223       return failure();
1224     }
1225 
1226     // Rewrite.
1227     Location loc = transpose_op.getLoc();
1228 
1229     SmallVector<int32_t, 8> output_shape_values;
1230     for (auto dim : output_type.getShape()) {
1231       output_shape_values.push_back(dim);
1232     }
1233     auto type = mlir::RankedTensorType::get(output_shape_values.size(),
1234                                             rewriter.getIntegerType(32));
1235     auto new_shape_attr =
1236         mlir::DenseIntElementsAttr::get(type, output_shape_values);
1237     auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr);
1238 
1239     rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
1240         transpose_op, transpose_op.output().getType(), transpose_op.input(),
1241         new_shape);
1242 
1243     return success();
1244   }
1245 };
1246 
1247 // Remove Reshape before FullyConnected when `keep_num_dims=false` and Reshape
1248 // does not alter the last dimension as FullyConnected will collapse all other
1249 // dimensions into a single dimension. For example,
1250 //
1251 //   %shape = constant dense<[1, 128, 64]> : tensor<3xi32>
1252 //   %reshape = tfl.reshape(%input, %shape) // %input: tensor<128x64xf32>
1253 //   %fc = tfl.fully_connected(%reshape, %filter, %bias)
1254 //           {keep_num_dims = false, weights_format = "DEFAULT"}
1255 //
1256 // can be canonicalized to
1257 //
1258 //   %fc = tfl.fully_connected(%input, %filter, %bias)
1259 //           {keep_num_dims = false, weights_format = "DEFAULT"}
1260 struct RemoveReshapeBeforeFullyConnected
1261     : public OpRewritePattern<TFL::FullyConnectedOp> {
1262   using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
1263 
matchAndRewritemlir::TFL::__anon15858f2e0111::RemoveReshapeBeforeFullyConnected1264   LogicalResult matchAndRewrite(TFL::FullyConnectedOp fully_connected_op,
1265                                 PatternRewriter &) const override {
1266     auto input = fully_connected_op.input();
1267     auto input_ty = input.getType().dyn_cast<ShapedType>();
1268     auto output_ty = fully_connected_op.output()[0]
1269                          .getType()
1270                          .template dyn_cast<ShapedType>();
1271     if (!input_ty.hasStaticShape() ||
1272         fully_connected_op.weights_format() != "DEFAULT" ||
1273         fully_connected_op.keep_num_dims() || !output_ty.hasStaticShape() ||
1274         output_ty.getRank() != 2) {
1275       return failure();
1276     }
1277 
1278     auto reshape_op = input.getDefiningOp<TFL::ReshapeOp>();
1279     if (!reshape_op) return failure();
1280 
1281     // Check if the last dimension does not change after reshape.
1282     auto reshape_input = reshape_op.input();
1283     auto reshape_input_ty = reshape_input.getType().dyn_cast<ShapedType>();
1284     if (!reshape_input_ty.hasStaticShape() || input_ty.getRank() == 0 ||
1285         reshape_input_ty.getRank() == 0 ||
1286         input_ty.getDimSize(input_ty.getRank() - 1) !=
1287             reshape_input_ty.getDimSize(reshape_input_ty.getRank() - 1)) {
1288       return failure();
1289     }
1290 
1291     // Connect the input to the one of reshape.
1292     fully_connected_op.setOperand(0, reshape_input);
1293     return success();
1294   }
1295 };
1296 
1297 // Remove Reshape after FullyConnected when `keep_num_dims=false`, the Reshaoe
1298 // does not alter the last dimension and it restores the batch dimensions
1299 // collapsed by the FullyConnected op due to `keep_num_dims=false`. For example,
1300 //
1301 //   // %input: tensor<4x16x32xf32>
1302 //   %fc = tfl.fully_connected(%input, %filter, %bias)
1303 //           {keep_num_dims = false, weights_format = "DEFAULT"}
1304 //   %shape = constant dense<[4, 16, 32]> : tensor<3xi32>
1305 //   %rs = tfl.reshape(%fc, %shape)
1306 //
1307 // can be canonicalized to
1308 //
1309 //   %fc = tfl.fully_connected(%input, %filter, %bias)
1310 //           {keep_num_dims = true, weights_format = "DEFAULT"}
1311 struct RemoveReshapeAfterFullyConnected
1312     : public OpRewritePattern<TFL::ReshapeOp> {
1313   using OpRewritePattern::OpRewritePattern;
1314 
matchAndRewritemlir::TFL::__anon15858f2e0111::RemoveReshapeAfterFullyConnected1315   LogicalResult matchAndRewrite(TFL::ReshapeOp reshape_op,
1316                                 PatternRewriter &rewriter) const override {
1317     auto fully_connected_op = llvm::dyn_cast_or_null<TFL::FullyConnectedOp>(
1318         reshape_op.input().getDefiningOp());
1319     if (!fully_connected_op || fully_connected_op.getNumResults() != 1 ||
1320         fully_connected_op.weights_format() != "DEFAULT" ||
1321         fully_connected_op.keep_num_dims())
1322       return failure();
1323     if (!reshape_op.input().hasOneUse()) return failure();
1324 
1325     auto input_shape = fully_connected_op.input().getType().cast<ShapedType>();
1326     auto output_shape = fully_connected_op.getType(0).cast<ShapedType>();
1327     auto reshape_shape = reshape_op.getType().cast<ShapedType>();
1328     if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() ||
1329         !reshape_shape.hasStaticShape())
1330       return failure();
1331 
1332     // Check that the reshape doesn't modify the last dimension and it restores
1333     // the input (batch) dimension with the exception of the feature (last)
1334     // dimension.
1335     if (output_shape.getShape().empty() || reshape_shape.getShape().empty() ||
1336         output_shape.getShape().back() != reshape_shape.getShape().back() ||
1337         input_shape.getShape().drop_back() !=
1338             reshape_shape.getShape().drop_back())
1339       return failure();
1340 
1341     llvm::SmallVector<Type, 1> output_type{reshape_op.getType()};
1342     rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
1343         reshape_op, output_type, fully_connected_op.input(),
1344         fully_connected_op.filter(), fully_connected_op.bias(),
1345         fully_connected_op.fused_activation_function(),
1346         fully_connected_op.weights_format(), /*keep_num_dims=*/true);
1347     return success();
1348   }
1349 };
1350 
1351 // Fuses Unpack with proceeding Concatenation to Reshape if output type has
1352 // static shape and activation function is none. For example:
1353 //
1354 //   // %input: tensor<1x3x2xf32>
1355 //   %unpack:3 = "tfl.unpack"(%input) {axis = 1 : i32, num = 3 : i32}
1356 //   %res = "tfl.concatenation"(%unpack#0, %unpack#1, %unpack#2)
1357 //        {axis = -1 : i32, fused_activation_function = "NONE"}
1358 //
1359 // can be optimized to
1360 //
1361 //   %cst = constant dense<[1, 6]> : tensor<2xi32>
1362 //   %res = "tfl.reshape"(%input, %cst)
1363 struct FuseUnpackAndConcatToReshape
1364     : public OpRewritePattern<TFL::ConcatenationOp> {
1365   using OpRewritePattern::OpRewritePattern;
1366 
matchAndRewritemlir::TFL::__anon15858f2e0111::FuseUnpackAndConcatToReshape1367   LogicalResult matchAndRewrite(TFL::ConcatenationOp concat_op,
1368                                 PatternRewriter &rewriter) const override {
1369     if (concat_op.fused_activation_function() != "NONE") {
1370       return failure();
1371     }
1372 
1373     // Checks all operands come from the same unpack op.
1374     auto first_operand = concat_op.values().front();
1375     auto unpack_op =
1376         dyn_cast_or_null<TFL::UnpackOp>(first_operand.getDefiningOp());
1377     if (!unpack_op || unpack_op.getNumResults() != concat_op.getNumOperands()) {
1378       return failure();
1379     }
1380     for (auto &index_and_value : llvm::enumerate(concat_op.values())) {
1381       if (index_and_value.value() !=
1382           unpack_op.getResult(index_and_value.index())) {
1383         return failure();
1384       }
1385     }
1386 
1387     auto output_type = concat_op.getType().cast<ShapedType>();
1388     if (!output_type.hasStaticShape()) {
1389       return failure();
1390     }
1391 
1392     auto new_shape_array = output_type.getShape();
1393     // This is to workaround the unnecessary cast i64 -> i32.
1394     SmallVector<int32_t, 4> new_shape_array_i32;
1395     for (auto size : new_shape_array) {
1396       new_shape_array_i32.push_back(static_cast<int32_t>(size));
1397     }
1398     auto new_shape = rewriter.create<TFL::ConstOp>(
1399         concat_op.getLoc(),
1400         DenseIntElementsAttr::get(
1401             RankedTensorType::get(new_shape_array_i32.size(),
1402                                   rewriter.getIntegerType(32)),
1403             new_shape_array_i32));
1404 
1405     rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(concat_op, output_type,
1406                                                 unpack_op.input(), new_shape);
1407     return success();
1408   }
1409 };
1410 
1411 using FuseBinaryOpToFollowingFullyConnected =
1412     FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
1413 using FuseBinaryOpToFollowingDepthwiseConv2D =
1414     FuseBinaryOpToFollowingAffineOp<DepthwiseConv2DOp>;
1415 using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp<Conv2DOp>;
1416 
1417 // Adds canonicalization patterns to the list of patterns.
AddCanonicalizationPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1418 void AddCanonicalizationPatterns(MLIRContext *context,
1419                                  OwningRewritePatternList *patterns) {
1420   for (auto *op : context->getRegisteredOperations())
1421     op->getCanonicalizationPatterns(*patterns, context);
1422 }
1423 
runOnFunction()1424 void OptimizePass::runOnFunction() {
1425   OwningRewritePatternList patterns(&getContext());
1426   auto *ctx = &getContext();
1427   auto func = getFunction();
1428 
1429   // Merge reshapes into fully connected ops before we start moving them past
1430   // binary ops.
1431   OwningRewritePatternList phase_0_patterns(&getContext());
1432   phase_0_patterns.insert<RemoveReshapeAfterFullyConnected,
1433                           RemoveReshapeBeforeFullyConnected>(ctx);
1434   (void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns));
1435 
1436   // Potentially the binary ops might be fused together, like hard_swish, thus
1437   // we explore these potentially first and then fuse the binary ops with the
1438   // following ops in a second pattern match.
1439   TFL::populateWithGenerated(patterns);
1440   patterns.insert<FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
1441                   FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
1442                   FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
1443                   FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
1444                   FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>>(ctx);
1445   if (enable_canonicalization_) AddCanonicalizationPatterns(ctx, &patterns);
1446   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1447 
1448   // Fuse the binary ops with the following ops.
1449   OwningRewritePatternList phase_2_patterns(&getContext());
1450   TFL::populateWithGenerated(phase_2_patterns);
1451   phase_2_patterns.insert<
1452       ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
1453       ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
1454       FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
1455       FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
1456       FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
1457       FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
1458       FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
1459       FuseBinaryOpToFollowingConv2D, FuseBinaryOpToFollowingDepthwiseConv2D,
1460       FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
1461       FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp,
1462       RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected,
1463       FuseUnpackAndConcatToReshape>(ctx);
1464   if (enable_canonicalization_)
1465     AddCanonicalizationPatterns(ctx, &phase_2_patterns);
1466   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1467 }
1468 }  // namespace
1469 
1470 // Creates an instance of the TensorFlow Lite dialect Optimize pass.
CreateOptimizePass(bool enable_canonicalization)1471 std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass(
1472     bool enable_canonicalization) {
1473   return std::make_unique<OptimizePass>(enable_canonicalization);
1474 }
1475 
1476 static PassRegistration<OptimizePass> pass;
1477 
1478 }  // namespace TFL
1479 }  // namespace mlir
1480