• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Legalize TensorFlow to TOSA
17 
18 #include <climits>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 
24 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
29 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
30 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
31 
32 #define PASS_NAME "tosa-legalize-tf"
33 #define DEBUG_TYPE PASS_NAME
34 
35 namespace mlir {
36 
37 namespace tosa {
38 
39 namespace {
40 // Performs lowering to TOSA dialect
41 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
42  public:
LegalizeTF()43   explicit LegalizeTF() {}
44   void runOnFunction() override;
45 };
46 
47 // All the Pat<> lowering mappings.
48 #include "tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.inc"
49 
50 #define DECL_CONVERT_OP(tf_op)                                               \
51   struct ConvertTF##tf_op##Op : public RewritePattern {                      \
52     explicit ConvertTF##tf_op##Op(MLIRContext* context)                      \
53         : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {}   \
54     LogicalResult matchAndRewrite(Operation* op,                             \
55                                   PatternRewriter& rewriter) const override; \
56   }
57 
58 // All the explcitly implemented complex lowerings.
59 DECL_CONVERT_OP(MatMul);
60 DECL_CONVERT_OP(Relu);
61 DECL_CONVERT_OP(Relu6);
62 DECL_CONVERT_OP(Equal);
63 DECL_CONVERT_OP(NotEqual);
64 DECL_CONVERT_OP(Greater);
65 DECL_CONVERT_OP(GreaterEqual);
66 DECL_CONVERT_OP(Add);
67 DECL_CONVERT_OP(AddV2);
68 DECL_CONVERT_OP(AddN);
69 DECL_CONVERT_OP(Sub);
70 DECL_CONVERT_OP(Mul);
71 DECL_CONVERT_OP(Square);
72 DECL_CONVERT_OP(SquaredDifference);
73 DECL_CONVERT_OP(Round);
74 DECL_CONVERT_OP(FloorDiv);
75 DECL_CONVERT_OP(FloorMod);
76 DECL_CONVERT_OP(Assert);
77 DECL_CONVERT_OP(Maximum);
78 DECL_CONVERT_OP(Minimum);
79 DECL_CONVERT_OP(RealDiv);
80 DECL_CONVERT_OP(ArgMax);
81 DECL_CONVERT_OP(AvgPool);
82 DECL_CONVERT_OP(MaxPool);
83 DECL_CONVERT_OP(ConcatV2);
84 DECL_CONVERT_OP(Reshape);
85 DECL_CONVERT_OP(Rank);
86 DECL_CONVERT_OP(Shape);
87 DECL_CONVERT_OP(ExpandDims);
88 DECL_CONVERT_OP(Squeeze);
89 DECL_CONVERT_OP(Fill);
90 DECL_CONVERT_OP(Conv2D);
91 DECL_CONVERT_OP(DepthwiseConv2dNative);
92 DECL_CONVERT_OP(Conv2DBackpropInput);
93 DECL_CONVERT_OP(Elu);
94 DECL_CONVERT_OP(Softmax);
95 DECL_CONVERT_OP(LogSoftmax);
96 DECL_CONVERT_OP(All);
97 DECL_CONVERT_OP(Any);
98 DECL_CONVERT_OP(Max);
99 DECL_CONVERT_OP(Min);
100 DECL_CONVERT_OP(Mean);
101 DECL_CONVERT_OP(Prod);
102 DECL_CONVERT_OP(Sum);
103 DECL_CONVERT_OP(FusedBatchNorm);
104 DECL_CONVERT_OP(FusedBatchNormV3);
105 DECL_CONVERT_OP(BiasAdd);
106 DECL_CONVERT_OP(Split);
107 DECL_CONVERT_OP(SplitV);
108 DECL_CONVERT_OP(Pack);
109 DECL_CONVERT_OP(Unpack);
110 DECL_CONVERT_OP(Transpose);
111 DECL_CONVERT_OP(Tile);
112 DECL_CONVERT_OP(Slice);
113 DECL_CONVERT_OP(StridedSlice);
114 DECL_CONVERT_OP(Less);
115 DECL_CONVERT_OP(LessEqual);
116 DECL_CONVERT_OP(Pad);
117 DECL_CONVERT_OP(ResizeBilinear);
118 DECL_CONVERT_OP(ResizeNearestNeighbor);
119 DECL_CONVERT_OP(Gather);
120 DECL_CONVERT_OP(GatherV2);
121 DECL_CONVERT_OP(SelectV2);
122 DECL_CONVERT_OP(SpaceToDepth);
123 DECL_CONVERT_OP(DepthToSpace);
124 DECL_CONVERT_OP(SpaceToBatchND);
125 DECL_CONVERT_OP(BatchToSpaceND);
126 DECL_CONVERT_OP(ZerosLike);
127 DECL_CONVERT_OP(Sigmoid);
128 DECL_CONVERT_OP(Tanh);
129 DECL_CONVERT_OP(LeakyRelu);
130 DECL_CONVERT_OP(Neg);
131 DECL_CONVERT_OP(StopGradient);
132 DECL_CONVERT_OP(ReverseV2);
133 DECL_CONVERT_OP(FakeQuantWithMinMaxArgs);
134 DECL_CONVERT_OP(FakeQuantWithMinMaxVars);
135 #undef DECL_CONVERT_OP
136 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const137 LogicalResult ConvertTFReluOp::matchAndRewrite(
138     Operation* op, PatternRewriter& rewriter) const {
139   auto tf_relu_op = cast<TF::ReluOp>(op);
140 
141   RankedTensorType output_type =
142       tf_relu_op.getResult().getType().dyn_cast<RankedTensorType>();
143   // Not a ranked tensor output
144   if (!output_type) return failure();
145 
146   if (output_type.getElementType().isa<mlir::FloatType>()) {
147     rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
148         op, output_type, tf_relu_op.features(), rewriter.getI64IntegerAttr(0),
149         rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
150   } else {
151     rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
152         op, output_type, tf_relu_op.features(),
153         rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
154         rewriter.getF32FloatAttr(0.0f));
155   }
156   return success();
157 }
158 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const159 LogicalResult ConvertTFRelu6Op::matchAndRewrite(
160     Operation* op, PatternRewriter& rewriter) const {
161   auto tf_relu6_op = cast<TF::Relu6Op>(op);
162 
163   RankedTensorType output_type =
164       tf_relu6_op.getResult().getType().dyn_cast<RankedTensorType>();
165   // Not a ranked tensor output
166   if (!output_type) return failure();
167 
168   if (output_type.getElementType().isa<mlir::FloatType>()) {
169     rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
170         op, output_type, tf_relu6_op.features(), rewriter.getI64IntegerAttr(0),
171         rewriter.getF32FloatAttr(6.0f));
172   } else {
173     rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
174         op, output_type, tf_relu6_op.features(), rewriter.getI64IntegerAttr(6),
175         rewriter.getF32FloatAttr(0.0f));
176   }
177   return success();
178 }
179 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const180 LogicalResult ConvertTFEqualOp::matchAndRewrite(
181     Operation* op, PatternRewriter& rewriter) const {
182   auto tf_equal_op = cast<TF::EqualOp>(op);
183 
184   RankedTensorType output_type =
185       tf_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
186   // Not a ranked tensor output
187   if (!output_type) return failure();
188 
189   rewriter.replaceOpWithNewOp<tosa::EqualOp>(op, output_type, tf_equal_op.x(),
190                                              tf_equal_op.y());
191   return success();
192 }
193 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const194 LogicalResult ConvertTFNotEqualOp::matchAndRewrite(
195     Operation* op, PatternRewriter& rewriter) const {
196   auto tf_not_equal_op = cast<TF::NotEqualOp>(op);
197 
198   RankedTensorType output_type =
199       tf_not_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
200   // Not a ranked tensor output
201   if (!output_type) return failure();
202 
203   auto op1_equal_in = rewriter.create<tosa::EqualOp>(
204       op->getLoc(), output_type, tf_not_equal_op.x(), tf_not_equal_op.y());
205 
206   auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
207       op->getLoc(), output_type, op1_equal_in.getResult());
208 
209   rewriter.replaceOp(op, {op2_not_op1.getResult()});
210 
211   return success();
212 }
213 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const214 LogicalResult ConvertTFGreaterOp::matchAndRewrite(
215     Operation* op, PatternRewriter& rewriter) const {
216   auto tf_greater_op = cast<TF::GreaterOp>(op);
217 
218   RankedTensorType output_type =
219       tf_greater_op.getResult().getType().dyn_cast<RankedTensorType>();
220   // Not a ranked tensor output
221   if (!output_type) return failure();
222 
223   rewriter.replaceOpWithNewOp<tosa::GreaterOp>(
224       op, output_type, tf_greater_op.x(), tf_greater_op.y());
225   return success();
226 }
227 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const228 LogicalResult ConvertTFGreaterEqualOp::matchAndRewrite(
229     Operation* op, PatternRewriter& rewriter) const {
230   auto tf_greater_equal_op = cast<TF::GreaterEqualOp>(op);
231 
232   RankedTensorType output_type =
233       tf_greater_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
234   // Not a ranked tensor output
235   if (!output_type) return failure();
236 
237   rewriter.replaceOpWithNewOp<tosa::GreaterEqualOp>(
238       op, output_type, tf_greater_equal_op.x(), tf_greater_equal_op.y());
239   return success();
240 }
241 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const242 LogicalResult ConvertTFAddOp::matchAndRewrite(Operation* op,
243                                               PatternRewriter& rewriter) const {
244   auto tf_add_op = cast<TF::AddOp>(op);
245 
246   RankedTensorType output_type =
247       tf_add_op.getResult().getType().dyn_cast<RankedTensorType>();
248   // Not a ranked tensor output
249   if (!output_type) return failure();
250 
251   rewriter.replaceOpWithNewOp<tosa::AddOp>(op, output_type, tf_add_op.x(),
252                                            tf_add_op.y());
253   return success();
254 }
255 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const256 LogicalResult ConvertTFAddV2Op::matchAndRewrite(
257     Operation* op, PatternRewriter& rewriter) const {
258   auto tf_addv2_op = cast<TF::AddV2Op>(op);
259 
260   RankedTensorType output_type =
261       tf_addv2_op.getResult().getType().dyn_cast<RankedTensorType>();
262   // Not a ranked tensor output
263   if (!output_type) return failure();
264 
265   rewriter.replaceOpWithNewOp<tosa::AddOp>(op, output_type, tf_addv2_op.x(),
266                                            tf_addv2_op.y());
267   return success();
268 }
269 
270 // AddN is commutative
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const271 LogicalResult ConvertTFAddNOp::matchAndRewrite(
272     Operation* op, PatternRewriter& rewriter) const {
273   auto tf_addn_op = cast<TF::AddNOp>(op);
274 
275   RankedTensorType output_type =
276       tf_addn_op.getResult().getType().dyn_cast<RankedTensorType>();
277   // Not a ranked tensor output
278   if (!output_type) return failure();
279 
280   SmallVector<Value, 8> inputs(tf_addn_op.inputs());
281 
282   assert(inputs.size() >= 2);
283 
284   auto newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type,
285                                             inputs[0], inputs[1]);
286   for (int i = 2; i < inputs.size(); i++) {
287     newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type, inputs[i],
288                                          newOp.getResult());
289   }
290 
291   rewriter.replaceOp(op, {newOp.getResult()});
292 
293   return success();
294 }
295 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const296 LogicalResult ConvertTFSubOp::matchAndRewrite(Operation* op,
297                                               PatternRewriter& rewriter) const {
298   auto tf_sub_op = cast<TF::SubOp>(op);
299 
300   RankedTensorType output_type =
301       tf_sub_op.getResult().getType().dyn_cast<RankedTensorType>();
302   // Not a ranked tensor output
303   if (!output_type) return failure();
304 
305   rewriter.replaceOpWithNewOp<tosa::SubOp>(op, output_type, tf_sub_op.x(),
306                                            tf_sub_op.y());
307   return success();
308 }
309 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const310 LogicalResult ConvertTFMulOp::matchAndRewrite(Operation* op,
311                                               PatternRewriter& rewriter) const {
312   auto tf_mul_op = cast<TF::MulOp>(op);
313 
314   llvm::Optional<Value> result = convertMultiplyOp(
315       rewriter, op, tf_mul_op.getResult(), tf_mul_op.x(), tf_mul_op.y());
316 
317   if (!result) return failure();
318 
319   rewriter.replaceOp(op, {result.getValue()});
320   return success();
321 }
322 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const323 LogicalResult ConvertTFSquareOp::matchAndRewrite(
324     Operation* op, PatternRewriter& rewriter) const {
325   auto tf_square_op = cast<TF::SquareOp>(op);
326 
327   llvm::Optional<Value> result =
328       convertMultiplyOp(rewriter, op, tf_square_op.getResult(),
329                         tf_square_op.x(), tf_square_op.x());
330 
331   if (!result) return failure();
332 
333   rewriter.replaceOp(op, {result.getValue()});
334   return success();
335 }
336 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const337 LogicalResult ConvertTFSquaredDifferenceOp::matchAndRewrite(
338     Operation* op, PatternRewriter& rewriter) const {
339   auto tf_squared_op = cast<TF::SquaredDifferenceOp>(op);
340 
341   llvm::Optional<Value> result =
342       convertSquaredDifferenceOp(rewriter, op, tf_squared_op.getResult(),
343                                  tf_squared_op.x(), tf_squared_op.y());
344 
345   if (!result) return failure();
346 
347   rewriter.replaceOp(op, {result.getValue()});
348   return success();
349 }
350 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const351 LogicalResult ConvertTFRoundOp::matchAndRewrite(
352     Operation* op, PatternRewriter& rewriter) const {
353   auto tf_round_op = cast<TF::RoundOp>(op);
354 
355   RankedTensorType input_type =
356       tf_round_op.x().getType().dyn_cast<RankedTensorType>();
357   if (!input_type) {
358     return op->emitOpError("Round: input not ranked tensor type");
359   }
360 
361   if (input_type.getElementType().isa<FloatType>()) {
362     llvm::Optional<Value> result =
363         convertRoundOp(rewriter, op, tf_round_op.getResult(), tf_round_op.x());
364 
365     if (!result) return failure();
366 
367     rewriter.replaceOp(op, {result.getValue()});
368     return success();
369 
370   } else {
371     tf_round_op.replaceAllUsesWith(tf_round_op.x());
372     return success();
373   }
374 }
375 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const376 LogicalResult ConvertTFFloorDivOp::matchAndRewrite(
377     Operation* op, PatternRewriter& rewriter) const {
378   auto tf_floordiv_op = cast<TF::FloorDivOp>(op);
379 
380   llvm::Optional<Value> result =
381       convertFloorDivOp(rewriter, op, tf_floordiv_op.getResult(),
382                         tf_floordiv_op.x(), tf_floordiv_op.y());
383 
384   if (!result) return failure();
385 
386   rewriter.replaceOp(op, {result.getValue()});
387 
388   return success();
389 }
390 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const391 LogicalResult ConvertTFFloorModOp::matchAndRewrite(
392     Operation* op, PatternRewriter& rewriter) const {
393   auto tf_floormod_op = cast<TF::FloorModOp>(op);
394 
395   llvm::Optional<Value> result =
396       convertFloorModOp(rewriter, op, tf_floormod_op.getResult(),
397                         tf_floormod_op.x(), tf_floormod_op.y());
398 
399   if (!result) return failure();
400 
401   rewriter.replaceOp(op, {result.getValue()});
402 
403   return success();
404 }
405 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const406 LogicalResult ConvertTFAssertOp::matchAndRewrite(
407     Operation* op, PatternRewriter& rewriter) const {
408   op->dropAllReferences();
409   op->erase();
410   return success();
411 }
412 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const413 LogicalResult ConvertTFMaximumOp::matchAndRewrite(
414     Operation* op, PatternRewriter& rewriter) const {
415   auto tf_maximum_op = cast<TF::MaximumOp>(op);
416 
417   RankedTensorType output_type =
418       tf_maximum_op.getResult().getType().dyn_cast<RankedTensorType>();
419   // Not a ranked tensor output
420   if (!output_type) return failure();
421 
422   rewriter.replaceOpWithNewOp<tosa::MaximumOp>(
423       op, output_type, tf_maximum_op.x(), tf_maximum_op.y());
424   return success();
425 }
426 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const427 LogicalResult ConvertTFMinimumOp::matchAndRewrite(
428     Operation* op, PatternRewriter& rewriter) const {
429   auto tf_minimum_op = cast<TF::MinimumOp>(op);
430 
431   RankedTensorType output_type =
432       tf_minimum_op.getResult().getType().dyn_cast<RankedTensorType>();
433   // Not a ranked tensor output
434   if (!output_type) return failure();
435 
436   rewriter.replaceOpWithNewOp<tosa::MinimumOp>(
437       op, output_type, tf_minimum_op.x(), tf_minimum_op.y());
438   return success();
439 }
440 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const441 LogicalResult ConvertTFRealDivOp::matchAndRewrite(
442     Operation* op, PatternRewriter& rewriter) const {
443   auto tf_div_op = cast<TF::RealDivOp>(op);
444 
445   RankedTensorType y_type =
446       tf_div_op.y().getType().dyn_cast<RankedTensorType>();
447   RankedTensorType output_type =
448       tf_div_op.getResult().getType().dyn_cast<RankedTensorType>();
449   // Not a ranked tensor output
450   if (!output_type || !y_type) return failure();
451 
452   auto reciprocal_op =
453       rewriter.create<tosa::ReciprocalOp>(op->getLoc(), y_type, tf_div_op.y());
454 
455   auto mul_op = rewriter.create<tosa::MulOp>(
456       op->getLoc(), output_type, tf_div_op.x(), reciprocal_op.getResult(), 0);
457   rewriter.replaceOp(op, {mul_op.getResult()});
458 
459   return success();
460 }
461 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const462 LogicalResult ConvertTFArgMaxOp::matchAndRewrite(
463     Operation* op, PatternRewriter& rewriter) const {
464   auto tf_argmax_op = cast<TF::ArgMaxOp>(op);
465 
466   RankedTensorType input_type =
467       tf_argmax_op.input().getType().dyn_cast<RankedTensorType>();
468   RankedTensorType output_type =
469       tf_argmax_op.getResult().getType().dyn_cast<RankedTensorType>();
470   // Not a ranked tensor output
471   if (!output_type || !input_type) return failure();
472 
473   ElementsAttr axis_elems;
474   if (!matchPattern(tf_argmax_op.dimension(), m_Constant(&axis_elems)))
475     return failure();
476 
477   int32_t axis = axis_elems.getValue<IntegerAttr>({}).getInt();
478   if (axis < 0) {
479     axis += input_type.getRank();
480   }
481 
482   if (axis < 0 || axis >= input_type.getRank()) {
483     return op->emitOpError("TFArgMax: invalid axis value");
484   }
485 
486   IntegerAttr axis_attr = rewriter.getI64IntegerAttr(axis);
487 
488   rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, output_type,
489                                               tf_argmax_op.input(), axis_attr);
490 
491   return success();
492 }
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const493 LogicalResult ConvertTFAvgPoolOp::matchAndRewrite(
494     Operation* op, PatternRewriter& rewriter) const {
495   auto tf_avgpool_op = cast<TF::AvgPoolOp>(op);
496 
497   RankedTensorType input_type =
498       tf_avgpool_op.value().getType().dyn_cast<RankedTensorType>();
499   RankedTensorType output_type =
500       tf_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
501   // Not a ranked tensor output
502   if (!input_type || !output_type) return failure();
503 
504   auto tmpAttr = tf_avgpool_op.data_formatAttr();
505   if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
506 
507   ArrayAttr pad;
508   ArrayAttr stride;
509   ArrayAttr kernel;
510   {
511     auto tmpAttr = tf_avgpool_op.strides();
512     if (!tmpAttr) {
513       stride = rewriter.getI64ArrayAttr({1, 1});
514     } else {
515       // Note: hardcoded to NHWC for now
516       int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
517       int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
518       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
519     }
520   }
521   {
522     auto tmpAttr = tf_avgpool_op.ksize();
523     if (!tmpAttr) {
524       kernel = rewriter.getI64ArrayAttr({1, 1});
525     } else {
526       // Note: hardcoded to NHWC for now
527       int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
528       int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
529       kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
530     }
531   }
532   {
533     tensorflow::Padding tf_pad;
534     if (!GetPaddingFromString(tf_avgpool_op.padding().str(), &tf_pad).ok())
535       return failure();
536 
537     ArrayAttr dilation =
538         rewriter.getI64ArrayAttr({1, 1});  // Pooling has no non-unit dilation
539 
540     SmallVector<int64_t, 2> i64array;
541 
542     for (auto& elem : tf_avgpool_op.ksize()) {
543       int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
544       i64array.emplace_back(value);
545     }
546 
547     RankedTensorType filter_type = RankedTensorType::get(
548         llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
549 
550     if (!getPaddingValuesFromPadType(
551             tf_pad,
552             tensorflow::FORMAT_NHWC,  // TFLite only supports this
553             1,                        // tensorflow::FORMAT_OHWI,
554             input_type, filter_type, stride, dilation, rewriter, pad))
555       return failure();
556   }
557 
558   rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
559       op, output_type, tf_avgpool_op.value(), kernel, stride, pad);
560   return success();
561 }
562 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const563 LogicalResult ConvertTFMaxPoolOp::matchAndRewrite(
564     Operation* op, PatternRewriter& rewriter) const {
565   auto tf_maxpool_op = cast<TF::MaxPoolOp>(op);
566 
567   RankedTensorType input_type =
568       tf_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
569   RankedTensorType output_type =
570       tf_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
571   // Not a ranked tensor output
572   if (!input_type || !output_type) return failure();
573 
574   auto tmpAttr = tf_maxpool_op.data_formatAttr();
575   if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
576 
577   ArrayAttr pad;
578   ArrayAttr stride;
579   ArrayAttr kernel;
580   {
581     auto tmpAttr = tf_maxpool_op.strides();
582     if (!tmpAttr) {
583       stride = rewriter.getI64ArrayAttr({1, 1});
584     } else {
585       // Note: hardcoded to NHWC for now
586       int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
587       int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
588       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
589     }
590   }
591   {
592     auto tmpAttr = tf_maxpool_op.ksize();
593     if (!tmpAttr) {
594       kernel = rewriter.getI64ArrayAttr({1, 1});
595     } else {
596       // Note: hardcoded to NHWC for now
597       int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
598       int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
599       kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
600     }
601   }
602   {
603     tensorflow::Padding tf_pad;
604     if (!GetPaddingFromString(tf_maxpool_op.padding().str(), &tf_pad).ok())
605       return failure();
606 
607     // Pooling has no non-unit dilation
608     ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
609 
610     SmallVector<int64_t, 4> i64array;
611 
612     for (auto& elem : tf_maxpool_op.ksize()) {
613       int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
614       i64array.emplace_back(value);
615     }
616 
617     RankedTensorType filter_type = RankedTensorType::get(
618         llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
619 
620     if (!getPaddingValuesFromPadType(
621             tf_pad,
622             tensorflow::FORMAT_NHWC,  // TFLite only supports this
623             1,                        // tensorflow::FORMAT_OHWI,
624             input_type, filter_type, stride, dilation, rewriter, pad))
625       return failure();
626   }
627 
628   rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
629       op, output_type, tf_maxpool_op.input(), kernel, stride, pad);
630   return success();
631 }
632 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const633 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
634     Operation* op, PatternRewriter& rewriter) const {
635   auto tf_concatv2_op = cast<TF::ConcatV2Op>(op);
636   SmallVector<Value, 8> values(tf_concatv2_op.values());
637 
638   ElementsAttr axis_elems;
639   if (!matchPattern(tf_concatv2_op.axis(), m_Constant(&axis_elems)))
640     return failure();
641 
642   int32_t axis = axis_elems.getValue<IntegerAttr>({}).getInt();
643 
644   llvm::Optional<Value> result =
645       convertConcatV2Op(rewriter, op, tf_concatv2_op.getResult(), values, axis);
646 
647   if (!result) return failure();
648 
649   rewriter.replaceOp(op, {result.getValue()});
650 
651   return success();
652 }
653 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const654 LogicalResult ConvertTFReshapeOp::matchAndRewrite(
655     Operation* op, PatternRewriter& rewriter) const {
656   auto tf_reshape_op = cast<TF::ReshapeOp>(op);
657 
658   RankedTensorType output_type =
659       tf_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
660   // Not a ranked tensor output
661   if (!output_type) return failure();
662 
663   // Regular way to match tensor as element attribute doesn't always work
664   // use output_type.getShape() which is more stable
665   SmallVector<int64_t, 8> shape_vals;
666   for (int i = 0; i < output_type.getShape().size(); i++) {
667     shape_vals.push_back(output_type.getShape()[i]);
668   }
669   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals);
670 
671   rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
672       op, output_type, tf_reshape_op.tensor(), shape_attr);
673   return success();
674 }
675 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const676 LogicalResult ConvertTFRankOp::matchAndRewrite(
677     Operation* op, PatternRewriter& rewriter) const {
678   auto tf_rank_op = cast<TF::RankOp>(op);
679 
680   RankedTensorType input_type =
681       tf_rank_op.input().getType().dyn_cast<RankedTensorType>();
682   if (!input_type) return failure();
683 
684   int32_t rank = input_type.getRank();
685 
686   RankedTensorType rank_type =
687       RankedTensorType::get({1}, rewriter.getIntegerType(32));
688   auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
689   auto rank_const =
690       rewriter.create<tosa::ConstOp>(op->getLoc(), rank_type, rank_attr);
691 
692   rewriter.replaceOp(op, {rank_const.getResult()});
693 
694   return success();
695 }
696 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const697 LogicalResult ConvertTFShapeOp::matchAndRewrite(
698     Operation* op, PatternRewriter& rewriter) const {
699   auto tf_shape_op = cast<TF::ShapeOp>(op);
700 
701   RankedTensorType output_type =
702       tf_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
703   // Not a ranked tensor output
704   if (!output_type) return failure();
705 
706   RankedTensorType input_type =
707       tf_shape_op.input().getType().dyn_cast<RankedTensorType>();
708   if (!input_type) return failure();
709 
710   auto input_shape = input_type.getShape();
711 
712   SmallVector<int32_t, 8> shape_arr;
713   for (int i = 0; i < input_shape.size(); i++) {
714     shape_arr.emplace_back(input_shape[i]);
715   }
716 
717   RankedTensorType shape_type = RankedTensorType::get(
718       {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
719   auto shape_attr = DenseElementsAttr::get(
720       shape_type, llvm::makeArrayRef<int32_t>(shape_arr));
721   auto shape_const =
722       rewriter.create<tosa::ConstOp>(op->getLoc(), shape_type, shape_attr);
723 
724   rewriter.replaceOp(op, {shape_const.getResult()});
725 
726   return success();
727 }
728 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const729 LogicalResult ConvertTFExpandDimsOp::matchAndRewrite(
730     Operation* op, PatternRewriter& rewriter) const {
731   auto tf_expanddims_op = cast<TF::ExpandDimsOp>(op);
732 
733   llvm::Optional<Value> result =
734       convertExpandDimsOp(rewriter, op, tf_expanddims_op.getResult(),
735                           tf_expanddims_op.input(), tf_expanddims_op.dim());
736 
737   if (!result) return failure();
738 
739   rewriter.replaceOp(op, {result.getValue()});
740 
741   return success();
742 }
743 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const744 LogicalResult ConvertTFSqueezeOp::matchAndRewrite(
745     Operation* op, PatternRewriter& rewriter) const {
746   auto tf_squeeze_op = cast<TF::SqueezeOp>(op);
747 
748   // Copy squeeze_dims into int32_t array
749   auto squeeze_dims_attr = tf_squeeze_op.squeeze_dimsAttr();
750   SmallVector<int32_t, 8> squeeze_dims;
751   for (auto& squeeze_dim : squeeze_dims_attr) {
752     squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
753   }
754 
755   llvm::Optional<Value> result =
756       convertSqueezeOp(rewriter, op, tf_squeeze_op.getResult(),
757                        tf_squeeze_op.input(), squeeze_dims);
758 
759   if (!result) return failure();
760 
761   rewriter.replaceOp(op, {result.getValue()});
762 
763   return success();
764 }
765 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const766 LogicalResult ConvertTFFillOp::matchAndRewrite(
767     Operation* op, PatternRewriter& rewriter) const {
768   auto tf_fill_op = cast<TF::FillOp>(op);
769 
770   RankedTensorType output_type =
771       tf_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
772   // Not a ranked tensor output
773   if (!output_type) return failure();
774 
775   ElementsAttr dims_elems;
776   if (!matchPattern(tf_fill_op.dims(), m_Constant(&dims_elems)))
777     return failure();
778   SmallVector<int64_t, 4> dims_vals;
779   uint32_t total_size = 1;
780   for (int i = 0; i < dims_elems.getNumElements(); i++) {
781     dims_vals.push_back(dims_elems.getValue<IntegerAttr>(i).getInt());
782     total_size *= dims_vals[i];
783   }
784 
785   ElementsAttr value_elem;
786   if (!matchPattern(tf_fill_op.value(), m_Constant(&value_elem)))
787     return failure();
788 
789   RankedTensorType fill_type = RankedTensorType::get(
790       ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
791   DenseElementsAttr fill_attr;
792 
793   // Convert to a compatible zero type
794   if (value_elem.getType().getElementType().isa<FloatType>()) {
795     llvm::SmallVector<float, 4> fill_arr(
796         total_size,
797         value_elem.getValue<FloatAttr>(0).getValue().convertToFloat());
798     fill_attr =
799         DenseElementsAttr::get(fill_type, llvm::makeArrayRef<float>(fill_arr));
800   } else {
801     llvm::SmallVector<int32_t, 4> fill_arr(
802         total_size,
803         value_elem.getValue<IntegerAttr>(0).getValue().getLimitedValue());
804     fill_attr = DenseElementsAttr::get(fill_type,
805                                        llvm::makeArrayRef<int32_t>(fill_arr));
806   }
807   auto fill_const_op =
808       rewriter.create<tosa::ConstOp>(op->getLoc(), fill_type, fill_attr);
809   rewriter.replaceOp(op, {fill_const_op.getResult()});
810 
811   return success();
812 }
813 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const814 LogicalResult ConvertTFConv2DOp::matchAndRewrite(
815     Operation* op, PatternRewriter& rewriter) const {
816   auto tf_conv2d_op = cast<TF::Conv2DOp>(op);
817 
818   RankedTensorType filter_type =
819       tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
820   RankedTensorType output_type =
821       tf_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
822 
823   // Set up a zero attr for subsequent pattern replacement if required
824   auto bias_dim = filter_type.getShape().back();
825   RankedTensorType bias_type =
826       RankedTensorType::get({bias_dim}, filter_type.getElementType());
827   auto bias_attr = rewriter.getZeroAttr(bias_type);
828   auto bias = rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type,
829                                              bias_attr.cast<ElementsAttr>());
830 
831   llvm::Optional<Value> result = convertTFConv2DCommon(
832       rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
833       bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
834       tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
835       tf_conv2d_op.data_format());
836 
837   if (!result) return failure();
838 
839   rewriter.replaceOp(op, {result.getValue()});
840 
841   return success();
842 }
843 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const844 LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite(
845     Operation* op, PatternRewriter& rewriter) const {
846   auto tf_dwconv2d_op = cast<TF::DepthwiseConv2dNativeOp>(op);
847 
848   RankedTensorType input_type =
849       tf_dwconv2d_op.input().getType().dyn_cast<RankedTensorType>();
850   RankedTensorType filter_type =
851       tf_dwconv2d_op.filter().getType().dyn_cast<RankedTensorType>();
852   RankedTensorType output_type =
853       tf_dwconv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
854   // Not a ranked tensor output
855   if (!input_type) return failure();
856   if (!output_type) return failure();
857 
858   // Set up a zero attr for subsequent pattern replacement if required
859   if (!filter_type) {
860     return op->emitOpError("DepthwiseConv2d: filter type unranked tensor");
861   }
862 
863   auto tmpAttr = tf_dwconv2d_op.data_formatAttr();
864   if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
865 
866   ArrayAttr stride;
867   ArrayAttr dilation;
868   ArrayAttr pad;
869   {
870     auto tmpAttr = tf_dwconv2d_op.strides();
871     if (!tmpAttr) {
872       stride = rewriter.getI64ArrayAttr({1, 1});
873     } else {
874       // Note: hardcoded to NHWC for now
875       int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
876       int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
877       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
878     }
879   }
880   {
881     auto tmpAttr = tf_dwconv2d_op.dilations();
882     if (!tmpAttr) {
883       dilation = rewriter.getI64ArrayAttr({1, 1});
884     } else {
885       // Note: hardcoded to NHWC for now
886       int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
887       int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
888       dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
889     }
890   }
891   {
892     tensorflow::Padding tf_pad;
893     if (!GetPaddingFromString(tf_dwconv2d_op.padding().str(), &tf_pad).ok())
894       return failure();
895 
896     tensorflow::TensorFormat data_format_tf;
897     if (!FormatFromString(tf_dwconv2d_op.data_format().str(), &data_format_tf))
898       return failure();
899 
900     if (tf_pad == tensorflow::Padding::EXPLICIT) {
901       pad = getPaddingValuesFromExplicitPadAttr(
902           tf_dwconv2d_op.explicit_paddings(), data_format_tf, rewriter);
903     } else {
904       if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
905                                        0,  // tensorflow::FORMAT_HWIO
906                                        input_type, filter_type, stride,
907                                        dilation, rewriter, pad))
908         return failure();
909     }
910   }
911 
912   auto filter_shape = filter_type.getShape();
913   auto bias_dim = filter_shape[2] * filter_shape[3];
914   RankedTensorType bias_type =
915       RankedTensorType::get({bias_dim}, filter_type.getElementType());
916   auto bias_attr = rewriter.getZeroAttr(bias_type);
917   auto bias = rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type,
918                                              bias_attr.cast<ElementsAttr>());
919 
920   rewriter.replaceOpWithNewOp<tosa::DepthwiseConv2DOp>(
921       op, output_type, tf_dwconv2d_op.input(), tf_dwconv2d_op.filter(), bias,
922       pad, stride, dilation);
923   return success();
924 }
925 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const926 LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite(
927     Operation* op, PatternRewriter& rewriter) const {
928   auto tf_conv_op = cast<TF::Conv2DBackpropInputOp>(op);
929 
930   RankedTensorType input_type =
931       tf_conv_op.out_backprop().getType().dyn_cast<RankedTensorType>();
932   RankedTensorType filter_type =
933       tf_conv_op.filter().getType().dyn_cast<RankedTensorType>();
934   RankedTensorType output_type =
935       tf_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
936   // Not a ranked tensor output
937   if (!input_type) return failure();
938   if (!filter_type) return failure();
939   if (!output_type) return failure();
940 
941   // Transpose [H, W, I, O] to [O, H, W, I]
942   auto filter_shape = filter_type.getShape();
943   llvm::SmallVector<int64_t, 4> a1_transpose_dims;
944   a1_transpose_dims.push_back(filter_shape[2]);
945   a1_transpose_dims.push_back(filter_shape[0]);
946   a1_transpose_dims.push_back(filter_shape[1]);
947   a1_transpose_dims.push_back(filter_shape[3]);
948   Value a1_filter_transpose_perm =
949       get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {2, 0, 1, 3});
950   auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
951       op->getLoc(),
952       RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
953                             filter_type.getElementType()),
954       tf_conv_op.filter(), a1_filter_transpose_perm);
955 
956   ArrayAttr stride;
957   ArrayAttr dilation;
958   ArrayAttr outpad;
959   ArrayAttr output_shape;
960   {
961     auto tmpAttr = tf_conv_op.strides();
962     if (!tmpAttr) {
963       stride = rewriter.getI64ArrayAttr({1, 1});
964     } else {
965       // Note: hardcoded to NHWC for now
966       int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
967       int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
968       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
969     }
970   }
971   {
972     auto tmpAttr = tf_conv_op.dilations();
973     if (!tmpAttr) {
974       dilation = rewriter.getI64ArrayAttr({1, 1});
975     } else {
976       // Note: hardcoded to NHWC for now
977       int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
978       int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
979       dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
980     }
981   }
982   {
983     tensorflow::Padding tf_pad;
984     if (!GetPaddingFromString(tf_conv_op.padding().str(), &tf_pad).ok())
985       return failure();
986 
987     tensorflow::TensorFormat data_format_tf;
988     if (!FormatFromString(tf_conv_op.data_format().str(), &data_format_tf))
989       return failure();
990 
991     if (tf_pad == tensorflow::Padding::EXPLICIT) {
992       outpad = getPaddingValuesFromExplicitPadAttr(
993           tf_conv_op.explicit_paddings(), data_format_tf, rewriter);
994     } else {
995       if (!getTransposeConv2dPaddingValues(tf_pad, data_format_tf,
996                                            0,  // tensorflow::FORMAT_HWIO,
997                                            input_type, filter_type, output_type,
998                                            stride, dilation, rewriter, outpad))
999         return failure();
1000     }
1001   }
1002   {
1003     ElementsAttr output_shape_elems;
1004     // Match from input_sizes tensor first.
1005     if (matchPattern(tf_conv_op.input_sizes(),
1006                      m_Constant(&output_shape_elems))) {
1007       llvm::SmallVector<int64_t, 4> shape_vec;
1008       for (int i = 0; i < output_shape_elems.getNumElements(); i++)
1009         shape_vec.push_back(
1010             output_shape_elems.getValue<IntegerAttr>(i).getInt());
1011       output_shape = rewriter.getI64ArrayAttr(shape_vec);
1012     } else {
1013       // Use output tensor's shape otherwise.
1014       output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
1015     }
1016   }
1017 
1018   SmallVector<float, 8> zero_bias_vec(output_type.getShape()[3], 0.0f);
1019   Value zero_bias =
1020       get1DConstTensor<tosa::ConstOp, float>(rewriter, op, zero_bias_vec);
1021 
1022   rewriter.replaceOpWithNewOp<tosa::TransposeConv2DOp>(
1023       op, output_type, tf_conv_op.out_backprop(),
1024       a1_filter_transpose_op.getResult(), zero_bias, outpad, stride, dilation,
1025       output_shape);
1026 
1027   return success();
1028 }
1029 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1030 LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op,
1031                                               PatternRewriter& rewriter) const {
1032   auto tf_all_op = cast<TF::AllOp>(op);
1033 
1034   RankedTensorType output_type =
1035       tf_all_op.getResult().getType().dyn_cast<RankedTensorType>();
1036   if (!output_type) return failure();
1037 
1038   ElementsAttr axes_elems;
1039   if (!matchPattern(tf_all_op.reduction_indices(), m_Constant(&axes_elems)))
1040     return failure();
1041 
1042   bool keep_dims = false;
1043   auto keep_dims_attr = tf_all_op.keep_dimsAttr();
1044   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1045 
1046   llvm::Optional<Value> result = convertReduceAllOp(
1047       rewriter, op, output_type, tf_all_op.input(), axes_elems, keep_dims);
1048 
1049   if (!result) return failure();
1050 
1051   rewriter.replaceOp(op, {result.getValue()});
1052 
1053   return success();
1054 }
1055 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1056 LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op,
1057                                               PatternRewriter& rewriter) const {
1058   auto tf_any_op = cast<TF::AnyOp>(op);
1059 
1060   RankedTensorType output_type =
1061       tf_any_op.getResult().getType().dyn_cast<RankedTensorType>();
1062   if (!output_type) return failure();
1063 
1064   ElementsAttr axes_elems;
1065   if (!matchPattern(tf_any_op.reduction_indices(), m_Constant(&axes_elems)))
1066     return failure();
1067 
1068   bool keep_dims = false;
1069   auto keep_dims_attr = tf_any_op.keep_dimsAttr();
1070   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1071 
1072   llvm::Optional<Value> result = convertReduceAnyOp(
1073       rewriter, op, output_type, tf_any_op.input(), axes_elems, keep_dims);
1074 
1075   if (!result) return failure();
1076 
1077   rewriter.replaceOp(op, {result.getValue()});
1078 
1079   return success();
1080 }
1081 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1082 LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op,
1083                                               PatternRewriter& rewriter) const {
1084   auto tf_max_op = cast<TF::MaxOp>(op);
1085 
1086   RankedTensorType output_type =
1087       tf_max_op.getResult().getType().dyn_cast<RankedTensorType>();
1088   if (!output_type) return failure();
1089 
1090   ElementsAttr axes_elems;
1091   if (!matchPattern(tf_max_op.reduction_indices(), m_Constant(&axes_elems)))
1092     return failure();
1093 
1094   bool keep_dims = false;
1095   auto keep_dims_attr = tf_max_op.keep_dimsAttr();
1096   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1097 
1098   llvm::Optional<Value> result = convertReduceMaxOp(
1099       rewriter, op, output_type, tf_max_op.input(), axes_elems, keep_dims);
1100 
1101   if (!result) return failure();
1102 
1103   rewriter.replaceOp(op, {result.getValue()});
1104 
1105   return success();
1106 }
1107 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1108 LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op,
1109                                               PatternRewriter& rewriter) const {
1110   auto tf_min_op = cast<TF::MinOp>(op);
1111 
1112   RankedTensorType output_type =
1113       tf_min_op.getResult().getType().dyn_cast<RankedTensorType>();
1114   if (!output_type) return failure();
1115 
1116   ElementsAttr axes_elems;
1117   if (!matchPattern(tf_min_op.reduction_indices(), m_Constant(&axes_elems)))
1118     return failure();
1119 
1120   bool keep_dims = false;
1121   auto keep_dims_attr = tf_min_op.keep_dimsAttr();
1122   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1123 
1124   llvm::Optional<Value> result = convertReduceMinOp(
1125       rewriter, op, output_type, tf_min_op.input(), axes_elems, keep_dims);
1126 
1127   if (!result) return failure();
1128 
1129   rewriter.replaceOp(op, {result.getValue()});
1130 
1131   return success();
1132 }
1133 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1134 LogicalResult ConvertTFMeanOp::matchAndRewrite(
1135     Operation* op, PatternRewriter& rewriter) const {
1136   auto tf_mean_op = cast<TF::MeanOp>(op);
1137 
1138   RankedTensorType output_type =
1139       tf_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
1140   if (!output_type) return failure();
1141 
1142   ElementsAttr axes_elems;
1143   if (!matchPattern(tf_mean_op.reduction_indices(), m_Constant(&axes_elems)))
1144     return failure();
1145 
1146   bool keep_dims = false;
1147   auto keep_dims_attr = tf_mean_op.keep_dimsAttr();
1148   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1149 
1150   llvm::Optional<Value> result = convertReduceMeanOp(
1151       rewriter, op, output_type, tf_mean_op.input(), axes_elems, keep_dims);
1152 
1153   if (!result) return failure();
1154 
1155   rewriter.replaceOp(op, {result.getValue()});
1156 
1157   return success();
1158 }
1159 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1160 LogicalResult ConvertTFProdOp::matchAndRewrite(
1161     Operation* op, PatternRewriter& rewriter) const {
1162   auto tf_prod_op = cast<TF::ProdOp>(op);
1163 
1164   RankedTensorType output_type =
1165       tf_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
1166   if (!output_type) return failure();
1167 
1168   ElementsAttr axes_elems;
1169   if (!matchPattern(tf_prod_op.reduction_indices(), m_Constant(&axes_elems)))
1170     return failure();
1171 
1172   bool keep_dims = false;
1173   auto keep_dims_attr = tf_prod_op.keep_dimsAttr();
1174   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1175 
1176   llvm::Optional<Value> result = convertReduceProdOp(
1177       rewriter, op, output_type, tf_prod_op.input(), axes_elems, keep_dims);
1178 
1179   if (!result) return failure();
1180 
1181   rewriter.replaceOp(op, {result.getValue()});
1182 
1183   return success();
1184 }
1185 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1186 LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op,
1187                                               PatternRewriter& rewriter) const {
1188   auto tf_sum_op = cast<TF::SumOp>(op);
1189 
1190   RankedTensorType output_type =
1191       tf_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
1192   if (!output_type) return failure();
1193 
1194   ElementsAttr axes_elems;
1195   if (!matchPattern(tf_sum_op.reduction_indices(), m_Constant(&axes_elems)))
1196     return failure();
1197 
1198   bool keep_dims = false;
1199   auto keep_dims_attr = tf_sum_op.keep_dimsAttr();
1200   if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1201 
1202   llvm::Optional<Value> result = convertReduceSumOp(
1203       rewriter, op, output_type, tf_sum_op.input(), axes_elems, keep_dims);
1204 
1205   if (!result) return failure();
1206 
1207   rewriter.replaceOp(op, {result.getValue()});
1208 
1209   return success();
1210 }
1211 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1212 LogicalResult ConvertTFEluOp::matchAndRewrite(Operation* op,
1213                                               PatternRewriter& rewriter) const {
1214   auto tf_elu_op = cast<TF::EluOp>(op);
1215 
1216   llvm::Optional<Value> result =
1217       convertEluOp(rewriter, op, tf_elu_op.getResult(), tf_elu_op.features());
1218 
1219   if (!result) return failure();
1220 
1221   rewriter.replaceOp(op, {result.getValue()});
1222 
1223   return success();
1224 }
1225 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1226 LogicalResult ConvertTFSoftmaxOp::matchAndRewrite(
1227     Operation* op, PatternRewriter& rewriter) const {
1228   auto tf_softmax_op = cast<TF::SoftmaxOp>(op);
1229 
1230   llvm::Optional<Value> result = convertSoftmaxOp(
1231       rewriter, op, tf_softmax_op.getResult(), tf_softmax_op.logits());
1232 
1233   if (!result) return failure();
1234 
1235   rewriter.replaceOp(op, {result.getValue()});
1236 
1237   return success();
1238 }
1239 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1240 LogicalResult ConvertTFLogSoftmaxOp::matchAndRewrite(
1241     Operation* op, PatternRewriter& rewriter) const {
1242   auto tf_logsoftmax_op = cast<TF::LogSoftmaxOp>(op);
1243 
1244   llvm::Optional<Value> result = convertLogSoftmaxOp(
1245       rewriter, op, tf_logsoftmax_op.getResult(), tf_logsoftmax_op.logits());
1246 
1247   if (!result) return failure();
1248 
1249   rewriter.replaceOp(op, {result.getValue()});
1250 
1251   return success();
1252 }
1253 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1254 LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite(
1255     Operation* op, PatternRewriter& rewriter) const {
1256   auto tf_batchnorm_op = cast<TF::FusedBatchNormOp>(op);
1257 
1258   RankedTensorType output_type =
1259       tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1260   // Not a ranked tensor output
1261   if (!output_type) return failure();
1262 
1263   // Lowering:
1264   // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
1265   //
1266   // shape_0 = ones(input.rank)
1267   // shape_0[input.rank-1] = input.shape[input.rank-1]
1268   // shape_1 = ones(1)
1269   //
1270   // bmean  = reshape(mean, shape_0)
1271   // bscale = reshape(scale, shape_0)
1272   // boffset= reshape(offset, shape_0)
1273   // beps   = reshape(epsilon, shape_1)
1274   //
1275   // op1 = sub(input, bmean)
1276   // op2 = add(var, beps)
1277   // op3 = rsqrt(op2)
1278   // bvar = reshape(op3, shape_0)
1279   // op4 = mul(op1, bvar)
1280   // op5 = mul(op4, bscale)
1281   // op6 = add(op5, boffset)
1282 
1283   RankedTensorType mean_type =
1284       tf_batchnorm_op.mean().getType().dyn_cast<RankedTensorType>();
1285   RankedTensorType variance_type =
1286       tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
1287   if (!variance_type || !mean_type) return failure();
1288 
1289   Value mean_val, variance_val;
1290 
1291   if (mean_type.getNumElements() == 0) {
1292     mean_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 0);
1293   } else {
1294     mean_val = tf_batchnorm_op.mean();
1295   }
1296 
1297   if (variance_type.getNumElements() == 0) {
1298     variance_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 1.0);
1299   } else {
1300     variance_val = tf_batchnorm_op.variance();
1301   }
1302 
1303   RankedTensorType epsilon_type =
1304       RankedTensorType::get({1}, variance_type.getElementType());
1305   auto epsilon_attr =
1306       DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
1307   auto epsilon_const =
1308       rewriter.create<tosa::ConstOp>(op->getLoc(), epsilon_type, epsilon_attr);
1309 
1310   auto op1_sub_input_mean = rewriter.create<tosa::SubOp>(
1311       op->getLoc(), tf_batchnorm_op.getResult(0).getType(), tf_batchnorm_op.x(),
1312       mean_val);
1313 
1314   auto op2_add_var_epsilon =
1315       rewriter.create<tosa::AddOp>(op->getLoc(), variance_val.getType(),
1316                                    variance_val, epsilon_const.getResult());
1317 
1318   auto op3_rsqrt_op2 = rewriter.create<tosa::RsqrtOp>(
1319       op->getLoc(), variance_val.getType(), op2_add_var_epsilon.getResult());
1320 
1321   auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
1322       op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1323       op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
1324 
1325   auto op5_mul_op4_scale = rewriter.create<tosa::MulOp>(
1326       op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1327       op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
1328 
1329   auto op6_add_op5_offset = rewriter.create<tosa::AddOp>(
1330       op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1331       op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
1332 
1333   rewriter.replaceOp(op, {op6_add_op5_offset.getResult()});
1334   return success();
1335 }
1336 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1337 LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite(
1338     Operation* op, PatternRewriter& rewriter) const {
1339   auto tf_batchnorm_op = cast<TF::FusedBatchNormV3Op>(op);
1340 
1341   RankedTensorType output_type =
1342       tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1343   // Not a ranked tensor output
1344   if (!output_type) return failure();
1345 
1346   // Lowering:
1347   // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
1348   // op1 = sub(input, mean)
1349   // op2 = add(var, epsilon)
1350   // op3 = rsqrt(op2)
1351   // op4 = mul(op1, op3)
1352   // op5 = mul(op4, scale)
1353   // op6 = add(op5, offset)
1354 
1355   auto op1_sub_input_mean = rewriter.create<tosa::SubOp>(
1356       op->getLoc(), tf_batchnorm_op.getResult(0).getType(), tf_batchnorm_op.x(),
1357       tf_batchnorm_op.mean());
1358 
1359   RankedTensorType variance_type =
1360       tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
1361   if (!variance_type) return failure();
1362 
1363   auto epsilon_type =
1364       RankedTensorType::get({1}, variance_type.getElementType());
1365   auto epsilon_attr =
1366       DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
1367   auto epsilon_const =
1368       rewriter.create<tosa::ConstOp>(op->getLoc(), epsilon_type, epsilon_attr);
1369 
1370   auto op2_add_var_epsilon = rewriter.create<tosa::AddOp>(
1371       op->getLoc(), tf_batchnorm_op.variance().getType(),
1372       tf_batchnorm_op.variance(), epsilon_const);
1373 
1374   auto op3_rsqrt_op2 = rewriter.create<tosa::RsqrtOp>(
1375       op->getLoc(), tf_batchnorm_op.variance().getType(),
1376       op2_add_var_epsilon.getResult());
1377 
1378   auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
1379       op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1380       op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
1381 
1382   auto op5_mul_op4_scale = rewriter.create<tosa::MulOp>(
1383       op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1384       op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
1385 
1386   auto op6_add_op5_offset = rewriter.create<tosa::AddOp>(
1387       op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1388       op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
1389 
1390   rewriter.replaceOp(op, {op6_add_op5_offset.getResult()});
1391   return success();
1392 }
1393 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1394 LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
1395     Operation* op, PatternRewriter& rewriter) const {
1396   auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
1397 
1398   RankedTensorType output_type =
1399       tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
1400   // Not a ranked tensor output
1401   if (!output_type) return failure();
1402 
1403   auto add_op = rewriter.create<tosa::AddOp>(
1404       op->getLoc(), output_type, tf_biasadd_op.value(), tf_biasadd_op.bias());
1405 
1406   rewriter.replaceOp(op, {add_op.getResult()});
1407   return success();
1408 }
1409 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1410 LogicalResult ConvertTFSliceOp::matchAndRewrite(
1411     Operation* op, PatternRewriter& rewriter) const {
1412   auto tf_slice_op = cast<TF::SliceOp>(op);
1413 
1414   RankedTensorType output_type =
1415       tf_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
1416   // Not a ranked tensor output
1417   if (!output_type) return failure();
1418 
1419   ElementsAttr begin_elems, size_elems;
1420 
1421   SmallVector<int64_t, 4> begin_vals, size_vals;
1422 
1423   // Assuming begin is always compile-time constant
1424   if (!matchPattern(tf_slice_op.begin(), m_Constant(&begin_elems))) {
1425     return op->emitOpError("TF::Slice error: begin is not constant");
1426   }
1427 
1428   for (int i = 0; i < begin_elems.getNumElements(); i++)
1429     begin_vals.push_back(begin_elems.getValue<IntegerAttr>(i).getInt());
1430 
1431   // Try to match size as compile-time constant first,
1432   // if this fails, use the output tensor shape instead.
1433   if (matchPattern(tf_slice_op.size(), m_Constant(&size_elems))) {
1434     for (int i = 0; i < size_elems.getNumElements(); i++)
1435       size_vals.push_back(size_elems.getValue<IntegerAttr>(i).getInt());
1436   } else {
1437     size_vals.assign(output_type.getShape().begin(),
1438                      output_type.getShape().end());
1439   }
1440 
1441   ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1442   ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1443 
1444   rewriter.replaceOpWithNewOp<tosa::SliceOp>(op, output_type,
1445                                              tf_slice_op.input(), begin, size);
1446   return success();
1447 }
1448 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1449 LogicalResult ConvertTFTileOp::matchAndRewrite(
1450     Operation* op, PatternRewriter& rewriter) const {
1451   auto tf_tile_op = cast<TF::TileOp>(op);
1452 
1453   RankedTensorType output_type =
1454       tf_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
1455   // Not a ranked tensor output
1456   if (!output_type) return failure();
1457 
1458   ElementsAttr multiples_elems;
1459   if (!matchPattern(tf_tile_op.multiples(), m_Constant(&multiples_elems)))
1460     return failure();
1461   SmallVector<int64_t, 4> multiples_vals;
1462   for (int i = 0; i < multiples_elems.getNumElements(); i++)
1463     multiples_vals.push_back(multiples_elems.getValue<IntegerAttr>(i).getInt());
1464 
1465   ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
1466 
1467   rewriter.replaceOpWithNewOp<tosa::TileOp>(op, output_type, tf_tile_op.input(),
1468                                             multiples_attr);
1469 
1470   return success();
1471 }
1472 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1473 LogicalResult ConvertTFTransposeOp::matchAndRewrite(
1474     Operation* op, PatternRewriter& rewriter) const {
1475   auto tf_transpose_op = cast<TF::TransposeOp>(op);
1476 
1477   RankedTensorType output_type =
1478       tf_transpose_op.getResult().getType().dyn_cast<RankedTensorType>();
1479   // Not a ranked tensor output
1480   if (!output_type) {
1481     return failure();
1482   }
1483 
1484   rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
1485       op, output_type, tf_transpose_op.x(), tf_transpose_op.perm());
1486 
1487   return success();
1488 }
1489 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1490 LogicalResult ConvertTFPackOp::matchAndRewrite(
1491     Operation* op, PatternRewriter& rewriter) const {
1492   auto tf_pack_op = cast<TF::PackOp>(op);
1493 
1494   SmallVector<Value, 8> inputs(tf_pack_op.values());
1495 
1496   assert(inputs.size() >= 2);
1497 
1498   IntegerAttr axis_attr;
1499   {
1500     auto tmpAttr = tf_pack_op.axisAttr();
1501     if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
1502     axis_attr = tmpAttr;
1503   }
1504   int32_t axis_i32 = axis_attr.getInt();
1505 
1506   llvm::Optional<Value> result =
1507       convertPackOp(rewriter, op, tf_pack_op.getResult(), inputs, axis_i32);
1508 
1509   if (!result) return failure();
1510 
1511   rewriter.replaceOp(op, {result.getValue()});
1512 
1513   return success();
1514 }
1515 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1516 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
1517     Operation* op, PatternRewriter& rewriter) const {
1518   auto tf_unpack_op = cast<TF::UnpackOp>(op);
1519 
1520   IntegerAttr axis_attr;
1521   {
1522     auto tmpAttr = tf_unpack_op.axisAttr();
1523     if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
1524     axis_attr = tmpAttr;
1525   }
1526   int32_t axis_i32 = axis_attr.getInt();
1527 
1528   llvm::Optional<ValueRange> results =
1529       convertUnpackOp(rewriter, op, tf_unpack_op.value(), axis_i32);
1530 
1531   if (!results) return failure();
1532 
1533   rewriter.replaceOp(op, results.getValue());
1534 
1535   return success();
1536 }
1537 
1538 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1539 LogicalResult ConvertTFSplitOp::matchAndRewrite(
1540     Operation* op, PatternRewriter& rewriter) const {
1541   auto tf_split_op = cast<TF::SplitOp>(op);
1542 
1543   // Get the number of splits
1544   int32_t num_split = -1;
1545 
1546   auto range = tf_split_op.getODSResults(0);
1547   num_split = std::distance(range.begin(), range.end());
1548 
1549   // Get the axis
1550   int32_t axis = 0;
1551   ElementsAttr axisAttrElems;
1552   if (matchPattern(tf_split_op.split_dim(), m_Constant(&axisAttrElems))) {
1553     axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
1554   }
1555 
1556   llvm::Optional<ValueRange> results =
1557       convertSplitOp(rewriter, op, tf_split_op.getResult(0),
1558                      tf_split_op.value(), num_split, axis);
1559 
1560   if (!results) return failure();
1561 
1562   rewriter.replaceOp(op, results.getValue());
1563 
1564   return success();
1565 }
1566 
1567 // TFSplitV op splits based on a vector of sizes
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1568 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
1569     Operation* op, PatternRewriter& rewriter) const {
1570   auto tf_splitv_op = cast<TF::SplitVOp>(op);
1571 
1572   // Get the size_splits array
1573   SmallVector<int32_t, 4> size_split;
1574   ElementsAttr size_split_elems;
1575   if (!matchPattern(tf_splitv_op.size_splits(),
1576                     m_Constant(&size_split_elems))) {
1577     return failure();
1578   }
1579 
1580   for (int i = 0; i < size_split_elems.getNumElements(); i++) {
1581     size_split.push_back(size_split_elems.getValue<IntegerAttr>(i).getInt());
1582   }
1583 
1584   // Get the axis
1585   ElementsAttr axisAttrElems;
1586   if (!matchPattern(tf_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
1587     return op->emitOpError("Cannot read split_dim elems");
1588   }
1589 
1590   int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
1591 
1592   llvm::Optional<ValueRange> results =
1593       convertSplitVOp(rewriter, op, tf_splitv_op.getResult(0),
1594                       tf_splitv_op.value(), size_split, axis);
1595 
1596   if (!results) return failure();
1597 
1598   rewriter.replaceOp(op, results.getValue());
1599 
1600   return success();
1601 }
1602 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1603 LogicalResult ConvertTFLessOp::matchAndRewrite(
1604     Operation* op, PatternRewriter& rewriter) const {
1605   auto tf_less_op = cast<TF::LessOp>(op);
1606 
1607   RankedTensorType output_type =
1608       tf_less_op.getResult().getType().dyn_cast<RankedTensorType>();
1609   // Not a ranked tensor output
1610   if (!output_type) return failure();
1611 
1612   // less(x, y) is not(greater_equal(x, y))
1613   auto greater_equal_op = rewriter.create<tosa::GreaterEqualOp>(
1614       op->getLoc(), output_type, tf_less_op.x(), tf_less_op.y());
1615 
1616   auto not_op = rewriter.create<tosa::LogicalNotOp>(
1617       op->getLoc(), output_type, greater_equal_op.getResult());
1618 
1619   rewriter.replaceOp(op, {not_op.getResult()});
1620   return success();
1621 }
1622 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1623 LogicalResult ConvertTFLessEqualOp::matchAndRewrite(
1624     Operation* op, PatternRewriter& rewriter) const {
1625   auto tf_less_equal_op = cast<TF::LessEqualOp>(op);
1626 
1627   RankedTensorType output_type =
1628       tf_less_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
1629   // Not a ranked tensor output
1630   if (!output_type) return failure();
1631 
1632   // less_equal(x, y) is not(greater(x, y))
1633   auto greater_op = rewriter.create<tosa::GreaterOp>(
1634       op->getLoc(), output_type, tf_less_equal_op.x(), tf_less_equal_op.y());
1635 
1636   auto not_op = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), output_type,
1637                                                     greater_op.getResult());
1638 
1639   rewriter.replaceOp(op, {not_op.getResult()});
1640   return success();
1641 }
1642 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1643 LogicalResult ConvertTFPadOp::matchAndRewrite(Operation* op,
1644                                               PatternRewriter& rewriter) const {
1645   auto tf_pad_op = cast<TF::PadOp>(op);
1646 
1647   RankedTensorType output_type =
1648       tf_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
1649   // Not a ranked tensor output
1650   if (!output_type) return failure();
1651 
1652   auto pad_op = rewriter.create<tosa::PadOp>(
1653       op->getLoc(), output_type, tf_pad_op.input(), tf_pad_op.paddings());
1654 
1655   rewriter.replaceOp(op, {pad_op.getResult()});
1656   return success();
1657 }
1658 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1659 LogicalResult ConvertTFResizeBilinearOp::matchAndRewrite(
1660     Operation* op, PatternRewriter& rewriter) const {
1661   auto tf_resize_op = cast<TF::ResizeBilinearOp>(op);
1662 
1663   RankedTensorType output_type =
1664       tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
1665   // Not a ranked tensor output
1666   if (!output_type) return failure();
1667 
1668   llvm::Optional<Value> result = convertResizeOp(
1669       rewriter, op, output_type, tf_resize_op.images(), StringRef("BILINEAR"));
1670 
1671   if (!result) return failure();
1672 
1673   rewriter.replaceOp(op, {result.getValue()});
1674 
1675   return success();
1676 }
1677 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1678 LogicalResult ConvertTFResizeNearestNeighborOp::matchAndRewrite(
1679     Operation* op, PatternRewriter& rewriter) const {
1680   auto tf_resize_op = cast<TF::ResizeNearestNeighborOp>(op);
1681 
1682   RankedTensorType output_type =
1683       tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
1684   // Not a ranked tensor output
1685   if (!output_type) return failure();
1686 
1687   llvm::Optional<Value> result = convertResizeOp(
1688       rewriter, op, output_type, tf_resize_op.images(), StringRef("NEAREST"));
1689 
1690   if (!result) return failure();
1691 
1692   rewriter.replaceOp(op, {result.getValue()});
1693 
1694   return success();
1695 }
1696 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1697 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
1698     Operation* op, PatternRewriter& rewriter) const {
1699   auto tf_matmul_op = cast<TF::MatMulOp>(op);
1700 
1701   RankedTensorType a_type =
1702       tf_matmul_op.a().getType().dyn_cast<RankedTensorType>();
1703   RankedTensorType b_type =
1704       tf_matmul_op.b().getType().dyn_cast<RankedTensorType>();
1705   RankedTensorType output_type =
1706       tf_matmul_op.getResult().getType().dyn_cast<RankedTensorType>();
1707 
1708   if (!(a_type && b_type && output_type)) {
1709     return op->emitOpError("MatMul: a/b/output not ranked tensors");
1710   }
1711 
1712   // Can only handle rank=2 inputs
1713   if (a_type.getShape().size() != 2) return failure();
1714 
1715   rewriter.replaceOpWithNewOp<tosa::MatMulOp>(op, output_type, tf_matmul_op.a(),
1716                                               tf_matmul_op.b());
1717 
1718   return success();
1719 }
1720 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1721 LogicalResult ConvertTFGatherOp::matchAndRewrite(
1722     Operation* op, PatternRewriter& rewriter) const {
1723   auto tf_gather_op = cast<TF::GatherOp>(op);
1724 
1725   RankedTensorType output_type =
1726       tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
1727   if (!output_type) return failure();
1728 
1729   IntegerAttr axis_attr = rewriter.getI32IntegerAttr(0);
1730 
1731   // TODO: batchdim_attr handling to be implemented with a revised
1732   // defintion of the TOSA operator.
1733   rewriter.replaceOpWithNewOp<tosa::GatherOp>(
1734       op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
1735       axis_attr);
1736 
1737   return success();
1738 }
1739 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1740 LogicalResult ConvertTFGatherV2Op::matchAndRewrite(
1741     Operation* op, PatternRewriter& rewriter) const {
1742   auto tf_gather_op = cast<TF::GatherV2Op>(op);
1743 
1744   RankedTensorType output_type =
1745       tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
1746   if (!output_type) return failure();
1747 
1748   // Axis is a tensor in TF. Convert to I64Attr for TOSA
1749   ElementsAttr axis_elem;
1750   if (!matchPattern(tf_gather_op.axis(), m_Constant(&axis_elem)))
1751     return failure();
1752   assert(axis_elem.getType().getRank() == 0 && "expected 0D tensor");
1753 
1754   IntegerAttr batchdim_attr;
1755   {
1756     auto tmpAttr = tf_gather_op.batch_dimsAttr();
1757     if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
1758     batchdim_attr = tmpAttr;
1759   }
1760 
1761   // TODO: batchdim_attr handling to be implemented with a revised
1762   // defintion of the TOSA operator.
1763   rewriter.replaceOpWithNewOp<tosa::GatherOp>(
1764       op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
1765       rewriter.getI32IntegerAttr(axis_elem.getValue<IntegerAttr>({}).getInt()));
1766 
1767   return success();
1768 }
1769 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1770 LogicalResult ConvertTFSelectV2Op::matchAndRewrite(
1771     Operation* op, PatternRewriter& rewriter) const {
1772   auto tf_sel_op = cast<TF::SelectV2Op>(op);
1773 
1774   llvm::Optional<Value> result =
1775       convertSelectOp(rewriter, op, tf_sel_op.getResult(),
1776                       tf_sel_op.condition(), tf_sel_op.t(), tf_sel_op.e());
1777 
1778   if (!result) return failure();
1779 
1780   rewriter.replaceOp(op, {result.getValue()});
1781 
1782   return success();
1783 }
1784 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1785 LogicalResult ConvertTFSpaceToDepthOp::matchAndRewrite(
1786     Operation* op, PatternRewriter& rewriter) const {
1787   auto tf_s2d_op = cast<TF::SpaceToDepthOp>(op);
1788 
1789   llvm::Optional<Value> result = convertSpaceToDepthOp(
1790       rewriter, op, tf_s2d_op.getResult(), tf_s2d_op.input(),
1791       tf_s2d_op.block_sizeAttr(), tf_s2d_op.data_formatAttr());
1792 
1793   if (!result) return failure();
1794 
1795   rewriter.replaceOp(op, {result.getValue()});
1796 
1797   return success();
1798 }
1799 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1800 LogicalResult ConvertTFDepthToSpaceOp::matchAndRewrite(
1801     Operation* op, PatternRewriter& rewriter) const {
1802   auto tf_d2s_op = cast<TF::DepthToSpaceOp>(op);
1803 
1804   llvm::Optional<Value> result = convertDepthToSpaceOp(
1805       rewriter, op, tf_d2s_op.getResult(), tf_d2s_op.input(),
1806       tf_d2s_op.block_sizeAttr(), tf_d2s_op.data_formatAttr());
1807 
1808   if (!result) return failure();
1809 
1810   rewriter.replaceOp(op, {result.getValue()});
1811 
1812   return success();
1813 }
1814 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1815 LogicalResult ConvertTFSpaceToBatchNDOp::matchAndRewrite(
1816     Operation* op, PatternRewriter& rewriter) const {
1817   auto tf_s2b_op = cast<TF::SpaceToBatchNDOp>(op);
1818 
1819   llvm::Optional<Value> result = convertSpaceToBatchNDOp(
1820       rewriter, op, tf_s2b_op.getResult(), tf_s2b_op.input(),
1821       tf_s2b_op.block_shape(), tf_s2b_op.paddings());
1822   if (!result) return failure();
1823 
1824   rewriter.replaceOp(op, {result.getValue()});
1825 
1826   return success();
1827 }
1828 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1829 LogicalResult ConvertTFBatchToSpaceNDOp::matchAndRewrite(
1830     Operation* op, PatternRewriter& rewriter) const {
1831   auto tf_b2s_op = cast<TF::BatchToSpaceNDOp>(op);
1832 
1833   llvm::Optional<Value> result = convertBatchToSpaceNDOp(
1834       rewriter, op, tf_b2s_op.getResult(), tf_b2s_op.input(),
1835       tf_b2s_op.block_shape(), tf_b2s_op.crops());
1836 
1837   if (!result) return failure();
1838 
1839   rewriter.replaceOp(op, {result.getValue()});
1840 
1841   return success();
1842 }
1843 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1844 LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
1845     Operation* op, PatternRewriter& rewriter) const {
1846   auto tf_ss_op = cast<TF::StridedSliceOp>(op);
1847 
1848   llvm::Optional<Value> result = convertStridedSliceOp(
1849       rewriter, op, tf_ss_op.getResult(), tf_ss_op.input(), tf_ss_op.begin(),
1850       tf_ss_op.end(), tf_ss_op.strides(), tf_ss_op.begin_maskAttr().getInt(),
1851       tf_ss_op.end_maskAttr().getInt(), tf_ss_op.ellipsis_maskAttr().getInt(),
1852       tf_ss_op.new_axis_maskAttr().getInt(),
1853       tf_ss_op.shrink_axis_maskAttr().getInt());
1854 
1855   if (!result) return failure();
1856 
1857   rewriter.replaceOp(op, {result.getValue()});
1858 
1859   return success();
1860 }
1861 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1862 LogicalResult ConvertTFZerosLikeOp::matchAndRewrite(
1863     Operation* op, PatternRewriter& rewriter) const {
1864   auto tf_zeroslike_op = cast<TF::ZerosLikeOp>(op);
1865 
1866   llvm::Optional<Value> result = convertZerosLikeOp(
1867       rewriter, op, tf_zeroslike_op.getResult(), tf_zeroslike_op.x());
1868 
1869   if (!result) return failure();
1870 
1871   rewriter.replaceOp(op, {result.getValue()});
1872 
1873   return success();
1874 }
1875 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1876 LogicalResult ConvertTFSigmoidOp::matchAndRewrite(
1877     Operation* op, PatternRewriter& rewriter) const {
1878   auto tf_sigmoid_op = cast<TF::SigmoidOp>(op);
1879   RankedTensorType output_type =
1880       tf_sigmoid_op.getResult().getType().dyn_cast<RankedTensorType>();
1881   if (!output_type) return failure();
1882 
1883   rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(op, output_type,
1884                                                tf_sigmoid_op.x());
1885 
1886   return success();
1887 }
1888 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1889 LogicalResult ConvertTFTanhOp::matchAndRewrite(
1890     Operation* op, PatternRewriter& rewriter) const {
1891   auto tf_tanh_op = cast<TF::TanhOp>(op);
1892   RankedTensorType output_type =
1893       tf_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
1894   if (!output_type) return failure();
1895 
1896   rewriter.replaceOpWithNewOp<tosa::TanhOp>(op, output_type, tf_tanh_op.x());
1897 
1898   return success();
1899 }
1900 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1901 LogicalResult ConvertTFLeakyReluOp::matchAndRewrite(
1902     Operation* op, PatternRewriter& rewriter) const {
1903   auto tf_leakyrelu_op = cast<TF::LeakyReluOp>(op);
1904   RankedTensorType output_type =
1905       tf_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
1906   if (!output_type) return failure();
1907 
1908   // TODO: add lowering with MUL + SELECT
1909 
1910   return failure();
1911 }
1912 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1913 LogicalResult ConvertTFNegOp::matchAndRewrite(Operation* op,
1914                                               PatternRewriter& rewriter) const {
1915   auto tf_neg_op = cast<TF::NegOp>(op);
1916   RankedTensorType output_type =
1917       tf_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
1918   if (!output_type) return failure();
1919 
1920   rewriter.replaceOpWithNewOp<tosa::NegateOp>(op, output_type, tf_neg_op.x());
1921 
1922   return success();
1923 }
1924 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1925 LogicalResult ConvertTFStopGradientOp::matchAndRewrite(
1926     Operation* op, PatternRewriter& rewriter) const {
1927   auto tf_stopgrad_op = cast<TF::StopGradientOp>(op);
1928   RankedTensorType output_type =
1929       tf_stopgrad_op.getResult().getType().dyn_cast<RankedTensorType>();
1930   if (!output_type) return failure();
1931 
1932   rewriter.replaceOpWithNewOp<tosa::IdentityOp>(op, output_type,
1933                                                 tf_stopgrad_op.input());
1934 
1935   return success();
1936 }
1937 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1938 LogicalResult ConvertTFReverseV2Op::matchAndRewrite(
1939     Operation* op, PatternRewriter& rewriter) const {
1940   auto tf_reverse_op = cast<TF::ReverseV2Op>(op);
1941   RankedTensorType input_type =
1942       tf_reverse_op.tensor().getType().dyn_cast<RankedTensorType>();
1943   RankedTensorType output_type =
1944       tf_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
1945   if (!input_type || !output_type) return failure();
1946 
1947   ElementsAttr axis_elems;
1948   if (!matchPattern(tf_reverse_op.axis(), m_Constant(&axis_elems)))
1949     return failure();
1950 
1951   auto input_rank = input_type.getShape().size();
1952   Value val = tf_reverse_op.tensor();
1953   if (axis_elems.getNumElements() == 0) {
1954     auto identity_op =
1955         rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
1956     val = identity_op.getResult();
1957   } else {
1958     for (int i = 0; i < axis_elems.getNumElements(); i++) {
1959       int64_t axis_val = axis_elems.getValue<IntegerAttr>(i).getInt();
1960       if (axis_val < 0) axis_val += input_rank;
1961       auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
1962       auto reverse_op = rewriter.create<tosa::ReverseOp>(
1963           op->getLoc(), output_type, val, axis_attr);
1964 
1965       val = reverse_op.getResult();
1966     }
1967   }
1968 
1969   rewriter.replaceOp(op, {val});
1970 
1971   return success();
1972 }
1973 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1974 LogicalResult ConvertTFFakeQuantWithMinMaxArgsOp::matchAndRewrite(
1975     Operation* op, PatternRewriter& rewriter) const {
1976   auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxArgsOp>(op);
1977 
1978   RankedTensorType output_type =
1979       tf_fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
1980   // Not a ranked tensor output
1981   if (!output_type) return failure();
1982 
1983   llvm::Optional<Value> result =
1984       convertFakeQuantOp(rewriter, op, output_type, tf_fakequant_op.inputs(),
1985                          tf_fakequant_op.minAttr().getValueAsDouble(),
1986                          tf_fakequant_op.maxAttr().getValueAsDouble(),
1987                          tf_fakequant_op.num_bitsAttr().getInt(),
1988                          tf_fakequant_op.narrow_rangeAttr().getValue());
1989 
1990   if (!result) return failure();
1991 
1992   rewriter.replaceOp(op, {result.getValue()});
1993 
1994   return success();
1995 }
1996 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1997 LogicalResult ConvertTFFakeQuantWithMinMaxVarsOp::matchAndRewrite(
1998     Operation* op, PatternRewriter& rewriter) const {
1999   auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
2000 
2001   RankedTensorType output_type =
2002       tf_fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
2003   // Not a ranked tensor output
2004   if (!output_type) return failure();
2005 
2006   // Only support min/max that can be matched at compile time
2007   ElementsAttr min_elems, max_elems;
2008   if (!matchPattern(tf_fakequant_op.min(), m_Constant(&min_elems)))
2009     return failure();
2010 
2011   if (!matchPattern(tf_fakequant_op.max(), m_Constant(&max_elems)))
2012     return failure();
2013 
2014   if (min_elems.getNumElements() != 1 && max_elems.getNumElements() != 1)
2015     return failure();
2016 
2017   int64_t min_val = min_elems.getValue<IntegerAttr>(0).getInt();
2018   int64_t max_val = max_elems.getValue<IntegerAttr>(0).getInt();
2019 
2020   llvm::Optional<Value> result = convertFakeQuantOp(
2021       rewriter, op, output_type, tf_fakequant_op.inputs(), min_val, max_val,
2022       tf_fakequant_op.num_bitsAttr().getInt(),
2023       tf_fakequant_op.narrow_rangeAttr().getValue());
2024 
2025   if (!result) return failure();
2026 
2027   rewriter.replaceOp(op, {result.getValue()});
2028 
2029   return success();
2030 }
2031 
runOnFunction()2032 void LegalizeTF::runOnFunction() {
2033   OwningRewritePatternList patterns;
2034   auto* ctx = &getContext();
2035   auto func = getFunction();
2036 
2037   // Add the generated patterns to the list.
2038   populateWithGenerated(ctx, patterns);
2039   patterns.insert<ConvertTFMatMulOp>(ctx);
2040   patterns.insert<ConvertTFReluOp>(ctx);
2041   patterns.insert<ConvertTFRelu6Op>(ctx);
2042   patterns.insert<ConvertTFEqualOp>(ctx);
2043   patterns.insert<ConvertTFNotEqualOp>(ctx);
2044   patterns.insert<ConvertTFGreaterOp>(ctx);
2045   patterns.insert<ConvertTFGreaterEqualOp>(ctx);
2046   patterns.insert<ConvertTFAddOp>(ctx);
2047   patterns.insert<ConvertTFAddV2Op>(ctx);
2048   patterns.insert<ConvertTFAddNOp>(ctx);
2049   patterns.insert<ConvertTFSubOp>(ctx);
2050   patterns.insert<ConvertTFMulOp>(ctx);
2051   patterns.insert<ConvertTFSquareOp>(ctx);
2052   patterns.insert<ConvertTFSquaredDifferenceOp>(ctx);
2053   patterns.insert<ConvertTFRoundOp>(ctx);
2054   patterns.insert<ConvertTFFloorDivOp>(ctx);
2055   patterns.insert<ConvertTFFloorModOp>(ctx);
2056   patterns.insert<ConvertTFAssertOp>(ctx);
2057   patterns.insert<ConvertTFMaximumOp>(ctx);
2058   patterns.insert<ConvertTFMinimumOp>(ctx);
2059   patterns.insert<ConvertTFRealDivOp>(ctx);
2060   patterns.insert<ConvertTFArgMaxOp>(ctx);
2061   patterns.insert<ConvertTFAvgPoolOp>(ctx);
2062   patterns.insert<ConvertTFMaxPoolOp>(ctx);
2063   patterns.insert<ConvertTFConcatV2Op>(ctx);
2064   patterns.insert<ConvertTFReshapeOp>(ctx);
2065   patterns.insert<ConvertTFRankOp>(ctx);
2066   patterns.insert<ConvertTFShapeOp>(ctx);
2067   patterns.insert<ConvertTFExpandDimsOp>(ctx);
2068   patterns.insert<ConvertTFSqueezeOp>(ctx);
2069   patterns.insert<ConvertTFFillOp>(ctx);
2070   patterns.insert<ConvertTFConv2DOp>(ctx);
2071   patterns.insert<ConvertTFDepthwiseConv2dNativeOp>(ctx);
2072   patterns.insert<ConvertTFConv2DBackpropInputOp>(ctx);
2073   patterns.insert<ConvertTFEluOp>(ctx);
2074   patterns.insert<ConvertTFSoftmaxOp>(ctx);
2075   patterns.insert<ConvertTFLogSoftmaxOp>(ctx);
2076   patterns.insert<ConvertTFAllOp>(ctx);
2077   patterns.insert<ConvertTFAnyOp>(ctx);
2078   patterns.insert<ConvertTFMaxOp>(ctx);
2079   patterns.insert<ConvertTFMinOp>(ctx);
2080   patterns.insert<ConvertTFMeanOp>(ctx);
2081   patterns.insert<ConvertTFProdOp>(ctx);
2082   patterns.insert<ConvertTFSumOp>(ctx);
2083   patterns.insert<ConvertTFFusedBatchNormOp>(ctx);
2084   patterns.insert<ConvertTFFusedBatchNormV3Op>(ctx);
2085   patterns.insert<ConvertTFBiasAddOp>(ctx);
2086   patterns.insert<ConvertTFSplitOp>(ctx);
2087   patterns.insert<ConvertTFSplitVOp>(ctx);
2088   patterns.insert<ConvertTFPackOp>(ctx);
2089   patterns.insert<ConvertTFUnpackOp>(ctx);
2090   patterns.insert<ConvertTFTransposeOp>(ctx);
2091   patterns.insert<ConvertTFTileOp>(ctx);
2092   patterns.insert<ConvertTFSliceOp>(ctx);
2093   patterns.insert<ConvertTFStridedSliceOp>(ctx);
2094   patterns.insert<ConvertTFLessOp>(ctx);
2095   patterns.insert<ConvertTFLessEqualOp>(ctx);
2096   patterns.insert<ConvertTFPadOp>(ctx);
2097   patterns.insert<ConvertTFResizeBilinearOp>(ctx);
2098   patterns.insert<ConvertTFResizeNearestNeighborOp>(ctx);
2099   patterns.insert<ConvertTFGatherOp>(ctx);
2100   patterns.insert<ConvertTFGatherV2Op>(ctx);
2101   patterns.insert<ConvertTFSelectV2Op>(ctx);
2102   patterns.insert<ConvertTFSpaceToDepthOp>(ctx);
2103   patterns.insert<ConvertTFDepthToSpaceOp>(ctx);
2104   patterns.insert<ConvertTFSpaceToBatchNDOp>(ctx);
2105   patterns.insert<ConvertTFBatchToSpaceNDOp>(ctx);
2106   patterns.insert<ConvertTFZerosLikeOp>(ctx);
2107   patterns.insert<ConvertTFSigmoidOp>(ctx);
2108   patterns.insert<ConvertTFTanhOp>(ctx);
2109   patterns.insert<ConvertTFLeakyReluOp>(ctx);
2110   patterns.insert<ConvertTFNegOp>(ctx);
2111   patterns.insert<ConvertTFStopGradientOp>(ctx);
2112   patterns.insert<ConvertTFReverseV2Op>(ctx);
2113   patterns.insert<ConvertTFFakeQuantWithMinMaxArgsOp>(ctx);
2114   patterns.insert<ConvertTFFakeQuantWithMinMaxVarsOp>(ctx);
2115   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
2116 }
2117 
2118 }  // anonymous namespace
2119 
2120 // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
createLegalizeTFPass()2121 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass() {
2122   return std::make_unique<LegalizeTF>();
2123 }
2124 
2125 static PassRegistration<LegalizeTF> pass(
2126     PASS_NAME, "Legalize from TensorFlow to TOSA dialect");
2127 
2128 }  // namespace tosa
2129 
2130 }  // namespace mlir
2131