• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Legalize TensorFlow Lite to TOSA
17 
18 #include <climits>
19 #include <cstddef>
20 #include <cstdint>
21 #include <fstream>
22 #include <iterator>
23 #include <numeric>
24 #include <unordered_set>
25 
26 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
27 #include "mlir/Support/LLVM.h"  // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
30 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
31 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
32 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
33 
34 #define PASS_NAME "tosa-legalize-tfl"
35 #define DEBUG_TYPE PASS_NAME
36 #define HARDSWISH_EXPLICIT_RESCALING false
37 
38 // Conditionally avoid converting some TFLite ops to TOSA.
39 // By default, all conversions will be invoked.
40 //
41 // The denylist file lists patterns which are not legalized from TFLite to TOSA.
42 llvm::cl::opt<std::string> tfl_tosa_denylist(
43     "tfl-tosa-denylist",
44     llvm::cl::desc("<a list of patterns not legalized from TFLite to TOSA>"),
45     llvm::cl::init("transforms/tfl_tosa_denylist.txt"),
46     llvm::cl::value_desc("pattern name"));
47 
48 namespace mlir {
49 
50 namespace tosa {
51 
52 namespace {
53 // Performs lowering to TOSA dialect.
54 class LegalizeTFL : public PassWrapper<LegalizeTFL, FunctionPass> {
55  public:
LegalizeTFL()56   explicit LegalizeTFL() {}
57   void runOnFunction() override;
58 };
59 
60 #include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
61 
62 #define DECL_CONVERT_OP(tfl_op)                                              \
63   struct ConvertTFL##tfl_op##Op : public RewritePattern {                    \
64     explicit ConvertTFL##tfl_op##Op(MLIRContext* context)                    \
65         : RewritePattern(TFL::tfl_op##Op::getOperationName(), 1, context) {} \
66     LogicalResult matchAndRewrite(Operation* op,                             \
67                                   PatternRewriter& rewriter) const override; \
68   }
69 DECL_CONVERT_OP(Relu);
70 DECL_CONVERT_OP(Relu6);
71 DECL_CONVERT_OP(Equal);
72 DECL_CONVERT_OP(NotEqual);
73 DECL_CONVERT_OP(Greater);
74 DECL_CONVERT_OP(GreaterEqual);
75 DECL_CONVERT_OP(Add);
76 DECL_CONVERT_OP(Sub);
77 DECL_CONVERT_OP(Mul);
78 DECL_CONVERT_OP(Square);
79 DECL_CONVERT_OP(SquaredDifference);
80 DECL_CONVERT_OP(Round);
81 DECL_CONVERT_OP(Div);
82 DECL_CONVERT_OP(Maximum);
83 DECL_CONVERT_OP(Minimum);
84 DECL_CONVERT_OP(FloorMod);
85 DECL_CONVERT_OP(FloorDiv);
86 DECL_CONVERT_OP(AddN);
87 DECL_CONVERT_OP(AveragePool2D);
88 DECL_CONVERT_OP(MaxPool2D);
89 DECL_CONVERT_OP(Concatenation);
90 DECL_CONVERT_OP(Reshape);
91 DECL_CONVERT_OP(Rank);
92 DECL_CONVERT_OP(Shape);
93 DECL_CONVERT_OP(ExpandDims);
94 DECL_CONVERT_OP(Squeeze);
95 DECL_CONVERT_OP(Fill);
96 DECL_CONVERT_OP(Elu);
97 DECL_CONVERT_OP(Softmax);
98 DECL_CONVERT_OP(LogSoftmax);
99 DECL_CONVERT_OP(ReduceAny);
100 DECL_CONVERT_OP(ReduceMax);
101 DECL_CONVERT_OP(ReduceMin);
102 DECL_CONVERT_OP(Mean);
103 DECL_CONVERT_OP(ReduceProd);
104 DECL_CONVERT_OP(Sum);
105 DECL_CONVERT_OP(Conv2D);
106 DECL_CONVERT_OP(TransposeConv);
107 DECL_CONVERT_OP(DepthwiseConv2D);
108 DECL_CONVERT_OP(FullyConnected);
109 DECL_CONVERT_OP(Split);
110 DECL_CONVERT_OP(SplitV);
111 DECL_CONVERT_OP(Pack);
112 DECL_CONVERT_OP(Unpack);
113 DECL_CONVERT_OP(Transpose);
114 DECL_CONVERT_OP(Tile);
115 DECL_CONVERT_OP(Slice);
116 DECL_CONVERT_OP(StridedSlice);
117 DECL_CONVERT_OP(HardSwish);
118 DECL_CONVERT_OP(ZerosLike);
119 DECL_CONVERT_OP(Less);
120 DECL_CONVERT_OP(LessEqual);
121 DECL_CONVERT_OP(Pad);
122 DECL_CONVERT_OP(ResizeBilinear);
123 DECL_CONVERT_OP(ResizeNearestNeighbor);
124 DECL_CONVERT_OP(Select);
125 DECL_CONVERT_OP(SelectV2);
126 DECL_CONVERT_OP(SpaceToBatchNd);
127 DECL_CONVERT_OP(BatchToSpaceNd);
128 DECL_CONVERT_OP(SpaceToDepth);
129 DECL_CONVERT_OP(DepthToSpace);
130 DECL_CONVERT_OP(Logistic);
131 DECL_CONVERT_OP(Tanh);
132 DECL_CONVERT_OP(PRelu);
133 DECL_CONVERT_OP(LeakyRelu);
134 DECL_CONVERT_OP(Neg);
135 DECL_CONVERT_OP(Yield);
136 DECL_CONVERT_OP(Custom);
137 DECL_CONVERT_OP(ReverseV2);
138 DECL_CONVERT_OP(Quantize);
139 DECL_CONVERT_OP(Dequantize);
140 DECL_CONVERT_OP(QConst);
141 #undef DECL_CONVERT_OP
142 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const143 LogicalResult ConvertTFLReluOp::matchAndRewrite(
144     Operation* op, PatternRewriter& rewriter) const {
145   auto tfl_relu_op = cast<TFL::ReluOp>(op);
146 
147   RankedTensorType input_type =
148       tfl_relu_op.x().getType().dyn_cast<RankedTensorType>();
149   RankedTensorType output_type =
150       tfl_relu_op.getResult().getType().dyn_cast<RankedTensorType>();
151   // Not a ranked tensor output
152   if (!input_type || !output_type) return failure();
153 
154   bool input_is_qtype =
155       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
156   bool output_is_qtype =
157       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
158 
159   if (input_is_qtype != output_is_qtype) {
160     return op->emitOpError(
161         "ConvertTFLReluOp: input/output tensor should "
162         "be all quantized or all floating-point.");
163   }
164 
165   Value output;
166   if (output_is_qtype) {
167     RankedTensorType rescale_type =
168         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
169     UniformQuantizedType input_qtype =
170         input_type.getElementType()
171             .dyn_cast<mlir::quant::UniformQuantizedType>();
172     UniformQuantizedType output_qtype =
173         output_type.getElementType()
174             .dyn_cast<mlir::quant::UniformQuantizedType>();
175 
176     Value op1_rescale_in = buildRescaleToInt32(
177         rewriter, op, tfl_relu_op.x(), 1.0f, input_qtype.getZeroPoint());
178     auto op2_relun_op1 = rewriter.create<tosa::ReluNOp>(
179         op->getLoc(), rescale_type, op1_rescale_in,
180         rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
181         rewriter.getF32FloatAttr(0.0f));
182     Value op3_rescale_op2 = buildRescaleFromInt32(
183         rewriter, op, output_type, op2_relun_op1.getResult(), 1.0f,
184         output_qtype.getZeroPoint());
185 
186     output = op3_rescale_op2;
187   } else {
188     auto op1_relun_in = rewriter.create<tosa::ReluNOp>(
189         op->getLoc(), output_type, tfl_relu_op.x(),
190         rewriter.getI64IntegerAttr(0),
191         rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
192 
193     output = op1_relun_in.getResult();
194   }
195 
196   rewriter.replaceOp(op, {output});
197   return success();
198 }
199 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const200 LogicalResult ConvertTFLRelu6Op::matchAndRewrite(
201     Operation* op, PatternRewriter& rewriter) const {
202   auto tfl_relu6_op = cast<TFL::Relu6Op>(op);
203 
204   RankedTensorType input_type =
205       tfl_relu6_op.x().getType().dyn_cast<RankedTensorType>();
206   RankedTensorType output_type =
207       tfl_relu6_op.getResult().getType().dyn_cast<RankedTensorType>();
208   // Not a ranked tensor output
209   if (!input_type || !output_type) return failure();
210 
211   bool input_is_qtype =
212       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
213   bool output_is_qtype =
214       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
215 
216   if (input_is_qtype != output_is_qtype) {
217     return op->emitOpError(
218         "ConvertTFLRelu6Op: input/output tensor should "
219         "be all quantized or all floating-point.");
220   }
221 
222   Value output;
223   if (output_is_qtype && input_is_qtype) {
224     RankedTensorType rescale_type =
225         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
226     UniformQuantizedType input_qtype =
227         input_type.getElementType()
228             .dyn_cast<mlir::quant::UniformQuantizedType>();
229     UniformQuantizedType output_qtype =
230         output_type.getElementType()
231             .dyn_cast<mlir::quant::UniformQuantizedType>();
232     int64_t rescaled_6 = std::llround(6.0f / input_qtype.getScale()) +
233                          input_qtype.getZeroPoint();
234 
235     Value op1_rescale_in = buildRescaleToInt32(
236         rewriter, op, tfl_relu6_op.x(), 1.0f, input_qtype.getZeroPoint());
237     auto op2_relun_op1 = rewriter.create<tosa::ReluNOp>(
238         op->getLoc(), rescale_type, op1_rescale_in,
239         rewriter.getI64IntegerAttr(rescaled_6), rewriter.getF32FloatAttr(0.0f));
240     Value op3_rescale_op2 = buildRescaleFromInt32(
241         rewriter, op, output_type, op2_relun_op1.getResult(), 1.0f,
242         output_qtype.getZeroPoint());
243 
244     output = op3_rescale_op2;
245   } else {
246     auto op1_relun_in = rewriter.create<tosa::ReluNOp>(
247         op->getLoc(), output_type, tfl_relu6_op.x(),
248         rewriter.getI64IntegerAttr(0), rewriter.getF32FloatAttr(6.0f));
249 
250     output = op1_relun_in.getResult();
251   }
252 
253   rewriter.replaceOp(op, {output});
254   return success();
255 }
256 
257 // TODO: Use a utility function for common code in comparison ops.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const258 LogicalResult ConvertTFLEqualOp::matchAndRewrite(
259     Operation* op, PatternRewriter& rewriter) const {
260   auto tfl_equal_op = cast<TFL::EqualOp>(op);
261 
262   RankedTensorType input_x_type =
263       tfl_equal_op.x().getType().dyn_cast<RankedTensorType>();
264   RankedTensorType input_y_type =
265       tfl_equal_op.y().getType().dyn_cast<RankedTensorType>();
266   RankedTensorType output_type =
267       tfl_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
268   // Not a ranked tensor output
269   if (!input_x_type || !input_y_type || !output_type) return failure();
270 
271   bool input_x_is_qtype =
272       input_x_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
273   bool input_y_is_qtype =
274       input_y_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
275   bool output_is_qtype =
276       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
277 
278   if (input_x_is_qtype != output_is_qtype ||
279       input_y_is_qtype != output_is_qtype) {
280     return op->emitOpError(
281         "ConvertTFLEqualOp: input/output tensor should "
282         "be all quantized or all floating-point.");
283   }
284 
285   Value output;
286   if (output_is_qtype && input_x_is_qtype && input_y_is_qtype) {
287     UniformQuantizedType input_x_qtype =
288         input_x_type.getElementType()
289             .dyn_cast<mlir::quant::UniformQuantizedType>();
290     UniformQuantizedType input_y_qtype =
291         input_y_type.getElementType()
292             .dyn_cast<mlir::quant::UniformQuantizedType>();
293 
294     if (input_x_qtype.getScale() != input_y_qtype.getScale() ||
295         input_x_qtype.getZeroPoint() != input_y_qtype.getZeroPoint()) {
296       return op->emitOpError(
297           "ConvertTFLEqualOp: input_x and input_y scale/zp "
298           "must be the same");
299     }
300 
301     Value op1_rescale_x = buildRescaleToInt32(
302         rewriter, op, tfl_equal_op.x(), 1.0f, input_x_qtype.getZeroPoint());
303     Value op2_rescale_y = buildRescaleToInt32(
304         rewriter, op, tfl_equal_op.y(), 1.0f, input_y_qtype.getZeroPoint());
305     auto op3_equal_op1_op2 = rewriter.create<tosa::EqualOp>(
306         op->getLoc(), output_type, op1_rescale_x, op2_rescale_y);
307 
308     output = op3_equal_op1_op2.getResult();
309   } else {
310     auto op1_equal_in = rewriter.create<tosa::EqualOp>(
311         op->getLoc(), output_type, tfl_equal_op.x(), tfl_equal_op.y());
312 
313     output = op1_equal_in.getResult();
314   }
315 
316   rewriter.replaceOp(op, {output});
317   return success();
318 }
319 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const320 LogicalResult ConvertTFLNotEqualOp::matchAndRewrite(
321     Operation* op, PatternRewriter& rewriter) const {
322   auto tfl_not_equal_op = cast<TFL::NotEqualOp>(op);
323 
324   RankedTensorType input_lhs_type =
325       tfl_not_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
326   RankedTensorType input_rhs_type =
327       tfl_not_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
328   RankedTensorType output_type =
329       tfl_not_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
330   // Not a ranked tensor output
331   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
332 
333   bool input_lhs_is_qtype =
334       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
335   bool input_rhs_is_qtype =
336       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
337   bool output_is_qtype =
338       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
339 
340   if (input_lhs_is_qtype != output_is_qtype ||
341       input_rhs_is_qtype != output_is_qtype) {
342     return op->emitOpError(
343         "ConvertTFLNotEqualOp: input/output tensor should "
344         "be all quantized or all floating-point.");
345   }
346 
347   Value output;
348   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
349     UniformQuantizedType input_lhs_qtype =
350         input_lhs_type.getElementType()
351             .dyn_cast<mlir::quant::UniformQuantizedType>();
352     UniformQuantizedType input_rhs_qtype =
353         input_rhs_type.getElementType()
354             .dyn_cast<mlir::quant::UniformQuantizedType>();
355 
356     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
357         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
358       return op->emitOpError(
359           "ConvertTFLNotEqualOp: input_x and input_y scale/zp "
360           "must be the same");
361     }
362 
363     Value op1_rescale_lhs =
364         buildRescaleToInt32(rewriter, op, tfl_not_equal_op.lhs(), 1.0f,
365                             input_lhs_qtype.getZeroPoint());
366     Value op2_rescale_rhs =
367         buildRescaleToInt32(rewriter, op, tfl_not_equal_op.rhs(), 1.0f,
368                             input_rhs_qtype.getZeroPoint());
369     auto op3_equal_op1_op2 = rewriter.create<tosa::EqualOp>(
370         op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
371     auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
372         op->getLoc(), output_type, op3_equal_op1_op2.getResult());
373 
374     output = op4_not_op3.getResult();
375   } else {
376     auto op1_equal_in = rewriter.create<tosa::EqualOp>(
377         op->getLoc(), output_type, tfl_not_equal_op.lhs(),
378         tfl_not_equal_op.rhs());
379     auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
380         op->getLoc(), output_type, op1_equal_in.getResult());
381 
382     output = op2_not_op1.getResult();
383   }
384 
385   rewriter.replaceOp(op, {output});
386   return success();
387 }
388 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const389 LogicalResult ConvertTFLGreaterOp::matchAndRewrite(
390     Operation* op, PatternRewriter& rewriter) const {
391   auto tfl_greater_op = cast<TFL::GreaterOp>(op);
392 
393   RankedTensorType input_lhs_type =
394       tfl_greater_op.lhs().getType().dyn_cast<RankedTensorType>();
395   RankedTensorType input_rhs_type =
396       tfl_greater_op.rhs().getType().dyn_cast<RankedTensorType>();
397   RankedTensorType output_type =
398       tfl_greater_op.getResult().getType().dyn_cast<RankedTensorType>();
399   // Not a ranked tensor output
400   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
401 
402   bool input_lhs_is_qtype =
403       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
404   bool input_rhs_is_qtype =
405       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
406   bool output_is_qtype =
407       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
408 
409   if (input_lhs_is_qtype != output_is_qtype ||
410       input_rhs_is_qtype != output_is_qtype) {
411     return op->emitOpError(
412         "ConvertTFLGreaterOp: input/output tensor should "
413         "be all quantized or all floating-point.");
414   }
415 
416   Value output;
417   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
418     UniformQuantizedType input_lhs_qtype =
419         input_lhs_type.getElementType()
420             .dyn_cast<mlir::quant::UniformQuantizedType>();
421     UniformQuantizedType input_rhs_qtype =
422         input_rhs_type.getElementType()
423             .dyn_cast<mlir::quant::UniformQuantizedType>();
424 
425     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
426         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
427       return op->emitOpError(
428           "ConvertTFLGreaterOp: input_x and input_y scale/zp "
429           "must be the same");
430     }
431 
432     Value op1_rescale_lhs =
433         buildRescaleToInt32(rewriter, op, tfl_greater_op.lhs(), 1.0f,
434                             input_lhs_qtype.getZeroPoint());
435     Value op2_rescale_rhs =
436         buildRescaleToInt32(rewriter, op, tfl_greater_op.rhs(), 1.0f,
437                             input_rhs_qtype.getZeroPoint());
438     auto op3_greater_op1_op2 = rewriter.create<tosa::GreaterOp>(
439         op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
440 
441     output = op3_greater_op1_op2.getResult();
442   } else {
443     auto op1_greater_in = rewriter.create<tosa::GreaterOp>(
444         op->getLoc(), output_type, tfl_greater_op.lhs(), tfl_greater_op.rhs());
445 
446     output = op1_greater_in.getResult();
447   }
448 
449   rewriter.replaceOp(op, {output});
450   return success();
451 }
452 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const453 LogicalResult ConvertTFLGreaterEqualOp::matchAndRewrite(
454     Operation* op, PatternRewriter& rewriter) const {
455   auto tfl_greater_equal_op = cast<TFL::GreaterEqualOp>(op);
456 
457   RankedTensorType input_lhs_type =
458       tfl_greater_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
459   RankedTensorType input_rhs_type =
460       tfl_greater_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
461   RankedTensorType output_type =
462       tfl_greater_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
463   // Not a ranked tensor output
464   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
465 
466   bool input_lhs_is_qtype =
467       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
468   bool input_rhs_is_qtype =
469       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
470   bool output_is_qtype =
471       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
472 
473   if (input_lhs_is_qtype != output_is_qtype ||
474       input_rhs_is_qtype != output_is_qtype) {
475     return op->emitOpError(
476         "ConvertTFLGreaterEqualOp: input/output tensor should "
477         "be all quantized or all floating-point.");
478   }
479 
480   Value output;
481   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
482     UniformQuantizedType input_lhs_qtype =
483         input_lhs_type.getElementType()
484             .dyn_cast<mlir::quant::UniformQuantizedType>();
485     UniformQuantizedType input_rhs_qtype =
486         input_rhs_type.getElementType()
487             .dyn_cast<mlir::quant::UniformQuantizedType>();
488 
489     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
490         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
491       return op->emitOpError(
492           "ConvertTFLGreaterEqualOp: input_x and input_y scale/zp "
493           "must be the same");
494     }
495 
496     Value op1_rescale_lhs =
497         buildRescaleToInt32(rewriter, op, tfl_greater_equal_op.lhs(), 1.0f,
498                             input_lhs_qtype.getZeroPoint());
499     Value op2_rescale_rhs =
500         buildRescaleToInt32(rewriter, op, tfl_greater_equal_op.rhs(), 1.0f,
501                             input_rhs_qtype.getZeroPoint());
502     auto op3_greater_equal_op1_op2 = rewriter.create<tosa::GreaterEqualOp>(
503         op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
504 
505     output = op3_greater_equal_op1_op2.getResult();
506   } else {
507     auto op1_greater_equal_in = rewriter.create<tosa::GreaterEqualOp>(
508         op->getLoc(), output_type, tfl_greater_equal_op.lhs(),
509         tfl_greater_equal_op.rhs());
510 
511     output = op1_greater_equal_in.getResult();
512   }
513 
514   rewriter.replaceOp(op, {output});
515   return success();
516 }
517 
518 // TODO: Use a utility function for common code in elementwise binary ops.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const519 LogicalResult ConvertTFLAddOp::matchAndRewrite(
520     Operation* op, PatternRewriter& rewriter) const {
521   auto tfl_add_op = cast<TFL::AddOp>(op);
522 
523   RankedTensorType input_lhs_type =
524       tfl_add_op.lhs().getType().dyn_cast<RankedTensorType>();
525   RankedTensorType input_rhs_type =
526       tfl_add_op.rhs().getType().dyn_cast<RankedTensorType>();
527   RankedTensorType output_type =
528       tfl_add_op.getResult().getType().dyn_cast<RankedTensorType>();
529   // Not a ranked tensor output
530   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
531 
532   bool input_lhs_is_qtype =
533       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
534   bool input_rhs_is_qtype =
535       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
536   bool output_is_qtype =
537       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
538 
539   if (input_lhs_is_qtype != output_is_qtype ||
540       input_rhs_is_qtype != output_is_qtype) {
541     return op->emitOpError(
542         "ConvertTFLAddOp: input/output tensor should "
543         "be all quantized or all floating-point.");
544   }
545 
546   Value output;
547   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
548     RankedTensorType rescale_type =
549         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
550     UniformQuantizedType input_lhs_qtype =
551         input_lhs_type.getElementType()
552             .dyn_cast<mlir::quant::UniformQuantizedType>();
553     UniformQuantizedType input_rhs_qtype =
554         input_rhs_type.getElementType()
555             .dyn_cast<mlir::quant::UniformQuantizedType>();
556     UniformQuantizedType output_qtype =
557         output_type.getElementType()
558             .dyn_cast<mlir::quant::UniformQuantizedType>();
559 
560     // Following quantization described in tensorflow/lite/kernels/add.cc
561     // In details it does:
562     // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
563     // 2. Extra left shift to input to increase precision
564     // Where input_shift = 20 if input is 8-bit
565     // input_shift = 15 if input is 16-bit
566     // TODO: support 16-bit
567     double in_lhs_scale = input_lhs_qtype.getScale();
568     double in_rhs_scale = input_rhs_qtype.getScale();
569     double output_scale = output_qtype.getScale();
570     double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
571 
572     const int32_t SHIFT_8_BIT = 20;
573     int32_t input_shift = SHIFT_8_BIT;
574 
575     double lhs_rescale_scale =
576         static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
577     double rhs_rescale_scale =
578         static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
579     double output_rescale_scale =
580         max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
581 
582     Value op1_rescale_lhs =
583         buildRescaleToInt32(rewriter, op, tfl_add_op.lhs(), lhs_rescale_scale,
584                             input_lhs_qtype.getZeroPoint());
585     Value op2_rescale_rhs =
586         buildRescaleToInt32(rewriter, op, tfl_add_op.rhs(), rhs_rescale_scale,
587                             input_rhs_qtype.getZeroPoint());
588     auto op3_add_op1_op2 = rewriter.create<tosa::AddOp>(
589         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
590     Value op4_rescale_op3 = buildRescaleFromInt32(
591         rewriter, op, output_type, op3_add_op1_op2.getResult(),
592         output_rescale_scale, output_qtype.getZeroPoint());
593     output = op4_rescale_op3;
594   } else {
595     auto op1_add_in = rewriter.create<tosa::AddOp>(
596         op->getLoc(), output_type, tfl_add_op.lhs(), tfl_add_op.rhs());
597 
598     output = op1_add_in.getResult();
599   }
600 
601   auto fused_activation_fn = tfl_add_op.fused_activation_functionAttr();
602 
603   if (fused_activation_fn) {
604     llvm::Optional<Value> fused_activation_val =
605         convertFusedActivation(rewriter, op, output, fused_activation_fn);
606 
607     if (!fused_activation_val) return failure();
608 
609     rewriter.replaceOp(op, {fused_activation_val.getValue()});
610     return success();
611   }
612 
613   rewriter.replaceOp(op, {output});
614   return success();
615 }
616 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const617 LogicalResult ConvertTFLSubOp::matchAndRewrite(
618     Operation* op, PatternRewriter& rewriter) const {
619   auto tfl_sub_op = cast<TFL::SubOp>(op);
620 
621   RankedTensorType input_lhs_type =
622       tfl_sub_op.lhs().getType().dyn_cast<RankedTensorType>();
623   RankedTensorType input_rhs_type =
624       tfl_sub_op.rhs().getType().dyn_cast<RankedTensorType>();
625   RankedTensorType output_type =
626       tfl_sub_op.getResult().getType().dyn_cast<RankedTensorType>();
627   // Not a ranked tensor output
628   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
629 
630   bool input_lhs_is_qtype =
631       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
632   bool input_rhs_is_qtype =
633       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
634   bool output_is_qtype =
635       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
636 
637   if (input_lhs_is_qtype != output_is_qtype ||
638       input_rhs_is_qtype != output_is_qtype) {
639     return op->emitOpError(
640         "ConvertTFLSubOp: input/output tensor should "
641         "be all quantized or all floating-point.");
642   }
643 
644   Value output;
645   if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
646     RankedTensorType rescale_type =
647         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
648     UniformQuantizedType input_lhs_qtype =
649         input_lhs_type.getElementType()
650             .cast<mlir::quant::UniformQuantizedType>();
651     UniformQuantizedType input_rhs_qtype =
652         input_rhs_type.getElementType()
653             .cast<mlir::quant::UniformQuantizedType>();
654     UniformQuantizedType output_qtype =
655         output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
656 
657     // Following quantization described in tensorflow/lite/kernels/add.cc
658     // In details it does:
659     // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
660     // 2. Extra left shift to input to increase precision
661     // Where input_shift = 20 if input is 8-bit
662     // input_shift = 15 if input is 16-bit
663     // TODO: support 16-bit
664     double in_lhs_scale = input_lhs_qtype.getScale();
665     double in_rhs_scale = input_rhs_qtype.getScale();
666     double output_scale = output_qtype.getScale();
667     double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
668 
669     const int32_t SHIFT_8_BIT = 20;
670     int32_t input_shift = SHIFT_8_BIT;
671 
672     double lhs_rescale_scale =
673         static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
674     double rhs_rescale_scale =
675         static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
676     double output_rescale_scale =
677         max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
678 
679     Value op1_rescale_lhs =
680         buildRescaleToInt32(rewriter, op, tfl_sub_op.lhs(), lhs_rescale_scale,
681                             input_lhs_qtype.getZeroPoint());
682     Value op2_rescale_rhs =
683         buildRescaleToInt32(rewriter, op, tfl_sub_op.rhs(), rhs_rescale_scale,
684                             input_rhs_qtype.getZeroPoint());
685     auto op3_sub_op1_op2 = rewriter.create<tosa::SubOp>(
686         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
687     Value op4_rescale_op3 = buildRescaleFromInt32(
688         rewriter, op, output_type, op3_sub_op1_op2.getResult(),
689         output_rescale_scale, output_qtype.getZeroPoint());
690     output = op4_rescale_op3;
691   } else {
692     auto op1_sub_in = rewriter.create<tosa::SubOp>(
693         op->getLoc(), output_type, tfl_sub_op.lhs(), tfl_sub_op.rhs());
694 
695     output = op1_sub_in.getResult();
696   }
697 
698   auto fused_activation_fn = tfl_sub_op.fused_activation_functionAttr();
699 
700   if (fused_activation_fn) {
701     llvm::Optional<Value> fused_activation_val =
702         convertFusedActivation(rewriter, op, output, fused_activation_fn);
703 
704     if (!fused_activation_val) return failure();
705 
706     rewriter.replaceOp(op, {fused_activation_val.getValue()});
707     return success();
708   }
709 
710   rewriter.replaceOp(op, {output});
711   return success();
712 }
713 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const714 LogicalResult ConvertTFLMulOp::matchAndRewrite(
715     Operation* op, PatternRewriter& rewriter) const {
716   auto tfl_mul_op = cast<TFL::MulOp>(op);
717 
718   llvm::Optional<Value> result = convertMultiplyOp(
719       rewriter, op, tfl_mul_op.getResult(), tfl_mul_op.lhs(), tfl_mul_op.rhs());
720 
721   if (!result) return failure();
722 
723   auto fused_activation_fn = tfl_mul_op.fused_activation_functionAttr();
724 
725   if (fused_activation_fn) {
726     llvm::Optional<Value> fused_activation_val = convertFusedActivation(
727         rewriter, op, result.getValue(), fused_activation_fn);
728 
729     if (!fused_activation_val) return failure();
730 
731     rewriter.replaceOp(op, {fused_activation_val.getValue()});
732     return success();
733   }
734 
735   rewriter.replaceOp(op, {result.getValue()});
736   return success();
737 }
738 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const739 LogicalResult ConvertTFLSquareOp::matchAndRewrite(
740     Operation* op, PatternRewriter& rewriter) const {
741   auto tfl_square_op = cast<TFL::SquareOp>(op);
742 
743   llvm::Optional<Value> result =
744       convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
745                         tfl_square_op.x(), tfl_square_op.x());
746 
747   if (!result) return failure();
748 
749   rewriter.replaceOp(op, {result.getValue()});
750   return success();
751 }
752 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const753 LogicalResult ConvertTFLSquaredDifferenceOp::matchAndRewrite(
754     Operation* op, PatternRewriter& rewriter) const {
755   auto tfl_squared_op = cast<TFL::SquaredDifferenceOp>(op);
756 
757   llvm::Optional<Value> result =
758       convertSquaredDifferenceOp(rewriter, op, tfl_squared_op.getResult(),
759                                  tfl_squared_op.lhs(), tfl_squared_op.rhs());
760 
761   if (!result) return failure();
762 
763   rewriter.replaceOp(op, {result.getValue()});
764   return success();
765 }
766 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const767 LogicalResult ConvertTFLRoundOp::matchAndRewrite(
768     Operation* op, PatternRewriter& rewriter) const {
769   auto tfl_round_op = cast<TFL::RoundOp>(op);
770 
771   RankedTensorType input_type =
772       tfl_round_op.x().getType().dyn_cast<RankedTensorType>();
773   if (!input_type) {
774     return op->emitOpError("Round: input not ranked tensor type");
775   }
776 
777   if (input_type.getElementType().isa<FloatType>()) {
778     llvm::Optional<Value> result = convertRoundOp(
779         rewriter, op, tfl_round_op.getResult(), tfl_round_op.x());
780 
781     if (!result) return failure();
782 
783     rewriter.replaceOp(op, {result.getValue()});
784     return success();
785 
786   } else {
787     // Round on int is nonsensical. Instead, replace uses of result with the
788     // input.
789     tfl_round_op.replaceAllUsesWith(tfl_round_op.x());
790     return success();
791   }
792 }
793 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const794 LogicalResult ConvertTFLDivOp::matchAndRewrite(
795     Operation* op, PatternRewriter& rewriter) const {
796   auto tfl_div_op = cast<TFL::DivOp>(op);
797 
798   RankedTensorType output_type =
799       tfl_div_op.getResult().getType().dyn_cast<RankedTensorType>();
800   // Not a ranked tensor output
801   if (!output_type) return failure();
802 
803   auto fused_activation_fn = tfl_div_op.fused_activation_functionAttr();
804 
805   auto reciprocal_op = rewriter.create<tosa::ReciprocalOp>(
806       op->getLoc(), output_type, tfl_div_op.rhs());
807   auto mul_op =
808       rewriter.create<tosa::MulOp>(op->getLoc(), output_type, tfl_div_op.lhs(),
809                                    reciprocal_op.getResult(), 0);
810 
811   if (fused_activation_fn) {
812     llvm::Optional<Value> fused_activation_val = convertFusedActivation(
813         rewriter, op, mul_op.getResult(), fused_activation_fn);
814 
815     if (!fused_activation_val) return failure();
816 
817     rewriter.replaceOp(op, {fused_activation_val.getValue()});
818     return success();
819   }
820 
821   rewriter.replaceOp(op, {mul_op.getResult()});
822 
823   return success();
824 }
825 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const826 LogicalResult ConvertTFLMaximumOp::matchAndRewrite(
827     Operation* op, PatternRewriter& rewriter) const {
828   auto tfl_max_op = cast<TFL::MaximumOp>(op);
829 
830   RankedTensorType input_lhs_type =
831       tfl_max_op.lhs().getType().dyn_cast<RankedTensorType>();
832   RankedTensorType input_rhs_type =
833       tfl_max_op.rhs().getType().dyn_cast<RankedTensorType>();
834   RankedTensorType output_type =
835       tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
836 
837   // Not a ranked tensor output
838   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
839 
840   bool input_lhs_is_qtype =
841       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
842   bool input_rhs_is_qtype =
843       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
844   bool output_is_qtype =
845       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
846 
847   if (input_lhs_is_qtype != output_is_qtype ||
848       input_rhs_is_qtype != output_is_qtype) {
849     return op->emitOpError(
850         "ConvertTFLMaximumOp: input/output tensor should "
851         "be all quantized or all floating-point.");
852   }
853 
854   Value output;
855   if (output_is_qtype) {
856     RankedTensorType rescale_type =
857         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
858 
859     Value op1_rescale_lhs =
860         buildRescaleToInt32(rewriter, op, tfl_max_op.lhs(), 1.0f, 0);
861     Value op2_rescale_rhs =
862         buildRescaleToInt32(rewriter, op, tfl_max_op.rhs(), 1.0f, 0);
863     auto op3_max_op1_op2 = rewriter.create<tosa::MaximumOp>(
864         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
865     Value op4_rescale_op3 = buildRescaleFromInt32(
866         rewriter, op, output_type, op3_max_op1_op2.getResult(), 1.0f, 0);
867 
868     output = op4_rescale_op3;
869   } else {
870     auto op1_max_in = rewriter.create<tosa::MaximumOp>(
871         op->getLoc(), output_type, tfl_max_op.lhs(), tfl_max_op.rhs());
872 
873     output = op1_max_in.getResult();
874   }
875 
876   rewriter.replaceOp(op, {output});
877 
878   return success();
879 }
880 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const881 LogicalResult ConvertTFLMinimumOp::matchAndRewrite(
882     Operation* op, PatternRewriter& rewriter) const {
883   auto tfl_min_op = cast<TFL::MinimumOp>(op);
884 
885   RankedTensorType input_lhs_type =
886       tfl_min_op.lhs().getType().dyn_cast<RankedTensorType>();
887   RankedTensorType input_rhs_type =
888       tfl_min_op.rhs().getType().dyn_cast<RankedTensorType>();
889   RankedTensorType output_type =
890       tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
891   // Not a ranked tensor output
892   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
893 
894   bool input_lhs_is_qtype =
895       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
896   bool input_rhs_is_qtype =
897       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
898   bool output_is_qtype =
899       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
900 
901   if (input_lhs_is_qtype != output_is_qtype ||
902       input_rhs_is_qtype != output_is_qtype) {
903     return op->emitOpError(
904         "ConvertTFLMinimumOp: input/output tensor should "
905         "be all quantized or all floating-point.");
906   }
907 
908   Value output;
909   if (output_is_qtype) {
910     RankedTensorType rescale_type =
911         RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
912 
913     Value op1_rescale_lhs =
914         buildRescaleToInt32(rewriter, op, tfl_min_op.lhs(), 1.0f, 0);
915     Value op2_rescale_rhs =
916         buildRescaleToInt32(rewriter, op, tfl_min_op.rhs(), 1.0f, 0);
917     auto op3_min_op1_op2 = rewriter.create<tosa::MinimumOp>(
918         op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
919     Value op4_rescale_op3 = buildRescaleFromInt32(
920         rewriter, op, output_type, op3_min_op1_op2.getResult(), 1.0f, 0);
921 
922     output = op4_rescale_op3;
923   } else {
924     auto op1_min_in = rewriter.create<tosa::MinimumOp>(
925         op->getLoc(), output_type, tfl_min_op.lhs(), tfl_min_op.rhs());
926 
927     output = op1_min_in.getResult();
928   }
929 
930   rewriter.replaceOp(op, {output});
931 
932   return success();
933 }
934 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const935 LogicalResult ConvertTFLFloorDivOp::matchAndRewrite(
936     Operation* op, PatternRewriter& rewriter) const {
937   auto tfl_floordiv_op = cast<TFL::FloorDivOp>(op);
938 
939   llvm::Optional<Value> result =
940       convertFloorDivOp(rewriter, op, tfl_floordiv_op.getResult(),
941                         tfl_floordiv_op.lhs(), tfl_floordiv_op.rhs());
942 
943   if (!result) return failure();
944 
945   rewriter.replaceOp(op, {result.getValue()});
946 
947   return success();
948 }
949 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const950 LogicalResult ConvertTFLFloorModOp::matchAndRewrite(
951     Operation* op, PatternRewriter& rewriter) const {
952   auto tfl_floormod_op = cast<TFL::FloorModOp>(op);
953 
954   llvm::Optional<Value> result =
955       convertFloorModOp(rewriter, op, tfl_floormod_op.getResult(),
956                         tfl_floormod_op.lhs(), tfl_floormod_op.rhs());
957 
958   if (!result) return failure();
959 
960   rewriter.replaceOp(op, {result.getValue()});
961 
962   return success();
963 }
964 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const965 LogicalResult ConvertTFLAddNOp::matchAndRewrite(
966     Operation* op, PatternRewriter& rewriter) const {
967   auto tfl_addn_op = cast<TFL::AddNOp>(op);
968 
969   RankedTensorType output_type =
970       tfl_addn_op.getResult().getType().dyn_cast<RankedTensorType>();
971   // Not a ranked tensor output
972   if (!output_type) return failure();
973 
974   SmallVector<Value, 4> inputs(tfl_addn_op.inputs());
975 
976   assert(inputs.size() >= 2);
977 
978   auto newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type,
979                                             inputs[0], inputs[1]);
980   for (int i = 2; i < inputs.size(); i++) {
981     newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type, inputs[i],
982                                          newOp.getResult());
983   }
984 
985   rewriter.replaceOp(op, {newOp.getResult()});
986 
987   return success();
988 }
989 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const990 LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite(
991     Operation* op, PatternRewriter& rewriter) const {
992   auto tfl_avgpool_op = cast<TFL::AveragePool2DOp>(op);
993 
994   RankedTensorType input_type =
995       tfl_avgpool_op.input().getType().dyn_cast<RankedTensorType>();
996   RankedTensorType output_type =
997       tfl_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
998   // Not a ranked tensor output
999   if (!output_type) return failure();
1000 
1001   // Kernels and strides are dimensionally ordered
1002   SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
1003   ArrayAttr kernel_size;
1004   ArrayAttr stride;
1005   ArrayAttr pad;
1006   {
1007     int64_t kernel_h = tfl_avgpool_op.filter_height();
1008     int64_t kernel_w = tfl_avgpool_op.filter_width();
1009     kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
1010     // i64array is formatted as NHWC now
1011     i64array[1] = kernel_h;
1012     i64array[2] = kernel_w;
1013   }
1014   {
1015     int64_t stride_h = tfl_avgpool_op.stride_h();
1016     int64_t stride_w = tfl_avgpool_op.stride_w();
1017     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1018   }
1019   {
1020     tensorflow::Padding tf_pad;
1021     if (!GetPaddingFromString(tfl_avgpool_op.padding().str(), &tf_pad).ok())
1022       return failure();
1023 
1024     // Pooling has no non-unit dilation
1025     ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
1026 
1027     RankedTensorType filter_type = RankedTensorType::get(
1028         llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
1029 
1030     // TFLite doesn't support explicit padding
1031     if (!getPaddingValuesFromPadType(
1032             tf_pad,
1033             tensorflow::FORMAT_NHWC,  // TFLite only supports this
1034             1,                        // tensorflow::FORMAT_OHWI,
1035             input_type, filter_type, stride, dilation, rewriter, pad))
1036       return failure();
1037   }
1038 
1039   rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
1040       op, output_type, tfl_avgpool_op.input(), kernel_size, stride, pad);
1041   return success();
1042 }
1043 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1044 LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite(
1045     Operation* op, PatternRewriter& rewriter) const {
1046   auto tfl_maxpool_op = cast<TFL::MaxPool2DOp>(op);
1047 
1048   RankedTensorType input_type =
1049       tfl_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
1050   RankedTensorType output_type =
1051       tfl_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
1052   // Not a ranked tensor output
1053   if (!output_type) return failure();
1054 
1055   // Kernels and strides are dimensionally ordered
1056   SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
1057   ArrayAttr kernel_size;
1058   ArrayAttr stride;
1059   ArrayAttr pad;
1060   {
1061     int64_t kernel_h = tfl_maxpool_op.filter_height();
1062     int64_t kernel_w = tfl_maxpool_op.filter_width();
1063     kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
1064     // i64array is formatted as NHWC now
1065     i64array[1] = kernel_h;
1066     i64array[2] = kernel_w;
1067   }
1068   {
1069     int64_t stride_h = tfl_maxpool_op.stride_h();
1070     int64_t stride_w = tfl_maxpool_op.stride_w();
1071     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1072   }
1073   {
1074     tensorflow::Padding tf_pad;
1075     if (!GetPaddingFromString(tfl_maxpool_op.padding().str(), &tf_pad).ok())
1076       return failure();
1077 
1078     // Pooling has no non-unit dilation
1079     ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
1080 
1081     RankedTensorType filter_type = RankedTensorType::get(
1082         llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
1083 
1084     // TFLite doesn't support explicit padding
1085     if (!getPaddingValuesFromPadType(
1086             tf_pad,
1087             tensorflow::FORMAT_NHWC,  // TFLite only supports this
1088             1,                        // tensorflow::FORMAT_OHWI,
1089             input_type, filter_type, stride, dilation, rewriter, pad))
1090       return failure();
1091   }
1092 
1093   rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
1094       op, output_type, tfl_maxpool_op.input(), kernel_size, stride, pad);
1095   return success();
1096 }
1097 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1098 LogicalResult ConvertTFLConv2DOp::matchAndRewrite(
1099     Operation* op, PatternRewriter& rewriter) const {
1100   auto tfl_conv2d_op = cast<TFL::Conv2DOp>(op);
1101 
1102   RankedTensorType input_type =
1103       tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
1104   RankedTensorType filter_type =
1105       tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
1106   RankedTensorType output_type =
1107       tfl_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
1108   // Not a ranked tensor output
1109   if (!input_type) return failure();
1110   if (!output_type) return failure();
1111   if (!filter_type) return failure();
1112 
1113   bool input_is_qtype =
1114       input_type.getElementType().isa<mlir::quant::QuantizedType>();
1115   bool filter_is_qtype =
1116       filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1117   bool output_is_qtype =
1118       output_type.getElementType().isa<mlir::quant::QuantizedType>();
1119 
1120   if ((input_is_qtype != filter_is_qtype) ||
1121       (input_is_qtype != output_is_qtype)) {
1122     return op->emitOpError(
1123         "ConvertTFLConv2DOp: input/filter/output tensor should "
1124         "be all quantized or all floating-point.");
1125   }
1126 
1127   ArrayAttr pad;
1128   ArrayAttr stride;
1129   ArrayAttr dilation;
1130   {
1131     int64_t stride_h = tfl_conv2d_op.stride_h();
1132     int64_t stride_w = tfl_conv2d_op.stride_w();
1133     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1134   }
1135   {
1136     int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
1137     int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
1138     dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
1139   }
1140   {
1141     tensorflow::Padding tf_pad;
1142     if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
1143       return failure();
1144 
1145     // TFLite doesn't support explicit padding
1146     if (!getPaddingValuesFromPadType(
1147             tf_pad,
1148             tensorflow::FORMAT_NHWC,  // TFLite only supports this
1149             1,                        // tensorflow::FORMAT_OHWI,
1150             input_type, filter_type, stride, dilation, rewriter, pad))
1151       return failure();
1152   }
1153 
1154   Value unquantized_bias =
1155       getUnquantizedBias(rewriter, op, tfl_conv2d_op.bias());
1156 
1157   auto a1_conv2d_op = rewriter.create<tosa::Conv2DOp>(
1158       op->getLoc(), output_type, tfl_conv2d_op.input(), tfl_conv2d_op.filter(),
1159       unquantized_bias, pad, stride, dilation);
1160 
1161   Value conv2d_output;
1162   if (input_is_qtype) {
1163     conv2d_output =
1164         buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
1165                                  input_type, filter_type, output_type);
1166   } else {
1167     conv2d_output = a1_conv2d_op.getResult();
1168   }
1169 
1170   auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
1171 
1172   if (fused_activation_fn) {
1173     llvm::Optional<Value> fused_activation_val = convertFusedActivation(
1174         rewriter, op, conv2d_output, fused_activation_fn);
1175 
1176     if (!fused_activation_val) return failure();
1177 
1178     rewriter.replaceOp(op, {fused_activation_val.getValue()});
1179     return success();
1180   }
1181 
1182   rewriter.replaceOp(op, {conv2d_output});
1183 
1184   return success();
1185 }
1186 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1187 LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite(
1188     Operation* op, PatternRewriter& rewriter) const {
1189   auto tfl_conv_op = cast<TFL::TransposeConvOp>(op);
1190 
1191   RankedTensorType input_type =
1192       tfl_conv_op.input().getType().dyn_cast<RankedTensorType>();
1193   RankedTensorType filter_type =
1194       tfl_conv_op.weights().getType().dyn_cast<RankedTensorType>();
1195   RankedTensorType output_type =
1196       tfl_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
1197   // Not a ranked tensor output
1198   if (!input_type) return failure();
1199   if (!output_type) return failure();
1200   if (!filter_type) return failure();
1201 
1202   bool input_is_qtype =
1203       input_type.getElementType().isa<mlir::quant::QuantizedType>();
1204   bool filter_is_qtype =
1205       filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1206   bool output_is_qtype =
1207       output_type.getElementType().isa<mlir::quant::QuantizedType>();
1208 
1209   if ((input_is_qtype != filter_is_qtype) ||
1210       (input_is_qtype != output_is_qtype)) {
1211     return op->emitOpError(
1212         "ConvertTFLConv2DOp: input/filter/output tensor should "
1213         "be all quantized or all floating-point.");
1214   }
1215 
1216   ArrayAttr stride;
1217   ArrayAttr dilation;
1218   ArrayAttr outpad;
1219   ArrayAttr output_shape;
1220   {
1221     int64_t stride_h = tfl_conv_op.stride_h();
1222     int64_t stride_w = tfl_conv_op.stride_w();
1223     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1224   }
1225 
1226   // tfl.transpose_conv doesn't support dilations
1227   dilation = rewriter.getI64ArrayAttr({1, 1});
1228 
1229   {
1230     tensorflow::Padding tf_pad;
1231     if (!GetPaddingFromString(tfl_conv_op.padding().str(), &tf_pad).ok())
1232       return failure();
1233 
1234     if (!getTransposeConv2dPaddingValues(
1235             tf_pad,
1236             tensorflow::FORMAT_NHWC,  // TFLite only supports this
1237             1,                        // tensorflow::FORMAT_OHWI,
1238             input_type, filter_type, output_type, stride, dilation, rewriter,
1239             outpad))
1240       return failure();
1241   }
1242   {
1243     ElementsAttr output_shape_elems;
1244     // Match from input_size tensor first
1245     if (matchPattern(tfl_conv_op.output_shape(),
1246                      m_Constant(&output_shape_elems))) {
1247       llvm::SmallVector<int64_t, 4> shape_vec;
1248       for (int i = 0; i < output_shape_elems.getNumElements(); i++)
1249         shape_vec.push_back(
1250             output_shape_elems.getValue<IntegerAttr>(i).getInt());
1251       output_shape = rewriter.getI64ArrayAttr(shape_vec);
1252     } else {
1253       // Use output tensor's shape otherwise
1254       output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
1255     }
1256   }
1257 
1258   Value zero_bias;
1259   if (input_is_qtype) {
1260     uint32_t input_bits = input_type.getElementType()
1261                               .dyn_cast<mlir::quant::QuantizedType>()
1262                               .getStorageTypeIntegralWidth();
1263     uint32_t weight_bits = filter_type.getElementType()
1264                                .dyn_cast<mlir::quant::QuantizedType>()
1265                                .getStorageTypeIntegralWidth();
1266 
1267     if (input_bits == 16 && weight_bits == 8) {
1268       SmallVector<int64_t, 8> zero_bias_vec(output_type.getShape()[3], 0);
1269       zero_bias = get1DConstTensorInt48(rewriter, op, zero_bias_vec);
1270     } else {
1271       SmallVector<int32_t, 8> zero_bias_vec(output_type.getShape()[3], 0);
1272       zero_bias =
1273           get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, zero_bias_vec);
1274     }
1275   } else {
1276     SmallVector<float, 8> zero_bias_vec(output_type.getShape()[3], 0.0f);
1277     zero_bias =
1278         get1DConstTensor<tosa::ConstOp, float>(rewriter, op, zero_bias_vec);
1279   }
1280 
1281   auto a1_conv2d_op = rewriter.create<tosa::TransposeConv2DOp>(
1282       op->getLoc(), output_type, tfl_conv_op.input(), tfl_conv_op.weights(),
1283       zero_bias, outpad, stride, dilation, output_shape);
1284 
1285   Value conv2d_output;
1286   if (input_is_qtype) {
1287     conv2d_output =
1288         buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
1289                                  input_type, filter_type, output_type);
1290   } else {
1291     conv2d_output = a1_conv2d_op.getResult();
1292   }
1293 
1294   rewriter.replaceOp(op, {conv2d_output});
1295 
1296   return success();
1297 }
1298 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1299 LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite(
1300     Operation* op, PatternRewriter& rewriter) const {
1301   auto tfl_conv2d_op = cast<TFL::DepthwiseConv2DOp>(op);
1302 
1303   RankedTensorType input_type =
1304       tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
1305   RankedTensorType filter_type =
1306       tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
1307   RankedTensorType output_type =
1308       tfl_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
1309   // Not a ranked tensor output
1310   if (!input_type) return failure();
1311   if (!output_type) return failure();
1312   if (!filter_type) return failure();
1313 
1314   bool input_is_qtype =
1315       input_type.getElementType().isa<mlir::quant::QuantizedType>();
1316   bool filter_is_qtype =
1317       filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1318   bool output_is_qtype =
1319       output_type.getElementType().isa<mlir::quant::QuantizedType>();
1320 
1321   if ((input_is_qtype != filter_is_qtype) ||
1322       (input_is_qtype != output_is_qtype)) {
1323     return op->emitOpError(
1324         "ConvertTFLConv2DOp: input/filter/output tensor should "
1325         "be all quantized or all floating-point.");
1326   }
1327 
1328   auto filter_shape = filter_type.getShape();
1329   // Operator depthwiseConv2D
1330   // TFLite orders the depthwiseConv2D filter in IHWO, while TOSA orders
1331   // filter in HWIO
1332   //
1333   // The lowering reorders the filter.
1334   //
1335   // a1_transpose = tosa.transpose(filter, {1, 2, 3, 0})   // HWIO
1336   // a2_reshape = tosa.reshape(filter, H, W, depth_multiplier, I /
1337   // depth_multiplier)
1338   // a3_transpose_conv2d = tosa.transpose_conv2d(input, a2_reshape, padding,
1339   // stride, dilation)
1340 
1341   ArrayAttr pad;
1342   ArrayAttr stride;
1343   ArrayAttr dilation;
1344   auto depth_multiplier = tfl_conv2d_op.depth_multiplierAttr();
1345 
1346   {
1347     int64_t stride_h = tfl_conv2d_op.stride_h();
1348     int64_t stride_w = tfl_conv2d_op.stride_w();
1349     stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
1350   }
1351   {
1352     int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
1353     int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
1354     dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
1355   }
1356   {
1357     tensorflow::Padding tf_pad;
1358     if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
1359       return failure();
1360 
1361     if (!getPaddingValuesFromPadType(
1362             tf_pad,
1363             tensorflow::FORMAT_NHWC,  // TFLite only supports this
1364             1,                        // tensorflow::FORMAT_OHWI,
1365             input_type, filter_type, stride, dilation, rewriter, pad))
1366       return failure();
1367   }
1368 
1369   llvm::SmallVector<int64_t, 4> a1_transpose_dims;
1370   a1_transpose_dims.push_back(filter_shape[1]);
1371   a1_transpose_dims.push_back(filter_shape[2]);
1372   a1_transpose_dims.push_back(filter_shape[3]);
1373   a1_transpose_dims.push_back(filter_shape[0]);
1374 
1375   llvm::SmallVector<int64_t, 4> a2_reshape_dims;
1376   a2_reshape_dims.push_back(a1_transpose_dims[0]);
1377   a2_reshape_dims.push_back(a1_transpose_dims[1]);
1378   a2_reshape_dims.push_back(a1_transpose_dims[2] / depth_multiplier.getInt());
1379   a2_reshape_dims.push_back(depth_multiplier.getInt());
1380 
1381   Value a1_filter_transpose_perms =
1382       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {1, 2, 3, 0});
1383   auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
1384       op->getLoc(),
1385       RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
1386                             filter_type.getElementType()),
1387       tfl_conv2d_op.filter(), a1_filter_transpose_perms);
1388 
1389   auto a2_filter_reshape_op = rewriter.create<tosa::ReshapeOp>(
1390       op->getLoc(),
1391       RankedTensorType::get(ArrayRef<int64_t>(a2_reshape_dims),
1392                             filter_type.getElementType()),
1393       a1_filter_transpose_op.getResult(),
1394       rewriter.getI64ArrayAttr(a2_reshape_dims));
1395 
1396   Value unquantized_bias =
1397       getUnquantizedBias(rewriter, op, tfl_conv2d_op.bias());
1398 
1399   auto a3_depthwise_conv2d_op = rewriter.create<tosa::DepthwiseConv2DOp>(
1400       op->getLoc(), output_type, tfl_conv2d_op.input(),
1401       a2_filter_reshape_op.getResult(), unquantized_bias, pad, stride,
1402       dilation);
1403 
1404   Value conv2d_output;
1405   if (input_is_qtype) {
1406     conv2d_output = buildRescaleOpConvOutput(
1407         rewriter, op, a3_depthwise_conv2d_op.getResult(), input_type,
1408         filter_type, output_type);
1409   } else {
1410     conv2d_output = a3_depthwise_conv2d_op.getResult();
1411   }
1412 
1413   auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
1414 
1415   if (fused_activation_fn) {
1416     llvm::Optional<Value> fused_activation_val = convertFusedActivation(
1417         rewriter, op, conv2d_output, fused_activation_fn);
1418 
1419     if (!fused_activation_val) return failure();
1420 
1421     rewriter.replaceOp(op, {fused_activation_val.getValue()});
1422     return success();
1423   }
1424 
1425   rewriter.replaceOp(op, {conv2d_output});
1426 
1427   return success();
1428 }
1429 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1430 LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite(
1431     Operation* op, PatternRewriter& rewriter) const {
1432   auto tfl_fc_op = cast<TFL::FullyConnectedOp>(op);
1433 
1434   RankedTensorType output_type =
1435       tfl_fc_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1436   // Not a ranked tensor output
1437   if (!output_type) return failure();
1438 
1439   RankedTensorType input_type =
1440       tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
1441   RankedTensorType filter_type =
1442       tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
1443   RankedTensorType bias_type =
1444       tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
1445   if (!input_type || !filter_type) return failure();
1446 
1447   bool input_is_qtype =
1448       input_type.getElementType().isa<mlir::quant::QuantizedType>();
1449   bool filter_is_qtype =
1450       filter_type.getElementType().isa<mlir::quant::QuantizedType>();
1451   bool output_is_qtype =
1452       output_type.getElementType().isa<mlir::quant::QuantizedType>();
1453 
1454   if ((input_is_qtype != filter_is_qtype) ||
1455       (input_is_qtype != output_is_qtype)) {
1456     return op->emitOpError(
1457         "ConvertTFLFullyConnectedOp: input/filter/output tensor should "
1458         "be all quantized or all floating-point.");
1459   }
1460 
1461   Value input_val = tfl_fc_op.input();
1462 
1463   // tfl.fully_connected() can takes various dimension tensor as input
1464   // need to reshape it to rank 2 tensor, which tosa.fully_connected only
1465   // supports if input tensor is rank 4.  It's not always reshaping to (dim[0] *
1466   // dim[1], dim[2] * dim[3]).
1467 
1468   // In some networks it's reshaping to (dim[0], dim[1] * dim[2] * dim[3]) so a
1469   // more general way to determine the reshape's shape is by looking at filter's
1470   // shape[1].
1471   if (input_type.getRank() != 2) {
1472     int64_t num_elems = filter_type.getShape()[1];
1473     int64_t num_batch = input_type.getNumElements() / num_elems;
1474     SmallVector<int64_t, 2> shape_vals({num_batch, num_elems});
1475 
1476     RankedTensorType reshape_type = RankedTensorType::get(
1477         ArrayRef<int64_t>(shape_vals), input_type.getElementType());
1478     auto reshape_op = rewriter.create<tosa::ReshapeOp>(
1479         op->getLoc(), reshape_type, tfl_fc_op.input(),
1480         rewriter.getI64ArrayAttr(shape_vals));
1481 
1482     input_val = reshape_op.getResult();
1483   }
1484 
1485   Value bias_val;
1486   if (!bias_type) {
1487     // For some matmuls, the bias may actually be a "UnitType" which has no
1488     // value. TOSA requires bias to be an array of output_channel_count values,
1489     // so create a constant of the appropriate number and type of zeros.
1490     SmallVector<int64_t, 1> bias_shape({filter_type.getShape()[0]});
1491     RankedTensorType bias_type = RankedTensorType::get(
1492         ArrayRef<int64_t>(bias_shape), input_type.getElementType());
1493 
1494     DenseElementsAttr bias_attr;
1495     if (input_type.getElementType().isa<FloatType>()) {
1496       SmallVector<float, 2> bias_arr(bias_shape[0]);
1497 
1498       for (int i = 0; i < bias_shape[0]; i++) {
1499         bias_arr[i] = 0.0;
1500       }
1501       // TODO: implicit cast suggest instead of makeArrayRef but triggers
1502       // build error.
1503       bias_attr = DenseElementsAttr::get(bias_type,
1504                                          llvm::makeArrayRef<float>(bias_arr));
1505     } else {
1506       SmallVector<int32_t, 2> bias_arr(bias_shape[0]);
1507 
1508       for (int i = 0; i < bias_shape[0]; i++) {
1509         bias_arr[i] = 0;
1510       }
1511       bias_attr = DenseElementsAttr::get(bias_type,
1512                                          llvm::makeArrayRef<int32_t>(bias_arr));
1513     }
1514     auto bias_op =
1515         rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type, bias_attr);
1516     bias_val = bias_op.getResult();
1517   } else {
1518     bias_val = getUnquantizedBias(rewriter, op, tfl_fc_op.bias());
1519   }
1520 
1521   auto fc_op = rewriter.create<tosa::FullyConnectedOp>(
1522       op->getLoc(), output_type, input_val, tfl_fc_op.filter(), bias_val);
1523 
1524   Value fc_output;
1525   if (input_is_qtype) {
1526     fc_output = buildRescaleOpConvOutput(rewriter, op, fc_op.getResult(),
1527                                          input_type, filter_type, output_type);
1528   } else {
1529     fc_output = fc_op.getResult();
1530   }
1531 
1532   auto fused_activation_fn = tfl_fc_op.fused_activation_functionAttr();
1533 
1534   if (fused_activation_fn) {
1535     llvm::Optional<Value> fused_activation_val =
1536         convertFusedActivation(rewriter, op, fc_output, fused_activation_fn);
1537 
1538     if (!fused_activation_val) return failure();
1539 
1540     rewriter.replaceOp(op, {fused_activation_val.getValue()});
1541     return success();
1542   }
1543 
1544   rewriter.replaceOp(op, {fc_output});
1545 
1546   return success();
1547 }
1548 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1549 LogicalResult ConvertTFLConcatenationOp::matchAndRewrite(
1550     Operation* op, PatternRewriter& rewriter) const {
1551   auto tfl_concat_op = cast<TFL::ConcatenationOp>(op);
1552 
1553   SmallVector<Value, 8> values(tfl_concat_op.values());
1554 
1555   IntegerAttr axis_attr;
1556   {
1557     auto tmpAttr = tfl_concat_op.axisAttr();
1558     if (!tmpAttr) {
1559       tmpAttr = rewriter.getI64IntegerAttr(0);
1560     }
1561     axis_attr = tmpAttr;
1562   }
1563   int32_t axis = axis_attr.getInt();
1564 
1565   llvm::Optional<Value> result =
1566       convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
1567 
1568   if (!result) return failure();
1569 
1570   rewriter.replaceOp(op, {result.getValue()});
1571   return success();
1572 }
1573 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1574 LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
1575     Operation* op, PatternRewriter& rewriter) const {
1576   auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
1577 
1578   RankedTensorType output_type =
1579       tfl_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
1580   // Not a ranked tensor output
1581   if (!output_type) return failure();
1582 
1583   SmallVector<int64_t, 8> shape_vals;
1584   for (int i = 0; i < output_type.getShape().size(); i++) {
1585     shape_vals.push_back(output_type.getShape()[i]);
1586   }
1587   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals);
1588 
1589   rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1590       op, output_type, tfl_reshape_op.input(), shape_attr);
1591   return success();
1592 }
1593 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1594 LogicalResult ConvertTFLRankOp::matchAndRewrite(
1595     Operation* op, PatternRewriter& rewriter) const {
1596   auto tfl_rank_op = cast<TFL::RankOp>(op);
1597 
1598   RankedTensorType input_type =
1599       tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
1600   if (!input_type) return failure();
1601 
1602   int32_t rank = input_type.getRank();
1603 
1604   RankedTensorType rank_type =
1605       RankedTensorType::get({1}, rewriter.getIntegerType(32));
1606   auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
1607   auto rank_const =
1608       rewriter.create<tosa::ConstOp>(op->getLoc(), rank_type, rank_attr);
1609 
1610   rewriter.replaceOp(op, {rank_const.getResult()});
1611 
1612   return success();
1613 }
1614 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1615 LogicalResult ConvertTFLShapeOp::matchAndRewrite(
1616     Operation* op, PatternRewriter& rewriter) const {
1617   auto tfl_shape_op = cast<TFL::ShapeOp>(op);
1618 
1619   RankedTensorType output_type =
1620       tfl_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
1621   // Not a ranked tensor output
1622   if (!output_type) return failure();
1623 
1624   RankedTensorType input_type =
1625       tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
1626   if (!input_type) return failure();
1627 
1628   auto input_shape = input_type.getShape();
1629 
1630   SmallVector<int32_t, 8> shape_arr;
1631   for (int i = 0; i < input_shape.size(); i++) {
1632     shape_arr.emplace_back(input_shape[i]);
1633   }
1634 
1635   RankedTensorType shape_type = RankedTensorType::get(
1636       {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
1637   auto shape_attr = DenseElementsAttr::get(
1638       shape_type, llvm::makeArrayRef<int32_t>(shape_arr));
1639   auto shape_const =
1640       rewriter.create<tosa::ConstOp>(op->getLoc(), shape_type, shape_attr);
1641 
1642   rewriter.replaceOp(op, {shape_const.getResult()});
1643 
1644   return success();
1645 }
1646 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1647 LogicalResult ConvertTFLExpandDimsOp::matchAndRewrite(
1648     Operation* op, PatternRewriter& rewriter) const {
1649   auto tfl_expanddims_op = cast<TFL::ExpandDimsOp>(op);
1650 
1651   llvm::Optional<Value> result =
1652       convertExpandDimsOp(rewriter, op, tfl_expanddims_op.getResult(),
1653                           tfl_expanddims_op.input(), tfl_expanddims_op.dim());
1654 
1655   if (!result) return failure();
1656 
1657   rewriter.replaceOp(op, {result.getValue()});
1658 
1659   return success();
1660 }
1661 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1662 LogicalResult ConvertTFLSqueezeOp::matchAndRewrite(
1663     Operation* op, PatternRewriter& rewriter) const {
1664   auto tfl_squeeze_op = cast<TFL::SqueezeOp>(op);
1665 
1666   // Copy squeeze_dims into int32_t array
1667   auto squeeze_dims_attr = tfl_squeeze_op.squeeze_dimsAttr();
1668   SmallVector<int32_t, 8> squeeze_dims;
1669   for (auto& squeeze_dim : squeeze_dims_attr) {
1670     squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
1671   }
1672 
1673   llvm::Optional<Value> result =
1674       convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
1675                        tfl_squeeze_op.input(), squeeze_dims);
1676 
1677   if (!result) return failure();
1678 
1679   rewriter.replaceOp(op, {result.getValue()});
1680 
1681   return success();
1682 }
1683 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1684 LogicalResult ConvertTFLFillOp::matchAndRewrite(
1685     Operation* op, PatternRewriter& rewriter) const {
1686   auto tfl_fill_op = cast<TFL::FillOp>(op);
1687 
1688   RankedTensorType output_type =
1689       tfl_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
1690   // Not a ranked tensor output
1691   if (!output_type) return failure();
1692 
1693   ElementsAttr dims_elems;
1694   if (!matchPattern(tfl_fill_op.dims(), m_Constant(&dims_elems)))
1695     return failure();
1696   SmallVector<int64_t, 4> dims_vals;
1697   uint32_t total_size = 1;
1698   for (int i = 0; i < dims_elems.getNumElements(); i++) {
1699     dims_vals.push_back(dims_elems.getValue<IntegerAttr>(i).getInt());
1700     total_size *= dims_vals[i];
1701   }
1702 
1703   ElementsAttr value_elem;
1704   if (!matchPattern(tfl_fill_op.input(), m_Constant(&value_elem)))
1705     return failure();
1706 
1707   RankedTensorType fill_type = RankedTensorType::get(
1708       ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
1709   DenseElementsAttr fill_attr;
1710 
1711   // Convert to a compatible zero type.
1712   if (value_elem.getType().getElementType().isa<FloatType>()) {
1713     llvm::SmallVector<float, 4> fill_arr(
1714         total_size,
1715         value_elem.getValue<FloatAttr>(0).getValue().convertToFloat());
1716     fill_attr =
1717         DenseElementsAttr::get(fill_type, llvm::makeArrayRef<float>(fill_arr));
1718   } else {
1719     llvm::SmallVector<int32_t, 4> fill_arr(
1720         total_size,
1721         value_elem.getValue<IntegerAttr>(0).getValue().getLimitedValue());
1722     fill_attr = DenseElementsAttr::get(fill_type,
1723                                        llvm::makeArrayRef<int32_t>(fill_arr));
1724   }
1725   auto fill_const_op =
1726       rewriter.create<tosa::ConstOp>(op->getLoc(), fill_type, fill_attr);
1727   rewriter.replaceOp(op, {fill_const_op.getResult()});
1728 
1729   return success();
1730 }
1731 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1732 LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite(
1733     Operation* op, PatternRewriter& rewriter) const {
1734   auto tfl_any_op = cast<TFL::ReduceAnyOp>(op);
1735 
1736   RankedTensorType output_type =
1737       tfl_any_op.getResult().getType().dyn_cast<RankedTensorType>();
1738   if (!output_type) return failure();
1739 
1740   ElementsAttr axes_elems;
1741   if (!matchPattern(tfl_any_op.reduction_indices(), m_Constant(&axes_elems)))
1742     return failure();
1743 
1744   bool keep_dims = false;
1745   auto keep_dims_attr = tfl_any_op.keep_dimsAttr();
1746   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1747 
1748   llvm::Optional<Value> result = convertReduceAnyOp(
1749       rewriter, op, output_type, tfl_any_op.input(), axes_elems, keep_dims);
1750 
1751   if (!result) return failure();
1752 
1753   rewriter.replaceOp(op, {result.getValue()});
1754 
1755   return success();
1756 }
1757 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1758 LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite(
1759     Operation* op, PatternRewriter& rewriter) const {
1760   auto tfl_max_op = cast<TFL::ReduceMaxOp>(op);
1761 
1762   RankedTensorType output_type =
1763       tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
1764   if (!output_type) return failure();
1765 
1766   ElementsAttr axes_elems;
1767   if (!matchPattern(tfl_max_op.axes(), m_Constant(&axes_elems)))
1768     return failure();
1769 
1770   bool keep_dims = false;
1771   auto keep_dims_attr = tfl_max_op.keep_dimsAttr();
1772   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1773 
1774   llvm::Optional<Value> result = convertReduceMaxOp(
1775       rewriter, op, output_type, tfl_max_op.input(), axes_elems, keep_dims);
1776 
1777   if (!result) return failure();
1778 
1779   rewriter.replaceOp(op, {result.getValue()});
1780 
1781   return success();
1782 }
1783 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1784 LogicalResult ConvertTFLReduceMinOp::matchAndRewrite(
1785     Operation* op, PatternRewriter& rewriter) const {
1786   auto tfl_min_op = cast<TFL::ReduceMinOp>(op);
1787 
1788   RankedTensorType output_type =
1789       tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
1790   if (!output_type) return failure();
1791 
1792   ElementsAttr axes_elems;
1793   if (!matchPattern(tfl_min_op.axes(), m_Constant(&axes_elems)))
1794     return failure();
1795 
1796   bool keep_dims = false;
1797   auto keep_dims_attr = tfl_min_op.keep_dimsAttr();
1798   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1799 
1800   llvm::Optional<Value> result = convertReduceMinOp(
1801       rewriter, op, output_type, tfl_min_op.input(), axes_elems, keep_dims);
1802 
1803   if (!result) return failure();
1804 
1805   rewriter.replaceOp(op, {result.getValue()});
1806 
1807   return success();
1808 }
1809 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1810 LogicalResult ConvertTFLReduceProdOp::matchAndRewrite(
1811     Operation* op, PatternRewriter& rewriter) const {
1812   auto tfl_prod_op = cast<TFL::ReduceProdOp>(op);
1813 
1814   RankedTensorType output_type =
1815       tfl_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
1816   if (!output_type) return failure();
1817 
1818   ElementsAttr axes_elems;
1819   if (!matchPattern(tfl_prod_op.axes(), m_Constant(&axes_elems)))
1820     return failure();
1821 
1822   bool keep_dims = false;
1823   auto keep_dims_attr = tfl_prod_op.keep_dimsAttr();
1824   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1825 
1826   llvm::Optional<Value> result = convertReduceProdOp(
1827       rewriter, op, output_type, tfl_prod_op.input(), axes_elems, keep_dims);
1828 
1829   if (!result) return failure();
1830 
1831   rewriter.replaceOp(op, {result.getValue()});
1832 
1833   return success();
1834 }
1835 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1836 LogicalResult ConvertTFLMeanOp::matchAndRewrite(
1837     Operation* op, PatternRewriter& rewriter) const {
1838   auto tfl_mean_op = cast<TFL::MeanOp>(op);
1839 
1840   RankedTensorType output_type =
1841       tfl_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
1842   if (!output_type) return failure();
1843 
1844   ElementsAttr axes_elems;
1845   if (!matchPattern(tfl_mean_op.axis(), m_Constant(&axes_elems)))
1846     return failure();
1847 
1848   bool keep_dims = false;
1849   auto keep_dims_attr = tfl_mean_op.keep_dimsAttr();
1850   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1851 
1852   llvm::Optional<Value> result = convertReduceMeanOp(
1853       rewriter, op, output_type, tfl_mean_op.input(), axes_elems, keep_dims);
1854 
1855   if (!result) return failure();
1856 
1857   rewriter.replaceOp(op, {result.getValue()});
1858 
1859   return success();
1860 }
1861 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1862 LogicalResult ConvertTFLSumOp::matchAndRewrite(
1863     Operation* op, PatternRewriter& rewriter) const {
1864   auto tfl_sum_op = cast<TFL::SumOp>(op);
1865 
1866   RankedTensorType output_type =
1867       tfl_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
1868   if (!output_type) return failure();
1869 
1870   ElementsAttr axes_elems;
1871   if (!matchPattern(tfl_sum_op.axes(), m_Constant(&axes_elems)))
1872     return failure();
1873 
1874   bool keep_dims = false;
1875   auto keep_dims_attr = tfl_sum_op.keep_dimsAttr();
1876   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1877 
1878   llvm::Optional<Value> result = convertReduceSumOp(
1879       rewriter, op, output_type, tfl_sum_op.input(), axes_elems, keep_dims);
1880 
1881   if (!result) return failure();
1882 
1883   rewriter.replaceOp(op, {result.getValue()});
1884 
1885   return success();
1886 }
1887 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1888 LogicalResult ConvertTFLEluOp::matchAndRewrite(
1889     Operation* op, PatternRewriter& rewriter) const {
1890   auto tfl_elu_op = cast<TFL::EluOp>(op);
1891 
1892   llvm::Optional<Value> result =
1893       convertEluOp(rewriter, op, tfl_elu_op.getResult(), tfl_elu_op.x());
1894 
1895   if (!result) return failure();
1896 
1897   rewriter.replaceOp(op, {result.getValue()});
1898 
1899   return success();
1900 }
1901 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1902 LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite(
1903     Operation* op, PatternRewriter& rewriter) const {
1904   auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
1905 
1906   llvm::Optional<Value> result = convertSoftmaxOp(
1907       rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input());
1908 
1909   if (!result) return failure();
1910 
1911   rewriter.replaceOp(op, {result.getValue()});
1912 
1913   return success();
1914 }
1915 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1916 LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite(
1917     Operation* op, PatternRewriter& rewriter) const {
1918   auto tfl_logsoftmax_op = cast<TFL::LogSoftmaxOp>(op);
1919 
1920   llvm::Optional<Value> result = convertLogSoftmaxOp(
1921       rewriter, op, tfl_logsoftmax_op.getResult(), tfl_logsoftmax_op.input());
1922 
1923   if (!result) return failure();
1924 
1925   rewriter.replaceOp(op, {result.getValue()});
1926 
1927   return success();
1928 }
1929 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1930 LogicalResult ConvertTFLSliceOp::matchAndRewrite(
1931     Operation* op, PatternRewriter& rewriter) const {
1932   auto tfl_slice_op = cast<TFL::SliceOp>(op);
1933 
1934   RankedTensorType output_type =
1935       tfl_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
1936   // Not a ranked tensor output
1937   if (!output_type) return failure();
1938 
1939   ElementsAttr begin_elems, size_elems;
1940 
1941   SmallVector<int64_t, 4> begin_vals, size_vals;
1942 
1943   if (!matchPattern(tfl_slice_op.begin(), m_Constant(&begin_elems)) ||
1944       !matchPattern(tfl_slice_op.size(), m_Constant(&size_elems))) {
1945     return failure();
1946   }
1947 
1948   for (int i = 0; i < begin_elems.getNumElements(); i++)
1949     begin_vals.push_back(begin_elems.getValue<IntegerAttr>(i).getInt());
1950 
1951   for (int i = 0; i < size_elems.getNumElements(); i++)
1952     size_vals.push_back(size_elems.getValue<IntegerAttr>(i).getInt());
1953 
1954   ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1955   ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1956 
1957   rewriter.replaceOpWithNewOp<tosa::SliceOp>(op, output_type,
1958                                              tfl_slice_op.input(), begin, size);
1959   return success();
1960 }
1961 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1962 LogicalResult ConvertTFLTileOp::matchAndRewrite(
1963     Operation* op, PatternRewriter& rewriter) const {
1964   auto tfl_tile_op = cast<TFL::TileOp>(op);
1965 
1966   RankedTensorType output_type =
1967       tfl_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
1968   // Not a ranked tensor output
1969   if (!output_type) return failure();
1970 
1971   ElementsAttr multiples_elems;
1972   if (!matchPattern(tfl_tile_op.multiples(), m_Constant(&multiples_elems)))
1973     return failure();
1974   SmallVector<int64_t, 4> multiples_vals;
1975   for (int i = 0; i < multiples_elems.getNumElements(); i++)
1976     multiples_vals.push_back(multiples_elems.getValue<IntegerAttr>(i).getInt());
1977 
1978   ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
1979   rewriter.replaceOpWithNewOp<tosa::TileOp>(
1980       op, output_type, tfl_tile_op.input(), multiples_attr);
1981 
1982   return success();
1983 }
1984 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1985 LogicalResult ConvertTFLTransposeOp::matchAndRewrite(
1986     Operation* op, PatternRewriter& rewriter) const {
1987   auto tfl_transpose_op = cast<TFL::TransposeOp>(op);
1988 
1989   RankedTensorType output_type =
1990       tfl_transpose_op.getResult().getType().dyn_cast<RankedTensorType>();
1991   // Not a ranked tensor output
1992   if (!output_type) return failure();
1993 
1994   rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
1995       op, output_type, tfl_transpose_op.input(), tfl_transpose_op.perm());
1996 
1997   return success();
1998 }
1999 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2000 LogicalResult ConvertTFLPackOp::matchAndRewrite(
2001     Operation* op, PatternRewriter& rewriter) const {
2002   auto tfl_pack_op = cast<TFL::PackOp>(op);
2003 
2004   SmallVector<Value, 8> inputs(tfl_pack_op.values());
2005   assert(inputs.size() >= 2);
2006 
2007   IntegerAttr axis_attr;
2008   {
2009     auto tmpAttr = tfl_pack_op.axisAttr();
2010     if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
2011     axis_attr = tmpAttr;
2012   }
2013   int32_t axis_i32 = axis_attr.getInt();
2014 
2015   llvm::Optional<Value> result =
2016       convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
2017 
2018   if (!result) return failure();
2019 
2020   rewriter.replaceOp(op, {result.getValue()});
2021 
2022   return success();
2023 }
2024 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2025 LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
2026     Operation* op, PatternRewriter& rewriter) const {
2027   auto tfl_unpack_op = cast<TFL::UnpackOp>(op);
2028 
2029   IntegerAttr axis_attr;
2030   {
2031     auto tmpAttr = tfl_unpack_op.axisAttr();
2032     if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
2033     axis_attr = tmpAttr;
2034   }
2035   int32_t axis_i32 = axis_attr.getInt();
2036 
2037   llvm::Optional<ValueRange> results =
2038       convertUnpackOp(rewriter, op, tfl_unpack_op.input(), axis_i32);
2039 
2040   if (!results) return failure();
2041 
2042   rewriter.replaceOp(op, results.getValue());
2043 
2044   return success();
2045 }
2046 
2047 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2048 LogicalResult ConvertTFLSplitOp::matchAndRewrite(
2049     Operation* op, PatternRewriter& rewriter) const {
2050   auto tfl_split_op = cast<TFL::SplitOp>(op);
2051 
2052   // Get the number of splits
2053   int32_t num_split = -1;
2054   auto numSplitAttr = tfl_split_op.num_splitsAttr();
2055   if (numSplitAttr) {
2056     num_split = numSplitAttr.getInt();
2057   } else {
2058     return failure();
2059   }
2060 
2061   // Get the axis
2062   ElementsAttr axisAttrElems;
2063   if (!matchPattern(tfl_split_op.split_dim(), m_Constant(&axisAttrElems))) {
2064     return op->emitOpError("Cannot read split_dim elems");
2065   }
2066 
2067   // The axis/split_dim parameter is stored as a 0D tensor instead of
2068   // an integer attribute in TFLite MLIR.
2069   int32_t axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
2070 
2071   llvm::Optional<ValueRange> results =
2072       convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
2073                      tfl_split_op.value(), num_split, axis);
2074 
2075   if (!results) return failure();
2076 
2077   rewriter.replaceOp(op, results.getValue());
2078 
2079   return success();
2080 }
2081 
2082 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2083 LogicalResult ConvertTFLSplitVOp::matchAndRewrite(
2084     Operation* op, PatternRewriter& rewriter) const {
2085   auto tfl_splitv_op = cast<TFL::SplitVOp>(op);
2086 
2087   // Get the size_splits array
2088   SmallVector<int32_t, 4> size_split;
2089   ElementsAttr size_split_elems;
2090   if (!matchPattern(tfl_splitv_op.size_splits(),
2091                     m_Constant(&size_split_elems))) {
2092     return failure();
2093   }
2094 
2095   for (int i = 0; i < size_split_elems.getNumElements(); i++) {
2096     size_split.push_back(size_split_elems.getValue<IntegerAttr>(i).getInt());
2097   }
2098 
2099   // Get the axis
2100   ElementsAttr axisAttrElems;
2101   if (!matchPattern(tfl_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
2102     return op->emitOpError("Cannot read split_dim elems");
2103   }
2104 
2105   // The axis/split_dim parameter is stored as a 0D tensor instead of
2106   // an integer attribute in TFLite MLIR.
2107   int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
2108 
2109   llvm::Optional<ValueRange> results =
2110       convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
2111                       tfl_splitv_op.value(), size_split, axis);
2112 
2113   if (!results) return failure();
2114 
2115   rewriter.replaceOp(op, results.getValue());
2116 
2117   return success();
2118 }
2119 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2120 LogicalResult ConvertTFLLessOp::matchAndRewrite(
2121     Operation* op, PatternRewriter& rewriter) const {
2122   auto tfl_less_op = cast<TFL::LessOp>(op);
2123 
2124   RankedTensorType input_lhs_type =
2125       tfl_less_op.lhs().getType().dyn_cast<RankedTensorType>();
2126   RankedTensorType input_rhs_type =
2127       tfl_less_op.rhs().getType().dyn_cast<RankedTensorType>();
2128   RankedTensorType output_type =
2129       tfl_less_op.getResult().getType().dyn_cast<RankedTensorType>();
2130   // Not a ranked tensor output
2131   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
2132 
2133   bool input_lhs_is_qtype =
2134       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2135   bool input_rhs_is_qtype =
2136       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2137   bool output_is_qtype =
2138       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2139 
2140   if (input_lhs_is_qtype != output_is_qtype ||
2141       input_rhs_is_qtype != output_is_qtype) {
2142     return op->emitOpError(
2143         "ConvertTFLLessOp: input/output tensor should "
2144         "be all quantized or all floating-point.");
2145   }
2146 
2147   Value output;
2148   if (output_is_qtype) {
2149     UniformQuantizedType input_lhs_qtype =
2150         input_lhs_type.getElementType()
2151             .dyn_cast<mlir::quant::UniformQuantizedType>();
2152     UniformQuantizedType input_rhs_qtype =
2153         input_rhs_type.getElementType()
2154             .dyn_cast<mlir::quant::UniformQuantizedType>();
2155 
2156     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
2157         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
2158       return op->emitOpError(
2159           "ConvertTFLLessOp: input_x and input_y scale/zp "
2160           "must be the same");
2161     }
2162 
2163     Value op1_rescale_lhs = buildRescaleToInt32(
2164         rewriter, op, tfl_less_op.lhs(), 1.0f, input_lhs_qtype.getZeroPoint());
2165     Value op2_rescale_rhs = buildRescaleToInt32(
2166         rewriter, op, tfl_less_op.rhs(), 1.0f, input_rhs_qtype.getZeroPoint());
2167     auto op3_greater_equal_op1_op2 = rewriter.create<tosa::GreaterEqualOp>(
2168         op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
2169     auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
2170         op->getLoc(), output_type, op3_greater_equal_op1_op2.getResult());
2171 
2172     output = op4_not_op3.getResult();
2173   } else {
2174     auto op1_greater_equal_in = rewriter.create<tosa::GreaterEqualOp>(
2175         op->getLoc(), output_type, tfl_less_op.lhs(), tfl_less_op.rhs());
2176     auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
2177         op->getLoc(), output_type, op1_greater_equal_in.getResult());
2178 
2179     output = op2_not_op1.getResult();
2180   }
2181 
2182   rewriter.replaceOp(op, {output});
2183   return success();
2184 }
2185 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2186 LogicalResult ConvertTFLLessEqualOp::matchAndRewrite(
2187     Operation* op, PatternRewriter& rewriter) const {
2188   auto tfl_less_equal_op = cast<TFL::LessEqualOp>(op);
2189 
2190   RankedTensorType input_lhs_type =
2191       tfl_less_equal_op.lhs().getType().dyn_cast<RankedTensorType>();
2192   RankedTensorType input_rhs_type =
2193       tfl_less_equal_op.rhs().getType().dyn_cast<RankedTensorType>();
2194   RankedTensorType output_type =
2195       tfl_less_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
2196   // Not a ranked tensor output
2197   if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
2198 
2199   bool input_lhs_is_qtype =
2200       input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2201   bool input_rhs_is_qtype =
2202       input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2203   bool output_is_qtype =
2204       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2205 
2206   if (input_lhs_is_qtype != output_is_qtype ||
2207       input_rhs_is_qtype != output_is_qtype) {
2208     return op->emitOpError(
2209         "ConvertTFLLessEqualOp: input/output tensor should "
2210         "be all quantized or all floating-point.");
2211   }
2212 
2213   Value output;
2214   if (output_is_qtype) {
2215     UniformQuantizedType input_lhs_qtype =
2216         input_lhs_type.getElementType()
2217             .dyn_cast<mlir::quant::UniformQuantizedType>();
2218     UniformQuantizedType input_rhs_qtype =
2219         input_rhs_type.getElementType()
2220             .dyn_cast<mlir::quant::UniformQuantizedType>();
2221 
2222     if (input_lhs_qtype.getScale() != input_rhs_qtype.getScale() ||
2223         input_lhs_qtype.getZeroPoint() != input_rhs_qtype.getZeroPoint()) {
2224       return op->emitOpError(
2225           "ConvertTFLLessEqualOp: input_x and input_y scale/zp "
2226           "must be the same");
2227     }
2228 
2229     Value op1_rescale_lhs =
2230         buildRescaleToInt32(rewriter, op, tfl_less_equal_op.lhs(), 1.0f,
2231                             input_lhs_qtype.getZeroPoint());
2232     Value op2_rescale_rhs =
2233         buildRescaleToInt32(rewriter, op, tfl_less_equal_op.rhs(), 1.0f,
2234                             input_rhs_qtype.getZeroPoint());
2235     auto op3_greater_op1_op2 = rewriter.create<tosa::GreaterOp>(
2236         op->getLoc(), output_type, op1_rescale_lhs, op2_rescale_rhs);
2237     auto op4_not_op3 = rewriter.create<tosa::LogicalNotOp>(
2238         op->getLoc(), output_type, op3_greater_op1_op2.getResult());
2239 
2240     output = op4_not_op3.getResult();
2241   } else {
2242     auto op1_greater_in = rewriter.create<tosa::GreaterOp>(
2243         op->getLoc(), output_type, tfl_less_equal_op.lhs(),
2244         tfl_less_equal_op.rhs());
2245     auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
2246         op->getLoc(), output_type, op1_greater_in.getResult());
2247 
2248     output = op2_not_op1.getResult();
2249   }
2250 
2251   rewriter.replaceOp(op, {output});
2252   return success();
2253 }
2254 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2255 LogicalResult ConvertTFLPadOp::matchAndRewrite(
2256     Operation* op, PatternRewriter& rewriter) const {
2257   auto tfl_pad_op = cast<TFL::PadOp>(op);
2258 
2259   RankedTensorType output_type =
2260       tfl_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
2261   // Not a ranked tensor output
2262   if (!output_type) return failure();
2263 
2264   auto pad_op = rewriter.create<tosa::PadOp>(
2265       op->getLoc(), output_type, tfl_pad_op.input(), tfl_pad_op.padding());
2266 
2267   rewriter.replaceOp(op, {pad_op.getResult()});
2268   return success();
2269 }
2270 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2271 LogicalResult ConvertTFLResizeBilinearOp::matchAndRewrite(
2272     Operation* op, PatternRewriter& rewriter) const {
2273   auto tfl_resize_op = cast<TFL::ResizeBilinearOp>(op);
2274 
2275   RankedTensorType output_type =
2276       tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
2277   // Not a ranked tensor output
2278   if (!output_type) return failure();
2279 
2280   llvm::Optional<Value> result = convertResizeOp(
2281       rewriter, op, output_type, tfl_resize_op.input(), StringRef("BILINEAR"));
2282 
2283   if (!result) return failure();
2284 
2285   rewriter.replaceOp(op, {result.getValue()});
2286 
2287   return success();
2288 }
2289 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2290 LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite(
2291     Operation* op, PatternRewriter& rewriter) const {
2292   auto tfl_resize_op = cast<TFL::ResizeNearestNeighborOp>(op);
2293 
2294   RankedTensorType output_type =
2295       tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
2296   // Not a ranked tensor output
2297   if (!output_type) return failure();
2298 
2299   llvm::Optional<Value> result = convertResizeOp(
2300       rewriter, op, output_type, tfl_resize_op.input(), StringRef("NEAREST"));
2301 
2302   if (!result) return failure();
2303 
2304   rewriter.replaceOp(op, {result.getValue()});
2305 
2306   return success();
2307 }
2308 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2309 LogicalResult ConvertTFLSelectOp::matchAndRewrite(
2310     Operation* op, PatternRewriter& rewriter) const {
2311   auto tfl_sel_op = cast<TFL::SelectOp>(op);
2312 
2313   llvm::Optional<Value> result =
2314       convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
2315                       tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
2316   if (!result) return failure();
2317 
2318   rewriter.replaceOp(op, {result.getValue()});
2319 
2320   return success();
2321 }
2322 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2323 LogicalResult ConvertTFLSelectV2Op::matchAndRewrite(
2324     Operation* op, PatternRewriter& rewriter) const {
2325   auto tfl_sel_op = cast<TFL::SelectV2Op>(op);
2326 
2327   llvm::Optional<Value> result =
2328       convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
2329                       tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
2330   if (!result) return failure();
2331 
2332   rewriter.replaceOp(op, {result.getValue()});
2333 
2334   return success();
2335 }
2336 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2337 LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite(
2338     Operation* op, PatternRewriter& rewriter) const {
2339   auto tfl_s2b_op = cast<TFL::SpaceToBatchNdOp>(op);
2340   llvm::Optional<Value> result = convertSpaceToBatchNDOp(
2341       rewriter, op, tfl_s2b_op.getResult(), tfl_s2b_op.input(),
2342       tfl_s2b_op.block_shape(), tfl_s2b_op.paddings());
2343 
2344   if (!result) return failure();
2345 
2346   rewriter.replaceOp(op, {result.getValue()});
2347 
2348   return success();
2349 }
2350 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2351 LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite(
2352     Operation* op, PatternRewriter& rewriter) const {
2353   auto tfl_b2s_op = cast<TFL::BatchToSpaceNdOp>(op);
2354 
2355   llvm::Optional<Value> result = convertBatchToSpaceNDOp(
2356       rewriter, op, tfl_b2s_op.getResult(), tfl_b2s_op.input(),
2357       tfl_b2s_op.block_shape(), tfl_b2s_op.indices());
2358 
2359   if (!result) return failure();
2360 
2361   rewriter.replaceOp(op, {result.getValue()});
2362 
2363   return success();
2364 }
2365 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2366 LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite(
2367     Operation* op, PatternRewriter& rewriter) const {
2368   auto tfl_s2d_op = cast<TFL::SpaceToDepthOp>(op);
2369 
2370   auto block_size_attr = tfl_s2d_op.block_sizeAttr();
2371   llvm::Optional<Value> result = convertSpaceToDepthOp(
2372       rewriter, op, tfl_s2d_op.getResult(), tfl_s2d_op.input(), block_size_attr,
2373       rewriter.getStringAttr("NHWC"));
2374 
2375   if (!result) return failure();
2376 
2377   rewriter.replaceOp(op, {result.getValue()});
2378 
2379   return success();
2380 }
2381 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2382 LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite(
2383     Operation* op, PatternRewriter& rewriter) const {
2384   auto tfl_d2s_op = cast<TFL::DepthToSpaceOp>(op);
2385 
2386   auto block_size_attr = tfl_d2s_op.block_sizeAttr();
2387   llvm::Optional<Value> result = convertDepthToSpaceOp(
2388       rewriter, op, tfl_d2s_op.getResult(), tfl_d2s_op.input(), block_size_attr,
2389       rewriter.getStringAttr("NHWC"));
2390 
2391   if (!result) return failure();
2392 
2393   rewriter.replaceOp(op, {result.getValue()});
2394 
2395   return success();
2396 }
2397 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2398 LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite(
2399     Operation* op, PatternRewriter& rewriter) const {
2400   auto tfl_ss_op = cast<TFL::StridedSliceOp>(op);
2401 
2402   llvm::Optional<Value> result = convertStridedSliceOp(
2403       rewriter, op, tfl_ss_op.getResult(), tfl_ss_op.input(), tfl_ss_op.begin(),
2404       tfl_ss_op.end(), tfl_ss_op.strides(), tfl_ss_op.begin_maskAttr().getInt(),
2405       tfl_ss_op.end_maskAttr().getInt(), tfl_ss_op.ellipsis_maskAttr().getInt(),
2406       tfl_ss_op.new_axis_maskAttr().getInt(),
2407       tfl_ss_op.shrink_axis_maskAttr().getInt());
2408   if (!result) return failure();
2409 
2410   rewriter.replaceOp(op, {result.getValue()});
2411 
2412   return success();
2413 }
2414 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2415 LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite(
2416     Operation* op, PatternRewriter& rewriter) const {
2417   auto tfl_zeroslike_op = cast<TFL::ZerosLikeOp>(op);
2418 
2419   llvm::Optional<Value> result = convertZerosLikeOp(
2420       rewriter, op, tfl_zeroslike_op.getResult(), tfl_zeroslike_op.input());
2421 
2422   if (!result) return failure();
2423 
2424   rewriter.replaceOp(op, {result.getValue()});
2425 
2426   return success();
2427 }
2428 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2429 LogicalResult ConvertTFLHardSwishOp::matchAndRewrite(
2430     Operation* op, PatternRewriter& rewriter) const {
2431   auto tfl_hardswish_op = cast<TFL::HardSwishOp>(op);
2432   RankedTensorType output_type =
2433       tfl_hardswish_op.getResult().getType().dyn_cast<RankedTensorType>();
2434   // Not a ranked tensor output
2435   if (!output_type) return failure();
2436 
2437   RankedTensorType input_type =
2438       tfl_hardswish_op.input().getType().dyn_cast<RankedTensorType>();
2439   // Not a ranked tensor output
2440   if (!input_type) return failure();
2441 
2442   auto input_shape = input_type.getShape();
2443 
2444   // TFL hardswish: f(x) -> (x * relu6(x+3))/6
2445 
2446   // TODO: support 16-bit hardswish
2447   if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
2448       output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
2449     // TFLite reference:
2450     // tensorflow/lite/kernels/internal/reference/reference_ops.h note
2451     // there's a potential rounding issue in TFLite reference
2452     mlir::quant::UniformQuantizedType in_quant_type =
2453         input_type.getElementType()
2454             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2455     mlir::quant::UniformQuantizedType out_quant_type =
2456         output_type.getElementType()
2457             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2458 
2459     UniformQuantizedType int16_element_qtype =
2460         mlir::quant::UniformQuantizedType::get(
2461             true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
2462             -32768, 32767);
2463     RankedTensorType bool_type =
2464         RankedTensorType::get(input_shape, rewriter.getI1Type());
2465     RankedTensorType int16_type =
2466         RankedTensorType::get(input_shape, int16_element_qtype);
2467     RankedTensorType int32_type =
2468         RankedTensorType::get(input_shape, rewriter.getI32Type());
2469 
2470     // Table's real input range [-4.0, 4.0].
2471     // Use TABLE op to get relu6(x+3) / 6
2472     const double input_sample_grain = 1.0 / 64.0;
2473     auto hardswish_func = [input_sample_grain](int32_t x) -> int32_t {
2474       double v = static_cast<double>(x) * input_sample_grain;
2475       double w = v + 3.0;
2476       w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w;
2477       v = v * w / 6.0;
2478       return std::lround(32768.0 * v);
2479     };
2480 
2481     Value table_const = getTosa1DConstTensorTable(rewriter, op, hardswish_func);
2482 
2483     // Rescale input to 9.7
2484     Value op1_rescale_in =
2485         buildRescale(rewriter, op, int16_type, tfl_hardswish_op.input(),
2486                      (in_quant_type.getScale() * 128.0) / input_sample_grain,
2487                      in_quant_type.getZeroPoint(), 0);
2488 
2489     // Table op. output 0.23
2490     auto op2_table_op1 = rewriter.create<tosa::TableOp>(
2491         op->getLoc(), int32_type, op1_rescale_in, table_const);
2492 
2493     // scale table output back to quantized space
2494     Value op3_rescale_op2 =
2495         buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
2496                      1.0 / (128.0 * 32768.0 * out_quant_type.getScale()), 0,
2497                      out_quant_type.getZeroPoint());
2498 
2499     Value op4_rescale_in = buildRescale(rewriter, op, int32_type,
2500                                         tfl_hardswish_op.input(), 1.0, 0, 0);
2501 
2502     // Get 3.0 in quantized space
2503     int32_t quantized_3 =
2504         static_cast<int32_t>(std::ceil(3.0 / in_quant_type.getScale())) +
2505         in_quant_type.getZeroPoint();
2506 
2507     auto op5_ge_op4 = rewriter.create<tosa::GreaterEqualOp>(
2508         op->getLoc(), bool_type, op4_rescale_in,
2509         getTosaConstTensorSingleI32(rewriter, op, quantized_3));
2510 
2511     auto op6_select_op5_op4_op3 = rewriter.create<tosa::SelectOp>(
2512         op->getLoc(), output_type, op5_ge_op4, tfl_hardswish_op.input(),
2513         op3_rescale_op2);
2514 
2515     rewriter.replaceOp(op, {op6_select_op5_op4_op3});
2516 
2517     return success();
2518 
2519   } else {
2520     // op1 = constop(3)
2521     // op2 = add(x, op1)
2522     // op3 = reluN(op2, 6)
2523     // op4 = mul(x, op3)
2524     // op5 = reciprocal(6)
2525     // op6 = mul (op4, op5)
2526 
2527     Value op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
2528 
2529     auto op2_add_x_op1 = rewriter.create<tosa::AddOp>(
2530         op->getLoc(), output_type, tfl_hardswish_op.input(), op1_value);
2531 
2532     auto op3_relu_op2_6 = rewriter.create<tosa::ReluNOp>(
2533         op->getLoc(), output_type, op2_add_x_op1.getResult(),
2534         rewriter.getI64IntegerAttr(0), rewriter.getF32FloatAttr(6.0));
2535 
2536     auto op4_mul_x_op3 = rewriter.create<tosa::MulOp>(
2537         op->getLoc(), output_type, tfl_hardswish_op.input(),
2538         op3_relu_op2_6.getResult(), 0);
2539 
2540     auto op5_reciprocal_6 = rewriter.create<tosa::ReciprocalOp>(
2541         op->getLoc(), output_type,
2542         getTosaConstTensorSingleF32(rewriter, op, 6.0));
2543 
2544     auto op6_mul_op4_op5 = rewriter.create<tosa::MulOp>(
2545         op->getLoc(), output_type, op4_mul_x_op3.getResult(),
2546         op5_reciprocal_6.getResult(), 0);
2547 
2548     rewriter.replaceOp(op, {op6_mul_op4_op5.getResult()});
2549 
2550     return success();
2551   }
2552 }
2553 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2554 LogicalResult ConvertTFLLogisticOp::matchAndRewrite(
2555     Operation* op, PatternRewriter& rewriter) const {
2556   auto tfl_logistic_op = cast<TFL::LogisticOp>(op);
2557 
2558   RankedTensorType output_type =
2559       tfl_logistic_op.getResult().getType().dyn_cast<RankedTensorType>();
2560   RankedTensorType input_type =
2561       tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
2562   if (!input_type || !output_type) return failure();
2563 
2564   bool input_is_qtype =
2565       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2566   bool output_is_qtype =
2567       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2568 
2569   if (input_is_qtype != output_is_qtype) {
2570     return op->emitOpError(
2571         "ConvertTFLLogisticOp: input/output tensor should "
2572         "be all quantized or all floating-point.");
2573   }
2574 
2575   if (input_is_qtype) {
2576     UniformQuantizedType int16_element_qtype =
2577         mlir::quant::UniformQuantizedType::get(
2578             true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
2579             -32768, 32767);
2580     RankedTensorType int16_type =
2581         RankedTensorType::get(output_type.getShape(), int16_element_qtype);
2582     RankedTensorType int32_type = RankedTensorType::get(
2583         output_type.getShape(), rewriter.getIntegerType(32));
2584     mlir::quant::UniformQuantizedType input_qtype =
2585         input_type.getElementType()
2586             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2587     mlir::quant::UniformQuantizedType output_qtype =
2588         output_type.getElementType()
2589             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2590     const double input_sample_grain = 1.0 / 16.0;
2591     auto sigmoid_func = [input_sample_grain](int32_t x) -> int32_t {
2592       // Input range [-16.0, 16.0], output range [0.0, 1.0]
2593       double v = static_cast<double>(x) * input_sample_grain;
2594       v = 1.0 / (1.0 + std::exp(-v));
2595 
2596       return std::lround(32768.0 * v);
2597     };
2598 
2599     Value table_const = getTosa1DConstTensorTable(rewriter, op, sigmoid_func);
2600 
2601     // Rescale input to 9.7 precision.
2602     Value op1_rescale_in =
2603         buildRescale(rewriter, op, int16_type, tfl_logistic_op.x(),
2604                      (input_qtype.getScale() * 128.0) / input_sample_grain,
2605                      input_qtype.getZeroPoint(), 0);
2606 
2607     auto op2_table_op1 = rewriter.create<tosa::TableOp>(
2608         op->getLoc(), int32_type, op1_rescale_in, table_const);
2609 
2610     double output_rescale_scale =
2611         1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
2612 
2613     Value op3_rescale_op2 =
2614         buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
2615                      output_rescale_scale, 0, output_qtype.getZeroPoint());
2616 
2617     rewriter.replaceOp(op, {op3_rescale_op2});
2618   } else {
2619     rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(op, output_type,
2620                                                  tfl_logistic_op.x());
2621   }
2622 
2623   return success();
2624 }
2625 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2626 LogicalResult ConvertTFLTanhOp::matchAndRewrite(
2627     Operation* op, PatternRewriter& rewriter) const {
2628   auto tfl_tanh_op = cast<TFL::TanhOp>(op);
2629   RankedTensorType output_type =
2630       tfl_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
2631   RankedTensorType input_type =
2632       tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
2633   if (!input_type || !output_type) return failure();
2634 
2635   bool input_is_qtype =
2636       input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2637   bool output_is_qtype =
2638       output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
2639 
2640   if (input_is_qtype != output_is_qtype) {
2641     return op->emitOpError(
2642         "ConvertTFLTanhOp: input/output tensor should "
2643         "be all quantized or all floating-point.");
2644   }
2645 
2646   if (input_is_qtype) {
2647     UniformQuantizedType int16_element_qtype =
2648         mlir::quant::UniformQuantizedType::get(
2649             true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0,
2650             -32768, 32767);
2651     RankedTensorType int16_type =
2652         RankedTensorType::get(output_type.getShape(), int16_element_qtype);
2653     RankedTensorType int32_type = RankedTensorType::get(
2654         output_type.getShape(), rewriter.getIntegerType(32));
2655     mlir::quant::UniformQuantizedType input_qtype =
2656         input_type.getElementType()
2657             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2658     mlir::quant::UniformQuantizedType output_qtype =
2659         output_type.getElementType()
2660             .dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
2661     const double input_sample_grain = 1.0 / 32.0;
2662     auto tanh_func = [input_sample_grain](int32_t x) -> int32_t {
2663       // Input range [-16.0, 16.0], output range [0.0, 1.0]
2664       double v = static_cast<double>(x) * input_sample_grain;
2665       v = std::exp(-2.0 * v);
2666       v = (1.0 - v) / (1.0 + v);
2667 
2668       return std::lround(32768.0 * v);
2669     };
2670 
2671     Value table_const = getTosa1DConstTensorTable(rewriter, op, tanh_func);
2672 
2673     // Rescale input to 9.7 precision.
2674     Value op1_rescale_in =
2675         buildRescale(rewriter, op, int16_type, tfl_tanh_op.input(),
2676                      (input_qtype.getScale() * 128.0) / input_sample_grain,
2677                      input_qtype.getZeroPoint(), 0);
2678 
2679     auto op2_table_op1 = rewriter.create<tosa::TableOp>(
2680         op->getLoc(), int32_type, op1_rescale_in, table_const);
2681 
2682     double output_rescale_scale =
2683         1.0 / (output_qtype.getScale() * 32768.0 * 128.0);
2684 
2685     Value op3_rescale_op2 =
2686         buildRescale(rewriter, op, output_type, op2_table_op1.getResult(),
2687                      output_rescale_scale, 0, output_qtype.getZeroPoint());
2688 
2689     rewriter.replaceOp(op, {op3_rescale_op2});
2690   } else {
2691     rewriter.replaceOpWithNewOp<tosa::TanhOp>(op, output_type,
2692                                               tfl_tanh_op.input());
2693   }
2694 
2695   return success();
2696 }
2697 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2698 LogicalResult ConvertTFLPReluOp::matchAndRewrite(
2699     Operation* op, PatternRewriter& rewriter) const {
2700   auto tfl_prelu_op = cast<TFL::PReluOp>(op);
2701   RankedTensorType output_type =
2702       tfl_prelu_op.getResult().getType().dyn_cast<RankedTensorType>();
2703   if (!output_type) return failure();
2704 
2705   // TODO: add lowering with MUL + SELECT + RESCALE
2706 
2707   return failure();
2708 }
2709 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2710 LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite(
2711     Operation* op, PatternRewriter& rewriter) const {
2712   auto tfl_leakyrelu_op = cast<TFL::LeakyReluOp>(op);
2713   RankedTensorType output_type =
2714       tfl_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
2715   if (!output_type) return failure();
2716 
2717   // TODO: add lowering with MUL + SELECT + RESCALE
2718 
2719   return failure();
2720 }
2721 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2722 LogicalResult ConvertTFLNegOp::matchAndRewrite(
2723     Operation* op, PatternRewriter& rewriter) const {
2724   auto tfl_neg_op = cast<TFL::NegOp>(op);
2725   RankedTensorType output_type =
2726       tfl_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
2727   if (!output_type) return failure();
2728 
2729   rewriter.replaceOpWithNewOp<tosa::NegateOp>(op, output_type, tfl_neg_op.x());
2730 
2731   return success();
2732 }
2733 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2734 LogicalResult ConvertTFLYieldOp::matchAndRewrite(
2735     Operation* op, PatternRewriter& rewriter) const {
2736   rewriter.replaceOpWithNewOp<tosa::YieldOp>(op, op->getResultTypes(),
2737                                              op->getOperands());
2738 
2739   return success();
2740 }
2741 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2742 LogicalResult ConvertTFLCustomOp::matchAndRewrite(
2743     Operation* op, PatternRewriter& rewriter) const {
2744   auto tfl_custom_op = cast<TFL::CustomOp>(op);
2745   rewriter.replaceOpWithNewOp<tosa::CustomOp>(
2746       op, op->getResultTypes(), tfl_custom_op.custom_code(), op->getOperands());
2747 
2748   return success();
2749 }
2750 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2751 LogicalResult ConvertTFLReverseV2Op::matchAndRewrite(
2752     Operation* op, PatternRewriter& rewriter) const {
2753   auto tfl_reverse_op = cast<TFL::ReverseV2Op>(op);
2754 
2755   RankedTensorType input_type =
2756       tfl_reverse_op.input().getType().dyn_cast<RankedTensorType>();
2757   RankedTensorType output_type =
2758       tfl_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
2759   if (!input_type || !output_type) return failure();
2760 
2761   ElementsAttr axis_elems;
2762   if (!matchPattern(tfl_reverse_op.axis(), m_Constant(&axis_elems)))
2763     return failure();
2764 
2765   auto input_rank = input_type.getShape().size();
2766   Value val = tfl_reverse_op.input();
2767   if (axis_elems.getNumElements() == 0) {
2768     auto identity_op =
2769         rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
2770     val = identity_op.getResult();
2771   } else {
2772     for (int i = 0; i < axis_elems.getNumElements(); i++) {
2773       int64_t axis_val = axis_elems.getValue<IntegerAttr>(i).getInt();
2774       if (axis_val < 0) axis_val += input_rank;
2775       auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2776       auto reverse_op = rewriter.create<tosa::ReverseOp>(
2777           op->getLoc(), output_type, val, axis_attr);
2778 
2779       val = reverse_op.getResult();
2780     }
2781   }
2782 
2783   rewriter.replaceOp(op, {val});
2784 
2785   return success();
2786 }
2787 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2788 LogicalResult ConvertTFLQuantizeOp::matchAndRewrite(
2789     Operation* op, PatternRewriter& rewriter) const {
2790   auto tfl_quantize_op = cast<TFL::QuantizeOp>(op);
2791 
2792   RankedTensorType input_type =
2793       tfl_quantize_op.input().getType().dyn_cast<RankedTensorType>();
2794   RankedTensorType output_type =
2795       tfl_quantize_op.getResult().getType().dyn_cast<RankedTensorType>();
2796 
2797   if (!input_type || !output_type) return failure();
2798 
2799   RankedTensorType qtype =
2800       tfl_quantize_op.qtypeAttr().getValue().dyn_cast<RankedTensorType>();
2801   if (!qtype) return failure();
2802 
2803   UniformQuantizedType element_type =
2804       qtype.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
2805   if (!element_type) return failure();
2806 
2807   UniformQuantizedType input_element_type =
2808       input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
2809 
2810   // If input is already a quantized type, this is basically a RESCALE (or
2811   // tensorflow::ops::Requantize)
2812   if (input_element_type) {
2813     double rescale_scale =
2814         input_element_type.getScale() / element_type.getScale();
2815     Value rescale_op = buildRescale(
2816         rewriter, op, output_type, tfl_quantize_op.input(), rescale_scale,
2817         input_element_type.getZeroPoint(), element_type.getZeroPoint());
2818 
2819     rewriter.replaceOp(op, {rescale_op});
2820     return success();
2821   } else {
2822     double scale = 1 / element_type.getScale();
2823     int64_t zp = element_type.getZeroPoint();
2824     int64_t num_bits = element_type.getStorageTypeIntegralWidth();
2825     zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
2826 
2827     llvm::Optional<Value> result = convertQuantizeOp(
2828         rewriter, op, output_type, tfl_quantize_op.input(), scale, zp);
2829 
2830     if (!result) return failure();
2831 
2832     rewriter.replaceOp(op, {result.getValue()});
2833 
2834     return success();
2835   }
2836 }
2837 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2838 LogicalResult ConvertTFLDequantizeOp::matchAndRewrite(
2839     Operation* op, PatternRewriter& rewriter) const {
2840   auto tfl_dequantize_op = cast<TFL::DequantizeOp>(op);
2841 
2842   RankedTensorType output_type =
2843       tfl_dequantize_op.getResult().getType().dyn_cast<RankedTensorType>();
2844   // Not a ranked tensor output
2845   if (!output_type) return failure();
2846 
2847   RankedTensorType qtype =
2848       tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
2849   if (!qtype) return failure();
2850 
2851   UniformQuantizedType element_type =
2852       qtype.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
2853   if (!element_type) return failure();
2854 
2855   double scale = element_type.getScale();
2856   int64_t zp = element_type.getZeroPoint();
2857   int64_t num_bits = element_type.getStorageTypeIntegralWidth();
2858   zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
2859 
2860   llvm::Optional<Value> result = convertDequantizeOp(
2861       rewriter, op, output_type, tfl_dequantize_op.input(), scale, zp);
2862 
2863   if (!result) return failure();
2864 
2865   rewriter.replaceOp(op, {result.getValue()});
2866 
2867   return success();
2868 }
2869 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2870 LogicalResult ConvertTFLQConstOp::matchAndRewrite(
2871     Operation* op, PatternRewriter& rewriter) const {
2872   auto tfl_qconst_op = cast<TFL::QConstOp>(op);
2873 
2874   RankedTensorType output_type =
2875       tfl_qconst_op.getResult().getType().dyn_cast<RankedTensorType>();
2876   // Not a ranked tensor output
2877   if (!output_type) return failure();
2878 
2879   rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type,
2880                                              tfl_qconst_op.valueAttr());
2881 
2882   return success();
2883 }
2884 
runOnFunction()2885 void LegalizeTFL::runOnFunction() {
2886   OwningRewritePatternList patterns;
2887   auto* ctx = &getContext();
2888   auto func = getFunction();
2889 
2890   // Add the generated patterns to the list.
2891   populateWithGenerated(ctx, patterns);
2892 
2893 #define DEF_PATTERN_INSERT(PAT) patterns.insert<Convert##PAT##Op>(ctx);
2894 
2895   DEF_PATTERN_INSERT(TFLRelu);
2896   DEF_PATTERN_INSERT(TFLRelu6);
2897   DEF_PATTERN_INSERT(TFLEqual);
2898   DEF_PATTERN_INSERT(TFLNotEqual);
2899   DEF_PATTERN_INSERT(TFLGreater);
2900   DEF_PATTERN_INSERT(TFLGreaterEqual);
2901   DEF_PATTERN_INSERT(TFLAdd);
2902   DEF_PATTERN_INSERT(TFLSub);
2903   DEF_PATTERN_INSERT(TFLMul);
2904   DEF_PATTERN_INSERT(TFLSquare);
2905   DEF_PATTERN_INSERT(TFLDiv);
2906   DEF_PATTERN_INSERT(TFLMaximum);
2907   DEF_PATTERN_INSERT(TFLMinimum);
2908   DEF_PATTERN_INSERT(TFLFloorMod);
2909   DEF_PATTERN_INSERT(TFLFloorDiv);
2910   DEF_PATTERN_INSERT(TFLAddN);
2911   DEF_PATTERN_INSERT(TFLAveragePool2D);
2912   DEF_PATTERN_INSERT(TFLMaxPool2D);
2913   DEF_PATTERN_INSERT(TFLConcatenation);
2914   DEF_PATTERN_INSERT(TFLReshape);
2915   DEF_PATTERN_INSERT(TFLRank);
2916   DEF_PATTERN_INSERT(TFLShape);
2917   DEF_PATTERN_INSERT(TFLExpandDims);
2918   DEF_PATTERN_INSERT(TFLSqueeze);
2919   DEF_PATTERN_INSERT(TFLFill);
2920   DEF_PATTERN_INSERT(TFLElu);
2921   DEF_PATTERN_INSERT(TFLSoftmax);
2922   DEF_PATTERN_INSERT(TFLLogSoftmax);
2923   DEF_PATTERN_INSERT(TFLReduceAny);
2924   DEF_PATTERN_INSERT(TFLReduceMax);
2925   DEF_PATTERN_INSERT(TFLReduceMin);
2926   DEF_PATTERN_INSERT(TFLMean);
2927   DEF_PATTERN_INSERT(TFLReduceProd);
2928   DEF_PATTERN_INSERT(TFLSum);
2929   DEF_PATTERN_INSERT(TFLConv2D);
2930   DEF_PATTERN_INSERT(TFLTransposeConv);
2931   DEF_PATTERN_INSERT(TFLDepthwiseConv2D);
2932   DEF_PATTERN_INSERT(TFLFullyConnected);
2933   DEF_PATTERN_INSERT(TFLSplit);
2934   DEF_PATTERN_INSERT(TFLSplitV);
2935   DEF_PATTERN_INSERT(TFLPack);
2936   DEF_PATTERN_INSERT(TFLUnpack);
2937   DEF_PATTERN_INSERT(TFLTranspose);
2938   DEF_PATTERN_INSERT(TFLTile);
2939   DEF_PATTERN_INSERT(TFLSlice);
2940   DEF_PATTERN_INSERT(TFLStridedSlice);
2941   DEF_PATTERN_INSERT(TFLZerosLike);
2942   DEF_PATTERN_INSERT(TFLHardSwish);
2943   DEF_PATTERN_INSERT(TFLLess);
2944   DEF_PATTERN_INSERT(TFLLessEqual);
2945   DEF_PATTERN_INSERT(TFLPad);
2946   DEF_PATTERN_INSERT(TFLResizeBilinear);
2947   DEF_PATTERN_INSERT(TFLResizeNearestNeighbor);
2948   DEF_PATTERN_INSERT(TFLSelect);
2949   DEF_PATTERN_INSERT(TFLSelectV2);
2950   DEF_PATTERN_INSERT(TFLSpaceToBatchNd);
2951   DEF_PATTERN_INSERT(TFLBatchToSpaceNd);
2952   DEF_PATTERN_INSERT(TFLSpaceToDepth);
2953   DEF_PATTERN_INSERT(TFLDepthToSpace);
2954   DEF_PATTERN_INSERT(TFLLogistic);
2955   DEF_PATTERN_INSERT(TFLTanh);
2956   DEF_PATTERN_INSERT(TFLPRelu);
2957   DEF_PATTERN_INSERT(TFLLeakyRelu);
2958   DEF_PATTERN_INSERT(TFLNeg);
2959   DEF_PATTERN_INSERT(TFLYield);
2960   DEF_PATTERN_INSERT(TFLCustom);
2961   DEF_PATTERN_INSERT(TFLReverseV2);
2962   DEF_PATTERN_INSERT(TFLQuantize);
2963   DEF_PATTERN_INSERT(TFLDequantize);
2964   DEF_PATTERN_INSERT(TFLQConst);
2965   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
2966 }
2967 }  // namespace
2968 
2969 // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
createLegalizeTFLPass()2970 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass() {
2971   return std::make_unique<LegalizeTFL>();
2972 }
2973 
2974 static PassRegistration<LegalizeTFL> pass(
2975     PASS_NAME, "Legalize from TensorFlow Lite to TOSA dialect");
2976 
2977 }  // namespace tosa
2978 }  // namespace mlir
2979