• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h"
17 
18 #include <memory>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Matchers.h"  // from @llvm-project
31 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
38 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
41 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
44 
45 namespace mlir {
46 namespace TFL {
47 namespace tac {
48 namespace {
49 
50 // ================== Common ========================
51 
52 // Converts any IntegerAttr to an IntegerAttr of an i32 type.
53 // The value won't change in the new attribute, but if the value is out of
54 // the bound of i32, the function returns a failure.
ConvertToI32Attr(IntegerAttr attr,IntegerAttr * attr_i32)55 LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
56   if (attr.getType().isInteger(/*width=*/32)) {
57     *attr_i32 = attr;
58     return success();
59   }
60 
61   int64_t value = attr.getInt();
62   if (value > std::numeric_limits<int>::max() ||
63       value < std::numeric_limits<int>::min()) {
64     return failure();
65   }
66 
67   *attr_i32 = IntegerAttr::get(
68       IntegerType::get(attr.getContext(), /*width=*/32), value);
69   return success();
70 }
71 
InsertReshapeOp(Location loc,Value input,Type element_type,llvm::ArrayRef<int64_t> new_shape_array,OpBuilder * builder)72 TFL::ReshapeOp InsertReshapeOp(Location loc, Value input, Type element_type,
73                                llvm::ArrayRef<int64_t> new_shape_array,
74                                OpBuilder* builder) {
75   auto reshape_shape_type = mlir::RankedTensorType::get(
76       new_shape_array.size(), builder->getIntegerType(32));
77 
78   // This is to workaround the unnecessary cast i64 -> i32. :(
79   // TODO(renjieliu): Revisit this later.
80   SmallVector<int32_t, 4> new_shape_array_i32;
81   for (auto size : new_shape_array) {
82     new_shape_array_i32.push_back(static_cast<int32_t>(size));
83   }
84   auto new_shape_attr =
85       mlir::DenseIntElementsAttr::get(reshape_shape_type, new_shape_array_i32);
86 
87   auto new_shape = builder->create<TFL::ConstOp>(loc, new_shape_attr);
88 
89   auto reshape_out_type = RankedTensorType::get(new_shape_array, element_type);
90   return builder->create<TFL::ReshapeOp>(loc, reshape_out_type, input,
91                                          new_shape);
92 }
93 
EnsureBias(Operation * op,int bias_idx,PatternRewriter & rewriter)94 LogicalResult EnsureBias(Operation* op, int bias_idx,
95                          PatternRewriter& rewriter) {
96   auto bias = op->getOperand(bias_idx);
97 
98   if (!bias.getType().isa<NoneType>()) return failure();
99 
100   // Proceed to create a zero bias.
101   auto output = op->getResult(0);
102   auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
103   if (!output_type) return failure();
104 
105   // bias should be a vector sized of the last output dim.
106   int num_units = output_type.getDimSize(output_type.getRank() - 1);
107   auto bias_type =
108       mlir::RankedTensorType::get({num_units}, output_type.getElementType());
109 
110   mlir::DenseElementsAttr bias_attr;
111   if (output_type.getElementType().isF32()) {
112     float val = 0.0;
113     bias_attr = mlir::DenseFPElementsAttr::get(bias_type, val);
114   } else {
115     // TODO(renjieliu): Refactor this and share the logic with
116     // CreateConstOpWithSingleValue. Also, make sure it works with QConst.
117     return failure();
118   }
119 
120   auto zero_bias = rewriter.create<TFL::ConstOp>(op->getLoc(), bias_attr);
121   op->setOperand(bias_idx, zero_bias);
122 
123   return success();
124 }
125 
PadConstValues(Operation * input_op,int value_to_pad,int pad_dimensions,Location loc,OpBuilder * builder)126 TF::ConstOp PadConstValues(Operation* input_op, int value_to_pad,
127                            int pad_dimensions, Location loc,
128                            OpBuilder* builder) {
129   if (input_op == nullptr) return nullptr;
130 
131   mlir::DenseIntElementsAttr attr;
132   if (!matchPattern(input_op, m_Constant(&attr))) {
133     return nullptr;
134   }
135 
136   auto value_shape_type = mlir::RankedTensorType::get(
137       {pad_dimensions}, builder->getIntegerType(32));
138 
139   SmallVector<int32_t, 4> value_i32;
140   value_i32.reserve(pad_dimensions);
141   for (int i = 0; i < pad_dimensions - attr.getNumElements(); ++i) {
142     value_i32.push_back(value_to_pad);
143   }
144   for (const auto& size : attr) {
145     value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
146   }
147   auto new_value_i32_attr =
148       mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
149 
150   return builder->create<TF::ConstOp>(loc, new_value_i32_attr);
151 }
152 
SliceOutputs(Operation * split_op,Value input,RankedTensorType input_type,int split_dim,int num_splits,PatternRewriter * rewriter)153 SmallVector<Value, 4> SliceOutputs(Operation* split_op, Value input,
154                                    RankedTensorType input_type, int split_dim,
155                                    int num_splits, PatternRewriter* rewriter) {
156   SmallVector<Value, 4> slice_outputs;
157   int begin = 0;
158   for (int i = 0; i < num_splits; ++i) {
159     // Create slice op.
160     // Populate begin & size.
161     SmallVector<int32_t, 4> slice_begin;
162     SmallVector<int32_t, 4> slice_size;
163     auto current_output = split_op->getResult(i);
164     auto current_output_type =
165         current_output.getType().cast<RankedTensorType>();
166     for (int d = 0; d < input_type.getRank(); ++d) {
167       if (d == split_dim) {
168         // Split dimension.
169         slice_begin.push_back(begin);
170         int size = current_output_type.getDimSize(d);
171         slice_size.push_back(size);
172         begin += size;
173       } else {
174         slice_begin.push_back(0);
175         // -1 means every elements.
176         slice_size.push_back(-1);
177       }
178     }
179 
180     auto slice_type = mlir::RankedTensorType::get(slice_begin.size(),
181                                                   rewriter->getIntegerType(32));
182     auto slice_begin_attr =
183         mlir::DenseIntElementsAttr::get(slice_type, slice_begin);
184     auto slice_size_attr =
185         mlir::DenseIntElementsAttr::get(slice_type, slice_size);
186 
187     auto slice_begin_const =
188         rewriter->create<TFL::ConstOp>(split_op->getLoc(), slice_begin_attr);
189     auto slice_size_const =
190         rewriter->create<TFL::ConstOp>(split_op->getLoc(), slice_size_attr);
191 
192     auto slice_op = rewriter->create<TFL::SliceOp>(
193         split_op->getLoc(), current_output_type, input, slice_begin_const,
194         slice_size_const);
195 
196     // Rewire output.
197     slice_outputs.push_back(slice_op.getResult());
198   }
199   return slice_outputs;
200 }
201 
202 }  // namespace
203 
204 // ================== Pack ========================
205 
matchAndRewrite(TFL::PackOp pack_op,PatternRewriter & rewriter) const206 LogicalResult LowerPackIntoConcatReshape::matchAndRewrite(
207     TFL::PackOp pack_op, PatternRewriter& rewriter) const {
208   // Pack op should have same shape type.
209   SmallVector<Value, 5> pack_inputs(pack_op.values());
210   auto input_type = pack_inputs[0].getType().dyn_cast<RankedTensorType>();
211   if (!input_type) return failure();
212 
213   // Figure out output shapes.
214   SmallVector<int64_t, 4> concat_out_shape;
215   SmallVector<int64_t, 4> pack_out_shape;
216 
217   const int rank = input_type.getRank();
218   int pack_axis = pack_op.axis();
219   int count = pack_inputs.size();
220   if (pack_axis < 0) {
221     pack_axis += rank;
222   }
223 
224   // Concat out shape.
225   for (int i = 0; i < rank; ++i) {
226     int dim_size = input_type.getDimSize(i);
227     if (i == pack_axis) {
228       dim_size *= count;
229     }
230     concat_out_shape.push_back(dim_size);
231   }
232 
233   // Pack out shape.
234   int j = 0;
235   for (int i = 0; i < rank + 1; ++i) {
236     if (i == pack_axis) {
237       pack_out_shape.push_back(count);
238     } else {
239       pack_out_shape.push_back(input_type.getDimSize(j));
240       j++;
241     }
242   }
243 
244   if (failed(TF::VerifyShapeOfReshapeOp(pack_out_shape))) return failure();
245 
246   // Insert the concat op.
247   auto concat_out_type =
248       RankedTensorType::get(concat_out_shape, input_type.getElementType());
249   auto concat_op = rewriter.create<TFL::ConcatenationOp>(
250       pack_op.getLoc(), concat_out_type, pack_inputs, pack_op.axis(), "NONE");
251 
252   auto reshape_op =
253       InsertReshapeOp(pack_op.getLoc(), concat_op, input_type.getElementType(),
254                       pack_out_shape, &rewriter);
255 
256   // Rewire output & get rid of the pack op.
257   rewriter.replaceOp(pack_op, reshape_op.getResult());
258   return success();
259 }
260 
261 // ================== squared_difference ========================
262 
matchAndRewrite(TFL::SquaredDifferenceOp squared_diff_op,PatternRewriter & rewriter) const263 LogicalResult SquaredDifference::matchAndRewrite(
264     TFL::SquaredDifferenceOp squared_diff_op, PatternRewriter& rewriter) const {
265   auto x = squared_diff_op.lhs();
266   auto y = squared_diff_op.rhs();
267   auto x_type = x.getType().dyn_cast<RankedTensorType>();
268   auto y_type = y.getType().dyn_cast<RankedTensorType>();
269   if (!x_type || !y_type) return failure();
270   if (x_type.getShape() != y_type.getShape()) return failure();
271 
272   auto result_type = squared_diff_op.getType();
273   if (!result_type) return failure();
274 
275   auto sub_op =
276       rewriter.create<TF::SubOp>(squared_diff_op.getLoc(), result_type, x, y);
277   auto mul_op =
278       rewriter.create<TF::MulOp>(squared_diff_op.getLoc(), result_type,
279                                  sub_op.getResult(), sub_op.getResult());
280   rewriter.replaceOp(squared_diff_op, mul_op.getResult());
281 
282   return success();
283 }
284 
285 // ================== split ========================
286 
matchAndRewrite(TFL::SplitOp split_op,PatternRewriter & rewriter) const287 LogicalResult UnrollSplit::matchAndRewrite(TFL::SplitOp split_op,
288                                            PatternRewriter& rewriter) const {
289   auto num_splits = split_op.num_splits();
290   auto input = split_op.value();
291   auto input_type = input.getType().dyn_cast<RankedTensorType>();
292   if (input_type == nullptr || !input_type.hasStaticShape()) return failure();
293 
294   for (auto result : split_op.getResults()) {
295     auto result_type = result.getType().dyn_cast<RankedTensorType>();
296     if (result_type == nullptr) return failure();
297   }
298 
299   auto output = split_op.getResult(0);
300   auto output_type = output.getType().cast<RankedTensorType>();
301 
302   // TODO(renjieliu): change to use split_dim when we raise the constants
303   // as well.
304   int split_dim = -1;
305   for (int d = 0; d < input_type.getRank(); ++d) {
306     if (input_type.getDimSize(d) != output_type.getDimSize(d)) split_dim = d;
307   }
308 
309   const SmallVector<Value, 4>& slice_outputs = SliceOutputs(
310       split_op, input, input_type, split_dim, num_splits, &rewriter);
311   rewriter.replaceOp(split_op, slice_outputs);
312   return success();
313 }
314 
315 // ================== splitV ========================
316 
matchAndRewrite(TFL::SplitVOp splitv_op,PatternRewriter & rewriter) const317 LogicalResult UnrollSplitV::matchAndRewrite(TFL::SplitVOp splitv_op,
318                                             PatternRewriter& rewriter) const {
319   // We need to make sure both splits & split dim are constants.
320   auto splits = splitv_op.size_splits().getDefiningOp();
321   mlir::DenseIntElementsAttr splits_attr;
322   if (!splits || !matchPattern(splits, m_Constant(&splits_attr)))
323     return failure();
324 
325   auto split_dim = splitv_op.split_dim().getDefiningOp();
326   mlir::ElementsAttr split_dim_attr;
327   if (!split_dim || !matchPattern(split_dim, m_Constant(&split_dim_attr)))
328     return failure();
329 
330   auto input = splitv_op.value();
331   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
332   if (!input_type || !input_type.hasRank()) return failure();
333 
334   for (auto result : splitv_op.getResults()) {
335     auto result_type = result.getType().dyn_cast<RankedTensorType>();
336     if (result_type == nullptr) return failure();
337   }
338 
339   const int rank = input_type.getRank();
340 
341   IntegerAttr dim_int = ExtractSingleElementAsInteger(split_dim_attr);
342 
343   // "axis" operand could be a i64 tensor. Resolve it here.
344   IntegerAttr dim_i32;
345   if (failed(ConvertToI32Attr(dim_int, &dim_i32))) return failure();
346 
347   int dim = dim_i32.getInt();
348   if (dim < 0) dim += rank;
349 
350   const SmallVector<Value, 4>& slice_outputs = SliceOutputs(
351       splitv_op, input, input_type, dim, splitv_op.num_splits(), &rewriter);
352   rewriter.replaceOp(splitv_op, slice_outputs);
353 
354   return success();
355 }
356 
357 // ================== conv_2d ========================
358 
matchAndRewrite(TFL::Conv2DOp conv_op,PatternRewriter & rewriter) const359 LogicalResult EnsureBiasForConv2d::matchAndRewrite(
360     TFL::Conv2DOp conv_op, PatternRewriter& rewriter) const {
361   return EnsureBias(conv_op, 2, rewriter);
362 }
363 
364 // ================== slice ============================
365 
366 // If a slice op has < 4d dimension, will pad it to 4d.
matchAndRewrite(TFL::SliceOp slice_op,PatternRewriter & rewriter) const367 LogicalResult PadSlice::matchAndRewrite(TFL::SliceOp slice_op,
368                                         PatternRewriter& rewriter) const {
369   // We have to know the shape of the input, as well as the begin/size.
370   // also, begin and size have to be constants.
371   auto input = slice_op.input();
372   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
373   if (!input_type || !input_type.hasStaticShape()) return failure();
374 
375   if (input_type.getRank() >= 4) return failure();
376 
377   auto begin = slice_op.begin();
378   auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
379   if (!begin_type || !begin_type.hasStaticShape()) return failure();
380 
381   auto size = slice_op.size();
382   auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
383   if (!size_type || !size_type.hasStaticShape()) return failure();
384 
385   auto output_type = slice_op.getType().dyn_cast_or_null<RankedTensorType>();
386   if (!output_type || !output_type.hasStaticShape()) return failure();
387 
388   // Pad 0s in front of the begin.
389   TF::ConstOp new_begin =
390       PadConstValues(begin.getDefiningOp(), 0, 4, slice_op.getLoc(), &rewriter);
391   if (!new_begin) return failure();
392 
393   // Pad 1s in front of the size.
394   TF::ConstOp new_size =
395       PadConstValues(size.getDefiningOp(), 1, 4, slice_op.getLoc(), &rewriter);
396   if (!new_size) return failure();
397 
398   // Reshape the input to 4d.
399   SmallVector<int64_t, 4> new_shape;
400   const int current_rank = input_type.getRank();
401   for (int i = 0; i < 4 - current_rank; ++i) {
402     new_shape.push_back(1);
403   }
404   for (auto size : input_type.getShape()) {
405     new_shape.push_back(size);
406   }
407 
408   auto reshape_op =
409       InsertReshapeOp(slice_op.getLoc(), input, input_type.getElementType(),
410                       new_shape, &rewriter);
411 
412   // Replace with the new slice op.
413   SmallVector<int64_t, 4> new_output_shape;
414   for (int i = 0; i < 4 - current_rank; ++i) {
415     new_output_shape.push_back(1);
416   }
417   for (auto size : output_type.getShape()) {
418     new_output_shape.push_back(size);
419   }
420 
421   RankedTensorType new_output_type =
422       RankedTensorType::get(new_output_shape, output_type.getElementType());
423 
424   auto new_slice = rewriter.create<TFL::SliceOp>(
425       slice_op.getLoc(), new_output_type, reshape_op, new_begin, new_size);
426 
427   // Append a reshape at the bottom.
428   auto output_reshape_op = InsertReshapeOp(slice_op.getLoc(), new_slice,
429                                            output_type.getElementType(),
430                                            output_type.getShape(), &rewriter);
431   rewriter.replaceOp(slice_op, output_reshape_op.getResult());
432 
433   return success();
434 }
435 
436 // ================== fully_connected ========================
437 
438 // TFL fully_connected basically does:
439 // Weight * Input + bias.
440 // Input layout is : [..., depth]
441 // Weight layout is : [output, depth]
442 // Bias is [output].
443 //
444 // While conv2d is:
445 // Filter: [NHWC]
446 // Input is also: [NHWC]
447 // Bias is [N]
448 //
449 // So to perform the transform, we need to insert a few reshape ops:
450 //
451 //  Input   weight   bias
452 //   \      /      /
453 //       FC
454 //       |
455 //     output
456 //
457 //     |
458 //    \/
459 //
460 //  Input   weight
461 //   |        |
462 //  Reshape  Reshape  bias
463 //  |         |      /
464 //     conv
465 //      |
466 //     reshape
467 //      |
468 //    output
matchAndRewrite(TFL::FullyConnectedOp fc_op,PatternRewriter & rewriter) const469 LogicalResult FullyConnectedToConv::matchAndRewrite(
470     TFL::FullyConnectedOp fc_op, PatternRewriter& rewriter) const {
471   // We have to know the shape of the input.
472   auto input = fc_op.input();
473   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
474   if (!input_type || !input_type.hasStaticShape()) return failure();
475 
476   // We have to know the shape of the weight.
477   auto weight = fc_op.filter();
478   auto weight_type = weight.getType().dyn_cast_or_null<RankedTensorType>();
479   if (!weight_type || !weight_type.hasStaticShape()) return failure();
480 
481   // We have to know the shape of the output as well.
482   auto output = fc_op.getResult(0);
483   auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
484   if (!output_type || !output_type.hasStaticShape()) return failure();
485 
486   // Insert a reshape after the input.
487   // Since the input maybe more than 2-d, we may collect the flat size of the
488   // input then reshape into [1, 1, flat_size / depth, depth].
489   const int depth = input_type.getDimSize(input_type.getRank() - 1);
490   const int flat_size = input_type.getNumElements();
491   const int width = flat_size / depth;
492   SmallVector<int64_t, 4> input_new_shape({1, 1, width, depth});
493   auto reshaped_input =
494       InsertReshapeOp(fc_op.getLoc(), input, input_type.getElementType(),
495                       input_new_shape, &rewriter);
496 
497   // Insert a reshape after the weight.
498   // We will reshape the weight into [output, 1, 1, depth]
499   const int output_size = weight_type.getDimSize(0);
500   SmallVector<int64_t, 2> weight_new_shape({output_size, 1, 1, depth});
501   auto reshaped_weight =
502       InsertReshapeOp(fc_op.getLoc(), weight, weight_type.getElementType(),
503                       weight_new_shape, &rewriter);
504 
505   // Replace the fc with conv.
506   // The output would be [1, 1, width, output].
507   auto conv_output_type = RankedTensorType::get({1, 1, width, output_size},
508                                                 output_type.getElementType());
509   auto conv = rewriter.create<TFL::Conv2DOp>(
510       fc_op.getLoc(), conv_output_type, reshaped_input, reshaped_weight,
511       fc_op.bias(), rewriter.getI32IntegerAttr(1),
512       rewriter.getI32IntegerAttr(1), fc_op.fused_activation_functionAttr(),
513       rewriter.getStringAttr("VALID"), rewriter.getI32IntegerAttr(1),
514       rewriter.getI32IntegerAttr(1));
515 
516   // Insert a shape after the conv.
517   auto reshaped_conv =
518       InsertReshapeOp(fc_op.getLoc(), conv, output_type.getElementType(),
519                       output_type.getShape(), &rewriter);
520 
521   rewriter.replaceOp(fc_op, reshaped_conv.getResult());
522 
523   return success();
524 }
525 
526 // ================== concat ============================
527 
528 // If a concat op has < 4d dimension, will pad it to 4d.
matchAndRewrite(TFL::ConcatenationOp concat_op,PatternRewriter & rewriter) const529 LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op,
530                                          PatternRewriter& rewriter) const {
531   int rank = -1;
532   for (auto input : concat_op.values()) {
533     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
534     if (!input_type || !input_type.hasStaticShape()) return failure();
535 
536     rank = input_type.getRank();
537   }
538 
539   auto output_type = concat_op.getType().dyn_cast_or_null<RankedTensorType>();
540   if (!output_type || !output_type.hasStaticShape()) return failure();
541 
542   if (rank >= 4) return failure();
543 
544   // All values should have the same rank.
545   // We will insert a reshape op after every input.
546   SmallVector<Value, 4> reshape_ops;
547   for (auto input : concat_op.values()) {
548     auto input_type = input.getType().cast<RankedTensorType>();
549     // Get the new shape.
550     SmallVector<int64_t, 4> new_shape;
551     for (int i = 0; i < 4 - rank; ++i) {
552       new_shape.push_back(1);
553     }
554     for (auto size : input_type.getShape()) {
555       new_shape.push_back(size);
556     }
557 
558     auto reshape_op =
559         InsertReshapeOp(concat_op.getLoc(), input, input_type.getElementType(),
560                         new_shape, &rewriter);
561     reshape_ops.push_back(reshape_op.getResult());
562   }
563 
564   // Deal with the axis.
565   // We don't need to handle axis < 0, since it's counting reversely.
566   int32_t axis = concat_op.axis();
567   if (axis >= 0) {
568     axis += (4 - rank);
569   }
570 
571   // Replace with the new concat op.
572   SmallVector<int64_t, 4> new_output_shape;
573   for (int i = 0; i < 4 - rank; ++i) {
574     new_output_shape.push_back(1);
575   }
576   for (auto size : output_type.getShape()) {
577     new_output_shape.push_back(size);
578   }
579 
580   RankedTensorType new_output_type =
581       RankedTensorType::get(new_output_shape, output_type.getElementType());
582 
583   auto new_concat = rewriter.create<TFL::ConcatenationOp>(
584       concat_op.getLoc(), new_output_type, reshape_ops, axis,
585       concat_op.fused_activation_function());
586 
587   // Append a reshape at the bottom.
588   auto output_reshape_op = InsertReshapeOp(concat_op.getLoc(), new_concat,
589                                            output_type.getElementType(),
590                                            output_type.getShape(), &rewriter);
591   rewriter.replaceOp(concat_op, output_reshape_op.getResult());
592 
593   return success();
594 }
595 
596 // ================== mean ========================
597 
598 // Currently NNAPI does not support mean op with different scales (quantization
599 // cases), and in TFLite avg_pool will ensure the input & output has the same
600 // scales.
matchAndRewrite(TFL::MeanOp mean_op,PatternRewriter & rewriter) const601 LogicalResult ReduceMeanToAvgPool::matchAndRewrite(
602     TFL::MeanOp mean_op, PatternRewriter& rewriter) const {
603   auto input = mean_op.input();
604   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
605   // Only 4d is supported here.
606   if (!input_type || input_type.getRank() != 4) return failure();
607 
608   // The axes has to be [1, 2].
609   DenseElementsAttr axis_const;
610   if (!matchPattern(mean_op.axis(), m_Constant(&axis_const))) return failure();
611   if (axis_const.size() != 2) return failure();
612   auto axis_values = axis_const.getValues<APInt>();
613   int i = 1;
614   for (auto axis_value : axis_values) {
615     if (axis_value != i++) return failure();
616   }
617 
618   auto output = mean_op.output();
619   auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
620   if (!output_type) return failure();
621 
622   auto input_quantized_type =
623       quant::QuantizedType::getQuantizedElementType(input_type);
624   auto output_quantized_type =
625       quant::QuantizedType::getQuantizedElementType(output_type);
626   // If both the input & output types are non-quantized, they will be both
627   // nullptrs.
628   if (input_quantized_type != output_quantized_type) {
629     return failure();
630   }
631 
632   int batch = input_type.getDimSize(0);
633   int height = input_type.getDimSize(1);
634   int width = input_type.getDimSize(2);
635   int channel = input_type.getDimSize(3);
636 
637   auto avg_pool_output_type = RankedTensorType::get(
638       {batch, 1, 1, channel}, input_type.getElementType());
639   auto avg_pool = rewriter.create<TFL::AveragePool2DOp>(
640       mean_op.getLoc(), avg_pool_output_type, input,
641       rewriter.getI32IntegerAttr(height), rewriter.getI32IntegerAttr(width),
642       rewriter.getStringAttr("VALID"), rewriter.getI32IntegerAttr(1),
643       rewriter.getI32IntegerAttr(1), rewriter.getStringAttr("NONE"));
644 
645   auto value_to_replace = avg_pool.getResult();
646 
647   // If it's not keep dim, we need to insert a reshape after the average
648   // pool.
649   if (!mean_op.keep_dims()) {
650     // Insert the reshape.
651     SmallVector<int64_t, 2> new_shape({batch, channel});
652     auto reshape_op =
653         InsertReshapeOp(mean_op.getLoc(), avg_pool.getResult(),
654                         input_type.getElementType(), new_shape, &rewriter);
655     value_to_replace = reshape_op.getResult();
656   }
657 
658   rewriter.replaceOp(mean_op, value_to_replace);
659   return success();
660 }
661 
662 // Insert a "requant" op after the mean op if the mean has different scales for
663 // input & output.
664 // Please note: THIS IS NOT a mathmetically-equivalent transformation and it may
665 // loose accuracy, so we need to use this very very carefully.
matchAndRewrite(TFL::MeanOp mean_op,PatternRewriter & rewriter) const666 LogicalResult InsertRequantForReduceMean::matchAndRewrite(
667     TFL::MeanOp mean_op, PatternRewriter& rewriter) const {
668   auto input = mean_op.input();
669   auto input_type = input.getType().dyn_cast_or_null<ShapedType>();
670   if (!input_type) return failure();
671 
672   // Only need to do this for quantized input.
673   auto input_quantized_type =
674       quant::QuantizedType::getQuantizedElementType(input_type);
675   if (!input_quantized_type) return failure();
676 
677   auto output = mean_op.output();
678   auto output_type = output.getType().dyn_cast_or_null<ShapedType>();
679   if (!output_type) return failure();
680   auto output_quantized_type =
681       quant::QuantizedType::getQuantizedElementType(output_type);
682 
683   // If the quantized type is the same, we don't need to do anything.
684   if (input_quantized_type == output_quantized_type) return failure();
685 
686   auto new_output_type =
687       RankedTensorType::get(output_type.getShape(), input_quantized_type);
688   auto new_mean_op =
689       rewriter.create<TFL::MeanOp>(mean_op->getLoc(), new_output_type, input,
690                                    mean_op.axis(), mean_op.keep_dims());
691 
692   // Insert a requant op.
693   rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(
694       mean_op, output_type, new_mean_op, mlir::TypeAttr::get(output_type));
695   return success();
696 }
697 
698 }  // namespace tac
699 }  // namespace TFL
700 }  // namespace mlir
701