• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Legalize TensorFlow Lite to TOSA
17 
18 #include <climits>
19 #include <cmath>
20 #include <cstddef>
21 #include <cstdint>
22 #include <fstream>
23 #include <iterator>
24 #include <limits>
25 #include <numeric>
26 #include <string>
27 #include <unordered_set>
28 
29 #include "llvm/ADT/ArrayRef.h"
30 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
31 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
32 #include "mlir/Dialect/Traits.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/Matchers.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
41 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
42 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
44 
45 #define PASS_NAME "tosa-legalize-tfl"
46 #define DEBUG_TYPE PASS_NAME
47 #define HARDSWISH_EXPLICIT_RESCALING false
48 
49 namespace mlir {
50 namespace tosa {
51 namespace {
52 
53 // Performs lowering to TOSA dialect.
54 class LegalizeTFL : public TosaLegalizeTFLPassBase<LegalizeTFL> {
55  public:
56   LegalizeTFL() = default;
LegalizeTFL(ArrayRef<std::string> disabled_patterns,ArrayRef<std::string> enabled_patterns)57   explicit LegalizeTFL(ArrayRef<std::string> disabled_patterns,
58                        ArrayRef<std::string> enabled_patterns) {
59     this->disabled_patterns_ = disabled_patterns;
60     this->enabled_patterns_ = enabled_patterns;
61   }
62   void runOnOperation() override;
63   LogicalResult initialize(MLIRContext* context) override;
64 
65  private:
66   FrozenRewritePatternSet frozen_patterns_;
67 };
68 
69 #include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
70 
71 // Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
72 // bits. Need to do a customized truncate here instead of tablegen to handle
73 // attribute with negative value.
74 struct ConvertConstantOp : public RewritePattern {
ConvertConstantOpmlir::tosa::__anon447d85b10111::ConvertConstantOp75   explicit ConvertConstantOp(MLIRContext* context)
76       : RewritePattern(arith::ConstantOp::getOperationName(), 1, context) {}
77   LogicalResult matchAndRewrite(Operation* op,
78                                 PatternRewriter& rewriter) const override;
79 };
80 
81 #define DECL_CONVERT_OP(tfl_op)                                              \
82   struct ConvertTFL##tfl_op##Op : public RewritePattern {                    \
83     explicit ConvertTFL##tfl_op##Op(MLIRContext* context)                    \
84         : RewritePattern(TFL::tfl_op##Op::getOperationName(), 1, context) {} \
85     LogicalResult matchAndRewrite(Operation* op,                             \
86                                   PatternRewriter& rewriter) const override; \
87   }
88 DECL_CONVERT_OP(Relu);
89 DECL_CONVERT_OP(Relu1);
90 DECL_CONVERT_OP(Relu6);
91 DECL_CONVERT_OP(Equal);
92 DECL_CONVERT_OP(NotEqual);
93 DECL_CONVERT_OP(Greater);
94 DECL_CONVERT_OP(GreaterEqual);
95 DECL_CONVERT_OP(Add);
96 DECL_CONVERT_OP(Sub);
97 DECL_CONVERT_OP(Mul);
98 DECL_CONVERT_OP(Square);
99 DECL_CONVERT_OP(SquaredDifference);
100 DECL_CONVERT_OP(Round);
101 DECL_CONVERT_OP(Div);
102 DECL_CONVERT_OP(Maximum);
103 DECL_CONVERT_OP(Minimum);
104 DECL_CONVERT_OP(FloorMod);
105 DECL_CONVERT_OP(FloorDiv);
106 DECL_CONVERT_OP(AddN);
107 DECL_CONVERT_OP(AveragePool2D);
108 DECL_CONVERT_OP(MaxPool2D);
109 DECL_CONVERT_OP(Concatenation);
110 DECL_CONVERT_OP(Reshape);
111 DECL_CONVERT_OP(Rank);
112 DECL_CONVERT_OP(Shape);
113 DECL_CONVERT_OP(ExpandDims);
114 DECL_CONVERT_OP(Squeeze);
115 DECL_CONVERT_OP(Fill);
116 DECL_CONVERT_OP(Elu);
117 DECL_CONVERT_OP(Softmax);
118 DECL_CONVERT_OP(LogSoftmax);
119 DECL_CONVERT_OP(Sqrt);
120 DECL_CONVERT_OP(L2Normalization);
121 DECL_CONVERT_OP(ReduceAny);
122 DECL_CONVERT_OP(ReduceMax);
123 DECL_CONVERT_OP(ReduceMin);
124 DECL_CONVERT_OP(Mean);
125 DECL_CONVERT_OP(ReduceProd);
126 DECL_CONVERT_OP(Sum);
127 DECL_CONVERT_OP(Conv2D);
128 DECL_CONVERT_OP(TransposeConv);
129 DECL_CONVERT_OP(DepthwiseConv2D);
130 DECL_CONVERT_OP(FullyConnected);
131 DECL_CONVERT_OP(BatchMatMul);
132 DECL_CONVERT_OP(Split);
133 DECL_CONVERT_OP(SplitV);
134 DECL_CONVERT_OP(Pack);
135 DECL_CONVERT_OP(Unpack);
136 DECL_CONVERT_OP(Transpose);
137 DECL_CONVERT_OP(Tile);
138 DECL_CONVERT_OP(Slice);
139 DECL_CONVERT_OP(StridedSlice);
140 DECL_CONVERT_OP(HardSwish);
141 DECL_CONVERT_OP(ZerosLike);
142 DECL_CONVERT_OP(Less);
143 DECL_CONVERT_OP(LessEqual);
144 DECL_CONVERT_OP(Pad);
145 DECL_CONVERT_OP(PadV2);
146 DECL_CONVERT_OP(ResizeBilinear);
147 DECL_CONVERT_OP(ResizeNearestNeighbor);
148 DECL_CONVERT_OP(Select);
149 DECL_CONVERT_OP(SelectV2);
150 DECL_CONVERT_OP(SpaceToBatchNd);
151 DECL_CONVERT_OP(BatchToSpaceNd);
152 DECL_CONVERT_OP(SpaceToDepth);
153 DECL_CONVERT_OP(DepthToSpace);
154 DECL_CONVERT_OP(Sin);
155 DECL_CONVERT_OP(Cos);
156 DECL_CONVERT_OP(Logistic);
157 DECL_CONVERT_OP(Tanh);
158 DECL_CONVERT_OP(PRelu);
159 DECL_CONVERT_OP(LeakyRelu);
160 DECL_CONVERT_OP(Neg);
161 DECL_CONVERT_OP(Yield);
162 DECL_CONVERT_OP(Custom);
163 DECL_CONVERT_OP(ReverseV2);
164 DECL_CONVERT_OP(Quantize);
165 DECL_CONVERT_OP(Dequantize);
166 DECL_CONVERT_OP(Const);
167 DECL_CONVERT_OP(QConst);
168 DECL_CONVERT_OP(Gather);
169 DECL_CONVERT_OP(GatherNd);
170 DECL_CONVERT_OP(SparseToDense);
171 DECL_CONVERT_OP(OneHot);
172 DECL_CONVERT_OP(ArgMax);
173 DECL_CONVERT_OP(FakeQuant);
174 
175 #undef DECL_CONVERT_OP
176 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const177 LogicalResult ConvertTFLReluOp::matchAndRewrite(
178     Operation* op, PatternRewriter& rewriter) const {
179   auto tfl_relu_op = cast<TFL::ReluOp>(op);
180 
181   ShapedType input_type = tfl_relu_op.x().getType().dyn_cast<ShapedType>();
182   ShapedType output_type =
183       tfl_relu_op.getResult().getType().dyn_cast<ShapedType>();
184   // Not a ranked tensor output
185   if (!input_type || !output_type) return failure();
186 
187   bool input_is_qtype =
188       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
189   bool output_is_qtype =
190       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
191 
192   if (input_is_qtype != output_is_qtype) {
193     return op->emitOpError(
194         "ConvertTFLReluOp: input/output tensor should "
195         "be all quantized or all floating-point.");
196   }
197 
198   int64_t clamp_min = 0;
199   Value clamp_in = tfl_relu_op.x();
200 
201   if (output_is_qtype) {
202     UniformQuantizedType input_qtype =
203         input_type.getElementType()
204             .dyn_cast<mlir::quant::UniformQuantizedType>();
205     UniformQuantizedType output_qtype =
206         output_type.getElementType()
207             .dyn_cast<mlir::quant::UniformQuantizedType>();
208 
209     clamp_min = output_qtype.getZeroPoint();
210     clamp_in =
211         buildRescale(rewriter, op, output_type, tfl_relu_op.x(),
212                      input_qtype.getScale() / output_qtype.getScale(),
213                      input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
214                      /*double_round=*/false, /*scale32=*/true);
215   }
216 
217   CreateReplaceOpAndInfer<tosa::ClampOp>(
218       rewriter, op, output_type, clamp_in,
219       rewriter.getI64IntegerAttr(clamp_min),
220       rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
221       rewriter.getF32FloatAttr(0.0f),
222       rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
223 
224   return success();
225 }
226 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const227 LogicalResult ConvertTFLRelu1Op::matchAndRewrite(
228     Operation* op, PatternRewriter& rewriter) const {
229   auto tfl_relu1_op = cast<TFL::Relu1Op>(op);
230 
231   ShapedType input_type = tfl_relu1_op.x().getType().dyn_cast<ShapedType>();
232   ShapedType output_type =
233       tfl_relu1_op.getResult().getType().dyn_cast<ShapedType>();
234   // Not a ranked tensor output
235   if (!input_type || !output_type) return failure();
236 
237   bool input_is_qtype =
238       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
239   bool output_is_qtype =
240       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
241 
242   if (input_is_qtype != output_is_qtype) {
243     return op->emitOpError(
244         "ConvertTFLRelu1Op: input/output tensor should "
245         "be all quantized or all floating-point.");
246   }
247 
248   int64_t clamp_min = -1;
249   int64_t clamp_max = 1;
250   Value clamp_in = tfl_relu1_op.x();
251 
252   if (output_is_qtype && input_is_qtype) {
253     UniformQuantizedType input_qtype =
254         input_type.getElementType()
255             .dyn_cast<mlir::quant::UniformQuantizedType>();
256     UniformQuantizedType output_qtype =
257         output_type.getElementType()
258             .dyn_cast<mlir::quant::UniformQuantizedType>();
259 
260     clamp_min = output_qtype.getZeroPoint() -
261                 std::llround(1.0f / output_qtype.getScale());
262 
263     clamp_max = std::llround(1.0f / output_qtype.getScale()) +
264                 output_qtype.getZeroPoint();
265 
266     clamp_in =
267         buildRescale(rewriter, op, output_type, tfl_relu1_op.x(),
268                      input_qtype.getScale() / output_qtype.getScale(),
269                      input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
270                      /*double_round=*/false, /*scale32=*/true);
271   }
272 
273   CreateReplaceOpAndInfer<tosa::ClampOp>(rewriter, op, output_type, clamp_in,
274                                          rewriter.getI64IntegerAttr(clamp_min),
275                                          rewriter.getI64IntegerAttr(clamp_max),
276                                          rewriter.getF32FloatAttr(-1.0f),
277                                          rewriter.getF32FloatAttr(1.0f));
278 
279   return success();
280 }
281 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const282 LogicalResult ConvertTFLRelu6Op::matchAndRewrite(
283     Operation* op, PatternRewriter& rewriter) const {
284   auto tfl_relu6_op = cast<TFL::Relu6Op>(op);
285 
286   ShapedType input_type = tfl_relu6_op.x().getType().dyn_cast<ShapedType>();
287   ShapedType output_type =
288       tfl_relu6_op.getResult().getType().dyn_cast<ShapedType>();
289   // Not a ranked tensor output
290   if (!input_type || !output_type) return failure();
291 
292   bool input_is_qtype =
293       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
294   bool output_is_qtype =
295       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
296 
297   if (input_is_qtype != output_is_qtype) {
298     return op->emitOpError(
299         "ConvertTFLRelu6Op: input/output tensor should "
300         "be all quantized or all floating-point.");
301   }
302 
303   int64_t clamp_min = 0;
304   int64_t clamp_max = 6;
305   Value clamp_in = tfl_relu6_op.x();
306 
307   if (output_is_qtype && input_is_qtype) {
308     UniformQuantizedType input_qtype =
309         input_type.getElementType()
310             .dyn_cast<mlir::quant::UniformQuantizedType>();
311     UniformQuantizedType output_qtype =
312         output_type.getElementType()
313             .dyn_cast<mlir::quant::UniformQuantizedType>();
314 
315     clamp_min = output_qtype.getZeroPoint();
316     clamp_max = std::llround(6.0f / output_qtype.getScale()) +
317                 output_qtype.getZeroPoint();
318 
319     clamp_in =
320         buildRescale(rewriter, op, output_type, tfl_relu6_op.x(),
321                      input_qtype.getScale() / output_qtype.getScale(),
322                      input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
323                      /*double_round=*/false, /*scale32=*/true);
324   }
325 
326   CreateReplaceOpAndInfer<tosa::ClampOp>(rewriter, op, output_type, clamp_in,
327                                          rewriter.getI64IntegerAttr(clamp_min),
328                                          rewriter.getI64IntegerAttr(clamp_max),
329                                          rewriter.getF32FloatAttr(0.0f),
330                                          rewriter.getF32FloatAttr(6.0f));
331 
332   return success();
333 }
334 
prepareMatchAndRewriteComparison(Operation * op,mlir::OperandRange operands,PatternRewriter & rewriter,llvm::SmallVectorImpl<Value> & newOperands)335 static LogicalResult prepareMatchAndRewriteComparison(
336     Operation* op, mlir::OperandRange operands, PatternRewriter& rewriter,
337     llvm::SmallVectorImpl<Value>& newOperands) {
338   Value x = operands[0];
339   Value y = operands[1];
340   Value result = op->getResult(0);
341 
342   ShapedType input_x_type = x.getType().dyn_cast<ShapedType>();
343   ShapedType input_y_type = y.getType().dyn_cast<ShapedType>();
344   ShapedType output_type = result.getType().dyn_cast<ShapedType>();
345   // Not a shaped tensor output
346   if (!input_x_type || !input_y_type || !output_type) return failure();
347 
348   bool input_x_is_qtype =
349       input_x_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
350   bool input_y_is_qtype =
351       input_y_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
352   bool output_is_qtype =
353       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
354 
355   if (input_x_is_qtype != input_y_is_qtype ||
356       input_y_is_qtype != output_is_qtype) {
357     return op->emitOpError(
358         "ConvertTFLEqualOp: input/output tensor should "
359         "be all quantized or all floating-point.");
360   }
361 
362   if (!output_is_qtype && !input_x_is_qtype && !input_y_is_qtype) {
363     newOperands.push_back(x);
364     newOperands.push_back(y);
365     return success();
366   }
367 
368   UniformQuantizedType input_x_qtype =
369       input_x_type.getElementType()
370           .dyn_cast<mlir::quant::UniformQuantizedType>();
371   UniformQuantizedType input_y_qtype =
372       input_y_type.getElementType()
373           .dyn_cast<mlir::quant::UniformQuantizedType>();
374 
375   if (input_x_qtype.getScale() != input_y_qtype.getScale() ||
376       input_x_qtype.getZeroPoint() != input_y_qtype.getZeroPoint()) {
377     return op->emitOpError(
378         "ConvertTFLEqualOp: input_x and input_y scale/zp "
379         "must be the same");
380   }
381 
382   x = buildRescaleToInt32(rewriter, op, x, 1.0f, input_x_qtype.getZeroPoint());
383   y = buildRescaleToInt32(rewriter, op, y, 1.0f, input_y_qtype.getZeroPoint());
384 
385   newOperands.push_back(x);
386   newOperands.push_back(y);
387   return success();
388 }
389 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const390 LogicalResult ConvertTFLEqualOp::matchAndRewrite(
391     Operation* op, PatternRewriter& rewriter) const {
392   llvm::SmallVector<Value, 2> newOperands;
393   LogicalResult status = prepareMatchAndRewriteComparison(
394       op, op->getOperands(), rewriter, newOperands);
395   if (status.failed()) return failure();
396 
397   CreateReplaceOpAndInfer<tosa::EqualOp>(
398       rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
399 
400   return success();
401 }
402 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const403 LogicalResult ConvertTFLNotEqualOp::matchAndRewrite(
404     Operation* op, PatternRewriter& rewriter) const {
405   llvm::SmallVector<Value, 2> newOperands;
406   LogicalResult status = prepareMatchAndRewriteComparison(
407       op, op->getOperands(), rewriter, newOperands);
408   if (status.failed()) return failure();
409 
410   auto equal_op = CreateOpAndInfer<tosa::EqualOp>(
411       rewriter, op->getLoc(), op->getResult(0).getType(), newOperands[0],
412       newOperands[1]);
413 
414   CreateReplaceOpAndInfer<tosa::LogicalNotOp>(
415       rewriter, op, op->getResult(0).getType(), equal_op);
416 
417   return success();
418 }
419 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const420 LogicalResult ConvertTFLGreaterOp::matchAndRewrite(
421     Operation* op, PatternRewriter& rewriter) const {
422   llvm::SmallVector<Value, 2> newOperands;
423   LogicalResult status = prepareMatchAndRewriteComparison(
424       op, op->getOperands(), rewriter, newOperands);
425   if (status.failed()) return failure();
426 
427   CreateReplaceOpAndInfer<tosa::GreaterOp>(
428       rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
429 
430   return success();
431 }
432 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const433 LogicalResult ConvertTFLGreaterEqualOp::matchAndRewrite(
434     Operation* op, PatternRewriter& rewriter) const {
435   llvm::SmallVector<Value, 2> newOperands;
436   LogicalResult status = prepareMatchAndRewriteComparison(
437       op, op->getOperands(), rewriter, newOperands);
438   if (status.failed()) return failure();
439 
440   CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(
441       rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
442 
443   return success();
444 }
445 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const446 LogicalResult ConvertTFLLessOp::matchAndRewrite(
447     Operation* op, PatternRewriter& rewriter) const {
448   llvm::SmallVector<Value, 2> newOperands;
449   LogicalResult status = prepareMatchAndRewriteComparison(
450       op, op->getOperands(), rewriter, newOperands);
451   if (status.failed()) return failure();
452 
453   CreateReplaceOpAndInfer<tosa::GreaterOp>(
454       rewriter, op, op->getResult(0).getType(), newOperands[1], newOperands[0]);
455   return success();
456 }
457 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const458 LogicalResult ConvertTFLLessEqualOp::matchAndRewrite(
459     Operation* op, PatternRewriter& rewriter) const {
460   llvm::SmallVector<Value, 2> newOperands;
461   LogicalResult status = prepareMatchAndRewriteComparison(
462       op, op->getOperands(), rewriter, newOperands);
463   if (status.failed()) return failure();
464 
465   // Swapping the args handles the greater/less difference.
466   CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(
467       rewriter, op, op->getResult(0).getType(), newOperands[1], newOperands[0]);
468 
469   return success();
470 }
471 
472 template <typename TflOp, typename TosaOp>
matchAndRewriteAddSub(Operation * op,mlir::OperandRange operands,PatternRewriter & rewriter)473 static LogicalResult matchAndRewriteAddSub(Operation* op,
474                                            mlir::OperandRange operands,
475                                            PatternRewriter& rewriter) {
476   auto tfl_add_op = cast<TflOp>(op);
477 
478   ShapedType input_lhs_type =
479       tfl_add_op.lhs().getType().template dyn_cast<ShapedType>();
480   ShapedType input_rhs_type =
481       tfl_add_op.rhs().getType().template dyn_cast<ShapedType>();
482   ShapedType output_type =
483       tfl_add_op.getResult().getType().template dyn_cast<ShapedType>();
484   // Not a ranked tensor output
485   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
486 
487   bool input_lhs_is_qtype =
488       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
489   bool input_rhs_is_qtype =
490       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
491   bool output_is_qtype =
492       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
493 
494   if (input_lhs_is_qtype != output_is_qtype ||
495       input_rhs_is_qtype != output_is_qtype) {
496     return op->emitOpError(
497         "ConvertTFLAddOp: input/output tensor should "
498         "be all quantized or all floating-point.");
499   }
500 
501   Value output;
502   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
503     ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
504     UniformQuantizedType input_lhs_qtype =
505         input_lhs_type.getElementType()
506             .dyn_cast<mlir::quant::UniformQuantizedType>();
507     UniformQuantizedType input_rhs_qtype =
508         input_rhs_type.getElementType()
509             .dyn_cast<mlir::quant::UniformQuantizedType>();
510     UniformQuantizedType output_qtype =
511         output_type.getElementType()
512             .dyn_cast<mlir::quant::UniformQuantizedType>();
513 
514     // Following quantization described in tensorflow/lite/kernels/add.cc
515     // In details it does:
516     // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
517     // 2. Extra left shift to input to increase precision
518     // Where input_shift = 20 if input is 8-bit
519     // input_shift = 15 if input is 16-bit
520     double in_lhs_scale = input_lhs_qtype.getScale();
521     double in_rhs_scale = input_rhs_qtype.getScale();
522     double output_scale = output_qtype.getScale();
523     double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
524 
525     const int32_t SHIFT_8_BIT = 20;
526     const int32_t SHIFT_16_BIT = 15;
527 
528     int32_t input_shift = (output_qtype.getStorageTypeIntegralWidth() == 16)
529                               ? SHIFT_16_BIT
530                               : SHIFT_8_BIT;
531 
532     double lhs_rescale_scale =
533         static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
534     double rhs_rescale_scale =
535         static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
536     double output_rescale_scale =
537         max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
538 
539     Value op1_rescale_lhs =
540         buildRescaleToInt32(rewriter, op, tfl_add_op.lhs(), lhs_rescale_scale,
541                             input_lhs_qtype.getZeroPoint());
542     Value op2_rescale_rhs =
543         buildRescaleToInt32(rewriter, op, tfl_add_op.rhs(), rhs_rescale_scale,
544                             input_rhs_qtype.getZeroPoint());
545     auto op3_add_op1_op2 = CreateOpAndInfer<TosaOp>(
546         rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
547     Value op4_rescale_op3 = buildRescaleFromInt32(
548         rewriter, op, output_type, op3_add_op1_op2.getResult(),
549         output_rescale_scale, output_qtype.getZeroPoint());
550     output = op4_rescale_op3;
551   } else {
552     auto op1_add_in =
553         CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), output_type,
554                                  tfl_add_op.lhs(), tfl_add_op.rhs());
555 
556     output = op1_add_in.getResult();
557   }
558 
559   auto fused_activation_fn = tfl_add_op.fused_activation_functionAttr();
560 
561   if (fused_activation_fn) {
562     llvm::Optional<Value> fused_activation_val =
563         convertFusedActivation(rewriter, op, output, fused_activation_fn);
564 
565     if (!fused_activation_val) return failure();
566 
567     rewriter.replaceOp(op, {fused_activation_val.getValue()});
568     return success();
569   }
570 
571   rewriter.replaceOp(op, {output});
572   return success();
573 }
574 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const575 LogicalResult ConvertTFLAddOp::matchAndRewrite(
576     Operation* op, PatternRewriter& rewriter) const {
577   return matchAndRewriteAddSub<TFL::AddOp, tosa::AddOp>(op, op->getOperands(),
578                                                         rewriter);
579 }
580 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const581 LogicalResult ConvertTFLSubOp::matchAndRewrite(
582     Operation* op, PatternRewriter& rewriter) const {
583   return matchAndRewriteAddSub<TFL::SubOp, tosa::SubOp>(op, op->getOperands(),
584                                                         rewriter);
585 }
586 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const587 LogicalResult ConvertTFLMulOp::matchAndRewrite(
588     Operation* op, PatternRewriter& rewriter) const {
589   auto tfl_mul_op = cast<TFL::MulOp>(op);
590 
591   llvm::Optional<Value> result = convertMultiplyOp(
592       rewriter, op, tfl_mul_op.getResult(), tfl_mul_op.lhs(), tfl_mul_op.rhs());
593 
594   if (!result) return failure();
595 
596   auto fused_activation_fn = tfl_mul_op.fused_activation_functionAttr();
597 
598   if (fused_activation_fn) {
599     llvm::Optional<Value> fused_activation_val = convertFusedActivation(
600         rewriter, op, result.getValue(), fused_activation_fn);
601 
602     if (!fused_activation_val) return failure();
603 
604     rewriter.replaceOp(op, {fused_activation_val.getValue()});
605     return success();
606   }
607 
608   rewriter.replaceOp(op, {result.getValue()});
609   return success();
610 }
611 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const612 LogicalResult ConvertTFLSquareOp::matchAndRewrite(
613     Operation* op, PatternRewriter& rewriter) const {
614   auto tfl_square_op = cast<TFL::SquareOp>(op);
615 
616   llvm::Optional<Value> result =
617       convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
618                         tfl_square_op.x(), tfl_square_op.x());
619 
620   if (!result) return failure();
621 
622   rewriter.replaceOp(op, {result.getValue()});
623   return success();
624 }
625 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const626 LogicalResult ConvertTFLSquaredDifferenceOp::matchAndRewrite(
627     Operation* op, PatternRewriter& rewriter) const {
628   auto tfl_squared_op = cast<TFL::SquaredDifferenceOp>(op);
629 
630   llvm::Optional<Value> result =
631       convertSquaredDifferenceOp(rewriter, op, tfl_squared_op.getResult(),
632                                  tfl_squared_op.lhs(), tfl_squared_op.rhs());
633 
634   if (!result) return failure();
635 
636   rewriter.replaceOp(op, {result.getValue()});
637   return success();
638 }
639 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const640 LogicalResult ConvertTFLRoundOp::matchAndRewrite(
641     Operation* op, PatternRewriter& rewriter) const {
642   auto tfl_round_op = cast<TFL::RoundOp>(op);
643 
644   ShapedType input_type = tfl_round_op.x().getType().dyn_cast<ShapedType>();
645   if (!input_type) {
646     return op->emitOpError("Round: input not shaped tensor type");
647   }
648 
649   if (input_type.getElementType().isa<FloatType>()) {
650     llvm::Optional<Value> result = convertRoundOp(
651         rewriter, op, tfl_round_op.getResult(), tfl_round_op.x());
652 
653     if (!result) return failure();
654 
655     rewriter.replaceOp(op, {result.getValue()});
656     return success();
657 
658   } else {
659     // Round on int is nonsensical. Instead, replace uses of result with the
660     // input.
661     tfl_round_op.replaceAllUsesWith(tfl_round_op.x());
662     return success();
663   }
664 }
665 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const666 LogicalResult ConvertTFLDivOp::matchAndRewrite(
667     Operation* op, PatternRewriter& rewriter) const {
668   auto tfl_div_op = cast<TFL::DivOp>(op);
669 
670   ShapedType output_type =
671       tfl_div_op.getResult().getType().dyn_cast<ShapedType>();
672   // Not a ranked tensor output
673   if (!output_type) return failure();
674 
675   auto fused_activation_fn = tfl_div_op.fused_activation_functionAttr();
676 
677   Type element_type = output_type.getElementType();
678   Value div_op;
679   if (element_type.isa<IntegerType>()) {
680     div_op = CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), output_type,
681                                            tfl_div_op.lhs(), tfl_div_op.rhs())
682                  .getResult();
683   } else {
684     auto reciprocal_op = CreateOpAndInfer<tosa::ReciprocalOp>(
685         rewriter, op->getLoc(), tfl_div_op.rhs().getType(), tfl_div_op.rhs());
686     div_op = CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
687                                            tfl_div_op.lhs(),
688                                            reciprocal_op.getResult(), 0)
689                  .getResult();
690   }
691 
692   if (fused_activation_fn) {
693     llvm::Optional<Value> fused_activation_val =
694         convertFusedActivation(rewriter, op, div_op, fused_activation_fn);
695 
696     if (!fused_activation_val) return failure();
697 
698     rewriter.replaceOp(op, {fused_activation_val.getValue()});
699     return success();
700   }
701 
702   rewriter.replaceOp(op, {div_op});
703 
704   return success();
705 }
706 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const707 LogicalResult ConvertTFLMaximumOp::matchAndRewrite(
708     Operation* op, PatternRewriter& rewriter) const {
709   auto tfl_max_op = cast<TFL::MaximumOp>(op);
710 
711   ShapedType input_lhs_type = tfl_max_op.lhs().getType().dyn_cast<ShapedType>();
712   ShapedType input_rhs_type = tfl_max_op.rhs().getType().dyn_cast<ShapedType>();
713   ShapedType output_type =
714       tfl_max_op.getResult().getType().dyn_cast<ShapedType>();
715 
716   // Not a shaped tensor output
717   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
718 
719   bool input_lhs_is_qtype =
720       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
721   bool input_rhs_is_qtype =
722       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
723   bool output_is_qtype =
724       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
725 
726   if (input_lhs_is_qtype != output_is_qtype ||
727       input_rhs_is_qtype != output_is_qtype) {
728     return op->emitOpError(
729         "ConvertTFLMaximumOp: input/output tensor should "
730         "be all quantized or all floating-point.");
731   }
732 
733   Value output;
734   if (output_is_qtype) {
735     ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
736 
737     Value op1_rescale_lhs =
738         buildRescaleToInt32(rewriter, op, tfl_max_op.lhs(), 1.0f, 0);
739     Value op2_rescale_rhs =
740         buildRescaleToInt32(rewriter, op, tfl_max_op.rhs(), 1.0f, 0);
741     auto op3_max_op1_op2 = CreateOpAndInfer<tosa::MaximumOp>(
742         rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
743     Value op4_rescale_op3 = buildRescaleFromInt32(
744         rewriter, op, output_type, op3_max_op1_op2.getResult(), 1.0f, 0);
745 
746     output = op4_rescale_op3;
747   } else {
748     auto op1_max_in =
749         CreateOpAndInfer<tosa::MaximumOp>(rewriter, op->getLoc(), output_type,
750                                           tfl_max_op.lhs(), tfl_max_op.rhs());
751 
752     output = op1_max_in.getResult();
753   }
754 
755   rewriter.replaceOp(op, {output});
756 
757   return success();
758 }
759 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const760 LogicalResult ConvertTFLMinimumOp::matchAndRewrite(
761     Operation* op, PatternRewriter& rewriter) const {
762   auto tfl_min_op = cast<TFL::MinimumOp>(op);
763 
764   ShapedType input_lhs_type = tfl_min_op.lhs().getType().dyn_cast<ShapedType>();
765   ShapedType input_rhs_type = tfl_min_op.rhs().getType().dyn_cast<ShapedType>();
766   ShapedType output_type =
767       tfl_min_op.getResult().getType().dyn_cast<ShapedType>();
768   // Not a shaped tensor output
769   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
770 
771   bool input_lhs_is_qtype =
772       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
773   bool input_rhs_is_qtype =
774       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
775   bool output_is_qtype =
776       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
777 
778   if (input_lhs_is_qtype != output_is_qtype ||
779       input_rhs_is_qtype != output_is_qtype) {
780     return op->emitOpError(
781         "ConvertTFLMinimumOp: input/output tensor should "
782         "be all quantized or all floating-point.");
783   }
784 
785   Value output;
786   if (output_is_qtype) {
787     ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
788 
789     Value op1_rescale_lhs =
790         buildRescaleToInt32(rewriter, op, tfl_min_op.lhs(), 1.0f, 0);
791     Value op2_rescale_rhs =
792         buildRescaleToInt32(rewriter, op, tfl_min_op.rhs(), 1.0f, 0);
793     auto op3_min_op1_op2 = CreateOpAndInfer<tosa::MinimumOp>(
794         rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
795     Value op4_rescale_op3 = buildRescaleFromInt32(
796         rewriter, op, output_type, op3_min_op1_op2.getResult(), 1.0f, 0);
797 
798     output = op4_rescale_op3;
799   } else {
800     auto op1_min_in =
801         CreateOpAndInfer<tosa::MinimumOp>(rewriter, op->getLoc(), output_type,
802                                           tfl_min_op.lhs(), tfl_min_op.rhs());
803 
804     output = op1_min_in.getResult();
805   }
806 
807   rewriter.replaceOp(op, {output});
808 
809   return success();
810 }
811 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const812 LogicalResult ConvertTFLFloorDivOp::matchAndRewrite(
813     Operation* op, PatternRewriter& rewriter) const {
814   auto tfl_floordiv_op = cast<TFL::FloorDivOp>(op);
815 
816   llvm::Optional<Value> result =
817       convertFloorDivOp(rewriter, op, tfl_floordiv_op.getResult(),
818                         tfl_floordiv_op.lhs(), tfl_floordiv_op.rhs());
819 
820   if (!result) return failure();
821 
822   rewriter.replaceOp(op, {result.getValue()});
823 
824   return success();
825 }
826 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const827 LogicalResult ConvertTFLFloorModOp::matchAndRewrite(
828     Operation* op, PatternRewriter& rewriter) const {
829   auto tfl_floormod_op = cast<TFL::FloorModOp>(op);
830 
831   llvm::Optional<Value> result =
832       convertFloorModOp(rewriter, op, tfl_floormod_op.getResult(),
833                         tfl_floormod_op.lhs(), tfl_floormod_op.rhs());
834 
835   if (!result) return failure();
836 
837   rewriter.replaceOp(op, {result.getValue()});
838 
839   return success();
840 }
841 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const842 LogicalResult ConvertTFLAddNOp::matchAndRewrite(
843     Operation* op, PatternRewriter& rewriter) const {
844   auto tfl_addn_op = cast<TFL::AddNOp>(op);
845 
846   ShapedType output_type =
847       tfl_addn_op.getResult().getType().dyn_cast<ShapedType>();
848   // Not a shaped output
849   if (!output_type) return failure();
850 
851   SmallVector<Value> inputs(tfl_addn_op.inputs());
852 
853   assert(inputs.size() >= 2);
854 
855   auto newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(),
856                                              output_type, inputs[0], inputs[1]);
857   for (int i = 2; i < inputs.size(); i++) {
858     newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
859                                           inputs[i], newOp.getResult());
860   }
861 
862   rewriter.replaceOp(op, {newOp.getResult()});
863 
864   return success();
865 }
866 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const867 LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite(
868     Operation* op, PatternRewriter& rewriter) const {
869   auto tfl_avgpool_op = cast<TFL::AveragePool2DOp>(op);
870 
871   ShapedType input_type =
872       tfl_avgpool_op.input().getType().dyn_cast<ShapedType>();
873   ShapedType output_type =
874       tfl_avgpool_op.getResult().getType().dyn_cast<ShapedType>();
875   // Not a shaped output
876   if (!output_type) return failure();
877 
878   // Kernels and strides are dimensionally ordered
879   SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
880   ArrayAttr kernel_size;
881   ArrayAttr stride;
882   ArrayAttr pad;
883   {
884     int64_t kernel_h = tfl_avgpool_op.filter_height();
885     int64_t kernel_w = tfl_avgpool_op.filter_width();
886     kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
887     // i64array is formatted as NHWC now
888     i64array[1] = kernel_h;
889     i64array[2] = kernel_w;
890   }
891   {
892     int64_t stride_h = tfl_avgpool_op.stride_h();
893     int64_t stride_w = tfl_avgpool_op.stride_w();
894     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
895   }
896   {
897     tensorflow::Padding tf_pad;
898     if (!GetPaddingFromString(tfl_avgpool_op.padding().str(), &tf_pad).ok())
899       return failure();
900 
901     // Pooling has no non-unit dilation
902     ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
903 
904     RankedTensorType filter_type = RankedTensorType::get(
905         llvm::makeArrayRef(i64array), rewriter.getIntegerType(64));
906 
907     // TFLite doesn't support explicit padding
908     if (!getPaddingValuesFromPadType(
909             tf_pad,
910             tensorflow::FORMAT_NHWC,  // TFLite only supports this
911             1,                        // tensorflow::FORMAT_OHWI,
912             input_type, filter_type, stride, dilation, rewriter, pad))
913       return failure();
914   }
915 
916   auto average_etype = input_type.getElementType();
917   auto average_type = output_type.clone(average_etype);
918 
919   Value result;
920   if (average_etype.isa<quant::UniformQuantizedType>()) {
921     // TensorFlow Lite doesn't use the zero point when calculating
922     // quantized average pool, while TOSA does. Force the TOSA
923     // zero_points to zero to ensure that the calculations match
924 
925     auto quant_attr = rewriter.getAttr<tosa::UnaryOpQuantizationAttr>(
926         /*input_zp=*/0, /*output_zp=*/0);
927     result = CreateOpAndInfer<tosa::AvgPool2dOp>(
928         rewriter, op->getLoc(), average_type, tfl_avgpool_op.input(),
929         kernel_size, stride, pad, quant_attr);
930   } else {
931     result = CreateOpAndInfer<tosa::AvgPool2dOp>(
932         rewriter, op->getLoc(), average_type, tfl_avgpool_op.input(),
933         kernel_size, stride, pad);
934   }
935   if (average_type != output_type) {
936     result = CreateOpAndInfer<tosa::CastOp>(rewriter, op->getLoc(), output_type,
937                                             result);
938   }
939 
940   rewriter.replaceOp(op, result);
941   return success();
942 }
943 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const944 LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite(
945     Operation* op, PatternRewriter& rewriter) const {
946   auto tfl_maxpool_op = cast<TFL::MaxPool2DOp>(op);
947 
948   ShapedType input_type =
949       tfl_maxpool_op.input().getType().dyn_cast<ShapedType>();
950   ShapedType output_type =
951       tfl_maxpool_op.getResult().getType().dyn_cast<ShapedType>();
952   // Not a shaped type
953   if (!output_type) return failure();
954 
955   // Kernels and strides are dimensionally ordered
956   SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
957   ArrayAttr kernel_size;
958   ArrayAttr stride;
959   ArrayAttr pad;
960   {
961     int64_t kernel_h = tfl_maxpool_op.filter_height();
962     int64_t kernel_w = tfl_maxpool_op.filter_width();
963     kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
964     // i64array is formatted as NHWC now
965     i64array[1] = kernel_h;
966     i64array[2] = kernel_w;
967   }
968   {
969     int64_t stride_h = tfl_maxpool_op.stride_h();
970     int64_t stride_w = tfl_maxpool_op.stride_w();
971     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
972   }
973   {
974     tensorflow::Padding tf_pad;
975     if (!GetPaddingFromString(tfl_maxpool_op.padding().str(), &tf_pad).ok())
976       return failure();
977 
978     // Pooling has no non-unit dilation
979     ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
980 
981     RankedTensorType filter_type =
982         RankedTensorType::get(i64array, rewriter.getIntegerType(64));
983 
984     // TFLite doesn't support explicit padding
985     if (!getPaddingValuesFromPadType(
986             tf_pad,
987             tensorflow::FORMAT_NHWC,  // TFLite only supports this
988             1,                        // tensorflow::FORMAT_OHWI,
989             input_type, filter_type, stride, dilation, rewriter, pad))
990       return failure();
991   }
992 
993   CreateReplaceOpAndInfer<tosa::MaxPool2dOp>(rewriter, op, output_type,
994                                              tfl_maxpool_op.input(),
995                                              kernel_size, stride, pad);
996   return success();
997 }
998 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const999 LogicalResult ConvertTFLConv2DOp::matchAndRewrite(
1000     Operation* op, PatternRewriter& rewriter) const {
1001   auto tfl_conv2d_op = cast<TFL::Conv2DOp>(op);
1002 
1003   RankedTensorType input_type =
1004       tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
1005   RankedTensorType filter_type =
1006       tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
1007   ShapedType output_type =
1008       tfl_conv2d_op.getResult().getType().dyn_cast<ShapedType>();
1009   // Not a ranked tensor output
1010   if (!input_type) return failure();
1011   if (!output_type) return failure();
1012   if (!filter_type) return failure();
1013 
1014   bool input_is_qtype =
1015       input_type.getElementType().isa<mlir::quant::QuantizedType>();
1016   bool filter_is_qtype =
1017       filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1018   bool output_is_qtype =
1019       output_type.getElementType().isa<mlir::quant::QuantizedType>();
1020 
1021   if ((input_is_qtype != filter_is_qtype) ||
1022       (input_is_qtype != output_is_qtype)) {
1023     return op->emitOpError(
1024         "ConvertTFLConv2DOp: input/filter/output tensor should "
1025         "be all quantized or all floating-point.");
1026   }
1027 
1028   ArrayAttr pad;
1029   ArrayAttr stride;
1030   ArrayAttr dilation;
1031   {
1032     int64_t stride_h = tfl_conv2d_op.stride_h();
1033     int64_t stride_w = tfl_conv2d_op.stride_w();
1034     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1035   }
1036   {
1037     int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
1038     int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
1039     dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
1040   }
1041   {
1042     tensorflow::Padding tf_pad;
1043     if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
1044       return failure();
1045 
1046     // TFLite doesn't support explicit padding
1047     if (!getPaddingValuesFromPadType(
1048             tf_pad,
1049             tensorflow::FORMAT_NHWC,  // TFLite only supports this
1050             1,                        // tensorflow::FORMAT_OHWI,
1051             input_type, filter_type, stride, dilation, rewriter, pad))
1052       return failure();
1053   }
1054 
1055   Value unquantized_bias = tfl_conv2d_op.bias();
1056   Type bias_ety =
1057       output_is_qtype ? rewriter.getI32Type() : output_type.getElementType();
1058   if (unquantized_bias)
1059     bias_ety = unquantized_bias.getType().cast<ShapedType>().getElementType();
1060 
1061   auto a1_conv2d_op = CreateOpAndInfer<tosa::Conv2DOp>(
1062       rewriter, op->getLoc(), output_type.clone(bias_ety),
1063       tfl_conv2d_op.input(), tfl_conv2d_op.filter(), unquantized_bias, pad,
1064       stride, dilation);
1065 
1066   Value conv2d_output;
1067   if (input_is_qtype) {
1068     conv2d_output =
1069         buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
1070                                  input_type, filter_type, output_type);
1071   } else {
1072     conv2d_output = a1_conv2d_op.getResult();
1073   }
1074 
1075   auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
1076 
1077   if (fused_activation_fn) {
1078     llvm::Optional<Value> fused_activation_val = convertFusedActivation(
1079         rewriter, op, conv2d_output, fused_activation_fn);
1080 
1081     if (!fused_activation_val) return failure();
1082 
1083     rewriter.replaceOp(op, {fused_activation_val.getValue()});
1084     return success();
1085   }
1086 
1087   rewriter.replaceOp(op, {conv2d_output});
1088 
1089   return success();
1090 }
1091 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1092 LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite(
1093     Operation* op, PatternRewriter& rewriter) const {
1094   auto tfl_conv_op = cast<TFL::TransposeConvOp>(op);
1095 
1096   ShapedType input_type = tfl_conv_op.input().getType().dyn_cast<ShapedType>();
1097   ShapedType filter_type =
1098       tfl_conv_op.weights().getType().dyn_cast<ShapedType>();
1099   ShapedType output_type =
1100       tfl_conv_op.getResult().getType().dyn_cast<ShapedType>();
1101   // Not a ranked tensor output
1102   if (!input_type) return failure();
1103   if (!output_type) return failure();
1104   if (!filter_type) return failure();
1105 
1106   bool input_is_qtype =
1107       input_type.getElementType().isa<mlir::quant::QuantizedType>();
1108   bool filter_is_qtype =
1109       filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1110   bool output_is_qtype =
1111       output_type.getElementType().isa<mlir::quant::QuantizedType>();
1112 
1113   if ((input_is_qtype != filter_is_qtype) ||
1114       (input_is_qtype != output_is_qtype)) {
1115     return op->emitOpError(
1116         "ConvertTFLConv2DOp: input/filter/output tensor should "
1117         "be all quantized or all floating-point.");
1118   }
1119 
1120   ArrayAttr stride;
1121   ArrayAttr outpad;
1122   ArrayAttr output_shape;
1123   {
1124     int64_t stride_h = tfl_conv_op.stride_h();
1125     int64_t stride_w = tfl_conv_op.stride_w();
1126     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1127   }
1128 
1129   {
1130     tensorflow::Padding tf_pad;
1131     if (!GetPaddingFromString(tfl_conv_op.padding().str(), &tf_pad).ok())
1132       return failure();
1133 
1134     if (!getTransposeConv2dPaddingValues(
1135             tf_pad,
1136             tensorflow::FORMAT_NHWC,  // TFLite only supports this
1137             1,                        // tensorflow::FORMAT_OHWI,
1138             input_type, filter_type, output_type, stride, rewriter, outpad))
1139       return failure();
1140   }
1141   {
1142     ElementsAttr output_shape_elems;
1143     // Match from input_size tensor first
1144     if (matchPattern(tfl_conv_op.output_shape(),
1145                      m_Constant(&output_shape_elems))) {
1146       SmallVector<int64_t> shape_vec;
1147       for (int i = 0; i < output_shape_elems.getNumElements(); i++)
1148         shape_vec.push_back(
1149             output_shape_elems.getValues<APInt>()[i].getSExtValue());
1150       output_shape = rewriter.getI64ArrayAttr(shape_vec);
1151     } else if (output_type.hasRank()) {
1152       // Use output tensor's shape otherwise
1153       output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
1154     } else {
1155       // TODO(suderman): Figure out rankless shape propagation.
1156       return failure();
1157     }
1158   }
1159 
1160   int output_channel = 0;
1161   // TODO(suderman): We need to figure out how to guarantee output channel
1162   // propagation.
1163   if (output_type.hasRank()) {
1164     output_channel = output_type.getDimSize(3);
1165   } else if (filter_type.hasRank()) {
1166     output_channel = filter_type.getDimSize(0);
1167   } else {
1168     return failure();
1169   }
1170 
1171   llvm::Optional<Value> zero_bias;
1172   if (input_is_qtype) {
1173     uint32_t input_bits = input_type.getElementType()
1174                               .dyn_cast<mlir::quant::QuantizedType>()
1175                               .getStorageTypeIntegralWidth();
1176     uint32_t weight_bits = filter_type.getElementType()
1177                                .dyn_cast<mlir::quant::QuantizedType>()
1178                                .getStorageTypeIntegralWidth();
1179 
1180     if (input_bits == 16 && weight_bits == 8) {
1181       SmallVector<APInt> vec(output_channel, APInt(48, 0, true));
1182       zero_bias = getConstTensor<APInt>(rewriter, op, vec, {output_channel});
1183     } else {
1184       SmallVector<int32_t> vec(output_channel, 0);
1185       zero_bias = getConstTensor<int32_t>(rewriter, op, vec, {output_channel});
1186     }
1187   } else {
1188     SmallVector<float> vec(output_channel, 0.0f);
1189     zero_bias = getConstTensor<float>(rewriter, op, vec, {output_channel});
1190   }
1191 
1192   if (!zero_bias) return failure();
1193   Type bias_ety = zero_bias->getType().cast<ShapedType>().getElementType();
1194 
1195   auto a1_conv2d_op = CreateOpAndInfer<tosa::TransposeConv2DOp>(
1196       rewriter, op->getLoc(), output_type.clone(bias_ety), tfl_conv_op.input(),
1197       tfl_conv_op.weights(), zero_bias.getValue(), outpad, stride,
1198       output_shape);
1199 
1200   Value conv2d_output;
1201   if (input_is_qtype) {
1202     conv2d_output =
1203         buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
1204                                  input_type, filter_type, output_type);
1205   } else {
1206     conv2d_output = a1_conv2d_op.getResult();
1207   }
1208 
1209   rewriter.replaceOp(op, {conv2d_output});
1210 
1211   return success();
1212 }
1213 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1214 LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite(
1215     Operation* op, PatternRewriter& rewriter) const {
1216   auto tfl_conv2d_op = cast<TFL::DepthwiseConv2DOp>(op);
1217 
1218   ShapedType input_type =
1219       tfl_conv2d_op.input().getType().dyn_cast<ShapedType>();
1220   ShapedType filter_type =
1221       tfl_conv2d_op.filter().getType().dyn_cast<ShapedType>();
1222   ShapedType output_type =
1223       tfl_conv2d_op.getResult().getType().dyn_cast<ShapedType>();
1224   // Not a shaped output
1225   if (!input_type) return failure();
1226   if (!output_type) return failure();
1227   if (!filter_type) return failure();
1228 
1229   bool input_is_qtype =
1230       input_type.getElementType().isa<mlir::quant::QuantizedType>();
1231   bool filter_is_qtype =
1232       filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1233   bool output_is_qtype =
1234       output_type.getElementType().isa<mlir::quant::QuantizedType>();
1235 
1236   if ((input_is_qtype != filter_is_qtype) ||
1237       (input_is_qtype != output_is_qtype)) {
1238     return op->emitOpError(
1239         "ConvertTFLConv2DOp: input/filter/output tensor should "
1240         "be all quantized or all floating-point.");
1241   }
1242 
1243   // We need the filter shape to compute the transpose.
1244   if (!filter_type.hasRank()) return failure();
1245   auto filter_shape = filter_type.getShape();
1246   // Operator depthwiseConv2D
1247   // TFLite orders the depthwiseConv2D filter in IHWO, while TOSA orders
1248   // filter in HWIO
1249   //
1250   // The lowering reorders the filter.
1251   //
1252   // a1_transpose = tosa.transpose(filter, {1, 2, 3, 0})   // HWIO
1253   // a2_reshape = tosa.reshape(filter, H, W, depth_multiplier, I /
1254   // depth_multiplier)
1255   // a3_transpose_conv2d = tosa.transpose_conv2d(input, a2_reshape, padding,
1256   // stride, dilation)
1257 
1258   ArrayAttr pad;
1259   ArrayAttr stride;
1260   ArrayAttr dilation;
1261   auto depth_multiplier = tfl_conv2d_op.depth_multiplierAttr();
1262 
1263   {
1264     int64_t stride_h = tfl_conv2d_op.stride_h();
1265     int64_t stride_w = tfl_conv2d_op.stride_w();
1266     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1267   }
1268   {
1269     int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
1270     int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
1271     dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
1272   }
1273   {
1274     tensorflow::Padding tf_pad;
1275     if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
1276       return failure();
1277 
1278     if (!getPaddingValuesFromPadType(
1279             tf_pad,
1280             tensorflow::FORMAT_NHWC,  // TFLite only supports this
1281             1,                        // tensorflow::FORMAT_OHWI,
1282             input_type, filter_type, stride, dilation, rewriter, pad))
1283       return failure();
1284   }
1285 
1286   SmallVector<int64_t, 4> a1_transpose_dims;
1287   a1_transpose_dims.push_back(filter_shape[1]);
1288   a1_transpose_dims.push_back(filter_shape[2]);
1289   a1_transpose_dims.push_back(filter_shape[3]);
1290   a1_transpose_dims.push_back(filter_shape[0]);
1291 
1292   SmallVector<int64_t, 4> a2_reshape_dims;
1293   a2_reshape_dims.push_back(a1_transpose_dims[0]);
1294   a2_reshape_dims.push_back(a1_transpose_dims[1]);
1295   a2_reshape_dims.push_back(a1_transpose_dims[2] / depth_multiplier.getInt());
1296   a2_reshape_dims.push_back(depth_multiplier.getInt());
1297 
1298   llvm::Optional<Value> a1_filter_transpose_perms = getConstTensor<int32_t>(
1299       rewriter, op, /*vec=*/{1, 2, 3, 0}, /*shape=*/{4});
1300 
1301   if (!a1_filter_transpose_perms) return failure();
1302 
1303   auto a1_filter_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
1304       rewriter, op->getLoc(),
1305       RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
1306                             filter_type.getElementType()),
1307       tfl_conv2d_op.filter(), a1_filter_transpose_perms.getValue());
1308 
1309   auto a2_filter_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
1310       rewriter, op->getLoc(),
1311       RankedTensorType::get(ArrayRef<int64_t>(a2_reshape_dims),
1312                             filter_type.getElementType()),
1313       a1_filter_transpose_op.getResult(),
1314       rewriter.getI64ArrayAttr(a2_reshape_dims));
1315 
1316   Value unquantized_bias = tfl_conv2d_op.bias();
1317   Type bias_ety =
1318       output_is_qtype ? rewriter.getI32Type() : output_type.getElementType();
1319   if (unquantized_bias)
1320     bias_ety = unquantized_bias.getType().cast<ShapedType>().getElementType();
1321 
1322   auto a3_depthwise_conv2d_op = CreateOpAndInfer<tosa::DepthwiseConv2DOp>(
1323       rewriter, op->getLoc(), output_type.clone(bias_ety),
1324       tfl_conv2d_op.input(), a2_filter_reshape_op.getResult(), unquantized_bias,
1325       pad, stride, dilation);
1326 
1327   Value conv2d_output;
1328   if (input_is_qtype) {
1329     conv2d_output = buildRescaleOpConvOutput(
1330         rewriter, op, a3_depthwise_conv2d_op.getResult(), input_type,
1331         filter_type, output_type);
1332   } else {
1333     conv2d_output = a3_depthwise_conv2d_op.getResult();
1334   }
1335 
1336   auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
1337 
1338   if (fused_activation_fn) {
1339     llvm::Optional<Value> fused_activation_val = convertFusedActivation(
1340         rewriter, op, conv2d_output, fused_activation_fn);
1341 
1342     if (!fused_activation_val) return failure();
1343 
1344     rewriter.replaceOp(op, {fused_activation_val.getValue()});
1345     return success();
1346   }
1347 
1348   rewriter.replaceOp(op, {conv2d_output});
1349 
1350   return success();
1351 }
1352 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1353 LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite(
1354     Operation* op, PatternRewriter& rewriter) const {
1355   auto tfl_mm_op = cast<TFL::BatchMatMulOp>(op);
1356   auto result_ty = tfl_mm_op.getType().cast<ShapedType>();
1357   Value lhs = tfl_mm_op.x();
1358   Value rhs = tfl_mm_op.y();
1359   RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
1360   RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
1361   bool transpose_lhs = tfl_mm_op.adj_x();
1362   bool transpose_rhs = tfl_mm_op.adj_y();
1363 
1364   if (!lhs_ty || !rhs_ty) return failure();
1365 
1366   bool lhs_is_qtype = lhs_ty.getElementType().isa<mlir::quant::QuantizedType>();
1367   bool rhs_is_qtype = rhs_ty.getElementType().isa<mlir::quant::QuantizedType>();
1368   bool result_is_qtype =
1369       result_ty.getElementType().isa<mlir::quant::QuantizedType>();
1370 
1371   if ((lhs_is_qtype != rhs_is_qtype) || (lhs_is_qtype != result_is_qtype)) {
1372     return op->emitOpError(
1373         "ConvertTFLBatchMatMulOp: lhs/rhs/output tensor should "
1374         "be all quantized or all floating-point.");
1375   }
1376 
1377   auto batch_dims = lhs_ty.getShape().drop_back(2);
1378   if (batch_dims.size() > 1) {
1379     int64_t N = 1;
1380     for (auto d : batch_dims) {
1381       N = N < 0 || d < 0 ? -1 : N * d;
1382     }
1383 
1384     llvm::SmallVector<int64_t> new_lhs_shape{N};
1385     llvm::SmallVector<int64_t> new_rhs_shape{N};
1386     auto lhs_shape_end = lhs_ty.getShape().take_back(2);
1387     auto rhs_shape_end = rhs_ty.getShape().take_back(2);
1388 
1389     new_lhs_shape.append(lhs_shape_end.begin(), lhs_shape_end.end());
1390     new_rhs_shape.append(rhs_shape_end.begin(), rhs_shape_end.end());
1391 
1392     lhs = CreateOpAndInfer<tosa::ReshapeOp>(
1393         rewriter, op->getLoc(),
1394         UnrankedTensorType::get(lhs_ty.getElementType()), lhs,
1395         rewriter.getI64ArrayAttr(new_lhs_shape));
1396     rhs = CreateOpAndInfer<tosa::ReshapeOp>(
1397         rewriter, op->getLoc(),
1398         UnrankedTensorType::get(rhs_ty.getElementType()), rhs,
1399         rewriter.getI64ArrayAttr(new_rhs_shape));
1400     lhs_ty = lhs.getType().cast<RankedTensorType>();
1401     rhs_ty = rhs.getType().cast<RankedTensorType>();
1402   }
1403 
1404   if (transpose_lhs) {
1405     Value perms =
1406         getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3})
1407             .getValue();
1408     Type output_type = UnrankedTensorType::get(lhs_ty.getElementType());
1409     lhs = CreateOpAndInfer<tosa::TransposeOp>(rewriter, op->getLoc(),
1410                                               output_type, lhs, perms)
1411               .getResult();
1412   }
1413 
1414   if (transpose_rhs) {
1415     Value perms =
1416         getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3})
1417             .getValue();
1418     Type output_type = UnrankedTensorType::get(rhs_ty.getElementType());
1419     rhs = CreateOpAndInfer<tosa::TransposeOp>(rewriter, op->getLoc(),
1420                                               output_type, rhs, perms)
1421               .getResult();
1422   }
1423 
1424   auto matmul =
1425       CreateOpAndInfer<tosa::MatMulOp>(
1426           rewriter, op->getLoc(),
1427           UnrankedTensorType::get(result_ty.getElementType()), lhs, rhs)
1428           .getResult();
1429 
1430   // Conditionally reshape rank back to expected rank.
1431   auto matmul_ty = matmul.getType().cast<RankedTensorType>();
1432   if (batch_dims.size() > 1) {
1433     llvm::SmallVector<int64_t> new_shape{};
1434     for (auto d : batch_dims) {
1435       new_shape.push_back(d);
1436     }
1437 
1438     for (auto d : matmul_ty.getShape().take_back(2)) {
1439       new_shape.push_back(d);
1440     }
1441 
1442     matmul = CreateOpAndInfer<tosa::ReshapeOp>(
1443         rewriter, op->getLoc(),
1444         UnrankedTensorType::get(matmul_ty.getElementType()), matmul,
1445         rewriter.getI64ArrayAttr(new_shape));
1446   }
1447 
1448   if (lhs_is_qtype) {
1449     matmul = buildRescaleOpConvOutput(rewriter, op, matmul, lhs_ty, rhs_ty,
1450                                       result_ty);
1451   }
1452 
1453   rewriter.replaceOp(op, matmul);
1454 
1455   return success();
1456 }
1457 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1458 LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite(
1459     Operation* op, PatternRewriter& rewriter) const {
1460   auto tfl_fc_op = cast<TFL::FullyConnectedOp>(op);
1461 
1462   ShapedType output_type =
1463       tfl_fc_op.getResult(0).getType().dyn_cast<ShapedType>();
1464   // Not a ranked tensor output
1465   if (!output_type) return failure();
1466 
1467   RankedTensorType input_type =
1468       tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
1469   RankedTensorType filter_type =
1470       tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
1471   RankedTensorType bias_type =
1472       tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
1473   if (!input_type || !filter_type) return failure();
1474 
1475   bool input_is_qtype =
1476       input_type.getElementType().isa<mlir::quant::QuantizedType>();
1477   bool filter_is_qtype =
1478       filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1479   bool output_is_qtype =
1480       output_type.getElementType().isa<mlir::quant::QuantizedType>();
1481 
1482   if ((input_is_qtype != filter_is_qtype) ||
1483       (input_is_qtype != output_is_qtype)) {
1484     return op->emitOpError(
1485         "ConvertTFLFullyConnectedOp: input/filter/output tensor should "
1486         "be all quantized or all floating-point.");
1487   }
1488 
1489   Value input_val = tfl_fc_op.input();
1490 
1491   // tfl.fully_connected() can takes various dimension tensor as input
1492   // need to reshape it to rank 2 tensor, which tosa.fully_connected only
1493   // supports if input tensor is rank 4.  It's not always reshaping to (dim[0] *
1494   // dim[1], dim[2] * dim[3]).
1495 
1496   // In some networks it's reshaping to (dim[0], dim[1] * dim[2] * dim[3]) so a
1497   // more general way to determine the reshape's shape is by looking at filter's
1498   // shape[1].
1499   if (input_type.getRank() != 2) {
1500     int64_t num_elems = filter_type.getShape()[1];
1501     int64_t num_batch = input_type.getNumElements() / num_elems;
1502     SmallVector<int64_t, 2> shape_vals({num_batch, num_elems});
1503 
1504     RankedTensorType reshape_type =
1505         RankedTensorType::get(shape_vals, input_type.getElementType());
1506     auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
1507         rewriter, op->getLoc(), reshape_type, tfl_fc_op.input(),
1508         rewriter.getI64ArrayAttr(shape_vals));
1509 
1510     input_val = reshape_op.getResult();
1511   }
1512 
1513   Value bias_val;
1514   if (!bias_type) {
1515     // For some matmuls, the bias may actually be a "UnitType" which has no
1516     // value. TOSA requires bias to be an array of output_channel_count values,
1517     // so create a constant of the appropriate number and type of zeros.
1518     SmallVector<int64_t, 1> bias_shape({filter_type.getShape()[0]});
1519     RankedTensorType new_bias_type;
1520 
1521     DenseElementsAttr bias_attr;
1522     if (input_type.getElementType().isa<FloatType>()) {
1523       SmallVector<float> bias_arr(bias_shape[0]);
1524 
1525       for (int i = 0; i < bias_shape[0]; i++) {
1526         bias_arr[i] = 0.0;
1527       }
1528       new_bias_type =
1529           RankedTensorType::get(bias_shape, input_type.getElementType());
1530       bias_attr =
1531           DenseElementsAttr::get(new_bias_type, llvm::makeArrayRef(bias_arr));
1532     } else {
1533       SmallVector<int32_t> bias_arr(bias_shape[0]);
1534 
1535       for (int i = 0; i < bias_shape[0]; i++) {
1536         bias_arr[i] = 0;
1537       }
1538       if (!input_is_qtype) {
1539         return op->emitOpError(
1540             "ConvertTFLFullyConnectedOp: input must be quantized type if it's "
1541             "not float type.");
1542       }
1543       auto input_qtype =
1544           input_type.getElementType().cast<mlir::quant::QuantizedType>();
1545       Type new_bias_ety = input_qtype.getStorageTypeIntegralWidth() == 16
1546                               ? rewriter.getIntegerType(48)
1547                               : rewriter.getI32Type();
1548       new_bias_type = RankedTensorType::get(bias_shape, new_bias_ety);
1549       bias_attr =
1550           DenseElementsAttr::get(new_bias_type, llvm::makeArrayRef(bias_arr));
1551     }
1552     auto bias_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
1553                                                    new_bias_type, bias_attr);
1554     bias_val = bias_op.getResult();
1555     bias_type = new_bias_type;
1556   } else {
1557     bias_val = tfl_fc_op.bias();
1558   }
1559 
1560   Type bias_ety = bias_val.getType().cast<ShapedType>().getElementType();
1561 
1562   auto fc_op = CreateOpAndInfer<tosa::FullyConnectedOp>(
1563       rewriter, op->getLoc(), UnrankedTensorType::get(bias_ety), input_val,
1564       tfl_fc_op.filter(), bias_val);
1565 
1566   Value fc_output;
1567   if (input_is_qtype) {
1568     fc_output = buildRescaleOpConvOutput(
1569         rewriter, op, fc_op.getResult(), input_type, filter_type,
1570         UnrankedTensorType::get(output_type.getElementType()));
1571   } else {
1572     fc_output = fc_op.getResult();
1573   }
1574 
1575   // If we know the output rank, we need to ensure the output shape is correct.
1576   ShapedType fc_type = fc_output.getType().cast<ShapedType>();
1577   if (output_type.hasRank()) {
1578     llvm::SmallVector<int64_t> output_shape;
1579 
1580     fc_output = CreateOpAndInfer<tosa::ReshapeOp>(
1581         rewriter, op->getLoc(),
1582         UnrankedTensorType::get(fc_type.getElementType()), fc_output,
1583         rewriter.getI64ArrayAttr(output_type.getShape()));
1584   }
1585 
1586   auto fused_activation_fn = tfl_fc_op.fused_activation_functionAttr();
1587 
1588   if (fused_activation_fn) {
1589     llvm::Optional<Value> fused_activation_val =
1590         convertFusedActivation(rewriter, op, fc_output, fused_activation_fn);
1591 
1592     if (!fused_activation_val) return failure();
1593 
1594     rewriter.replaceOp(op, {fused_activation_val.getValue()});
1595     return success();
1596   }
1597 
1598   rewriter.replaceOp(op, {fc_output});
1599 
1600   return success();
1601 }
1602 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1603 LogicalResult ConvertTFLConcatenationOp::matchAndRewrite(
1604     Operation* op, PatternRewriter& rewriter) const {
1605   auto tfl_concat_op = cast<TFL::ConcatenationOp>(op);
1606 
1607   SmallVector<Value> values(tfl_concat_op.values());
1608 
1609   IntegerAttr axis_attr;
1610   {
1611     auto tmpAttr = tfl_concat_op.axisAttr();
1612     if (!tmpAttr) {
1613       tmpAttr = rewriter.getI64IntegerAttr(0);
1614     }
1615     axis_attr = tmpAttr;
1616   }
1617   int32_t axis = axis_attr.getInt();
1618 
1619   llvm::Optional<Value> result =
1620       convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
1621 
1622   if (!result) return failure();
1623 
1624   rewriter.replaceOp(op, {result.getValue()});
1625   return success();
1626 }
1627 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1628 LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
1629     Operation* op, PatternRewriter& rewriter) const {
1630   auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
1631 
1632   ShapedType output_type =
1633       tfl_reshape_op.getResult().getType().dyn_cast<ShapedType>();
1634   // Not a shaped tensor output
1635   if (!output_type) return failure();
1636 
1637   SmallVector<int64_t> shape_vals;
1638 
1639   // Either the output type needs to be ranked or we need a constant input
1640   // to compute the output rank.
1641   ElementsAttr shape_attr;
1642   if (!matchPattern(tfl_reshape_op.shape(), m_Constant(&shape_attr))) {
1643     if (!output_type.hasRank()) return failure();
1644     shape_vals.resize(output_type.getRank(), -1);
1645   } else {
1646     for (auto dim : shape_attr.getValues<int32_t>()) shape_vals.push_back(dim);
1647   }
1648 
1649   // Propagate the agreement between the output shape and constant value.
1650   if (output_type.hasRank()) {
1651     if (output_type.getRank() != shape_vals.size()) return failure();
1652     for (int i = 0; i < output_type.getRank(); i++) {
1653       if (shape_vals[i] == -1) shape_vals[i] = output_type.getDimSize(i);
1654     }
1655   }
1656 
1657   // We cannot handle more than 1 dynamic dimension.
1658   int64_t dynamic_count = 0;
1659   for (auto val : shape_vals)
1660     if (val == -1) dynamic_count++;
1661   if (dynamic_count > 1) return failure();
1662 
1663   ArrayAttr new_shape_attr = rewriter.getI64ArrayAttr(shape_vals);
1664   CreateReplaceOpAndInfer<tosa::ReshapeOp>(
1665       rewriter, op, output_type, tfl_reshape_op.input(), new_shape_attr);
1666   return success();
1667 }
1668 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1669 LogicalResult ConvertTFLRankOp::matchAndRewrite(
1670     Operation* op, PatternRewriter& rewriter) const {
1671   auto tfl_rank_op = cast<TFL::RankOp>(op);
1672 
1673   RankedTensorType input_type =
1674       tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
1675   if (!input_type) return failure();
1676 
1677   int32_t rank = input_type.getRank();
1678 
1679   RankedTensorType rank_type =
1680       RankedTensorType::get({1}, rewriter.getIntegerType(32));
1681   auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
1682   auto rank_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
1683                                                     rank_type, rank_attr);
1684 
1685   rewriter.replaceOp(op, {rank_const.getResult()});
1686 
1687   return success();
1688 }
1689 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1690 LogicalResult ConvertTFLShapeOp::matchAndRewrite(
1691     Operation* op, PatternRewriter& rewriter) const {
1692   auto tfl_shape_op = cast<TFL::ShapeOp>(op);
1693 
1694   RankedTensorType output_type =
1695       tfl_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
1696   // Not a ranked tensor output
1697   if (!output_type) return failure();
1698 
1699   RankedTensorType input_type =
1700       tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
1701   if (!input_type || !input_type.hasStaticShape())
1702     return rewriter.notifyMatchFailure(op, "input shape not static");
1703 
1704   auto input_shape = input_type.getShape();
1705 
1706   SmallVector<int32_t> shape_arr;
1707   for (int i = 0; i < input_shape.size(); i++) {
1708     shape_arr.emplace_back(input_shape[i]);
1709   }
1710 
1711   RankedTensorType shape_type = RankedTensorType::get(
1712       {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
1713   auto shape_attr =
1714       DenseElementsAttr::get(shape_type, llvm::makeArrayRef(shape_arr));
1715   auto shape_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
1716                                                      shape_type, shape_attr);
1717 
1718   rewriter.replaceOp(op, {shape_const.getResult()});
1719 
1720   return success();
1721 }
1722 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1723 LogicalResult ConvertTFLExpandDimsOp::matchAndRewrite(
1724     Operation* op, PatternRewriter& rewriter) const {
1725   auto tfl_expanddims_op = cast<TFL::ExpandDimsOp>(op);
1726 
1727   llvm::Optional<Value> result =
1728       convertExpandDimsOp(rewriter, op, tfl_expanddims_op.getResult(),
1729                           tfl_expanddims_op.input(), tfl_expanddims_op.dim());
1730 
1731   if (!result) return failure();
1732 
1733   rewriter.replaceOp(op, {result.getValue()});
1734 
1735   return success();
1736 }
1737 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1738 LogicalResult ConvertTFLSqueezeOp::matchAndRewrite(
1739     Operation* op, PatternRewriter& rewriter) const {
1740   auto tfl_squeeze_op = cast<TFL::SqueezeOp>(op);
1741 
1742   // Copy squeeze_dims into int32_t array
1743   auto squeeze_dims_attr = tfl_squeeze_op.squeeze_dimsAttr();
1744   SmallVector<int32_t> squeeze_dims;
1745   for (auto& squeeze_dim : squeeze_dims_attr) {
1746     squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
1747   }
1748 
1749   llvm::Optional<Value> result =
1750       convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
1751                        tfl_squeeze_op.input(), squeeze_dims);
1752 
1753   if (!result) return failure();
1754 
1755   rewriter.replaceOp(op, {result.getValue()});
1756 
1757   return success();
1758 }
1759 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1760 LogicalResult ConvertTFLFillOp::matchAndRewrite(
1761     Operation* op, PatternRewriter& rewriter) const {
1762   auto tfl_fill_op = cast<TFL::FillOp>(op);
1763 
1764   RankedTensorType output_type =
1765       tfl_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
1766   // Not a ranked tensor output
1767   if (!output_type) return failure();
1768 
1769   ElementsAttr dims_elems;
1770   if (!matchPattern(tfl_fill_op.dims(), m_Constant(&dims_elems)))
1771     return failure();
1772   SmallVector<int64_t> dims_vals;
1773   uint32_t total_size = 1;
1774   for (int i = 0; i < dims_elems.getNumElements(); i++) {
1775     dims_vals.push_back(dims_elems.getValues<APInt>()[i].getSExtValue());
1776     total_size *= dims_vals[i];
1777   }
1778 
1779   ElementsAttr value_elem;
1780   if (!matchPattern(tfl_fill_op.input(), m_Constant(&value_elem)))
1781     return failure();
1782 
1783   RankedTensorType fill_type = RankedTensorType::get(
1784       ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
1785   DenseElementsAttr fill_attr;
1786 
1787   // Convert to a compatible zero type.
1788   if (value_elem.getType().getElementType().isa<FloatType>()) {
1789     SmallVector<float> fill_arr(
1790         total_size, value_elem.getValues<APFloat>()[0].convertToFloat());
1791     fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
1792   } else {
1793     SmallVector<int32_t> fill_arr(
1794         total_size, value_elem.getValues<APInt>()[0].getLimitedValue());
1795     fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
1796   }
1797   auto fill_const_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
1798                                                        fill_type, fill_attr);
1799   rewriter.replaceOp(op, {fill_const_op.getResult()});
1800 
1801   return success();
1802 }
1803 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1804 LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite(
1805     Operation* op, PatternRewriter& rewriter) const {
1806   auto tfl_any_op = cast<TFL::ReduceAnyOp>(op);
1807 
1808   RankedTensorType output_type =
1809       tfl_any_op.getResult().getType().dyn_cast<RankedTensorType>();
1810   if (!output_type) return failure();
1811 
1812   ElementsAttr axes_elems;
1813   if (!matchPattern(tfl_any_op.reduction_indices(), m_Constant(&axes_elems)))
1814     return failure();
1815 
1816   llvm::Optional<Value> result = convertReduceAnyOp(
1817       rewriter, op, output_type, tfl_any_op.input(), axes_elems);
1818 
1819   if (!result) return failure();
1820 
1821   rewriter.replaceOp(op, {result.getValue()});
1822 
1823   return success();
1824 }
1825 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1826 LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite(
1827     Operation* op, PatternRewriter& rewriter) const {
1828   auto tfl_max_op = cast<TFL::ReduceMaxOp>(op);
1829 
1830   RankedTensorType output_type =
1831       tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
1832   if (!output_type) return failure();
1833 
1834   ElementsAttr axes_elems;
1835   if (!matchPattern(tfl_max_op.axes(), m_Constant(&axes_elems)))
1836     return failure();
1837 
1838   llvm::Optional<Value> result = convertReduceMaxOp(
1839       rewriter, op, output_type, tfl_max_op.input(), axes_elems);
1840 
1841   if (!result) return failure();
1842 
1843   rewriter.replaceOp(op, {result.getValue()});
1844 
1845   return success();
1846 }
1847 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1848 LogicalResult ConvertTFLReduceMinOp::matchAndRewrite(
1849     Operation* op, PatternRewriter& rewriter) const {
1850   auto tfl_min_op = cast<TFL::ReduceMinOp>(op);
1851 
1852   RankedTensorType output_type =
1853       tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
1854   if (!output_type) return failure();
1855 
1856   ElementsAttr axes_elems;
1857   if (!matchPattern(tfl_min_op.axes(), m_Constant(&axes_elems)))
1858     return failure();
1859 
1860   llvm::Optional<Value> result = convertReduceMinOp(
1861       rewriter, op, output_type, tfl_min_op.input(), axes_elems);
1862 
1863   if (!result) return failure();
1864 
1865   rewriter.replaceOp(op, {result.getValue()});
1866 
1867   return success();
1868 }
1869 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1870 LogicalResult ConvertTFLReduceProdOp::matchAndRewrite(
1871     Operation* op, PatternRewriter& rewriter) const {
1872   auto tfl_prod_op = cast<TFL::ReduceProdOp>(op);
1873 
1874   RankedTensorType output_type =
1875       tfl_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
1876   if (!output_type) return failure();
1877 
1878   ElementsAttr axes_elems;
1879   if (!matchPattern(tfl_prod_op.axes(), m_Constant(&axes_elems)))
1880     return failure();
1881 
1882   llvm::Optional<Value> result = convertReduceProdOp(
1883       rewriter, op, output_type, tfl_prod_op.input(), axes_elems);
1884 
1885   if (!result) return failure();
1886 
1887   rewriter.replaceOp(op, {result.getValue()});
1888 
1889   return success();
1890 }
1891 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1892 LogicalResult ConvertTFLMeanOp::matchAndRewrite(
1893     Operation* op, PatternRewriter& rewriter) const {
1894   auto tfl_mean_op = cast<TFL::MeanOp>(op);
1895 
1896   RankedTensorType output_type =
1897       tfl_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
1898   if (!output_type) return failure();
1899 
1900   ElementsAttr axes_elems;
1901   if (!matchPattern(tfl_mean_op.axis(), m_Constant(&axes_elems)))
1902     return failure();
1903 
1904   llvm::Optional<Value> result = convertReduceMeanOp(
1905       rewriter, op, output_type, tfl_mean_op.input(), axes_elems);
1906 
1907   if (!result) return failure();
1908 
1909   rewriter.replaceOp(op, {result.getValue()});
1910 
1911   return success();
1912 }
1913 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1914 LogicalResult ConvertTFLSumOp::matchAndRewrite(
1915     Operation* op, PatternRewriter& rewriter) const {
1916   auto tfl_sum_op = cast<TFL::SumOp>(op);
1917 
1918   RankedTensorType output_type =
1919       tfl_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
1920   if (!output_type) return failure();
1921 
1922   ElementsAttr axes_elems;
1923   if (!matchPattern(tfl_sum_op.axes(), m_Constant(&axes_elems)))
1924     return failure();
1925 
1926   llvm::Optional<Value> result = convertReduceSumOp(
1927       rewriter, op, output_type, tfl_sum_op.input(), axes_elems);
1928 
1929   if (!result) return failure();
1930 
1931   rewriter.replaceOp(op, {result.getValue()});
1932 
1933   return success();
1934 }
1935 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1936 LogicalResult ConvertTFLEluOp::matchAndRewrite(
1937     Operation* op, PatternRewriter& rewriter) const {
1938   auto tfl_elu_op = cast<TFL::EluOp>(op);
1939 
1940   llvm::Optional<Value> result =
1941       convertEluOp(rewriter, op, tfl_elu_op.getResult(), tfl_elu_op.x());
1942 
1943   if (!result) return failure();
1944 
1945   rewriter.replaceOp(op, {result.getValue()});
1946 
1947   return success();
1948 }
1949 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1950 LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite(
1951     Operation* op, PatternRewriter& rewriter) const {
1952   auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
1953 
1954   llvm::Optional<Value> result = convertSoftmaxOp(
1955       rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input(),
1956       tfl_softmax_op.betaAttr().getValueAsDouble());
1957 
1958   if (!result) return failure();
1959 
1960   rewriter.replaceOp(op, {result.getValue()});
1961 
1962   return success();
1963 }
1964 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1965 LogicalResult ConvertTFLSqrtOp::matchAndRewrite(
1966     Operation* op, PatternRewriter& rewriter) const {
1967   auto tfl_rsqrt_op = cast<TFL::SqrtOp>(op);
1968   auto rsqrt = CreateOpAndInfer<tosa::RsqrtOp>(
1969       rewriter, op->getLoc(), tfl_rsqrt_op.getType(), tfl_rsqrt_op.x());
1970 
1971   CreateReplaceOpAndInfer<tosa::ReciprocalOp>(rewriter, op, rsqrt.getType(),
1972                                               rsqrt);
1973 
1974   return success();
1975 }
1976 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1977 LogicalResult ConvertTFLL2NormalizationOp::matchAndRewrite(
1978     Operation* op, PatternRewriter& rewriter) const {
1979   auto tfl_l2norm_op = cast<TFL::L2NormalizationOp>(op);
1980   auto input = tfl_l2norm_op.input();
1981   auto input_ty = input.getType().cast<ShapedType>();
1982   auto loc = op->getLoc();
1983 
1984   if (!input_ty.hasRank()) return failure();
1985 
1986   if (input_ty.getElementType().isF32()) {
1987     auto shift = rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
1988     auto result_ty = UnrankedTensorType::get(input_ty.getElementType());
1989     auto mul = CreateOpAndInfer<tosa::MulOp>(rewriter, loc, result_ty, input,
1990                                              input, shift);
1991     auto sum = CreateOpAndInfer<tosa::ReduceSumOp>(
1992         rewriter, loc, result_ty, mul,
1993         rewriter.getI64IntegerAttr(input_ty.getRank() - 1));
1994 
1995     SmallVector<float> min(1, sqrt(std::numeric_limits<float>::min()));
1996     Value min_val = getConstTensor<float>(rewriter, op, min, {}).getValue();
1997     auto max = CreateOpAndInfer<tosa::MaximumOp>(rewriter, loc, result_ty, sum,
1998                                                  min_val);
1999     auto rsqrt = CreateOpAndInfer<tosa::RsqrtOp>(rewriter, loc, result_ty, max)
2000                      .getResult();
2001     auto result = CreateOpAndInfer<tosa::MulOp>(rewriter, loc, result_ty, rsqrt,
2002                                                 input, shift)
2003                       .getResult();
2004 
2005     auto fused_activation_fn = tfl_l2norm_op.fused_activation_functionAttr();
2006 
2007     if (fused_activation_fn) {
2008       llvm::Optional<Value> fused_activation_val =
2009           convertFusedActivation(rewriter, op, result, fused_activation_fn);
2010       if (!fused_activation_val) return failure();
2011       result = fused_activation_val.getValue();
2012     }
2013 
2014     rewriter.replaceOp(op, result);
2015     return success();
2016   }
2017 
2018   return failure();
2019 }
2020 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2021 LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite(
2022     Operation* op, PatternRewriter& rewriter) const {
2023   auto tfl_logsoftmax_op = cast<TFL::LogSoftmaxOp>(op);
2024 
2025   llvm::Optional<Value> result = convertLogSoftmaxOp(
2026       rewriter, op, tfl_logsoftmax_op.getResult(), tfl_logsoftmax_op.input());
2027 
2028   if (!result) return failure();
2029 
2030   rewriter.replaceOp(op, {result.getValue()});
2031 
2032   return success();
2033 }
2034 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2035 LogicalResult ConvertTFLSliceOp::matchAndRewrite(
2036     Operation* op, PatternRewriter& rewriter) const {
2037   auto tfl_slice_op = cast<TFL::SliceOp>(op);
2038 
2039   ShapedType output_type =
2040       tfl_slice_op.getResult().getType().dyn_cast<ShapedType>();
2041   // Not a shaped tensor output
2042   if (!output_type) return failure();
2043 
2044   ElementsAttr begin_elems, size_elems;
2045 
2046   SmallVector<int64_t> begin_vals, size_vals;
2047 
2048   if (!matchPattern(tfl_slice_op.begin(), m_Constant(&begin_elems)) ||
2049       !matchPattern(tfl_slice_op.size(), m_Constant(&size_elems))) {
2050     return failure();
2051   }
2052 
2053   for (int i = 0; i < begin_elems.getNumElements(); i++)
2054     begin_vals.push_back(begin_elems.getValues<APInt>()[i].getSExtValue());
2055 
2056   for (int i = 0; i < size_elems.getNumElements(); i++)
2057     size_vals.push_back(size_elems.getValues<APInt>()[i].getSExtValue());
2058 
2059   ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
2060   ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
2061 
2062   CreateReplaceOpAndInfer<tosa::SliceOp>(rewriter, op, output_type,
2063                                          tfl_slice_op.input(), begin, size);
2064   return success();
2065 }
2066 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2067 LogicalResult ConvertTFLTileOp::matchAndRewrite(
2068     Operation* op, PatternRewriter& rewriter) const {
2069   auto tfl_tile_op = cast<TFL::TileOp>(op);
2070 
2071   ShapedType output_type =
2072       tfl_tile_op.getResult().getType().dyn_cast<ShapedType>();
2073   // Not a ranked tensor output
2074   if (!output_type) return failure();
2075 
2076   ElementsAttr multiples_elems;
2077   if (!matchPattern(tfl_tile_op.multiples(), m_Constant(&multiples_elems)))
2078     return failure();
2079   SmallVector<int64_t> multiples_vals;
2080   for (int i = 0; i < multiples_elems.getNumElements(); i++)
2081     multiples_vals.push_back(
2082         multiples_elems.getValues<APInt>()[i].getSExtValue());
2083 
2084   ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
2085   CreateReplaceOpAndInfer<tosa::TileOp>(rewriter, op, output_type,
2086                                         tfl_tile_op.input(), multiples_attr);
2087 
2088   return success();
2089 }
2090 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2091 LogicalResult ConvertTFLTransposeOp::matchAndRewrite(
2092     Operation* op, PatternRewriter& rewriter) const {
2093   auto tfl_transpose_op = cast<TFL::TransposeOp>(op);
2094 
2095   Type output_type = tfl_transpose_op.getResult().getType();
2096   CreateReplaceOpAndInfer<tosa::TransposeOp>(rewriter, op, output_type,
2097                                              tfl_transpose_op.input(),
2098                                              tfl_transpose_op.perm());
2099 
2100   return success();
2101 }
2102 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2103 LogicalResult ConvertTFLPackOp::matchAndRewrite(
2104     Operation* op, PatternRewriter& rewriter) const {
2105   auto tfl_pack_op = cast<TFL::PackOp>(op);
2106 
2107   SmallVector<Value> inputs(tfl_pack_op.values());
2108   assert(!inputs.empty());
2109 
2110   IntegerAttr axis_attr;
2111   {
2112     auto tmpAttr = tfl_pack_op.axisAttr();
2113     if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
2114     axis_attr = tmpAttr;
2115   }
2116   int32_t axis_i32 = axis_attr.getInt();
2117 
2118   llvm::Optional<Value> result =
2119       convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
2120 
2121   if (!result) return failure();
2122 
2123   rewriter.replaceOp(op, {result.getValue()});
2124 
2125   return success();
2126 }
2127 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2128 LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
2129     Operation* op, PatternRewriter& rewriter) const {
2130   auto tfl_unpack_op = cast<TFL::UnpackOp>(op);
2131 
2132   IntegerAttr axis_attr;
2133   {
2134     auto tmpAttr = tfl_unpack_op.axisAttr();
2135     if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
2136     axis_attr = tmpAttr;
2137   }
2138   int32_t axis_i32 = axis_attr.getInt();
2139 
2140   llvm::Optional<SmallVector<Value>> results =
2141       convertUnpackOp(rewriter, op, tfl_unpack_op.input(), axis_i32);
2142 
2143   if (!results) return failure();
2144 
2145   rewriter.replaceOp(op, results.getValue());
2146 
2147   return success();
2148 }
2149 
2150 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2151 LogicalResult ConvertTFLSplitOp::matchAndRewrite(
2152     Operation* op, PatternRewriter& rewriter) const {
2153   auto tfl_split_op = cast<TFL::SplitOp>(op);
2154 
2155   // Get the number of splits
2156   int32_t num_split = -1;
2157   auto numSplitAttr = tfl_split_op.num_splitsAttr();
2158   if (numSplitAttr) {
2159     num_split = numSplitAttr.getInt();
2160   } else {
2161     return failure();
2162   }
2163 
2164   // Get the axis
2165   ElementsAttr axisAttrElems;
2166   if (!matchPattern(tfl_split_op.split_dim(), m_Constant(&axisAttrElems))) {
2167     return op->emitOpError("Cannot read split_dim elems");
2168   }
2169 
2170   // The axis/split_dim parameter is stored as a 0D tensor instead of
2171   // an integer attribute in TFLite MLIR.
2172   int32_t axis = axisAttrElems.getValues<APInt>()[0].getSExtValue();
2173 
2174   llvm::Optional<SmallVector<Value>> results =
2175       convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
2176                      tfl_split_op.value(), num_split, axis);
2177 
2178   if (!results) return failure();
2179 
2180   rewriter.replaceOp(op, results.getValue());
2181 
2182   return success();
2183 }
2184 
2185 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2186 LogicalResult ConvertTFLSplitVOp::matchAndRewrite(
2187     Operation* op, PatternRewriter& rewriter) const {
2188   auto tfl_splitv_op = cast<TFL::SplitVOp>(op);
2189 
2190   // Get the size_splits array
2191   SmallVector<int32_t> size_split;
2192   ElementsAttr size_split_elems;
2193   if (!matchPattern(tfl_splitv_op.size_splits(),
2194                     m_Constant(&size_split_elems))) {
2195     return failure();
2196   }
2197 
2198   for (int i = 0; i < size_split_elems.getNumElements(); i++) {
2199     size_split.push_back(size_split_elems.getValues<APInt>()[i].getSExtValue());
2200   }
2201 
2202   // Get the axis
2203   ElementsAttr axisAttrElems;
2204   if (!matchPattern(tfl_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
2205     return op->emitOpError("Cannot read split_dim elems");
2206   }
2207 
2208   // The axis/split_dim parameter is stored as a 0D tensor instead of
2209   // an integer attribute in TFLite MLIR.
2210   int32_t axis = axisAttrElems.getValues<APInt>()[0].getSExtValue();
2211 
2212   llvm::Optional<SmallVector<Value>> results =
2213       convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
2214                       tfl_splitv_op.value(), size_split, axis);
2215 
2216   if (!results) return failure();
2217 
2218   rewriter.replaceOp(op, results.getValue());
2219 
2220   return success();
2221 }
2222 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2223 LogicalResult ConvertTFLPadOp::matchAndRewrite(
2224     Operation* op, PatternRewriter& rewriter) const {
2225   auto tfl_pad_op = cast<TFL::PadOp>(op);
2226 
2227   ShapedType output_type =
2228       tfl_pad_op.getResult().getType().dyn_cast<ShapedType>();
2229   // Not a ranked tensor output
2230   if (!output_type) return failure();
2231 
2232   auto pad_op =
2233       CreateOpAndInfer<tosa::PadOp>(rewriter, op->getLoc(), output_type,
2234                                     tfl_pad_op.input(), tfl_pad_op.padding());
2235 
2236   rewriter.replaceOp(op, {pad_op.getResult()});
2237   return success();
2238 }
2239 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2240 LogicalResult ConvertTFLPadV2Op::matchAndRewrite(
2241     Operation* op, PatternRewriter& rewriter) const {
2242   auto tfl_pad_op = cast<TFL::PadV2Op>(op);
2243 
2244   Value input = tfl_pad_op.input();
2245   Value padding = tfl_pad_op.padding();
2246   Value constant_value = tfl_pad_op.constant_values();
2247 
2248   CreateReplaceOpAndInfer<tosa::PadOp>(rewriter, op, tfl_pad_op.getType(),
2249                                        input, padding, constant_value);
2250 
2251   return success();
2252 }
2253 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2254 LogicalResult ConvertTFLResizeBilinearOp::matchAndRewrite(
2255     Operation* op, PatternRewriter& rewriter) const {
2256   auto tfl_resize_op = cast<TFL::ResizeBilinearOp>(op);
2257 
2258   RankedTensorType output_type =
2259       tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
2260   // Not a ranked tensor output
2261   if (!output_type) return failure();
2262 
2263   llvm::Optional<Value> result = convertResizeOp(
2264       rewriter, op, output_type, tfl_resize_op.input(), StringRef("BILINEAR"),
2265       tfl_resize_op.align_cornersAttr().getValue(),
2266       tfl_resize_op.half_pixel_centersAttr().getValue());
2267 
2268   if (!result) return failure();
2269 
2270   rewriter.replaceOp(op, {result.getValue()});
2271 
2272   return success();
2273 }
2274 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2275 LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite(
2276     Operation* op, PatternRewriter& rewriter) const {
2277   auto tfl_resize_op = cast<TFL::ResizeNearestNeighborOp>(op);
2278 
2279   RankedTensorType output_type =
2280       tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
2281   // Not a ranked tensor output
2282   if (!output_type) return failure();
2283 
2284   llvm::Optional<Value> result =
2285       convertResizeOp(rewriter, op, output_type, tfl_resize_op.input(),
2286                       StringRef("NEAREST_NEIGHBOR"),
2287                       tfl_resize_op.align_cornersAttr().getValue(),
2288                       tfl_resize_op.half_pixel_centersAttr().getValue());
2289 
2290   if (!result) return failure();
2291 
2292   rewriter.replaceOp(op, {result.getValue()});
2293 
2294   return success();
2295 }
2296 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2297 LogicalResult ConvertTFLSelectOp::matchAndRewrite(
2298     Operation* op, PatternRewriter& rewriter) const {
2299   auto tfl_sel_op = cast<TFL::SelectOp>(op);
2300 
2301   llvm::Optional<Value> result =
2302       convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
2303                       tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
2304   if (!result) return failure();
2305 
2306   rewriter.replaceOp(op, {result.getValue()});
2307 
2308   return success();
2309 }
2310 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2311 LogicalResult ConvertTFLSelectV2Op::matchAndRewrite(
2312     Operation* op, PatternRewriter& rewriter) const {
2313   auto tfl_sel_op = cast<TFL::SelectV2Op>(op);
2314 
2315   llvm::Optional<Value> result =
2316       convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
2317                       tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
2318   if (!result) return failure();
2319 
2320   rewriter.replaceOp(op, {result.getValue()});
2321 
2322   return success();
2323 }
2324 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2325 LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite(
2326     Operation* op, PatternRewriter& rewriter) const {
2327   auto tfl_s2b_op = cast<TFL::SpaceToBatchNdOp>(op);
2328   llvm::Optional<Value> result = convertSpaceToBatchNDOp(
2329       rewriter, op, tfl_s2b_op.getResult(), tfl_s2b_op.input(),
2330       tfl_s2b_op.block_shape(), tfl_s2b_op.paddings());
2331 
2332   if (!result) return failure();
2333 
2334   rewriter.replaceOp(op, {result.getValue()});
2335 
2336   return success();
2337 }
2338 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2339 LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite(
2340     Operation* op, PatternRewriter& rewriter) const {
2341   auto tfl_b2s_op = cast<TFL::BatchToSpaceNdOp>(op);
2342 
2343   llvm::Optional<Value> result = convertBatchToSpaceNDOp(
2344       rewriter, op, tfl_b2s_op.getResult(), tfl_b2s_op.input(),
2345       tfl_b2s_op.block_shape(), tfl_b2s_op.indices());
2346 
2347   if (!result) return failure();
2348 
2349   rewriter.replaceOp(op, {result.getValue()});
2350 
2351   return success();
2352 }
2353 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2354 LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite(
2355     Operation* op, PatternRewriter& rewriter) const {
2356   auto tfl_s2d_op = cast<TFL::SpaceToDepthOp>(op);
2357 
2358   auto block_size_attr = tfl_s2d_op.block_sizeAttr();
2359   llvm::Optional<Value> result = convertSpaceToDepthOp(
2360       rewriter, op, tfl_s2d_op.getResult(), tfl_s2d_op.input(), block_size_attr,
2361       rewriter.getStringAttr("NHWC"));
2362 
2363   if (!result) return failure();
2364 
2365   rewriter.replaceOp(op, {result.getValue()});
2366 
2367   return success();
2368 }
2369 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2370 LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite(
2371     Operation* op, PatternRewriter& rewriter) const {
2372   auto tfl_d2s_op = cast<TFL::DepthToSpaceOp>(op);
2373 
2374   auto block_size_attr = tfl_d2s_op.block_sizeAttr();
2375   llvm::Optional<Value> result = convertDepthToSpaceOp(
2376       rewriter, op, tfl_d2s_op.getResult(), tfl_d2s_op.input(), block_size_attr,
2377       rewriter.getStringAttr("NHWC"));
2378 
2379   if (!result) return failure();
2380 
2381   rewriter.replaceOp(op, {result.getValue()});
2382 
2383   return success();
2384 }
2385 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2386 LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite(
2387     Operation* op, PatternRewriter& rewriter) const {
2388   auto tfl_ss_op = cast<TFL::StridedSliceOp>(op);
2389 
2390   llvm::Optional<Value> result = convertStridedSliceOp(
2391       rewriter, op, tfl_ss_op.getResult(), tfl_ss_op.input(), tfl_ss_op.begin(),
2392       tfl_ss_op.end(), tfl_ss_op.strides(), tfl_ss_op.begin_maskAttr().getInt(),
2393       tfl_ss_op.end_maskAttr().getInt(), tfl_ss_op.ellipsis_maskAttr().getInt(),
2394       tfl_ss_op.new_axis_maskAttr().getInt(),
2395       tfl_ss_op.shrink_axis_maskAttr().getInt());
2396   if (!result) return failure();
2397 
2398   rewriter.replaceOp(op, {result.getValue()});
2399 
2400   return success();
2401 }
2402 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2403 LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite(
2404     Operation* op, PatternRewriter& rewriter) const {
2405   auto tfl_zeroslike_op = cast<TFL::ZerosLikeOp>(op);
2406 
2407   llvm::Optional<Value> result = convertZerosLikeOp(
2408       rewriter, op, tfl_zeroslike_op.getResult(), tfl_zeroslike_op.input());
2409 
2410   if (!result) return failure();
2411 
2412   rewriter.replaceOp(op, {result.getValue()});
2413 
2414   return success();
2415 }
2416 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2417 LogicalResult ConvertTFLHardSwishOp::matchAndRewrite(
2418     Operation* op, PatternRewriter& rewriter) const {
2419   auto tfl_hardswish_op = cast<TFL::HardSwishOp>(op);
2420   RankedTensorType output_type =
2421       tfl_hardswish_op.getResult().getType().dyn_cast<RankedTensorType>();
2422   // Not a ranked tensor output
2423   if (!output_type) return failure();
2424 
2425   RankedTensorType input_type =
2426       tfl_hardswish_op.input().getType().dyn_cast<RankedTensorType>();
2427   // Not a ranked tensor output
2428   if (!input_type) return failure();
2429 
2430   // TFL hardswish: f(x) -> (x * relu6(x+3))/6
2431 
2432   if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
2433       output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
2434     // Should match TFLite reference numerical behavior
2435     mlir::quant::UniformQuantizedType input_qtype =
2436         input_type.getElementType()
2437             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2438     mlir::quant::UniformQuantizedType output_qtype =
2439         output_type.getElementType()
2440             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2441 
2442     auto hardswish_func = [](double v) -> double {
2443       double w = v + 3.0;
2444       w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w;
2445       return v * w / 6.0;
2446     };
2447 
2448     if (input_qtype.getStorageTypeIntegralWidth() == 8) {
2449       // Implement with 8-bit table lookup.
2450       Value table_const = getTosaConst8bitTable(
2451           rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
2452           output_qtype.getScale(), output_qtype.getZeroPoint(), hardswish_func);
2453 
2454       CreateReplaceOpAndInfer<tosa::TableOp>(
2455           rewriter, op, output_type, tfl_hardswish_op.input(), table_const);
2456     }
2457 
2458   } else {
2459     // op1 = constop(3)
2460     // op2 = add(x, op1)
2461     // op3 = clamp(op2, 0, 6)
2462     // op4 = mul(x, op3)
2463     // op5 = reciprocal(6)
2464     // op6 = mul (op4, op5)
2465 
2466     Value op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
2467 
2468     auto op2_add_x_op1 =
2469         CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
2470                                       tfl_hardswish_op.input(), op1_value);
2471 
2472     auto op3_relu_op2_6 = CreateOpAndInfer<tosa::ClampOp>(
2473         rewriter, op->getLoc(), output_type, op2_add_x_op1.getResult(),
2474         rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(0),
2475         rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f));
2476 
2477     auto op4_mul_x_op3 = CreateOpAndInfer<tosa::MulOp>(
2478         rewriter, op->getLoc(), output_type, tfl_hardswish_op.input(),
2479         op3_relu_op2_6.getResult(), 0);
2480 
2481     auto const_6 = getTosaConstTensorSingleF32(rewriter, op, 6.0);
2482     auto op5_reciprocal_6 = CreateOpAndInfer<tosa::ReciprocalOp>(
2483         rewriter, op->getLoc(), const_6.getType(), const_6);
2484 
2485     auto op6_mul_op4_op5 = CreateOpAndInfer<tosa::MulOp>(
2486         rewriter, op->getLoc(), output_type, op4_mul_x_op3.getResult(),
2487         op5_reciprocal_6.getResult(), 0);
2488 
2489     rewriter.replaceOp(op, {op6_mul_op4_op5.getResult()});
2490   }
2491 
2492   return success();
2493 }
2494 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2495 LogicalResult ConvertTFLSinOp::matchAndRewrite(
2496     Operation* op, PatternRewriter& rewriter) const {
2497   auto tfl_sin_op = cast<TFL::SinOp>(op);
2498   Location loc = op->getLoc();
2499   Value input = tfl_sin_op.x();
2500   RankedTensorType input_ty = input.getType().dyn_cast<RankedTensorType>();
2501   ShapedType output_ty =
2502       tfl_sin_op.getResult().getType().dyn_cast<ShapedType>();
2503 
2504   Type input_ety = input_ty.getElementType();
2505   Type output_ety = output_ty.getElementType();
2506 
2507   if (!input_ty || !output_ty) return failure();
2508 
2509   if (input_ety != output_ety) {
2510     return rewriter.notifyMatchFailure(
2511         op, "ConvertTFLSinOp: input/output element type must match");
2512   }
2513 
2514   bool input_is_fp = input_ty.getElementType().isF32();
2515   bool output_is_fp = output_ty.getElementType().isF32();
2516 
2517   if (!input_is_fp || !output_is_fp) {
2518     return rewriter.notifyMatchFailure(
2519         op, "ConvertTFLSinOp: input/result must be fp32.");
2520   }
2521 
2522   // To perform a sin operation we remap the sin domain to be over a single
2523   // period of the function, remapping to the domain of the table function.
2524   // We then remap the range of the table function to map to the range of the
2525   // sin operation.
2526 
2527   // 1. Normalize the period of the domain from [0, 2π) to [0, 1).
2528   auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type());
2529   Value fp_scale = rewriter.create<tosa::ConstOp>(
2530       loc, fp_scalar_ty,
2531       DenseElementsAttr::get(fp_scalar_ty, {static_cast<float>(0.5 / M_PI)}));
2532 
2533   // 2. Remap the periodic behavior of the domain to line up within [0, 1).
2534   Value fp_scaled = CreateOpAndInfer<tosa::MulOp>(
2535       rewriter, loc, input_ty, input, fp_scale, rewriter.getI32IntegerAttr(0));
2536   auto floored =
2537       CreateOpAndInfer<tosa::FloorOp>(rewriter, loc, input_ty, fp_scaled);
2538   auto repeated = CreateOpAndInfer<tosa::SubOp>(rewriter, loc, input_ty,
2539                                                 fp_scaled, floored);
2540 
2541   // 3. Scale and translate the normalized domain to the table domain. This
2542   // includes a translating and scaling to [-int16_max, int16_max] and casting
2543   // to an i16.
2544   Value one = rewriter.create<tosa::ConstOp>(
2545       loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {1.0f}));
2546 
2547   Value two = rewriter.create<tosa::ConstOp>(
2548       loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {2.0f}));
2549   auto scale_up = CreateOpAndInfer<tosa::MulOp>(
2550       rewriter, loc, input_ty, repeated, two, rewriter.getI32IntegerAttr(0));
2551   auto translate =
2552       CreateOpAndInfer<tosa::SubOp>(rewriter, loc, input_ty, scale_up, one);
2553 
2554   Value int_limit = rewriter.create<tosa::ConstOp>(
2555       loc, fp_scalar_ty,
2556       DenseElementsAttr::get(
2557           fp_scalar_ty,
2558           {static_cast<float>(std::numeric_limits<int16_t>::max())}));
2559   auto int_scaled =
2560       CreateOpAndInfer<tosa::MulOp>(rewriter, loc, input_ty, translate,
2561                                     int_limit, rewriter.getI32IntegerAttr(0));
2562 
2563   auto int16_ty = input_ty.clone(rewriter.getIntegerType(16));
2564   auto casted =
2565       CreateOpAndInfer<tosa::CastOp>(rewriter, loc, int16_ty, int_scaled);
2566 
2567   // 4. Compute the lookup table using the range of [-255, 255] for sin.
2568   llvm::SmallVector<int16_t> values;
2569   const int num_values = 513;
2570   values.resize(num_values, 0);
2571   // First and last values should be 0;
2572   for (int i = 1; i < num_values - 1; ++i)
2573     values[i] = std::numeric_limits<int16_t>::max() *
2574                 sin(static_cast<float>(i) * 2.0 * M_PI / (num_values - 1.0));
2575 
2576   auto table_ty =
2577       RankedTensorType::get({num_values}, rewriter.getIntegerType(16));
2578   Value table = rewriter.create<tosa::ConstOp>(
2579       loc, table_ty,
2580       DenseElementsAttr::get(table_ty, llvm::makeArrayRef(values)));
2581 
2582   auto table_result_ty = input_ty.clone(rewriter.getIntegerType(32));
2583   auto table_result = CreateOpAndInfer<tosa::TableOp>(
2584       rewriter, loc, table_result_ty, casted, table);
2585 
2586   // 5. The range of table is a 23-bit two's compliment value. Normalize the
2587   // range by casting to an fp32 and dividing by 2^22.
2588   auto table_result_fp =
2589       CreateOpAndInfer<CastOp>(rewriter, loc, input_ty, table_result);
2590   auto output_scale = rewriter.create<ConstOp>(
2591       loc, fp_scalar_ty,
2592       DenseElementsAttr::get(
2593           fp_scalar_ty,
2594           {static_cast<float>(1.0 / static_cast<float>(1 << 22))}));
2595   CreateReplaceOpAndInfer<MulOp>(rewriter, op, output_ty, table_result_fp,
2596                                  output_scale, rewriter.getI32IntegerAttr(0));
2597   return success();
2598 }
2599 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2600 LogicalResult ConvertTFLCosOp::matchAndRewrite(
2601     Operation* op, PatternRewriter& rewriter) const {
2602   auto tfl_cos_op = cast<TFL::CosOp>(op);
2603   Value input = tfl_cos_op.x();
2604   RankedTensorType input_ty = input.getType().dyn_cast<RankedTensorType>();
2605   ShapedType output_ty =
2606       tfl_cos_op.getResult().getType().dyn_cast<ShapedType>();
2607 
2608   if (!input_ty || !output_ty) return failure();
2609 
2610   bool input_is_fp = input_ty.getElementType().isa<mlir::FloatType>();
2611   bool output_is_fp = output_ty.getElementType().isa<mlir::FloatType>();
2612 
2613   if (!input_is_fp || !output_is_fp) {
2614     return rewriter.notifyMatchFailure(
2615         op, "ConvertTFLCosOp: input/result must be fp.");
2616   }
2617 
2618   // Replace with the equivalent sin operation:
2619   //   cos(x) = sin(x + π / 2).
2620   auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type());
2621   auto pi_2 = rewriter.create<ConstOp>(
2622       op->getLoc(), fp_scalar_ty,
2623       DenseElementsAttr::get(fp_scalar_ty, {static_cast<float>(M_PI_2)}));
2624   auto offset = rewriter.create<AddOp>(op->getLoc(), input_ty, input, pi_2);
2625 
2626   CreateReplaceOpAndInfer<TFL::SinOp>(rewriter, op, output_ty, offset);
2627   return success();
2628 }
2629 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2630 LogicalResult ConvertTFLLogisticOp::matchAndRewrite(
2631     Operation* op, PatternRewriter& rewriter) const {
2632   auto tfl_logistic_op = cast<TFL::LogisticOp>(op);
2633 
2634   ShapedType output_type =
2635       tfl_logistic_op.getResult().getType().dyn_cast<ShapedType>();
2636   RankedTensorType input_type =
2637       tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
2638   if (!input_type || !output_type) return failure();
2639 
2640   bool input_is_qtype =
2641       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2642   bool output_is_qtype =
2643       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2644 
2645   if (input_is_qtype != output_is_qtype) {
2646     return op->emitOpError(
2647         "ConvertTFLLogisticOp: input/output tensor should "
2648         "be all quantized or all floating-point.");
2649   }
2650 
2651   if (input_is_qtype) {
2652     ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32));
2653     mlir::quant::UniformQuantizedType input_qtype =
2654         input_type.getElementType()
2655             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2656     mlir::quant::UniformQuantizedType output_qtype =
2657         output_type.getElementType()
2658             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2659 
2660     auto sigmoid_func = [](double x) -> double {
2661       return 1.0 / (1.0 + std::exp(-x));
2662     };
2663 
2664     if (input_qtype.getStorageTypeIntegralWidth() == 8) {
2665       Value table_const = getTosaConst8bitTable(
2666           rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
2667           output_qtype.getScale(), output_qtype.getZeroPoint(), sigmoid_func);
2668 
2669       CreateReplaceOpAndInfer<tosa::TableOp>(rewriter, op, output_type,
2670                                              tfl_logistic_op.x(), table_const);
2671     } else {  // int16
2672       if (input_qtype.getZeroPoint() != 0 || output_qtype.getZeroPoint() != 0) {
2673         op->emitOpError(
2674             "ConvertTFLLogistic: input/output zeropoint should be 0 in 16-bit "
2675             "mode");
2676         return failure();
2677       }
2678       double input_min = -32768 * input_qtype.getScale();
2679       double input_max = 32767 * input_qtype.getScale();
2680 
2681       // Generate table with gen_lut() in
2682       // tensorflow/lite/kernels/internal/common.h
2683       Value table_const = getTosaConst16bitTable(rewriter, op, sigmoid_func,
2684                                                  input_min, input_max);
2685 
2686       auto op1_table_in = CreateOpAndInfer<tosa::TableOp>(
2687           rewriter, op->getLoc(), int32_type, tfl_logistic_op.x(), table_const);
2688 
2689       Value op2_rescale_op1 =
2690           buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
2691                        1.0 / 128.0, 0, 0, false, true);
2692 
2693       rewriter.replaceOp(op, {op2_rescale_op1});
2694     }
2695   } else {
2696     CreateReplaceOpAndInfer<tosa::SigmoidOp>(rewriter, op, output_type,
2697                                              tfl_logistic_op.x());
2698   }
2699 
2700   return success();
2701 }
2702 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2703 LogicalResult ConvertTFLTanhOp::matchAndRewrite(
2704     Operation* op, PatternRewriter& rewriter) const {
2705   auto tfl_tanh_op = cast<TFL::TanhOp>(op);
2706   ShapedType output_type =
2707       tfl_tanh_op.getResult().getType().dyn_cast<ShapedType>();
2708   RankedTensorType input_type =
2709       tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
2710   if (!input_type || !output_type) return failure();
2711 
2712   bool input_is_qtype =
2713       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2714   bool output_is_qtype =
2715       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2716 
2717   if (input_is_qtype != output_is_qtype) {
2718     return op->emitOpError(
2719         "ConvertTFLTanhOp: input/output tensor should "
2720         "be all quantized or all floating-point.");
2721   }
2722 
2723   if (input_is_qtype) {
2724     ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32));
2725     mlir::quant::UniformQuantizedType input_qtype =
2726         input_type.getElementType()
2727             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2728     mlir::quant::UniformQuantizedType output_qtype =
2729         output_type.getElementType()
2730             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2731 
2732     auto tanh_func = [](double x) -> double {
2733       x = std::exp(-2.0 * x);
2734       return (1.0 - x) / (1.0 + x);
2735     };
2736 
2737     if (input_qtype.getStorageTypeIntegralWidth() == 8) {
2738       Value table_const = getTosaConst8bitTable(
2739           rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
2740           output_qtype.getScale(), output_qtype.getZeroPoint(), tanh_func);
2741 
2742       CreateReplaceOpAndInfer<tosa::TableOp>(rewriter, op, output_type,
2743                                              tfl_tanh_op.input(), table_const);
2744     } else {  // int16
2745       if (input_qtype.getZeroPoint() != 0 || output_qtype.getZeroPoint() != 0) {
2746         op->emitOpError(
2747             "ConvertTFLLogistic: input/output zeropoint should be 0 in 16-bit "
2748             "mode");
2749         return failure();
2750       }
2751       double input_min = -32768 * input_qtype.getScale();
2752       double input_max = 32767 * input_qtype.getScale();
2753 
2754       // Generate table with gen_lut() in
2755       // tensorflow/lite/kernels/internal/common.h
2756       Value table_const =
2757           getTosaConst16bitTable(rewriter, op, tanh_func, input_min, input_max);
2758 
2759       auto op1_table_in = CreateOpAndInfer<tosa::TableOp>(
2760           rewriter, op->getLoc(), int32_type, tfl_tanh_op.input(), table_const);
2761 
2762       Value op2_rescale_op1 =
2763           buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
2764                        1.0 / 128.0, 0, 0, false, true);
2765 
2766       rewriter.replaceOp(op, {op2_rescale_op1});
2767     }
2768 
2769   } else {
2770     CreateReplaceOpAndInfer<tosa::TanhOp>(rewriter, op, output_type,
2771                                           tfl_tanh_op.input());
2772   }
2773 
2774   return success();
2775 }
2776 
LegalizeFloatingPointPrelu(Operation * op,PatternRewriter & rewriter,Value input,Value alpha,ShapedType output_type)2777 static LogicalResult LegalizeFloatingPointPrelu(Operation* op,
2778                                                 PatternRewriter& rewriter,
2779                                                 Value input, Value alpha,
2780                                                 ShapedType output_type) {
2781   Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0);
2782 
2783   auto mul = CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
2784                                            input, alpha, 0);
2785 
2786   auto ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
2787       rewriter, op->getLoc(), output_type.clone(rewriter.getIntegerType(1)),
2788       input, const_zero);
2789 
2790   CreateReplaceOpAndInfer<tosa::SelectOp>(rewriter, op, output_type, ge, input,
2791                                           mul);
2792 
2793   return success();
2794 }
2795 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2796 LogicalResult ConvertTFLPReluOp::matchAndRewrite(
2797     Operation* op, PatternRewriter& rewriter) const {
2798   auto tfl_prelu_op = cast<TFL::PReluOp>(op);
2799 
2800   ShapedType input_type = tfl_prelu_op.input().getType().dyn_cast<ShapedType>();
2801   ShapedType output_type =
2802       tfl_prelu_op.getResult().getType().dyn_cast<ShapedType>();
2803   if (!input_type || !output_type)
2804     return rewriter.notifyMatchFailure(op,
2805                                        "input or output is not a ShapedType");
2806 
2807   bool output_is_qtype =
2808       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2809 
2810   // TODO(jennik): Handle the quantized case.
2811   if (output_is_qtype)
2812     return rewriter.notifyMatchFailure(
2813         op, "only floating point currently supported");
2814 
2815   return LegalizeFloatingPointPrelu(op, rewriter, tfl_prelu_op.input(),
2816                                     tfl_prelu_op.alpha(), output_type);
2817 }
2818 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2819 LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite(
2820     Operation* op, PatternRewriter& rewriter) const {
2821   auto tfl_leakyrelu_op = cast<TFL::LeakyReluOp>(op);
2822   RankedTensorType input_type =
2823       tfl_leakyrelu_op.input().getType().dyn_cast<RankedTensorType>();
2824 
2825   ShapedType output_type =
2826       tfl_leakyrelu_op.getResult().getType().dyn_cast<ShapedType>();
2827 
2828   if (!input_type || !output_type) return failure();
2829 
2830   bool output_is_qtype =
2831       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2832 
2833   // Implement LeakyRelu as element-wise:
2834   //   out = x > 0 ? x : alpha * x
2835   //
2836   // In TOSA ops:
2837   //
2838   //   const_zero = constant(0)
2839   //   op1 = mul(x, alpha)
2840   //   op2 = greater_equal(x, const_zero)
2841   //   out = select(a2, x, a1)
2842   //
2843   // If alpha can be constrained to 0.0 <= alpha <= 1.0, then
2844   // an alternative simpler lowering could be implemented with:
2845   //
2846   //   max(mul(x, alapha), x)
2847   //
2848   // But this alternative is not robust unless alpha meets those constraints.
2849 
2850   FloatAttr tmpAttr = tfl_leakyrelu_op.alphaAttr();
2851   // There is disagreement between the MLIR .td defaults and TF
2852   // documentation on 0.2 vs 0.3, but 0.2 will be used here.
2853   double alpha = 0.2;
2854 
2855   if (tmpAttr) {
2856     alpha = tmpAttr.getValueAsDouble();
2857   }
2858 
2859   if (output_is_qtype) {
2860     // op1 = rescale(input)
2861     // rescaled_alpha = (alpha << alpha_shift) // Remains within int32 range
2862     // op2 = mul(rescaled_input, rescaled_alpha, alpha_shift)
2863     // op3 = greater_equal(op1, 0)
2864     // op4 = select(op3, op1, op2)
2865     // out = rescale(op4)
2866     ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
2867 
2868     UniformQuantizedType input_qtype =
2869         input_type.getElementType().cast<UniformQuantizedType>();
2870 
2871     UniformQuantizedType output_qtype =
2872         output_type.getElementType().cast<UniformQuantizedType>();
2873 
2874     double scale_alpha =
2875         input_qtype.getScale() * alpha / output_qtype.getScale();
2876     double scale_identity = input_qtype.getScale() / output_qtype.getScale();
2877 
2878     Value op1_rescale_in =
2879         buildRescaleToInt32(rewriter, op, tfl_leakyrelu_op.input(), 1.0,
2880                             input_qtype.getZeroPoint());
2881 
2882     Value const_zero = getTosaConstTensorSingleI32(rewriter, op, 0);
2883     auto op2_ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
2884         rewriter, op->getLoc(), rescale_type.clone(rewriter.getI1Type()),
2885         op1_rescale_in, const_zero);
2886 
2887     Value op3_rescale_alpha_in = buildRescale(
2888         rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_alpha,
2889         input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
2890 
2891     Value op4_rescale_identity_in = buildRescale(
2892         rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_identity,
2893         input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
2894 
2895     CreateReplaceOpAndInfer<tosa::SelectOp>(rewriter, op, output_type, op2_ge,
2896                                             op4_rescale_identity_in,
2897                                             op3_rescale_alpha_in);
2898 
2899     return success();
2900 
2901   } else {
2902     return LegalizeFloatingPointPrelu(
2903         op, rewriter, tfl_leakyrelu_op.input(),
2904         getTosaConstTensorSingleF32(rewriter, op, alpha), output_type);
2905   }
2906 }
2907 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2908 LogicalResult ConvertTFLNegOp::matchAndRewrite(
2909     Operation* op, PatternRewriter& rewriter) const {
2910   auto tfl_neg_op = cast<TFL::NegOp>(op);
2911   ShapedType output_type =
2912       tfl_neg_op.getResult().getType().dyn_cast<ShapedType>();
2913   if (!output_type) return failure();
2914 
2915   CreateReplaceOpAndInfer<tosa::NegateOp>(rewriter, op, output_type,
2916                                           tfl_neg_op.x());
2917 
2918   return success();
2919 }
2920 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2921 LogicalResult ConvertTFLYieldOp::matchAndRewrite(
2922     Operation* op, PatternRewriter& rewriter) const {
2923   rewriter.replaceOpWithNewOp<tosa::YieldOp>(op, op->getResultTypes(),
2924                                              op->getOperands());
2925 
2926   return success();
2927 }
2928 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2929 LogicalResult ConvertTFLCustomOp::matchAndRewrite(
2930     Operation* op, PatternRewriter& rewriter) const {
2931   auto tfl_custom_op = cast<TFL::CustomOp>(op);
2932   rewriter.replaceOpWithNewOp<tosa::CustomOp>(
2933       op, op->getResultTypes(), tfl_custom_op.custom_code(), op->getOperands());
2934 
2935   return success();
2936 }
2937 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2938 LogicalResult ConvertTFLReverseV2Op::matchAndRewrite(
2939     Operation* op, PatternRewriter& rewriter) const {
2940   auto tfl_reverse_op = cast<TFL::ReverseV2Op>(op);
2941 
2942   RankedTensorType input_type =
2943       tfl_reverse_op.input().getType().dyn_cast<RankedTensorType>();
2944   RankedTensorType output_type =
2945       tfl_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
2946   if (!input_type || !output_type) return failure();
2947 
2948   ElementsAttr axis_elems;
2949   if (!matchPattern(tfl_reverse_op.axis(), m_Constant(&axis_elems)))
2950     return failure();
2951 
2952   auto input_rank = input_type.getShape().size();
2953   Value val = tfl_reverse_op.input();
2954   if (axis_elems.getNumElements() == 0) {
2955     auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
2956         rewriter, op->getLoc(), output_type, val);
2957     val = identity_op.getResult();
2958   } else {
2959     for (int i = 0; i < axis_elems.getNumElements(); i++) {
2960       int64_t axis_val = axis_elems.getValues<APInt>()[i].getSExtValue();
2961       if (axis_val < 0) axis_val += input_rank;
2962       auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2963       auto reverse_op = CreateOpAndInfer<tosa::ReverseOp>(
2964           rewriter, op->getLoc(), output_type, val, axis_attr);
2965 
2966       val = reverse_op.getResult();
2967     }
2968   }
2969 
2970   rewriter.replaceOp(op, {val});
2971 
2972   return success();
2973 }
2974 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2975 LogicalResult ConvertTFLQuantizeOp::matchAndRewrite(
2976     Operation* op, PatternRewriter& rewriter) const {
2977   auto tfl_quantize_op = cast<TFL::QuantizeOp>(op);
2978 
2979   RankedTensorType input_type =
2980       tfl_quantize_op.input().getType().dyn_cast<RankedTensorType>();
2981   ShapedType output_type =
2982       tfl_quantize_op.getResult().getType().dyn_cast<ShapedType>();
2983   if (!input_type || !output_type) return failure();
2984 
2985   ShapedType qtype =
2986       tfl_quantize_op.getResult().getType().dyn_cast<ShapedType>();
2987   if (!qtype) return failure();
2988 
2989   UniformQuantizedType element_type =
2990       qtype.getElementType().dyn_cast<UniformQuantizedType>();
2991   if (!element_type) return failure();
2992 
2993   UniformQuantizedType input_element_type =
2994       input_type.getElementType().dyn_cast<UniformQuantizedType>();
2995 
2996   // If input is already a quantized type, this is basically a RESCALE (or
2997   // tensorflow::ops::Requantize)
2998   if (input_element_type) {
2999     double rescale_scale =
3000         input_element_type.getScale() / element_type.getScale();
3001     Value rescale_op =
3002         buildRescale(rewriter, op, output_type, tfl_quantize_op.input(),
3003                      rescale_scale, input_element_type.getZeroPoint(),
3004                      element_type.getZeroPoint(), true, true);
3005 
3006     rewriter.replaceOp(op, {rescale_op});
3007     return success();
3008   } else {
3009     double scale = 1 / element_type.getScale();
3010     int64_t zp = element_type.getZeroPoint();
3011     int64_t num_bits = element_type.getStorageTypeIntegralWidth();
3012     zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
3013 
3014     llvm::Optional<Value> result = convertQuantizeOp(
3015         rewriter, op, output_type, tfl_quantize_op.input(), scale, zp);
3016 
3017     if (!result) return failure();
3018 
3019     rewriter.replaceOp(op, {result.getValue()});
3020 
3021     return success();
3022   }
3023 }
3024 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3025 LogicalResult ConvertTFLDequantizeOp::matchAndRewrite(
3026     Operation* op, PatternRewriter& rewriter) const {
3027   auto tfl_dequantize_op = cast<TFL::DequantizeOp>(op);
3028 
3029   ShapedType output_type =
3030       tfl_dequantize_op.getResult().getType().dyn_cast<ShapedType>();
3031   // Not a ranked tensor output
3032   if (!output_type) return failure();
3033 
3034   RankedTensorType qtype =
3035       tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
3036   if (!qtype) return failure();
3037 
3038   Type element_type = qtype.getElementType();
3039   if (element_type.isa<FloatType>()) {
3040     CreateReplaceOpAndInfer<tosa::CastOp>(rewriter, op, output_type,
3041                                           tfl_dequantize_op.input());
3042     return success();
3043   }
3044 
3045   if (auto eq_ty = element_type.dyn_cast<quant::UniformQuantizedType>()) {
3046     double scale = eq_ty.getScale();
3047     int64_t zp = eq_ty.getZeroPoint();
3048     int64_t num_bits = eq_ty.getStorageTypeIntegralWidth();
3049     zp = eq_ty.isSigned() ? zp : zp - (1 << (num_bits - 1));
3050 
3051     llvm::Optional<Value> result = convertDequantizeOp(
3052         rewriter, op, output_type, tfl_dequantize_op.input(), scale, zp, 0);
3053 
3054     if (!result) return failure();
3055 
3056     rewriter.replaceOp(op, {result.getValue()});
3057     return success();
3058   }
3059 
3060   if (quant::UniformQuantizedPerAxisType eq_ty =
3061           element_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
3062     SmallVector<float> zps;
3063     for (auto zp : eq_ty.getZeroPoints()) {
3064       int64_t num_bits = eq_ty.getStorageTypeIntegralWidth();
3065       zps.push_back(eq_ty.isSigned() ? zp : zp - (1 << (num_bits - 1)));
3066     }
3067 
3068     SmallVector<float> scales;
3069     for (auto scale : eq_ty.getScales()) {
3070       scales.push_back(scale);
3071     }
3072 
3073     llvm::Optional<Value> result = convertDequantizeOp(
3074         rewriter, op, output_type, tfl_dequantize_op.input(), scales, zps,
3075         eq_ty.getQuantizedDimension());
3076 
3077     if (!result) return failure();
3078 
3079     rewriter.replaceOp(op, {result.getValue()});
3080     return success();
3081   }
3082 
3083   return failure();
3084 }
3085 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3086 LogicalResult ConvertTFLConstOp::matchAndRewrite(
3087     Operation* op, PatternRewriter& rewriter) const {
3088   auto tfl_const_op = cast<TFL::ConstOp>(op);
3089 
3090   ShapedType output_type =
3091       tfl_const_op.getResult().getType().dyn_cast<ShapedType>();
3092   if (!output_type) return failure();
3093 
3094   ElementsAttr elements = tfl_const_op.value();
3095   Type element_type = elements.getType().getElementType();
3096   if (output_type.getElementType().isa<quant::QuantizedType>()) {
3097     output_type = RankedTensorType::get(output_type.getShape(), element_type);
3098   }
3099 
3100   // If the output shape is unranked we can extract the result shape from the
3101   // attribute shape. This occurs as some TFLite folders create constants with
3102   // unranked shapes.
3103   if (!output_type.hasRank()) {
3104     output_type = elements.getType().cast<ShapedType>().clone(element_type);
3105   }
3106 
3107   rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, elements);
3108 
3109   return success();
3110 }
3111 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3112 LogicalResult ConvertTFLQConstOp::matchAndRewrite(
3113     Operation* op, PatternRewriter& rewriter) const {
3114   auto tfl_qconst_op = cast<TFL::QConstOp>(op);
3115 
3116   ShapedType output_type =
3117       tfl_qconst_op.getResult().getType().dyn_cast<ShapedType>();
3118   if (!output_type) return failure();
3119 
3120   ElementsAttr elements = tfl_qconst_op.value();
3121 
3122   // If the output shape is unranked we can extract the result shape from the
3123   // attribute shape. This occurs as some TFLite folders create constants with
3124   // unranked shapes.
3125   if (!output_type.hasRank()) {
3126     output_type = elements.getType().cast<ShapedType>().clone(
3127         output_type.getElementType());
3128   }
3129 
3130   rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, elements);
3131 
3132   return success();
3133 }
3134 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3135 LogicalResult ConvertConstantOp::matchAndRewrite(
3136     Operation* op, PatternRewriter& rewriter) const {
3137   auto tfl_const_op = cast<arith::ConstantOp>(op);
3138 
3139   ShapedType output_type =
3140       tfl_const_op.getResult().getType().dyn_cast<ShapedType>();
3141   // Not a ranked tensor output
3142   if (!output_type) return failure();
3143 
3144   ElementsAttr attr = tfl_const_op.getValueAttr().dyn_cast<ElementsAttr>();
3145 
3146   auto e_type = output_type.getElementType();
3147   // TOSA only support up to 48-bits
3148   // If source is higher than that, it's not representabble.
3149   // For data type like 64 bits, we need to truncate them into 48 bits.
3150   if (e_type.isInteger(64)) {
3151     e_type = rewriter.getIntegerType(48);
3152     attr = attr.cast<DenseIntOrFPElementsAttr>().mapValues(
3153         e_type, [](const APInt& x) -> APInt { return x.trunc(48); });
3154   }
3155 
3156   if (!output_type.hasRank()) {
3157     if (auto attr_type = attr.getType().dyn_cast<ShapedType>()) {
3158       output_type = attr_type.clone(e_type);
3159     }
3160   }
3161 
3162   output_type = output_type.clone(e_type);
3163   rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, attr);
3164 
3165   return success();
3166 }
3167 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3168 LogicalResult ConvertTFLGatherOp::matchAndRewrite(
3169     Operation* op, PatternRewriter& rewriter) const {
3170   auto tfl_gather_op = cast<TFL::GatherOp>(op);
3171 
3172   int32_t axis = tfl_gather_op.axisAttr().getInt();
3173   int32_t batch_dims = 0;
3174   if (auto batch_attr = tfl_gather_op.batch_dimsAttr()) {
3175     batch_dims = static_cast<int32_t>(batch_attr.getInt());
3176   }
3177 
3178   llvm::Optional<Value> result = convertGatherOp(
3179       rewriter, op, tfl_gather_op.getResult(), tfl_gather_op.params(),
3180       tfl_gather_op.indices(), batch_dims, axis);
3181 
3182   if (!result) return failure();
3183 
3184   rewriter.replaceOp(op, {result.getValue()});
3185 
3186   return success();
3187 }
3188 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3189 LogicalResult ConvertTFLGatherNdOp::matchAndRewrite(
3190     Operation* op, PatternRewriter& rewriter) const {
3191   auto tfl_gathernd_op = cast<TFL::GatherNdOp>(op);
3192 
3193   llvm::Optional<Value> result =
3194       convertGatherNdOp(rewriter, op, tfl_gathernd_op.getResult(),
3195                         tfl_gathernd_op.params(), tfl_gathernd_op.indices());
3196 
3197   if (!result) return failure();
3198   rewriter.replaceOp(op, {result.getValue()});
3199 
3200   return success();
3201 }
3202 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3203 LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite(
3204     Operation* op, PatternRewriter& rewriter) const {
3205   auto tfl_sparse_to_dense_op = cast<TFL::SparseToDenseOp>(op);
3206   auto indices = tfl_sparse_to_dense_op.sparse_indices();
3207   auto values = tfl_sparse_to_dense_op.sparse_values();
3208   auto default_value = tfl_sparse_to_dense_op.default_value();
3209   auto indices_ty = indices.getType().cast<ShapedType>();
3210   auto indices_ety = indices_ty.getElementType();
3211   auto values_ty = values.getType().cast<ShapedType>();
3212   auto result_ty =
3213       tfl_sparse_to_dense_op.getResult().getType().cast<ShapedType>();
3214   auto result_ety = result_ty.getElementType();
3215   auto loc = op->getLoc();
3216 
3217   if (!result_ty.hasStaticShape()) return failure();
3218   auto result_rank = result_ty.getRank();
3219 
3220   // We want to generate the default tensor we need to scatter. Note that the
3221   // result_ty needs to be a statically shaped tensor.
3222   ElementsAttr default_value_attr;
3223   if (!matchPattern(default_value, m_Constant(&default_value_attr)))
3224     return failure();
3225 
3226   if (!default_value_attr.isSplat()) return failure();
3227 
3228   ShapedType scatter_ty =
3229       RankedTensorType::get({1, result_ty.getNumElements(), 1}, result_ety);
3230 
3231   Value default_const = rewriter.create<tosa::ConstOp>(
3232       loc, scatter_ty,
3233       DenseElementsAttr::get(scatter_ty,
3234                              default_value_attr.getSplatValue<APInt>().sext(
3235                                  result_ety.getIntOrFloatBitWidth())));
3236 
3237   // We need to determine what the index multiplier does
3238   llvm::SmallVector<int32_t> multiply_constant_ints;
3239   multiply_constant_ints.resize(result_rank, 1);
3240   for (int i = result_rank - 1; i > 0; i--) {
3241     multiply_constant_ints[i - 1] =
3242         result_ty.getDimSize(i) * multiply_constant_ints[i];
3243   }
3244 
3245   indices_ety = rewriter.getI32Type();
3246   indices_ty = RankedTensorType::get(indices_ty.getShape(), indices_ety);
3247   indices = CreateOpAndInfer<tosa::CastOp>(rewriter, loc, indices_ty, indices);
3248 
3249   auto multiply_constant_type =
3250       RankedTensorType::get({result_rank}, indices_ety);
3251   auto multiply_constant_attr = DenseElementsAttr::get(
3252       multiply_constant_type, llvm::makeArrayRef(multiply_constant_ints));
3253   Value multiply_constant = CreateOpAndInfer<tosa::ConstOp>(
3254       rewriter, loc, multiply_constant_type, multiply_constant_attr);
3255 
3256   Value multiply_op = CreateOpAndInfer<tosa::MulOp>(
3257       rewriter, loc, indices_ty, indices, multiply_constant, 0);
3258 
3259   Value reduce_op = CreateOpAndInfer<tosa::ReduceSumOp>(
3260       rewriter, loc, UnrankedTensorType::get(indices_ety), multiply_op,
3261       rewriter.getI64IntegerAttr(1));
3262 
3263   auto values_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3264       rewriter, loc, UnrankedTensorType::get(result_ety), values,
3265       rewriter.getI64ArrayAttr(
3266           ArrayRef<int64_t>{1, values_ty.getDimSize(0), 1}));
3267 
3268   auto index_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3269       rewriter, loc, UnrankedTensorType::get(indices_ety), reduce_op,
3270       rewriter.getI64ArrayAttr(ArrayRef<int64_t>{1, indices_ty.getDimSize(0)}));
3271 
3272   auto scatter = CreateOpAndInfer<tosa::ScatterOp>(
3273       rewriter, loc, UnrankedTensorType::get(result_ety), default_const,
3274       index_reshape_op, values_reshape_op);
3275 
3276   CreateReplaceOpAndInfer<tosa::ReshapeOp>(
3277       rewriter, op, result_ty, scatter,
3278       rewriter.getI64ArrayAttr(result_ty.getShape()));
3279 
3280   return success();
3281 }
3282 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3283 LogicalResult ConvertTFLOneHotOp::matchAndRewrite(
3284     Operation* op, PatternRewriter& rewriter) const {
3285   auto tfl_one_hot_op = cast<TFL::OneHotOp>(op);
3286 
3287   ElementsAttr depth_elems;
3288   if (!matchPattern(tfl_one_hot_op.depth(), m_Constant(&depth_elems)))
3289     return failure();
3290   int32_t depth = depth_elems.getValues<APInt>()[0].getSExtValue();
3291 
3292   IntegerAttr axisAttr = tfl_one_hot_op.axisAttr();
3293   int32_t axis = axisAttr.getInt();
3294 
3295   llvm::Optional<Value> result = convertOneHotOp(
3296       rewriter, op, tfl_one_hot_op.getResult(), tfl_one_hot_op.indices(),
3297       tfl_one_hot_op.on_value(), tfl_one_hot_op.off_value(), depth, axis);
3298 
3299   if (!result) return failure();
3300 
3301   rewriter.replaceOp(op, {result.getValue()});
3302 
3303   return success();
3304 }
3305 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3306 LogicalResult ConvertTFLArgMaxOp::matchAndRewrite(
3307     Operation* op, PatternRewriter& rewriter) const {
3308   auto arg_max_op = cast<TFL::ArgMaxOp>(op);
3309 
3310   ElementsAttr dim_elems;
3311   if (!matchPattern(arg_max_op.dim(), m_Constant(&dim_elems))) return failure();
3312 
3313   int32_t dim = dim_elems.getValues<APInt>()[0].getSExtValue();
3314   CreateReplaceOpAndInfer<tosa::ArgMaxOp>(
3315       rewriter, op, arg_max_op.getType(), arg_max_op.input(),
3316       rewriter.getIntegerAttr(rewriter.getI64Type(), dim));
3317 
3318   return success();
3319 }
3320 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3321 LogicalResult ConvertTFLFakeQuantOp::matchAndRewrite(
3322     Operation* op, PatternRewriter& rewriter) const {
3323   auto fakequant_op = cast<TFL::FakeQuantOp>(op);
3324 
3325   ShapedType output_type =
3326       fakequant_op.getResult().getType().dyn_cast<ShapedType>();
3327   // Not a ranked tensor output
3328   if (!output_type) return failure();
3329 
3330   llvm::Optional<Value> result =
3331       convertFakeQuantOp(rewriter, op, output_type, fakequant_op.input(),
3332                          fakequant_op.minAttr().getValueAsDouble(),
3333                          fakequant_op.maxAttr().getValueAsDouble(),
3334                          fakequant_op.num_bitsAttr().getInt(),
3335                          fakequant_op.narrow_rangeAttr().getValue());
3336 
3337   if (!result) return failure();
3338 
3339   rewriter.replaceOp(op, {result.getValue()});
3340 
3341   return success();
3342 }
3343 
initialize(MLIRContext * context)3344 LogicalResult LegalizeTFL::initialize(MLIRContext* context) {
3345   RewritePatternSet patterns(context);
3346   mlir::tosa::populateLegalizeTFLPatterns(context, patterns);
3347   frozen_patterns_ = FrozenRewritePatternSet(
3348       std::move(patterns), this->disabled_patterns_, this->enabled_patterns_);
3349   return success();
3350 }
3351 
runOnOperation()3352 void LegalizeTFL::runOnOperation() {
3353   if (ApplyPatternsWithShapeResolution(getOperation(), this->frozen_patterns_)
3354           .failed()) {
3355     signalPassFailure();
3356   }
3357 }
3358 
3359 }  // namespace
3360 
populateLegalizeTFLPatterns(MLIRContext * ctx,RewritePatternSet & patterns)3361 void populateLegalizeTFLPatterns(MLIRContext* ctx,
3362                                  RewritePatternSet& patterns) {
3363 #define DEF_PATTERN_INSERT(PAT) \
3364   patterns.addWithLabel<Convert##PAT##Op>({#PAT}, ctx);
3365 
3366   DEF_PATTERN_INSERT(TFLAbs);
3367   DEF_PATTERN_INSERT(TFLCeil);
3368   DEF_PATTERN_INSERT(TFLFloor);
3369   DEF_PATTERN_INSERT(TFLExp);
3370   DEF_PATTERN_INSERT(TFLLog);
3371   DEF_PATTERN_INSERT(TFLRsqrt);
3372   DEF_PATTERN_INSERT(TFLLogicalNot);
3373   DEF_PATTERN_INSERT(TFLCast);
3374 
3375   DEF_PATTERN_INSERT(QuantStat);
3376 
3377   DEF_PATTERN_INSERT(TFLLogicalAnd);
3378   DEF_PATTERN_INSERT(TFLLogicalOr);
3379   DEF_PATTERN_INSERT(TFLPow);
3380 
3381   DEF_PATTERN_INSERT(TFLRelu);
3382   DEF_PATTERN_INSERT(TFLRelu1);
3383   DEF_PATTERN_INSERT(TFLRelu6);
3384   DEF_PATTERN_INSERT(TFLEqual);
3385   DEF_PATTERN_INSERT(TFLNotEqual);
3386   DEF_PATTERN_INSERT(TFLGreater);
3387   DEF_PATTERN_INSERT(TFLGreaterEqual);
3388   DEF_PATTERN_INSERT(TFLAdd);
3389   DEF_PATTERN_INSERT(TFLSub);
3390   DEF_PATTERN_INSERT(TFLMul);
3391   DEF_PATTERN_INSERT(TFLSquare);
3392   DEF_PATTERN_INSERT(TFLSquaredDifference);
3393   DEF_PATTERN_INSERT(TFLRound);
3394   DEF_PATTERN_INSERT(TFLDiv);
3395   DEF_PATTERN_INSERT(TFLMaximum);
3396   DEF_PATTERN_INSERT(TFLMinimum);
3397   DEF_PATTERN_INSERT(TFLFloorMod);
3398   DEF_PATTERN_INSERT(TFLFloorDiv);
3399   DEF_PATTERN_INSERT(TFLAddN);
3400   DEF_PATTERN_INSERT(TFLAveragePool2D);
3401   DEF_PATTERN_INSERT(TFLMaxPool2D);
3402   DEF_PATTERN_INSERT(TFLConcatenation);
3403   DEF_PATTERN_INSERT(TFLReshape);
3404   DEF_PATTERN_INSERT(TFLRank);
3405   DEF_PATTERN_INSERT(TFLShape);
3406   DEF_PATTERN_INSERT(TFLExpandDims);
3407   DEF_PATTERN_INSERT(TFLSqueeze);
3408   DEF_PATTERN_INSERT(TFLFill);
3409   DEF_PATTERN_INSERT(TFLElu);
3410   DEF_PATTERN_INSERT(TFLSoftmax);
3411   DEF_PATTERN_INSERT(TFLLogSoftmax);
3412   DEF_PATTERN_INSERT(TFLSqrt);
3413   DEF_PATTERN_INSERT(TFLL2Normalization);
3414   DEF_PATTERN_INSERT(TFLReduceAny);
3415   DEF_PATTERN_INSERT(TFLReduceMax);
3416   DEF_PATTERN_INSERT(TFLReduceMin);
3417   DEF_PATTERN_INSERT(TFLMean);
3418   DEF_PATTERN_INSERT(TFLReduceProd);
3419   DEF_PATTERN_INSERT(TFLSum);
3420   DEF_PATTERN_INSERT(TFLConv2D);
3421   DEF_PATTERN_INSERT(TFLTransposeConv);
3422   DEF_PATTERN_INSERT(TFLDepthwiseConv2D);
3423   DEF_PATTERN_INSERT(TFLFullyConnected);
3424   DEF_PATTERN_INSERT(TFLBatchMatMul);
3425   DEF_PATTERN_INSERT(TFLSplit);
3426   DEF_PATTERN_INSERT(TFLSplitV);
3427   DEF_PATTERN_INSERT(TFLPack);
3428   DEF_PATTERN_INSERT(TFLUnpack);
3429   DEF_PATTERN_INSERT(TFLTranspose);
3430   DEF_PATTERN_INSERT(TFLTile);
3431   DEF_PATTERN_INSERT(TFLSlice);
3432   DEF_PATTERN_INSERT(TFLStridedSlice);
3433   DEF_PATTERN_INSERT(TFLHardSwish);
3434   DEF_PATTERN_INSERT(TFLZerosLike);
3435   DEF_PATTERN_INSERT(TFLLess);
3436   DEF_PATTERN_INSERT(TFLLessEqual);
3437   DEF_PATTERN_INSERT(TFLPad);
3438   DEF_PATTERN_INSERT(TFLPadV2);
3439   DEF_PATTERN_INSERT(TFLResizeBilinear);
3440   DEF_PATTERN_INSERT(TFLResizeNearestNeighbor);
3441   DEF_PATTERN_INSERT(TFLSelect);
3442   DEF_PATTERN_INSERT(TFLSelectV2);
3443   DEF_PATTERN_INSERT(TFLSpaceToBatchNd);
3444   DEF_PATTERN_INSERT(TFLBatchToSpaceNd);
3445   DEF_PATTERN_INSERT(TFLSpaceToDepth);
3446   DEF_PATTERN_INSERT(TFLDepthToSpace);
3447   DEF_PATTERN_INSERT(TFLSin);
3448   DEF_PATTERN_INSERT(TFLCos);
3449   DEF_PATTERN_INSERT(TFLLogistic);
3450   DEF_PATTERN_INSERT(TFLTanh);
3451   DEF_PATTERN_INSERT(TFLPRelu);
3452   DEF_PATTERN_INSERT(TFLLeakyRelu);
3453   DEF_PATTERN_INSERT(TFLNeg);
3454   DEF_PATTERN_INSERT(TFLYield);
3455   DEF_PATTERN_INSERT(TFLCustom);
3456   DEF_PATTERN_INSERT(TFLReverseV2);
3457   DEF_PATTERN_INSERT(TFLQuantize);
3458   DEF_PATTERN_INSERT(TFLDequantize);
3459   DEF_PATTERN_INSERT(TFLConst);
3460   DEF_PATTERN_INSERT(TFLQConst);
3461   DEF_PATTERN_INSERT(TFLGather);
3462   DEF_PATTERN_INSERT(TFLGatherNd);
3463   DEF_PATTERN_INSERT(TFLSparseToDense);
3464   DEF_PATTERN_INSERT(Constant);
3465   DEF_PATTERN_INSERT(TFLOneHot);
3466   DEF_PATTERN_INSERT(TFLArgMax);
3467   DEF_PATTERN_INSERT(TFLFakeQuant);
3468 }
3469 
3470 // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
createLegalizeTFLPass(ArrayRef<std::string> disabled_patterns,ArrayRef<std::string> enabled_patterns)3471 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTFLPass(
3472     ArrayRef<std::string> disabled_patterns,
3473     ArrayRef<std::string> enabled_patterns) {
3474   return std::make_unique<LegalizeTFL>(disabled_patterns, enabled_patterns);
3475 }
3476 
3477 }  // namespace tosa
3478 }  // namespace mlir
3479