1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Legalize TensorFlow Lite to TOSA
17
18 #include <climits>
19 #include <cmath>
20 #include <cstddef>
21 #include <cstdint>
22 #include <fstream>
23 #include <iterator>
24 #include <limits>
25 #include <numeric>
26 #include <string>
27 #include <unordered_set>
28
29 #include "llvm/ADT/ArrayRef.h"
30 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
31 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
32 #include "mlir/Dialect/Traits.h" // from @llvm-project
33 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/Matchers.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
41 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
42 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
44
45 #define PASS_NAME "tosa-legalize-tfl"
46 #define DEBUG_TYPE PASS_NAME
47 #define HARDSWISH_EXPLICIT_RESCALING false
48
49 namespace mlir {
50 namespace tosa {
51 namespace {
52
53 // Performs lowering to TOSA dialect.
54 class LegalizeTFL : public TosaLegalizeTFLPassBase<LegalizeTFL> {
55 public:
56 LegalizeTFL() = default;
LegalizeTFL(ArrayRef<std::string> disabled_patterns,ArrayRef<std::string> enabled_patterns)57 explicit LegalizeTFL(ArrayRef<std::string> disabled_patterns,
58 ArrayRef<std::string> enabled_patterns) {
59 this->disabled_patterns_ = disabled_patterns;
60 this->enabled_patterns_ = enabled_patterns;
61 }
62 void runOnOperation() override;
63 LogicalResult initialize(MLIRContext* context) override;
64
65 private:
66 FrozenRewritePatternSet frozen_patterns_;
67 };
68
69 #include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
70
71 // Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
72 // bits. Need to do a customized truncate here instead of tablegen to handle
73 // attribute with negative value.
74 struct ConvertConstantOp : public RewritePattern {
ConvertConstantOpmlir::tosa::__anon447d85b10111::ConvertConstantOp75 explicit ConvertConstantOp(MLIRContext* context)
76 : RewritePattern(arith::ConstantOp::getOperationName(), 1, context) {}
77 LogicalResult matchAndRewrite(Operation* op,
78 PatternRewriter& rewriter) const override;
79 };
80
81 #define DECL_CONVERT_OP(tfl_op) \
82 struct ConvertTFL##tfl_op##Op : public RewritePattern { \
83 explicit ConvertTFL##tfl_op##Op(MLIRContext* context) \
84 : RewritePattern(TFL::tfl_op##Op::getOperationName(), 1, context) {} \
85 LogicalResult matchAndRewrite(Operation* op, \
86 PatternRewriter& rewriter) const override; \
87 }
88 DECL_CONVERT_OP(Relu);
89 DECL_CONVERT_OP(Relu1);
90 DECL_CONVERT_OP(Relu6);
91 DECL_CONVERT_OP(Equal);
92 DECL_CONVERT_OP(NotEqual);
93 DECL_CONVERT_OP(Greater);
94 DECL_CONVERT_OP(GreaterEqual);
95 DECL_CONVERT_OP(Add);
96 DECL_CONVERT_OP(Sub);
97 DECL_CONVERT_OP(Mul);
98 DECL_CONVERT_OP(Square);
99 DECL_CONVERT_OP(SquaredDifference);
100 DECL_CONVERT_OP(Round);
101 DECL_CONVERT_OP(Div);
102 DECL_CONVERT_OP(Maximum);
103 DECL_CONVERT_OP(Minimum);
104 DECL_CONVERT_OP(FloorMod);
105 DECL_CONVERT_OP(FloorDiv);
106 DECL_CONVERT_OP(AddN);
107 DECL_CONVERT_OP(AveragePool2D);
108 DECL_CONVERT_OP(MaxPool2D);
109 DECL_CONVERT_OP(Concatenation);
110 DECL_CONVERT_OP(Reshape);
111 DECL_CONVERT_OP(Rank);
112 DECL_CONVERT_OP(Shape);
113 DECL_CONVERT_OP(ExpandDims);
114 DECL_CONVERT_OP(Squeeze);
115 DECL_CONVERT_OP(Fill);
116 DECL_CONVERT_OP(Elu);
117 DECL_CONVERT_OP(Softmax);
118 DECL_CONVERT_OP(LogSoftmax);
119 DECL_CONVERT_OP(Sqrt);
120 DECL_CONVERT_OP(L2Normalization);
121 DECL_CONVERT_OP(ReduceAny);
122 DECL_CONVERT_OP(ReduceMax);
123 DECL_CONVERT_OP(ReduceMin);
124 DECL_CONVERT_OP(Mean);
125 DECL_CONVERT_OP(ReduceProd);
126 DECL_CONVERT_OP(Sum);
127 DECL_CONVERT_OP(Conv2D);
128 DECL_CONVERT_OP(TransposeConv);
129 DECL_CONVERT_OP(DepthwiseConv2D);
130 DECL_CONVERT_OP(FullyConnected);
131 DECL_CONVERT_OP(BatchMatMul);
132 DECL_CONVERT_OP(Split);
133 DECL_CONVERT_OP(SplitV);
134 DECL_CONVERT_OP(Pack);
135 DECL_CONVERT_OP(Unpack);
136 DECL_CONVERT_OP(Transpose);
137 DECL_CONVERT_OP(Tile);
138 DECL_CONVERT_OP(Slice);
139 DECL_CONVERT_OP(StridedSlice);
140 DECL_CONVERT_OP(HardSwish);
141 DECL_CONVERT_OP(ZerosLike);
142 DECL_CONVERT_OP(Less);
143 DECL_CONVERT_OP(LessEqual);
144 DECL_CONVERT_OP(Pad);
145 DECL_CONVERT_OP(PadV2);
146 DECL_CONVERT_OP(ResizeBilinear);
147 DECL_CONVERT_OP(ResizeNearestNeighbor);
148 DECL_CONVERT_OP(Select);
149 DECL_CONVERT_OP(SelectV2);
150 DECL_CONVERT_OP(SpaceToBatchNd);
151 DECL_CONVERT_OP(BatchToSpaceNd);
152 DECL_CONVERT_OP(SpaceToDepth);
153 DECL_CONVERT_OP(DepthToSpace);
154 DECL_CONVERT_OP(Sin);
155 DECL_CONVERT_OP(Cos);
156 DECL_CONVERT_OP(Logistic);
157 DECL_CONVERT_OP(Tanh);
158 DECL_CONVERT_OP(PRelu);
159 DECL_CONVERT_OP(LeakyRelu);
160 DECL_CONVERT_OP(Neg);
161 DECL_CONVERT_OP(Yield);
162 DECL_CONVERT_OP(Custom);
163 DECL_CONVERT_OP(ReverseV2);
164 DECL_CONVERT_OP(Quantize);
165 DECL_CONVERT_OP(Dequantize);
166 DECL_CONVERT_OP(Const);
167 DECL_CONVERT_OP(QConst);
168 DECL_CONVERT_OP(Gather);
169 DECL_CONVERT_OP(GatherNd);
170 DECL_CONVERT_OP(SparseToDense);
171 DECL_CONVERT_OP(OneHot);
172 DECL_CONVERT_OP(ArgMax);
173 DECL_CONVERT_OP(FakeQuant);
174
175 #undef DECL_CONVERT_OP
176
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const177 LogicalResult ConvertTFLReluOp::matchAndRewrite(
178 Operation* op, PatternRewriter& rewriter) const {
179 auto tfl_relu_op = cast<TFL::ReluOp>(op);
180
181 ShapedType input_type = tfl_relu_op.x().getType().dyn_cast<ShapedType>();
182 ShapedType output_type =
183 tfl_relu_op.getResult().getType().dyn_cast<ShapedType>();
184 // Not a ranked tensor output
185 if (!input_type || !output_type) return failure();
186
187 bool input_is_qtype =
188 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
189 bool output_is_qtype =
190 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
191
192 if (input_is_qtype != output_is_qtype) {
193 return op->emitOpError(
194 "ConvertTFLReluOp: input/output tensor should "
195 "be all quantized or all floating-point.");
196 }
197
198 int64_t clamp_min = 0;
199 Value clamp_in = tfl_relu_op.x();
200
201 if (output_is_qtype) {
202 UniformQuantizedType input_qtype =
203 input_type.getElementType()
204 .dyn_cast<mlir::quant::UniformQuantizedType>();
205 UniformQuantizedType output_qtype =
206 output_type.getElementType()
207 .dyn_cast<mlir::quant::UniformQuantizedType>();
208
209 clamp_min = output_qtype.getZeroPoint();
210 clamp_in =
211 buildRescale(rewriter, op, output_type, tfl_relu_op.x(),
212 input_qtype.getScale() / output_qtype.getScale(),
213 input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
214 /*double_round=*/false, /*scale32=*/true);
215 }
216
217 CreateReplaceOpAndInfer<tosa::ClampOp>(
218 rewriter, op, output_type, clamp_in,
219 rewriter.getI64IntegerAttr(clamp_min),
220 rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
221 rewriter.getF32FloatAttr(0.0f),
222 rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
223
224 return success();
225 }
226
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const227 LogicalResult ConvertTFLRelu1Op::matchAndRewrite(
228 Operation* op, PatternRewriter& rewriter) const {
229 auto tfl_relu1_op = cast<TFL::Relu1Op>(op);
230
231 ShapedType input_type = tfl_relu1_op.x().getType().dyn_cast<ShapedType>();
232 ShapedType output_type =
233 tfl_relu1_op.getResult().getType().dyn_cast<ShapedType>();
234 // Not a ranked tensor output
235 if (!input_type || !output_type) return failure();
236
237 bool input_is_qtype =
238 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
239 bool output_is_qtype =
240 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
241
242 if (input_is_qtype != output_is_qtype) {
243 return op->emitOpError(
244 "ConvertTFLRelu1Op: input/output tensor should "
245 "be all quantized or all floating-point.");
246 }
247
248 int64_t clamp_min = -1;
249 int64_t clamp_max = 1;
250 Value clamp_in = tfl_relu1_op.x();
251
252 if (output_is_qtype && input_is_qtype) {
253 UniformQuantizedType input_qtype =
254 input_type.getElementType()
255 .dyn_cast<mlir::quant::UniformQuantizedType>();
256 UniformQuantizedType output_qtype =
257 output_type.getElementType()
258 .dyn_cast<mlir::quant::UniformQuantizedType>();
259
260 clamp_min = output_qtype.getZeroPoint() -
261 std::llround(1.0f / output_qtype.getScale());
262
263 clamp_max = std::llround(1.0f / output_qtype.getScale()) +
264 output_qtype.getZeroPoint();
265
266 clamp_in =
267 buildRescale(rewriter, op, output_type, tfl_relu1_op.x(),
268 input_qtype.getScale() / output_qtype.getScale(),
269 input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
270 /*double_round=*/false, /*scale32=*/true);
271 }
272
273 CreateReplaceOpAndInfer<tosa::ClampOp>(rewriter, op, output_type, clamp_in,
274 rewriter.getI64IntegerAttr(clamp_min),
275 rewriter.getI64IntegerAttr(clamp_max),
276 rewriter.getF32FloatAttr(-1.0f),
277 rewriter.getF32FloatAttr(1.0f));
278
279 return success();
280 }
281
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const282 LogicalResult ConvertTFLRelu6Op::matchAndRewrite(
283 Operation* op, PatternRewriter& rewriter) const {
284 auto tfl_relu6_op = cast<TFL::Relu6Op>(op);
285
286 ShapedType input_type = tfl_relu6_op.x().getType().dyn_cast<ShapedType>();
287 ShapedType output_type =
288 tfl_relu6_op.getResult().getType().dyn_cast<ShapedType>();
289 // Not a ranked tensor output
290 if (!input_type || !output_type) return failure();
291
292 bool input_is_qtype =
293 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
294 bool output_is_qtype =
295 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
296
297 if (input_is_qtype != output_is_qtype) {
298 return op->emitOpError(
299 "ConvertTFLRelu6Op: input/output tensor should "
300 "be all quantized or all floating-point.");
301 }
302
303 int64_t clamp_min = 0;
304 int64_t clamp_max = 6;
305 Value clamp_in = tfl_relu6_op.x();
306
307 if (output_is_qtype && input_is_qtype) {
308 UniformQuantizedType input_qtype =
309 input_type.getElementType()
310 .dyn_cast<mlir::quant::UniformQuantizedType>();
311 UniformQuantizedType output_qtype =
312 output_type.getElementType()
313 .dyn_cast<mlir::quant::UniformQuantizedType>();
314
315 clamp_min = output_qtype.getZeroPoint();
316 clamp_max = std::llround(6.0f / output_qtype.getScale()) +
317 output_qtype.getZeroPoint();
318
319 clamp_in =
320 buildRescale(rewriter, op, output_type, tfl_relu6_op.x(),
321 input_qtype.getScale() / output_qtype.getScale(),
322 input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
323 /*double_round=*/false, /*scale32=*/true);
324 }
325
326 CreateReplaceOpAndInfer<tosa::ClampOp>(rewriter, op, output_type, clamp_in,
327 rewriter.getI64IntegerAttr(clamp_min),
328 rewriter.getI64IntegerAttr(clamp_max),
329 rewriter.getF32FloatAttr(0.0f),
330 rewriter.getF32FloatAttr(6.0f));
331
332 return success();
333 }
334
prepareMatchAndRewriteComparison(Operation * op,mlir::OperandRange operands,PatternRewriter & rewriter,llvm::SmallVectorImpl<Value> & newOperands)335 static LogicalResult prepareMatchAndRewriteComparison(
336 Operation* op, mlir::OperandRange operands, PatternRewriter& rewriter,
337 llvm::SmallVectorImpl<Value>& newOperands) {
338 Value x = operands[0];
339 Value y = operands[1];
340 Value result = op->getResult(0);
341
342 ShapedType input_x_type = x.getType().dyn_cast<ShapedType>();
343 ShapedType input_y_type = y.getType().dyn_cast<ShapedType>();
344 ShapedType output_type = result.getType().dyn_cast<ShapedType>();
345 // Not a shaped tensor output
346 if (!input_x_type || !input_y_type || !output_type) return failure();
347
348 bool input_x_is_qtype =
349 input_x_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
350 bool input_y_is_qtype =
351 input_y_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
352 bool output_is_qtype =
353 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
354
355 if (input_x_is_qtype != input_y_is_qtype ||
356 input_y_is_qtype != output_is_qtype) {
357 return op->emitOpError(
358 "ConvertTFLEqualOp: input/output tensor should "
359 "be all quantized or all floating-point.");
360 }
361
362 if (!output_is_qtype && !input_x_is_qtype && !input_y_is_qtype) {
363 newOperands.push_back(x);
364 newOperands.push_back(y);
365 return success();
366 }
367
368 UniformQuantizedType input_x_qtype =
369 input_x_type.getElementType()
370 .dyn_cast<mlir::quant::UniformQuantizedType>();
371 UniformQuantizedType input_y_qtype =
372 input_y_type.getElementType()
373 .dyn_cast<mlir::quant::UniformQuantizedType>();
374
375 if (input_x_qtype.getScale() != input_y_qtype.getScale() ||
376 input_x_qtype.getZeroPoint() != input_y_qtype.getZeroPoint()) {
377 return op->emitOpError(
378 "ConvertTFLEqualOp: input_x and input_y scale/zp "
379 "must be the same");
380 }
381
382 x = buildRescaleToInt32(rewriter, op, x, 1.0f, input_x_qtype.getZeroPoint());
383 y = buildRescaleToInt32(rewriter, op, y, 1.0f, input_y_qtype.getZeroPoint());
384
385 newOperands.push_back(x);
386 newOperands.push_back(y);
387 return success();
388 }
389
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const390 LogicalResult ConvertTFLEqualOp::matchAndRewrite(
391 Operation* op, PatternRewriter& rewriter) const {
392 llvm::SmallVector<Value, 2> newOperands;
393 LogicalResult status = prepareMatchAndRewriteComparison(
394 op, op->getOperands(), rewriter, newOperands);
395 if (status.failed()) return failure();
396
397 CreateReplaceOpAndInfer<tosa::EqualOp>(
398 rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
399
400 return success();
401 }
402
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const403 LogicalResult ConvertTFLNotEqualOp::matchAndRewrite(
404 Operation* op, PatternRewriter& rewriter) const {
405 llvm::SmallVector<Value, 2> newOperands;
406 LogicalResult status = prepareMatchAndRewriteComparison(
407 op, op->getOperands(), rewriter, newOperands);
408 if (status.failed()) return failure();
409
410 auto equal_op = CreateOpAndInfer<tosa::EqualOp>(
411 rewriter, op->getLoc(), op->getResult(0).getType(), newOperands[0],
412 newOperands[1]);
413
414 CreateReplaceOpAndInfer<tosa::LogicalNotOp>(
415 rewriter, op, op->getResult(0).getType(), equal_op);
416
417 return success();
418 }
419
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const420 LogicalResult ConvertTFLGreaterOp::matchAndRewrite(
421 Operation* op, PatternRewriter& rewriter) const {
422 llvm::SmallVector<Value, 2> newOperands;
423 LogicalResult status = prepareMatchAndRewriteComparison(
424 op, op->getOperands(), rewriter, newOperands);
425 if (status.failed()) return failure();
426
427 CreateReplaceOpAndInfer<tosa::GreaterOp>(
428 rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
429
430 return success();
431 }
432
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const433 LogicalResult ConvertTFLGreaterEqualOp::matchAndRewrite(
434 Operation* op, PatternRewriter& rewriter) const {
435 llvm::SmallVector<Value, 2> newOperands;
436 LogicalResult status = prepareMatchAndRewriteComparison(
437 op, op->getOperands(), rewriter, newOperands);
438 if (status.failed()) return failure();
439
440 CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(
441 rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
442
443 return success();
444 }
445
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const446 LogicalResult ConvertTFLLessOp::matchAndRewrite(
447 Operation* op, PatternRewriter& rewriter) const {
448 llvm::SmallVector<Value, 2> newOperands;
449 LogicalResult status = prepareMatchAndRewriteComparison(
450 op, op->getOperands(), rewriter, newOperands);
451 if (status.failed()) return failure();
452
453 CreateReplaceOpAndInfer<tosa::GreaterOp>(
454 rewriter, op, op->getResult(0).getType(), newOperands[1], newOperands[0]);
455 return success();
456 }
457
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const458 LogicalResult ConvertTFLLessEqualOp::matchAndRewrite(
459 Operation* op, PatternRewriter& rewriter) const {
460 llvm::SmallVector<Value, 2> newOperands;
461 LogicalResult status = prepareMatchAndRewriteComparison(
462 op, op->getOperands(), rewriter, newOperands);
463 if (status.failed()) return failure();
464
465 // Swapping the args handles the greater/less difference.
466 CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(
467 rewriter, op, op->getResult(0).getType(), newOperands[1], newOperands[0]);
468
469 return success();
470 }
471
472 template <typename TflOp, typename TosaOp>
matchAndRewriteAddSub(Operation * op,mlir::OperandRange operands,PatternRewriter & rewriter)473 static LogicalResult matchAndRewriteAddSub(Operation* op,
474 mlir::OperandRange operands,
475 PatternRewriter& rewriter) {
476 auto tfl_add_op = cast<TflOp>(op);
477
478 ShapedType input_lhs_type =
479 tfl_add_op.lhs().getType().template dyn_cast<ShapedType>();
480 ShapedType input_rhs_type =
481 tfl_add_op.rhs().getType().template dyn_cast<ShapedType>();
482 ShapedType output_type =
483 tfl_add_op.getResult().getType().template dyn_cast<ShapedType>();
484 // Not a ranked tensor output
485 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
486
487 bool input_lhs_is_qtype =
488 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
489 bool input_rhs_is_qtype =
490 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
491 bool output_is_qtype =
492 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
493
494 if (input_lhs_is_qtype != output_is_qtype ||
495 input_rhs_is_qtype != output_is_qtype) {
496 return op->emitOpError(
497 "ConvertTFLAddOp: input/output tensor should "
498 "be all quantized or all floating-point.");
499 }
500
501 Value output;
502 if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
503 ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
504 UniformQuantizedType input_lhs_qtype =
505 input_lhs_type.getElementType()
506 .dyn_cast<mlir::quant::UniformQuantizedType>();
507 UniformQuantizedType input_rhs_qtype =
508 input_rhs_type.getElementType()
509 .dyn_cast<mlir::quant::UniformQuantizedType>();
510 UniformQuantizedType output_qtype =
511 output_type.getElementType()
512 .dyn_cast<mlir::quant::UniformQuantizedType>();
513
514 // Following quantization described in tensorflow/lite/kernels/add.cc
515 // In details it does:
516 // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
517 // 2. Extra left shift to input to increase precision
518 // Where input_shift = 20 if input is 8-bit
519 // input_shift = 15 if input is 16-bit
520 double in_lhs_scale = input_lhs_qtype.getScale();
521 double in_rhs_scale = input_rhs_qtype.getScale();
522 double output_scale = output_qtype.getScale();
523 double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
524
525 const int32_t SHIFT_8_BIT = 20;
526 const int32_t SHIFT_16_BIT = 15;
527
528 int32_t input_shift = (output_qtype.getStorageTypeIntegralWidth() == 16)
529 ? SHIFT_16_BIT
530 : SHIFT_8_BIT;
531
532 double lhs_rescale_scale =
533 static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
534 double rhs_rescale_scale =
535 static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
536 double output_rescale_scale =
537 max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
538
539 Value op1_rescale_lhs =
540 buildRescaleToInt32(rewriter, op, tfl_add_op.lhs(), lhs_rescale_scale,
541 input_lhs_qtype.getZeroPoint());
542 Value op2_rescale_rhs =
543 buildRescaleToInt32(rewriter, op, tfl_add_op.rhs(), rhs_rescale_scale,
544 input_rhs_qtype.getZeroPoint());
545 auto op3_add_op1_op2 = CreateOpAndInfer<TosaOp>(
546 rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
547 Value op4_rescale_op3 = buildRescaleFromInt32(
548 rewriter, op, output_type, op3_add_op1_op2.getResult(),
549 output_rescale_scale, output_qtype.getZeroPoint());
550 output = op4_rescale_op3;
551 } else {
552 auto op1_add_in =
553 CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), output_type,
554 tfl_add_op.lhs(), tfl_add_op.rhs());
555
556 output = op1_add_in.getResult();
557 }
558
559 auto fused_activation_fn = tfl_add_op.fused_activation_functionAttr();
560
561 if (fused_activation_fn) {
562 llvm::Optional<Value> fused_activation_val =
563 convertFusedActivation(rewriter, op, output, fused_activation_fn);
564
565 if (!fused_activation_val) return failure();
566
567 rewriter.replaceOp(op, {fused_activation_val.getValue()});
568 return success();
569 }
570
571 rewriter.replaceOp(op, {output});
572 return success();
573 }
574
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const575 LogicalResult ConvertTFLAddOp::matchAndRewrite(
576 Operation* op, PatternRewriter& rewriter) const {
577 return matchAndRewriteAddSub<TFL::AddOp, tosa::AddOp>(op, op->getOperands(),
578 rewriter);
579 }
580
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const581 LogicalResult ConvertTFLSubOp::matchAndRewrite(
582 Operation* op, PatternRewriter& rewriter) const {
583 return matchAndRewriteAddSub<TFL::SubOp, tosa::SubOp>(op, op->getOperands(),
584 rewriter);
585 }
586
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const587 LogicalResult ConvertTFLMulOp::matchAndRewrite(
588 Operation* op, PatternRewriter& rewriter) const {
589 auto tfl_mul_op = cast<TFL::MulOp>(op);
590
591 llvm::Optional<Value> result = convertMultiplyOp(
592 rewriter, op, tfl_mul_op.getResult(), tfl_mul_op.lhs(), tfl_mul_op.rhs());
593
594 if (!result) return failure();
595
596 auto fused_activation_fn = tfl_mul_op.fused_activation_functionAttr();
597
598 if (fused_activation_fn) {
599 llvm::Optional<Value> fused_activation_val = convertFusedActivation(
600 rewriter, op, result.getValue(), fused_activation_fn);
601
602 if (!fused_activation_val) return failure();
603
604 rewriter.replaceOp(op, {fused_activation_val.getValue()});
605 return success();
606 }
607
608 rewriter.replaceOp(op, {result.getValue()});
609 return success();
610 }
611
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const612 LogicalResult ConvertTFLSquareOp::matchAndRewrite(
613 Operation* op, PatternRewriter& rewriter) const {
614 auto tfl_square_op = cast<TFL::SquareOp>(op);
615
616 llvm::Optional<Value> result =
617 convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
618 tfl_square_op.x(), tfl_square_op.x());
619
620 if (!result) return failure();
621
622 rewriter.replaceOp(op, {result.getValue()});
623 return success();
624 }
625
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const626 LogicalResult ConvertTFLSquaredDifferenceOp::matchAndRewrite(
627 Operation* op, PatternRewriter& rewriter) const {
628 auto tfl_squared_op = cast<TFL::SquaredDifferenceOp>(op);
629
630 llvm::Optional<Value> result =
631 convertSquaredDifferenceOp(rewriter, op, tfl_squared_op.getResult(),
632 tfl_squared_op.lhs(), tfl_squared_op.rhs());
633
634 if (!result) return failure();
635
636 rewriter.replaceOp(op, {result.getValue()});
637 return success();
638 }
639
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const640 LogicalResult ConvertTFLRoundOp::matchAndRewrite(
641 Operation* op, PatternRewriter& rewriter) const {
642 auto tfl_round_op = cast<TFL::RoundOp>(op);
643
644 ShapedType input_type = tfl_round_op.x().getType().dyn_cast<ShapedType>();
645 if (!input_type) {
646 return op->emitOpError("Round: input not shaped tensor type");
647 }
648
649 if (input_type.getElementType().isa<FloatType>()) {
650 llvm::Optional<Value> result = convertRoundOp(
651 rewriter, op, tfl_round_op.getResult(), tfl_round_op.x());
652
653 if (!result) return failure();
654
655 rewriter.replaceOp(op, {result.getValue()});
656 return success();
657
658 } else {
659 // Round on int is nonsensical. Instead, replace uses of result with the
660 // input.
661 tfl_round_op.replaceAllUsesWith(tfl_round_op.x());
662 return success();
663 }
664 }
665
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const666 LogicalResult ConvertTFLDivOp::matchAndRewrite(
667 Operation* op, PatternRewriter& rewriter) const {
668 auto tfl_div_op = cast<TFL::DivOp>(op);
669
670 ShapedType output_type =
671 tfl_div_op.getResult().getType().dyn_cast<ShapedType>();
672 // Not a ranked tensor output
673 if (!output_type) return failure();
674
675 auto fused_activation_fn = tfl_div_op.fused_activation_functionAttr();
676
677 Type element_type = output_type.getElementType();
678 Value div_op;
679 if (element_type.isa<IntegerType>()) {
680 div_op = CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), output_type,
681 tfl_div_op.lhs(), tfl_div_op.rhs())
682 .getResult();
683 } else {
684 auto reciprocal_op = CreateOpAndInfer<tosa::ReciprocalOp>(
685 rewriter, op->getLoc(), tfl_div_op.rhs().getType(), tfl_div_op.rhs());
686 div_op = CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
687 tfl_div_op.lhs(),
688 reciprocal_op.getResult(), 0)
689 .getResult();
690 }
691
692 if (fused_activation_fn) {
693 llvm::Optional<Value> fused_activation_val =
694 convertFusedActivation(rewriter, op, div_op, fused_activation_fn);
695
696 if (!fused_activation_val) return failure();
697
698 rewriter.replaceOp(op, {fused_activation_val.getValue()});
699 return success();
700 }
701
702 rewriter.replaceOp(op, {div_op});
703
704 return success();
705 }
706
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const707 LogicalResult ConvertTFLMaximumOp::matchAndRewrite(
708 Operation* op, PatternRewriter& rewriter) const {
709 auto tfl_max_op = cast<TFL::MaximumOp>(op);
710
711 ShapedType input_lhs_type = tfl_max_op.lhs().getType().dyn_cast<ShapedType>();
712 ShapedType input_rhs_type = tfl_max_op.rhs().getType().dyn_cast<ShapedType>();
713 ShapedType output_type =
714 tfl_max_op.getResult().getType().dyn_cast<ShapedType>();
715
716 // Not a shaped tensor output
717 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
718
719 bool input_lhs_is_qtype =
720 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
721 bool input_rhs_is_qtype =
722 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
723 bool output_is_qtype =
724 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
725
726 if (input_lhs_is_qtype != output_is_qtype ||
727 input_rhs_is_qtype != output_is_qtype) {
728 return op->emitOpError(
729 "ConvertTFLMaximumOp: input/output tensor should "
730 "be all quantized or all floating-point.");
731 }
732
733 Value output;
734 if (output_is_qtype) {
735 ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
736
737 Value op1_rescale_lhs =
738 buildRescaleToInt32(rewriter, op, tfl_max_op.lhs(), 1.0f, 0);
739 Value op2_rescale_rhs =
740 buildRescaleToInt32(rewriter, op, tfl_max_op.rhs(), 1.0f, 0);
741 auto op3_max_op1_op2 = CreateOpAndInfer<tosa::MaximumOp>(
742 rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
743 Value op4_rescale_op3 = buildRescaleFromInt32(
744 rewriter, op, output_type, op3_max_op1_op2.getResult(), 1.0f, 0);
745
746 output = op4_rescale_op3;
747 } else {
748 auto op1_max_in =
749 CreateOpAndInfer<tosa::MaximumOp>(rewriter, op->getLoc(), output_type,
750 tfl_max_op.lhs(), tfl_max_op.rhs());
751
752 output = op1_max_in.getResult();
753 }
754
755 rewriter.replaceOp(op, {output});
756
757 return success();
758 }
759
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const760 LogicalResult ConvertTFLMinimumOp::matchAndRewrite(
761 Operation* op, PatternRewriter& rewriter) const {
762 auto tfl_min_op = cast<TFL::MinimumOp>(op);
763
764 ShapedType input_lhs_type = tfl_min_op.lhs().getType().dyn_cast<ShapedType>();
765 ShapedType input_rhs_type = tfl_min_op.rhs().getType().dyn_cast<ShapedType>();
766 ShapedType output_type =
767 tfl_min_op.getResult().getType().dyn_cast<ShapedType>();
768 // Not a shaped tensor output
769 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
770
771 bool input_lhs_is_qtype =
772 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
773 bool input_rhs_is_qtype =
774 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
775 bool output_is_qtype =
776 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
777
778 if (input_lhs_is_qtype != output_is_qtype ||
779 input_rhs_is_qtype != output_is_qtype) {
780 return op->emitOpError(
781 "ConvertTFLMinimumOp: input/output tensor should "
782 "be all quantized or all floating-point.");
783 }
784
785 Value output;
786 if (output_is_qtype) {
787 ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
788
789 Value op1_rescale_lhs =
790 buildRescaleToInt32(rewriter, op, tfl_min_op.lhs(), 1.0f, 0);
791 Value op2_rescale_rhs =
792 buildRescaleToInt32(rewriter, op, tfl_min_op.rhs(), 1.0f, 0);
793 auto op3_min_op1_op2 = CreateOpAndInfer<tosa::MinimumOp>(
794 rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
795 Value op4_rescale_op3 = buildRescaleFromInt32(
796 rewriter, op, output_type, op3_min_op1_op2.getResult(), 1.0f, 0);
797
798 output = op4_rescale_op3;
799 } else {
800 auto op1_min_in =
801 CreateOpAndInfer<tosa::MinimumOp>(rewriter, op->getLoc(), output_type,
802 tfl_min_op.lhs(), tfl_min_op.rhs());
803
804 output = op1_min_in.getResult();
805 }
806
807 rewriter.replaceOp(op, {output});
808
809 return success();
810 }
811
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const812 LogicalResult ConvertTFLFloorDivOp::matchAndRewrite(
813 Operation* op, PatternRewriter& rewriter) const {
814 auto tfl_floordiv_op = cast<TFL::FloorDivOp>(op);
815
816 llvm::Optional<Value> result =
817 convertFloorDivOp(rewriter, op, tfl_floordiv_op.getResult(),
818 tfl_floordiv_op.lhs(), tfl_floordiv_op.rhs());
819
820 if (!result) return failure();
821
822 rewriter.replaceOp(op, {result.getValue()});
823
824 return success();
825 }
826
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const827 LogicalResult ConvertTFLFloorModOp::matchAndRewrite(
828 Operation* op, PatternRewriter& rewriter) const {
829 auto tfl_floormod_op = cast<TFL::FloorModOp>(op);
830
831 llvm::Optional<Value> result =
832 convertFloorModOp(rewriter, op, tfl_floormod_op.getResult(),
833 tfl_floormod_op.lhs(), tfl_floormod_op.rhs());
834
835 if (!result) return failure();
836
837 rewriter.replaceOp(op, {result.getValue()});
838
839 return success();
840 }
841
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const842 LogicalResult ConvertTFLAddNOp::matchAndRewrite(
843 Operation* op, PatternRewriter& rewriter) const {
844 auto tfl_addn_op = cast<TFL::AddNOp>(op);
845
846 ShapedType output_type =
847 tfl_addn_op.getResult().getType().dyn_cast<ShapedType>();
848 // Not a shaped output
849 if (!output_type) return failure();
850
851 SmallVector<Value> inputs(tfl_addn_op.inputs());
852
853 assert(inputs.size() >= 2);
854
855 auto newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(),
856 output_type, inputs[0], inputs[1]);
857 for (int i = 2; i < inputs.size(); i++) {
858 newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
859 inputs[i], newOp.getResult());
860 }
861
862 rewriter.replaceOp(op, {newOp.getResult()});
863
864 return success();
865 }
866
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const867 LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite(
868 Operation* op, PatternRewriter& rewriter) const {
869 auto tfl_avgpool_op = cast<TFL::AveragePool2DOp>(op);
870
871 ShapedType input_type =
872 tfl_avgpool_op.input().getType().dyn_cast<ShapedType>();
873 ShapedType output_type =
874 tfl_avgpool_op.getResult().getType().dyn_cast<ShapedType>();
875 // Not a shaped output
876 if (!output_type) return failure();
877
878 // Kernels and strides are dimensionally ordered
879 SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
880 ArrayAttr kernel_size;
881 ArrayAttr stride;
882 ArrayAttr pad;
883 {
884 int64_t kernel_h = tfl_avgpool_op.filter_height();
885 int64_t kernel_w = tfl_avgpool_op.filter_width();
886 kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
887 // i64array is formatted as NHWC now
888 i64array[1] = kernel_h;
889 i64array[2] = kernel_w;
890 }
891 {
892 int64_t stride_h = tfl_avgpool_op.stride_h();
893 int64_t stride_w = tfl_avgpool_op.stride_w();
894 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
895 }
896 {
897 tensorflow::Padding tf_pad;
898 if (!GetPaddingFromString(tfl_avgpool_op.padding().str(), &tf_pad).ok())
899 return failure();
900
901 // Pooling has no non-unit dilation
902 ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
903
904 RankedTensorType filter_type = RankedTensorType::get(
905 llvm::makeArrayRef(i64array), rewriter.getIntegerType(64));
906
907 // TFLite doesn't support explicit padding
908 if (!getPaddingValuesFromPadType(
909 tf_pad,
910 tensorflow::FORMAT_NHWC, // TFLite only supports this
911 1, // tensorflow::FORMAT_OHWI,
912 input_type, filter_type, stride, dilation, rewriter, pad))
913 return failure();
914 }
915
916 auto average_etype = input_type.getElementType();
917 auto average_type = output_type.clone(average_etype);
918
919 Value result;
920 if (average_etype.isa<quant::UniformQuantizedType>()) {
921 // TensorFlow Lite doesn't use the zero point when calculating
922 // quantized average pool, while TOSA does. Force the TOSA
923 // zero_points to zero to ensure that the calculations match
924
925 auto quant_attr = rewriter.getAttr<tosa::UnaryOpQuantizationAttr>(
926 /*input_zp=*/0, /*output_zp=*/0);
927 result = CreateOpAndInfer<tosa::AvgPool2dOp>(
928 rewriter, op->getLoc(), average_type, tfl_avgpool_op.input(),
929 kernel_size, stride, pad, quant_attr);
930 } else {
931 result = CreateOpAndInfer<tosa::AvgPool2dOp>(
932 rewriter, op->getLoc(), average_type, tfl_avgpool_op.input(),
933 kernel_size, stride, pad);
934 }
935 if (average_type != output_type) {
936 result = CreateOpAndInfer<tosa::CastOp>(rewriter, op->getLoc(), output_type,
937 result);
938 }
939
940 rewriter.replaceOp(op, result);
941 return success();
942 }
943
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const944 LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite(
945 Operation* op, PatternRewriter& rewriter) const {
946 auto tfl_maxpool_op = cast<TFL::MaxPool2DOp>(op);
947
948 ShapedType input_type =
949 tfl_maxpool_op.input().getType().dyn_cast<ShapedType>();
950 ShapedType output_type =
951 tfl_maxpool_op.getResult().getType().dyn_cast<ShapedType>();
952 // Not a shaped type
953 if (!output_type) return failure();
954
955 // Kernels and strides are dimensionally ordered
956 SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
957 ArrayAttr kernel_size;
958 ArrayAttr stride;
959 ArrayAttr pad;
960 {
961 int64_t kernel_h = tfl_maxpool_op.filter_height();
962 int64_t kernel_w = tfl_maxpool_op.filter_width();
963 kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
964 // i64array is formatted as NHWC now
965 i64array[1] = kernel_h;
966 i64array[2] = kernel_w;
967 }
968 {
969 int64_t stride_h = tfl_maxpool_op.stride_h();
970 int64_t stride_w = tfl_maxpool_op.stride_w();
971 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
972 }
973 {
974 tensorflow::Padding tf_pad;
975 if (!GetPaddingFromString(tfl_maxpool_op.padding().str(), &tf_pad).ok())
976 return failure();
977
978 // Pooling has no non-unit dilation
979 ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
980
981 RankedTensorType filter_type =
982 RankedTensorType::get(i64array, rewriter.getIntegerType(64));
983
984 // TFLite doesn't support explicit padding
985 if (!getPaddingValuesFromPadType(
986 tf_pad,
987 tensorflow::FORMAT_NHWC, // TFLite only supports this
988 1, // tensorflow::FORMAT_OHWI,
989 input_type, filter_type, stride, dilation, rewriter, pad))
990 return failure();
991 }
992
993 CreateReplaceOpAndInfer<tosa::MaxPool2dOp>(rewriter, op, output_type,
994 tfl_maxpool_op.input(),
995 kernel_size, stride, pad);
996 return success();
997 }
998
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const999 LogicalResult ConvertTFLConv2DOp::matchAndRewrite(
1000 Operation* op, PatternRewriter& rewriter) const {
1001 auto tfl_conv2d_op = cast<TFL::Conv2DOp>(op);
1002
1003 RankedTensorType input_type =
1004 tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
1005 RankedTensorType filter_type =
1006 tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
1007 ShapedType output_type =
1008 tfl_conv2d_op.getResult().getType().dyn_cast<ShapedType>();
1009 // Not a ranked tensor output
1010 if (!input_type) return failure();
1011 if (!output_type) return failure();
1012 if (!filter_type) return failure();
1013
1014 bool input_is_qtype =
1015 input_type.getElementType().isa<mlir::quant::QuantizedType>();
1016 bool filter_is_qtype =
1017 filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1018 bool output_is_qtype =
1019 output_type.getElementType().isa<mlir::quant::QuantizedType>();
1020
1021 if ((input_is_qtype != filter_is_qtype) ||
1022 (input_is_qtype != output_is_qtype)) {
1023 return op->emitOpError(
1024 "ConvertTFLConv2DOp: input/filter/output tensor should "
1025 "be all quantized or all floating-point.");
1026 }
1027
1028 ArrayAttr pad;
1029 ArrayAttr stride;
1030 ArrayAttr dilation;
1031 {
1032 int64_t stride_h = tfl_conv2d_op.stride_h();
1033 int64_t stride_w = tfl_conv2d_op.stride_w();
1034 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1035 }
1036 {
1037 int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
1038 int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
1039 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
1040 }
1041 {
1042 tensorflow::Padding tf_pad;
1043 if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
1044 return failure();
1045
1046 // TFLite doesn't support explicit padding
1047 if (!getPaddingValuesFromPadType(
1048 tf_pad,
1049 tensorflow::FORMAT_NHWC, // TFLite only supports this
1050 1, // tensorflow::FORMAT_OHWI,
1051 input_type, filter_type, stride, dilation, rewriter, pad))
1052 return failure();
1053 }
1054
1055 Value unquantized_bias = tfl_conv2d_op.bias();
1056 Type bias_ety =
1057 output_is_qtype ? rewriter.getI32Type() : output_type.getElementType();
1058 if (unquantized_bias)
1059 bias_ety = unquantized_bias.getType().cast<ShapedType>().getElementType();
1060
1061 auto a1_conv2d_op = CreateOpAndInfer<tosa::Conv2DOp>(
1062 rewriter, op->getLoc(), output_type.clone(bias_ety),
1063 tfl_conv2d_op.input(), tfl_conv2d_op.filter(), unquantized_bias, pad,
1064 stride, dilation);
1065
1066 Value conv2d_output;
1067 if (input_is_qtype) {
1068 conv2d_output =
1069 buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
1070 input_type, filter_type, output_type);
1071 } else {
1072 conv2d_output = a1_conv2d_op.getResult();
1073 }
1074
1075 auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
1076
1077 if (fused_activation_fn) {
1078 llvm::Optional<Value> fused_activation_val = convertFusedActivation(
1079 rewriter, op, conv2d_output, fused_activation_fn);
1080
1081 if (!fused_activation_val) return failure();
1082
1083 rewriter.replaceOp(op, {fused_activation_val.getValue()});
1084 return success();
1085 }
1086
1087 rewriter.replaceOp(op, {conv2d_output});
1088
1089 return success();
1090 }
1091
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1092 LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite(
1093 Operation* op, PatternRewriter& rewriter) const {
1094 auto tfl_conv_op = cast<TFL::TransposeConvOp>(op);
1095
1096 ShapedType input_type = tfl_conv_op.input().getType().dyn_cast<ShapedType>();
1097 ShapedType filter_type =
1098 tfl_conv_op.weights().getType().dyn_cast<ShapedType>();
1099 ShapedType output_type =
1100 tfl_conv_op.getResult().getType().dyn_cast<ShapedType>();
1101 // Not a ranked tensor output
1102 if (!input_type) return failure();
1103 if (!output_type) return failure();
1104 if (!filter_type) return failure();
1105
1106 bool input_is_qtype =
1107 input_type.getElementType().isa<mlir::quant::QuantizedType>();
1108 bool filter_is_qtype =
1109 filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1110 bool output_is_qtype =
1111 output_type.getElementType().isa<mlir::quant::QuantizedType>();
1112
1113 if ((input_is_qtype != filter_is_qtype) ||
1114 (input_is_qtype != output_is_qtype)) {
1115 return op->emitOpError(
1116 "ConvertTFLConv2DOp: input/filter/output tensor should "
1117 "be all quantized or all floating-point.");
1118 }
1119
1120 ArrayAttr stride;
1121 ArrayAttr outpad;
1122 ArrayAttr output_shape;
1123 {
1124 int64_t stride_h = tfl_conv_op.stride_h();
1125 int64_t stride_w = tfl_conv_op.stride_w();
1126 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1127 }
1128
1129 {
1130 tensorflow::Padding tf_pad;
1131 if (!GetPaddingFromString(tfl_conv_op.padding().str(), &tf_pad).ok())
1132 return failure();
1133
1134 if (!getTransposeConv2dPaddingValues(
1135 tf_pad,
1136 tensorflow::FORMAT_NHWC, // TFLite only supports this
1137 1, // tensorflow::FORMAT_OHWI,
1138 input_type, filter_type, output_type, stride, rewriter, outpad))
1139 return failure();
1140 }
1141 {
1142 ElementsAttr output_shape_elems;
1143 // Match from input_size tensor first
1144 if (matchPattern(tfl_conv_op.output_shape(),
1145 m_Constant(&output_shape_elems))) {
1146 SmallVector<int64_t> shape_vec;
1147 for (int i = 0; i < output_shape_elems.getNumElements(); i++)
1148 shape_vec.push_back(
1149 output_shape_elems.getValues<APInt>()[i].getSExtValue());
1150 output_shape = rewriter.getI64ArrayAttr(shape_vec);
1151 } else if (output_type.hasRank()) {
1152 // Use output tensor's shape otherwise
1153 output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
1154 } else {
1155 // TODO(suderman): Figure out rankless shape propagation.
1156 return failure();
1157 }
1158 }
1159
1160 int output_channel = 0;
1161 // TODO(suderman): We need to figure out how to guarantee output channel
1162 // propagation.
1163 if (output_type.hasRank()) {
1164 output_channel = output_type.getDimSize(3);
1165 } else if (filter_type.hasRank()) {
1166 output_channel = filter_type.getDimSize(0);
1167 } else {
1168 return failure();
1169 }
1170
1171 llvm::Optional<Value> zero_bias;
1172 if (input_is_qtype) {
1173 uint32_t input_bits = input_type.getElementType()
1174 .dyn_cast<mlir::quant::QuantizedType>()
1175 .getStorageTypeIntegralWidth();
1176 uint32_t weight_bits = filter_type.getElementType()
1177 .dyn_cast<mlir::quant::QuantizedType>()
1178 .getStorageTypeIntegralWidth();
1179
1180 if (input_bits == 16 && weight_bits == 8) {
1181 SmallVector<APInt> vec(output_channel, APInt(48, 0, true));
1182 zero_bias = getConstTensor<APInt>(rewriter, op, vec, {output_channel});
1183 } else {
1184 SmallVector<int32_t> vec(output_channel, 0);
1185 zero_bias = getConstTensor<int32_t>(rewriter, op, vec, {output_channel});
1186 }
1187 } else {
1188 SmallVector<float> vec(output_channel, 0.0f);
1189 zero_bias = getConstTensor<float>(rewriter, op, vec, {output_channel});
1190 }
1191
1192 if (!zero_bias) return failure();
1193 Type bias_ety = zero_bias->getType().cast<ShapedType>().getElementType();
1194
1195 auto a1_conv2d_op = CreateOpAndInfer<tosa::TransposeConv2DOp>(
1196 rewriter, op->getLoc(), output_type.clone(bias_ety), tfl_conv_op.input(),
1197 tfl_conv_op.weights(), zero_bias.getValue(), outpad, stride,
1198 output_shape);
1199
1200 Value conv2d_output;
1201 if (input_is_qtype) {
1202 conv2d_output =
1203 buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
1204 input_type, filter_type, output_type);
1205 } else {
1206 conv2d_output = a1_conv2d_op.getResult();
1207 }
1208
1209 rewriter.replaceOp(op, {conv2d_output});
1210
1211 return success();
1212 }
1213
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1214 LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite(
1215 Operation* op, PatternRewriter& rewriter) const {
1216 auto tfl_conv2d_op = cast<TFL::DepthwiseConv2DOp>(op);
1217
1218 ShapedType input_type =
1219 tfl_conv2d_op.input().getType().dyn_cast<ShapedType>();
1220 ShapedType filter_type =
1221 tfl_conv2d_op.filter().getType().dyn_cast<ShapedType>();
1222 ShapedType output_type =
1223 tfl_conv2d_op.getResult().getType().dyn_cast<ShapedType>();
1224 // Not a shaped output
1225 if (!input_type) return failure();
1226 if (!output_type) return failure();
1227 if (!filter_type) return failure();
1228
1229 bool input_is_qtype =
1230 input_type.getElementType().isa<mlir::quant::QuantizedType>();
1231 bool filter_is_qtype =
1232 filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1233 bool output_is_qtype =
1234 output_type.getElementType().isa<mlir::quant::QuantizedType>();
1235
1236 if ((input_is_qtype != filter_is_qtype) ||
1237 (input_is_qtype != output_is_qtype)) {
1238 return op->emitOpError(
1239 "ConvertTFLConv2DOp: input/filter/output tensor should "
1240 "be all quantized or all floating-point.");
1241 }
1242
1243 // We need the filter shape to compute the transpose.
1244 if (!filter_type.hasRank()) return failure();
1245 auto filter_shape = filter_type.getShape();
1246 // Operator depthwiseConv2D
1247 // TFLite orders the depthwiseConv2D filter in IHWO, while TOSA orders
1248 // filter in HWIO
1249 //
1250 // The lowering reorders the filter.
1251 //
1252 // a1_transpose = tosa.transpose(filter, {1, 2, 3, 0}) // HWIO
1253 // a2_reshape = tosa.reshape(filter, H, W, depth_multiplier, I /
1254 // depth_multiplier)
1255 // a3_transpose_conv2d = tosa.transpose_conv2d(input, a2_reshape, padding,
1256 // stride, dilation)
1257
1258 ArrayAttr pad;
1259 ArrayAttr stride;
1260 ArrayAttr dilation;
1261 auto depth_multiplier = tfl_conv2d_op.depth_multiplierAttr();
1262
1263 {
1264 int64_t stride_h = tfl_conv2d_op.stride_h();
1265 int64_t stride_w = tfl_conv2d_op.stride_w();
1266 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1267 }
1268 {
1269 int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
1270 int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
1271 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
1272 }
1273 {
1274 tensorflow::Padding tf_pad;
1275 if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
1276 return failure();
1277
1278 if (!getPaddingValuesFromPadType(
1279 tf_pad,
1280 tensorflow::FORMAT_NHWC, // TFLite only supports this
1281 1, // tensorflow::FORMAT_OHWI,
1282 input_type, filter_type, stride, dilation, rewriter, pad))
1283 return failure();
1284 }
1285
1286 SmallVector<int64_t, 4> a1_transpose_dims;
1287 a1_transpose_dims.push_back(filter_shape[1]);
1288 a1_transpose_dims.push_back(filter_shape[2]);
1289 a1_transpose_dims.push_back(filter_shape[3]);
1290 a1_transpose_dims.push_back(filter_shape[0]);
1291
1292 SmallVector<int64_t, 4> a2_reshape_dims;
1293 a2_reshape_dims.push_back(a1_transpose_dims[0]);
1294 a2_reshape_dims.push_back(a1_transpose_dims[1]);
1295 a2_reshape_dims.push_back(a1_transpose_dims[2] / depth_multiplier.getInt());
1296 a2_reshape_dims.push_back(depth_multiplier.getInt());
1297
1298 llvm::Optional<Value> a1_filter_transpose_perms = getConstTensor<int32_t>(
1299 rewriter, op, /*vec=*/{1, 2, 3, 0}, /*shape=*/{4});
1300
1301 if (!a1_filter_transpose_perms) return failure();
1302
1303 auto a1_filter_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
1304 rewriter, op->getLoc(),
1305 RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
1306 filter_type.getElementType()),
1307 tfl_conv2d_op.filter(), a1_filter_transpose_perms.getValue());
1308
1309 auto a2_filter_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
1310 rewriter, op->getLoc(),
1311 RankedTensorType::get(ArrayRef<int64_t>(a2_reshape_dims),
1312 filter_type.getElementType()),
1313 a1_filter_transpose_op.getResult(),
1314 rewriter.getI64ArrayAttr(a2_reshape_dims));
1315
1316 Value unquantized_bias = tfl_conv2d_op.bias();
1317 Type bias_ety =
1318 output_is_qtype ? rewriter.getI32Type() : output_type.getElementType();
1319 if (unquantized_bias)
1320 bias_ety = unquantized_bias.getType().cast<ShapedType>().getElementType();
1321
1322 auto a3_depthwise_conv2d_op = CreateOpAndInfer<tosa::DepthwiseConv2DOp>(
1323 rewriter, op->getLoc(), output_type.clone(bias_ety),
1324 tfl_conv2d_op.input(), a2_filter_reshape_op.getResult(), unquantized_bias,
1325 pad, stride, dilation);
1326
1327 Value conv2d_output;
1328 if (input_is_qtype) {
1329 conv2d_output = buildRescaleOpConvOutput(
1330 rewriter, op, a3_depthwise_conv2d_op.getResult(), input_type,
1331 filter_type, output_type);
1332 } else {
1333 conv2d_output = a3_depthwise_conv2d_op.getResult();
1334 }
1335
1336 auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
1337
1338 if (fused_activation_fn) {
1339 llvm::Optional<Value> fused_activation_val = convertFusedActivation(
1340 rewriter, op, conv2d_output, fused_activation_fn);
1341
1342 if (!fused_activation_val) return failure();
1343
1344 rewriter.replaceOp(op, {fused_activation_val.getValue()});
1345 return success();
1346 }
1347
1348 rewriter.replaceOp(op, {conv2d_output});
1349
1350 return success();
1351 }
1352
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1353 LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite(
1354 Operation* op, PatternRewriter& rewriter) const {
1355 auto tfl_mm_op = cast<TFL::BatchMatMulOp>(op);
1356 auto result_ty = tfl_mm_op.getType().cast<ShapedType>();
1357 Value lhs = tfl_mm_op.x();
1358 Value rhs = tfl_mm_op.y();
1359 RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
1360 RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
1361 bool transpose_lhs = tfl_mm_op.adj_x();
1362 bool transpose_rhs = tfl_mm_op.adj_y();
1363
1364 if (!lhs_ty || !rhs_ty) return failure();
1365
1366 bool lhs_is_qtype = lhs_ty.getElementType().isa<mlir::quant::QuantizedType>();
1367 bool rhs_is_qtype = rhs_ty.getElementType().isa<mlir::quant::QuantizedType>();
1368 bool result_is_qtype =
1369 result_ty.getElementType().isa<mlir::quant::QuantizedType>();
1370
1371 if ((lhs_is_qtype != rhs_is_qtype) || (lhs_is_qtype != result_is_qtype)) {
1372 return op->emitOpError(
1373 "ConvertTFLBatchMatMulOp: lhs/rhs/output tensor should "
1374 "be all quantized or all floating-point.");
1375 }
1376
1377 auto batch_dims = lhs_ty.getShape().drop_back(2);
1378 if (batch_dims.size() > 1) {
1379 int64_t N = 1;
1380 for (auto d : batch_dims) {
1381 N = N < 0 || d < 0 ? -1 : N * d;
1382 }
1383
1384 llvm::SmallVector<int64_t> new_lhs_shape{N};
1385 llvm::SmallVector<int64_t> new_rhs_shape{N};
1386 auto lhs_shape_end = lhs_ty.getShape().take_back(2);
1387 auto rhs_shape_end = rhs_ty.getShape().take_back(2);
1388
1389 new_lhs_shape.append(lhs_shape_end.begin(), lhs_shape_end.end());
1390 new_rhs_shape.append(rhs_shape_end.begin(), rhs_shape_end.end());
1391
1392 lhs = CreateOpAndInfer<tosa::ReshapeOp>(
1393 rewriter, op->getLoc(),
1394 UnrankedTensorType::get(lhs_ty.getElementType()), lhs,
1395 rewriter.getI64ArrayAttr(new_lhs_shape));
1396 rhs = CreateOpAndInfer<tosa::ReshapeOp>(
1397 rewriter, op->getLoc(),
1398 UnrankedTensorType::get(rhs_ty.getElementType()), rhs,
1399 rewriter.getI64ArrayAttr(new_rhs_shape));
1400 lhs_ty = lhs.getType().cast<RankedTensorType>();
1401 rhs_ty = rhs.getType().cast<RankedTensorType>();
1402 }
1403
1404 if (transpose_lhs) {
1405 Value perms =
1406 getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3})
1407 .getValue();
1408 Type output_type = UnrankedTensorType::get(lhs_ty.getElementType());
1409 lhs = CreateOpAndInfer<tosa::TransposeOp>(rewriter, op->getLoc(),
1410 output_type, lhs, perms)
1411 .getResult();
1412 }
1413
1414 if (transpose_rhs) {
1415 Value perms =
1416 getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3})
1417 .getValue();
1418 Type output_type = UnrankedTensorType::get(rhs_ty.getElementType());
1419 rhs = CreateOpAndInfer<tosa::TransposeOp>(rewriter, op->getLoc(),
1420 output_type, rhs, perms)
1421 .getResult();
1422 }
1423
1424 auto matmul =
1425 CreateOpAndInfer<tosa::MatMulOp>(
1426 rewriter, op->getLoc(),
1427 UnrankedTensorType::get(result_ty.getElementType()), lhs, rhs)
1428 .getResult();
1429
1430 // Conditionally reshape rank back to expected rank.
1431 auto matmul_ty = matmul.getType().cast<RankedTensorType>();
1432 if (batch_dims.size() > 1) {
1433 llvm::SmallVector<int64_t> new_shape{};
1434 for (auto d : batch_dims) {
1435 new_shape.push_back(d);
1436 }
1437
1438 for (auto d : matmul_ty.getShape().take_back(2)) {
1439 new_shape.push_back(d);
1440 }
1441
1442 matmul = CreateOpAndInfer<tosa::ReshapeOp>(
1443 rewriter, op->getLoc(),
1444 UnrankedTensorType::get(matmul_ty.getElementType()), matmul,
1445 rewriter.getI64ArrayAttr(new_shape));
1446 }
1447
1448 if (lhs_is_qtype) {
1449 matmul = buildRescaleOpConvOutput(rewriter, op, matmul, lhs_ty, rhs_ty,
1450 result_ty);
1451 }
1452
1453 rewriter.replaceOp(op, matmul);
1454
1455 return success();
1456 }
1457
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1458 LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite(
1459 Operation* op, PatternRewriter& rewriter) const {
1460 auto tfl_fc_op = cast<TFL::FullyConnectedOp>(op);
1461
1462 ShapedType output_type =
1463 tfl_fc_op.getResult(0).getType().dyn_cast<ShapedType>();
1464 // Not a ranked tensor output
1465 if (!output_type) return failure();
1466
1467 RankedTensorType input_type =
1468 tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
1469 RankedTensorType filter_type =
1470 tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
1471 RankedTensorType bias_type =
1472 tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
1473 if (!input_type || !filter_type) return failure();
1474
1475 bool input_is_qtype =
1476 input_type.getElementType().isa<mlir::quant::QuantizedType>();
1477 bool filter_is_qtype =
1478 filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1479 bool output_is_qtype =
1480 output_type.getElementType().isa<mlir::quant::QuantizedType>();
1481
1482 if ((input_is_qtype != filter_is_qtype) ||
1483 (input_is_qtype != output_is_qtype)) {
1484 return op->emitOpError(
1485 "ConvertTFLFullyConnectedOp: input/filter/output tensor should "
1486 "be all quantized or all floating-point.");
1487 }
1488
1489 Value input_val = tfl_fc_op.input();
1490
1491 // tfl.fully_connected() can takes various dimension tensor as input
1492 // need to reshape it to rank 2 tensor, which tosa.fully_connected only
1493 // supports if input tensor is rank 4. It's not always reshaping to (dim[0] *
1494 // dim[1], dim[2] * dim[3]).
1495
1496 // In some networks it's reshaping to (dim[0], dim[1] * dim[2] * dim[3]) so a
1497 // more general way to determine the reshape's shape is by looking at filter's
1498 // shape[1].
1499 if (input_type.getRank() != 2) {
1500 int64_t num_elems = filter_type.getShape()[1];
1501 int64_t num_batch = input_type.getNumElements() / num_elems;
1502 SmallVector<int64_t, 2> shape_vals({num_batch, num_elems});
1503
1504 RankedTensorType reshape_type =
1505 RankedTensorType::get(shape_vals, input_type.getElementType());
1506 auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
1507 rewriter, op->getLoc(), reshape_type, tfl_fc_op.input(),
1508 rewriter.getI64ArrayAttr(shape_vals));
1509
1510 input_val = reshape_op.getResult();
1511 }
1512
1513 Value bias_val;
1514 if (!bias_type) {
1515 // For some matmuls, the bias may actually be a "UnitType" which has no
1516 // value. TOSA requires bias to be an array of output_channel_count values,
1517 // so create a constant of the appropriate number and type of zeros.
1518 SmallVector<int64_t, 1> bias_shape({filter_type.getShape()[0]});
1519 RankedTensorType new_bias_type;
1520
1521 DenseElementsAttr bias_attr;
1522 if (input_type.getElementType().isa<FloatType>()) {
1523 SmallVector<float> bias_arr(bias_shape[0]);
1524
1525 for (int i = 0; i < bias_shape[0]; i++) {
1526 bias_arr[i] = 0.0;
1527 }
1528 new_bias_type =
1529 RankedTensorType::get(bias_shape, input_type.getElementType());
1530 bias_attr =
1531 DenseElementsAttr::get(new_bias_type, llvm::makeArrayRef(bias_arr));
1532 } else {
1533 SmallVector<int32_t> bias_arr(bias_shape[0]);
1534
1535 for (int i = 0; i < bias_shape[0]; i++) {
1536 bias_arr[i] = 0;
1537 }
1538 if (!input_is_qtype) {
1539 return op->emitOpError(
1540 "ConvertTFLFullyConnectedOp: input must be quantized type if it's "
1541 "not float type.");
1542 }
1543 auto input_qtype =
1544 input_type.getElementType().cast<mlir::quant::QuantizedType>();
1545 Type new_bias_ety = input_qtype.getStorageTypeIntegralWidth() == 16
1546 ? rewriter.getIntegerType(48)
1547 : rewriter.getI32Type();
1548 new_bias_type = RankedTensorType::get(bias_shape, new_bias_ety);
1549 bias_attr =
1550 DenseElementsAttr::get(new_bias_type, llvm::makeArrayRef(bias_arr));
1551 }
1552 auto bias_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
1553 new_bias_type, bias_attr);
1554 bias_val = bias_op.getResult();
1555 bias_type = new_bias_type;
1556 } else {
1557 bias_val = tfl_fc_op.bias();
1558 }
1559
1560 Type bias_ety = bias_val.getType().cast<ShapedType>().getElementType();
1561
1562 auto fc_op = CreateOpAndInfer<tosa::FullyConnectedOp>(
1563 rewriter, op->getLoc(), UnrankedTensorType::get(bias_ety), input_val,
1564 tfl_fc_op.filter(), bias_val);
1565
1566 Value fc_output;
1567 if (input_is_qtype) {
1568 fc_output = buildRescaleOpConvOutput(
1569 rewriter, op, fc_op.getResult(), input_type, filter_type,
1570 UnrankedTensorType::get(output_type.getElementType()));
1571 } else {
1572 fc_output = fc_op.getResult();
1573 }
1574
1575 // If we know the output rank, we need to ensure the output shape is correct.
1576 ShapedType fc_type = fc_output.getType().cast<ShapedType>();
1577 if (output_type.hasRank()) {
1578 llvm::SmallVector<int64_t> output_shape;
1579
1580 fc_output = CreateOpAndInfer<tosa::ReshapeOp>(
1581 rewriter, op->getLoc(),
1582 UnrankedTensorType::get(fc_type.getElementType()), fc_output,
1583 rewriter.getI64ArrayAttr(output_type.getShape()));
1584 }
1585
1586 auto fused_activation_fn = tfl_fc_op.fused_activation_functionAttr();
1587
1588 if (fused_activation_fn) {
1589 llvm::Optional<Value> fused_activation_val =
1590 convertFusedActivation(rewriter, op, fc_output, fused_activation_fn);
1591
1592 if (!fused_activation_val) return failure();
1593
1594 rewriter.replaceOp(op, {fused_activation_val.getValue()});
1595 return success();
1596 }
1597
1598 rewriter.replaceOp(op, {fc_output});
1599
1600 return success();
1601 }
1602
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1603 LogicalResult ConvertTFLConcatenationOp::matchAndRewrite(
1604 Operation* op, PatternRewriter& rewriter) const {
1605 auto tfl_concat_op = cast<TFL::ConcatenationOp>(op);
1606
1607 SmallVector<Value> values(tfl_concat_op.values());
1608
1609 IntegerAttr axis_attr;
1610 {
1611 auto tmpAttr = tfl_concat_op.axisAttr();
1612 if (!tmpAttr) {
1613 tmpAttr = rewriter.getI64IntegerAttr(0);
1614 }
1615 axis_attr = tmpAttr;
1616 }
1617 int32_t axis = axis_attr.getInt();
1618
1619 llvm::Optional<Value> result =
1620 convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
1621
1622 if (!result) return failure();
1623
1624 rewriter.replaceOp(op, {result.getValue()});
1625 return success();
1626 }
1627
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1628 LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
1629 Operation* op, PatternRewriter& rewriter) const {
1630 auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
1631
1632 ShapedType output_type =
1633 tfl_reshape_op.getResult().getType().dyn_cast<ShapedType>();
1634 // Not a shaped tensor output
1635 if (!output_type) return failure();
1636
1637 SmallVector<int64_t> shape_vals;
1638
1639 // Either the output type needs to be ranked or we need a constant input
1640 // to compute the output rank.
1641 ElementsAttr shape_attr;
1642 if (!matchPattern(tfl_reshape_op.shape(), m_Constant(&shape_attr))) {
1643 if (!output_type.hasRank()) return failure();
1644 shape_vals.resize(output_type.getRank(), -1);
1645 } else {
1646 for (auto dim : shape_attr.getValues<int32_t>()) shape_vals.push_back(dim);
1647 }
1648
1649 // Propagate the agreement between the output shape and constant value.
1650 if (output_type.hasRank()) {
1651 if (output_type.getRank() != shape_vals.size()) return failure();
1652 for (int i = 0; i < output_type.getRank(); i++) {
1653 if (shape_vals[i] == -1) shape_vals[i] = output_type.getDimSize(i);
1654 }
1655 }
1656
1657 // We cannot handle more than 1 dynamic dimension.
1658 int64_t dynamic_count = 0;
1659 for (auto val : shape_vals)
1660 if (val == -1) dynamic_count++;
1661 if (dynamic_count > 1) return failure();
1662
1663 ArrayAttr new_shape_attr = rewriter.getI64ArrayAttr(shape_vals);
1664 CreateReplaceOpAndInfer<tosa::ReshapeOp>(
1665 rewriter, op, output_type, tfl_reshape_op.input(), new_shape_attr);
1666 return success();
1667 }
1668
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1669 LogicalResult ConvertTFLRankOp::matchAndRewrite(
1670 Operation* op, PatternRewriter& rewriter) const {
1671 auto tfl_rank_op = cast<TFL::RankOp>(op);
1672
1673 RankedTensorType input_type =
1674 tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
1675 if (!input_type) return failure();
1676
1677 int32_t rank = input_type.getRank();
1678
1679 RankedTensorType rank_type =
1680 RankedTensorType::get({1}, rewriter.getIntegerType(32));
1681 auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
1682 auto rank_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
1683 rank_type, rank_attr);
1684
1685 rewriter.replaceOp(op, {rank_const.getResult()});
1686
1687 return success();
1688 }
1689
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1690 LogicalResult ConvertTFLShapeOp::matchAndRewrite(
1691 Operation* op, PatternRewriter& rewriter) const {
1692 auto tfl_shape_op = cast<TFL::ShapeOp>(op);
1693
1694 RankedTensorType output_type =
1695 tfl_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
1696 // Not a ranked tensor output
1697 if (!output_type) return failure();
1698
1699 RankedTensorType input_type =
1700 tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
1701 if (!input_type || !input_type.hasStaticShape())
1702 return rewriter.notifyMatchFailure(op, "input shape not static");
1703
1704 auto input_shape = input_type.getShape();
1705
1706 SmallVector<int32_t> shape_arr;
1707 for (int i = 0; i < input_shape.size(); i++) {
1708 shape_arr.emplace_back(input_shape[i]);
1709 }
1710
1711 RankedTensorType shape_type = RankedTensorType::get(
1712 {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
1713 auto shape_attr =
1714 DenseElementsAttr::get(shape_type, llvm::makeArrayRef(shape_arr));
1715 auto shape_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
1716 shape_type, shape_attr);
1717
1718 rewriter.replaceOp(op, {shape_const.getResult()});
1719
1720 return success();
1721 }
1722
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1723 LogicalResult ConvertTFLExpandDimsOp::matchAndRewrite(
1724 Operation* op, PatternRewriter& rewriter) const {
1725 auto tfl_expanddims_op = cast<TFL::ExpandDimsOp>(op);
1726
1727 llvm::Optional<Value> result =
1728 convertExpandDimsOp(rewriter, op, tfl_expanddims_op.getResult(),
1729 tfl_expanddims_op.input(), tfl_expanddims_op.dim());
1730
1731 if (!result) return failure();
1732
1733 rewriter.replaceOp(op, {result.getValue()});
1734
1735 return success();
1736 }
1737
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1738 LogicalResult ConvertTFLSqueezeOp::matchAndRewrite(
1739 Operation* op, PatternRewriter& rewriter) const {
1740 auto tfl_squeeze_op = cast<TFL::SqueezeOp>(op);
1741
1742 // Copy squeeze_dims into int32_t array
1743 auto squeeze_dims_attr = tfl_squeeze_op.squeeze_dimsAttr();
1744 SmallVector<int32_t> squeeze_dims;
1745 for (auto& squeeze_dim : squeeze_dims_attr) {
1746 squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
1747 }
1748
1749 llvm::Optional<Value> result =
1750 convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
1751 tfl_squeeze_op.input(), squeeze_dims);
1752
1753 if (!result) return failure();
1754
1755 rewriter.replaceOp(op, {result.getValue()});
1756
1757 return success();
1758 }
1759
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1760 LogicalResult ConvertTFLFillOp::matchAndRewrite(
1761 Operation* op, PatternRewriter& rewriter) const {
1762 auto tfl_fill_op = cast<TFL::FillOp>(op);
1763
1764 RankedTensorType output_type =
1765 tfl_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
1766 // Not a ranked tensor output
1767 if (!output_type) return failure();
1768
1769 ElementsAttr dims_elems;
1770 if (!matchPattern(tfl_fill_op.dims(), m_Constant(&dims_elems)))
1771 return failure();
1772 SmallVector<int64_t> dims_vals;
1773 uint32_t total_size = 1;
1774 for (int i = 0; i < dims_elems.getNumElements(); i++) {
1775 dims_vals.push_back(dims_elems.getValues<APInt>()[i].getSExtValue());
1776 total_size *= dims_vals[i];
1777 }
1778
1779 ElementsAttr value_elem;
1780 if (!matchPattern(tfl_fill_op.input(), m_Constant(&value_elem)))
1781 return failure();
1782
1783 RankedTensorType fill_type = RankedTensorType::get(
1784 ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
1785 DenseElementsAttr fill_attr;
1786
1787 // Convert to a compatible zero type.
1788 if (value_elem.getType().getElementType().isa<FloatType>()) {
1789 SmallVector<float> fill_arr(
1790 total_size, value_elem.getValues<APFloat>()[0].convertToFloat());
1791 fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
1792 } else {
1793 SmallVector<int32_t> fill_arr(
1794 total_size, value_elem.getValues<APInt>()[0].getLimitedValue());
1795 fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
1796 }
1797 auto fill_const_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
1798 fill_type, fill_attr);
1799 rewriter.replaceOp(op, {fill_const_op.getResult()});
1800
1801 return success();
1802 }
1803
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1804 LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite(
1805 Operation* op, PatternRewriter& rewriter) const {
1806 auto tfl_any_op = cast<TFL::ReduceAnyOp>(op);
1807
1808 RankedTensorType output_type =
1809 tfl_any_op.getResult().getType().dyn_cast<RankedTensorType>();
1810 if (!output_type) return failure();
1811
1812 ElementsAttr axes_elems;
1813 if (!matchPattern(tfl_any_op.reduction_indices(), m_Constant(&axes_elems)))
1814 return failure();
1815
1816 llvm::Optional<Value> result = convertReduceAnyOp(
1817 rewriter, op, output_type, tfl_any_op.input(), axes_elems);
1818
1819 if (!result) return failure();
1820
1821 rewriter.replaceOp(op, {result.getValue()});
1822
1823 return success();
1824 }
1825
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1826 LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite(
1827 Operation* op, PatternRewriter& rewriter) const {
1828 auto tfl_max_op = cast<TFL::ReduceMaxOp>(op);
1829
1830 RankedTensorType output_type =
1831 tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
1832 if (!output_type) return failure();
1833
1834 ElementsAttr axes_elems;
1835 if (!matchPattern(tfl_max_op.axes(), m_Constant(&axes_elems)))
1836 return failure();
1837
1838 llvm::Optional<Value> result = convertReduceMaxOp(
1839 rewriter, op, output_type, tfl_max_op.input(), axes_elems);
1840
1841 if (!result) return failure();
1842
1843 rewriter.replaceOp(op, {result.getValue()});
1844
1845 return success();
1846 }
1847
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1848 LogicalResult ConvertTFLReduceMinOp::matchAndRewrite(
1849 Operation* op, PatternRewriter& rewriter) const {
1850 auto tfl_min_op = cast<TFL::ReduceMinOp>(op);
1851
1852 RankedTensorType output_type =
1853 tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
1854 if (!output_type) return failure();
1855
1856 ElementsAttr axes_elems;
1857 if (!matchPattern(tfl_min_op.axes(), m_Constant(&axes_elems)))
1858 return failure();
1859
1860 llvm::Optional<Value> result = convertReduceMinOp(
1861 rewriter, op, output_type, tfl_min_op.input(), axes_elems);
1862
1863 if (!result) return failure();
1864
1865 rewriter.replaceOp(op, {result.getValue()});
1866
1867 return success();
1868 }
1869
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1870 LogicalResult ConvertTFLReduceProdOp::matchAndRewrite(
1871 Operation* op, PatternRewriter& rewriter) const {
1872 auto tfl_prod_op = cast<TFL::ReduceProdOp>(op);
1873
1874 RankedTensorType output_type =
1875 tfl_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
1876 if (!output_type) return failure();
1877
1878 ElementsAttr axes_elems;
1879 if (!matchPattern(tfl_prod_op.axes(), m_Constant(&axes_elems)))
1880 return failure();
1881
1882 llvm::Optional<Value> result = convertReduceProdOp(
1883 rewriter, op, output_type, tfl_prod_op.input(), axes_elems);
1884
1885 if (!result) return failure();
1886
1887 rewriter.replaceOp(op, {result.getValue()});
1888
1889 return success();
1890 }
1891
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1892 LogicalResult ConvertTFLMeanOp::matchAndRewrite(
1893 Operation* op, PatternRewriter& rewriter) const {
1894 auto tfl_mean_op = cast<TFL::MeanOp>(op);
1895
1896 RankedTensorType output_type =
1897 tfl_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
1898 if (!output_type) return failure();
1899
1900 ElementsAttr axes_elems;
1901 if (!matchPattern(tfl_mean_op.axis(), m_Constant(&axes_elems)))
1902 return failure();
1903
1904 llvm::Optional<Value> result = convertReduceMeanOp(
1905 rewriter, op, output_type, tfl_mean_op.input(), axes_elems);
1906
1907 if (!result) return failure();
1908
1909 rewriter.replaceOp(op, {result.getValue()});
1910
1911 return success();
1912 }
1913
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1914 LogicalResult ConvertTFLSumOp::matchAndRewrite(
1915 Operation* op, PatternRewriter& rewriter) const {
1916 auto tfl_sum_op = cast<TFL::SumOp>(op);
1917
1918 RankedTensorType output_type =
1919 tfl_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
1920 if (!output_type) return failure();
1921
1922 ElementsAttr axes_elems;
1923 if (!matchPattern(tfl_sum_op.axes(), m_Constant(&axes_elems)))
1924 return failure();
1925
1926 llvm::Optional<Value> result = convertReduceSumOp(
1927 rewriter, op, output_type, tfl_sum_op.input(), axes_elems);
1928
1929 if (!result) return failure();
1930
1931 rewriter.replaceOp(op, {result.getValue()});
1932
1933 return success();
1934 }
1935
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1936 LogicalResult ConvertTFLEluOp::matchAndRewrite(
1937 Operation* op, PatternRewriter& rewriter) const {
1938 auto tfl_elu_op = cast<TFL::EluOp>(op);
1939
1940 llvm::Optional<Value> result =
1941 convertEluOp(rewriter, op, tfl_elu_op.getResult(), tfl_elu_op.x());
1942
1943 if (!result) return failure();
1944
1945 rewriter.replaceOp(op, {result.getValue()});
1946
1947 return success();
1948 }
1949
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1950 LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite(
1951 Operation* op, PatternRewriter& rewriter) const {
1952 auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
1953
1954 llvm::Optional<Value> result = convertSoftmaxOp(
1955 rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input(),
1956 tfl_softmax_op.betaAttr().getValueAsDouble());
1957
1958 if (!result) return failure();
1959
1960 rewriter.replaceOp(op, {result.getValue()});
1961
1962 return success();
1963 }
1964
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1965 LogicalResult ConvertTFLSqrtOp::matchAndRewrite(
1966 Operation* op, PatternRewriter& rewriter) const {
1967 auto tfl_rsqrt_op = cast<TFL::SqrtOp>(op);
1968 auto rsqrt = CreateOpAndInfer<tosa::RsqrtOp>(
1969 rewriter, op->getLoc(), tfl_rsqrt_op.getType(), tfl_rsqrt_op.x());
1970
1971 CreateReplaceOpAndInfer<tosa::ReciprocalOp>(rewriter, op, rsqrt.getType(),
1972 rsqrt);
1973
1974 return success();
1975 }
1976
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1977 LogicalResult ConvertTFLL2NormalizationOp::matchAndRewrite(
1978 Operation* op, PatternRewriter& rewriter) const {
1979 auto tfl_l2norm_op = cast<TFL::L2NormalizationOp>(op);
1980 auto input = tfl_l2norm_op.input();
1981 auto input_ty = input.getType().cast<ShapedType>();
1982 auto loc = op->getLoc();
1983
1984 if (!input_ty.hasRank()) return failure();
1985
1986 if (input_ty.getElementType().isF32()) {
1987 auto shift = rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
1988 auto result_ty = UnrankedTensorType::get(input_ty.getElementType());
1989 auto mul = CreateOpAndInfer<tosa::MulOp>(rewriter, loc, result_ty, input,
1990 input, shift);
1991 auto sum = CreateOpAndInfer<tosa::ReduceSumOp>(
1992 rewriter, loc, result_ty, mul,
1993 rewriter.getI64IntegerAttr(input_ty.getRank() - 1));
1994
1995 SmallVector<float> min(1, sqrt(std::numeric_limits<float>::min()));
1996 Value min_val = getConstTensor<float>(rewriter, op, min, {}).getValue();
1997 auto max = CreateOpAndInfer<tosa::MaximumOp>(rewriter, loc, result_ty, sum,
1998 min_val);
1999 auto rsqrt = CreateOpAndInfer<tosa::RsqrtOp>(rewriter, loc, result_ty, max)
2000 .getResult();
2001 auto result = CreateOpAndInfer<tosa::MulOp>(rewriter, loc, result_ty, rsqrt,
2002 input, shift)
2003 .getResult();
2004
2005 auto fused_activation_fn = tfl_l2norm_op.fused_activation_functionAttr();
2006
2007 if (fused_activation_fn) {
2008 llvm::Optional<Value> fused_activation_val =
2009 convertFusedActivation(rewriter, op, result, fused_activation_fn);
2010 if (!fused_activation_val) return failure();
2011 result = fused_activation_val.getValue();
2012 }
2013
2014 rewriter.replaceOp(op, result);
2015 return success();
2016 }
2017
2018 return failure();
2019 }
2020
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2021 LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite(
2022 Operation* op, PatternRewriter& rewriter) const {
2023 auto tfl_logsoftmax_op = cast<TFL::LogSoftmaxOp>(op);
2024
2025 llvm::Optional<Value> result = convertLogSoftmaxOp(
2026 rewriter, op, tfl_logsoftmax_op.getResult(), tfl_logsoftmax_op.input());
2027
2028 if (!result) return failure();
2029
2030 rewriter.replaceOp(op, {result.getValue()});
2031
2032 return success();
2033 }
2034
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2035 LogicalResult ConvertTFLSliceOp::matchAndRewrite(
2036 Operation* op, PatternRewriter& rewriter) const {
2037 auto tfl_slice_op = cast<TFL::SliceOp>(op);
2038
2039 ShapedType output_type =
2040 tfl_slice_op.getResult().getType().dyn_cast<ShapedType>();
2041 // Not a shaped tensor output
2042 if (!output_type) return failure();
2043
2044 ElementsAttr begin_elems, size_elems;
2045
2046 SmallVector<int64_t> begin_vals, size_vals;
2047
2048 if (!matchPattern(tfl_slice_op.begin(), m_Constant(&begin_elems)) ||
2049 !matchPattern(tfl_slice_op.size(), m_Constant(&size_elems))) {
2050 return failure();
2051 }
2052
2053 for (int i = 0; i < begin_elems.getNumElements(); i++)
2054 begin_vals.push_back(begin_elems.getValues<APInt>()[i].getSExtValue());
2055
2056 for (int i = 0; i < size_elems.getNumElements(); i++)
2057 size_vals.push_back(size_elems.getValues<APInt>()[i].getSExtValue());
2058
2059 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
2060 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
2061
2062 CreateReplaceOpAndInfer<tosa::SliceOp>(rewriter, op, output_type,
2063 tfl_slice_op.input(), begin, size);
2064 return success();
2065 }
2066
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2067 LogicalResult ConvertTFLTileOp::matchAndRewrite(
2068 Operation* op, PatternRewriter& rewriter) const {
2069 auto tfl_tile_op = cast<TFL::TileOp>(op);
2070
2071 ShapedType output_type =
2072 tfl_tile_op.getResult().getType().dyn_cast<ShapedType>();
2073 // Not a ranked tensor output
2074 if (!output_type) return failure();
2075
2076 ElementsAttr multiples_elems;
2077 if (!matchPattern(tfl_tile_op.multiples(), m_Constant(&multiples_elems)))
2078 return failure();
2079 SmallVector<int64_t> multiples_vals;
2080 for (int i = 0; i < multiples_elems.getNumElements(); i++)
2081 multiples_vals.push_back(
2082 multiples_elems.getValues<APInt>()[i].getSExtValue());
2083
2084 ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
2085 CreateReplaceOpAndInfer<tosa::TileOp>(rewriter, op, output_type,
2086 tfl_tile_op.input(), multiples_attr);
2087
2088 return success();
2089 }
2090
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2091 LogicalResult ConvertTFLTransposeOp::matchAndRewrite(
2092 Operation* op, PatternRewriter& rewriter) const {
2093 auto tfl_transpose_op = cast<TFL::TransposeOp>(op);
2094
2095 Type output_type = tfl_transpose_op.getResult().getType();
2096 CreateReplaceOpAndInfer<tosa::TransposeOp>(rewriter, op, output_type,
2097 tfl_transpose_op.input(),
2098 tfl_transpose_op.perm());
2099
2100 return success();
2101 }
2102
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2103 LogicalResult ConvertTFLPackOp::matchAndRewrite(
2104 Operation* op, PatternRewriter& rewriter) const {
2105 auto tfl_pack_op = cast<TFL::PackOp>(op);
2106
2107 SmallVector<Value> inputs(tfl_pack_op.values());
2108 assert(!inputs.empty());
2109
2110 IntegerAttr axis_attr;
2111 {
2112 auto tmpAttr = tfl_pack_op.axisAttr();
2113 if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
2114 axis_attr = tmpAttr;
2115 }
2116 int32_t axis_i32 = axis_attr.getInt();
2117
2118 llvm::Optional<Value> result =
2119 convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
2120
2121 if (!result) return failure();
2122
2123 rewriter.replaceOp(op, {result.getValue()});
2124
2125 return success();
2126 }
2127
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2128 LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
2129 Operation* op, PatternRewriter& rewriter) const {
2130 auto tfl_unpack_op = cast<TFL::UnpackOp>(op);
2131
2132 IntegerAttr axis_attr;
2133 {
2134 auto tmpAttr = tfl_unpack_op.axisAttr();
2135 if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
2136 axis_attr = tmpAttr;
2137 }
2138 int32_t axis_i32 = axis_attr.getInt();
2139
2140 llvm::Optional<SmallVector<Value>> results =
2141 convertUnpackOp(rewriter, op, tfl_unpack_op.input(), axis_i32);
2142
2143 if (!results) return failure();
2144
2145 rewriter.replaceOp(op, results.getValue());
2146
2147 return success();
2148 }
2149
2150 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2151 LogicalResult ConvertTFLSplitOp::matchAndRewrite(
2152 Operation* op, PatternRewriter& rewriter) const {
2153 auto tfl_split_op = cast<TFL::SplitOp>(op);
2154
2155 // Get the number of splits
2156 int32_t num_split = -1;
2157 auto numSplitAttr = tfl_split_op.num_splitsAttr();
2158 if (numSplitAttr) {
2159 num_split = numSplitAttr.getInt();
2160 } else {
2161 return failure();
2162 }
2163
2164 // Get the axis
2165 ElementsAttr axisAttrElems;
2166 if (!matchPattern(tfl_split_op.split_dim(), m_Constant(&axisAttrElems))) {
2167 return op->emitOpError("Cannot read split_dim elems");
2168 }
2169
2170 // The axis/split_dim parameter is stored as a 0D tensor instead of
2171 // an integer attribute in TFLite MLIR.
2172 int32_t axis = axisAttrElems.getValues<APInt>()[0].getSExtValue();
2173
2174 llvm::Optional<SmallVector<Value>> results =
2175 convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
2176 tfl_split_op.value(), num_split, axis);
2177
2178 if (!results) return failure();
2179
2180 rewriter.replaceOp(op, results.getValue());
2181
2182 return success();
2183 }
2184
2185 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2186 LogicalResult ConvertTFLSplitVOp::matchAndRewrite(
2187 Operation* op, PatternRewriter& rewriter) const {
2188 auto tfl_splitv_op = cast<TFL::SplitVOp>(op);
2189
2190 // Get the size_splits array
2191 SmallVector<int32_t> size_split;
2192 ElementsAttr size_split_elems;
2193 if (!matchPattern(tfl_splitv_op.size_splits(),
2194 m_Constant(&size_split_elems))) {
2195 return failure();
2196 }
2197
2198 for (int i = 0; i < size_split_elems.getNumElements(); i++) {
2199 size_split.push_back(size_split_elems.getValues<APInt>()[i].getSExtValue());
2200 }
2201
2202 // Get the axis
2203 ElementsAttr axisAttrElems;
2204 if (!matchPattern(tfl_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
2205 return op->emitOpError("Cannot read split_dim elems");
2206 }
2207
2208 // The axis/split_dim parameter is stored as a 0D tensor instead of
2209 // an integer attribute in TFLite MLIR.
2210 int32_t axis = axisAttrElems.getValues<APInt>()[0].getSExtValue();
2211
2212 llvm::Optional<SmallVector<Value>> results =
2213 convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
2214 tfl_splitv_op.value(), size_split, axis);
2215
2216 if (!results) return failure();
2217
2218 rewriter.replaceOp(op, results.getValue());
2219
2220 return success();
2221 }
2222
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2223 LogicalResult ConvertTFLPadOp::matchAndRewrite(
2224 Operation* op, PatternRewriter& rewriter) const {
2225 auto tfl_pad_op = cast<TFL::PadOp>(op);
2226
2227 ShapedType output_type =
2228 tfl_pad_op.getResult().getType().dyn_cast<ShapedType>();
2229 // Not a ranked tensor output
2230 if (!output_type) return failure();
2231
2232 auto pad_op =
2233 CreateOpAndInfer<tosa::PadOp>(rewriter, op->getLoc(), output_type,
2234 tfl_pad_op.input(), tfl_pad_op.padding());
2235
2236 rewriter.replaceOp(op, {pad_op.getResult()});
2237 return success();
2238 }
2239
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2240 LogicalResult ConvertTFLPadV2Op::matchAndRewrite(
2241 Operation* op, PatternRewriter& rewriter) const {
2242 auto tfl_pad_op = cast<TFL::PadV2Op>(op);
2243
2244 Value input = tfl_pad_op.input();
2245 Value padding = tfl_pad_op.padding();
2246 Value constant_value = tfl_pad_op.constant_values();
2247
2248 CreateReplaceOpAndInfer<tosa::PadOp>(rewriter, op, tfl_pad_op.getType(),
2249 input, padding, constant_value);
2250
2251 return success();
2252 }
2253
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2254 LogicalResult ConvertTFLResizeBilinearOp::matchAndRewrite(
2255 Operation* op, PatternRewriter& rewriter) const {
2256 auto tfl_resize_op = cast<TFL::ResizeBilinearOp>(op);
2257
2258 RankedTensorType output_type =
2259 tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
2260 // Not a ranked tensor output
2261 if (!output_type) return failure();
2262
2263 llvm::Optional<Value> result = convertResizeOp(
2264 rewriter, op, output_type, tfl_resize_op.input(), StringRef("BILINEAR"),
2265 tfl_resize_op.align_cornersAttr().getValue(),
2266 tfl_resize_op.half_pixel_centersAttr().getValue());
2267
2268 if (!result) return failure();
2269
2270 rewriter.replaceOp(op, {result.getValue()});
2271
2272 return success();
2273 }
2274
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2275 LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite(
2276 Operation* op, PatternRewriter& rewriter) const {
2277 auto tfl_resize_op = cast<TFL::ResizeNearestNeighborOp>(op);
2278
2279 RankedTensorType output_type =
2280 tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
2281 // Not a ranked tensor output
2282 if (!output_type) return failure();
2283
2284 llvm::Optional<Value> result =
2285 convertResizeOp(rewriter, op, output_type, tfl_resize_op.input(),
2286 StringRef("NEAREST_NEIGHBOR"),
2287 tfl_resize_op.align_cornersAttr().getValue(),
2288 tfl_resize_op.half_pixel_centersAttr().getValue());
2289
2290 if (!result) return failure();
2291
2292 rewriter.replaceOp(op, {result.getValue()});
2293
2294 return success();
2295 }
2296
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2297 LogicalResult ConvertTFLSelectOp::matchAndRewrite(
2298 Operation* op, PatternRewriter& rewriter) const {
2299 auto tfl_sel_op = cast<TFL::SelectOp>(op);
2300
2301 llvm::Optional<Value> result =
2302 convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
2303 tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
2304 if (!result) return failure();
2305
2306 rewriter.replaceOp(op, {result.getValue()});
2307
2308 return success();
2309 }
2310
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2311 LogicalResult ConvertTFLSelectV2Op::matchAndRewrite(
2312 Operation* op, PatternRewriter& rewriter) const {
2313 auto tfl_sel_op = cast<TFL::SelectV2Op>(op);
2314
2315 llvm::Optional<Value> result =
2316 convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
2317 tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
2318 if (!result) return failure();
2319
2320 rewriter.replaceOp(op, {result.getValue()});
2321
2322 return success();
2323 }
2324
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2325 LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite(
2326 Operation* op, PatternRewriter& rewriter) const {
2327 auto tfl_s2b_op = cast<TFL::SpaceToBatchNdOp>(op);
2328 llvm::Optional<Value> result = convertSpaceToBatchNDOp(
2329 rewriter, op, tfl_s2b_op.getResult(), tfl_s2b_op.input(),
2330 tfl_s2b_op.block_shape(), tfl_s2b_op.paddings());
2331
2332 if (!result) return failure();
2333
2334 rewriter.replaceOp(op, {result.getValue()});
2335
2336 return success();
2337 }
2338
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2339 LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite(
2340 Operation* op, PatternRewriter& rewriter) const {
2341 auto tfl_b2s_op = cast<TFL::BatchToSpaceNdOp>(op);
2342
2343 llvm::Optional<Value> result = convertBatchToSpaceNDOp(
2344 rewriter, op, tfl_b2s_op.getResult(), tfl_b2s_op.input(),
2345 tfl_b2s_op.block_shape(), tfl_b2s_op.indices());
2346
2347 if (!result) return failure();
2348
2349 rewriter.replaceOp(op, {result.getValue()});
2350
2351 return success();
2352 }
2353
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2354 LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite(
2355 Operation* op, PatternRewriter& rewriter) const {
2356 auto tfl_s2d_op = cast<TFL::SpaceToDepthOp>(op);
2357
2358 auto block_size_attr = tfl_s2d_op.block_sizeAttr();
2359 llvm::Optional<Value> result = convertSpaceToDepthOp(
2360 rewriter, op, tfl_s2d_op.getResult(), tfl_s2d_op.input(), block_size_attr,
2361 rewriter.getStringAttr("NHWC"));
2362
2363 if (!result) return failure();
2364
2365 rewriter.replaceOp(op, {result.getValue()});
2366
2367 return success();
2368 }
2369
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2370 LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite(
2371 Operation* op, PatternRewriter& rewriter) const {
2372 auto tfl_d2s_op = cast<TFL::DepthToSpaceOp>(op);
2373
2374 auto block_size_attr = tfl_d2s_op.block_sizeAttr();
2375 llvm::Optional<Value> result = convertDepthToSpaceOp(
2376 rewriter, op, tfl_d2s_op.getResult(), tfl_d2s_op.input(), block_size_attr,
2377 rewriter.getStringAttr("NHWC"));
2378
2379 if (!result) return failure();
2380
2381 rewriter.replaceOp(op, {result.getValue()});
2382
2383 return success();
2384 }
2385
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2386 LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite(
2387 Operation* op, PatternRewriter& rewriter) const {
2388 auto tfl_ss_op = cast<TFL::StridedSliceOp>(op);
2389
2390 llvm::Optional<Value> result = convertStridedSliceOp(
2391 rewriter, op, tfl_ss_op.getResult(), tfl_ss_op.input(), tfl_ss_op.begin(),
2392 tfl_ss_op.end(), tfl_ss_op.strides(), tfl_ss_op.begin_maskAttr().getInt(),
2393 tfl_ss_op.end_maskAttr().getInt(), tfl_ss_op.ellipsis_maskAttr().getInt(),
2394 tfl_ss_op.new_axis_maskAttr().getInt(),
2395 tfl_ss_op.shrink_axis_maskAttr().getInt());
2396 if (!result) return failure();
2397
2398 rewriter.replaceOp(op, {result.getValue()});
2399
2400 return success();
2401 }
2402
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2403 LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite(
2404 Operation* op, PatternRewriter& rewriter) const {
2405 auto tfl_zeroslike_op = cast<TFL::ZerosLikeOp>(op);
2406
2407 llvm::Optional<Value> result = convertZerosLikeOp(
2408 rewriter, op, tfl_zeroslike_op.getResult(), tfl_zeroslike_op.input());
2409
2410 if (!result) return failure();
2411
2412 rewriter.replaceOp(op, {result.getValue()});
2413
2414 return success();
2415 }
2416
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2417 LogicalResult ConvertTFLHardSwishOp::matchAndRewrite(
2418 Operation* op, PatternRewriter& rewriter) const {
2419 auto tfl_hardswish_op = cast<TFL::HardSwishOp>(op);
2420 RankedTensorType output_type =
2421 tfl_hardswish_op.getResult().getType().dyn_cast<RankedTensorType>();
2422 // Not a ranked tensor output
2423 if (!output_type) return failure();
2424
2425 RankedTensorType input_type =
2426 tfl_hardswish_op.input().getType().dyn_cast<RankedTensorType>();
2427 // Not a ranked tensor output
2428 if (!input_type) return failure();
2429
2430 // TFL hardswish: f(x) -> (x * relu6(x+3))/6
2431
2432 if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
2433 output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
2434 // Should match TFLite reference numerical behavior
2435 mlir::quant::UniformQuantizedType input_qtype =
2436 input_type.getElementType()
2437 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2438 mlir::quant::UniformQuantizedType output_qtype =
2439 output_type.getElementType()
2440 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2441
2442 auto hardswish_func = [](double v) -> double {
2443 double w = v + 3.0;
2444 w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w;
2445 return v * w / 6.0;
2446 };
2447
2448 if (input_qtype.getStorageTypeIntegralWidth() == 8) {
2449 // Implement with 8-bit table lookup.
2450 Value table_const = getTosaConst8bitTable(
2451 rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
2452 output_qtype.getScale(), output_qtype.getZeroPoint(), hardswish_func);
2453
2454 CreateReplaceOpAndInfer<tosa::TableOp>(
2455 rewriter, op, output_type, tfl_hardswish_op.input(), table_const);
2456 }
2457
2458 } else {
2459 // op1 = constop(3)
2460 // op2 = add(x, op1)
2461 // op3 = clamp(op2, 0, 6)
2462 // op4 = mul(x, op3)
2463 // op5 = reciprocal(6)
2464 // op6 = mul (op4, op5)
2465
2466 Value op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
2467
2468 auto op2_add_x_op1 =
2469 CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
2470 tfl_hardswish_op.input(), op1_value);
2471
2472 auto op3_relu_op2_6 = CreateOpAndInfer<tosa::ClampOp>(
2473 rewriter, op->getLoc(), output_type, op2_add_x_op1.getResult(),
2474 rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(0),
2475 rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f));
2476
2477 auto op4_mul_x_op3 = CreateOpAndInfer<tosa::MulOp>(
2478 rewriter, op->getLoc(), output_type, tfl_hardswish_op.input(),
2479 op3_relu_op2_6.getResult(), 0);
2480
2481 auto const_6 = getTosaConstTensorSingleF32(rewriter, op, 6.0);
2482 auto op5_reciprocal_6 = CreateOpAndInfer<tosa::ReciprocalOp>(
2483 rewriter, op->getLoc(), const_6.getType(), const_6);
2484
2485 auto op6_mul_op4_op5 = CreateOpAndInfer<tosa::MulOp>(
2486 rewriter, op->getLoc(), output_type, op4_mul_x_op3.getResult(),
2487 op5_reciprocal_6.getResult(), 0);
2488
2489 rewriter.replaceOp(op, {op6_mul_op4_op5.getResult()});
2490 }
2491
2492 return success();
2493 }
2494
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2495 LogicalResult ConvertTFLSinOp::matchAndRewrite(
2496 Operation* op, PatternRewriter& rewriter) const {
2497 auto tfl_sin_op = cast<TFL::SinOp>(op);
2498 Location loc = op->getLoc();
2499 Value input = tfl_sin_op.x();
2500 RankedTensorType input_ty = input.getType().dyn_cast<RankedTensorType>();
2501 ShapedType output_ty =
2502 tfl_sin_op.getResult().getType().dyn_cast<ShapedType>();
2503
2504 Type input_ety = input_ty.getElementType();
2505 Type output_ety = output_ty.getElementType();
2506
2507 if (!input_ty || !output_ty) return failure();
2508
2509 if (input_ety != output_ety) {
2510 return rewriter.notifyMatchFailure(
2511 op, "ConvertTFLSinOp: input/output element type must match");
2512 }
2513
2514 bool input_is_fp = input_ty.getElementType().isF32();
2515 bool output_is_fp = output_ty.getElementType().isF32();
2516
2517 if (!input_is_fp || !output_is_fp) {
2518 return rewriter.notifyMatchFailure(
2519 op, "ConvertTFLSinOp: input/result must be fp32.");
2520 }
2521
2522 // To perform a sin operation we remap the sin domain to be over a single
2523 // period of the function, remapping to the domain of the table function.
2524 // We then remap the range of the table function to map to the range of the
2525 // sin operation.
2526
2527 // 1. Normalize the period of the domain from [0, 2π) to [0, 1).
2528 auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type());
2529 Value fp_scale = rewriter.create<tosa::ConstOp>(
2530 loc, fp_scalar_ty,
2531 DenseElementsAttr::get(fp_scalar_ty, {static_cast<float>(0.5 / M_PI)}));
2532
2533 // 2. Remap the periodic behavior of the domain to line up within [0, 1).
2534 Value fp_scaled = CreateOpAndInfer<tosa::MulOp>(
2535 rewriter, loc, input_ty, input, fp_scale, rewriter.getI32IntegerAttr(0));
2536 auto floored =
2537 CreateOpAndInfer<tosa::FloorOp>(rewriter, loc, input_ty, fp_scaled);
2538 auto repeated = CreateOpAndInfer<tosa::SubOp>(rewriter, loc, input_ty,
2539 fp_scaled, floored);
2540
2541 // 3. Scale and translate the normalized domain to the table domain. This
2542 // includes a translating and scaling to [-int16_max, int16_max] and casting
2543 // to an i16.
2544 Value one = rewriter.create<tosa::ConstOp>(
2545 loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {1.0f}));
2546
2547 Value two = rewriter.create<tosa::ConstOp>(
2548 loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {2.0f}));
2549 auto scale_up = CreateOpAndInfer<tosa::MulOp>(
2550 rewriter, loc, input_ty, repeated, two, rewriter.getI32IntegerAttr(0));
2551 auto translate =
2552 CreateOpAndInfer<tosa::SubOp>(rewriter, loc, input_ty, scale_up, one);
2553
2554 Value int_limit = rewriter.create<tosa::ConstOp>(
2555 loc, fp_scalar_ty,
2556 DenseElementsAttr::get(
2557 fp_scalar_ty,
2558 {static_cast<float>(std::numeric_limits<int16_t>::max())}));
2559 auto int_scaled =
2560 CreateOpAndInfer<tosa::MulOp>(rewriter, loc, input_ty, translate,
2561 int_limit, rewriter.getI32IntegerAttr(0));
2562
2563 auto int16_ty = input_ty.clone(rewriter.getIntegerType(16));
2564 auto casted =
2565 CreateOpAndInfer<tosa::CastOp>(rewriter, loc, int16_ty, int_scaled);
2566
2567 // 4. Compute the lookup table using the range of [-255, 255] for sin.
2568 llvm::SmallVector<int16_t> values;
2569 const int num_values = 513;
2570 values.resize(num_values, 0);
2571 // First and last values should be 0;
2572 for (int i = 1; i < num_values - 1; ++i)
2573 values[i] = std::numeric_limits<int16_t>::max() *
2574 sin(static_cast<float>(i) * 2.0 * M_PI / (num_values - 1.0));
2575
2576 auto table_ty =
2577 RankedTensorType::get({num_values}, rewriter.getIntegerType(16));
2578 Value table = rewriter.create<tosa::ConstOp>(
2579 loc, table_ty,
2580 DenseElementsAttr::get(table_ty, llvm::makeArrayRef(values)));
2581
2582 auto table_result_ty = input_ty.clone(rewriter.getIntegerType(32));
2583 auto table_result = CreateOpAndInfer<tosa::TableOp>(
2584 rewriter, loc, table_result_ty, casted, table);
2585
2586 // 5. The range of table is a 23-bit two's compliment value. Normalize the
2587 // range by casting to an fp32 and dividing by 2^22.
2588 auto table_result_fp =
2589 CreateOpAndInfer<CastOp>(rewriter, loc, input_ty, table_result);
2590 auto output_scale = rewriter.create<ConstOp>(
2591 loc, fp_scalar_ty,
2592 DenseElementsAttr::get(
2593 fp_scalar_ty,
2594 {static_cast<float>(1.0 / static_cast<float>(1 << 22))}));
2595 CreateReplaceOpAndInfer<MulOp>(rewriter, op, output_ty, table_result_fp,
2596 output_scale, rewriter.getI32IntegerAttr(0));
2597 return success();
2598 }
2599
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2600 LogicalResult ConvertTFLCosOp::matchAndRewrite(
2601 Operation* op, PatternRewriter& rewriter) const {
2602 auto tfl_cos_op = cast<TFL::CosOp>(op);
2603 Value input = tfl_cos_op.x();
2604 RankedTensorType input_ty = input.getType().dyn_cast<RankedTensorType>();
2605 ShapedType output_ty =
2606 tfl_cos_op.getResult().getType().dyn_cast<ShapedType>();
2607
2608 if (!input_ty || !output_ty) return failure();
2609
2610 bool input_is_fp = input_ty.getElementType().isa<mlir::FloatType>();
2611 bool output_is_fp = output_ty.getElementType().isa<mlir::FloatType>();
2612
2613 if (!input_is_fp || !output_is_fp) {
2614 return rewriter.notifyMatchFailure(
2615 op, "ConvertTFLCosOp: input/result must be fp.");
2616 }
2617
2618 // Replace with the equivalent sin operation:
2619 // cos(x) = sin(x + π / 2).
2620 auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type());
2621 auto pi_2 = rewriter.create<ConstOp>(
2622 op->getLoc(), fp_scalar_ty,
2623 DenseElementsAttr::get(fp_scalar_ty, {static_cast<float>(M_PI_2)}));
2624 auto offset = rewriter.create<AddOp>(op->getLoc(), input_ty, input, pi_2);
2625
2626 CreateReplaceOpAndInfer<TFL::SinOp>(rewriter, op, output_ty, offset);
2627 return success();
2628 }
2629
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2630 LogicalResult ConvertTFLLogisticOp::matchAndRewrite(
2631 Operation* op, PatternRewriter& rewriter) const {
2632 auto tfl_logistic_op = cast<TFL::LogisticOp>(op);
2633
2634 ShapedType output_type =
2635 tfl_logistic_op.getResult().getType().dyn_cast<ShapedType>();
2636 RankedTensorType input_type =
2637 tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
2638 if (!input_type || !output_type) return failure();
2639
2640 bool input_is_qtype =
2641 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2642 bool output_is_qtype =
2643 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2644
2645 if (input_is_qtype != output_is_qtype) {
2646 return op->emitOpError(
2647 "ConvertTFLLogisticOp: input/output tensor should "
2648 "be all quantized or all floating-point.");
2649 }
2650
2651 if (input_is_qtype) {
2652 ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32));
2653 mlir::quant::UniformQuantizedType input_qtype =
2654 input_type.getElementType()
2655 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2656 mlir::quant::UniformQuantizedType output_qtype =
2657 output_type.getElementType()
2658 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2659
2660 auto sigmoid_func = [](double x) -> double {
2661 return 1.0 / (1.0 + std::exp(-x));
2662 };
2663
2664 if (input_qtype.getStorageTypeIntegralWidth() == 8) {
2665 Value table_const = getTosaConst8bitTable(
2666 rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
2667 output_qtype.getScale(), output_qtype.getZeroPoint(), sigmoid_func);
2668
2669 CreateReplaceOpAndInfer<tosa::TableOp>(rewriter, op, output_type,
2670 tfl_logistic_op.x(), table_const);
2671 } else { // int16
2672 if (input_qtype.getZeroPoint() != 0 || output_qtype.getZeroPoint() != 0) {
2673 op->emitOpError(
2674 "ConvertTFLLogistic: input/output zeropoint should be 0 in 16-bit "
2675 "mode");
2676 return failure();
2677 }
2678 double input_min = -32768 * input_qtype.getScale();
2679 double input_max = 32767 * input_qtype.getScale();
2680
2681 // Generate table with gen_lut() in
2682 // tensorflow/lite/kernels/internal/common.h
2683 Value table_const = getTosaConst16bitTable(rewriter, op, sigmoid_func,
2684 input_min, input_max);
2685
2686 auto op1_table_in = CreateOpAndInfer<tosa::TableOp>(
2687 rewriter, op->getLoc(), int32_type, tfl_logistic_op.x(), table_const);
2688
2689 Value op2_rescale_op1 =
2690 buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
2691 1.0 / 128.0, 0, 0, false, true);
2692
2693 rewriter.replaceOp(op, {op2_rescale_op1});
2694 }
2695 } else {
2696 CreateReplaceOpAndInfer<tosa::SigmoidOp>(rewriter, op, output_type,
2697 tfl_logistic_op.x());
2698 }
2699
2700 return success();
2701 }
2702
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2703 LogicalResult ConvertTFLTanhOp::matchAndRewrite(
2704 Operation* op, PatternRewriter& rewriter) const {
2705 auto tfl_tanh_op = cast<TFL::TanhOp>(op);
2706 ShapedType output_type =
2707 tfl_tanh_op.getResult().getType().dyn_cast<ShapedType>();
2708 RankedTensorType input_type =
2709 tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
2710 if (!input_type || !output_type) return failure();
2711
2712 bool input_is_qtype =
2713 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2714 bool output_is_qtype =
2715 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2716
2717 if (input_is_qtype != output_is_qtype) {
2718 return op->emitOpError(
2719 "ConvertTFLTanhOp: input/output tensor should "
2720 "be all quantized or all floating-point.");
2721 }
2722
2723 if (input_is_qtype) {
2724 ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32));
2725 mlir::quant::UniformQuantizedType input_qtype =
2726 input_type.getElementType()
2727 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2728 mlir::quant::UniformQuantizedType output_qtype =
2729 output_type.getElementType()
2730 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2731
2732 auto tanh_func = [](double x) -> double {
2733 x = std::exp(-2.0 * x);
2734 return (1.0 - x) / (1.0 + x);
2735 };
2736
2737 if (input_qtype.getStorageTypeIntegralWidth() == 8) {
2738 Value table_const = getTosaConst8bitTable(
2739 rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
2740 output_qtype.getScale(), output_qtype.getZeroPoint(), tanh_func);
2741
2742 CreateReplaceOpAndInfer<tosa::TableOp>(rewriter, op, output_type,
2743 tfl_tanh_op.input(), table_const);
2744 } else { // int16
2745 if (input_qtype.getZeroPoint() != 0 || output_qtype.getZeroPoint() != 0) {
2746 op->emitOpError(
2747 "ConvertTFLLogistic: input/output zeropoint should be 0 in 16-bit "
2748 "mode");
2749 return failure();
2750 }
2751 double input_min = -32768 * input_qtype.getScale();
2752 double input_max = 32767 * input_qtype.getScale();
2753
2754 // Generate table with gen_lut() in
2755 // tensorflow/lite/kernels/internal/common.h
2756 Value table_const =
2757 getTosaConst16bitTable(rewriter, op, tanh_func, input_min, input_max);
2758
2759 auto op1_table_in = CreateOpAndInfer<tosa::TableOp>(
2760 rewriter, op->getLoc(), int32_type, tfl_tanh_op.input(), table_const);
2761
2762 Value op2_rescale_op1 =
2763 buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
2764 1.0 / 128.0, 0, 0, false, true);
2765
2766 rewriter.replaceOp(op, {op2_rescale_op1});
2767 }
2768
2769 } else {
2770 CreateReplaceOpAndInfer<tosa::TanhOp>(rewriter, op, output_type,
2771 tfl_tanh_op.input());
2772 }
2773
2774 return success();
2775 }
2776
LegalizeFloatingPointPrelu(Operation * op,PatternRewriter & rewriter,Value input,Value alpha,ShapedType output_type)2777 static LogicalResult LegalizeFloatingPointPrelu(Operation* op,
2778 PatternRewriter& rewriter,
2779 Value input, Value alpha,
2780 ShapedType output_type) {
2781 Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0);
2782
2783 auto mul = CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
2784 input, alpha, 0);
2785
2786 auto ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
2787 rewriter, op->getLoc(), output_type.clone(rewriter.getIntegerType(1)),
2788 input, const_zero);
2789
2790 CreateReplaceOpAndInfer<tosa::SelectOp>(rewriter, op, output_type, ge, input,
2791 mul);
2792
2793 return success();
2794 }
2795
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2796 LogicalResult ConvertTFLPReluOp::matchAndRewrite(
2797 Operation* op, PatternRewriter& rewriter) const {
2798 auto tfl_prelu_op = cast<TFL::PReluOp>(op);
2799
2800 ShapedType input_type = tfl_prelu_op.input().getType().dyn_cast<ShapedType>();
2801 ShapedType output_type =
2802 tfl_prelu_op.getResult().getType().dyn_cast<ShapedType>();
2803 if (!input_type || !output_type)
2804 return rewriter.notifyMatchFailure(op,
2805 "input or output is not a ShapedType");
2806
2807 bool output_is_qtype =
2808 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2809
2810 // TODO(jennik): Handle the quantized case.
2811 if (output_is_qtype)
2812 return rewriter.notifyMatchFailure(
2813 op, "only floating point currently supported");
2814
2815 return LegalizeFloatingPointPrelu(op, rewriter, tfl_prelu_op.input(),
2816 tfl_prelu_op.alpha(), output_type);
2817 }
2818
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2819 LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite(
2820 Operation* op, PatternRewriter& rewriter) const {
2821 auto tfl_leakyrelu_op = cast<TFL::LeakyReluOp>(op);
2822 RankedTensorType input_type =
2823 tfl_leakyrelu_op.input().getType().dyn_cast<RankedTensorType>();
2824
2825 ShapedType output_type =
2826 tfl_leakyrelu_op.getResult().getType().dyn_cast<ShapedType>();
2827
2828 if (!input_type || !output_type) return failure();
2829
2830 bool output_is_qtype =
2831 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2832
2833 // Implement LeakyRelu as element-wise:
2834 // out = x > 0 ? x : alpha * x
2835 //
2836 // In TOSA ops:
2837 //
2838 // const_zero = constant(0)
2839 // op1 = mul(x, alpha)
2840 // op2 = greater_equal(x, const_zero)
2841 // out = select(a2, x, a1)
2842 //
2843 // If alpha can be constrained to 0.0 <= alpha <= 1.0, then
2844 // an alternative simpler lowering could be implemented with:
2845 //
2846 // max(mul(x, alapha), x)
2847 //
2848 // But this alternative is not robust unless alpha meets those constraints.
2849
2850 FloatAttr tmpAttr = tfl_leakyrelu_op.alphaAttr();
2851 // There is disagreement between the MLIR .td defaults and TF
2852 // documentation on 0.2 vs 0.3, but 0.2 will be used here.
2853 double alpha = 0.2;
2854
2855 if (tmpAttr) {
2856 alpha = tmpAttr.getValueAsDouble();
2857 }
2858
2859 if (output_is_qtype) {
2860 // op1 = rescale(input)
2861 // rescaled_alpha = (alpha << alpha_shift) // Remains within int32 range
2862 // op2 = mul(rescaled_input, rescaled_alpha, alpha_shift)
2863 // op3 = greater_equal(op1, 0)
2864 // op4 = select(op3, op1, op2)
2865 // out = rescale(op4)
2866 ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
2867
2868 UniformQuantizedType input_qtype =
2869 input_type.getElementType().cast<UniformQuantizedType>();
2870
2871 UniformQuantizedType output_qtype =
2872 output_type.getElementType().cast<UniformQuantizedType>();
2873
2874 double scale_alpha =
2875 input_qtype.getScale() * alpha / output_qtype.getScale();
2876 double scale_identity = input_qtype.getScale() / output_qtype.getScale();
2877
2878 Value op1_rescale_in =
2879 buildRescaleToInt32(rewriter, op, tfl_leakyrelu_op.input(), 1.0,
2880 input_qtype.getZeroPoint());
2881
2882 Value const_zero = getTosaConstTensorSingleI32(rewriter, op, 0);
2883 auto op2_ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
2884 rewriter, op->getLoc(), rescale_type.clone(rewriter.getI1Type()),
2885 op1_rescale_in, const_zero);
2886
2887 Value op3_rescale_alpha_in = buildRescale(
2888 rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_alpha,
2889 input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
2890
2891 Value op4_rescale_identity_in = buildRescale(
2892 rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_identity,
2893 input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
2894
2895 CreateReplaceOpAndInfer<tosa::SelectOp>(rewriter, op, output_type, op2_ge,
2896 op4_rescale_identity_in,
2897 op3_rescale_alpha_in);
2898
2899 return success();
2900
2901 } else {
2902 return LegalizeFloatingPointPrelu(
2903 op, rewriter, tfl_leakyrelu_op.input(),
2904 getTosaConstTensorSingleF32(rewriter, op, alpha), output_type);
2905 }
2906 }
2907
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2908 LogicalResult ConvertTFLNegOp::matchAndRewrite(
2909 Operation* op, PatternRewriter& rewriter) const {
2910 auto tfl_neg_op = cast<TFL::NegOp>(op);
2911 ShapedType output_type =
2912 tfl_neg_op.getResult().getType().dyn_cast<ShapedType>();
2913 if (!output_type) return failure();
2914
2915 CreateReplaceOpAndInfer<tosa::NegateOp>(rewriter, op, output_type,
2916 tfl_neg_op.x());
2917
2918 return success();
2919 }
2920
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2921 LogicalResult ConvertTFLYieldOp::matchAndRewrite(
2922 Operation* op, PatternRewriter& rewriter) const {
2923 rewriter.replaceOpWithNewOp<tosa::YieldOp>(op, op->getResultTypes(),
2924 op->getOperands());
2925
2926 return success();
2927 }
2928
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2929 LogicalResult ConvertTFLCustomOp::matchAndRewrite(
2930 Operation* op, PatternRewriter& rewriter) const {
2931 auto tfl_custom_op = cast<TFL::CustomOp>(op);
2932 rewriter.replaceOpWithNewOp<tosa::CustomOp>(
2933 op, op->getResultTypes(), tfl_custom_op.custom_code(), op->getOperands());
2934
2935 return success();
2936 }
2937
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2938 LogicalResult ConvertTFLReverseV2Op::matchAndRewrite(
2939 Operation* op, PatternRewriter& rewriter) const {
2940 auto tfl_reverse_op = cast<TFL::ReverseV2Op>(op);
2941
2942 RankedTensorType input_type =
2943 tfl_reverse_op.input().getType().dyn_cast<RankedTensorType>();
2944 RankedTensorType output_type =
2945 tfl_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
2946 if (!input_type || !output_type) return failure();
2947
2948 ElementsAttr axis_elems;
2949 if (!matchPattern(tfl_reverse_op.axis(), m_Constant(&axis_elems)))
2950 return failure();
2951
2952 auto input_rank = input_type.getShape().size();
2953 Value val = tfl_reverse_op.input();
2954 if (axis_elems.getNumElements() == 0) {
2955 auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
2956 rewriter, op->getLoc(), output_type, val);
2957 val = identity_op.getResult();
2958 } else {
2959 for (int i = 0; i < axis_elems.getNumElements(); i++) {
2960 int64_t axis_val = axis_elems.getValues<APInt>()[i].getSExtValue();
2961 if (axis_val < 0) axis_val += input_rank;
2962 auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2963 auto reverse_op = CreateOpAndInfer<tosa::ReverseOp>(
2964 rewriter, op->getLoc(), output_type, val, axis_attr);
2965
2966 val = reverse_op.getResult();
2967 }
2968 }
2969
2970 rewriter.replaceOp(op, {val});
2971
2972 return success();
2973 }
2974
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2975 LogicalResult ConvertTFLQuantizeOp::matchAndRewrite(
2976 Operation* op, PatternRewriter& rewriter) const {
2977 auto tfl_quantize_op = cast<TFL::QuantizeOp>(op);
2978
2979 RankedTensorType input_type =
2980 tfl_quantize_op.input().getType().dyn_cast<RankedTensorType>();
2981 ShapedType output_type =
2982 tfl_quantize_op.getResult().getType().dyn_cast<ShapedType>();
2983 if (!input_type || !output_type) return failure();
2984
2985 ShapedType qtype =
2986 tfl_quantize_op.getResult().getType().dyn_cast<ShapedType>();
2987 if (!qtype) return failure();
2988
2989 UniformQuantizedType element_type =
2990 qtype.getElementType().dyn_cast<UniformQuantizedType>();
2991 if (!element_type) return failure();
2992
2993 UniformQuantizedType input_element_type =
2994 input_type.getElementType().dyn_cast<UniformQuantizedType>();
2995
2996 // If input is already a quantized type, this is basically a RESCALE (or
2997 // tensorflow::ops::Requantize)
2998 if (input_element_type) {
2999 double rescale_scale =
3000 input_element_type.getScale() / element_type.getScale();
3001 Value rescale_op =
3002 buildRescale(rewriter, op, output_type, tfl_quantize_op.input(),
3003 rescale_scale, input_element_type.getZeroPoint(),
3004 element_type.getZeroPoint(), true, true);
3005
3006 rewriter.replaceOp(op, {rescale_op});
3007 return success();
3008 } else {
3009 double scale = 1 / element_type.getScale();
3010 int64_t zp = element_type.getZeroPoint();
3011 int64_t num_bits = element_type.getStorageTypeIntegralWidth();
3012 zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
3013
3014 llvm::Optional<Value> result = convertQuantizeOp(
3015 rewriter, op, output_type, tfl_quantize_op.input(), scale, zp);
3016
3017 if (!result) return failure();
3018
3019 rewriter.replaceOp(op, {result.getValue()});
3020
3021 return success();
3022 }
3023 }
3024
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3025 LogicalResult ConvertTFLDequantizeOp::matchAndRewrite(
3026 Operation* op, PatternRewriter& rewriter) const {
3027 auto tfl_dequantize_op = cast<TFL::DequantizeOp>(op);
3028
3029 ShapedType output_type =
3030 tfl_dequantize_op.getResult().getType().dyn_cast<ShapedType>();
3031 // Not a ranked tensor output
3032 if (!output_type) return failure();
3033
3034 RankedTensorType qtype =
3035 tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
3036 if (!qtype) return failure();
3037
3038 Type element_type = qtype.getElementType();
3039 if (element_type.isa<FloatType>()) {
3040 CreateReplaceOpAndInfer<tosa::CastOp>(rewriter, op, output_type,
3041 tfl_dequantize_op.input());
3042 return success();
3043 }
3044
3045 if (auto eq_ty = element_type.dyn_cast<quant::UniformQuantizedType>()) {
3046 double scale = eq_ty.getScale();
3047 int64_t zp = eq_ty.getZeroPoint();
3048 int64_t num_bits = eq_ty.getStorageTypeIntegralWidth();
3049 zp = eq_ty.isSigned() ? zp : zp - (1 << (num_bits - 1));
3050
3051 llvm::Optional<Value> result = convertDequantizeOp(
3052 rewriter, op, output_type, tfl_dequantize_op.input(), scale, zp, 0);
3053
3054 if (!result) return failure();
3055
3056 rewriter.replaceOp(op, {result.getValue()});
3057 return success();
3058 }
3059
3060 if (quant::UniformQuantizedPerAxisType eq_ty =
3061 element_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
3062 SmallVector<float> zps;
3063 for (auto zp : eq_ty.getZeroPoints()) {
3064 int64_t num_bits = eq_ty.getStorageTypeIntegralWidth();
3065 zps.push_back(eq_ty.isSigned() ? zp : zp - (1 << (num_bits - 1)));
3066 }
3067
3068 SmallVector<float> scales;
3069 for (auto scale : eq_ty.getScales()) {
3070 scales.push_back(scale);
3071 }
3072
3073 llvm::Optional<Value> result = convertDequantizeOp(
3074 rewriter, op, output_type, tfl_dequantize_op.input(), scales, zps,
3075 eq_ty.getQuantizedDimension());
3076
3077 if (!result) return failure();
3078
3079 rewriter.replaceOp(op, {result.getValue()});
3080 return success();
3081 }
3082
3083 return failure();
3084 }
3085
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3086 LogicalResult ConvertTFLConstOp::matchAndRewrite(
3087 Operation* op, PatternRewriter& rewriter) const {
3088 auto tfl_const_op = cast<TFL::ConstOp>(op);
3089
3090 ShapedType output_type =
3091 tfl_const_op.getResult().getType().dyn_cast<ShapedType>();
3092 if (!output_type) return failure();
3093
3094 ElementsAttr elements = tfl_const_op.value();
3095 Type element_type = elements.getType().getElementType();
3096 if (output_type.getElementType().isa<quant::QuantizedType>()) {
3097 output_type = RankedTensorType::get(output_type.getShape(), element_type);
3098 }
3099
3100 // If the output shape is unranked we can extract the result shape from the
3101 // attribute shape. This occurs as some TFLite folders create constants with
3102 // unranked shapes.
3103 if (!output_type.hasRank()) {
3104 output_type = elements.getType().cast<ShapedType>().clone(element_type);
3105 }
3106
3107 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, elements);
3108
3109 return success();
3110 }
3111
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3112 LogicalResult ConvertTFLQConstOp::matchAndRewrite(
3113 Operation* op, PatternRewriter& rewriter) const {
3114 auto tfl_qconst_op = cast<TFL::QConstOp>(op);
3115
3116 ShapedType output_type =
3117 tfl_qconst_op.getResult().getType().dyn_cast<ShapedType>();
3118 if (!output_type) return failure();
3119
3120 ElementsAttr elements = tfl_qconst_op.value();
3121
3122 // If the output shape is unranked we can extract the result shape from the
3123 // attribute shape. This occurs as some TFLite folders create constants with
3124 // unranked shapes.
3125 if (!output_type.hasRank()) {
3126 output_type = elements.getType().cast<ShapedType>().clone(
3127 output_type.getElementType());
3128 }
3129
3130 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, elements);
3131
3132 return success();
3133 }
3134
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3135 LogicalResult ConvertConstantOp::matchAndRewrite(
3136 Operation* op, PatternRewriter& rewriter) const {
3137 auto tfl_const_op = cast<arith::ConstantOp>(op);
3138
3139 ShapedType output_type =
3140 tfl_const_op.getResult().getType().dyn_cast<ShapedType>();
3141 // Not a ranked tensor output
3142 if (!output_type) return failure();
3143
3144 ElementsAttr attr = tfl_const_op.getValueAttr().dyn_cast<ElementsAttr>();
3145
3146 auto e_type = output_type.getElementType();
3147 // TOSA only support up to 48-bits
3148 // If source is higher than that, it's not representabble.
3149 // For data type like 64 bits, we need to truncate them into 48 bits.
3150 if (e_type.isInteger(64)) {
3151 e_type = rewriter.getIntegerType(48);
3152 attr = attr.cast<DenseIntOrFPElementsAttr>().mapValues(
3153 e_type, [](const APInt& x) -> APInt { return x.trunc(48); });
3154 }
3155
3156 if (!output_type.hasRank()) {
3157 if (auto attr_type = attr.getType().dyn_cast<ShapedType>()) {
3158 output_type = attr_type.clone(e_type);
3159 }
3160 }
3161
3162 output_type = output_type.clone(e_type);
3163 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, attr);
3164
3165 return success();
3166 }
3167
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3168 LogicalResult ConvertTFLGatherOp::matchAndRewrite(
3169 Operation* op, PatternRewriter& rewriter) const {
3170 auto tfl_gather_op = cast<TFL::GatherOp>(op);
3171
3172 int32_t axis = tfl_gather_op.axisAttr().getInt();
3173 int32_t batch_dims = 0;
3174 if (auto batch_attr = tfl_gather_op.batch_dimsAttr()) {
3175 batch_dims = static_cast<int32_t>(batch_attr.getInt());
3176 }
3177
3178 llvm::Optional<Value> result = convertGatherOp(
3179 rewriter, op, tfl_gather_op.getResult(), tfl_gather_op.params(),
3180 tfl_gather_op.indices(), batch_dims, axis);
3181
3182 if (!result) return failure();
3183
3184 rewriter.replaceOp(op, {result.getValue()});
3185
3186 return success();
3187 }
3188
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3189 LogicalResult ConvertTFLGatherNdOp::matchAndRewrite(
3190 Operation* op, PatternRewriter& rewriter) const {
3191 auto tfl_gathernd_op = cast<TFL::GatherNdOp>(op);
3192
3193 llvm::Optional<Value> result =
3194 convertGatherNdOp(rewriter, op, tfl_gathernd_op.getResult(),
3195 tfl_gathernd_op.params(), tfl_gathernd_op.indices());
3196
3197 if (!result) return failure();
3198 rewriter.replaceOp(op, {result.getValue()});
3199
3200 return success();
3201 }
3202
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3203 LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite(
3204 Operation* op, PatternRewriter& rewriter) const {
3205 auto tfl_sparse_to_dense_op = cast<TFL::SparseToDenseOp>(op);
3206 auto indices = tfl_sparse_to_dense_op.sparse_indices();
3207 auto values = tfl_sparse_to_dense_op.sparse_values();
3208 auto default_value = tfl_sparse_to_dense_op.default_value();
3209 auto indices_ty = indices.getType().cast<ShapedType>();
3210 auto indices_ety = indices_ty.getElementType();
3211 auto values_ty = values.getType().cast<ShapedType>();
3212 auto result_ty =
3213 tfl_sparse_to_dense_op.getResult().getType().cast<ShapedType>();
3214 auto result_ety = result_ty.getElementType();
3215 auto loc = op->getLoc();
3216
3217 if (!result_ty.hasStaticShape()) return failure();
3218 auto result_rank = result_ty.getRank();
3219
3220 // We want to generate the default tensor we need to scatter. Note that the
3221 // result_ty needs to be a statically shaped tensor.
3222 ElementsAttr default_value_attr;
3223 if (!matchPattern(default_value, m_Constant(&default_value_attr)))
3224 return failure();
3225
3226 if (!default_value_attr.isSplat()) return failure();
3227
3228 ShapedType scatter_ty =
3229 RankedTensorType::get({1, result_ty.getNumElements(), 1}, result_ety);
3230
3231 Value default_const = rewriter.create<tosa::ConstOp>(
3232 loc, scatter_ty,
3233 DenseElementsAttr::get(scatter_ty,
3234 default_value_attr.getSplatValue<APInt>().sext(
3235 result_ety.getIntOrFloatBitWidth())));
3236
3237 // We need to determine what the index multiplier does
3238 llvm::SmallVector<int32_t> multiply_constant_ints;
3239 multiply_constant_ints.resize(result_rank, 1);
3240 for (int i = result_rank - 1; i > 0; i--) {
3241 multiply_constant_ints[i - 1] =
3242 result_ty.getDimSize(i) * multiply_constant_ints[i];
3243 }
3244
3245 indices_ety = rewriter.getI32Type();
3246 indices_ty = RankedTensorType::get(indices_ty.getShape(), indices_ety);
3247 indices = CreateOpAndInfer<tosa::CastOp>(rewriter, loc, indices_ty, indices);
3248
3249 auto multiply_constant_type =
3250 RankedTensorType::get({result_rank}, indices_ety);
3251 auto multiply_constant_attr = DenseElementsAttr::get(
3252 multiply_constant_type, llvm::makeArrayRef(multiply_constant_ints));
3253 Value multiply_constant = CreateOpAndInfer<tosa::ConstOp>(
3254 rewriter, loc, multiply_constant_type, multiply_constant_attr);
3255
3256 Value multiply_op = CreateOpAndInfer<tosa::MulOp>(
3257 rewriter, loc, indices_ty, indices, multiply_constant, 0);
3258
3259 Value reduce_op = CreateOpAndInfer<tosa::ReduceSumOp>(
3260 rewriter, loc, UnrankedTensorType::get(indices_ety), multiply_op,
3261 rewriter.getI64IntegerAttr(1));
3262
3263 auto values_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3264 rewriter, loc, UnrankedTensorType::get(result_ety), values,
3265 rewriter.getI64ArrayAttr(
3266 ArrayRef<int64_t>{1, values_ty.getDimSize(0), 1}));
3267
3268 auto index_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
3269 rewriter, loc, UnrankedTensorType::get(indices_ety), reduce_op,
3270 rewriter.getI64ArrayAttr(ArrayRef<int64_t>{1, indices_ty.getDimSize(0)}));
3271
3272 auto scatter = CreateOpAndInfer<tosa::ScatterOp>(
3273 rewriter, loc, UnrankedTensorType::get(result_ety), default_const,
3274 index_reshape_op, values_reshape_op);
3275
3276 CreateReplaceOpAndInfer<tosa::ReshapeOp>(
3277 rewriter, op, result_ty, scatter,
3278 rewriter.getI64ArrayAttr(result_ty.getShape()));
3279
3280 return success();
3281 }
3282
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3283 LogicalResult ConvertTFLOneHotOp::matchAndRewrite(
3284 Operation* op, PatternRewriter& rewriter) const {
3285 auto tfl_one_hot_op = cast<TFL::OneHotOp>(op);
3286
3287 ElementsAttr depth_elems;
3288 if (!matchPattern(tfl_one_hot_op.depth(), m_Constant(&depth_elems)))
3289 return failure();
3290 int32_t depth = depth_elems.getValues<APInt>()[0].getSExtValue();
3291
3292 IntegerAttr axisAttr = tfl_one_hot_op.axisAttr();
3293 int32_t axis = axisAttr.getInt();
3294
3295 llvm::Optional<Value> result = convertOneHotOp(
3296 rewriter, op, tfl_one_hot_op.getResult(), tfl_one_hot_op.indices(),
3297 tfl_one_hot_op.on_value(), tfl_one_hot_op.off_value(), depth, axis);
3298
3299 if (!result) return failure();
3300
3301 rewriter.replaceOp(op, {result.getValue()});
3302
3303 return success();
3304 }
3305
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3306 LogicalResult ConvertTFLArgMaxOp::matchAndRewrite(
3307 Operation* op, PatternRewriter& rewriter) const {
3308 auto arg_max_op = cast<TFL::ArgMaxOp>(op);
3309
3310 ElementsAttr dim_elems;
3311 if (!matchPattern(arg_max_op.dim(), m_Constant(&dim_elems))) return failure();
3312
3313 int32_t dim = dim_elems.getValues<APInt>()[0].getSExtValue();
3314 CreateReplaceOpAndInfer<tosa::ArgMaxOp>(
3315 rewriter, op, arg_max_op.getType(), arg_max_op.input(),
3316 rewriter.getIntegerAttr(rewriter.getI64Type(), dim));
3317
3318 return success();
3319 }
3320
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3321 LogicalResult ConvertTFLFakeQuantOp::matchAndRewrite(
3322 Operation* op, PatternRewriter& rewriter) const {
3323 auto fakequant_op = cast<TFL::FakeQuantOp>(op);
3324
3325 ShapedType output_type =
3326 fakequant_op.getResult().getType().dyn_cast<ShapedType>();
3327 // Not a ranked tensor output
3328 if (!output_type) return failure();
3329
3330 llvm::Optional<Value> result =
3331 convertFakeQuantOp(rewriter, op, output_type, fakequant_op.input(),
3332 fakequant_op.minAttr().getValueAsDouble(),
3333 fakequant_op.maxAttr().getValueAsDouble(),
3334 fakequant_op.num_bitsAttr().getInt(),
3335 fakequant_op.narrow_rangeAttr().getValue());
3336
3337 if (!result) return failure();
3338
3339 rewriter.replaceOp(op, {result.getValue()});
3340
3341 return success();
3342 }
3343
initialize(MLIRContext * context)3344 LogicalResult LegalizeTFL::initialize(MLIRContext* context) {
3345 RewritePatternSet patterns(context);
3346 mlir::tosa::populateLegalizeTFLPatterns(context, patterns);
3347 frozen_patterns_ = FrozenRewritePatternSet(
3348 std::move(patterns), this->disabled_patterns_, this->enabled_patterns_);
3349 return success();
3350 }
3351
runOnOperation()3352 void LegalizeTFL::runOnOperation() {
3353 if (ApplyPatternsWithShapeResolution(getOperation(), this->frozen_patterns_)
3354 .failed()) {
3355 signalPassFailure();
3356 }
3357 }
3358
3359 } // namespace
3360
populateLegalizeTFLPatterns(MLIRContext * ctx,RewritePatternSet & patterns)3361 void populateLegalizeTFLPatterns(MLIRContext* ctx,
3362 RewritePatternSet& patterns) {
3363 #define DEF_PATTERN_INSERT(PAT) \
3364 patterns.addWithLabel<Convert##PAT##Op>({#PAT}, ctx);
3365
3366 DEF_PATTERN_INSERT(TFLAbs);
3367 DEF_PATTERN_INSERT(TFLCeil);
3368 DEF_PATTERN_INSERT(TFLFloor);
3369 DEF_PATTERN_INSERT(TFLExp);
3370 DEF_PATTERN_INSERT(TFLLog);
3371 DEF_PATTERN_INSERT(TFLRsqrt);
3372 DEF_PATTERN_INSERT(TFLLogicalNot);
3373 DEF_PATTERN_INSERT(TFLCast);
3374
3375 DEF_PATTERN_INSERT(QuantStat);
3376
3377 DEF_PATTERN_INSERT(TFLLogicalAnd);
3378 DEF_PATTERN_INSERT(TFLLogicalOr);
3379 DEF_PATTERN_INSERT(TFLPow);
3380
3381 DEF_PATTERN_INSERT(TFLRelu);
3382 DEF_PATTERN_INSERT(TFLRelu1);
3383 DEF_PATTERN_INSERT(TFLRelu6);
3384 DEF_PATTERN_INSERT(TFLEqual);
3385 DEF_PATTERN_INSERT(TFLNotEqual);
3386 DEF_PATTERN_INSERT(TFLGreater);
3387 DEF_PATTERN_INSERT(TFLGreaterEqual);
3388 DEF_PATTERN_INSERT(TFLAdd);
3389 DEF_PATTERN_INSERT(TFLSub);
3390 DEF_PATTERN_INSERT(TFLMul);
3391 DEF_PATTERN_INSERT(TFLSquare);
3392 DEF_PATTERN_INSERT(TFLSquaredDifference);
3393 DEF_PATTERN_INSERT(TFLRound);
3394 DEF_PATTERN_INSERT(TFLDiv);
3395 DEF_PATTERN_INSERT(TFLMaximum);
3396 DEF_PATTERN_INSERT(TFLMinimum);
3397 DEF_PATTERN_INSERT(TFLFloorMod);
3398 DEF_PATTERN_INSERT(TFLFloorDiv);
3399 DEF_PATTERN_INSERT(TFLAddN);
3400 DEF_PATTERN_INSERT(TFLAveragePool2D);
3401 DEF_PATTERN_INSERT(TFLMaxPool2D);
3402 DEF_PATTERN_INSERT(TFLConcatenation);
3403 DEF_PATTERN_INSERT(TFLReshape);
3404 DEF_PATTERN_INSERT(TFLRank);
3405 DEF_PATTERN_INSERT(TFLShape);
3406 DEF_PATTERN_INSERT(TFLExpandDims);
3407 DEF_PATTERN_INSERT(TFLSqueeze);
3408 DEF_PATTERN_INSERT(TFLFill);
3409 DEF_PATTERN_INSERT(TFLElu);
3410 DEF_PATTERN_INSERT(TFLSoftmax);
3411 DEF_PATTERN_INSERT(TFLLogSoftmax);
3412 DEF_PATTERN_INSERT(TFLSqrt);
3413 DEF_PATTERN_INSERT(TFLL2Normalization);
3414 DEF_PATTERN_INSERT(TFLReduceAny);
3415 DEF_PATTERN_INSERT(TFLReduceMax);
3416 DEF_PATTERN_INSERT(TFLReduceMin);
3417 DEF_PATTERN_INSERT(TFLMean);
3418 DEF_PATTERN_INSERT(TFLReduceProd);
3419 DEF_PATTERN_INSERT(TFLSum);
3420 DEF_PATTERN_INSERT(TFLConv2D);
3421 DEF_PATTERN_INSERT(TFLTransposeConv);
3422 DEF_PATTERN_INSERT(TFLDepthwiseConv2D);
3423 DEF_PATTERN_INSERT(TFLFullyConnected);
3424 DEF_PATTERN_INSERT(TFLBatchMatMul);
3425 DEF_PATTERN_INSERT(TFLSplit);
3426 DEF_PATTERN_INSERT(TFLSplitV);
3427 DEF_PATTERN_INSERT(TFLPack);
3428 DEF_PATTERN_INSERT(TFLUnpack);
3429 DEF_PATTERN_INSERT(TFLTranspose);
3430 DEF_PATTERN_INSERT(TFLTile);
3431 DEF_PATTERN_INSERT(TFLSlice);
3432 DEF_PATTERN_INSERT(TFLStridedSlice);
3433 DEF_PATTERN_INSERT(TFLHardSwish);
3434 DEF_PATTERN_INSERT(TFLZerosLike);
3435 DEF_PATTERN_INSERT(TFLLess);
3436 DEF_PATTERN_INSERT(TFLLessEqual);
3437 DEF_PATTERN_INSERT(TFLPad);
3438 DEF_PATTERN_INSERT(TFLPadV2);
3439 DEF_PATTERN_INSERT(TFLResizeBilinear);
3440 DEF_PATTERN_INSERT(TFLResizeNearestNeighbor);
3441 DEF_PATTERN_INSERT(TFLSelect);
3442 DEF_PATTERN_INSERT(TFLSelectV2);
3443 DEF_PATTERN_INSERT(TFLSpaceToBatchNd);
3444 DEF_PATTERN_INSERT(TFLBatchToSpaceNd);
3445 DEF_PATTERN_INSERT(TFLSpaceToDepth);
3446 DEF_PATTERN_INSERT(TFLDepthToSpace);
3447 DEF_PATTERN_INSERT(TFLSin);
3448 DEF_PATTERN_INSERT(TFLCos);
3449 DEF_PATTERN_INSERT(TFLLogistic);
3450 DEF_PATTERN_INSERT(TFLTanh);
3451 DEF_PATTERN_INSERT(TFLPRelu);
3452 DEF_PATTERN_INSERT(TFLLeakyRelu);
3453 DEF_PATTERN_INSERT(TFLNeg);
3454 DEF_PATTERN_INSERT(TFLYield);
3455 DEF_PATTERN_INSERT(TFLCustom);
3456 DEF_PATTERN_INSERT(TFLReverseV2);
3457 DEF_PATTERN_INSERT(TFLQuantize);
3458 DEF_PATTERN_INSERT(TFLDequantize);
3459 DEF_PATTERN_INSERT(TFLConst);
3460 DEF_PATTERN_INSERT(TFLQConst);
3461 DEF_PATTERN_INSERT(TFLGather);
3462 DEF_PATTERN_INSERT(TFLGatherNd);
3463 DEF_PATTERN_INSERT(TFLSparseToDense);
3464 DEF_PATTERN_INSERT(Constant);
3465 DEF_PATTERN_INSERT(TFLOneHot);
3466 DEF_PATTERN_INSERT(TFLArgMax);
3467 DEF_PATTERN_INSERT(TFLFakeQuant);
3468 }
3469
3470 // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
createLegalizeTFLPass(ArrayRef<std::string> disabled_patterns,ArrayRef<std::string> enabled_patterns)3471 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTFLPass(
3472 ArrayRef<std::string> disabled_patterns,
3473 ArrayRef<std::string> enabled_patterns) {
3474 return std::make_unique<LegalizeTFL>(disabled_patterns, enabled_patterns);
3475 }
3476
3477 } // namespace tosa
3478 } // namespace mlir
3479