1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Legalize TensorFlow Lite to TOSA
17
18 #include <climits>
19 #include <cstddef>
20 #include <cstdint>
21 #include <fstream>
22 #include <iterator>
23 #include <numeric>
24 #include <unordered_set>
25
26 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
27 #include "mlir/Support/LLVM.h" // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
30 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
31 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
32 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
33
34 #define PASS_NAME "tosa-legalize-tfl"
35 #define DEBUG_TYPE PASS_NAME
36 #define HARDSWISH_EXPLICIT_RESCALING false
37
38 // Conditionally avoid converting some TFLite ops to TOSA.
39 // By default, all conversions will be invoked.
40 //
41 // The denylist file lists patterns which are not legalized from TFLite to TOSA.
42 llvm::cl::opt<std::string> tfl_tosa_denylist(
43 "tfl-tosa-denylist",
44 llvm::cl::desc("<a list of patterns not legalized from TFLite to TOSA>"),
45 llvm::cl::init("transforms/tfl_tosa_denylist.txt"),
46 llvm::cl::value_desc("pattern name"));
47
48 namespace mlir {
49
50 namespace tosa {
51
52 namespace {
53 // Performs lowering to TOSA dialect.
54 class LegalizeTFL : public PassWrapper<LegalizeTFL, FunctionPass> {
55 public:
LegalizeTFL()56 explicit LegalizeTFL() {}
57 void runOnFunction() override;
58 };
59
60 #include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
61
62 #define DECL_CONVERT_OP(tfl_op) \
63 struct ConvertTFL##tfl_op##Op : public RewritePattern { \
64 explicit ConvertTFL##tfl_op##Op(MLIRContext* context) \
65 : RewritePattern(TFL::tfl_op##Op::getOperationName(), 1, context) {} \
66 LogicalResult matchAndRewrite(Operation* op, \
67 PatternRewriter& rewriter) const override; \
68 }
69 DECL_CONVERT_OP(Relu);
70 DECL_CONVERT_OP(Relu6);
71 DECL_CONVERT_OP(Equal);
72 DECL_CONVERT_OP(NotEqual);
73 DECL_CONVERT_OP(Greater);
74 DECL_CONVERT_OP(GreaterEqual);
75 DECL_CONVERT_OP(Add);
76 DECL_CONVERT_OP(Sub);
77 DECL_CONVERT_OP(Mul);
78 DECL_CONVERT_OP(Square);
79 DECL_CONVERT_OP(SquaredDifference);
80 DECL_CONVERT_OP(Round);
81 DECL_CONVERT_OP(Div);
82 DECL_CONVERT_OP(Maximum);
83 DECL_CONVERT_OP(Minimum);
84 DECL_CONVERT_OP(FloorMod);
85 DECL_CONVERT_OP(FloorDiv);
86 DECL_CONVERT_OP(AddN);
87 DECL_CONVERT_OP(AveragePool2D);
88 DECL_CONVERT_OP(MaxPool2D);
89 DECL_CONVERT_OP(Concatenation);
90 DECL_CONVERT_OP(Reshape);
91 DECL_CONVERT_OP(Rank);
92 DECL_CONVERT_OP(Shape);
93 DECL_CONVERT_OP(ExpandDims);
94 DECL_CONVERT_OP(Squeeze);
95 DECL_CONVERT_OP(Fill);
96 DECL_CONVERT_OP(Elu);
97 DECL_CONVERT_OP(Softmax);
98 DECL_CONVERT_OP(LogSoftmax);
99 DECL_CONVERT_OP(ReduceAny);
100 DECL_CONVERT_OP(ReduceMax);
101 DECL_CONVERT_OP(ReduceMin);
102 DECL_CONVERT_OP(Mean);
103 DECL_CONVERT_OP(ReduceProd);
104 DECL_CONVERT_OP(Sum);
105 DECL_CONVERT_OP(Conv2D);
106 DECL_CONVERT_OP(TransposeConv);
107 DECL_CONVERT_OP(DepthwiseConv2D);
108 DECL_CONVERT_OP(FullyConnected);
109 DECL_CONVERT_OP(Split);
110 DECL_CONVERT_OP(SplitV);
111 DECL_CONVERT_OP(Pack);
112 DECL_CONVERT_OP(Unpack);
113 DECL_CONVERT_OP(Transpose);
114 DECL_CONVERT_OP(Tile);
115 DECL_CONVERT_OP(Slice);
116 DECL_CONVERT_OP(StridedSlice);
117 DECL_CONVERT_OP(HardSwish);
118 DECL_CONVERT_OP(ZerosLike);
119 DECL_CONVERT_OP(Less);
120 DECL_CONVERT_OP(LessEqual);
121 DECL_CONVERT_OP(Pad);
122 DECL_CONVERT_OP(ResizeBilinear);
123 DECL_CONVERT_OP(ResizeNearestNeighbor);
124 DECL_CONVERT_OP(Select);
125 DECL_CONVERT_OP(SelectV2);
126 DECL_CONVERT_OP(SpaceToBatchNd);
127 DECL_CONVERT_OP(BatchToSpaceNd);
128 DECL_CONVERT_OP(SpaceToDepth);
129 DECL_CONVERT_OP(DepthToSpace);
130 DECL_CONVERT_OP(Logistic);
131 DECL_CONVERT_OP(Tanh);
132 DECL_CONVERT_OP(PRelu);
133 DECL_CONVERT_OP(LeakyRelu);
134 DECL_CONVERT_OP(Neg);
135 DECL_CONVERT_OP(Yield);
136 DECL_CONVERT_OP(Custom);
137 DECL_CONVERT_OP(ReverseV2);
138 DECL_CONVERT_OP(Quantize);
139 DECL_CONVERT_OP(Dequantize);
140 DECL_CONVERT_OP(QConst);
141 #undef DECL_CONVERT_OP
142
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const143 LogicalResult ConvertTFLReluOp::matchAndRewrite(
144 Operation* op, PatternRewriter& rewriter) const {
145 auto tfl_relu_op = cast<TFL::ReluOp>(op);
146
147 RankedTensorType input_type =
148 tfl_relu_op.x().getType().dyn_cast<RankedTensorType>();
149 RankedTensorType output_type =
150 tfl_relu_op.getResult().getType().dyn_cast<RankedTensorType>();
151 // Not a ranked tensor output
152 if (!input_type || !output_type) return failure();
153
154 bool input_is_qtype =
155 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
156 bool output_is_qtype =
157 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
158
159 if (input_is_qtype != output_is_qtype) {
160 return op->emitOpError(
161 "ConvertTFLReluOp: input/output tensor should "
162 "be all quantized or all floating-point.");
163 }
164
165 Value output;
166 if (output_is_qtype) {
167 RankedTensorType rescale_type =
168 RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
169 UniformQuantizedType input_qtype =
170 input_type.getElementType()
171 .dyn_cast<mlir::quant::UniformQuantizedType>();
172 UniformQuantizedType output_qtype =
173 output_type.getElementType()
174 .dyn_cast<mlir::quant::UniformQuantizedType>();
175
176 Value op1_rescale_in = buildRescaleToInt32(
177 rewriter, op, tfl_relu_op.x(), 1.0f, input_qtype.getZeroPoint());
178 auto op2_relun_op1 = rewriter.create<tosa::ReluNOp>(
179 op->getLoc(), rescale_type, op1_rescale_in,
180 rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
181 rewriter.getF32FloatAttr(0.0f));
182 Value op3_rescale_op2 = buildRescaleFromInt32(
183 rewriter, op, output_type, op2_relun_op1.getResult(), 1.0f,
184 output_qtype.getZeroPoint());
185
186 output = op3_rescale_op2;
187 } else {
188 auto op1_relun_in = rewriter.create<tosa::ReluNOp>(
189 op->getLoc(), output_type, tfl_relu_op.x(),
190 rewriter.getI64IntegerAttr(0),
191 rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
192
193 output = op1_relun_in.getResult();
194 }
195
196 rewriter.replaceOp(op, {output});
197 return success();
198 }
199
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const200 LogicalResult ConvertTFLRelu6Op::matchAndRewrite(
201 Operation* op, PatternRewriter& rewriter) const {
202 auto tfl_relu6_op = cast<TFL::Relu6Op>(op);
203
204 RankedTensorType input_type =
205 tfl_relu6_op.x().getType().dyn_cast<RankedTensorType>();
206 RankedTensorType output_type =
207 tfl_relu6_op.getResult().getType().dyn_cast<RankedTensorType>();
208 // Not a ranked tensor output
209 if (!input_type || !output_type) return failure();
210
211 bool input_is_qtype =
212 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
213 bool output_is_qtype =
214 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
215
216 if (input_is_qtype != output_is_qtype) {
217 return op->emitOpError(
218 "ConvertTFLRelu6Op: input/output tensor should "
219 "be all quantized or all floating-point.");
220 }
221
222 Value output;
223 if (output_is_qtype && input_is_qtype) {
224 RankedTensorType rescale_type =
225 RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
226 UniformQuantizedType input_qtype =
227 input_type.getElementType()
228 .dyn_cast<mlir::quant::UniformQuantizedType>();
229 UniformQuantizedType output_qtype =
230 output_type.getElementType()
231 .dyn_cast<mlir::quant::UniformQuantizedType>();
232 int64_t rescaled_6 = std::llround(6.0f / input_qtype.getScale()) +
233 input_qtype.getZeroPoint();
234
235 Value op1_rescale_in = buildRescaleToInt32(
236 rewriter, op, tfl_relu6_op.x(), 1.0f, input_qtype.getZeroPoint());
237 auto op2_relun_op1 = rewriter.create<tosa::ReluNOp>(
238 op->getLoc(), rescale_type, op1_rescale_in,
239 rewriter.getI64IntegerAttr(rescaled_6), rewriter.getF32FloatAttr(0.0f));
240 Value op3_rescale_op2 = buildRescaleFromInt32(
241 rewriter, op, output_type, op2_relun_op1.getResult(), 1.0f,
242 output_qtype.getZeroPoint());
243
244 output = op3_rescale_op2;
245 } else {
246 auto op1_relun_in = rewriter.create<tosa::ReluNOp>(
247 op->getLoc(), output_type, tfl_relu6_op.x(),
248 rewriter.getI64IntegerAttr(0), rewriter.getF32FloatAttr(6.0f));
249
250 output = op1_relun_in.getResult();
251 }
252
253 rewriter.replaceOp(op, {output});
254 return success();
255 }
256
257 // TODO: Use a utility function for common code in comparison ops.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const258 LogicalResult ConvertTFLEqualOp::matchAndRewrite(
259 Operation* op, PatternRewriter& rewriter) const {
260 auto tfl_equal_op = cast<TFL::EqualOp>(op);
261
262 RankedTensorType input_x_type =
263 tfl_equal_op.x().getType().dyn_cast<RankedTensorType>();
264 RankedTensorType input_y_type =
265 tfl_equal_op.y().getType().dyn_cast<RankedTensorType>();
266 RankedTensorType output_type =
267 tfl_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
268 // Not a ranked tensor output
269 if (!input_x_type || !input_y_type || !output_type) return failure();
270
271 bool input_x_is_qtype =
272 input_x_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
273 bool input_y_is_qtype =
274 input_y_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
275 bool output_is_qtype =
276 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
277
278 if (input_x_is_qtype != output_is_qtype ||
279 input_y_is_qtype != output_is_qtype) {
280 return op->emitOpError(
281 "ConvertTFLEqualOp: input/output tensor should "
282 "be all quantized or all floating-point.");
283 }
284
285 Value output;
286 if (output_is_qtype && input_x_is_qtype && input_y_is_qtype) {
287 UniformQuantizedType input_x_qtype =
288 input_x_type.getElementType()
289 .dyn_cast<mlir::quant::UniformQuantizedType>();
290 UniformQuantizedType input_y_qtype =
291 input_y_type.getElementType()
292 .dyn_cast<mlir::quant::UniformQuantizedType>();
293
294 if (input_x_qtype.getScale() != input_y_qtype.getScale() ||
295 input_x_qtype.getZeroPoint() != input_y_qtype.getZeroPoint()) {
296 return op->emitOpError(
297 "ConvertTFLEqualOp: input_x and input_y scale/zp "
298 "must be the same");
299 }
300
301 Value op1_rescale_x = buildRescaleToInt32(
302 rewriter, op, tfl_equal_op.x(), 1.0f, input_x_qtype.getZeroPoint());
303 Value op2_rescale_y = buildRescaleToInt32(
304 rewriter, op, tfl_equal_op.y(), 1.0f, input_y_qtype.getZeroPoint());
305 auto op3_equal_op1_op2 = rewriter.create<tosa::EqualOp>(
306 op->getLoc(), output_type, op1_rescale_x, op2_rescale_y);
307
308 output = op3_equal_op1_op2.getResult();
309 } else {
310 auto op1_equal_in = rewriter.create<tosa::EqualOp>(
311 op->getLoc(), output_type, tfl_equal_op.x(), tfl_equal_op.y());
312
313 output = op1_equal_in.getResult();
314 }
315
316 rewriter.replaceOp(op, {output});
317 return success();
318 }
319
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const320 LogicalResult ConvertTFLNotEqualOp::matchAndRewrite(
321 Operation* op, PatternRewriter& rewriter) const {
322 auto tfl_not_equal_op = cast<TFL::NotEqualOp>(op);
323
324 RankedTensorType input_lhs_type =
325 tfl_not_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
326 RankedTensorType input_rhs_type =
327 tfl_not_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
328 RankedTensorType output_type =
329 tfl_not_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
330 // Not a ranked tensor output
331 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
332
333 bool input_lhs_is_qtype =
334 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
335 bool input_rhs_is_qtype =
336 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
337 bool output_is_qtype =
338 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
339
340 if (input_lhs_is_qtype != output_is_qtype ||
341 input_rhs_is_qtype != output_is_qtype) {
342 return op->emitOpError(
343 "ConvertTFLNotEqualOp: input/output tensor should "
344 "be all quantized or all floating-point.");
345 }
346
347 Value output;
348 if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
349 UniformQuantizedType input_lhs_qtype =
350 input_lhs_type.getElementType()
351 .dyn_cast<mlir::quant::UniformQuantizedType>();
352 UniformQuantizedType input_rhs_qtype =
353 input_rhs_type.getElementType()
354 .dyn_cast<mlir::quant::UniformQuantizedType>();
355
356 if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
357 input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
358 return op->emitOpError(
359 "ConvertTFLNotEqualOp: input_x and input_y scale/zp "
360 "must be the same");
361 }
362
363 Value op1_rescale_lhs =
364 buildRescaleToInt32(rewriter, op, tfl_not_equal_op.lhs(), 1.0f,
365 input_lhs_qtype.getZeroPoint());
366 Value op2_rescale_rhs =
367 buildRescaleToInt32(rewriter, op, tfl_not_equal_op.rhs(), 1.0f,
368 input_rhs_qtype.getZeroPoint());
369 auto op3_equal_op1_op2 = rewriter.create<tosa::EqualOp>(
370 op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
371 auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
372 op->getLoc(), output_type, op3_equal_op1_op2.getResult());
373
374 output = op4_not_op3.getResult();
375 } else {
376 auto op1_equal_in = rewriter.create<tosa::EqualOp>(
377 op->getLoc(), output_type, tfl_not_equal_op.lhs(),
378 tfl_not_equal_op.rhs());
379 auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
380 op->getLoc(), output_type, op1_equal_in.getResult());
381
382 output = op2_not_op1.getResult();
383 }
384
385 rewriter.replaceOp(op, {output});
386 return success();
387 }
388
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const389 LogicalResult ConvertTFLGreaterOp::matchAndRewrite(
390 Operation* op, PatternRewriter& rewriter) const {
391 auto tfl_greater_op = cast<TFL::GreaterOp>(op);
392
393 RankedTensorType input_lhs_type =
394 tfl_greater_op.lhs().getType().dyn_cast<RankedTensorType>();
395 RankedTensorType input_rhs_type =
396 tfl_greater_op.rhs().getType().dyn_cast<RankedTensorType>();
397 RankedTensorType output_type =
398 tfl_greater_op.getResult().getType().dyn_cast<RankedTensorType>();
399 // Not a ranked tensor output
400 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
401
402 bool input_lhs_is_qtype =
403 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
404 bool input_rhs_is_qtype =
405 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
406 bool output_is_qtype =
407 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
408
409 if (input_lhs_is_qtype != output_is_qtype ||
410 input_rhs_is_qtype != output_is_qtype) {
411 return op->emitOpError(
412 "ConvertTFLGreaterOp: input/output tensor should "
413 "be all quantized or all floating-point.");
414 }
415
416 Value output;
417 if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
418 UniformQuantizedType input_lhs_qtype =
419 input_lhs_type.getElementType()
420 .dyn_cast<mlir::quant::UniformQuantizedType>();
421 UniformQuantizedType input_rhs_qtype =
422 input_rhs_type.getElementType()
423 .dyn_cast<mlir::quant::UniformQuantizedType>();
424
425 if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
426 input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
427 return op->emitOpError(
428 "ConvertTFLGreaterOp: input_x and input_y scale/zp "
429 "must be the same");
430 }
431
432 Value op1_rescale_lhs =
433 buildRescaleToInt32(rewriter, op, tfl_greater_op.lhs(), 1.0f,
434 input_lhs_qtype.getZeroPoint());
435 Value op2_rescale_rhs =
436 buildRescaleToInt32(rewriter, op, tfl_greater_op.rhs(), 1.0f,
437 input_rhs_qtype.getZeroPoint());
438 auto op3_greater_op1_op2 = rewriter.create<tosa::GreaterOp>(
439 op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
440
441 output = op3_greater_op1_op2.getResult();
442 } else {
443 auto op1_greater_in = rewriter.create<tosa::GreaterOp>(
444 op->getLoc(), output_type, tfl_greater_op.lhs(), tfl_greater_op.rhs());
445
446 output = op1_greater_in.getResult();
447 }
448
449 rewriter.replaceOp(op, {output});
450 return success();
451 }
452
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const453 LogicalResult ConvertTFLGreaterEqualOp::matchAndRewrite(
454 Operation* op, PatternRewriter& rewriter) const {
455 auto tfl_greater_equal_op = cast<TFL::GreaterEqualOp>(op);
456
457 RankedTensorType input_lhs_type =
458 tfl_greater_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
459 RankedTensorType input_rhs_type =
460 tfl_greater_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
461 RankedTensorType output_type =
462 tfl_greater_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
463 // Not a ranked tensor output
464 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
465
466 bool input_lhs_is_qtype =
467 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
468 bool input_rhs_is_qtype =
469 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
470 bool output_is_qtype =
471 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
472
473 if (input_lhs_is_qtype != output_is_qtype ||
474 input_rhs_is_qtype != output_is_qtype) {
475 return op->emitOpError(
476 "ConvertTFLGreaterEqualOp: input/output tensor should "
477 "be all quantized or all floating-point.");
478 }
479
480 Value output;
481 if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
482 UniformQuantizedType input_lhs_qtype =
483 input_lhs_type.getElementType()
484 .dyn_cast<mlir::quant::UniformQuantizedType>();
485 UniformQuantizedType input_rhs_qtype =
486 input_rhs_type.getElementType()
487 .dyn_cast<mlir::quant::UniformQuantizedType>();
488
489 if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
490 input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
491 return op->emitOpError(
492 "ConvertTFLGreaterEqualOp: input_x and input_y scale/zp "
493 "must be the same");
494 }
495
496 Value op1_rescale_lhs =
497 buildRescaleToInt32(rewriter, op, tfl_greater_equal_op.lhs(), 1.0f,
498 input_lhs_qtype.getZeroPoint());
499 Value op2_rescale_rhs =
500 buildRescaleToInt32(rewriter, op, tfl_greater_equal_op.rhs(), 1.0f,
501 input_rhs_qtype.getZeroPoint());
502 auto op3_greater_equal_op1_op2 = rewriter.create<tosa::GreaterEqualOp>(
503 op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
504
505 output = op3_greater_equal_op1_op2.getResult();
506 } else {
507 auto op1_greater_equal_in = rewriter.create<tosa::GreaterEqualOp>(
508 op->getLoc(), output_type, tfl_greater_equal_op.lhs(),
509 tfl_greater_equal_op.rhs());
510
511 output = op1_greater_equal_in.getResult();
512 }
513
514 rewriter.replaceOp(op, {output});
515 return success();
516 }
517
518 // TODO: Use a utility function for common code in elementwise binary ops.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const519 LogicalResult ConvertTFLAddOp::matchAndRewrite(
520 Operation* op, PatternRewriter& rewriter) const {
521 auto tfl_add_op = cast<TFL::AddOp>(op);
522
523 RankedTensorType input_lhs_type =
524 tfl_add_op.lhs().getType().dyn_cast<RankedTensorType>();
525 RankedTensorType input_rhs_type =
526 tfl_add_op.rhs().getType().dyn_cast<RankedTensorType>();
527 RankedTensorType output_type =
528 tfl_add_op.getResult().getType().dyn_cast<RankedTensorType>();
529 // Not a ranked tensor output
530 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
531
532 bool input_lhs_is_qtype =
533 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
534 bool input_rhs_is_qtype =
535 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
536 bool output_is_qtype =
537 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
538
539 if (input_lhs_is_qtype != output_is_qtype ||
540 input_rhs_is_qtype != output_is_qtype) {
541 return op->emitOpError(
542 "ConvertTFLAddOp: input/output tensor should "
543 "be all quantized or all floating-point.");
544 }
545
546 Value output;
547 if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
548 RankedTensorType rescale_type =
549 RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
550 UniformQuantizedType input_lhs_qtype =
551 input_lhs_type.getElementType()
552 .dyn_cast<mlir::quant::UniformQuantizedType>();
553 UniformQuantizedType input_rhs_qtype =
554 input_rhs_type.getElementType()
555 .dyn_cast<mlir::quant::UniformQuantizedType>();
556 UniformQuantizedType output_qtype =
557 output_type.getElementType()
558 .dyn_cast<mlir::quant::UniformQuantizedType>();
559
560 // Following quantization described in tensorflow/lite/kernels/add.cc
561 // In details it does:
562 // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
563 // 2. Extra left shift to input to increase precision
564 // Where input_shift = 20 if input is 8-bit
565 // input_shift = 15 if input is 16-bit
566 // TODO: support 16-bit
567 double in_lhs_scale = input_lhs_qtype.getScale();
568 double in_rhs_scale = input_rhs_qtype.getScale();
569 double output_scale = output_qtype.getScale();
570 double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
571
572 const int32_t SHIFT_8_BIT = 20;
573 int32_t input_shift = SHIFT_8_BIT;
574
575 double lhs_rescale_scale =
576 static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
577 double rhs_rescale_scale =
578 static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
579 double output_rescale_scale =
580 max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
581
582 Value op1_rescale_lhs =
583 buildRescaleToInt32(rewriter, op, tfl_add_op.lhs(), lhs_rescale_scale,
584 input_lhs_qtype.getZeroPoint());
585 Value op2_rescale_rhs =
586 buildRescaleToInt32(rewriter, op, tfl_add_op.rhs(), rhs_rescale_scale,
587 input_rhs_qtype.getZeroPoint());
588 auto op3_add_op1_op2 = rewriter.create<tosa::AddOp>(
589 op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
590 Value op4_rescale_op3 = buildRescaleFromInt32(
591 rewriter, op, output_type, op3_add_op1_op2.getResult(),
592 output_rescale_scale, output_qtype.getZeroPoint());
593 output = op4_rescale_op3;
594 } else {
595 auto op1_add_in = rewriter.create<tosa::AddOp>(
596 op->getLoc(), output_type, tfl_add_op.lhs(), tfl_add_op.rhs());
597
598 output = op1_add_in.getResult();
599 }
600
601 auto fused_activation_fn = tfl_add_op.fused_activation_functionAttr();
602
603 if (fused_activation_fn) {
604 llvm::Optional<Value> fused_activation_val =
605 convertFusedActivation(rewriter, op, output, fused_activation_fn);
606
607 if (!fused_activation_val) return failure();
608
609 rewriter.replaceOp(op, {fused_activation_val.getValue()});
610 return success();
611 }
612
613 rewriter.replaceOp(op, {output});
614 return success();
615 }
616
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const617 LogicalResult ConvertTFLSubOp::matchAndRewrite(
618 Operation* op, PatternRewriter& rewriter) const {
619 auto tfl_sub_op = cast<TFL::SubOp>(op);
620
621 RankedTensorType input_lhs_type =
622 tfl_sub_op.lhs().getType().dyn_cast<RankedTensorType>();
623 RankedTensorType input_rhs_type =
624 tfl_sub_op.rhs().getType().dyn_cast<RankedTensorType>();
625 RankedTensorType output_type =
626 tfl_sub_op.getResult().getType().dyn_cast<RankedTensorType>();
627 // Not a ranked tensor output
628 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
629
630 bool input_lhs_is_qtype =
631 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
632 bool input_rhs_is_qtype =
633 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
634 bool output_is_qtype =
635 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
636
637 if (input_lhs_is_qtype != output_is_qtype ||
638 input_rhs_is_qtype != output_is_qtype) {
639 return op->emitOpError(
640 "ConvertTFLSubOp: input/output tensor should "
641 "be all quantized or all floating-point.");
642 }
643
644 Value output;
645 if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
646 RankedTensorType rescale_type =
647 RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
648 UniformQuantizedType input_lhs_qtype =
649 input_lhs_type.getElementType()
650 .cast<mlir::quant::UniformQuantizedType>();
651 UniformQuantizedType input_rhs_qtype =
652 input_rhs_type.getElementType()
653 .cast<mlir::quant::UniformQuantizedType>();
654 UniformQuantizedType output_qtype =
655 output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
656
657 // Following quantization described in tensorflow/lite/kernels/add.cc
658 // In details it does:
659 // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
660 // 2. Extra left shift to input to increase precision
661 // Where input_shift = 20 if input is 8-bit
662 // input_shift = 15 if input is 16-bit
663 // TODO: support 16-bit
664 double in_lhs_scale = input_lhs_qtype.getScale();
665 double in_rhs_scale = input_rhs_qtype.getScale();
666 double output_scale = output_qtype.getScale();
667 double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
668
669 const int32_t SHIFT_8_BIT = 20;
670 int32_t input_shift = SHIFT_8_BIT;
671
672 double lhs_rescale_scale =
673 static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
674 double rhs_rescale_scale =
675 static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
676 double output_rescale_scale =
677 max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
678
679 Value op1_rescale_lhs =
680 buildRescaleToInt32(rewriter, op, tfl_sub_op.lhs(), lhs_rescale_scale,
681 input_lhs_qtype.getZeroPoint());
682 Value op2_rescale_rhs =
683 buildRescaleToInt32(rewriter, op, tfl_sub_op.rhs(), rhs_rescale_scale,
684 input_rhs_qtype.getZeroPoint());
685 auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
686 op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
687 Value op4_rescale_op3 = buildRescaleFromInt32(
688 rewriter, op, output_type, op3_sub_op1_op2.getResult(),
689 output_rescale_scale, output_qtype.getZeroPoint());
690 output = op4_rescale_op3;
691 } else {
692 auto op1_sub_in = rewriter.create<tosa::SubOp>(
693 op->getLoc(), output_type, tfl_sub_op.lhs(), tfl_sub_op.rhs());
694
695 output = op1_sub_in.getResult();
696 }
697
698 auto fused_activation_fn = tfl_sub_op.fused_activation_functionAttr();
699
700 if (fused_activation_fn) {
701 llvm::Optional<Value> fused_activation_val =
702 convertFusedActivation(rewriter, op, output, fused_activation_fn);
703
704 if (!fused_activation_val) return failure();
705
706 rewriter.replaceOp(op, {fused_activation_val.getValue()});
707 return success();
708 }
709
710 rewriter.replaceOp(op, {output});
711 return success();
712 }
713
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const714 LogicalResult ConvertTFLMulOp::matchAndRewrite(
715 Operation* op, PatternRewriter& rewriter) const {
716 auto tfl_mul_op = cast<TFL::MulOp>(op);
717
718 llvm::Optional<Value> result = convertMultiplyOp(
719 rewriter, op, tfl_mul_op.getResult(), tfl_mul_op.lhs(), tfl_mul_op.rhs());
720
721 if (!result) return failure();
722
723 auto fused_activation_fn = tfl_mul_op.fused_activation_functionAttr();
724
725 if (fused_activation_fn) {
726 llvm::Optional<Value> fused_activation_val = convertFusedActivation(
727 rewriter, op, result.getValue(), fused_activation_fn);
728
729 if (!fused_activation_val) return failure();
730
731 rewriter.replaceOp(op, {fused_activation_val.getValue()});
732 return success();
733 }
734
735 rewriter.replaceOp(op, {result.getValue()});
736 return success();
737 }
738
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const739 LogicalResult ConvertTFLSquareOp::matchAndRewrite(
740 Operation* op, PatternRewriter& rewriter) const {
741 auto tfl_square_op = cast<TFL::SquareOp>(op);
742
743 llvm::Optional<Value> result =
744 convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
745 tfl_square_op.x(), tfl_square_op.x());
746
747 if (!result) return failure();
748
749 rewriter.replaceOp(op, {result.getValue()});
750 return success();
751 }
752
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const753 LogicalResult ConvertTFLSquaredDifferenceOp::matchAndRewrite(
754 Operation* op, PatternRewriter& rewriter) const {
755 auto tfl_squared_op = cast<TFL::SquaredDifferenceOp>(op);
756
757 llvm::Optional<Value> result =
758 convertSquaredDifferenceOp(rewriter, op, tfl_squared_op.getResult(),
759 tfl_squared_op.lhs(), tfl_squared_op.rhs());
760
761 if (!result) return failure();
762
763 rewriter.replaceOp(op, {result.getValue()});
764 return success();
765 }
766
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const767 LogicalResult ConvertTFLRoundOp::matchAndRewrite(
768 Operation* op, PatternRewriter& rewriter) const {
769 auto tfl_round_op = cast<TFL::RoundOp>(op);
770
771 RankedTensorType input_type =
772 tfl_round_op.x().getType().dyn_cast<RankedTensorType>();
773 if (!input_type) {
774 return op->emitOpError("Round: input not ranked tensor type");
775 }
776
777 if (input_type.getElementType().isa<FloatType>()) {
778 llvm::Optional<Value> result = convertRoundOp(
779 rewriter, op, tfl_round_op.getResult(), tfl_round_op.x());
780
781 if (!result) return failure();
782
783 rewriter.replaceOp(op, {result.getValue()});
784 return success();
785
786 } else {
787 // Round on int is nonsensical. Instead, replace uses of result with the
788 // input.
789 tfl_round_op.replaceAllUsesWith(tfl_round_op.x());
790 return success();
791 }
792 }
793
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const794 LogicalResult ConvertTFLDivOp::matchAndRewrite(
795 Operation* op, PatternRewriter& rewriter) const {
796 auto tfl_div_op = cast<TFL::DivOp>(op);
797
798 RankedTensorType output_type =
799 tfl_div_op.getResult().getType().dyn_cast<RankedTensorType>();
800 // Not a ranked tensor output
801 if (!output_type) return failure();
802
803 auto fused_activation_fn = tfl_div_op.fused_activation_functionAttr();
804
805 auto reciprocal_op = rewriter.create<tosa::ReciprocalOp>(
806 op->getLoc(), output_type, tfl_div_op.rhs());
807 auto mul_op =
808 rewriter.create<tosa::MulOp>(op->getLoc(), output_type, tfl_div_op.lhs(),
809 reciprocal_op.getResult(), 0);
810
811 if (fused_activation_fn) {
812 llvm::Optional<Value> fused_activation_val = convertFusedActivation(
813 rewriter, op, mul_op.getResult(), fused_activation_fn);
814
815 if (!fused_activation_val) return failure();
816
817 rewriter.replaceOp(op, {fused_activation_val.getValue()});
818 return success();
819 }
820
821 rewriter.replaceOp(op, {mul_op.getResult()});
822
823 return success();
824 }
825
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const826 LogicalResult ConvertTFLMaximumOp::matchAndRewrite(
827 Operation* op, PatternRewriter& rewriter) const {
828 auto tfl_max_op = cast<TFL::MaximumOp>(op);
829
830 RankedTensorType input_lhs_type =
831 tfl_max_op.lhs().getType().dyn_cast<RankedTensorType>();
832 RankedTensorType input_rhs_type =
833 tfl_max_op.rhs().getType().dyn_cast<RankedTensorType>();
834 RankedTensorType output_type =
835 tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
836
837 // Not a ranked tensor output
838 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
839
840 bool input_lhs_is_qtype =
841 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
842 bool input_rhs_is_qtype =
843 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
844 bool output_is_qtype =
845 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
846
847 if (input_lhs_is_qtype != output_is_qtype ||
848 input_rhs_is_qtype != output_is_qtype) {
849 return op->emitOpError(
850 "ConvertTFLMaximumOp: input/output tensor should "
851 "be all quantized or all floating-point.");
852 }
853
854 Value output;
855 if (output_is_qtype) {
856 RankedTensorType rescale_type =
857 RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
858
859 Value op1_rescale_lhs =
860 buildRescaleToInt32(rewriter, op, tfl_max_op.lhs(), 1.0f, 0);
861 Value op2_rescale_rhs =
862 buildRescaleToInt32(rewriter, op, tfl_max_op.rhs(), 1.0f, 0);
863 auto op3_max_op1_op2 = rewriter.create<tosa::MaximumOp>(
864 op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
865 Value op4_rescale_op3 = buildRescaleFromInt32(
866 rewriter, op, output_type, op3_max_op1_op2.getResult(), 1.0f, 0);
867
868 output = op4_rescale_op3;
869 } else {
870 auto op1_max_in = rewriter.create<tosa::MaximumOp>(
871 op->getLoc(), output_type, tfl_max_op.lhs(), tfl_max_op.rhs());
872
873 output = op1_max_in.getResult();
874 }
875
876 rewriter.replaceOp(op, {output});
877
878 return success();
879 }
880
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const881 LogicalResult ConvertTFLMinimumOp::matchAndRewrite(
882 Operation* op, PatternRewriter& rewriter) const {
883 auto tfl_min_op = cast<TFL::MinimumOp>(op);
884
885 RankedTensorType input_lhs_type =
886 tfl_min_op.lhs().getType().dyn_cast<RankedTensorType>();
887 RankedTensorType input_rhs_type =
888 tfl_min_op.rhs().getType().dyn_cast<RankedTensorType>();
889 RankedTensorType output_type =
890 tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
891 // Not a ranked tensor output
892 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
893
894 bool input_lhs_is_qtype =
895 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
896 bool input_rhs_is_qtype =
897 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
898 bool output_is_qtype =
899 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
900
901 if (input_lhs_is_qtype != output_is_qtype ||
902 input_rhs_is_qtype != output_is_qtype) {
903 return op->emitOpError(
904 "ConvertTFLMinimumOp: input/output tensor should "
905 "be all quantized or all floating-point.");
906 }
907
908 Value output;
909 if (output_is_qtype) {
910 RankedTensorType rescale_type =
911 RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
912
913 Value op1_rescale_lhs =
914 buildRescaleToInt32(rewriter, op, tfl_min_op.lhs(), 1.0f, 0);
915 Value op2_rescale_rhs =
916 buildRescaleToInt32(rewriter, op, tfl_min_op.rhs(), 1.0f, 0);
917 auto op3_min_op1_op2 = rewriter.create<tosa::MinimumOp>(
918 op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
919 Value op4_rescale_op3 = buildRescaleFromInt32(
920 rewriter, op, output_type, op3_min_op1_op2.getResult(), 1.0f, 0);
921
922 output = op4_rescale_op3;
923 } else {
924 auto op1_min_in = rewriter.create<tosa::MinimumOp>(
925 op->getLoc(), output_type, tfl_min_op.lhs(), tfl_min_op.rhs());
926
927 output = op1_min_in.getResult();
928 }
929
930 rewriter.replaceOp(op, {output});
931
932 return success();
933 }
934
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const935 LogicalResult ConvertTFLFloorDivOp::matchAndRewrite(
936 Operation* op, PatternRewriter& rewriter) const {
937 auto tfl_floordiv_op = cast<TFL::FloorDivOp>(op);
938
939 llvm::Optional<Value> result =
940 convertFloorDivOp(rewriter, op, tfl_floordiv_op.getResult(),
941 tfl_floordiv_op.lhs(), tfl_floordiv_op.rhs());
942
943 if (!result) return failure();
944
945 rewriter.replaceOp(op, {result.getValue()});
946
947 return success();
948 }
949
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const950 LogicalResult ConvertTFLFloorModOp::matchAndRewrite(
951 Operation* op, PatternRewriter& rewriter) const {
952 auto tfl_floormod_op = cast<TFL::FloorModOp>(op);
953
954 llvm::Optional<Value> result =
955 convertFloorModOp(rewriter, op, tfl_floormod_op.getResult(),
956 tfl_floormod_op.lhs(), tfl_floormod_op.rhs());
957
958 if (!result) return failure();
959
960 rewriter.replaceOp(op, {result.getValue()});
961
962 return success();
963 }
964
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const965 LogicalResult ConvertTFLAddNOp::matchAndRewrite(
966 Operation* op, PatternRewriter& rewriter) const {
967 auto tfl_addn_op = cast<TFL::AddNOp>(op);
968
969 RankedTensorType output_type =
970 tfl_addn_op.getResult().getType().dyn_cast<RankedTensorType>();
971 // Not a ranked tensor output
972 if (!output_type) return failure();
973
974 SmallVector<Value, 4> inputs(tfl_addn_op.inputs());
975
976 assert(inputs.size() >= 2);
977
978 auto newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type,
979 inputs[0], inputs[1]);
980 for (int i = 2; i < inputs.size(); i++) {
981 newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type, inputs[i],
982 newOp.getResult());
983 }
984
985 rewriter.replaceOp(op, {newOp.getResult()});
986
987 return success();
988 }
989
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const990 LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite(
991 Operation* op, PatternRewriter& rewriter) const {
992 auto tfl_avgpool_op = cast<TFL::AveragePool2DOp>(op);
993
994 RankedTensorType input_type =
995 tfl_avgpool_op.input().getType().dyn_cast<RankedTensorType>();
996 RankedTensorType output_type =
997 tfl_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
998 // Not a ranked tensor output
999 if (!output_type) return failure();
1000
1001 // Kernels and strides are dimensionally ordered
1002 SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
1003 ArrayAttr kernel_size;
1004 ArrayAttr stride;
1005 ArrayAttr pad;
1006 {
1007 int64_t kernel_h = tfl_avgpool_op.filter_height();
1008 int64_t kernel_w = tfl_avgpool_op.filter_width();
1009 kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
1010 // i64array is formatted as NHWC now
1011 i64array[1] = kernel_h;
1012 i64array[2] = kernel_w;
1013 }
1014 {
1015 int64_t stride_h = tfl_avgpool_op.stride_h();
1016 int64_t stride_w = tfl_avgpool_op.stride_w();
1017 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1018 }
1019 {
1020 tensorflow::Padding tf_pad;
1021 if (!GetPaddingFromString(tfl_avgpool_op.padding().str(), &tf_pad).ok())
1022 return failure();
1023
1024 // Pooling has no non-unit dilation
1025 ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
1026
1027 RankedTensorType filter_type = RankedTensorType::get(
1028 llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
1029
1030 // TFLite doesn't support explicit padding
1031 if (!getPaddingValuesFromPadType(
1032 tf_pad,
1033 tensorflow::FORMAT_NHWC, // TFLite only supports this
1034 1, // tensorflow::FORMAT_OHWI,
1035 input_type, filter_type, stride, dilation, rewriter, pad))
1036 return failure();
1037 }
1038
1039 rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
1040 op, output_type, tfl_avgpool_op.input(), kernel_size, stride, pad);
1041 return success();
1042 }
1043
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1044 LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite(
1045 Operation* op, PatternRewriter& rewriter) const {
1046 auto tfl_maxpool_op = cast<TFL::MaxPool2DOp>(op);
1047
1048 RankedTensorType input_type =
1049 tfl_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
1050 RankedTensorType output_type =
1051 tfl_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
1052 // Not a ranked tensor output
1053 if (!output_type) return failure();
1054
1055 // Kernels and strides are dimensionally ordered
1056 SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
1057 ArrayAttr kernel_size;
1058 ArrayAttr stride;
1059 ArrayAttr pad;
1060 {
1061 int64_t kernel_h = tfl_maxpool_op.filter_height();
1062 int64_t kernel_w = tfl_maxpool_op.filter_width();
1063 kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
1064 // i64array is formatted as NHWC now
1065 i64array[1] = kernel_h;
1066 i64array[2] = kernel_w;
1067 }
1068 {
1069 int64_t stride_h = tfl_maxpool_op.stride_h();
1070 int64_t stride_w = tfl_maxpool_op.stride_w();
1071 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1072 }
1073 {
1074 tensorflow::Padding tf_pad;
1075 if (!GetPaddingFromString(tfl_maxpool_op.padding().str(), &tf_pad).ok())
1076 return failure();
1077
1078 // Pooling has no non-unit dilation
1079 ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
1080
1081 RankedTensorType filter_type = RankedTensorType::get(
1082 llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
1083
1084 // TFLite doesn't support explicit padding
1085 if (!getPaddingValuesFromPadType(
1086 tf_pad,
1087 tensorflow::FORMAT_NHWC, // TFLite only supports this
1088 1, // tensorflow::FORMAT_OHWI,
1089 input_type, filter_type, stride, dilation, rewriter, pad))
1090 return failure();
1091 }
1092
1093 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
1094 op, output_type, tfl_maxpool_op.input(), kernel_size, stride, pad);
1095 return success();
1096 }
1097
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1098 LogicalResult ConvertTFLConv2DOp::matchAndRewrite(
1099 Operation* op, PatternRewriter& rewriter) const {
1100 auto tfl_conv2d_op = cast<TFL::Conv2DOp>(op);
1101
1102 RankedTensorType input_type =
1103 tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
1104 RankedTensorType filter_type =
1105 tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
1106 RankedTensorType output_type =
1107 tfl_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
1108 // Not a ranked tensor output
1109 if (!input_type) return failure();
1110 if (!output_type) return failure();
1111 if (!filter_type) return failure();
1112
1113 bool input_is_qtype =
1114 input_type.getElementType().isa<mlir::quant::QuantizedType>();
1115 bool filter_is_qtype =
1116 filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1117 bool output_is_qtype =
1118 output_type.getElementType().isa<mlir::quant::QuantizedType>();
1119
1120 if ((input_is_qtype != filter_is_qtype) ||
1121 (input_is_qtype != output_is_qtype)) {
1122 return op->emitOpError(
1123 "ConvertTFLConv2DOp: input/filter/output tensor should "
1124 "be all quantized or all floating-point.");
1125 }
1126
1127 ArrayAttr pad;
1128 ArrayAttr stride;
1129 ArrayAttr dilation;
1130 {
1131 int64_t stride_h = tfl_conv2d_op.stride_h();
1132 int64_t stride_w = tfl_conv2d_op.stride_w();
1133 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1134 }
1135 {
1136 int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
1137 int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
1138 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
1139 }
1140 {
1141 tensorflow::Padding tf_pad;
1142 if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
1143 return failure();
1144
1145 // TFLite doesn't support explicit padding
1146 if (!getPaddingValuesFromPadType(
1147 tf_pad,
1148 tensorflow::FORMAT_NHWC, // TFLite only supports this
1149 1, // tensorflow::FORMAT_OHWI,
1150 input_type, filter_type, stride, dilation, rewriter, pad))
1151 return failure();
1152 }
1153
1154 Value unquantized_bias =
1155 getUnquantizedBias(rewriter, op, tfl_conv2d_op.bias());
1156
1157 auto a1_conv2d_op = rewriter.create<tosa::Conv2DOp>(
1158 op->getLoc(), output_type, tfl_conv2d_op.input(), tfl_conv2d_op.filter(),
1159 unquantized_bias, pad, stride, dilation);
1160
1161 Value conv2d_output;
1162 if (input_is_qtype) {
1163 conv2d_output =
1164 buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
1165 input_type, filter_type, output_type);
1166 } else {
1167 conv2d_output = a1_conv2d_op.getResult();
1168 }
1169
1170 auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
1171
1172 if (fused_activation_fn) {
1173 llvm::Optional<Value> fused_activation_val = convertFusedActivation(
1174 rewriter, op, conv2d_output, fused_activation_fn);
1175
1176 if (!fused_activation_val) return failure();
1177
1178 rewriter.replaceOp(op, {fused_activation_val.getValue()});
1179 return success();
1180 }
1181
1182 rewriter.replaceOp(op, {conv2d_output});
1183
1184 return success();
1185 }
1186
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1187 LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite(
1188 Operation* op, PatternRewriter& rewriter) const {
1189 auto tfl_conv_op = cast<TFL::TransposeConvOp>(op);
1190
1191 RankedTensorType input_type =
1192 tfl_conv_op.input().getType().dyn_cast<RankedTensorType>();
1193 RankedTensorType filter_type =
1194 tfl_conv_op.weights().getType().dyn_cast<RankedTensorType>();
1195 RankedTensorType output_type =
1196 tfl_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
1197 // Not a ranked tensor output
1198 if (!input_type) return failure();
1199 if (!output_type) return failure();
1200 if (!filter_type) return failure();
1201
1202 bool input_is_qtype =
1203 input_type.getElementType().isa<mlir::quant::QuantizedType>();
1204 bool filter_is_qtype =
1205 filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1206 bool output_is_qtype =
1207 output_type.getElementType().isa<mlir::quant::QuantizedType>();
1208
1209 if ((input_is_qtype != filter_is_qtype) ||
1210 (input_is_qtype != output_is_qtype)) {
1211 return op->emitOpError(
1212 "ConvertTFLConv2DOp: input/filter/output tensor should "
1213 "be all quantized or all floating-point.");
1214 }
1215
1216 ArrayAttr stride;
1217 ArrayAttr dilation;
1218 ArrayAttr outpad;
1219 ArrayAttr output_shape;
1220 {
1221 int64_t stride_h = tfl_conv_op.stride_h();
1222 int64_t stride_w = tfl_conv_op.stride_w();
1223 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1224 }
1225
1226 // tfl.transpose_conv doesn't support dilations
1227 dilation = rewriter.getI64ArrayAttr({1, 1});
1228
1229 {
1230 tensorflow::Padding tf_pad;
1231 if (!GetPaddingFromString(tfl_conv_op.padding().str(), &tf_pad).ok())
1232 return failure();
1233
1234 if (!getTransposeConv2dPaddingValues(
1235 tf_pad,
1236 tensorflow::FORMAT_NHWC, // TFLite only supports this
1237 1, // tensorflow::FORMAT_OHWI,
1238 input_type, filter_type, output_type, stride, dilation, rewriter,
1239 outpad))
1240 return failure();
1241 }
1242 {
1243 ElementsAttr output_shape_elems;
1244 // Match from input_size tensor first
1245 if (matchPattern(tfl_conv_op.output_shape(),
1246 m_Constant(&output_shape_elems))) {
1247 llvm::SmallVector<int64_t, 4> shape_vec;
1248 for (int i = 0; i < output_shape_elems.getNumElements(); i++)
1249 shape_vec.push_back(
1250 output_shape_elems.getValue<IntegerAttr>(i).getInt());
1251 output_shape = rewriter.getI64ArrayAttr(shape_vec);
1252 } else {
1253 // Use output tensor's shape otherwise
1254 output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
1255 }
1256 }
1257
1258 Value zero_bias;
1259 if (input_is_qtype) {
1260 uint32_t input_bits = input_type.getElementType()
1261 .dyn_cast<mlir::quant::QuantizedType>()
1262 .getStorageTypeIntegralWidth();
1263 uint32_t weight_bits = filter_type.getElementType()
1264 .dyn_cast<mlir::quant::QuantizedType>()
1265 .getStorageTypeIntegralWidth();
1266
1267 if (input_bits == 16 && weight_bits == 8) {
1268 SmallVector<int64_t, 8> zero_bias_vec(output_type.getShape()[3], 0);
1269 zero_bias = get1DConstTensorInt48(rewriter, op, zero_bias_vec);
1270 } else {
1271 SmallVector<int32_t, 8> zero_bias_vec(output_type.getShape()[3], 0);
1272 zero_bias =
1273 get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, zero_bias_vec);
1274 }
1275 } else {
1276 SmallVector<float, 8> zero_bias_vec(output_type.getShape()[3], 0.0f);
1277 zero_bias =
1278 get1DConstTensor<tosa::ConstOp, float>(rewriter, op, zero_bias_vec);
1279 }
1280
1281 auto a1_conv2d_op = rewriter.create<tosa::TransposeConv2DOp>(
1282 op->getLoc(), output_type, tfl_conv_op.input(), tfl_conv_op.weights(),
1283 zero_bias, outpad, stride, dilation, output_shape);
1284
1285 Value conv2d_output;
1286 if (input_is_qtype) {
1287 conv2d_output =
1288 buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
1289 input_type, filter_type, output_type);
1290 } else {
1291 conv2d_output = a1_conv2d_op.getResult();
1292 }
1293
1294 rewriter.replaceOp(op, {conv2d_output});
1295
1296 return success();
1297 }
1298
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1299 LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite(
1300 Operation* op, PatternRewriter& rewriter) const {
1301 auto tfl_conv2d_op = cast<TFL::DepthwiseConv2DOp>(op);
1302
1303 RankedTensorType input_type =
1304 tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
1305 RankedTensorType filter_type =
1306 tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
1307 RankedTensorType output_type =
1308 tfl_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
1309 // Not a ranked tensor output
1310 if (!input_type) return failure();
1311 if (!output_type) return failure();
1312 if (!filter_type) return failure();
1313
1314 bool input_is_qtype =
1315 input_type.getElementType().isa<mlir::quant::QuantizedType>();
1316 bool filter_is_qtype =
1317 filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1318 bool output_is_qtype =
1319 output_type.getElementType().isa<mlir::quant::QuantizedType>();
1320
1321 if ((input_is_qtype != filter_is_qtype) ||
1322 (input_is_qtype != output_is_qtype)) {
1323 return op->emitOpError(
1324 "ConvertTFLConv2DOp: input/filter/output tensor should "
1325 "be all quantized or all floating-point.");
1326 }
1327
1328 auto filter_shape = filter_type.getShape();
1329 // Operator depthwiseConv2D
1330 // TFLite orders the depthwiseConv2D filter in IHWO, while TOSA orders
1331 // filter in HWIO
1332 //
1333 // The lowering reorders the filter.
1334 //
1335 // a1_transpose = tosa.transpose(filter, {1, 2, 3, 0}) // HWIO
1336 // a2_reshape = tosa.reshape(filter, H, W, depth_multiplier, I /
1337 // depth_multiplier)
1338 // a3_transpose_conv2d = tosa.transpose_conv2d(input, a2_reshape, padding,
1339 // stride, dilation)
1340
1341 ArrayAttr pad;
1342 ArrayAttr stride;
1343 ArrayAttr dilation;
1344 auto depth_multiplier = tfl_conv2d_op.depth_multiplierAttr();
1345
1346 {
1347 int64_t stride_h = tfl_conv2d_op.stride_h();
1348 int64_t stride_w = tfl_conv2d_op.stride_w();
1349 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1350 }
1351 {
1352 int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
1353 int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
1354 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
1355 }
1356 {
1357 tensorflow::Padding tf_pad;
1358 if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
1359 return failure();
1360
1361 if (!getPaddingValuesFromPadType(
1362 tf_pad,
1363 tensorflow::FORMAT_NHWC, // TFLite only supports this
1364 1, // tensorflow::FORMAT_OHWI,
1365 input_type, filter_type, stride, dilation, rewriter, pad))
1366 return failure();
1367 }
1368
1369 llvm::SmallVector<int64_t, 4> a1_transpose_dims;
1370 a1_transpose_dims.push_back(filter_shape[1]);
1371 a1_transpose_dims.push_back(filter_shape[2]);
1372 a1_transpose_dims.push_back(filter_shape[3]);
1373 a1_transpose_dims.push_back(filter_shape[0]);
1374
1375 llvm::SmallVector<int64_t, 4> a2_reshape_dims;
1376 a2_reshape_dims.push_back(a1_transpose_dims[0]);
1377 a2_reshape_dims.push_back(a1_transpose_dims[1]);
1378 a2_reshape_dims.push_back(a1_transpose_dims[2] / depth_multiplier.getInt());
1379 a2_reshape_dims.push_back(depth_multiplier.getInt());
1380
1381 Value a1_filter_transpose_perms =
1382 get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {1, 2, 3, 0});
1383 auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
1384 op->getLoc(),
1385 RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
1386 filter_type.getElementType()),
1387 tfl_conv2d_op.filter(), a1_filter_transpose_perms);
1388
1389 auto a2_filter_reshape_op = rewriter.create<tosa::ReshapeOp>(
1390 op->getLoc(),
1391 RankedTensorType::get(ArrayRef<int64_t>(a2_reshape_dims),
1392 filter_type.getElementType()),
1393 a1_filter_transpose_op.getResult(),
1394 rewriter.getI64ArrayAttr(a2_reshape_dims));
1395
1396 Value unquantized_bias =
1397 getUnquantizedBias(rewriter, op, tfl_conv2d_op.bias());
1398
1399 auto a3_depthwise_conv2d_op = rewriter.create<tosa::DepthwiseConv2DOp>(
1400 op->getLoc(), output_type, tfl_conv2d_op.input(),
1401 a2_filter_reshape_op.getResult(), unquantized_bias, pad, stride,
1402 dilation);
1403
1404 Value conv2d_output;
1405 if (input_is_qtype) {
1406 conv2d_output = buildRescaleOpConvOutput(
1407 rewriter, op, a3_depthwise_conv2d_op.getResult(), input_type,
1408 filter_type, output_type);
1409 } else {
1410 conv2d_output = a3_depthwise_conv2d_op.getResult();
1411 }
1412
1413 auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
1414
1415 if (fused_activation_fn) {
1416 llvm::Optional<Value> fused_activation_val = convertFusedActivation(
1417 rewriter, op, conv2d_output, fused_activation_fn);
1418
1419 if (!fused_activation_val) return failure();
1420
1421 rewriter.replaceOp(op, {fused_activation_val.getValue()});
1422 return success();
1423 }
1424
1425 rewriter.replaceOp(op, {conv2d_output});
1426
1427 return success();
1428 }
1429
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1430 LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite(
1431 Operation* op, PatternRewriter& rewriter) const {
1432 auto tfl_fc_op = cast<TFL::FullyConnectedOp>(op);
1433
1434 RankedTensorType output_type =
1435 tfl_fc_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1436 // Not a ranked tensor output
1437 if (!output_type) return failure();
1438
1439 RankedTensorType input_type =
1440 tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
1441 RankedTensorType filter_type =
1442 tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
1443 RankedTensorType bias_type =
1444 tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
1445 if (!input_type || !filter_type) return failure();
1446
1447 bool input_is_qtype =
1448 input_type.getElementType().isa<mlir::quant::QuantizedType>();
1449 bool filter_is_qtype =
1450 filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1451 bool output_is_qtype =
1452 output_type.getElementType().isa<mlir::quant::QuantizedType>();
1453
1454 if ((input_is_qtype != filter_is_qtype) ||
1455 (input_is_qtype != output_is_qtype)) {
1456 return op->emitOpError(
1457 "ConvertTFLFullyConnectedOp: input/filter/output tensor should "
1458 "be all quantized or all floating-point.");
1459 }
1460
1461 Value input_val = tfl_fc_op.input();
1462
1463 // tfl.fully_connected() can takes various dimension tensor as input
1464 // need to reshape it to rank 2 tensor, which tosa.fully_connected only
1465 // supports if input tensor is rank 4. It's not always reshaping to (dim[0] *
1466 // dim[1], dim[2] * dim[3]).
1467
1468 // In some networks it's reshaping to (dim[0], dim[1] * dim[2] * dim[3]) so a
1469 // more general way to determine the reshape's shape is by looking at filter's
1470 // shape[1].
1471 if (input_type.getRank() != 2) {
1472 int64_t num_elems = filter_type.getShape()[1];
1473 int64_t num_batch = input_type.getNumElements() / num_elems;
1474 SmallVector<int64_t, 2> shape_vals({num_batch, num_elems});
1475
1476 RankedTensorType reshape_type = RankedTensorType::get(
1477 ArrayRef<int64_t>(shape_vals), input_type.getElementType());
1478 auto reshape_op = rewriter.create<tosa::ReshapeOp>(
1479 op->getLoc(), reshape_type, tfl_fc_op.input(),
1480 rewriter.getI64ArrayAttr(shape_vals));
1481
1482 input_val = reshape_op.getResult();
1483 }
1484
1485 Value bias_val;
1486 if (!bias_type) {
1487 // For some matmuls, the bias may actually be a "UnitType" which has no
1488 // value. TOSA requires bias to be an array of output_channel_count values,
1489 // so create a constant of the appropriate number and type of zeros.
1490 SmallVector<int64_t, 1> bias_shape({filter_type.getShape()[0]});
1491 RankedTensorType bias_type = RankedTensorType::get(
1492 ArrayRef<int64_t>(bias_shape), input_type.getElementType());
1493
1494 DenseElementsAttr bias_attr;
1495 if (input_type.getElementType().isa<FloatType>()) {
1496 SmallVector<float, 2> bias_arr(bias_shape[0]);
1497
1498 for (int i = 0; i < bias_shape[0]; i++) {
1499 bias_arr[i] = 0.0;
1500 }
1501 // TODO: implicit cast suggest instead of makeArrayRef but triggers
1502 // build error.
1503 bias_attr = DenseElementsAttr::get(bias_type,
1504 llvm::makeArrayRef<float>(bias_arr));
1505 } else {
1506 SmallVector<int32_t, 2> bias_arr(bias_shape[0]);
1507
1508 for (int i = 0; i < bias_shape[0]; i++) {
1509 bias_arr[i] = 0;
1510 }
1511 bias_attr = DenseElementsAttr::get(bias_type,
1512 llvm::makeArrayRef<int32_t>(bias_arr));
1513 }
1514 auto bias_op =
1515 rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type, bias_attr);
1516 bias_val = bias_op.getResult();
1517 } else {
1518 bias_val = getUnquantizedBias(rewriter, op, tfl_fc_op.bias());
1519 }
1520
1521 auto fc_op = rewriter.create<tosa::FullyConnectedOp>(
1522 op->getLoc(), output_type, input_val, tfl_fc_op.filter(), bias_val);
1523
1524 Value fc_output;
1525 if (input_is_qtype) {
1526 fc_output = buildRescaleOpConvOutput(rewriter, op, fc_op.getResult(),
1527 input_type, filter_type, output_type);
1528 } else {
1529 fc_output = fc_op.getResult();
1530 }
1531
1532 auto fused_activation_fn = tfl_fc_op.fused_activation_functionAttr();
1533
1534 if (fused_activation_fn) {
1535 llvm::Optional<Value> fused_activation_val =
1536 convertFusedActivation(rewriter, op, fc_output, fused_activation_fn);
1537
1538 if (!fused_activation_val) return failure();
1539
1540 rewriter.replaceOp(op, {fused_activation_val.getValue()});
1541 return success();
1542 }
1543
1544 rewriter.replaceOp(op, {fc_output});
1545
1546 return success();
1547 }
1548
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1549 LogicalResult ConvertTFLConcatenationOp::matchAndRewrite(
1550 Operation* op, PatternRewriter& rewriter) const {
1551 auto tfl_concat_op = cast<TFL::ConcatenationOp>(op);
1552
1553 SmallVector<Value, 8> values(tfl_concat_op.values());
1554
1555 IntegerAttr axis_attr;
1556 {
1557 auto tmpAttr = tfl_concat_op.axisAttr();
1558 if (!tmpAttr) {
1559 tmpAttr = rewriter.getI64IntegerAttr(0);
1560 }
1561 axis_attr = tmpAttr;
1562 }
1563 int32_t axis = axis_attr.getInt();
1564
1565 llvm::Optional<Value> result =
1566 convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
1567
1568 if (!result) return failure();
1569
1570 rewriter.replaceOp(op, {result.getValue()});
1571 return success();
1572 }
1573
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1574 LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
1575 Operation* op, PatternRewriter& rewriter) const {
1576 auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
1577
1578 RankedTensorType output_type =
1579 tfl_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
1580 // Not a ranked tensor output
1581 if (!output_type) return failure();
1582
1583 SmallVector<int64_t, 8> shape_vals;
1584 for (int i = 0; i < output_type.getShape().size(); i++) {
1585 shape_vals.push_back(output_type.getShape()[i]);
1586 }
1587 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals);
1588
1589 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1590 op, output_type, tfl_reshape_op.input(), shape_attr);
1591 return success();
1592 }
1593
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1594 LogicalResult ConvertTFLRankOp::matchAndRewrite(
1595 Operation* op, PatternRewriter& rewriter) const {
1596 auto tfl_rank_op = cast<TFL::RankOp>(op);
1597
1598 RankedTensorType input_type =
1599 tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
1600 if (!input_type) return failure();
1601
1602 int32_t rank = input_type.getRank();
1603
1604 RankedTensorType rank_type =
1605 RankedTensorType::get({1}, rewriter.getIntegerType(32));
1606 auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
1607 auto rank_const =
1608 rewriter.create<tosa::ConstOp>(op->getLoc(), rank_type, rank_attr);
1609
1610 rewriter.replaceOp(op, {rank_const.getResult()});
1611
1612 return success();
1613 }
1614
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1615 LogicalResult ConvertTFLShapeOp::matchAndRewrite(
1616 Operation* op, PatternRewriter& rewriter) const {
1617 auto tfl_shape_op = cast<TFL::ShapeOp>(op);
1618
1619 RankedTensorType output_type =
1620 tfl_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
1621 // Not a ranked tensor output
1622 if (!output_type) return failure();
1623
1624 RankedTensorType input_type =
1625 tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
1626 if (!input_type) return failure();
1627
1628 auto input_shape = input_type.getShape();
1629
1630 SmallVector<int32_t, 8> shape_arr;
1631 for (int i = 0; i < input_shape.size(); i++) {
1632 shape_arr.emplace_back(input_shape[i]);
1633 }
1634
1635 RankedTensorType shape_type = RankedTensorType::get(
1636 {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
1637 auto shape_attr = DenseElementsAttr::get(
1638 shape_type, llvm::makeArrayRef<int32_t>(shape_arr));
1639 auto shape_const =
1640 rewriter.create<tosa::ConstOp>(op->getLoc(), shape_type, shape_attr);
1641
1642 rewriter.replaceOp(op, {shape_const.getResult()});
1643
1644 return success();
1645 }
1646
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1647 LogicalResult ConvertTFLExpandDimsOp::matchAndRewrite(
1648 Operation* op, PatternRewriter& rewriter) const {
1649 auto tfl_expanddims_op = cast<TFL::ExpandDimsOp>(op);
1650
1651 llvm::Optional<Value> result =
1652 convertExpandDimsOp(rewriter, op, tfl_expanddims_op.getResult(),
1653 tfl_expanddims_op.input(), tfl_expanddims_op.dim());
1654
1655 if (!result) return failure();
1656
1657 rewriter.replaceOp(op, {result.getValue()});
1658
1659 return success();
1660 }
1661
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1662 LogicalResult ConvertTFLSqueezeOp::matchAndRewrite(
1663 Operation* op, PatternRewriter& rewriter) const {
1664 auto tfl_squeeze_op = cast<TFL::SqueezeOp>(op);
1665
1666 // Copy squeeze_dims into int32_t array
1667 auto squeeze_dims_attr = tfl_squeeze_op.squeeze_dimsAttr();
1668 SmallVector<int32_t, 8> squeeze_dims;
1669 for (auto& squeeze_dim : squeeze_dims_attr) {
1670 squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
1671 }
1672
1673 llvm::Optional<Value> result =
1674 convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
1675 tfl_squeeze_op.input(), squeeze_dims);
1676
1677 if (!result) return failure();
1678
1679 rewriter.replaceOp(op, {result.getValue()});
1680
1681 return success();
1682 }
1683
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1684 LogicalResult ConvertTFLFillOp::matchAndRewrite(
1685 Operation* op, PatternRewriter& rewriter) const {
1686 auto tfl_fill_op = cast<TFL::FillOp>(op);
1687
1688 RankedTensorType output_type =
1689 tfl_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
1690 // Not a ranked tensor output
1691 if (!output_type) return failure();
1692
1693 ElementsAttr dims_elems;
1694 if (!matchPattern(tfl_fill_op.dims(), m_Constant(&dims_elems)))
1695 return failure();
1696 SmallVector<int64_t, 4> dims_vals;
1697 uint32_t total_size = 1;
1698 for (int i = 0; i < dims_elems.getNumElements(); i++) {
1699 dims_vals.push_back(dims_elems.getValue<IntegerAttr>(i).getInt());
1700 total_size *= dims_vals[i];
1701 }
1702
1703 ElementsAttr value_elem;
1704 if (!matchPattern(tfl_fill_op.input(), m_Constant(&value_elem)))
1705 return failure();
1706
1707 RankedTensorType fill_type = RankedTensorType::get(
1708 ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
1709 DenseElementsAttr fill_attr;
1710
1711 // Convert to a compatible zero type.
1712 if (value_elem.getType().getElementType().isa<FloatType>()) {
1713 llvm::SmallVector<float, 4> fill_arr(
1714 total_size,
1715 value_elem.getValue<FloatAttr>(0).getValue().convertToFloat());
1716 fill_attr =
1717 DenseElementsAttr::get(fill_type, llvm::makeArrayRef<float>(fill_arr));
1718 } else {
1719 llvm::SmallVector<int32_t, 4> fill_arr(
1720 total_size,
1721 value_elem.getValue<IntegerAttr>(0).getValue().getLimitedValue());
1722 fill_attr = DenseElementsAttr::get(fill_type,
1723 llvm::makeArrayRef<int32_t>(fill_arr));
1724 }
1725 auto fill_const_op =
1726 rewriter.create<tosa::ConstOp>(op->getLoc(), fill_type, fill_attr);
1727 rewriter.replaceOp(op, {fill_const_op.getResult()});
1728
1729 return success();
1730 }
1731
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1732 LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite(
1733 Operation* op, PatternRewriter& rewriter) const {
1734 auto tfl_any_op = cast<TFL::ReduceAnyOp>(op);
1735
1736 RankedTensorType output_type =
1737 tfl_any_op.getResult().getType().dyn_cast<RankedTensorType>();
1738 if (!output_type) return failure();
1739
1740 ElementsAttr axes_elems;
1741 if (!matchPattern(tfl_any_op.reduction_indices(), m_Constant(&axes_elems)))
1742 return failure();
1743
1744 bool keep_dims = false;
1745 auto keep_dims_attr = tfl_any_op.keep_dimsAttr();
1746 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1747
1748 llvm::Optional<Value> result = convertReduceAnyOp(
1749 rewriter, op, output_type, tfl_any_op.input(), axes_elems, keep_dims);
1750
1751 if (!result) return failure();
1752
1753 rewriter.replaceOp(op, {result.getValue()});
1754
1755 return success();
1756 }
1757
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1758 LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite(
1759 Operation* op, PatternRewriter& rewriter) const {
1760 auto tfl_max_op = cast<TFL::ReduceMaxOp>(op);
1761
1762 RankedTensorType output_type =
1763 tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
1764 if (!output_type) return failure();
1765
1766 ElementsAttr axes_elems;
1767 if (!matchPattern(tfl_max_op.axes(), m_Constant(&axes_elems)))
1768 return failure();
1769
1770 bool keep_dims = false;
1771 auto keep_dims_attr = tfl_max_op.keep_dimsAttr();
1772 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1773
1774 llvm::Optional<Value> result = convertReduceMaxOp(
1775 rewriter, op, output_type, tfl_max_op.input(), axes_elems, keep_dims);
1776
1777 if (!result) return failure();
1778
1779 rewriter.replaceOp(op, {result.getValue()});
1780
1781 return success();
1782 }
1783
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1784 LogicalResult ConvertTFLReduceMinOp::matchAndRewrite(
1785 Operation* op, PatternRewriter& rewriter) const {
1786 auto tfl_min_op = cast<TFL::ReduceMinOp>(op);
1787
1788 RankedTensorType output_type =
1789 tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
1790 if (!output_type) return failure();
1791
1792 ElementsAttr axes_elems;
1793 if (!matchPattern(tfl_min_op.axes(), m_Constant(&axes_elems)))
1794 return failure();
1795
1796 bool keep_dims = false;
1797 auto keep_dims_attr = tfl_min_op.keep_dimsAttr();
1798 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1799
1800 llvm::Optional<Value> result = convertReduceMinOp(
1801 rewriter, op, output_type, tfl_min_op.input(), axes_elems, keep_dims);
1802
1803 if (!result) return failure();
1804
1805 rewriter.replaceOp(op, {result.getValue()});
1806
1807 return success();
1808 }
1809
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1810 LogicalResult ConvertTFLReduceProdOp::matchAndRewrite(
1811 Operation* op, PatternRewriter& rewriter) const {
1812 auto tfl_prod_op = cast<TFL::ReduceProdOp>(op);
1813
1814 RankedTensorType output_type =
1815 tfl_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
1816 if (!output_type) return failure();
1817
1818 ElementsAttr axes_elems;
1819 if (!matchPattern(tfl_prod_op.axes(), m_Constant(&axes_elems)))
1820 return failure();
1821
1822 bool keep_dims = false;
1823 auto keep_dims_attr = tfl_prod_op.keep_dimsAttr();
1824 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1825
1826 llvm::Optional<Value> result = convertReduceProdOp(
1827 rewriter, op, output_type, tfl_prod_op.input(), axes_elems, keep_dims);
1828
1829 if (!result) return failure();
1830
1831 rewriter.replaceOp(op, {result.getValue()});
1832
1833 return success();
1834 }
1835
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1836 LogicalResult ConvertTFLMeanOp::matchAndRewrite(
1837 Operation* op, PatternRewriter& rewriter) const {
1838 auto tfl_mean_op = cast<TFL::MeanOp>(op);
1839
1840 RankedTensorType output_type =
1841 tfl_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
1842 if (!output_type) return failure();
1843
1844 ElementsAttr axes_elems;
1845 if (!matchPattern(tfl_mean_op.axis(), m_Constant(&axes_elems)))
1846 return failure();
1847
1848 bool keep_dims = false;
1849 auto keep_dims_attr = tfl_mean_op.keep_dimsAttr();
1850 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1851
1852 llvm::Optional<Value> result = convertReduceMeanOp(
1853 rewriter, op, output_type, tfl_mean_op.input(), axes_elems, keep_dims);
1854
1855 if (!result) return failure();
1856
1857 rewriter.replaceOp(op, {result.getValue()});
1858
1859 return success();
1860 }
1861
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1862 LogicalResult ConvertTFLSumOp::matchAndRewrite(
1863 Operation* op, PatternRewriter& rewriter) const {
1864 auto tfl_sum_op = cast<TFL::SumOp>(op);
1865
1866 RankedTensorType output_type =
1867 tfl_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
1868 if (!output_type) return failure();
1869
1870 ElementsAttr axes_elems;
1871 if (!matchPattern(tfl_sum_op.axes(), m_Constant(&axes_elems)))
1872 return failure();
1873
1874 bool keep_dims = false;
1875 auto keep_dims_attr = tfl_sum_op.keep_dimsAttr();
1876 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1877
1878 llvm::Optional<Value> result = convertReduceSumOp(
1879 rewriter, op, output_type, tfl_sum_op.input(), axes_elems, keep_dims);
1880
1881 if (!result) return failure();
1882
1883 rewriter.replaceOp(op, {result.getValue()});
1884
1885 return success();
1886 }
1887
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1888 LogicalResult ConvertTFLEluOp::matchAndRewrite(
1889 Operation* op, PatternRewriter& rewriter) const {
1890 auto tfl_elu_op = cast<TFL::EluOp>(op);
1891
1892 llvm::Optional<Value> result =
1893 convertEluOp(rewriter, op, tfl_elu_op.getResult(), tfl_elu_op.x());
1894
1895 if (!result) return failure();
1896
1897 rewriter.replaceOp(op, {result.getValue()});
1898
1899 return success();
1900 }
1901
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1902 LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite(
1903 Operation* op, PatternRewriter& rewriter) const {
1904 auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
1905
1906 llvm::Optional<Value> result = convertSoftmaxOp(
1907 rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input());
1908
1909 if (!result) return failure();
1910
1911 rewriter.replaceOp(op, {result.getValue()});
1912
1913 return success();
1914 }
1915
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1916 LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite(
1917 Operation* op, PatternRewriter& rewriter) const {
1918 auto tfl_logsoftmax_op = cast<TFL::LogSoftmaxOp>(op);
1919
1920 llvm::Optional<Value> result = convertLogSoftmaxOp(
1921 rewriter, op, tfl_logsoftmax_op.getResult(), tfl_logsoftmax_op.input());
1922
1923 if (!result) return failure();
1924
1925 rewriter.replaceOp(op, {result.getValue()});
1926
1927 return success();
1928 }
1929
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1930 LogicalResult ConvertTFLSliceOp::matchAndRewrite(
1931 Operation* op, PatternRewriter& rewriter) const {
1932 auto tfl_slice_op = cast<TFL::SliceOp>(op);
1933
1934 RankedTensorType output_type =
1935 tfl_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
1936 // Not a ranked tensor output
1937 if (!output_type) return failure();
1938
1939 ElementsAttr begin_elems, size_elems;
1940
1941 SmallVector<int64_t, 4> begin_vals, size_vals;
1942
1943 if (!matchPattern(tfl_slice_op.begin(), m_Constant(&begin_elems)) ||
1944 !matchPattern(tfl_slice_op.size(), m_Constant(&size_elems))) {
1945 return failure();
1946 }
1947
1948 for (int i = 0; i < begin_elems.getNumElements(); i++)
1949 begin_vals.push_back(begin_elems.getValue<IntegerAttr>(i).getInt());
1950
1951 for (int i = 0; i < size_elems.getNumElements(); i++)
1952 size_vals.push_back(size_elems.getValue<IntegerAttr>(i).getInt());
1953
1954 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1955 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1956
1957 rewriter.replaceOpWithNewOp<tosa::SliceOp>(op, output_type,
1958 tfl_slice_op.input(), begin, size);
1959 return success();
1960 }
1961
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1962 LogicalResult ConvertTFLTileOp::matchAndRewrite(
1963 Operation* op, PatternRewriter& rewriter) const {
1964 auto tfl_tile_op = cast<TFL::TileOp>(op);
1965
1966 RankedTensorType output_type =
1967 tfl_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
1968 // Not a ranked tensor output
1969 if (!output_type) return failure();
1970
1971 ElementsAttr multiples_elems;
1972 if (!matchPattern(tfl_tile_op.multiples(), m_Constant(&multiples_elems)))
1973 return failure();
1974 SmallVector<int64_t, 4> multiples_vals;
1975 for (int i = 0; i < multiples_elems.getNumElements(); i++)
1976 multiples_vals.push_back(multiples_elems.getValue<IntegerAttr>(i).getInt());
1977
1978 ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
1979 rewriter.replaceOpWithNewOp<tosa::TileOp>(
1980 op, output_type, tfl_tile_op.input(), multiples_attr);
1981
1982 return success();
1983 }
1984
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1985 LogicalResult ConvertTFLTransposeOp::matchAndRewrite(
1986 Operation* op, PatternRewriter& rewriter) const {
1987 auto tfl_transpose_op = cast<TFL::TransposeOp>(op);
1988
1989 RankedTensorType output_type =
1990 tfl_transpose_op.getResult().getType().dyn_cast<RankedTensorType>();
1991 // Not a ranked tensor output
1992 if (!output_type) return failure();
1993
1994 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
1995 op, output_type, tfl_transpose_op.input(), tfl_transpose_op.perm());
1996
1997 return success();
1998 }
1999
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2000 LogicalResult ConvertTFLPackOp::matchAndRewrite(
2001 Operation* op, PatternRewriter& rewriter) const {
2002 auto tfl_pack_op = cast<TFL::PackOp>(op);
2003
2004 SmallVector<Value, 8> inputs(tfl_pack_op.values());
2005 assert(inputs.size() >= 2);
2006
2007 IntegerAttr axis_attr;
2008 {
2009 auto tmpAttr = tfl_pack_op.axisAttr();
2010 if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
2011 axis_attr = tmpAttr;
2012 }
2013 int32_t axis_i32 = axis_attr.getInt();
2014
2015 llvm::Optional<Value> result =
2016 convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
2017
2018 if (!result) return failure();
2019
2020 rewriter.replaceOp(op, {result.getValue()});
2021
2022 return success();
2023 }
2024
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2025 LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
2026 Operation* op, PatternRewriter& rewriter) const {
2027 auto tfl_unpack_op = cast<TFL::UnpackOp>(op);
2028
2029 IntegerAttr axis_attr;
2030 {
2031 auto tmpAttr = tfl_unpack_op.axisAttr();
2032 if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
2033 axis_attr = tmpAttr;
2034 }
2035 int32_t axis_i32 = axis_attr.getInt();
2036
2037 llvm::Optional<ValueRange> results =
2038 convertUnpackOp(rewriter, op, tfl_unpack_op.input(), axis_i32);
2039
2040 if (!results) return failure();
2041
2042 rewriter.replaceOp(op, results.getValue());
2043
2044 return success();
2045 }
2046
2047 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2048 LogicalResult ConvertTFLSplitOp::matchAndRewrite(
2049 Operation* op, PatternRewriter& rewriter) const {
2050 auto tfl_split_op = cast<TFL::SplitOp>(op);
2051
2052 // Get the number of splits
2053 int32_t num_split = -1;
2054 auto numSplitAttr = tfl_split_op.num_splitsAttr();
2055 if (numSplitAttr) {
2056 num_split = numSplitAttr.getInt();
2057 } else {
2058 return failure();
2059 }
2060
2061 // Get the axis
2062 ElementsAttr axisAttrElems;
2063 if (!matchPattern(tfl_split_op.split_dim(), m_Constant(&axisAttrElems))) {
2064 return op->emitOpError("Cannot read split_dim elems");
2065 }
2066
2067 // The axis/split_dim parameter is stored as a 0D tensor instead of
2068 // an integer attribute in TFLite MLIR.
2069 int32_t axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
2070
2071 llvm::Optional<ValueRange> results =
2072 convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
2073 tfl_split_op.value(), num_split, axis);
2074
2075 if (!results) return failure();
2076
2077 rewriter.replaceOp(op, results.getValue());
2078
2079 return success();
2080 }
2081
2082 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2083 LogicalResult ConvertTFLSplitVOp::matchAndRewrite(
2084 Operation* op, PatternRewriter& rewriter) const {
2085 auto tfl_splitv_op = cast<TFL::SplitVOp>(op);
2086
2087 // Get the size_splits array
2088 SmallVector<int32_t, 4> size_split;
2089 ElementsAttr size_split_elems;
2090 if (!matchPattern(tfl_splitv_op.size_splits(),
2091 m_Constant(&size_split_elems))) {
2092 return failure();
2093 }
2094
2095 for (int i = 0; i < size_split_elems.getNumElements(); i++) {
2096 size_split.push_back(size_split_elems.getValue<IntegerAttr>(i).getInt());
2097 }
2098
2099 // Get the axis
2100 ElementsAttr axisAttrElems;
2101 if (!matchPattern(tfl_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
2102 return op->emitOpError("Cannot read split_dim elems");
2103 }
2104
2105 // The axis/split_dim parameter is stored as a 0D tensor instead of
2106 // an integer attribute in TFLite MLIR.
2107 int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
2108
2109 llvm::Optional<ValueRange> results =
2110 convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
2111 tfl_splitv_op.value(), size_split, axis);
2112
2113 if (!results) return failure();
2114
2115 rewriter.replaceOp(op, results.getValue());
2116
2117 return success();
2118 }
2119
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2120 LogicalResult ConvertTFLLessOp::matchAndRewrite(
2121 Operation* op, PatternRewriter& rewriter) const {
2122 auto tfl_less_op = cast<TFL::LessOp>(op);
2123
2124 RankedTensorType input_lhs_type =
2125 tfl_less_op.lhs().getType().dyn_cast<RankedTensorType>();
2126 RankedTensorType input_rhs_type =
2127 tfl_less_op.rhs().getType().dyn_cast<RankedTensorType>();
2128 RankedTensorType output_type =
2129 tfl_less_op.getResult().getType().dyn_cast<RankedTensorType>();
2130 // Not a ranked tensor output
2131 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
2132
2133 bool input_lhs_is_qtype =
2134 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2135 bool input_rhs_is_qtype =
2136 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2137 bool output_is_qtype =
2138 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2139
2140 if (input_lhs_is_qtype != output_is_qtype ||
2141 input_rhs_is_qtype != output_is_qtype) {
2142 return op->emitOpError(
2143 "ConvertTFLLessOp: input/output tensor should "
2144 "be all quantized or all floating-point.");
2145 }
2146
2147 Value output;
2148 if (output_is_qtype) {
2149 UniformQuantizedType input_lhs_qtype =
2150 input_lhs_type.getElementType()
2151 .dyn_cast<mlir::quant::UniformQuantizedType>();
2152 UniformQuantizedType input_rhs_qtype =
2153 input_rhs_type.getElementType()
2154 .dyn_cast<mlir::quant::UniformQuantizedType>();
2155
2156 if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
2157 input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
2158 return op->emitOpError(
2159 "ConvertTFLLessOp: input_x and input_y scale/zp "
2160 "must be the same");
2161 }
2162
2163 Value op1_rescale_lhs = buildRescaleToInt32(
2164 rewriter, op, tfl_less_op.lhs(), 1.0f, input_lhs_qtype.getZeroPoint());
2165 Value op2_rescale_rhs = buildRescaleToInt32(
2166 rewriter, op, tfl_less_op.rhs(), 1.0f, input_rhs_qtype.getZeroPoint());
2167 auto op3_greater_equal_op1_op2 = rewriter.create<tosa::GreaterEqualOp>(
2168 op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
2169 auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
2170 op->getLoc(), output_type, op3_greater_equal_op1_op2.getResult());
2171
2172 output = op4_not_op3.getResult();
2173 } else {
2174 auto op1_greater_equal_in = rewriter.create<tosa::GreaterEqualOp>(
2175 op->getLoc(), output_type, tfl_less_op.lhs(), tfl_less_op.rhs());
2176 auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
2177 op->getLoc(), output_type, op1_greater_equal_in.getResult());
2178
2179 output = op2_not_op1.getResult();
2180 }
2181
2182 rewriter.replaceOp(op, {output});
2183 return success();
2184 }
2185
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2186 LogicalResult ConvertTFLLessEqualOp::matchAndRewrite(
2187 Operation* op, PatternRewriter& rewriter) const {
2188 auto tfl_less_equal_op = cast<TFL::LessEqualOp>(op);
2189
2190 RankedTensorType input_lhs_type =
2191 tfl_less_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
2192 RankedTensorType input_rhs_type =
2193 tfl_less_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
2194 RankedTensorType output_type =
2195 tfl_less_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
2196 // Not a ranked tensor output
2197 if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
2198
2199 bool input_lhs_is_qtype =
2200 input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2201 bool input_rhs_is_qtype =
2202 input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2203 bool output_is_qtype =
2204 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2205
2206 if (input_lhs_is_qtype != output_is_qtype ||
2207 input_rhs_is_qtype != output_is_qtype) {
2208 return op->emitOpError(
2209 "ConvertTFLLessEqualOp: input/output tensor should "
2210 "be all quantized or all floating-point.");
2211 }
2212
2213 Value output;
2214 if (output_is_qtype) {
2215 UniformQuantizedType input_lhs_qtype =
2216 input_lhs_type.getElementType()
2217 .dyn_cast<mlir::quant::UniformQuantizedType>();
2218 UniformQuantizedType input_rhs_qtype =
2219 input_rhs_type.getElementType()
2220 .dyn_cast<mlir::quant::UniformQuantizedType>();
2221
2222 if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
2223 input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
2224 return op->emitOpError(
2225 "ConvertTFLLessEqualOp: input_x and input_y scale/zp "
2226 "must be the same");
2227 }
2228
2229 Value op1_rescale_lhs =
2230 buildRescaleToInt32(rewriter, op, tfl_less_equal_op.lhs(), 1.0f,
2231 input_lhs_qtype.getZeroPoint());
2232 Value op2_rescale_rhs =
2233 buildRescaleToInt32(rewriter, op, tfl_less_equal_op.rhs(), 1.0f,
2234 input_rhs_qtype.getZeroPoint());
2235 auto op3_greater_op1_op2 = rewriter.create<tosa::GreaterOp>(
2236 op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
2237 auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
2238 op->getLoc(), output_type, op3_greater_op1_op2.getResult());
2239
2240 output = op4_not_op3.getResult();
2241 } else {
2242 auto op1_greater_in = rewriter.create<tosa::GreaterOp>(
2243 op->getLoc(), output_type, tfl_less_equal_op.lhs(),
2244 tfl_less_equal_op.rhs());
2245 auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
2246 op->getLoc(), output_type, op1_greater_in.getResult());
2247
2248 output = op2_not_op1.getResult();
2249 }
2250
2251 rewriter.replaceOp(op, {output});
2252 return success();
2253 }
2254
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2255 LogicalResult ConvertTFLPadOp::matchAndRewrite(
2256 Operation* op, PatternRewriter& rewriter) const {
2257 auto tfl_pad_op = cast<TFL::PadOp>(op);
2258
2259 RankedTensorType output_type =
2260 tfl_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
2261 // Not a ranked tensor output
2262 if (!output_type) return failure();
2263
2264 auto pad_op = rewriter.create<tosa::PadOp>(
2265 op->getLoc(), output_type, tfl_pad_op.input(), tfl_pad_op.padding());
2266
2267 rewriter.replaceOp(op, {pad_op.getResult()});
2268 return success();
2269 }
2270
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2271 LogicalResult ConvertTFLResizeBilinearOp::matchAndRewrite(
2272 Operation* op, PatternRewriter& rewriter) const {
2273 auto tfl_resize_op = cast<TFL::ResizeBilinearOp>(op);
2274
2275 RankedTensorType output_type =
2276 tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
2277 // Not a ranked tensor output
2278 if (!output_type) return failure();
2279
2280 llvm::Optional<Value> result = convertResizeOp(
2281 rewriter, op, output_type, tfl_resize_op.input(), StringRef("BILINEAR"));
2282
2283 if (!result) return failure();
2284
2285 rewriter.replaceOp(op, {result.getValue()});
2286
2287 return success();
2288 }
2289
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2290 LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite(
2291 Operation* op, PatternRewriter& rewriter) const {
2292 auto tfl_resize_op = cast<TFL::ResizeNearestNeighborOp>(op);
2293
2294 RankedTensorType output_type =
2295 tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
2296 // Not a ranked tensor output
2297 if (!output_type) return failure();
2298
2299 llvm::Optional<Value> result = convertResizeOp(
2300 rewriter, op, output_type, tfl_resize_op.input(), StringRef("NEAREST"));
2301
2302 if (!result) return failure();
2303
2304 rewriter.replaceOp(op, {result.getValue()});
2305
2306 return success();
2307 }
2308
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2309 LogicalResult ConvertTFLSelectOp::matchAndRewrite(
2310 Operation* op, PatternRewriter& rewriter) const {
2311 auto tfl_sel_op = cast<TFL::SelectOp>(op);
2312
2313 llvm::Optional<Value> result =
2314 convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
2315 tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
2316 if (!result) return failure();
2317
2318 rewriter.replaceOp(op, {result.getValue()});
2319
2320 return success();
2321 }
2322
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2323 LogicalResult ConvertTFLSelectV2Op::matchAndRewrite(
2324 Operation* op, PatternRewriter& rewriter) const {
2325 auto tfl_sel_op = cast<TFL::SelectV2Op>(op);
2326
2327 llvm::Optional<Value> result =
2328 convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
2329 tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
2330 if (!result) return failure();
2331
2332 rewriter.replaceOp(op, {result.getValue()});
2333
2334 return success();
2335 }
2336
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2337 LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite(
2338 Operation* op, PatternRewriter& rewriter) const {
2339 auto tfl_s2b_op = cast<TFL::SpaceToBatchNdOp>(op);
2340 llvm::Optional<Value> result = convertSpaceToBatchNDOp(
2341 rewriter, op, tfl_s2b_op.getResult(), tfl_s2b_op.input(),
2342 tfl_s2b_op.block_shape(), tfl_s2b_op.paddings());
2343
2344 if (!result) return failure();
2345
2346 rewriter.replaceOp(op, {result.getValue()});
2347
2348 return success();
2349 }
2350
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2351 LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite(
2352 Operation* op, PatternRewriter& rewriter) const {
2353 auto tfl_b2s_op = cast<TFL::BatchToSpaceNdOp>(op);
2354
2355 llvm::Optional<Value> result = convertBatchToSpaceNDOp(
2356 rewriter, op, tfl_b2s_op.getResult(), tfl_b2s_op.input(),
2357 tfl_b2s_op.block_shape(), tfl_b2s_op.indices());
2358
2359 if (!result) return failure();
2360
2361 rewriter.replaceOp(op, {result.getValue()});
2362
2363 return success();
2364 }
2365
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2366 LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite(
2367 Operation* op, PatternRewriter& rewriter) const {
2368 auto tfl_s2d_op = cast<TFL::SpaceToDepthOp>(op);
2369
2370 auto block_size_attr = tfl_s2d_op.block_sizeAttr();
2371 llvm::Optional<Value> result = convertSpaceToDepthOp(
2372 rewriter, op, tfl_s2d_op.getResult(), tfl_s2d_op.input(), block_size_attr,
2373 rewriter.getStringAttr("NHWC"));
2374
2375 if (!result) return failure();
2376
2377 rewriter.replaceOp(op, {result.getValue()});
2378
2379 return success();
2380 }
2381
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2382 LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite(
2383 Operation* op, PatternRewriter& rewriter) const {
2384 auto tfl_d2s_op = cast<TFL::DepthToSpaceOp>(op);
2385
2386 auto block_size_attr = tfl_d2s_op.block_sizeAttr();
2387 llvm::Optional<Value> result = convertDepthToSpaceOp(
2388 rewriter, op, tfl_d2s_op.getResult(), tfl_d2s_op.input(), block_size_attr,
2389 rewriter.getStringAttr("NHWC"));
2390
2391 if (!result) return failure();
2392
2393 rewriter.replaceOp(op, {result.getValue()});
2394
2395 return success();
2396 }
2397
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2398 LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite(
2399 Operation* op, PatternRewriter& rewriter) const {
2400 auto tfl_ss_op = cast<TFL::StridedSliceOp>(op);
2401
2402 llvm::Optional<Value> result = convertStridedSliceOp(
2403 rewriter, op, tfl_ss_op.getResult(), tfl_ss_op.input(), tfl_ss_op.begin(),
2404 tfl_ss_op.end(), tfl_ss_op.strides(), tfl_ss_op.begin_maskAttr().getInt(),
2405 tfl_ss_op.end_maskAttr().getInt(), tfl_ss_op.ellipsis_maskAttr().getInt(),
2406 tfl_ss_op.new_axis_maskAttr().getInt(),
2407 tfl_ss_op.shrink_axis_maskAttr().getInt());
2408 if (!result) return failure();
2409
2410 rewriter.replaceOp(op, {result.getValue()});
2411
2412 return success();
2413 }
2414
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2415 LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite(
2416 Operation* op, PatternRewriter& rewriter) const {
2417 auto tfl_zeroslike_op = cast<TFL::ZerosLikeOp>(op);
2418
2419 llvm::Optional<Value> result = convertZerosLikeOp(
2420 rewriter, op, tfl_zeroslike_op.getResult(), tfl_zeroslike_op.input());
2421
2422 if (!result) return failure();
2423
2424 rewriter.replaceOp(op, {result.getValue()});
2425
2426 return success();
2427 }
2428
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2429 LogicalResult ConvertTFLHardSwishOp::matchAndRewrite(
2430 Operation* op, PatternRewriter& rewriter) const {
2431 auto tfl_hardswish_op = cast<TFL::HardSwishOp>(op);
2432 RankedTensorType output_type =
2433 tfl_hardswish_op.getResult().getType().dyn_cast<RankedTensorType>();
2434 // Not a ranked tensor output
2435 if (!output_type) return failure();
2436
2437 RankedTensorType input_type =
2438 tfl_hardswish_op.input().getType().dyn_cast<RankedTensorType>();
2439 // Not a ranked tensor output
2440 if (!input_type) return failure();
2441
2442 auto input_shape = input_type.getShape();
2443
2444 // TFL hardswish: f(x) -> (x * relu6(x+3))/6
2445
2446 // TODO: support 16-bit hardswish
2447 if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
2448 output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
2449 // TFLite reference:
2450 // tensorflow/lite/kernels/internal/reference/reference_ops.h note
2451 // there's a potential rounding issue in TFLite reference
2452 mlir::quant::UniformQuantizedType in_quant_type =
2453 input_type.getElementType()
2454 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2455 mlir::quant::UniformQuantizedType out_quant_type =
2456 output_type.getElementType()
2457 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2458
2459 UniformQuantizedType int16_element_qtype =
2460 mlir::quant::UniformQuantizedType::get(
2461 true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
2462 -32768, 32767);
2463 RankedTensorType bool_type =
2464 RankedTensorType::get(input_shape, rewriter.getI1Type());
2465 RankedTensorType int16_type =
2466 RankedTensorType::get(input_shape, int16_element_qtype);
2467 RankedTensorType int32_type =
2468 RankedTensorType::get(input_shape, rewriter.getI32Type());
2469
2470 // Table's real input range [-4.0, 4.0].
2471 // Use TABLE op to get relu6(x+3) / 6
2472 const double input_sample_grain = 1.0 / 64.0;
2473 auto hardswish_func = [input_sample_grain](int32_t x) -> int32_t {
2474 double v = static_cast<double>(x) * input_sample_grain;
2475 double w = v + 3.0;
2476 w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w;
2477 v = v * w / 6.0;
2478 return std::lround(32768.0 * v);
2479 };
2480
2481 Value table_const = getTosa1DConstTensorTable(rewriter, op, hardswish_func);
2482
2483 // Rescale input to 9.7
2484 Value op1_rescale_in =
2485 buildRescale(rewriter, op, int16_type, tfl_hardswish_op.input(),
2486 (in_quant_type.getScale() * 128.0) / input_sample_grain,
2487 in_quant_type.getZeroPoint(), 0);
2488
2489 // Table op. output 0.23
2490 auto op2_table_op1 = rewriter.create<tosa::TableOp>(
2491 op->getLoc(), int32_type, op1_rescale_in, table_const);
2492
2493 // scale table output back to quantized space
2494 Value op3_rescale_op2 =
2495 buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
2496 1.0 / (128.0 * 32768.0 * out_quant_type.getScale()), 0,
2497 out_quant_type.getZeroPoint());
2498
2499 Value op4_rescale_in = buildRescale(rewriter, op, int32_type,
2500 tfl_hardswish_op.input(), 1.0, 0, 0);
2501
2502 // Get 3.0 in quantized space
2503 int32_t quantized_3 =
2504 static_cast<int32_t>(std::ceil(3.0 / in_quant_type.getScale())) +
2505 in_quant_type.getZeroPoint();
2506
2507 auto op5_ge_op4 = rewriter.create<tosa::GreaterEqualOp>(
2508 op->getLoc(), bool_type, op4_rescale_in,
2509 getTosaConstTensorSingleI32(rewriter, op, quantized_3));
2510
2511 auto op6_select_op5_op4_op3 = rewriter.create<tosa::SelectOp>(
2512 op->getLoc(), output_type, op5_ge_op4, tfl_hardswish_op.input(),
2513 op3_rescale_op2);
2514
2515 rewriter.replaceOp(op, {op6_select_op5_op4_op3});
2516
2517 return success();
2518
2519 } else {
2520 // op1 = constop(3)
2521 // op2 = add(x, op1)
2522 // op3 = reluN(op2, 6)
2523 // op4 = mul(x, op3)
2524 // op5 = reciprocal(6)
2525 // op6 = mul (op4, op5)
2526
2527 Value op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
2528
2529 auto op2_add_x_op1 = rewriter.create<tosa::AddOp>(
2530 op->getLoc(), output_type, tfl_hardswish_op.input(), op1_value);
2531
2532 auto op3_relu_op2_6 = rewriter.create<tosa::ReluNOp>(
2533 op->getLoc(), output_type, op2_add_x_op1.getResult(),
2534 rewriter.getI64IntegerAttr(0), rewriter.getF32FloatAttr(6.0));
2535
2536 auto op4_mul_x_op3 = rewriter.create<tosa::MulOp>(
2537 op->getLoc(), output_type, tfl_hardswish_op.input(),
2538 op3_relu_op2_6.getResult(), 0);
2539
2540 auto op5_reciprocal_6 = rewriter.create<tosa::ReciprocalOp>(
2541 op->getLoc(), output_type,
2542 getTosaConstTensorSingleF32(rewriter, op, 6.0));
2543
2544 auto op6_mul_op4_op5 = rewriter.create<tosa::MulOp>(
2545 op->getLoc(), output_type, op4_mul_x_op3.getResult(),
2546 op5_reciprocal_6.getResult(), 0);
2547
2548 rewriter.replaceOp(op, {op6_mul_op4_op5.getResult()});
2549
2550 return success();
2551 }
2552 }
2553
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2554 LogicalResult ConvertTFLLogisticOp::matchAndRewrite(
2555 Operation* op, PatternRewriter& rewriter) const {
2556 auto tfl_logistic_op = cast<TFL::LogisticOp>(op);
2557
2558 RankedTensorType output_type =
2559 tfl_logistic_op.getResult().getType().dyn_cast<RankedTensorType>();
2560 RankedTensorType input_type =
2561 tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
2562 if (!input_type || !output_type) return failure();
2563
2564 bool input_is_qtype =
2565 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2566 bool output_is_qtype =
2567 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2568
2569 if (input_is_qtype != output_is_qtype) {
2570 return op->emitOpError(
2571 "ConvertTFLLogisticOp: input/output tensor should "
2572 "be all quantized or all floating-point.");
2573 }
2574
2575 if (input_is_qtype) {
2576 UniformQuantizedType int16_element_qtype =
2577 mlir::quant::UniformQuantizedType::get(
2578 true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
2579 -32768, 32767);
2580 RankedTensorType int16_type =
2581 RankedTensorType::get(output_type.getShape(), int16_element_qtype);
2582 RankedTensorType int32_type = RankedTensorType::get(
2583 output_type.getShape(), rewriter.getIntegerType(32));
2584 mlir::quant::UniformQuantizedType input_qtype =
2585 input_type.getElementType()
2586 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2587 mlir::quant::UniformQuantizedType output_qtype =
2588 output_type.getElementType()
2589 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2590 const double input_sample_grain = 1.0 / 16.0;
2591 auto sigmoid_func = [input_sample_grain](int32_t x) -> int32_t {
2592 // Input range [-16.0, 16.0], output range [0.0, 1.0]
2593 double v = static_cast<double>(x) * input_sample_grain;
2594 v = 1.0 / (1.0 + std::exp(-v));
2595
2596 return std::lround(32768.0 * v);
2597 };
2598
2599 Value table_const = getTosa1DConstTensorTable(rewriter, op, sigmoid_func);
2600
2601 // Rescale input to 9.7 precision.
2602 Value op1_rescale_in =
2603 buildRescale(rewriter, op, int16_type, tfl_logistic_op.x(),
2604 (input_qtype.getScale() * 128.0) / input_sample_grain,
2605 input_qtype.getZeroPoint(), 0);
2606
2607 auto op2_table_op1 = rewriter.create<tosa::TableOp>(
2608 op->getLoc(), int32_type, op1_rescale_in, table_const);
2609
2610 double output_rescale_scale =
2611 1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
2612
2613 Value op3_rescale_op2 =
2614 buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
2615 output_rescale_scale, 0, output_qtype.getZeroPoint());
2616
2617 rewriter.replaceOp(op, {op3_rescale_op2});
2618 } else {
2619 rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(op, output_type,
2620 tfl_logistic_op.x());
2621 }
2622
2623 return success();
2624 }
2625
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2626 LogicalResult ConvertTFLTanhOp::matchAndRewrite(
2627 Operation* op, PatternRewriter& rewriter) const {
2628 auto tfl_tanh_op = cast<TFL::TanhOp>(op);
2629 RankedTensorType output_type =
2630 tfl_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
2631 RankedTensorType input_type =
2632 tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
2633 if (!input_type || !output_type) return failure();
2634
2635 bool input_is_qtype =
2636 input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2637 bool output_is_qtype =
2638 output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2639
2640 if (input_is_qtype != output_is_qtype) {
2641 return op->emitOpError(
2642 "ConvertTFLTanhOp: input/output tensor should "
2643 "be all quantized or all floating-point.");
2644 }
2645
2646 if (input_is_qtype) {
2647 UniformQuantizedType int16_element_qtype =
2648 mlir::quant::UniformQuantizedType::get(
2649 true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
2650 -32768, 32767);
2651 RankedTensorType int16_type =
2652 RankedTensorType::get(output_type.getShape(), int16_element_qtype);
2653 RankedTensorType int32_type = RankedTensorType::get(
2654 output_type.getShape(), rewriter.getIntegerType(32));
2655 mlir::quant::UniformQuantizedType input_qtype =
2656 input_type.getElementType()
2657 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2658 mlir::quant::UniformQuantizedType output_qtype =
2659 output_type.getElementType()
2660 .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2661 const double input_sample_grain = 1.0 / 32.0;
2662 auto tanh_func = [input_sample_grain](int32_t x) -> int32_t {
2663 // Input range [-16.0, 16.0], output range [0.0, 1.0]
2664 double v = static_cast<double>(x) * input_sample_grain;
2665 v = std::exp(-2.0 * v);
2666 v = (1.0 - v) / (1.0 + v);
2667
2668 return std::lround(32768.0 * v);
2669 };
2670
2671 Value table_const = getTosa1DConstTensorTable(rewriter, op, tanh_func);
2672
2673 // Rescale input to 9.7 precision.
2674 Value op1_rescale_in =
2675 buildRescale(rewriter, op, int16_type, tfl_tanh_op.input(),
2676 (input_qtype.getScale() * 128.0) / input_sample_grain,
2677 input_qtype.getZeroPoint(), 0);
2678
2679 auto op2_table_op1 = rewriter.create<tosa::TableOp>(
2680 op->getLoc(), int32_type, op1_rescale_in, table_const);
2681
2682 double output_rescale_scale =
2683 1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
2684
2685 Value op3_rescale_op2 =
2686 buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
2687 output_rescale_scale, 0, output_qtype.getZeroPoint());
2688
2689 rewriter.replaceOp(op, {op3_rescale_op2});
2690 } else {
2691 rewriter.replaceOpWithNewOp<tosa::TanhOp>(op, output_type,
2692 tfl_tanh_op.input());
2693 }
2694
2695 return success();
2696 }
2697
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2698 LogicalResult ConvertTFLPReluOp::matchAndRewrite(
2699 Operation* op, PatternRewriter& rewriter) const {
2700 auto tfl_prelu_op = cast<TFL::PReluOp>(op);
2701 RankedTensorType output_type =
2702 tfl_prelu_op.getResult().getType().dyn_cast<RankedTensorType>();
2703 if (!output_type) return failure();
2704
2705 // TODO: add lowering with MUL + SELECT + RESCALE
2706
2707 return failure();
2708 }
2709
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2710 LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite(
2711 Operation* op, PatternRewriter& rewriter) const {
2712 auto tfl_leakyrelu_op = cast<TFL::LeakyReluOp>(op);
2713 RankedTensorType output_type =
2714 tfl_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
2715 if (!output_type) return failure();
2716
2717 // TODO: add lowering with MUL + SELECT + RESCALE
2718
2719 return failure();
2720 }
2721
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2722 LogicalResult ConvertTFLNegOp::matchAndRewrite(
2723 Operation* op, PatternRewriter& rewriter) const {
2724 auto tfl_neg_op = cast<TFL::NegOp>(op);
2725 RankedTensorType output_type =
2726 tfl_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
2727 if (!output_type) return failure();
2728
2729 rewriter.replaceOpWithNewOp<tosa::NegateOp>(op, output_type, tfl_neg_op.x());
2730
2731 return success();
2732 }
2733
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2734 LogicalResult ConvertTFLYieldOp::matchAndRewrite(
2735 Operation* op, PatternRewriter& rewriter) const {
2736 rewriter.replaceOpWithNewOp<tosa::YieldOp>(op, op->getResultTypes(),
2737 op->getOperands());
2738
2739 return success();
2740 }
2741
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2742 LogicalResult ConvertTFLCustomOp::matchAndRewrite(
2743 Operation* op, PatternRewriter& rewriter) const {
2744 auto tfl_custom_op = cast<TFL::CustomOp>(op);
2745 rewriter.replaceOpWithNewOp<tosa::CustomOp>(
2746 op, op->getResultTypes(), tfl_custom_op.custom_code(), op->getOperands());
2747
2748 return success();
2749 }
2750
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2751 LogicalResult ConvertTFLReverseV2Op::matchAndRewrite(
2752 Operation* op, PatternRewriter& rewriter) const {
2753 auto tfl_reverse_op = cast<TFL::ReverseV2Op>(op);
2754
2755 RankedTensorType input_type =
2756 tfl_reverse_op.input().getType().dyn_cast<RankedTensorType>();
2757 RankedTensorType output_type =
2758 tfl_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
2759 if (!input_type || !output_type) return failure();
2760
2761 ElementsAttr axis_elems;
2762 if (!matchPattern(tfl_reverse_op.axis(), m_Constant(&axis_elems)))
2763 return failure();
2764
2765 auto input_rank = input_type.getShape().size();
2766 Value val = tfl_reverse_op.input();
2767 if (axis_elems.getNumElements() == 0) {
2768 auto identity_op =
2769 rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
2770 val = identity_op.getResult();
2771 } else {
2772 for (int i = 0; i < axis_elems.getNumElements(); i++) {
2773 int64_t axis_val = axis_elems.getValue<IntegerAttr>(i).getInt();
2774 if (axis_val < 0) axis_val += input_rank;
2775 auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2776 auto reverse_op = rewriter.create<tosa::ReverseOp>(
2777 op->getLoc(), output_type, val, axis_attr);
2778
2779 val = reverse_op.getResult();
2780 }
2781 }
2782
2783 rewriter.replaceOp(op, {val});
2784
2785 return success();
2786 }
2787
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2788 LogicalResult ConvertTFLQuantizeOp::matchAndRewrite(
2789 Operation* op, PatternRewriter& rewriter) const {
2790 auto tfl_quantize_op = cast<TFL::QuantizeOp>(op);
2791
2792 RankedTensorType input_type =
2793 tfl_quantize_op.input().getType().dyn_cast<RankedTensorType>();
2794 RankedTensorType output_type =
2795 tfl_quantize_op.getResult().getType().dyn_cast<RankedTensorType>();
2796
2797 if (!input_type || !output_type) return failure();
2798
2799 RankedTensorType qtype =
2800 tfl_quantize_op.qtypeAttr().getValue().dyn_cast<RankedTensorType>();
2801 if (!qtype) return failure();
2802
2803 UniformQuantizedType element_type =
2804 qtype.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
2805 if (!element_type) return failure();
2806
2807 UniformQuantizedType input_element_type =
2808 input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
2809
2810 // If input is already a quantized type, this is basically a RESCALE (or
2811 // tensorflow::ops::Requantize)
2812 if (input_element_type) {
2813 double rescale_scale =
2814 input_element_type.getScale() / element_type.getScale();
2815 Value rescale_op = buildRescale(
2816 rewriter, op, output_type, tfl_quantize_op.input(), rescale_scale,
2817 input_element_type.getZeroPoint(), element_type.getZeroPoint());
2818
2819 rewriter.replaceOp(op, {rescale_op});
2820 return success();
2821 } else {
2822 double scale = 1 / element_type.getScale();
2823 int64_t zp = element_type.getZeroPoint();
2824 int64_t num_bits = element_type.getStorageTypeIntegralWidth();
2825 zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
2826
2827 llvm::Optional<Value> result = convertQuantizeOp(
2828 rewriter, op, output_type, tfl_quantize_op.input(), scale, zp);
2829
2830 if (!result) return failure();
2831
2832 rewriter.replaceOp(op, {result.getValue()});
2833
2834 return success();
2835 }
2836 }
2837
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2838 LogicalResult ConvertTFLDequantizeOp::matchAndRewrite(
2839 Operation* op, PatternRewriter& rewriter) const {
2840 auto tfl_dequantize_op = cast<TFL::DequantizeOp>(op);
2841
2842 RankedTensorType output_type =
2843 tfl_dequantize_op.getResult().getType().dyn_cast<RankedTensorType>();
2844 // Not a ranked tensor output
2845 if (!output_type) return failure();
2846
2847 RankedTensorType qtype =
2848 tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
2849 if (!qtype) return failure();
2850
2851 UniformQuantizedType element_type =
2852 qtype.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
2853 if (!element_type) return failure();
2854
2855 double scale = element_type.getScale();
2856 int64_t zp = element_type.getZeroPoint();
2857 int64_t num_bits = element_type.getStorageTypeIntegralWidth();
2858 zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
2859
2860 llvm::Optional<Value> result = convertDequantizeOp(
2861 rewriter, op, output_type, tfl_dequantize_op.input(), scale, zp);
2862
2863 if (!result) return failure();
2864
2865 rewriter.replaceOp(op, {result.getValue()});
2866
2867 return success();
2868 }
2869
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2870 LogicalResult ConvertTFLQConstOp::matchAndRewrite(
2871 Operation* op, PatternRewriter& rewriter) const {
2872 auto tfl_qconst_op = cast<TFL::QConstOp>(op);
2873
2874 RankedTensorType output_type =
2875 tfl_qconst_op.getResult().getType().dyn_cast<RankedTensorType>();
2876 // Not a ranked tensor output
2877 if (!output_type) return failure();
2878
2879 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type,
2880 tfl_qconst_op.valueAttr());
2881
2882 return success();
2883 }
2884
runOnFunction()2885 void LegalizeTFL::runOnFunction() {
2886 OwningRewritePatternList patterns;
2887 auto* ctx = &getContext();
2888 auto func = getFunction();
2889
2890 // Add the generated patterns to the list.
2891 populateWithGenerated(ctx, patterns);
2892
2893 #define DEF_PATTERN_INSERT(PAT) patterns.insert<Convert##PAT##Op>(ctx);
2894
2895 DEF_PATTERN_INSERT(TFLRelu);
2896 DEF_PATTERN_INSERT(TFLRelu6);
2897 DEF_PATTERN_INSERT(TFLEqual);
2898 DEF_PATTERN_INSERT(TFLNotEqual);
2899 DEF_PATTERN_INSERT(TFLGreater);
2900 DEF_PATTERN_INSERT(TFLGreaterEqual);
2901 DEF_PATTERN_INSERT(TFLAdd);
2902 DEF_PATTERN_INSERT(TFLSub);
2903 DEF_PATTERN_INSERT(TFLMul);
2904 DEF_PATTERN_INSERT(TFLSquare);
2905 DEF_PATTERN_INSERT(TFLDiv);
2906 DEF_PATTERN_INSERT(TFLMaximum);
2907 DEF_PATTERN_INSERT(TFLMinimum);
2908 DEF_PATTERN_INSERT(TFLFloorMod);
2909 DEF_PATTERN_INSERT(TFLFloorDiv);
2910 DEF_PATTERN_INSERT(TFLAddN);
2911 DEF_PATTERN_INSERT(TFLAveragePool2D);
2912 DEF_PATTERN_INSERT(TFLMaxPool2D);
2913 DEF_PATTERN_INSERT(TFLConcatenation);
2914 DEF_PATTERN_INSERT(TFLReshape);
2915 DEF_PATTERN_INSERT(TFLRank);
2916 DEF_PATTERN_INSERT(TFLShape);
2917 DEF_PATTERN_INSERT(TFLExpandDims);
2918 DEF_PATTERN_INSERT(TFLSqueeze);
2919 DEF_PATTERN_INSERT(TFLFill);
2920 DEF_PATTERN_INSERT(TFLElu);
2921 DEF_PATTERN_INSERT(TFLSoftmax);
2922 DEF_PATTERN_INSERT(TFLLogSoftmax);
2923 DEF_PATTERN_INSERT(TFLReduceAny);
2924 DEF_PATTERN_INSERT(TFLReduceMax);
2925 DEF_PATTERN_INSERT(TFLReduceMin);
2926 DEF_PATTERN_INSERT(TFLMean);
2927 DEF_PATTERN_INSERT(TFLReduceProd);
2928 DEF_PATTERN_INSERT(TFLSum);
2929 DEF_PATTERN_INSERT(TFLConv2D);
2930 DEF_PATTERN_INSERT(TFLTransposeConv);
2931 DEF_PATTERN_INSERT(TFLDepthwiseConv2D);
2932 DEF_PATTERN_INSERT(TFLFullyConnected);
2933 DEF_PATTERN_INSERT(TFLSplit);
2934 DEF_PATTERN_INSERT(TFLSplitV);
2935 DEF_PATTERN_INSERT(TFLPack);
2936 DEF_PATTERN_INSERT(TFLUnpack);
2937 DEF_PATTERN_INSERT(TFLTranspose);
2938 DEF_PATTERN_INSERT(TFLTile);
2939 DEF_PATTERN_INSERT(TFLSlice);
2940 DEF_PATTERN_INSERT(TFLStridedSlice);
2941 DEF_PATTERN_INSERT(TFLZerosLike);
2942 DEF_PATTERN_INSERT(TFLHardSwish);
2943 DEF_PATTERN_INSERT(TFLLess);
2944 DEF_PATTERN_INSERT(TFLLessEqual);
2945 DEF_PATTERN_INSERT(TFLPad);
2946 DEF_PATTERN_INSERT(TFLResizeBilinear);
2947 DEF_PATTERN_INSERT(TFLResizeNearestNeighbor);
2948 DEF_PATTERN_INSERT(TFLSelect);
2949 DEF_PATTERN_INSERT(TFLSelectV2);
2950 DEF_PATTERN_INSERT(TFLSpaceToBatchNd);
2951 DEF_PATTERN_INSERT(TFLBatchToSpaceNd);
2952 DEF_PATTERN_INSERT(TFLSpaceToDepth);
2953 DEF_PATTERN_INSERT(TFLDepthToSpace);
2954 DEF_PATTERN_INSERT(TFLLogistic);
2955 DEF_PATTERN_INSERT(TFLTanh);
2956 DEF_PATTERN_INSERT(TFLPRelu);
2957 DEF_PATTERN_INSERT(TFLLeakyRelu);
2958 DEF_PATTERN_INSERT(TFLNeg);
2959 DEF_PATTERN_INSERT(TFLYield);
2960 DEF_PATTERN_INSERT(TFLCustom);
2961 DEF_PATTERN_INSERT(TFLReverseV2);
2962 DEF_PATTERN_INSERT(TFLQuantize);
2963 DEF_PATTERN_INSERT(TFLDequantize);
2964 DEF_PATTERN_INSERT(TFLQConst);
2965 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
2966 }
2967 } // namespace
2968
2969 // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
createLegalizeTFLPass()2970 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass() {
2971 return std::make_unique<LegalizeTFL>();
2972 }
2973
2974 static PassRegistration<LegalizeTFL> pass(
2975 PASS_NAME, "Legalize from TensorFlow Lite to TOSA dialect");
2976
2977 } // namespace tosa
2978 } // namespace mlir
2979