1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h"
17
18 #include <memory>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/Builders.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/Matchers.h" // from @llvm-project
29 #include "mlir/IR/PatternMatch.h" // from @llvm-project
30 #include "mlir/IR/Value.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
33 #include "mlir/Support/LLVM.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
36 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
37 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
40
41 namespace mlir {
42 namespace TFL {
43 namespace tac {
44 namespace {
45
46 // ================== Common ========================
47
48 // Converts any IntegerAttr to an IntegerAttr of an i32 type.
49 // The value won't change in the new attribute, but if the value is out of
50 // the bound of i32, the function returns a failure.
ConvertToI32Attr(IntegerAttr attr,IntegerAttr * attr_i32)51 LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
52 if (attr.getType().isInteger(/*width=*/32)) {
53 *attr_i32 = attr;
54 return success();
55 }
56
57 int64_t value = attr.getInt();
58 if (value > std::numeric_limits<int>::max() ||
59 value < std::numeric_limits<int>::min()) {
60 return failure();
61 }
62
63 *attr_i32 = IntegerAttr::get(
64 IntegerType::get(attr.getContext(), /*width=*/32), value);
65 return success();
66 }
67
InsertReshapeOp(Location loc,Value input,Type element_type,llvm::ArrayRef<int64_t> new_shape_array,OpBuilder * builder)68 TFL::ReshapeOp InsertReshapeOp(Location loc, Value input, Type element_type,
69 llvm::ArrayRef<int64_t> new_shape_array,
70 OpBuilder* builder) {
71 auto reshape_shape_type = mlir::RankedTensorType::get(
72 new_shape_array.size(), builder->getIntegerType(32));
73
74 // This is to workaround the unnecessary cast i64 -> i32. :(
75 // TODO(renjieliu): Revisit this later.
76 SmallVector<int32_t, 4> new_shape_array_i32;
77 for (auto size : new_shape_array) {
78 new_shape_array_i32.push_back(static_cast<int32_t>(size));
79 }
80 auto new_shape_attr =
81 mlir::DenseIntElementsAttr::get(reshape_shape_type, new_shape_array_i32);
82
83 auto new_shape = builder->create<TFL::ConstOp>(loc, new_shape_attr);
84
85 auto reshape_out_type = RankedTensorType::get(new_shape_array, element_type);
86 return builder->create<TFL::ReshapeOp>(loc, reshape_out_type, input,
87 new_shape);
88 }
89
EnsureBias(Operation * op,int bias_idx,PatternRewriter & rewriter)90 LogicalResult EnsureBias(Operation* op, int bias_idx,
91 PatternRewriter& rewriter) {
92 auto bias = op->getOperand(bias_idx);
93
94 if (!bias.getType().isa<NoneType>()) return failure();
95
96 // Proceed to create a zero bias.
97 auto output = op->getResult(0);
98 auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
99 if (!output_type) return failure();
100
101 // bias should be a vector sized of the last output dim.
102 int num_units = output_type.getDimSize(output_type.getRank() - 1);
103 auto bias_type =
104 mlir::RankedTensorType::get({num_units}, output_type.getElementType());
105
106 mlir::DenseElementsAttr bias_attr;
107 if (output_type.getElementType().isF32()) {
108 float val = 0.0;
109 bias_attr = mlir::DenseFPElementsAttr::get(bias_type, val);
110 } else {
111 // TODO(renjieliu): Refactor this and share the logic with
112 // CreateConstOpWithSingleValue. Also, make sure it works with QConst.
113 return failure();
114 }
115
116 auto zero_bias = rewriter.create<TFL::ConstOp>(op->getLoc(), bias_attr);
117 op->setOperand(bias_idx, zero_bias);
118
119 return success();
120 }
121
PadConstValues(Operation * input_op,int value_to_pad,int pad_dimensions,Location loc,OpBuilder * builder)122 TF::ConstOp PadConstValues(Operation* input_op, int value_to_pad,
123 int pad_dimensions, Location loc,
124 OpBuilder* builder) {
125 if (input_op == nullptr) return nullptr;
126
127 mlir::DenseIntElementsAttr attr;
128 if (!matchPattern(input_op, m_Constant(&attr))) {
129 return nullptr;
130 }
131
132 auto value_shape_type = mlir::RankedTensorType::get(
133 {pad_dimensions}, builder->getIntegerType(32));
134
135 SmallVector<int32_t, 4> value_i32;
136 value_i32.reserve(pad_dimensions);
137 for (int i = 0; i < pad_dimensions - attr.getNumElements(); ++i) {
138 value_i32.push_back(value_to_pad);
139 }
140 for (const auto& size : attr) {
141 value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
142 }
143 auto new_value_i32_attr =
144 mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
145
146 return builder->create<TF::ConstOp>(loc, new_value_i32_attr);
147 }
148
SliceOutputs(Operation * split_op,Value input,RankedTensorType input_type,int split_dim,int num_splits,PatternRewriter * rewriter)149 SmallVector<Value, 4> SliceOutputs(Operation* split_op, Value input,
150 RankedTensorType input_type, int split_dim,
151 int num_splits, PatternRewriter* rewriter) {
152 SmallVector<Value, 4> slice_outputs;
153 int begin = 0;
154 for (int i = 0; i < num_splits; ++i) {
155 // Create slice op.
156 // Populate begin & size.
157 SmallVector<int32_t, 4> slice_begin;
158 SmallVector<int32_t, 4> slice_size;
159 auto current_output = split_op->getResult(i);
160 auto current_output_type =
161 current_output.getType().cast<RankedTensorType>();
162 for (int d = 0; d < input_type.getRank(); ++d) {
163 if (d == split_dim) {
164 // Split dimension.
165 slice_begin.push_back(begin);
166 int size = current_output_type.getDimSize(d);
167 slice_size.push_back(size);
168 begin += size;
169 } else {
170 slice_begin.push_back(0);
171 // -1 means every elements.
172 slice_size.push_back(-1);
173 }
174 }
175
176 auto slice_type = mlir::RankedTensorType::get(slice_begin.size(),
177 rewriter->getIntegerType(32));
178 auto slice_begin_attr =
179 mlir::DenseIntElementsAttr::get(slice_type, slice_begin);
180 auto slice_size_attr =
181 mlir::DenseIntElementsAttr::get(slice_type, slice_size);
182
183 auto slice_begin_const =
184 rewriter->create<TFL::ConstOp>(split_op->getLoc(), slice_begin_attr);
185 auto slice_size_const =
186 rewriter->create<TFL::ConstOp>(split_op->getLoc(), slice_size_attr);
187
188 auto slice_op = rewriter->create<TFL::SliceOp>(
189 split_op->getLoc(), current_output_type, input, slice_begin_const,
190 slice_size_const);
191
192 // Rewire output.
193 slice_outputs.push_back(slice_op.getResult());
194 }
195 return slice_outputs;
196 }
197
198 } // namespace
199
200 // ================== Pack ========================
201
matchAndRewrite(TFL::PackOp pack_op,PatternRewriter & rewriter) const202 LogicalResult LowerPackIntoConcatReshape::matchAndRewrite(
203 TFL::PackOp pack_op, PatternRewriter& rewriter) const {
204 // Pack op should have same shape type.
205 SmallVector<Value, 5> pack_inputs(pack_op.values());
206 auto input_type = pack_inputs[0].getType().dyn_cast<RankedTensorType>();
207 if (!input_type) return failure();
208
209 // Figure out output shapes.
210 SmallVector<int64_t, 4> concat_out_shape;
211 SmallVector<int64_t, 4> pack_out_shape;
212
213 const int rank = input_type.getRank();
214 int pack_axis = pack_op.axis();
215 int count = pack_inputs.size();
216 if (pack_axis < 0) {
217 pack_axis += rank;
218 }
219
220 // Concat out shape.
221 for (int i = 0; i < rank; ++i) {
222 int dim_size = input_type.getDimSize(i);
223 if (i == pack_axis) {
224 dim_size *= count;
225 }
226 concat_out_shape.push_back(dim_size);
227 }
228
229 // Pack out shape.
230 int j = 0;
231 for (int i = 0; i < rank + 1; ++i) {
232 if (i == pack_axis) {
233 pack_out_shape.push_back(count);
234 } else {
235 pack_out_shape.push_back(input_type.getDimSize(j));
236 j++;
237 }
238 }
239
240 if (failed(TF::VerifyShapeOfReshapeOp(pack_out_shape))) return failure();
241
242 // Insert the concat op.
243 auto concat_out_type =
244 RankedTensorType::get(concat_out_shape, input_type.getElementType());
245 auto concat_op = rewriter.create<TFL::ConcatenationOp>(
246 pack_op.getLoc(), concat_out_type, pack_inputs, pack_op.axis(), "NONE");
247
248 auto reshape_op =
249 InsertReshapeOp(pack_op.getLoc(), concat_op, input_type.getElementType(),
250 pack_out_shape, &rewriter);
251
252 // Rewire output & get rid of the pack op.
253 rewriter.replaceOp(pack_op, reshape_op.getResult());
254 return success();
255 }
256
257 // ================== squared_difference ========================
258
matchAndRewrite(TFL::SquaredDifferenceOp squared_diff_op,PatternRewriter & rewriter) const259 LogicalResult SquaredDifference::matchAndRewrite(
260 TFL::SquaredDifferenceOp squared_diff_op, PatternRewriter& rewriter) const {
261 auto x = squared_diff_op.lhs();
262 auto y = squared_diff_op.rhs();
263 auto x_type = x.getType().dyn_cast<RankedTensorType>();
264 auto y_type = y.getType().dyn_cast<RankedTensorType>();
265 if (!x_type || !y_type) return failure();
266 if (x_type.getShape() != y_type.getShape()) return failure();
267
268 auto result_type = squared_diff_op.getType();
269 if (!result_type) return failure();
270
271 auto sub_op =
272 rewriter.create<TF::SubOp>(squared_diff_op.getLoc(), result_type, x, y);
273 auto mul_op =
274 rewriter.create<TF::MulOp>(squared_diff_op.getLoc(), result_type,
275 sub_op.getResult(), sub_op.getResult());
276 rewriter.replaceOp(squared_diff_op, mul_op.getResult());
277
278 return success();
279 }
280
281 // ================== split ========================
282
matchAndRewrite(TFL::SplitOp split_op,PatternRewriter & rewriter) const283 LogicalResult UnrollSplit::matchAndRewrite(TFL::SplitOp split_op,
284 PatternRewriter& rewriter) const {
285 auto num_splits = split_op.num_splits();
286 auto input = split_op.value();
287 auto input_type = input.getType().dyn_cast<RankedTensorType>();
288 if (input_type == nullptr || !input_type.hasStaticShape()) return failure();
289
290 for (auto result : split_op.getResults()) {
291 auto result_type = result.getType().dyn_cast<RankedTensorType>();
292 if (result_type == nullptr) return failure();
293 }
294
295 auto output = split_op.getResult(0);
296 auto output_type = output.getType().cast<RankedTensorType>();
297
298 // TODO(renjieliu): change to use split_dim when we raise the constants
299 // as well.
300 int split_dim = -1;
301 for (int d = 0; d < input_type.getRank(); ++d) {
302 if (input_type.getDimSize(d) != output_type.getDimSize(d)) split_dim = d;
303 }
304
305 const SmallVector<Value, 4>& slice_outputs = SliceOutputs(
306 split_op, input, input_type, split_dim, num_splits, &rewriter);
307 rewriter.replaceOp(split_op, slice_outputs);
308 return success();
309 }
310
311 // ================== splitV ========================
312
matchAndRewrite(TFL::SplitVOp splitv_op,PatternRewriter & rewriter) const313 LogicalResult UnrollSplitV::matchAndRewrite(TFL::SplitVOp splitv_op,
314 PatternRewriter& rewriter) const {
315 // We need to make sure both splits & split dim are constants.
316 auto splits = splitv_op.size_splits().getDefiningOp();
317 mlir::DenseIntElementsAttr splits_attr;
318 if (!splits || !matchPattern(splits, m_Constant(&splits_attr)))
319 return failure();
320
321 auto split_dim = splitv_op.split_dim().getDefiningOp();
322 mlir::ElementsAttr split_dim_attr;
323 if (!split_dim || !matchPattern(split_dim, m_Constant(&split_dim_attr)))
324 return failure();
325
326 auto input = splitv_op.value();
327 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
328 if (!input_type || !input_type.hasRank()) return failure();
329
330 for (auto result : splitv_op.getResults()) {
331 auto result_type = result.getType().dyn_cast<RankedTensorType>();
332 if (result_type == nullptr) return failure();
333 }
334
335 const int rank = input_type.getRank();
336
337 IntegerAttr dim_int = ExtractSingleElementAsInteger(split_dim_attr);
338
339 // "axis" operand could be a i64 tensor. Resolve it here.
340 IntegerAttr dim_i32;
341 if (failed(ConvertToI32Attr(dim_int, &dim_i32))) return failure();
342
343 int dim = dim_i32.getInt();
344 if (dim < 0) dim += rank;
345
346 const SmallVector<Value, 4>& slice_outputs = SliceOutputs(
347 splitv_op, input, input_type, dim, splitv_op.num_splits(), &rewriter);
348 rewriter.replaceOp(splitv_op, slice_outputs);
349
350 return success();
351 }
352
353 // ================== conv_2d ========================
354
matchAndRewrite(TFL::Conv2DOp conv_op,PatternRewriter & rewriter) const355 LogicalResult EnsureBiasForConv2d::matchAndRewrite(
356 TFL::Conv2DOp conv_op, PatternRewriter& rewriter) const {
357 return EnsureBias(conv_op, 2, rewriter);
358 }
359
360 // ================== slice ============================
361
362 // If a slice op has < 4d dimension, will pad it to 4d.
matchAndRewrite(TFL::SliceOp slice_op,PatternRewriter & rewriter) const363 LogicalResult PadSlice::matchAndRewrite(TFL::SliceOp slice_op,
364 PatternRewriter& rewriter) const {
365 // We have to know the shape of the input, as well as the begin/size.
366 // also, begin and size have to be constants.
367 auto input = slice_op.input();
368 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
369 if (!input_type || !input_type.hasStaticShape()) return failure();
370
371 if (input_type.getRank() >= 4) return failure();
372
373 auto begin = slice_op.begin();
374 auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
375 if (!begin_type || !begin_type.hasStaticShape()) return failure();
376
377 auto size = slice_op.size();
378 auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
379 if (!size_type || !size_type.hasStaticShape()) return failure();
380
381 auto output_type = slice_op.getType().dyn_cast_or_null<RankedTensorType>();
382 if (!output_type || !output_type.hasStaticShape()) return failure();
383
384 // Pad 0s in front of the begin.
385 TF::ConstOp new_begin =
386 PadConstValues(begin.getDefiningOp(), 0, 4, slice_op.getLoc(), &rewriter);
387 if (!new_begin) return failure();
388
389 // Pad 1s in front of the size.
390 TF::ConstOp new_size =
391 PadConstValues(size.getDefiningOp(), 1, 4, slice_op.getLoc(), &rewriter);
392 if (!new_size) return failure();
393
394 // Reshape the input to 4d.
395 SmallVector<int64_t, 4> new_shape;
396 const int current_rank = input_type.getRank();
397 for (int i = 0; i < 4 - current_rank; ++i) {
398 new_shape.push_back(1);
399 }
400 for (auto size : input_type.getShape()) {
401 new_shape.push_back(size);
402 }
403
404 auto reshape_op =
405 InsertReshapeOp(slice_op.getLoc(), input, input_type.getElementType(),
406 new_shape, &rewriter);
407
408 // Replace with the new slice op.
409 SmallVector<int64_t, 4> new_output_shape;
410 for (int i = 0; i < 4 - current_rank; ++i) {
411 new_output_shape.push_back(1);
412 }
413 for (auto size : output_type.getShape()) {
414 new_output_shape.push_back(size);
415 }
416
417 RankedTensorType new_output_type =
418 RankedTensorType::get(new_output_shape, output_type.getElementType());
419
420 auto new_slice = rewriter.create<TFL::SliceOp>(
421 slice_op.getLoc(), new_output_type, reshape_op, new_begin, new_size);
422
423 // Append a reshape at the bottom.
424 auto output_reshape_op = InsertReshapeOp(slice_op.getLoc(), new_slice,
425 output_type.getElementType(),
426 output_type.getShape(), &rewriter);
427 rewriter.replaceOp(slice_op, output_reshape_op.getResult());
428
429 return success();
430 }
431
432 // ================== fully_connected ========================
433
434 // TFL fully_connected basically does:
435 // Weight * Input + bias.
436 // Input layout is : [..., depth]
437 // Weight layout is : [output, depth]
438 // Bias is [output].
439 //
440 // While conv2d is:
441 // Filter: [NHWC]
442 // Input is also: [NHWC]
443 // Bias is [N]
444 //
445 // So to perform the transform, we need to insert a few reshape ops:
446 //
447 // Input weight bias
448 // \ / /
449 // FC
450 // |
451 // output
452 //
453 // |
454 // \/
455 //
456 // Input weight
457 // | |
458 // Reshape Reshape bias
459 // | | /
460 // conv
461 // |
462 // reshape
463 // |
464 // output
matchAndRewrite(TFL::FullyConnectedOp fc_op,PatternRewriter & rewriter) const465 LogicalResult FullyConnectedToConv::matchAndRewrite(
466 TFL::FullyConnectedOp fc_op, PatternRewriter& rewriter) const {
467 // We have to know the shape of the input.
468 auto input = fc_op.input();
469 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
470 if (!input_type || !input_type.hasStaticShape()) return failure();
471
472 // We have to know the shape of the weight.
473 auto weight = fc_op.filter();
474 auto weight_type = weight.getType().dyn_cast_or_null<RankedTensorType>();
475 if (!weight_type || !weight_type.hasStaticShape()) return failure();
476
477 // We have to know the shape of the output as well.
478 auto output = fc_op.getResult(0);
479 auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
480 if (!output_type || !output_type.hasStaticShape()) return failure();
481
482 // Insert a reshape after the input.
483 // Since the input maybe more than 2-d, we may collect the flat size of the
484 // input then reshape into [1, 1, flat_size / depth, depth].
485 const int depth = input_type.getDimSize(input_type.getRank() - 1);
486 const int flat_size = input_type.getNumElements();
487 const int width = flat_size / depth;
488 SmallVector<int64_t, 4> input_new_shape({1, 1, width, depth});
489 auto reshaped_input =
490 InsertReshapeOp(fc_op.getLoc(), input, input_type.getElementType(),
491 input_new_shape, &rewriter);
492
493 // Insert a reshape after the weight.
494 // We will reshape the weight into [output, 1, 1, depth]
495 const int output_size = weight_type.getDimSize(0);
496 SmallVector<int64_t, 2> weight_new_shape({output_size, 1, 1, depth});
497 auto reshaped_weight =
498 InsertReshapeOp(fc_op.getLoc(), weight, weight_type.getElementType(),
499 weight_new_shape, &rewriter);
500
501 // Replace the fc with conv.
502 // The output would be [1, 1, width, output].
503 auto conv_output_type = RankedTensorType::get({1, 1, width, output_size},
504 output_type.getElementType());
505 auto conv = rewriter.create<TFL::Conv2DOp>(
506 fc_op.getLoc(), conv_output_type, reshaped_input, reshaped_weight,
507 fc_op.bias(), rewriter.getI32IntegerAttr(1),
508 rewriter.getI32IntegerAttr(1), fc_op.fused_activation_functionAttr(),
509 rewriter.getStringAttr("VALID"), rewriter.getI32IntegerAttr(1),
510 rewriter.getI32IntegerAttr(1));
511
512 // Insert a shape after the conv.
513 auto reshaped_conv =
514 InsertReshapeOp(fc_op.getLoc(), conv, output_type.getElementType(),
515 output_type.getShape(), &rewriter);
516
517 rewriter.replaceOp(fc_op, reshaped_conv.getResult());
518
519 return success();
520 }
521
522 // ================== concat ============================
523
524 // If a concat op has < 4d dimension, will pad it to 4d.
matchAndRewrite(TFL::ConcatenationOp concat_op,PatternRewriter & rewriter) const525 LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op,
526 PatternRewriter& rewriter) const {
527 int rank = -1;
528 for (auto input : concat_op.values()) {
529 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
530 if (!input_type || !input_type.hasStaticShape()) return failure();
531
532 rank = input_type.getRank();
533 }
534
535 auto output_type = concat_op.getType().dyn_cast_or_null<RankedTensorType>();
536 if (!output_type || !output_type.hasStaticShape()) return failure();
537
538 if (rank >= 4) return failure();
539
540 // All values should have the same rank.
541 // We will insert a reshape op after every input.
542 SmallVector<Value, 4> reshape_ops;
543 for (auto input : concat_op.values()) {
544 auto input_type = input.getType().cast<RankedTensorType>();
545 // Get the new shape.
546 SmallVector<int64_t, 4> new_shape;
547 for (int i = 0; i < 4 - rank; ++i) {
548 new_shape.push_back(1);
549 }
550 for (auto size : input_type.getShape()) {
551 new_shape.push_back(size);
552 }
553
554 auto reshape_op =
555 InsertReshapeOp(concat_op.getLoc(), input, input_type.getElementType(),
556 new_shape, &rewriter);
557 reshape_ops.push_back(reshape_op.getResult());
558 }
559
560 // Deal with the axis.
561 // We don't need to handle axis < 0, since it's counting reversely.
562 int32_t axis = concat_op.axis();
563 if (axis >= 0) {
564 axis += (4 - rank);
565 }
566
567 // Replace with the new concat op.
568 SmallVector<int64_t, 4> new_output_shape;
569 for (int i = 0; i < 4 - rank; ++i) {
570 new_output_shape.push_back(1);
571 }
572 for (auto size : output_type.getShape()) {
573 new_output_shape.push_back(size);
574 }
575
576 RankedTensorType new_output_type =
577 RankedTensorType::get(new_output_shape, output_type.getElementType());
578
579 auto new_concat = rewriter.create<TFL::ConcatenationOp>(
580 concat_op.getLoc(), new_output_type, reshape_ops, axis,
581 concat_op.fused_activation_function());
582
583 // Append a reshape at the bottom.
584 auto output_reshape_op = InsertReshapeOp(concat_op.getLoc(), new_concat,
585 output_type.getElementType(),
586 output_type.getShape(), &rewriter);
587 rewriter.replaceOp(concat_op, output_reshape_op.getResult());
588
589 return success();
590 }
591
592 } // namespace tac
593 } // namespace TFL
594 } // namespace mlir
595