• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass takes operations in TensorFlowLite dialect and
17 // optimizes them to resulting operations in TensorFlowLite dialect.
18 
19 #include <algorithm>
20 #include <climits>
21 #include <cstdint>
22 #include <functional>
23 #include <iterator>
24 #include <map>
25 #include <numeric>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/None.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/SmallSet.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringRef.h"
35 #include "llvm/ADT/StringSwitch.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
39 #include "mlir/IR/Attributes.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
41 #include "mlir/IR/Matchers.h"  // from @llvm-project
42 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
43 #include "mlir/IR/Value.h"  // from @llvm-project
44 #include "mlir/Pass/Pass.h"  // from @llvm-project
45 #include "mlir/Support/LLVM.h"  // from @llvm-project
46 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
47 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
48 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
49 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
50 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
52 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
53 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
54 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
55 
56 namespace mlir {
57 namespace TFL {
58 
59 //===----------------------------------------------------------------------===//
60 // The actual Optimize Pass.
61 namespace {
62 constexpr char kRelu[] = "RELU";
63 constexpr char kRelu6[] = "RELU6";
64 constexpr char kRelu1[] = "RELU_N1_TO_1";
65 
L2NormalizeReduceAxis(Value sq_op,DenseElementsAttr axis)66 bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
67   if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
68           *axis.getValues<int>().begin() ||
69       *axis.getValues<int>().begin() == -1) {
70     return true;
71   }
72   if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
73     return false;
74   }
75   auto shape = sq_op.getType().cast<ShapedType>();
76   SmallVector<int, 4> elems{axis.getValues<int>().begin(),
77                             axis.getValues<int>().end()};
78   for (int i = 0; i < shape.getRank(); ++i) {
79     if (i != elems[i]) return false;
80   }
81   return true;
82 }
83 
84 using ::llvm::cast;
85 
86 // Optimize TFLite operations in functions.
87 struct Optimize : public PassWrapper<Optimize, FunctionPass> {
88   void runOnFunction() override;
89 };
90 
91 // Returns whether the given type `a` is broadcast-compatible with `b`.
IsBroadcastableElementsAttrAndType(Type a,Type b)92 bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
93   return OpTrait::util::getBroadcastedType(a, b) != Type();
94 }
95 
96 // Returns whether the resultant type of any broadcastable operation with
97 // operands `a` and `b` matches `expected_output`. Returns false if `a` is not
98 // broadcast-compatible with `b`.
OperandsBroadcastToOutputType(Type a,Type b,Type expected_output)99 bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) {
100   Type output_element_type =
101       expected_output.cast<ShapedType>().getElementType();
102   Type broadcasted_type =
103       OpTrait::util::getBroadcastedType(a, b, output_element_type);
104   return broadcasted_type != Type() && broadcasted_type == expected_output;
105 }
106 
107 // Returns whether if `type1` dimensions are the same as the ending dimensions
108 // of `type2`. This is more restricted than broadcastable.
IsTailOfShape(Type type1,Type type2)109 bool IsTailOfShape(Type type1, Type type2) {
110   auto tail_type = type1.dyn_cast<ShapedType>();
111   auto full_type = type2.dyn_cast<ShapedType>();
112   if (!tail_type || !full_type || !tail_type.hasRank() ||
113       !full_type.hasRank() || tail_type.getRank() > full_type.getRank())
114     return false;
115   auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
116   auto i2 = full_type.getShape().rbegin();
117   return std::equal(i1, e1, i2);
118 }
119 
CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,const ArrayRef<int64_t> elements_shape,bool is_depthwise)120 bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
121                                       const ArrayRef<int64_t> elements_shape,
122                                       bool is_depthwise) {
123   // Make sure the val tensor has shape where all dimensions are 1 except
124   // last one.
125   // Also, val tensor must be of rank 1 or 4 or 0 (scalar).
126   const auto elements_rank = elements_shape.size();
127   for (int i = 0; i < static_cast<int>(elements_shape.size()) - 1; ++i) {
128     if (elements_shape[i] != 1) return false;
129   }
130   if (elements_rank != 1 && elements_rank != 0 && elements_rank != 4) {
131     return false;
132   }
133   auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back();
134   // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we
135   // can let binary op to broadcast elements.
136   if (elements_depth == 1) {
137     return true;
138   }
139 
140   // In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise Conv.
141   // For conv:
142   // Check if last dimension in filter equals the first dimension
143   // For depthwise conv:
144   // Check if the first in filter dimension equals the first dimension.
145   if (filter_shape.empty() ||
146       (is_depthwise ? filter_shape.back() != elements_depth
147                     : filter_shape[0] != elements_depth))
148     return false;
149   return true;
150 }
151 
CanFuseConvOrDepthwiseConv(Value filter,Attribute val,bool is_depthwise)152 bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
153                                 bool is_depthwise) {
154   const auto elements = val.dyn_cast<DenseElementsAttr>();
155   if (!elements) {
156     return false;
157   }
158   const auto elements_shape = elements.getType().getShape();
159   const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
160   return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
161                                           is_depthwise);
162 }
163 
CanFuseConvOrDepthwiseConv(Attribute filter,Attribute val,bool is_depthwise)164 bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
165                                 bool is_depthwise) {
166   if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
167     if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
168       return CanFuseConvOrDepthwiseConvShapes(
169           filter_elements.getType().getShape(), elements.getType().getShape(),
170           is_depthwise);
171     }
172   }
173   return false;
174 }
175 
176 // Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value
177 // of `indices` are from 0 to n-1, the output tensor are identical to the
178 // `params`.
CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,DenseIntElementsAttr indices)179 bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
180                                               DenseIntElementsAttr indices) {
181   auto params_type = params.getType().dyn_cast<RankedTensorType>();
182   auto indices_type = indices.getType().dyn_cast<RankedTensorType>();
183   // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D
184   // `indices` means it gets the first row of `params`. As long as indices
185   // iterate the first row of `params`, the output is identical to input.
186   if (!params_type || !indices_type || indices_type.getRank() != 2 ||
187       indices_type.getDimSize(0) != params_type.getDimSize(0) ||
188       indices_type.getDimSize(1) != 1)
189     return false;
190 
191   // Checks the value in `indices` is from 0 to n-1.
192   int cur_value = 0;
193   for (const auto &v : indices.getValues<APInt>()) {
194     if (v.getSExtValue() != cur_value) return false;
195     ++cur_value;
196   }
197 
198   return true;
199 }
200 
201 // Expand Attribute 'a' to 4D with all 1s except 1 dimension.
202 // Which dimension depends on 'is_depthwise' is true or false.
ExpandTo4DForConvImpl(Attribute a,bool is_depthwise)203 ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
204   auto elements = a.dyn_cast<DenseElementsAttr>();
205   auto shape = elements.getType().getShape();
206   if (!shape.empty()) {
207     // Checks that elements are essentially 1d.
208     assert(elements.getNumElements() == shape.back());
209   }
210   std::vector<int64_t> shape_data = {1, 1, 1, 1};
211   const int vector_length = elements.getNumElements();
212   if (is_depthwise)
213     shape_data[3] = vector_length;
214   else
215     shape_data[0] = vector_length;
216   auto new_shape =
217       RankedTensorType::get(shape_data, elements.getType().getElementType());
218   return elements.reshape(new_shape);
219 }
220 
ExpandTo4DForConv(Attribute a)221 ElementsAttr ExpandTo4DForConv(Attribute a) {
222   return ExpandTo4DForConvImpl(a, false);
223 }
224 
ExpandTo4DForDepthwiseConv(Attribute a)225 ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
226   return ExpandTo4DForConvImpl(a, true);
227 }
228 
RescaleQtype(Type input,Attribute factor)229 TypeAttr RescaleQtype(Type input, Attribute factor) {
230   return quant::RescaleQuantizedType(input, factor);
231 }
232 
233 // Returns shape of a ranked tensor.
234 // Precondition: output_val's is ranked tensor.
GetShape(Value output_val)235 DenseElementsAttr GetShape(Value output_val) {
236   auto output_type = output_val.getType().cast<RankedTensorType>();
237   auto shape_vector = output_type.getShape();
238   std::vector<int32_t> shape;
239   shape.reserve(shape_vector.size());
240   for (auto shape_object : shape_vector) {
241     shape.push_back(shape_object);
242   }
243   return mlir::DenseElementsAttr::get(
244       RankedTensorType::get(
245           {static_cast<int>(shape.size())},
246           mlir::IntegerType::get(output_val.getContext(), 32)),
247       llvm::makeArrayRef(shape));
248 }
249 
GetShapeStrippedType(TypeAttr type_attr)250 static Type GetShapeStrippedType(TypeAttr type_attr) {
251   auto type = type_attr.getValue();
252   auto shaped_type = type.dyn_cast<ShapedType>();
253   if (shaped_type) {
254     return shaped_type.getElementType();
255   } else {
256     return type;
257   }
258 }
259 
260 // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in
261 // the specified `shape` and `false` otherwise.
ShapeMatchesReduceWithKeepAxes(Value input,const mlir::Attribute & axes,const mlir::Attribute & shape)262 static bool ShapeMatchesReduceWithKeepAxes(Value input,
263                                            const mlir::Attribute &axes,
264                                            const mlir::Attribute &shape) {
265   RankedTensorType type = input.getType().dyn_cast_or_null<RankedTensorType>();
266   if (!type) return false;
267 
268   DenseIntElementsAttr axes_attr =
269       axes.dyn_cast_or_null<DenseIntElementsAttr>();
270   DenseIntElementsAttr shape_attr =
271       shape.dyn_cast_or_null<DenseIntElementsAttr>();
272   if (!axes_attr || !shape_attr) return false;
273 
274   if (shape_attr.getNumElements() != type.getRank()) return false;
275 
276   llvm::SmallSet<uint64_t, 4> axes_set;
277   for (auto a : axes_attr.getIntValues()) {
278     axes_set.insert(a.getZExtValue());
279   }
280 
281   auto type_shape = type.getShape();
282   for (uint64_t i = 0; i < type.getRank(); ++i) {
283     if (axes_set.contains(i)) {
284       if (shape_attr.getValue<APInt>({i}) != 1) return false;
285     } else {
286       if (shape_attr.getValue<APInt>({i}) != type_shape[i]) return false;
287     }
288   }
289   return true;
290 }
291 
FloatValueEquals(const Attribute & attr,double value)292 static bool FloatValueEquals(const Attribute &attr, double value) {
293   auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
294   if (!fp_attr) return false;
295 
296   if (fp_attr.isSplat()) {
297     return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
298   }
299   return llvm::all_of(fp_attr.getFloatValues(), [value](const APFloat &f) {
300     return f.isExactlyValue(value);
301   });
302 }
303 
304 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
305 
306 // Fuse Add with proceeding FullyConnected.
307 // TODO(b/136285429): Move to tablegen when variadic is supported
308 struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
309   using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
310 
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseFullyConnectedAndAdd311   LogicalResult matchAndRewrite(TFL::AddOp add_op,
312                                 PatternRewriter &rewriter) const override {
313     // Match Add.
314     DenseElementsAttr added_value;
315     Value constant_val = add_op.rhs();
316     if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
317 
318     // Match Fully Connected.
319     auto fc_op =
320         dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
321     if (!fc_op) return failure();
322 
323     // Check if the constant RHS is either 0D (scalar), or a 1D with
324     // `{num_channels}` shape.
325     auto constant_val_type = constant_val.getType().cast<TensorType>();
326 
327     // In TFLite FullyConnect definition, bias must be a 1D tensor where
328     // the number of elements is equal to the number of channels.
329     // If it's not 1D or 0D (which can be broadcasted to 1D), reject the
330     // matching.
331     bool is_scalar_rhs = false;
332     if (constant_val_type.getRank() == 0) {
333       is_scalar_rhs = true;
334     } else if (constant_val_type.getRank() != 1) {
335       return failure();
336     }
337 
338     Value filter = fc_op.filter();
339     Value bias = fc_op.bias();
340     ElementsAttr bias_value;
341     const bool is_none_bias = bias.getType().isa<NoneType>();
342     if (fc_op.fused_activation_function() != "NONE") return failure();
343 
344     if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
345       return failure();
346 
347     // Rewrite
348     if (is_none_bias) {
349       if (is_scalar_rhs) {
350         // If the `constant_val` is scalar, we must the shape of filter
351         // to properly broadcast the scalar to `{num_channels}` shape.
352 
353         // Get the number of channels if possible.
354         auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
355         // Filter must be a `2D` tensor with `{num_channels, num_features}`
356         // shape. The following check is rejecting unknown rank (-1).
357         if (filter_type == nullptr || filter_type.getRank() != 2) {
358           return failure();
359         }
360         int num_channels = filter_type.getShape()[0];
361 
362         // Create a zero tensor with shape {num_channels}, and the type need to
363         // be the same as constant_val.
364         // This is a way to gracefully handle scalar tensor. The Add will always
365         // be constant-folded away regardless if `constant_val` is a scalar or
366         // not.
367         RankedTensorType type = RankedTensorType::get(
368             {num_channels}, constant_val_type.getElementType());
369         auto attr = rewriter.getZeroAttr(type);
370         bias = rewriter.create<ConstantOp>(add_op.getLoc(), type, attr);
371         auto none_af = rewriter.getStringAttr("NONE");
372         bias =
373             rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
374                 .output();
375       } else {
376         // If there no pre-existing bias and the `constant_val` is 1D, simply
377         // use `constant_val` as bias.
378         bias = constant_val;
379       }
380     } else {
381       auto none_af = rewriter.getStringAttr("NONE");
382       bias =
383           rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
384               .output();
385     }
386 
387     auto fc = rewriter.create<TFL::FullyConnectedOp>(
388         FusedLoc::get({fc_op.getLoc(), add_op.getLoc()}, fc_op.getContext()),
389         add_op.getType(),
390         /*input=*/fc_op.input(),
391         /*filter=*/filter,
392         /*bias=*/bias,
393         /*fused_activation_function=*/
394         rewriter.getStringAttr(add_op.fused_activation_function()),
395         /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
396         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
397     rewriter.replaceOp(add_op, fc.output());
398 
399     return success();
400   }
401 };
402 
403 // TODO(b/136285429): Move to tablegen when variadic is supported.
404 template <typename ReluXOp, char const *Act>
405 struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
406   using OpRewritePattern<ReluXOp>::OpRewritePattern;
407 
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseFullyConnectedAndReluX408   LogicalResult matchAndRewrite(ReluXOp relu_op,
409                                 PatternRewriter &rewriter) const override {
410     Operation *input = relu_op.getOperand().getDefiningOp();
411     if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
412     auto fully_connected_op = cast<FullyConnectedOp>(input);
413     if (fully_connected_op.fused_activation_function() != "NONE")
414       return failure();
415 
416     auto new_activation_func = rewriter.getStringAttr(Act);
417     auto new_weights_format =
418         rewriter.getStringAttr(fully_connected_op.weights_format());
419     auto new_keep_num_dims =
420         rewriter.getBoolAttr(fully_connected_op.keep_num_dims());
421     auto fc = rewriter.create<FullyConnectedOp>(
422         FusedLoc::get({fully_connected_op.getLoc(), relu_op.getLoc()},
423                       relu_op.getContext()),
424         relu_op.getType(), fully_connected_op.input(),
425         fully_connected_op.filter(), fully_connected_op.bias(),
426         new_activation_func, new_weights_format, new_keep_num_dims);
427     rewriter.replaceOp(relu_op, fc.output());
428 
429     return success();
430   }
431 };
432 
433 // Fuse Mul with proceeding FullyConnected.
434 // TODO(b/136285429): Move to tablegen when variadic is supported
435 struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
436   using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
437 
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseFullyConnectedAndMul438   LogicalResult matchAndRewrite(TFL::MulOp mul_op,
439                                 PatternRewriter &rewriter) const override {
440     // If we are broadcasting on the lhs then don't fold the multiply as it
441     // would increase the amount of compute done by the fully connected op.
442     if (mul_op.lhs().getType() != mul_op.getType()) return failure();
443 
444     // Mul.
445     DenseElementsAttr cst;
446     Value constant_val = mul_op.rhs();
447     if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
448 
449     // Fully Connected.
450     auto fc_op =
451         dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
452     if (!fc_op) return failure();
453     Value filter = fc_op.filter();
454     Value bias = fc_op.bias();
455     ElementsAttr cst_tmp;
456     if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
457     if (!bias.getType().isa<NoneType>() &&
458         !matchPattern(bias, m_Constant(&cst_tmp)))
459       return failure();
460     if (fc_op.fused_activation_function() != "NONE") return failure();
461 
462     // Only fuse multiplier if all dimensions other than the depth dimension
463     // are equal to 1 since otherwise
464     // `matmul(x, filter) * cst != matmul(x, filter * cst)`
465     // even if `filter` and `cst` are be broadcastable.
466     auto shape = cst.getType().getShape();
467     if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
468 
469     int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
470     // Expand and transpose the multiplier since weights are using the
471     // OHWI data format in TFLite.
472     int64_t normalized_shape[2] = {element_size, 1};
473     auto new_cst = cst.reshape(RankedTensorType::get(
474         normalized_shape, cst.getType().getElementType()));
475     Type new_type = new_cst.getType();
476     if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
477       return failure();
478     }
479 
480     auto new_op =
481         rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
482     Value new_const_val = new_op.getResult();
483 
484     // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
485     // TF::MulOp is used to fold the constant.
486     // TODO(b/139192933): switch to the TFL constant folding
487     auto new_filter =
488         rewriter.create<TF::MulOp>(mul_op.getLoc(), filter, new_const_val).z();
489     // If bias isn't None, it needs to be multiplied as well.
490     if (!bias.getType().isa<NoneType>()) {
491       bias =
492           rewriter.create<TF::MulOp>(mul_op.getLoc(), bias, constant_val).z();
493     }
494 
495     auto fc = rewriter.create<TFL::FullyConnectedOp>(
496         FusedLoc::get({fc_op.getLoc(), mul_op.getLoc()}, fc_op.getContext()),
497         mul_op.getType(),
498         /*input=*/fc_op.input(),
499         /*filter=*/new_filter,
500         /*bias=*/bias,
501         /*fused_activation_function=*/
502         rewriter.getStringAttr(mul_op.fused_activation_function()),
503         /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
504         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
505     rewriter.replaceOp(mul_op, fc.output());
506 
507     return success();
508   }
509 };
510 
511 // Fuse Mul with proceeding Affine ops. This is an C++ implementation of the
512 // following table gen implementation, which doesn't derived the result type of
513 // the TFL_DequantizeOp.
514 // def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input,
515 //                          (TFL_DequantizeOp (TFL_QuantizeOp
516 //                              (ConstantOp F32ElementsAttr:$filter), $qtype)),
517 //                          (ConstantOp F32ElementsAttr:$bias),
518 //                          $h_factor, $w_factor, TFL_AF_None,
519 //                          $padding, $stride_h, $stride_w),
520 //                      (ConstantOp F32ElementsAttr:$value), $act_fn),
521 //           (TFL_Conv2DOp $input,
522 //                      (TFL_DequantizeOp (TFL_QuantizeOp
523 //                          (TFL_MulOp (ConstantOp $filter),
524 //                                     (ConstantOp (ExpandTo4DForConv $value)),
525 //                                      TFL_AF_None),
526 //                          (RescaleQtype $qtype, $value))),
527 //                      (TFL_MulOp (ConstantOp $bias), (ConstantOp $value),
528 //                          TFL_AF_None),
529 //                      $h_factor, $w_factor, $act_fn,
530 //                      $padding, $stride_h, $stride_w),
531 //         [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
532 //          (HasOneUse $conv_output),
533 //          (IsPerAxisQuantization $qtype), // per-axis quantization
534 //         ]>;
535 template <typename AffineOpType>
536 struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
537   using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
538 
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseAffinOpAndMulWithQDQs539   LogicalResult matchAndRewrite(TFL::MulOp mul_op,
540                                 PatternRewriter &rewriter) const override {
541     // Mul. Required 1-D rhs for batch normalization.
542     DenseElementsAttr gamma_cst;
543     Value gamma = mul_op.rhs();
544     if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
545     if (gamma_cst.getType().getRank() != 1) return failure();
546 
547     // Affine op
548     Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
549     auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
550     if (!fc_op) return failure();
551     Value filter = fc_op.filter();
552     Value bias = fc_op.bias();
553 
554     // QDQs
555     auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
556     if (!dq_op) return failure();
557     auto q_op =
558         dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
559     if (!q_op) return failure();
560     filter = q_op.input();
561 
562     // weight constant
563     ElementsAttr cst_tmp;
564     if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
565     if (!bias.getType().isa<NoneType>() &&
566         !matchPattern(bias, m_Constant(&cst_tmp)))
567       return failure();
568     if (fc_op.fused_activation_function() != "NONE") return failure();
569 
570     // Broadcast the constant operand of Mul if it isn't compatible to the
571     // filter input. We only support broadcasting the operand along the depth
572     // dimension, when the operand's depth is 1.
573     rewriter.setInsertionPoint(q_op);
574     Location loc = fc_op.getLoc();
575     Value broadcasted_gamma;
576     if (isa<TFL::Conv2DOp>(mul_op_lhs)) {
577       auto mul_rhs = ExpandTo4DForConv(gamma_cst);
578       broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
579     } else if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
580       auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
581       broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
582     } else {
583       return failure();
584     }
585 
586     // Rewrite filter constant. Since the folder of TFL::MulOp couldn't
587     // broadcast the operands, TF::MulOp is used to fold the constant.
588     auto new_filter =
589         rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
590     // Update the scale in the quantize op.
591     auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
592     if (!new_qtype) return failure();
593     rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
594                                                  new_filter, new_qtype);
595 
596     // If bias isn't None, it needs to be multiplied as well.
597     if (!bias.getType().isa<NoneType>()) {
598       rewriter.setInsertionPoint(fc_op);
599       auto new_bias = rewriter.create<TF::MulOp>(loc, bias, gamma);
600       fc_op.getOperation()->replaceUsesOfWith(bias, new_bias);
601     }
602 
603     // Remove the tailing mul op.
604     mul_op.replaceAllUsesWith(fc_op.getResult());
605     return success();
606   }
607 };
608 
609 using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs<TFL::Conv2DOp>;
610 using FuseDepthwiseConv2DAndMulWithQDQs =
611     FuseAffinOpAndMulWithQDQs<TFL::DepthwiseConv2DOp>;
612 
613 // Fuse Binary Op with following Affine operation.
614 template <typename AffineOpType>
615 struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
616   using OpRewritePattern<AffineOpType>::OpRewritePattern;
617 
matchAndRewritemlir::TFL::__anonf2f2cbe90111::FuseBinaryOpToFollowingAffineOp618   LogicalResult matchAndRewrite(AffineOpType fc_op,
619                                 PatternRewriter &rewriter) const override {
620     // Binary op.
621     Operation *binary_op = fc_op.input().getDefiningOp();
622     if (!binary_op || binary_op->getNumOperands() != 2) return failure();
623     // We only handle the cases the RHS is a scalar.
624     // TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
625     // the constant operands are on the RHS, we need to consider LHS constant
626     // operand if necessary.
627     DenseFPElementsAttr cst;
628     if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
629       return failure();
630     if (cst.getNumElements() != 1) return failure();
631     APFloat cst_value = *cst.float_value_begin();
632 
633     // Affine op.
634     Value filter = fc_op.filter();
635     Value bias = fc_op.bias();
636     DenseFPElementsAttr filter_cst, bias_cst;
637     if (!matchPattern(filter, m_Constant(&filter_cst))) {
638       // The filter maybe quantized, then we should set it to the real constant.
639       auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
640       if (!dq) return failure();
641       auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
642       if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
643         return failure();
644       }
645       filter = q.input();
646     }
647     if (!bias.getType().isa<NoneType>() &&
648         !matchPattern(bias, m_Constant(&bias_cst)))
649       return failure();
650     ShapedType filter_type = filter_cst.getType();
651 
652     if (llvm::isa<AddOp, SubOp>(binary_op)) {
653       auto padding = fc_op->template getAttrOfType<StringAttr>("padding");
654       if (padding && padding.getValue() != "VALID") return failure();
655 
656       // The fusion of add/sub is actually applying the following
657       // transformation:
658       // w * (x + c) + b => w * x + (w * c + b)
659       // so we have to update the bias.
660       if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
661 
662       auto bias_and_slice =
663           GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
664       int64_t bias_size = bias_and_slice.first;
665       int64_t slice_size = bias_and_slice.second;
666       ShapedType new_bias_type =
667           RankedTensorType::get({bias_size}, filter_type.getElementType());
668 
669       // The new bias should be a 1-D tensor with length equals to the bias
670       // dimension of the weight.
671       SmallVector<APFloat, 4> new_bias_values;
672       if (bias.getType().isa<NoneType>()) {  // none bias, a list of zeros
673         new_bias_values.resize(bias_size, APFloat(0.0));
674       } else if (bias_cst.getNumElements() == 1) {  // scalar bias, broadcast it
675         new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
676       } else if (bias_cst.getNumElements() == bias_size) {  // 1-d bias, copy it
677         new_bias_values.insert(new_bias_values.begin(),
678                                bias_cst.float_value_begin(),
679                                bias_cst.float_value_end());
680       } else {
681         return failure();
682       }
683 
684       int64_t flatten_index = 0;
685       for (auto fp_it = filter_cst.float_value_begin(),
686                 fp_end = filter_cst.float_value_end();
687            fp_it != fp_end; ++fp_it) {
688         int bias_index = (flatten_index++ / slice_size) % bias_size;
689 
690         new_bias_values[bias_index] =
691             new_bias_values[bias_index] + *fp_it * cst_value;
692       }
693       auto new_bias = DenseFPElementsAttr::get(new_bias_type, new_bias_values);
694       auto new_bias_op =
695           rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
696       fc_op.setOperand(0, binary_op->getOperand(0));
697       fc_op.setOperand(2, new_bias_op);
698     } else if (llvm::isa<MulOp, DivOp>(binary_op)) {
699       // The fusion of mul/div is actually applying the following
700       // transformation:
701       // w * (x ' c) + b => (w ' c) x + b
702       // so we have to update the weight.
703       bool is_mul = llvm::isa<MulOp>(binary_op);
704       auto new_filter =
705           filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
706             return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
707           });
708       // We recreate the constant op in case it is shared by the other ops. This
709       // might increase the model size.
710       auto new_filter_op = rewriter.create<ConstOp>(
711           fc_op.getLoc(), filter.getType(), new_filter);
712       fc_op.setOperand(0, binary_op->getOperand(0));
713       if (fc_op.filter() != filter) {
714         // This filter goes through quantize and dequantize ops. Then we just
715         // need to update the weight to the quantize op.
716         filter.replaceAllUsesWith(new_filter_op);
717       } else {
718         // This filter doesn't go through quantize and dequantize ops, Then
719         // we update the weight of the affine op directly.
720         fc_op.setOperand(1, new_filter_op);
721       }
722     } else {
723       return failure();
724     }
725     return success();
726   }
727 
728  private:
729   // Returns the dimension length of the channel dimension and also the slide
730   // size by each position in the channel dimension accordingly. tfl.conv2d and
731   // tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d
732   // has tailing channel dimension. This function is to provide a utility to
733   // create the above information from the op property.
GetBiasDimAndSliceSizemlir::TFL::__anonf2f2cbe90111::FuseBinaryOpToFollowingAffineOp734   static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
735       ArrayRef<int64_t> filter_shape, AffineOpType op) {
736     // Channel dimension index is specified as op property
737     auto channel_index_iter = filter_shape.begin();
738     std::advance(channel_index_iter, op.GetChannelDimIndex());
739     // The slide size is the size of the data in higher dimensions.
740     int64_t slice_size =
741         std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
742                         std::multiplies<int64_t>());
743     return {*channel_index_iter, slice_size};
744   }
745 };
746 
747 // If the operand to a broadcastable op is a splat constant, try to replace it
748 // with a 0-d constant, e.g. before this optimization,
749 //   %cst = constant dense<1.0> : tensor<16x16x4xf32>
750 //   %0 = "tfl.conv_2d"...
751 //   %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
752 // After this optimization:
753 //   %cst = constant dense<1.0> : tensor<f32>
754 //   %0 = "tfl.conv_2d"...
755 //   %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
756 // This pattern can enable more fusing opportunities when the binary op is
757 // following conv ops.
758 template <typename BinaryOpType>
759 struct ScalarizeSplatConstantForBroadcastableOps
760     : public OpRewritePattern<BinaryOpType> {
761   using OpRewritePattern<BinaryOpType>::OpRewritePattern;
762 
matchAndRewritemlir::TFL::__anonf2f2cbe90111::ScalarizeSplatConstantForBroadcastableOps763   LogicalResult matchAndRewrite(BinaryOpType binary_op,
764                                 PatternRewriter &rewriter) const override {
765     DenseElementsAttr splat_elements_attr;
766     if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
767       return failure();
768     }
769 
770     constexpr int kSplatOperandIndex = 1;
771     auto result_type =
772         binary_op.getResult().getType().template cast<ShapedType>();
773     mlir::Value non_splat_operand =
774         binary_op.getOperand(1 - kSplatOperandIndex);
775     auto non_splat_operand_type =
776         non_splat_operand.getType().cast<ShapedType>();
777     // If the other operand's shape does not equal to the result shape, then we
778     // cannot scalarize the splat constant because the result shape relies on
779     // the splat constant op's shape for broadcasting.
780     if (!non_splat_operand_type.hasStaticShape() ||
781         non_splat_operand_type.getShape() != result_type.getShape() ||
782         non_splat_operand_type.getRank() > 4) {
783       return failure();
784     }
785 
786     // If non-splat operand is not fusable affine ops, then no need to apply
787     // this transformation.
788     if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
789       return failure();
790     }
791 
792     // Creates a new scalar constant op using the splat value.
793     mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
794     auto scalar_elements_attr = DenseElementsAttr::get(
795         RankedTensorType::get({},
796                               splat_elements_attr.getType().getElementType()),
797         splat_elements_attr.getSplatValue());
798 
799     auto scalar_constant_op = rewriter.create<ConstantOp>(
800         splat_operand.getLoc(), scalar_elements_attr.getType(),
801         scalar_elements_attr);
802 
803     binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
804     return success();
805   }
806 
807  private:
808   // Returns true if this value is a splat constant op which can be scalarized.
809   // Also returns the elements attr if this value is indeed a splat constant.
IsScalarizableSplatConstantmlir::TFL::__anonf2f2cbe90111::ScalarizeSplatConstantForBroadcastableOps810   bool IsScalarizableSplatConstant(mlir::Value value,
811                                    DenseElementsAttr *elements_attr) const {
812     if (!matchPattern(value, m_Constant(elements_attr))) {
813       return false;
814     }
815     auto element_type = value.getType().cast<ShapedType>().getElementType();
816     // Ignore per-axis quantized constants because after converting to scalar,
817     // we will lose per-axis qantization parameter.
818     if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
819       return false;
820     }
821     if (IsScalar(value)) {
822       return false;
823     }
824     return elements_attr->isSplat();
825   }
826 
827   // If this type is a scalar shaped type.
IsScalarmlir::TFL::__anonf2f2cbe90111::ScalarizeSplatConstantForBroadcastableOps828   bool IsScalar(mlir::Value value) const {
829     auto type = value.getType().dyn_cast<ShapedType>();
830     if (!type) {
831       return false;
832     }
833     if (!type.hasStaticShape()) {
834       return false;
835     }
836     return type.getNumElements() == 1;
837   }
838 
839   // Returns true if we can fuse an affine op with consuming binary op.
CanFuseAffineOpmlir::TFL::__anonf2f2cbe90111::ScalarizeSplatConstantForBroadcastableOps840   bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
841     if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
842                          TFL::FullyConnectedOp>(affine_op)) {
843       return false;
844     }
845     DenseElementsAttr value;
846     // Check that bias are constants if not none.
847     Value bias = affine_op->getOperand(2);
848     if (!bias.getType().isa<NoneType>() &&
849         !matchPattern(bias, m_Constant(&value))) {
850       return false;
851     }
852     // If the binary op is mul/div, also check that filter is constant.
853     if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
854         !matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
855       return false;
856     }
857 
858     // We can only fuse F32/BF16.
859     auto is_fusable_type = [](Type t) {
860       Type element_type = t;
861       if (auto shaped_type = t.dyn_cast<ShapedType>()) {
862         element_type = shaped_type.getElementType();
863       }
864       return element_type.isBF16() || element_type.isF32();
865     };
866     for (Type t : binary_op->getOperandTypes()) {
867       if (!is_fusable_type(t)) {
868         return false;
869       }
870     }
871 
872     return true;
873   }
874 };
875 
876 using ScalarizeSplatConstantForSub =
877     ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
878 using ScalarizeSplatConstantForAdd =
879     ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
880 using ScalarizeSplatConstantForMul =
881     ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
882 using ScalarizeSplatConstantForDiv =
883     ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
884 
885 struct ConvertTrivialTransposeOpToReshapeOp
886     : public OpRewritePattern<TFL::TransposeOp> {
887   using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
888 
matchAndRewritemlir::TFL::__anonf2f2cbe90111::ConvertTrivialTransposeOpToReshapeOp889   LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
890                                 PatternRewriter &rewriter) const override {
891     auto input_type = transpose_op.input().getType().cast<ShapedType>();
892     auto output_type = transpose_op.output().getType().cast<ShapedType>();
893     // It's possible to know if the transformation is safe only if the input
894     // & output shapes are fully known and permutation is a constant.
895     if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
896       return failure();
897     Value perm = transpose_op.perm();
898     DenseElementsAttr perm_values_attr;
899     if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
900 
901     auto input_shape = input_type.getShape();
902     SmallVector<int64_t, 8> perm_values;
903     for (const auto &dim : perm_values_attr.getIntValues())
904       perm_values.push_back(dim.getSExtValue());
905 
906     // This should never happen unless the input graph is malformed.
907     if (input_shape.size() != perm_values.size()) {
908       transpose_op.emitError(
909           "TransposeOP has inconsistent input and perm values.");
910     }
911 
912     SmallVector<int, 8> old_major_index_ordering;
913     SmallVector<int, 8> new_major_index_ordering;
914     for (int i = 0, end = input_shape.size(); i < end; i++) {
915       if (input_shape[i] != 1) {
916         old_major_index_ordering.push_back(i);
917       }
918 
919       if (input_shape[perm_values[i]] != 1) {
920         new_major_index_ordering.push_back(perm_values[i]);
921       }
922     }
923     if (old_major_index_ordering != new_major_index_ordering) {
924       return failure();
925     }
926 
927     // Rewrite.
928     Location loc = transpose_op.getLoc();
929 
930     SmallVector<int32_t, 8> output_shape_values;
931     for (auto dim : output_type.getShape()) {
932       output_shape_values.push_back(dim);
933     }
934     auto type = mlir::RankedTensorType::get(output_shape_values.size(),
935                                             rewriter.getIntegerType(32));
936     auto new_shape_attr =
937         mlir::DenseIntElementsAttr::get(type, output_shape_values);
938     auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr);
939 
940     rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
941         transpose_op, transpose_op.output().getType(), transpose_op.input(),
942         new_shape);
943 
944     return success();
945   }
946 };
947 
948 using FuseBinaryOpToFollowingFullyConnected =
949     FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
950 using FuseBinaryOpToFollowingDepthwiseConv2D =
951     FuseBinaryOpToFollowingAffineOp<DepthwiseConv2DOp>;
952 using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp<Conv2DOp>;
953 
runOnFunction()954 void Optimize::runOnFunction() {
955   OwningRewritePatternList patterns;
956   auto *ctx = &getContext();
957   auto func = getFunction();
958 
959   // Potentially the binary ops might be fused together, like hard_swish, thus
960   // we explore these potentially first and then fuse the binary ops with the
961   // following ops in a second pattern match.
962   TFL::populateWithGenerated(ctx, patterns);
963   patterns.insert<FuseFullyConnectedAndAdd,
964                   FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
965                   FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
966                   FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
967                   FuseFullyConnectedAndMul>(ctx);
968   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
969 
970   // Fuse the binary ops with the following ops.
971   OwningRewritePatternList phase_2_patterns;
972   TFL::populateWithGenerated(ctx, phase_2_patterns);
973   phase_2_patterns.insert<
974       ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
975       ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
976       FuseFullyConnectedAndAdd, FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
977       FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
978       FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
979       FuseFullyConnectedAndMul, FuseBinaryOpToFollowingConv2D,
980       FuseBinaryOpToFollowingDepthwiseConv2D,
981       FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
982       FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp>(
983       ctx);
984   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
985 }
986 
987 }  // namespace
988 
989 // Creates an instance of the TensorFlow Lite dialect Optimize pass.
CreateOptimizePass()990 std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass() {
991   return std::make_unique<Optimize>();
992 }
993 
994 static PassRegistration<Optimize> pass(
995     "tfl-optimize", "Optimize within the TensorFlow Lite dialect");
996 
997 }  // namespace TFL
998 }  // namespace mlir
999