1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h"
17
18 #include <memory>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Matchers.h" // from @llvm-project
31 #include "mlir/IR/PatternMatch.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "mlir/Support/LogicalResult.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
38 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
41 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
44
45 namespace mlir {
46 namespace TFL {
47 namespace tac {
48 namespace {
49
50 // ================== Common ========================
51
52 // Converts any IntegerAttr to an IntegerAttr of an i32 type.
53 // The value won't change in the new attribute, but if the value is out of
54 // the bound of i32, the function returns a failure.
ConvertToI32Attr(IntegerAttr attr,IntegerAttr * attr_i32)55 LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
56 if (attr.getType().isInteger(/*width=*/32)) {
57 *attr_i32 = attr;
58 return success();
59 }
60
61 int64_t value = attr.getInt();
62 if (value > std::numeric_limits<int>::max() ||
63 value < std::numeric_limits<int>::min()) {
64 return failure();
65 }
66
67 *attr_i32 = IntegerAttr::get(
68 IntegerType::get(attr.getContext(), /*width=*/32), value);
69 return success();
70 }
71
InsertReshapeOp(Location loc,Value input,Type element_type,llvm::ArrayRef<int64_t> new_shape_array,OpBuilder * builder)72 TFL::ReshapeOp InsertReshapeOp(Location loc, Value input, Type element_type,
73 llvm::ArrayRef<int64_t> new_shape_array,
74 OpBuilder* builder) {
75 auto reshape_shape_type = mlir::RankedTensorType::get(
76 new_shape_array.size(), builder->getIntegerType(32));
77
78 // This is to workaround the unnecessary cast i64 -> i32. :(
79 // TODO(renjieliu): Revisit this later.
80 SmallVector<int32_t, 4> new_shape_array_i32;
81 for (auto size : new_shape_array) {
82 new_shape_array_i32.push_back(static_cast<int32_t>(size));
83 }
84 auto new_shape_attr =
85 mlir::DenseIntElementsAttr::get(reshape_shape_type, new_shape_array_i32);
86
87 auto new_shape = builder->create<TFL::ConstOp>(loc, new_shape_attr);
88
89 auto reshape_out_type = RankedTensorType::get(new_shape_array, element_type);
90 return builder->create<TFL::ReshapeOp>(loc, reshape_out_type, input,
91 new_shape);
92 }
93
EnsureBias(Operation * op,int bias_idx,PatternRewriter & rewriter)94 LogicalResult EnsureBias(Operation* op, int bias_idx,
95 PatternRewriter& rewriter) {
96 auto bias = op->getOperand(bias_idx);
97
98 if (!bias.getType().isa<NoneType>()) return failure();
99
100 // Proceed to create a zero bias.
101 auto output = op->getResult(0);
102 auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
103 if (!output_type) return failure();
104
105 // bias should be a vector sized of the last output dim.
106 int num_units = output_type.getDimSize(output_type.getRank() - 1);
107 auto bias_type =
108 mlir::RankedTensorType::get({num_units}, output_type.getElementType());
109
110 mlir::DenseElementsAttr bias_attr;
111 if (output_type.getElementType().isF32()) {
112 float val = 0.0;
113 bias_attr = mlir::DenseFPElementsAttr::get(bias_type, val);
114 } else {
115 // TODO(renjieliu): Refactor this and share the logic with
116 // CreateConstOpWithSingleValue. Also, make sure it works with QConst.
117 return failure();
118 }
119
120 auto zero_bias = rewriter.create<TFL::ConstOp>(op->getLoc(), bias_attr);
121 op->setOperand(bias_idx, zero_bias);
122
123 return success();
124 }
125
PadConstValues(Operation * input_op,int value_to_pad,int pad_dimensions,Location loc,OpBuilder * builder)126 TF::ConstOp PadConstValues(Operation* input_op, int value_to_pad,
127 int pad_dimensions, Location loc,
128 OpBuilder* builder) {
129 if (input_op == nullptr) return nullptr;
130
131 mlir::DenseIntElementsAttr attr;
132 if (!matchPattern(input_op, m_Constant(&attr))) {
133 return nullptr;
134 }
135
136 auto value_shape_type = mlir::RankedTensorType::get(
137 {pad_dimensions}, builder->getIntegerType(32));
138
139 SmallVector<int32_t, 4> value_i32;
140 value_i32.reserve(pad_dimensions);
141 for (int i = 0; i < pad_dimensions - attr.getNumElements(); ++i) {
142 value_i32.push_back(value_to_pad);
143 }
144 for (const auto& size : attr) {
145 value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
146 }
147 auto new_value_i32_attr =
148 mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
149
150 return builder->create<TF::ConstOp>(loc, new_value_i32_attr);
151 }
152
SliceOutputs(Operation * split_op,Value input,RankedTensorType input_type,int split_dim,int num_splits,PatternRewriter * rewriter)153 SmallVector<Value, 4> SliceOutputs(Operation* split_op, Value input,
154 RankedTensorType input_type, int split_dim,
155 int num_splits, PatternRewriter* rewriter) {
156 SmallVector<Value, 4> slice_outputs;
157 int begin = 0;
158 for (int i = 0; i < num_splits; ++i) {
159 // Create slice op.
160 // Populate begin & size.
161 SmallVector<int32_t, 4> slice_begin;
162 SmallVector<int32_t, 4> slice_size;
163 auto current_output = split_op->getResult(i);
164 auto current_output_type =
165 current_output.getType().cast<RankedTensorType>();
166 for (int d = 0; d < input_type.getRank(); ++d) {
167 if (d == split_dim) {
168 // Split dimension.
169 slice_begin.push_back(begin);
170 int size = current_output_type.getDimSize(d);
171 slice_size.push_back(size);
172 begin += size;
173 } else {
174 slice_begin.push_back(0);
175 // -1 means every elements.
176 slice_size.push_back(-1);
177 }
178 }
179
180 auto slice_type = mlir::RankedTensorType::get(slice_begin.size(),
181 rewriter->getIntegerType(32));
182 auto slice_begin_attr =
183 mlir::DenseIntElementsAttr::get(slice_type, slice_begin);
184 auto slice_size_attr =
185 mlir::DenseIntElementsAttr::get(slice_type, slice_size);
186
187 auto slice_begin_const =
188 rewriter->create<TFL::ConstOp>(split_op->getLoc(), slice_begin_attr);
189 auto slice_size_const =
190 rewriter->create<TFL::ConstOp>(split_op->getLoc(), slice_size_attr);
191
192 auto slice_op = rewriter->create<TFL::SliceOp>(
193 split_op->getLoc(), current_output_type, input, slice_begin_const,
194 slice_size_const);
195
196 // Rewire output.
197 slice_outputs.push_back(slice_op.getResult());
198 }
199 return slice_outputs;
200 }
201
202 } // namespace
203
204 // ================== Pack ========================
205
matchAndRewrite(TFL::PackOp pack_op,PatternRewriter & rewriter) const206 LogicalResult LowerPackIntoConcatReshape::matchAndRewrite(
207 TFL::PackOp pack_op, PatternRewriter& rewriter) const {
208 // Pack op should have same shape type.
209 SmallVector<Value, 5> pack_inputs(pack_op.values());
210 auto input_type = pack_inputs[0].getType().dyn_cast<RankedTensorType>();
211 if (!input_type) return failure();
212
213 // Figure out output shapes.
214 SmallVector<int64_t, 4> concat_out_shape;
215 SmallVector<int64_t, 4> pack_out_shape;
216
217 const int rank = input_type.getRank();
218 int pack_axis = pack_op.axis();
219 int count = pack_inputs.size();
220 if (pack_axis < 0) {
221 pack_axis += rank;
222 }
223
224 // Concat out shape.
225 for (int i = 0; i < rank; ++i) {
226 int dim_size = input_type.getDimSize(i);
227 if (i == pack_axis) {
228 dim_size *= count;
229 }
230 concat_out_shape.push_back(dim_size);
231 }
232
233 // Pack out shape.
234 int j = 0;
235 for (int i = 0; i < rank + 1; ++i) {
236 if (i == pack_axis) {
237 pack_out_shape.push_back(count);
238 } else {
239 pack_out_shape.push_back(input_type.getDimSize(j));
240 j++;
241 }
242 }
243
244 if (failed(TF::VerifyShapeOfReshapeOp(pack_out_shape))) return failure();
245
246 // Insert the concat op.
247 auto concat_out_type =
248 RankedTensorType::get(concat_out_shape, input_type.getElementType());
249 auto concat_op = rewriter.create<TFL::ConcatenationOp>(
250 pack_op.getLoc(), concat_out_type, pack_inputs, pack_op.axis(), "NONE");
251
252 auto reshape_op =
253 InsertReshapeOp(pack_op.getLoc(), concat_op, input_type.getElementType(),
254 pack_out_shape, &rewriter);
255
256 // Rewire output & get rid of the pack op.
257 rewriter.replaceOp(pack_op, reshape_op.getResult());
258 return success();
259 }
260
261 // ================== squared_difference ========================
262
matchAndRewrite(TFL::SquaredDifferenceOp squared_diff_op,PatternRewriter & rewriter) const263 LogicalResult SquaredDifference::matchAndRewrite(
264 TFL::SquaredDifferenceOp squared_diff_op, PatternRewriter& rewriter) const {
265 auto x = squared_diff_op.lhs();
266 auto y = squared_diff_op.rhs();
267 auto x_type = x.getType().dyn_cast<RankedTensorType>();
268 auto y_type = y.getType().dyn_cast<RankedTensorType>();
269 if (!x_type || !y_type) return failure();
270 if (x_type.getShape() != y_type.getShape()) return failure();
271
272 auto result_type = squared_diff_op.getType();
273 if (!result_type) return failure();
274
275 auto sub_op =
276 rewriter.create<TF::SubOp>(squared_diff_op.getLoc(), result_type, x, y);
277 auto mul_op =
278 rewriter.create<TF::MulOp>(squared_diff_op.getLoc(), result_type,
279 sub_op.getResult(), sub_op.getResult());
280 rewriter.replaceOp(squared_diff_op, mul_op.getResult());
281
282 return success();
283 }
284
285 // ================== split ========================
286
matchAndRewrite(TFL::SplitOp split_op,PatternRewriter & rewriter) const287 LogicalResult UnrollSplit::matchAndRewrite(TFL::SplitOp split_op,
288 PatternRewriter& rewriter) const {
289 auto num_splits = split_op.num_splits();
290 auto input = split_op.value();
291 auto input_type = input.getType().dyn_cast<RankedTensorType>();
292 if (input_type == nullptr || !input_type.hasStaticShape()) return failure();
293
294 for (auto result : split_op.getResults()) {
295 auto result_type = result.getType().dyn_cast<RankedTensorType>();
296 if (result_type == nullptr) return failure();
297 }
298
299 auto output = split_op.getResult(0);
300 auto output_type = output.getType().cast<RankedTensorType>();
301
302 // TODO(renjieliu): change to use split_dim when we raise the constants
303 // as well.
304 int split_dim = -1;
305 for (int d = 0; d < input_type.getRank(); ++d) {
306 if (input_type.getDimSize(d) != output_type.getDimSize(d)) split_dim = d;
307 }
308
309 const SmallVector<Value, 4>& slice_outputs = SliceOutputs(
310 split_op, input, input_type, split_dim, num_splits, &rewriter);
311 rewriter.replaceOp(split_op, slice_outputs);
312 return success();
313 }
314
315 // ================== splitV ========================
316
matchAndRewrite(TFL::SplitVOp splitv_op,PatternRewriter & rewriter) const317 LogicalResult UnrollSplitV::matchAndRewrite(TFL::SplitVOp splitv_op,
318 PatternRewriter& rewriter) const {
319 // We need to make sure both splits & split dim are constants.
320 auto splits = splitv_op.size_splits().getDefiningOp();
321 mlir::DenseIntElementsAttr splits_attr;
322 if (!splits || !matchPattern(splits, m_Constant(&splits_attr)))
323 return failure();
324
325 auto split_dim = splitv_op.split_dim().getDefiningOp();
326 mlir::ElementsAttr split_dim_attr;
327 if (!split_dim || !matchPattern(split_dim, m_Constant(&split_dim_attr)))
328 return failure();
329
330 auto input = splitv_op.value();
331 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
332 if (!input_type || !input_type.hasRank()) return failure();
333
334 for (auto result : splitv_op.getResults()) {
335 auto result_type = result.getType().dyn_cast<RankedTensorType>();
336 if (result_type == nullptr) return failure();
337 }
338
339 const int rank = input_type.getRank();
340
341 IntegerAttr dim_int = ExtractSingleElementAsInteger(split_dim_attr);
342
343 // "axis" operand could be a i64 tensor. Resolve it here.
344 IntegerAttr dim_i32;
345 if (failed(ConvertToI32Attr(dim_int, &dim_i32))) return failure();
346
347 int dim = dim_i32.getInt();
348 if (dim < 0) dim += rank;
349
350 const SmallVector<Value, 4>& slice_outputs = SliceOutputs(
351 splitv_op, input, input_type, dim, splitv_op.num_splits(), &rewriter);
352 rewriter.replaceOp(splitv_op, slice_outputs);
353
354 return success();
355 }
356
357 // ================== conv_2d ========================
358
matchAndRewrite(TFL::Conv2DOp conv_op,PatternRewriter & rewriter) const359 LogicalResult EnsureBiasForConv2d::matchAndRewrite(
360 TFL::Conv2DOp conv_op, PatternRewriter& rewriter) const {
361 return EnsureBias(conv_op, 2, rewriter);
362 }
363
364 // ================== slice ============================
365
366 // If a slice op has < 4d dimension, will pad it to 4d.
matchAndRewrite(TFL::SliceOp slice_op,PatternRewriter & rewriter) const367 LogicalResult PadSlice::matchAndRewrite(TFL::SliceOp slice_op,
368 PatternRewriter& rewriter) const {
369 // We have to know the shape of the input, as well as the begin/size.
370 // also, begin and size have to be constants.
371 auto input = slice_op.input();
372 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
373 if (!input_type || !input_type.hasStaticShape()) return failure();
374
375 if (input_type.getRank() >= 4) return failure();
376
377 auto begin = slice_op.begin();
378 auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
379 if (!begin_type || !begin_type.hasStaticShape()) return failure();
380
381 auto size = slice_op.size();
382 auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
383 if (!size_type || !size_type.hasStaticShape()) return failure();
384
385 auto output_type = slice_op.getType().dyn_cast_or_null<RankedTensorType>();
386 if (!output_type || !output_type.hasStaticShape()) return failure();
387
388 // Pad 0s in front of the begin.
389 TF::ConstOp new_begin =
390 PadConstValues(begin.getDefiningOp(), 0, 4, slice_op.getLoc(), &rewriter);
391 if (!new_begin) return failure();
392
393 // Pad 1s in front of the size.
394 TF::ConstOp new_size =
395 PadConstValues(size.getDefiningOp(), 1, 4, slice_op.getLoc(), &rewriter);
396 if (!new_size) return failure();
397
398 // Reshape the input to 4d.
399 SmallVector<int64_t, 4> new_shape;
400 const int current_rank = input_type.getRank();
401 for (int i = 0; i < 4 - current_rank; ++i) {
402 new_shape.push_back(1);
403 }
404 for (auto size : input_type.getShape()) {
405 new_shape.push_back(size);
406 }
407
408 auto reshape_op =
409 InsertReshapeOp(slice_op.getLoc(), input, input_type.getElementType(),
410 new_shape, &rewriter);
411
412 // Replace with the new slice op.
413 SmallVector<int64_t, 4> new_output_shape;
414 for (int i = 0; i < 4 - current_rank; ++i) {
415 new_output_shape.push_back(1);
416 }
417 for (auto size : output_type.getShape()) {
418 new_output_shape.push_back(size);
419 }
420
421 RankedTensorType new_output_type =
422 RankedTensorType::get(new_output_shape, output_type.getElementType());
423
424 auto new_slice = rewriter.create<TFL::SliceOp>(
425 slice_op.getLoc(), new_output_type, reshape_op, new_begin, new_size);
426
427 // Append a reshape at the bottom.
428 auto output_reshape_op = InsertReshapeOp(slice_op.getLoc(), new_slice,
429 output_type.getElementType(),
430 output_type.getShape(), &rewriter);
431 rewriter.replaceOp(slice_op, output_reshape_op.getResult());
432
433 return success();
434 }
435
436 // ================== fully_connected ========================
437
438 // TFL fully_connected basically does:
439 // Weight * Input + bias.
440 // Input layout is : [..., depth]
441 // Weight layout is : [output, depth]
442 // Bias is [output].
443 //
444 // While conv2d is:
445 // Filter: [NHWC]
446 // Input is also: [NHWC]
447 // Bias is [N]
448 //
449 // So to perform the transform, we need to insert a few reshape ops:
450 //
451 // Input weight bias
452 // \ / /
453 // FC
454 // |
455 // output
456 //
457 // |
458 // \/
459 //
460 // Input weight
461 // | |
462 // Reshape Reshape bias
463 // | | /
464 // conv
465 // |
466 // reshape
467 // |
468 // output
matchAndRewrite(TFL::FullyConnectedOp fc_op,PatternRewriter & rewriter) const469 LogicalResult FullyConnectedToConv::matchAndRewrite(
470 TFL::FullyConnectedOp fc_op, PatternRewriter& rewriter) const {
471 // We have to know the shape of the input.
472 auto input = fc_op.input();
473 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
474 if (!input_type || !input_type.hasStaticShape()) return failure();
475
476 // We have to know the shape of the weight.
477 auto weight = fc_op.filter();
478 auto weight_type = weight.getType().dyn_cast_or_null<RankedTensorType>();
479 if (!weight_type || !weight_type.hasStaticShape()) return failure();
480
481 // We have to know the shape of the output as well.
482 auto output = fc_op.getResult(0);
483 auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
484 if (!output_type || !output_type.hasStaticShape()) return failure();
485
486 // Insert a reshape after the input.
487 // Since the input maybe more than 2-d, we may collect the flat size of the
488 // input then reshape into [1, 1, flat_size / depth, depth].
489 const int depth = input_type.getDimSize(input_type.getRank() - 1);
490 const int flat_size = input_type.getNumElements();
491 const int width = flat_size / depth;
492 SmallVector<int64_t, 4> input_new_shape({1, 1, width, depth});
493 auto reshaped_input =
494 InsertReshapeOp(fc_op.getLoc(), input, input_type.getElementType(),
495 input_new_shape, &rewriter);
496
497 // Insert a reshape after the weight.
498 // We will reshape the weight into [output, 1, 1, depth]
499 const int output_size = weight_type.getDimSize(0);
500 SmallVector<int64_t, 2> weight_new_shape({output_size, 1, 1, depth});
501 auto reshaped_weight =
502 InsertReshapeOp(fc_op.getLoc(), weight, weight_type.getElementType(),
503 weight_new_shape, &rewriter);
504
505 // Replace the fc with conv.
506 // The output would be [1, 1, width, output].
507 auto conv_output_type = RankedTensorType::get({1, 1, width, output_size},
508 output_type.getElementType());
509 auto conv = rewriter.create<TFL::Conv2DOp>(
510 fc_op.getLoc(), conv_output_type, reshaped_input, reshaped_weight,
511 fc_op.bias(), rewriter.getI32IntegerAttr(1),
512 rewriter.getI32IntegerAttr(1), fc_op.fused_activation_functionAttr(),
513 rewriter.getStringAttr("VALID"), rewriter.getI32IntegerAttr(1),
514 rewriter.getI32IntegerAttr(1));
515
516 // Insert a shape after the conv.
517 auto reshaped_conv =
518 InsertReshapeOp(fc_op.getLoc(), conv, output_type.getElementType(),
519 output_type.getShape(), &rewriter);
520
521 rewriter.replaceOp(fc_op, reshaped_conv.getResult());
522
523 return success();
524 }
525
526 // ================== concat ============================
527
528 // If a concat op has < 4d dimension, will pad it to 4d.
matchAndRewrite(TFL::ConcatenationOp concat_op,PatternRewriter & rewriter) const529 LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op,
530 PatternRewriter& rewriter) const {
531 int rank = -1;
532 for (auto input : concat_op.values()) {
533 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
534 if (!input_type || !input_type.hasStaticShape()) return failure();
535
536 rank = input_type.getRank();
537 }
538
539 auto output_type = concat_op.getType().dyn_cast_or_null<RankedTensorType>();
540 if (!output_type || !output_type.hasStaticShape()) return failure();
541
542 if (rank >= 4) return failure();
543
544 // All values should have the same rank.
545 // We will insert a reshape op after every input.
546 SmallVector<Value, 4> reshape_ops;
547 for (auto input : concat_op.values()) {
548 auto input_type = input.getType().cast<RankedTensorType>();
549 // Get the new shape.
550 SmallVector<int64_t, 4> new_shape;
551 for (int i = 0; i < 4 - rank; ++i) {
552 new_shape.push_back(1);
553 }
554 for (auto size : input_type.getShape()) {
555 new_shape.push_back(size);
556 }
557
558 auto reshape_op =
559 InsertReshapeOp(concat_op.getLoc(), input, input_type.getElementType(),
560 new_shape, &rewriter);
561 reshape_ops.push_back(reshape_op.getResult());
562 }
563
564 // Deal with the axis.
565 // We don't need to handle axis < 0, since it's counting reversely.
566 int32_t axis = concat_op.axis();
567 if (axis >= 0) {
568 axis += (4 - rank);
569 }
570
571 // Replace with the new concat op.
572 SmallVector<int64_t, 4> new_output_shape;
573 for (int i = 0; i < 4 - rank; ++i) {
574 new_output_shape.push_back(1);
575 }
576 for (auto size : output_type.getShape()) {
577 new_output_shape.push_back(size);
578 }
579
580 RankedTensorType new_output_type =
581 RankedTensorType::get(new_output_shape, output_type.getElementType());
582
583 auto new_concat = rewriter.create<TFL::ConcatenationOp>(
584 concat_op.getLoc(), new_output_type, reshape_ops, axis,
585 concat_op.fused_activation_function());
586
587 // Append a reshape at the bottom.
588 auto output_reshape_op = InsertReshapeOp(concat_op.getLoc(), new_concat,
589 output_type.getElementType(),
590 output_type.getShape(), &rewriter);
591 rewriter.replaceOp(concat_op, output_reshape_op.getResult());
592
593 return success();
594 }
595
596 // ================== mean ========================
597
598 // Currently NNAPI does not support mean op with different scales (quantization
599 // cases), and in TFLite avg_pool will ensure the input & output has the same
600 // scales.
matchAndRewrite(TFL::MeanOp mean_op,PatternRewriter & rewriter) const601 LogicalResult ReduceMeanToAvgPool::matchAndRewrite(
602 TFL::MeanOp mean_op, PatternRewriter& rewriter) const {
603 auto input = mean_op.input();
604 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
605 // Only 4d is supported here.
606 if (!input_type || input_type.getRank() != 4) return failure();
607
608 // The axes has to be [1, 2].
609 DenseElementsAttr axis_const;
610 if (!matchPattern(mean_op.axis(), m_Constant(&axis_const))) return failure();
611 if (axis_const.size() != 2) return failure();
612 auto axis_values = axis_const.getValues<APInt>();
613 int i = 1;
614 for (auto axis_value : axis_values) {
615 if (axis_value != i++) return failure();
616 }
617
618 auto output = mean_op.output();
619 auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
620 if (!output_type) return failure();
621
622 auto input_quantized_type =
623 quant::QuantizedType::getQuantizedElementType(input_type);
624 auto output_quantized_type =
625 quant::QuantizedType::getQuantizedElementType(output_type);
626 // If both the input & output types are non-quantized, they will be both
627 // nullptrs.
628 if (input_quantized_type != output_quantized_type) {
629 return failure();
630 }
631
632 int batch = input_type.getDimSize(0);
633 int height = input_type.getDimSize(1);
634 int width = input_type.getDimSize(2);
635 int channel = input_type.getDimSize(3);
636
637 auto avg_pool_output_type = RankedTensorType::get(
638 {batch, 1, 1, channel}, input_type.getElementType());
639 auto avg_pool = rewriter.create<TFL::AveragePool2DOp>(
640 mean_op.getLoc(), avg_pool_output_type, input,
641 rewriter.getI32IntegerAttr(height), rewriter.getI32IntegerAttr(width),
642 rewriter.getStringAttr("VALID"), rewriter.getI32IntegerAttr(1),
643 rewriter.getI32IntegerAttr(1), rewriter.getStringAttr("NONE"));
644
645 auto value_to_replace = avg_pool.getResult();
646
647 // If it's not keep dim, we need to insert a reshape after the average
648 // pool.
649 if (!mean_op.keep_dims()) {
650 // Insert the reshape.
651 SmallVector<int64_t, 2> new_shape({batch, channel});
652 auto reshape_op =
653 InsertReshapeOp(mean_op.getLoc(), avg_pool.getResult(),
654 input_type.getElementType(), new_shape, &rewriter);
655 value_to_replace = reshape_op.getResult();
656 }
657
658 rewriter.replaceOp(mean_op, value_to_replace);
659 return success();
660 }
661
662 // Insert a "requant" op after the mean op if the mean has different scales for
663 // input & output.
664 // Please note: THIS IS NOT a mathmetically-equivalent transformation and it may
665 // loose accuracy, so we need to use this very very carefully.
matchAndRewrite(TFL::MeanOp mean_op,PatternRewriter & rewriter) const666 LogicalResult InsertRequantForReduceMean::matchAndRewrite(
667 TFL::MeanOp mean_op, PatternRewriter& rewriter) const {
668 auto input = mean_op.input();
669 auto input_type = input.getType().dyn_cast_or_null<ShapedType>();
670 if (!input_type) return failure();
671
672 // Only need to do this for quantized input.
673 auto input_quantized_type =
674 quant::QuantizedType::getQuantizedElementType(input_type);
675 if (!input_quantized_type) return failure();
676
677 auto output = mean_op.output();
678 auto output_type = output.getType().dyn_cast_or_null<ShapedType>();
679 if (!output_type) return failure();
680 auto output_quantized_type =
681 quant::QuantizedType::getQuantizedElementType(output_type);
682
683 // If the quantized type is the same, we don't need to do anything.
684 if (input_quantized_type == output_quantized_type) return failure();
685
686 auto new_output_type =
687 RankedTensorType::get(output_type.getShape(), input_quantized_type);
688 auto new_mean_op =
689 rewriter.create<TFL::MeanOp>(mean_op->getLoc(), new_output_type, input,
690 mean_op.axis(), mean_op.keep_dims());
691
692 // Insert a requant op.
693 rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(
694 mean_op, output_type, new_mean_op, mlir::TypeAttr::get(output_type));
695 return success();
696 }
697
698 } // namespace tac
699 } // namespace TFL
700 } // namespace mlir
701