• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h"
17 
18 #include <memory>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/Matchers.h"  // from @llvm-project
29 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
30 #include "mlir/IR/Value.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
33 #include "mlir/Support/LLVM.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
36 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
37 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
40 
41 namespace mlir {
42 namespace TFL {
43 namespace tac {
44 namespace {
45 
46 // ================== Common ========================
47 
48 // Converts any IntegerAttr to an IntegerAttr of an i32 type.
49 // The value won't change in the new attribute, but if the value is out of
50 // the bound of i32, the function returns a failure.
ConvertToI32Attr(IntegerAttr attr,IntegerAttr * attr_i32)51 LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
52   if (attr.getType().isInteger(/*width=*/32)) {
53     *attr_i32 = attr;
54     return success();
55   }
56 
57   int64_t value = attr.getInt();
58   if (value > std::numeric_limits<int>::max() ||
59       value < std::numeric_limits<int>::min()) {
60     return failure();
61   }
62 
63   *attr_i32 = IntegerAttr::get(
64       IntegerType::get(attr.getContext(), /*width=*/32), value);
65   return success();
66 }
67 
InsertReshapeOp(Location loc,Value input,Type element_type,llvm::ArrayRef<int64_t> new_shape_array,OpBuilder * builder)68 TFL::ReshapeOp InsertReshapeOp(Location loc, Value input, Type element_type,
69                                llvm::ArrayRef<int64_t> new_shape_array,
70                                OpBuilder* builder) {
71   auto reshape_shape_type = mlir::RankedTensorType::get(
72       new_shape_array.size(), builder->getIntegerType(32));
73 
74   // This is to workaround the unnecessary cast i64 -> i32. :(
75   // TODO(renjieliu): Revisit this later.
76   SmallVector<int32_t, 4> new_shape_array_i32;
77   for (auto size : new_shape_array) {
78     new_shape_array_i32.push_back(static_cast<int32_t>(size));
79   }
80   auto new_shape_attr =
81       mlir::DenseIntElementsAttr::get(reshape_shape_type, new_shape_array_i32);
82 
83   auto new_shape = builder->create<TFL::ConstOp>(loc, new_shape_attr);
84 
85   auto reshape_out_type = RankedTensorType::get(new_shape_array, element_type);
86   return builder->create<TFL::ReshapeOp>(loc, reshape_out_type, input,
87                                          new_shape);
88 }
89 
EnsureBias(Operation * op,int bias_idx,PatternRewriter & rewriter)90 LogicalResult EnsureBias(Operation* op, int bias_idx,
91                          PatternRewriter& rewriter) {
92   auto bias = op->getOperand(bias_idx);
93 
94   if (!bias.getType().isa<NoneType>()) return failure();
95 
96   // Proceed to create a zero bias.
97   auto output = op->getResult(0);
98   auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
99   if (!output_type) return failure();
100 
101   // bias should be a vector sized of the last output dim.
102   int num_units = output_type.getDimSize(output_type.getRank() - 1);
103   auto bias_type =
104       mlir::RankedTensorType::get({num_units}, output_type.getElementType());
105 
106   mlir::DenseElementsAttr bias_attr;
107   if (output_type.getElementType().isF32()) {
108     float val = 0.0;
109     bias_attr = mlir::DenseFPElementsAttr::get(bias_type, val);
110   } else {
111     // TODO(renjieliu): Refactor this and share the logic with
112     // CreateConstOpWithSingleValue. Also, make sure it works with QConst.
113     return failure();
114   }
115 
116   auto zero_bias = rewriter.create<TFL::ConstOp>(op->getLoc(), bias_attr);
117   op->setOperand(bias_idx, zero_bias);
118 
119   return success();
120 }
121 
PadConstValues(Operation * input_op,int value_to_pad,int pad_dimensions,Location loc,OpBuilder * builder)122 TF::ConstOp PadConstValues(Operation* input_op, int value_to_pad,
123                            int pad_dimensions, Location loc,
124                            OpBuilder* builder) {
125   if (input_op == nullptr) return nullptr;
126 
127   mlir::DenseIntElementsAttr attr;
128   if (!matchPattern(input_op, m_Constant(&attr))) {
129     return nullptr;
130   }
131 
132   auto value_shape_type = mlir::RankedTensorType::get(
133       {pad_dimensions}, builder->getIntegerType(32));
134 
135   SmallVector<int32_t, 4> value_i32;
136   value_i32.reserve(pad_dimensions);
137   for (int i = 0; i < pad_dimensions - attr.getNumElements(); ++i) {
138     value_i32.push_back(value_to_pad);
139   }
140   for (const auto& size : attr) {
141     value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
142   }
143   auto new_value_i32_attr =
144       mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
145 
146   return builder->create<TF::ConstOp>(loc, new_value_i32_attr);
147 }
148 
SliceOutputs(Operation * split_op,Value input,RankedTensorType input_type,int split_dim,int num_splits,PatternRewriter * rewriter)149 SmallVector<Value, 4> SliceOutputs(Operation* split_op, Value input,
150                                    RankedTensorType input_type, int split_dim,
151                                    int num_splits, PatternRewriter* rewriter) {
152   SmallVector<Value, 4> slice_outputs;
153   int begin = 0;
154   for (int i = 0; i < num_splits; ++i) {
155     // Create slice op.
156     // Populate begin & size.
157     SmallVector<int32_t, 4> slice_begin;
158     SmallVector<int32_t, 4> slice_size;
159     auto current_output = split_op->getResult(i);
160     auto current_output_type =
161         current_output.getType().cast<RankedTensorType>();
162     for (int d = 0; d < input_type.getRank(); ++d) {
163       if (d == split_dim) {
164         // Split dimension.
165         slice_begin.push_back(begin);
166         int size = current_output_type.getDimSize(d);
167         slice_size.push_back(size);
168         begin += size;
169       } else {
170         slice_begin.push_back(0);
171         // -1 means every elements.
172         slice_size.push_back(-1);
173       }
174     }
175 
176     auto slice_type = mlir::RankedTensorType::get(slice_begin.size(),
177                                                   rewriter->getIntegerType(32));
178     auto slice_begin_attr =
179         mlir::DenseIntElementsAttr::get(slice_type, slice_begin);
180     auto slice_size_attr =
181         mlir::DenseIntElementsAttr::get(slice_type, slice_size);
182 
183     auto slice_begin_const =
184         rewriter->create<TFL::ConstOp>(split_op->getLoc(), slice_begin_attr);
185     auto slice_size_const =
186         rewriter->create<TFL::ConstOp>(split_op->getLoc(), slice_size_attr);
187 
188     auto slice_op = rewriter->create<TFL::SliceOp>(
189         split_op->getLoc(), current_output_type, input, slice_begin_const,
190         slice_size_const);
191 
192     // Rewire output.
193     slice_outputs.push_back(slice_op.getResult());
194   }
195   return slice_outputs;
196 }
197 
198 }  // namespace
199 
200 // ================== Pack ========================
201 
matchAndRewrite(TFL::PackOp pack_op,PatternRewriter & rewriter) const202 LogicalResult LowerPackIntoConcatReshape::matchAndRewrite(
203     TFL::PackOp pack_op, PatternRewriter& rewriter) const {
204   // Pack op should have same shape type.
205   SmallVector<Value, 5> pack_inputs(pack_op.values());
206   auto input_type = pack_inputs[0].getType().dyn_cast<RankedTensorType>();
207   if (!input_type) return failure();
208 
209   // Figure out output shapes.
210   SmallVector<int64_t, 4> concat_out_shape;
211   SmallVector<int64_t, 4> pack_out_shape;
212 
213   const int rank = input_type.getRank();
214   int pack_axis = pack_op.axis();
215   int count = pack_inputs.size();
216   if (pack_axis < 0) {
217     pack_axis += rank;
218   }
219 
220   // Concat out shape.
221   for (int i = 0; i < rank; ++i) {
222     int dim_size = input_type.getDimSize(i);
223     if (i == pack_axis) {
224       dim_size *= count;
225     }
226     concat_out_shape.push_back(dim_size);
227   }
228 
229   // Pack out shape.
230   int j = 0;
231   for (int i = 0; i < rank + 1; ++i) {
232     if (i == pack_axis) {
233       pack_out_shape.push_back(count);
234     } else {
235       pack_out_shape.push_back(input_type.getDimSize(j));
236       j++;
237     }
238   }
239 
240   if (failed(TF::VerifyShapeOfReshapeOp(pack_out_shape))) return failure();
241 
242   // Insert the concat op.
243   auto concat_out_type =
244       RankedTensorType::get(concat_out_shape, input_type.getElementType());
245   auto concat_op = rewriter.create<TFL::ConcatenationOp>(
246       pack_op.getLoc(), concat_out_type, pack_inputs, pack_op.axis(), "NONE");
247 
248   auto reshape_op =
249       InsertReshapeOp(pack_op.getLoc(), concat_op, input_type.getElementType(),
250                       pack_out_shape, &rewriter);
251 
252   // Rewire output & get rid of the pack op.
253   rewriter.replaceOp(pack_op, reshape_op.getResult());
254   return success();
255 }
256 
257 // ================== squared_difference ========================
258 
matchAndRewrite(TFL::SquaredDifferenceOp squared_diff_op,PatternRewriter & rewriter) const259 LogicalResult SquaredDifference::matchAndRewrite(
260     TFL::SquaredDifferenceOp squared_diff_op, PatternRewriter& rewriter) const {
261   auto x = squared_diff_op.lhs();
262   auto y = squared_diff_op.rhs();
263   auto x_type = x.getType().dyn_cast<RankedTensorType>();
264   auto y_type = y.getType().dyn_cast<RankedTensorType>();
265   if (!x_type || !y_type) return failure();
266   if (x_type.getShape() != y_type.getShape()) return failure();
267 
268   auto result_type = squared_diff_op.getType();
269   if (!result_type) return failure();
270 
271   auto sub_op =
272       rewriter.create<TF::SubOp>(squared_diff_op.getLoc(), result_type, x, y);
273   auto mul_op =
274       rewriter.create<TF::MulOp>(squared_diff_op.getLoc(), result_type,
275                                  sub_op.getResult(), sub_op.getResult());
276   rewriter.replaceOp(squared_diff_op, mul_op.getResult());
277 
278   return success();
279 }
280 
281 // ================== split ========================
282 
matchAndRewrite(TFL::SplitOp split_op,PatternRewriter & rewriter) const283 LogicalResult UnrollSplit::matchAndRewrite(TFL::SplitOp split_op,
284                                            PatternRewriter& rewriter) const {
285   auto num_splits = split_op.num_splits();
286   auto input = split_op.value();
287   auto input_type = input.getType().dyn_cast<RankedTensorType>();
288   if (input_type == nullptr || !input_type.hasStaticShape()) return failure();
289 
290   for (auto result : split_op.getResults()) {
291     auto result_type = result.getType().dyn_cast<RankedTensorType>();
292     if (result_type == nullptr) return failure();
293   }
294 
295   auto output = split_op.getResult(0);
296   auto output_type = output.getType().cast<RankedTensorType>();
297 
298   // TODO(renjieliu): change to use split_dim when we raise the constants
299   // as well.
300   int split_dim = -1;
301   for (int d = 0; d < input_type.getRank(); ++d) {
302     if (input_type.getDimSize(d) != output_type.getDimSize(d)) split_dim = d;
303   }
304 
305   const SmallVector<Value, 4>& slice_outputs = SliceOutputs(
306       split_op, input, input_type, split_dim, num_splits, &rewriter);
307   rewriter.replaceOp(split_op, slice_outputs);
308   return success();
309 }
310 
311 // ================== splitV ========================
312 
matchAndRewrite(TFL::SplitVOp splitv_op,PatternRewriter & rewriter) const313 LogicalResult UnrollSplitV::matchAndRewrite(TFL::SplitVOp splitv_op,
314                                             PatternRewriter& rewriter) const {
315   // We need to make sure both splits & split dim are constants.
316   auto splits = splitv_op.size_splits().getDefiningOp();
317   mlir::DenseIntElementsAttr splits_attr;
318   if (!splits || !matchPattern(splits, m_Constant(&splits_attr)))
319     return failure();
320 
321   auto split_dim = splitv_op.split_dim().getDefiningOp();
322   mlir::ElementsAttr split_dim_attr;
323   if (!split_dim || !matchPattern(split_dim, m_Constant(&split_dim_attr)))
324     return failure();
325 
326   auto input = splitv_op.value();
327   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
328   if (!input_type || !input_type.hasRank()) return failure();
329 
330   for (auto result : splitv_op.getResults()) {
331     auto result_type = result.getType().dyn_cast<RankedTensorType>();
332     if (result_type == nullptr) return failure();
333   }
334 
335   const int rank = input_type.getRank();
336 
337   IntegerAttr dim_int = ExtractSingleElementAsInteger(split_dim_attr);
338 
339   // "axis" operand could be a i64 tensor. Resolve it here.
340   IntegerAttr dim_i32;
341   if (failed(ConvertToI32Attr(dim_int, &dim_i32))) return failure();
342 
343   int dim = dim_i32.getInt();
344   if (dim < 0) dim += rank;
345 
346   const SmallVector<Value, 4>& slice_outputs = SliceOutputs(
347       splitv_op, input, input_type, dim, splitv_op.num_splits(), &rewriter);
348   rewriter.replaceOp(splitv_op, slice_outputs);
349 
350   return success();
351 }
352 
353 // ================== conv_2d ========================
354 
matchAndRewrite(TFL::Conv2DOp conv_op,PatternRewriter & rewriter) const355 LogicalResult EnsureBiasForConv2d::matchAndRewrite(
356     TFL::Conv2DOp conv_op, PatternRewriter& rewriter) const {
357   return EnsureBias(conv_op, 2, rewriter);
358 }
359 
360 // ================== slice ============================
361 
362 // If a slice op has < 4d dimension, will pad it to 4d.
matchAndRewrite(TFL::SliceOp slice_op,PatternRewriter & rewriter) const363 LogicalResult PadSlice::matchAndRewrite(TFL::SliceOp slice_op,
364                                         PatternRewriter& rewriter) const {
365   // We have to know the shape of the input, as well as the begin/size.
366   // also, begin and size have to be constants.
367   auto input = slice_op.input();
368   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
369   if (!input_type || !input_type.hasStaticShape()) return failure();
370 
371   if (input_type.getRank() >= 4) return failure();
372 
373   auto begin = slice_op.begin();
374   auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
375   if (!begin_type || !begin_type.hasStaticShape()) return failure();
376 
377   auto size = slice_op.size();
378   auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
379   if (!size_type || !size_type.hasStaticShape()) return failure();
380 
381   auto output_type = slice_op.getType().dyn_cast_or_null<RankedTensorType>();
382   if (!output_type || !output_type.hasStaticShape()) return failure();
383 
384   // Pad 0s in front of the begin.
385   TF::ConstOp new_begin =
386       PadConstValues(begin.getDefiningOp(), 0, 4, slice_op.getLoc(), &rewriter);
387   if (!new_begin) return failure();
388 
389   // Pad 1s in front of the size.
390   TF::ConstOp new_size =
391       PadConstValues(size.getDefiningOp(), 1, 4, slice_op.getLoc(), &rewriter);
392   if (!new_size) return failure();
393 
394   // Reshape the input to 4d.
395   SmallVector<int64_t, 4> new_shape;
396   const int current_rank = input_type.getRank();
397   for (int i = 0; i < 4 - current_rank; ++i) {
398     new_shape.push_back(1);
399   }
400   for (auto size : input_type.getShape()) {
401     new_shape.push_back(size);
402   }
403 
404   auto reshape_op =
405       InsertReshapeOp(slice_op.getLoc(), input, input_type.getElementType(),
406                       new_shape, &rewriter);
407 
408   // Replace with the new slice op.
409   SmallVector<int64_t, 4> new_output_shape;
410   for (int i = 0; i < 4 - current_rank; ++i) {
411     new_output_shape.push_back(1);
412   }
413   for (auto size : output_type.getShape()) {
414     new_output_shape.push_back(size);
415   }
416 
417   RankedTensorType new_output_type =
418       RankedTensorType::get(new_output_shape, output_type.getElementType());
419 
420   auto new_slice = rewriter.create<TFL::SliceOp>(
421       slice_op.getLoc(), new_output_type, reshape_op, new_begin, new_size);
422 
423   // Append a reshape at the bottom.
424   auto output_reshape_op = InsertReshapeOp(slice_op.getLoc(), new_slice,
425                                            output_type.getElementType(),
426                                            output_type.getShape(), &rewriter);
427   rewriter.replaceOp(slice_op, output_reshape_op.getResult());
428 
429   return success();
430 }
431 
432 // ================== fully_connected ========================
433 
434 // TFL fully_connected basically does:
435 // Weight * Input + bias.
436 // Input layout is : [..., depth]
437 // Weight layout is : [output, depth]
438 // Bias is [output].
439 //
440 // While conv2d is:
441 // Filter: [NHWC]
442 // Input is also: [NHWC]
443 // Bias is [N]
444 //
445 // So to perform the transform, we need to insert a few reshape ops:
446 //
447 //  Input   weight   bias
448 //   \      /      /
449 //       FC
450 //       |
451 //     output
452 //
453 //     |
454 //    \/
455 //
456 //  Input   weight
457 //   |        |
458 //  Reshape  Reshape  bias
459 //  |         |      /
460 //     conv
461 //      |
462 //     reshape
463 //      |
464 //    output
matchAndRewrite(TFL::FullyConnectedOp fc_op,PatternRewriter & rewriter) const465 LogicalResult FullyConnectedToConv::matchAndRewrite(
466     TFL::FullyConnectedOp fc_op, PatternRewriter& rewriter) const {
467   // We have to know the shape of the input.
468   auto input = fc_op.input();
469   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
470   if (!input_type || !input_type.hasStaticShape()) return failure();
471 
472   // We have to know the shape of the weight.
473   auto weight = fc_op.filter();
474   auto weight_type = weight.getType().dyn_cast_or_null<RankedTensorType>();
475   if (!weight_type || !weight_type.hasStaticShape()) return failure();
476 
477   // We have to know the shape of the output as well.
478   auto output = fc_op.getResult(0);
479   auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
480   if (!output_type || !output_type.hasStaticShape()) return failure();
481 
482   // Insert a reshape after the input.
483   // Since the input maybe more than 2-d, we may collect the flat size of the
484   // input then reshape into [1, 1, flat_size / depth, depth].
485   const int depth = input_type.getDimSize(input_type.getRank() - 1);
486   const int flat_size = input_type.getNumElements();
487   const int width = flat_size / depth;
488   SmallVector<int64_t, 4> input_new_shape({1, 1, width, depth});
489   auto reshaped_input =
490       InsertReshapeOp(fc_op.getLoc(), input, input_type.getElementType(),
491                       input_new_shape, &rewriter);
492 
493   // Insert a reshape after the weight.
494   // We will reshape the weight into [output, 1, 1, depth]
495   const int output_size = weight_type.getDimSize(0);
496   SmallVector<int64_t, 2> weight_new_shape({output_size, 1, 1, depth});
497   auto reshaped_weight =
498       InsertReshapeOp(fc_op.getLoc(), weight, weight_type.getElementType(),
499                       weight_new_shape, &rewriter);
500 
501   // Replace the fc with conv.
502   // The output would be [1, 1, width, output].
503   auto conv_output_type = RankedTensorType::get({1, 1, width, output_size},
504                                                 output_type.getElementType());
505   auto conv = rewriter.create<TFL::Conv2DOp>(
506       fc_op.getLoc(), conv_output_type, reshaped_input, reshaped_weight,
507       fc_op.bias(), rewriter.getI32IntegerAttr(1),
508       rewriter.getI32IntegerAttr(1), fc_op.fused_activation_functionAttr(),
509       rewriter.getStringAttr("VALID"), rewriter.getI32IntegerAttr(1),
510       rewriter.getI32IntegerAttr(1));
511 
512   // Insert a shape after the conv.
513   auto reshaped_conv =
514       InsertReshapeOp(fc_op.getLoc(), conv, output_type.getElementType(),
515                       output_type.getShape(), &rewriter);
516 
517   rewriter.replaceOp(fc_op, reshaped_conv.getResult());
518 
519   return success();
520 }
521 
522 // ================== concat ============================
523 
524 // If a concat op has < 4d dimension, will pad it to 4d.
matchAndRewrite(TFL::ConcatenationOp concat_op,PatternRewriter & rewriter) const525 LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op,
526                                          PatternRewriter& rewriter) const {
527   int rank = -1;
528   for (auto input : concat_op.values()) {
529     auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
530     if (!input_type || !input_type.hasStaticShape()) return failure();
531 
532     rank = input_type.getRank();
533   }
534 
535   auto output_type = concat_op.getType().dyn_cast_or_null<RankedTensorType>();
536   if (!output_type || !output_type.hasStaticShape()) return failure();
537 
538   if (rank >= 4) return failure();
539 
540   // All values should have the same rank.
541   // We will insert a reshape op after every input.
542   SmallVector<Value, 4> reshape_ops;
543   for (auto input : concat_op.values()) {
544     auto input_type = input.getType().cast<RankedTensorType>();
545     // Get the new shape.
546     SmallVector<int64_t, 4> new_shape;
547     for (int i = 0; i < 4 - rank; ++i) {
548       new_shape.push_back(1);
549     }
550     for (auto size : input_type.getShape()) {
551       new_shape.push_back(size);
552     }
553 
554     auto reshape_op =
555         InsertReshapeOp(concat_op.getLoc(), input, input_type.getElementType(),
556                         new_shape, &rewriter);
557     reshape_ops.push_back(reshape_op.getResult());
558   }
559 
560   // Deal with the axis.
561   // We don't need to handle axis < 0, since it's counting reversely.
562   int32_t axis = concat_op.axis();
563   if (axis >= 0) {
564     axis += (4 - rank);
565   }
566 
567   // Replace with the new concat op.
568   SmallVector<int64_t, 4> new_output_shape;
569   for (int i = 0; i < 4 - rank; ++i) {
570     new_output_shape.push_back(1);
571   }
572   for (auto size : output_type.getShape()) {
573     new_output_shape.push_back(size);
574   }
575 
576   RankedTensorType new_output_type =
577       RankedTensorType::get(new_output_shape, output_type.getElementType());
578 
579   auto new_concat = rewriter.create<TFL::ConcatenationOp>(
580       concat_op.getLoc(), new_output_type, reshape_ops, axis,
581       concat_op.fused_activation_function());
582 
583   // Append a reshape at the bottom.
584   auto output_reshape_op = InsertReshapeOp(concat_op.getLoc(), new_concat,
585                                            output_type.getElementType(),
586                                            output_type.getShape(), &rewriter);
587   rewriter.replaceOp(concat_op, output_reshape_op.getResult());
588 
589   return success();
590 }
591 
592 }  // namespace tac
593 }  // namespace TFL
594 }  // namespace mlir
595