1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Legalize TensorFlow to TOSA
17
18 #include <climits>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23
24 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
25 #include "mlir/Support/LLVM.h" // from @llvm-project
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
29 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
30 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
31
32 #define PASS_NAME "tosa-legalize-tf"
33 #define DEBUG_TYPE PASS_NAME
34
35 namespace mlir {
36
37 namespace tosa {
38
39 namespace {
40 // Performs lowering to TOSA dialect
41 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
42 public:
LegalizeTF()43 explicit LegalizeTF() {}
44 void runOnFunction() override;
45 };
46
47 // All the Pat<> lowering mappings.
48 #include "tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.inc"
49
50 #define DECL_CONVERT_OP(tf_op) \
51 struct ConvertTF##tf_op##Op : public RewritePattern { \
52 explicit ConvertTF##tf_op##Op(MLIRContext* context) \
53 : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
54 LogicalResult matchAndRewrite(Operation* op, \
55 PatternRewriter& rewriter) const override; \
56 }
57
58 // All the explcitly implemented complex lowerings.
59 DECL_CONVERT_OP(MatMul);
60 DECL_CONVERT_OP(Relu);
61 DECL_CONVERT_OP(Relu6);
62 DECL_CONVERT_OP(Equal);
63 DECL_CONVERT_OP(NotEqual);
64 DECL_CONVERT_OP(Greater);
65 DECL_CONVERT_OP(GreaterEqual);
66 DECL_CONVERT_OP(Add);
67 DECL_CONVERT_OP(AddV2);
68 DECL_CONVERT_OP(AddN);
69 DECL_CONVERT_OP(Sub);
70 DECL_CONVERT_OP(Mul);
71 DECL_CONVERT_OP(Square);
72 DECL_CONVERT_OP(SquaredDifference);
73 DECL_CONVERT_OP(Round);
74 DECL_CONVERT_OP(FloorDiv);
75 DECL_CONVERT_OP(FloorMod);
76 DECL_CONVERT_OP(Assert);
77 DECL_CONVERT_OP(Maximum);
78 DECL_CONVERT_OP(Minimum);
79 DECL_CONVERT_OP(RealDiv);
80 DECL_CONVERT_OP(ArgMax);
81 DECL_CONVERT_OP(AvgPool);
82 DECL_CONVERT_OP(MaxPool);
83 DECL_CONVERT_OP(ConcatV2);
84 DECL_CONVERT_OP(Reshape);
85 DECL_CONVERT_OP(Rank);
86 DECL_CONVERT_OP(Shape);
87 DECL_CONVERT_OP(ExpandDims);
88 DECL_CONVERT_OP(Squeeze);
89 DECL_CONVERT_OP(Fill);
90 DECL_CONVERT_OP(Conv2D);
91 DECL_CONVERT_OP(DepthwiseConv2dNative);
92 DECL_CONVERT_OP(Conv2DBackpropInput);
93 DECL_CONVERT_OP(Elu);
94 DECL_CONVERT_OP(Softmax);
95 DECL_CONVERT_OP(LogSoftmax);
96 DECL_CONVERT_OP(All);
97 DECL_CONVERT_OP(Any);
98 DECL_CONVERT_OP(Max);
99 DECL_CONVERT_OP(Min);
100 DECL_CONVERT_OP(Mean);
101 DECL_CONVERT_OP(Prod);
102 DECL_CONVERT_OP(Sum);
103 DECL_CONVERT_OP(FusedBatchNorm);
104 DECL_CONVERT_OP(FusedBatchNormV3);
105 DECL_CONVERT_OP(BiasAdd);
106 DECL_CONVERT_OP(Split);
107 DECL_CONVERT_OP(SplitV);
108 DECL_CONVERT_OP(Pack);
109 DECL_CONVERT_OP(Unpack);
110 DECL_CONVERT_OP(Transpose);
111 DECL_CONVERT_OP(Tile);
112 DECL_CONVERT_OP(Slice);
113 DECL_CONVERT_OP(StridedSlice);
114 DECL_CONVERT_OP(Less);
115 DECL_CONVERT_OP(LessEqual);
116 DECL_CONVERT_OP(Pad);
117 DECL_CONVERT_OP(ResizeBilinear);
118 DECL_CONVERT_OP(ResizeNearestNeighbor);
119 DECL_CONVERT_OP(Gather);
120 DECL_CONVERT_OP(GatherV2);
121 DECL_CONVERT_OP(SelectV2);
122 DECL_CONVERT_OP(SpaceToDepth);
123 DECL_CONVERT_OP(DepthToSpace);
124 DECL_CONVERT_OP(SpaceToBatchND);
125 DECL_CONVERT_OP(BatchToSpaceND);
126 DECL_CONVERT_OP(ZerosLike);
127 DECL_CONVERT_OP(Sigmoid);
128 DECL_CONVERT_OP(Tanh);
129 DECL_CONVERT_OP(LeakyRelu);
130 DECL_CONVERT_OP(Neg);
131 DECL_CONVERT_OP(StopGradient);
132 DECL_CONVERT_OP(ReverseV2);
133 DECL_CONVERT_OP(FakeQuantWithMinMaxArgs);
134 DECL_CONVERT_OP(FakeQuantWithMinMaxVars);
135 #undef DECL_CONVERT_OP
136
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const137 LogicalResult ConvertTFReluOp::matchAndRewrite(
138 Operation* op, PatternRewriter& rewriter) const {
139 auto tf_relu_op = cast<TF::ReluOp>(op);
140
141 RankedTensorType output_type =
142 tf_relu_op.getResult().getType().dyn_cast<RankedTensorType>();
143 // Not a ranked tensor output
144 if (!output_type) return failure();
145
146 if (output_type.getElementType().isa<mlir::FloatType>()) {
147 rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
148 op, output_type, tf_relu_op.features(), rewriter.getI64IntegerAttr(0),
149 rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
150 } else {
151 rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
152 op, output_type, tf_relu_op.features(),
153 rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
154 rewriter.getF32FloatAttr(0.0f));
155 }
156 return success();
157 }
158
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const159 LogicalResult ConvertTFRelu6Op::matchAndRewrite(
160 Operation* op, PatternRewriter& rewriter) const {
161 auto tf_relu6_op = cast<TF::Relu6Op>(op);
162
163 RankedTensorType output_type =
164 tf_relu6_op.getResult().getType().dyn_cast<RankedTensorType>();
165 // Not a ranked tensor output
166 if (!output_type) return failure();
167
168 if (output_type.getElementType().isa<mlir::FloatType>()) {
169 rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
170 op, output_type, tf_relu6_op.features(), rewriter.getI64IntegerAttr(0),
171 rewriter.getF32FloatAttr(6.0f));
172 } else {
173 rewriter.replaceOpWithNewOp<tosa::ReluNOp>(
174 op, output_type, tf_relu6_op.features(), rewriter.getI64IntegerAttr(6),
175 rewriter.getF32FloatAttr(0.0f));
176 }
177 return success();
178 }
179
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const180 LogicalResult ConvertTFEqualOp::matchAndRewrite(
181 Operation* op, PatternRewriter& rewriter) const {
182 auto tf_equal_op = cast<TF::EqualOp>(op);
183
184 RankedTensorType output_type =
185 tf_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
186 // Not a ranked tensor output
187 if (!output_type) return failure();
188
189 rewriter.replaceOpWithNewOp<tosa::EqualOp>(op, output_type, tf_equal_op.x(),
190 tf_equal_op.y());
191 return success();
192 }
193
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const194 LogicalResult ConvertTFNotEqualOp::matchAndRewrite(
195 Operation* op, PatternRewriter& rewriter) const {
196 auto tf_not_equal_op = cast<TF::NotEqualOp>(op);
197
198 RankedTensorType output_type =
199 tf_not_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
200 // Not a ranked tensor output
201 if (!output_type) return failure();
202
203 auto op1_equal_in = rewriter.create<tosa::EqualOp>(
204 op->getLoc(), output_type, tf_not_equal_op.x(), tf_not_equal_op.y());
205
206 auto op2_not_op1 = rewriter.create<tosa::LogicalNotOp>(
207 op->getLoc(), output_type, op1_equal_in.getResult());
208
209 rewriter.replaceOp(op, {op2_not_op1.getResult()});
210
211 return success();
212 }
213
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const214 LogicalResult ConvertTFGreaterOp::matchAndRewrite(
215 Operation* op, PatternRewriter& rewriter) const {
216 auto tf_greater_op = cast<TF::GreaterOp>(op);
217
218 RankedTensorType output_type =
219 tf_greater_op.getResult().getType().dyn_cast<RankedTensorType>();
220 // Not a ranked tensor output
221 if (!output_type) return failure();
222
223 rewriter.replaceOpWithNewOp<tosa::GreaterOp>(
224 op, output_type, tf_greater_op.x(), tf_greater_op.y());
225 return success();
226 }
227
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const228 LogicalResult ConvertTFGreaterEqualOp::matchAndRewrite(
229 Operation* op, PatternRewriter& rewriter) const {
230 auto tf_greater_equal_op = cast<TF::GreaterEqualOp>(op);
231
232 RankedTensorType output_type =
233 tf_greater_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
234 // Not a ranked tensor output
235 if (!output_type) return failure();
236
237 rewriter.replaceOpWithNewOp<tosa::GreaterEqualOp>(
238 op, output_type, tf_greater_equal_op.x(), tf_greater_equal_op.y());
239 return success();
240 }
241
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const242 LogicalResult ConvertTFAddOp::matchAndRewrite(Operation* op,
243 PatternRewriter& rewriter) const {
244 auto tf_add_op = cast<TF::AddOp>(op);
245
246 RankedTensorType output_type =
247 tf_add_op.getResult().getType().dyn_cast<RankedTensorType>();
248 // Not a ranked tensor output
249 if (!output_type) return failure();
250
251 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, output_type, tf_add_op.x(),
252 tf_add_op.y());
253 return success();
254 }
255
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const256 LogicalResult ConvertTFAddV2Op::matchAndRewrite(
257 Operation* op, PatternRewriter& rewriter) const {
258 auto tf_addv2_op = cast<TF::AddV2Op>(op);
259
260 RankedTensorType output_type =
261 tf_addv2_op.getResult().getType().dyn_cast<RankedTensorType>();
262 // Not a ranked tensor output
263 if (!output_type) return failure();
264
265 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, output_type, tf_addv2_op.x(),
266 tf_addv2_op.y());
267 return success();
268 }
269
270 // AddN is commutative
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const271 LogicalResult ConvertTFAddNOp::matchAndRewrite(
272 Operation* op, PatternRewriter& rewriter) const {
273 auto tf_addn_op = cast<TF::AddNOp>(op);
274
275 RankedTensorType output_type =
276 tf_addn_op.getResult().getType().dyn_cast<RankedTensorType>();
277 // Not a ranked tensor output
278 if (!output_type) return failure();
279
280 SmallVector<Value, 8> inputs(tf_addn_op.inputs());
281
282 assert(inputs.size() >= 2);
283
284 auto newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type,
285 inputs[0], inputs[1]);
286 for (int i = 2; i < inputs.size(); i++) {
287 newOp = rewriter.create<tosa::AddOp>(op->getLoc(), output_type, inputs[i],
288 newOp.getResult());
289 }
290
291 rewriter.replaceOp(op, {newOp.getResult()});
292
293 return success();
294 }
295
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const296 LogicalResult ConvertTFSubOp::matchAndRewrite(Operation* op,
297 PatternRewriter& rewriter) const {
298 auto tf_sub_op = cast<TF::SubOp>(op);
299
300 RankedTensorType output_type =
301 tf_sub_op.getResult().getType().dyn_cast<RankedTensorType>();
302 // Not a ranked tensor output
303 if (!output_type) return failure();
304
305 rewriter.replaceOpWithNewOp<tosa::SubOp>(op, output_type, tf_sub_op.x(),
306 tf_sub_op.y());
307 return success();
308 }
309
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const310 LogicalResult ConvertTFMulOp::matchAndRewrite(Operation* op,
311 PatternRewriter& rewriter) const {
312 auto tf_mul_op = cast<TF::MulOp>(op);
313
314 llvm::Optional<Value> result = convertMultiplyOp(
315 rewriter, op, tf_mul_op.getResult(), tf_mul_op.x(), tf_mul_op.y());
316
317 if (!result) return failure();
318
319 rewriter.replaceOp(op, {result.getValue()});
320 return success();
321 }
322
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const323 LogicalResult ConvertTFSquareOp::matchAndRewrite(
324 Operation* op, PatternRewriter& rewriter) const {
325 auto tf_square_op = cast<TF::SquareOp>(op);
326
327 llvm::Optional<Value> result =
328 convertMultiplyOp(rewriter, op, tf_square_op.getResult(),
329 tf_square_op.x(), tf_square_op.x());
330
331 if (!result) return failure();
332
333 rewriter.replaceOp(op, {result.getValue()});
334 return success();
335 }
336
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const337 LogicalResult ConvertTFSquaredDifferenceOp::matchAndRewrite(
338 Operation* op, PatternRewriter& rewriter) const {
339 auto tf_squared_op = cast<TF::SquaredDifferenceOp>(op);
340
341 llvm::Optional<Value> result =
342 convertSquaredDifferenceOp(rewriter, op, tf_squared_op.getResult(),
343 tf_squared_op.x(), tf_squared_op.y());
344
345 if (!result) return failure();
346
347 rewriter.replaceOp(op, {result.getValue()});
348 return success();
349 }
350
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const351 LogicalResult ConvertTFRoundOp::matchAndRewrite(
352 Operation* op, PatternRewriter& rewriter) const {
353 auto tf_round_op = cast<TF::RoundOp>(op);
354
355 RankedTensorType input_type =
356 tf_round_op.x().getType().dyn_cast<RankedTensorType>();
357 if (!input_type) {
358 return op->emitOpError("Round: input not ranked tensor type");
359 }
360
361 if (input_type.getElementType().isa<FloatType>()) {
362 llvm::Optional<Value> result =
363 convertRoundOp(rewriter, op, tf_round_op.getResult(), tf_round_op.x());
364
365 if (!result) return failure();
366
367 rewriter.replaceOp(op, {result.getValue()});
368 return success();
369
370 } else {
371 tf_round_op.replaceAllUsesWith(tf_round_op.x());
372 return success();
373 }
374 }
375
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const376 LogicalResult ConvertTFFloorDivOp::matchAndRewrite(
377 Operation* op, PatternRewriter& rewriter) const {
378 auto tf_floordiv_op = cast<TF::FloorDivOp>(op);
379
380 llvm::Optional<Value> result =
381 convertFloorDivOp(rewriter, op, tf_floordiv_op.getResult(),
382 tf_floordiv_op.x(), tf_floordiv_op.y());
383
384 if (!result) return failure();
385
386 rewriter.replaceOp(op, {result.getValue()});
387
388 return success();
389 }
390
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const391 LogicalResult ConvertTFFloorModOp::matchAndRewrite(
392 Operation* op, PatternRewriter& rewriter) const {
393 auto tf_floormod_op = cast<TF::FloorModOp>(op);
394
395 llvm::Optional<Value> result =
396 convertFloorModOp(rewriter, op, tf_floormod_op.getResult(),
397 tf_floormod_op.x(), tf_floormod_op.y());
398
399 if (!result) return failure();
400
401 rewriter.replaceOp(op, {result.getValue()});
402
403 return success();
404 }
405
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const406 LogicalResult ConvertTFAssertOp::matchAndRewrite(
407 Operation* op, PatternRewriter& rewriter) const {
408 op->dropAllReferences();
409 op->erase();
410 return success();
411 }
412
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const413 LogicalResult ConvertTFMaximumOp::matchAndRewrite(
414 Operation* op, PatternRewriter& rewriter) const {
415 auto tf_maximum_op = cast<TF::MaximumOp>(op);
416
417 RankedTensorType output_type =
418 tf_maximum_op.getResult().getType().dyn_cast<RankedTensorType>();
419 // Not a ranked tensor output
420 if (!output_type) return failure();
421
422 rewriter.replaceOpWithNewOp<tosa::MaximumOp>(
423 op, output_type, tf_maximum_op.x(), tf_maximum_op.y());
424 return success();
425 }
426
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const427 LogicalResult ConvertTFMinimumOp::matchAndRewrite(
428 Operation* op, PatternRewriter& rewriter) const {
429 auto tf_minimum_op = cast<TF::MinimumOp>(op);
430
431 RankedTensorType output_type =
432 tf_minimum_op.getResult().getType().dyn_cast<RankedTensorType>();
433 // Not a ranked tensor output
434 if (!output_type) return failure();
435
436 rewriter.replaceOpWithNewOp<tosa::MinimumOp>(
437 op, output_type, tf_minimum_op.x(), tf_minimum_op.y());
438 return success();
439 }
440
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const441 LogicalResult ConvertTFRealDivOp::matchAndRewrite(
442 Operation* op, PatternRewriter& rewriter) const {
443 auto tf_div_op = cast<TF::RealDivOp>(op);
444
445 RankedTensorType y_type =
446 tf_div_op.y().getType().dyn_cast<RankedTensorType>();
447 RankedTensorType output_type =
448 tf_div_op.getResult().getType().dyn_cast<RankedTensorType>();
449 // Not a ranked tensor output
450 if (!output_type || !y_type) return failure();
451
452 auto reciprocal_op =
453 rewriter.create<tosa::ReciprocalOp>(op->getLoc(), y_type, tf_div_op.y());
454
455 auto mul_op = rewriter.create<tosa::MulOp>(
456 op->getLoc(), output_type, tf_div_op.x(), reciprocal_op.getResult(), 0);
457 rewriter.replaceOp(op, {mul_op.getResult()});
458
459 return success();
460 }
461
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const462 LogicalResult ConvertTFArgMaxOp::matchAndRewrite(
463 Operation* op, PatternRewriter& rewriter) const {
464 auto tf_argmax_op = cast<TF::ArgMaxOp>(op);
465
466 RankedTensorType input_type =
467 tf_argmax_op.input().getType().dyn_cast<RankedTensorType>();
468 RankedTensorType output_type =
469 tf_argmax_op.getResult().getType().dyn_cast<RankedTensorType>();
470 // Not a ranked tensor output
471 if (!output_type || !input_type) return failure();
472
473 ElementsAttr axis_elems;
474 if (!matchPattern(tf_argmax_op.dimension(), m_Constant(&axis_elems)))
475 return failure();
476
477 int32_t axis = axis_elems.getValue<IntegerAttr>({}).getInt();
478 if (axis < 0) {
479 axis += input_type.getRank();
480 }
481
482 if (axis < 0 || axis >= input_type.getRank()) {
483 return op->emitOpError("TFArgMax: invalid axis value");
484 }
485
486 IntegerAttr axis_attr = rewriter.getI64IntegerAttr(axis);
487
488 rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, output_type,
489 tf_argmax_op.input(), axis_attr);
490
491 return success();
492 }
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const493 LogicalResult ConvertTFAvgPoolOp::matchAndRewrite(
494 Operation* op, PatternRewriter& rewriter) const {
495 auto tf_avgpool_op = cast<TF::AvgPoolOp>(op);
496
497 RankedTensorType input_type =
498 tf_avgpool_op.value().getType().dyn_cast<RankedTensorType>();
499 RankedTensorType output_type =
500 tf_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
501 // Not a ranked tensor output
502 if (!input_type || !output_type) return failure();
503
504 auto tmpAttr = tf_avgpool_op.data_formatAttr();
505 if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
506
507 ArrayAttr pad;
508 ArrayAttr stride;
509 ArrayAttr kernel;
510 {
511 auto tmpAttr = tf_avgpool_op.strides();
512 if (!tmpAttr) {
513 stride = rewriter.getI64ArrayAttr({1, 1});
514 } else {
515 // Note: hardcoded to NHWC for now
516 int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
517 int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
518 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
519 }
520 }
521 {
522 auto tmpAttr = tf_avgpool_op.ksize();
523 if (!tmpAttr) {
524 kernel = rewriter.getI64ArrayAttr({1, 1});
525 } else {
526 // Note: hardcoded to NHWC for now
527 int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
528 int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
529 kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
530 }
531 }
532 {
533 tensorflow::Padding tf_pad;
534 if (!GetPaddingFromString(tf_avgpool_op.padding().str(), &tf_pad).ok())
535 return failure();
536
537 ArrayAttr dilation =
538 rewriter.getI64ArrayAttr({1, 1}); // Pooling has no non-unit dilation
539
540 SmallVector<int64_t, 2> i64array;
541
542 for (auto& elem : tf_avgpool_op.ksize()) {
543 int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
544 i64array.emplace_back(value);
545 }
546
547 RankedTensorType filter_type = RankedTensorType::get(
548 llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
549
550 if (!getPaddingValuesFromPadType(
551 tf_pad,
552 tensorflow::FORMAT_NHWC, // TFLite only supports this
553 1, // tensorflow::FORMAT_OHWI,
554 input_type, filter_type, stride, dilation, rewriter, pad))
555 return failure();
556 }
557
558 rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
559 op, output_type, tf_avgpool_op.value(), kernel, stride, pad);
560 return success();
561 }
562
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const563 LogicalResult ConvertTFMaxPoolOp::matchAndRewrite(
564 Operation* op, PatternRewriter& rewriter) const {
565 auto tf_maxpool_op = cast<TF::MaxPoolOp>(op);
566
567 RankedTensorType input_type =
568 tf_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
569 RankedTensorType output_type =
570 tf_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
571 // Not a ranked tensor output
572 if (!input_type || !output_type) return failure();
573
574 auto tmpAttr = tf_maxpool_op.data_formatAttr();
575 if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
576
577 ArrayAttr pad;
578 ArrayAttr stride;
579 ArrayAttr kernel;
580 {
581 auto tmpAttr = tf_maxpool_op.strides();
582 if (!tmpAttr) {
583 stride = rewriter.getI64ArrayAttr({1, 1});
584 } else {
585 // Note: hardcoded to NHWC for now
586 int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
587 int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
588 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
589 }
590 }
591 {
592 auto tmpAttr = tf_maxpool_op.ksize();
593 if (!tmpAttr) {
594 kernel = rewriter.getI64ArrayAttr({1, 1});
595 } else {
596 // Note: hardcoded to NHWC for now
597 int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
598 int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
599 kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
600 }
601 }
602 {
603 tensorflow::Padding tf_pad;
604 if (!GetPaddingFromString(tf_maxpool_op.padding().str(), &tf_pad).ok())
605 return failure();
606
607 // Pooling has no non-unit dilation
608 ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
609
610 SmallVector<int64_t, 4> i64array;
611
612 for (auto& elem : tf_maxpool_op.ksize()) {
613 int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
614 i64array.emplace_back(value);
615 }
616
617 RankedTensorType filter_type = RankedTensorType::get(
618 llvm::makeArrayRef<int64_t>(i64array), rewriter.getIntegerType(64));
619
620 if (!getPaddingValuesFromPadType(
621 tf_pad,
622 tensorflow::FORMAT_NHWC, // TFLite only supports this
623 1, // tensorflow::FORMAT_OHWI,
624 input_type, filter_type, stride, dilation, rewriter, pad))
625 return failure();
626 }
627
628 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
629 op, output_type, tf_maxpool_op.input(), kernel, stride, pad);
630 return success();
631 }
632
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const633 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
634 Operation* op, PatternRewriter& rewriter) const {
635 auto tf_concatv2_op = cast<TF::ConcatV2Op>(op);
636 SmallVector<Value, 8> values(tf_concatv2_op.values());
637
638 ElementsAttr axis_elems;
639 if (!matchPattern(tf_concatv2_op.axis(), m_Constant(&axis_elems)))
640 return failure();
641
642 int32_t axis = axis_elems.getValue<IntegerAttr>({}).getInt();
643
644 llvm::Optional<Value> result =
645 convertConcatV2Op(rewriter, op, tf_concatv2_op.getResult(), values, axis);
646
647 if (!result) return failure();
648
649 rewriter.replaceOp(op, {result.getValue()});
650
651 return success();
652 }
653
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const654 LogicalResult ConvertTFReshapeOp::matchAndRewrite(
655 Operation* op, PatternRewriter& rewriter) const {
656 auto tf_reshape_op = cast<TF::ReshapeOp>(op);
657
658 RankedTensorType output_type =
659 tf_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
660 // Not a ranked tensor output
661 if (!output_type) return failure();
662
663 // Regular way to match tensor as element attribute doesn't always work
664 // use output_type.getShape() which is more stable
665 SmallVector<int64_t, 8> shape_vals;
666 for (int i = 0; i < output_type.getShape().size(); i++) {
667 shape_vals.push_back(output_type.getShape()[i]);
668 }
669 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals);
670
671 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
672 op, output_type, tf_reshape_op.tensor(), shape_attr);
673 return success();
674 }
675
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const676 LogicalResult ConvertTFRankOp::matchAndRewrite(
677 Operation* op, PatternRewriter& rewriter) const {
678 auto tf_rank_op = cast<TF::RankOp>(op);
679
680 RankedTensorType input_type =
681 tf_rank_op.input().getType().dyn_cast<RankedTensorType>();
682 if (!input_type) return failure();
683
684 int32_t rank = input_type.getRank();
685
686 RankedTensorType rank_type =
687 RankedTensorType::get({1}, rewriter.getIntegerType(32));
688 auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
689 auto rank_const =
690 rewriter.create<tosa::ConstOp>(op->getLoc(), rank_type, rank_attr);
691
692 rewriter.replaceOp(op, {rank_const.getResult()});
693
694 return success();
695 }
696
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const697 LogicalResult ConvertTFShapeOp::matchAndRewrite(
698 Operation* op, PatternRewriter& rewriter) const {
699 auto tf_shape_op = cast<TF::ShapeOp>(op);
700
701 RankedTensorType output_type =
702 tf_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
703 // Not a ranked tensor output
704 if (!output_type) return failure();
705
706 RankedTensorType input_type =
707 tf_shape_op.input().getType().dyn_cast<RankedTensorType>();
708 if (!input_type) return failure();
709
710 auto input_shape = input_type.getShape();
711
712 SmallVector<int32_t, 8> shape_arr;
713 for (int i = 0; i < input_shape.size(); i++) {
714 shape_arr.emplace_back(input_shape[i]);
715 }
716
717 RankedTensorType shape_type = RankedTensorType::get(
718 {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
719 auto shape_attr = DenseElementsAttr::get(
720 shape_type, llvm::makeArrayRef<int32_t>(shape_arr));
721 auto shape_const =
722 rewriter.create<tosa::ConstOp>(op->getLoc(), shape_type, shape_attr);
723
724 rewriter.replaceOp(op, {shape_const.getResult()});
725
726 return success();
727 }
728
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const729 LogicalResult ConvertTFExpandDimsOp::matchAndRewrite(
730 Operation* op, PatternRewriter& rewriter) const {
731 auto tf_expanddims_op = cast<TF::ExpandDimsOp>(op);
732
733 llvm::Optional<Value> result =
734 convertExpandDimsOp(rewriter, op, tf_expanddims_op.getResult(),
735 tf_expanddims_op.input(), tf_expanddims_op.dim());
736
737 if (!result) return failure();
738
739 rewriter.replaceOp(op, {result.getValue()});
740
741 return success();
742 }
743
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const744 LogicalResult ConvertTFSqueezeOp::matchAndRewrite(
745 Operation* op, PatternRewriter& rewriter) const {
746 auto tf_squeeze_op = cast<TF::SqueezeOp>(op);
747
748 // Copy squeeze_dims into int32_t array
749 auto squeeze_dims_attr = tf_squeeze_op.squeeze_dimsAttr();
750 SmallVector<int32_t, 8> squeeze_dims;
751 for (auto& squeeze_dim : squeeze_dims_attr) {
752 squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
753 }
754
755 llvm::Optional<Value> result =
756 convertSqueezeOp(rewriter, op, tf_squeeze_op.getResult(),
757 tf_squeeze_op.input(), squeeze_dims);
758
759 if (!result) return failure();
760
761 rewriter.replaceOp(op, {result.getValue()});
762
763 return success();
764 }
765
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const766 LogicalResult ConvertTFFillOp::matchAndRewrite(
767 Operation* op, PatternRewriter& rewriter) const {
768 auto tf_fill_op = cast<TF::FillOp>(op);
769
770 RankedTensorType output_type =
771 tf_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
772 // Not a ranked tensor output
773 if (!output_type) return failure();
774
775 ElementsAttr dims_elems;
776 if (!matchPattern(tf_fill_op.dims(), m_Constant(&dims_elems)))
777 return failure();
778 SmallVector<int64_t, 4> dims_vals;
779 uint32_t total_size = 1;
780 for (int i = 0; i < dims_elems.getNumElements(); i++) {
781 dims_vals.push_back(dims_elems.getValue<IntegerAttr>(i).getInt());
782 total_size *= dims_vals[i];
783 }
784
785 ElementsAttr value_elem;
786 if (!matchPattern(tf_fill_op.value(), m_Constant(&value_elem)))
787 return failure();
788
789 RankedTensorType fill_type = RankedTensorType::get(
790 ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
791 DenseElementsAttr fill_attr;
792
793 // Convert to a compatible zero type
794 if (value_elem.getType().getElementType().isa<FloatType>()) {
795 llvm::SmallVector<float, 4> fill_arr(
796 total_size,
797 value_elem.getValue<FloatAttr>(0).getValue().convertToFloat());
798 fill_attr =
799 DenseElementsAttr::get(fill_type, llvm::makeArrayRef<float>(fill_arr));
800 } else {
801 llvm::SmallVector<int32_t, 4> fill_arr(
802 total_size,
803 value_elem.getValue<IntegerAttr>(0).getValue().getLimitedValue());
804 fill_attr = DenseElementsAttr::get(fill_type,
805 llvm::makeArrayRef<int32_t>(fill_arr));
806 }
807 auto fill_const_op =
808 rewriter.create<tosa::ConstOp>(op->getLoc(), fill_type, fill_attr);
809 rewriter.replaceOp(op, {fill_const_op.getResult()});
810
811 return success();
812 }
813
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const814 LogicalResult ConvertTFConv2DOp::matchAndRewrite(
815 Operation* op, PatternRewriter& rewriter) const {
816 auto tf_conv2d_op = cast<TF::Conv2DOp>(op);
817
818 RankedTensorType filter_type =
819 tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
820 RankedTensorType output_type =
821 tf_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
822
823 // Set up a zero attr for subsequent pattern replacement if required
824 auto bias_dim = filter_type.getShape().back();
825 RankedTensorType bias_type =
826 RankedTensorType::get({bias_dim}, filter_type.getElementType());
827 auto bias_attr = rewriter.getZeroAttr(bias_type);
828 auto bias = rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type,
829 bias_attr.cast<ElementsAttr>());
830
831 llvm::Optional<Value> result = convertTFConv2DCommon(
832 rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
833 bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
834 tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
835 tf_conv2d_op.data_format());
836
837 if (!result) return failure();
838
839 rewriter.replaceOp(op, {result.getValue()});
840
841 return success();
842 }
843
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const844 LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite(
845 Operation* op, PatternRewriter& rewriter) const {
846 auto tf_dwconv2d_op = cast<TF::DepthwiseConv2dNativeOp>(op);
847
848 RankedTensorType input_type =
849 tf_dwconv2d_op.input().getType().dyn_cast<RankedTensorType>();
850 RankedTensorType filter_type =
851 tf_dwconv2d_op.filter().getType().dyn_cast<RankedTensorType>();
852 RankedTensorType output_type =
853 tf_dwconv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
854 // Not a ranked tensor output
855 if (!input_type) return failure();
856 if (!output_type) return failure();
857
858 // Set up a zero attr for subsequent pattern replacement if required
859 if (!filter_type) {
860 return op->emitOpError("DepthwiseConv2d: filter type unranked tensor");
861 }
862
863 auto tmpAttr = tf_dwconv2d_op.data_formatAttr();
864 if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
865
866 ArrayAttr stride;
867 ArrayAttr dilation;
868 ArrayAttr pad;
869 {
870 auto tmpAttr = tf_dwconv2d_op.strides();
871 if (!tmpAttr) {
872 stride = rewriter.getI64ArrayAttr({1, 1});
873 } else {
874 // Note: hardcoded to NHWC for now
875 int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
876 int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
877 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
878 }
879 }
880 {
881 auto tmpAttr = tf_dwconv2d_op.dilations();
882 if (!tmpAttr) {
883 dilation = rewriter.getI64ArrayAttr({1, 1});
884 } else {
885 // Note: hardcoded to NHWC for now
886 int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
887 int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
888 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
889 }
890 }
891 {
892 tensorflow::Padding tf_pad;
893 if (!GetPaddingFromString(tf_dwconv2d_op.padding().str(), &tf_pad).ok())
894 return failure();
895
896 tensorflow::TensorFormat data_format_tf;
897 if (!FormatFromString(tf_dwconv2d_op.data_format().str(), &data_format_tf))
898 return failure();
899
900 if (tf_pad == tensorflow::Padding::EXPLICIT) {
901 pad = getPaddingValuesFromExplicitPadAttr(
902 tf_dwconv2d_op.explicit_paddings(), data_format_tf, rewriter);
903 } else {
904 if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
905 0, // tensorflow::FORMAT_HWIO
906 input_type, filter_type, stride,
907 dilation, rewriter, pad))
908 return failure();
909 }
910 }
911
912 auto filter_shape = filter_type.getShape();
913 auto bias_dim = filter_shape[2] * filter_shape[3];
914 RankedTensorType bias_type =
915 RankedTensorType::get({bias_dim}, filter_type.getElementType());
916 auto bias_attr = rewriter.getZeroAttr(bias_type);
917 auto bias = rewriter.create<tosa::ConstOp>(op->getLoc(), bias_type,
918 bias_attr.cast<ElementsAttr>());
919
920 rewriter.replaceOpWithNewOp<tosa::DepthwiseConv2DOp>(
921 op, output_type, tf_dwconv2d_op.input(), tf_dwconv2d_op.filter(), bias,
922 pad, stride, dilation);
923 return success();
924 }
925
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const926 LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite(
927 Operation* op, PatternRewriter& rewriter) const {
928 auto tf_conv_op = cast<TF::Conv2DBackpropInputOp>(op);
929
930 RankedTensorType input_type =
931 tf_conv_op.out_backprop().getType().dyn_cast<RankedTensorType>();
932 RankedTensorType filter_type =
933 tf_conv_op.filter().getType().dyn_cast<RankedTensorType>();
934 RankedTensorType output_type =
935 tf_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
936 // Not a ranked tensor output
937 if (!input_type) return failure();
938 if (!filter_type) return failure();
939 if (!output_type) return failure();
940
941 // Transpose [H, W, I, O] to [O, H, W, I]
942 auto filter_shape = filter_type.getShape();
943 llvm::SmallVector<int64_t, 4> a1_transpose_dims;
944 a1_transpose_dims.push_back(filter_shape[2]);
945 a1_transpose_dims.push_back(filter_shape[0]);
946 a1_transpose_dims.push_back(filter_shape[1]);
947 a1_transpose_dims.push_back(filter_shape[3]);
948 Value a1_filter_transpose_perm =
949 get1DConstTensor<tosa::ConstOp, int32_t>(rewriter, op, {2, 0, 1, 3});
950 auto a1_filter_transpose_op = rewriter.create<tosa::TransposeOp>(
951 op->getLoc(),
952 RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
953 filter_type.getElementType()),
954 tf_conv_op.filter(), a1_filter_transpose_perm);
955
956 ArrayAttr stride;
957 ArrayAttr dilation;
958 ArrayAttr outpad;
959 ArrayAttr output_shape;
960 {
961 auto tmpAttr = tf_conv_op.strides();
962 if (!tmpAttr) {
963 stride = rewriter.getI64ArrayAttr({1, 1});
964 } else {
965 // Note: hardcoded to NHWC for now
966 int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
967 int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
968 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
969 }
970 }
971 {
972 auto tmpAttr = tf_conv_op.dilations();
973 if (!tmpAttr) {
974 dilation = rewriter.getI64ArrayAttr({1, 1});
975 } else {
976 // Note: hardcoded to NHWC for now
977 int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
978 int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
979 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
980 }
981 }
982 {
983 tensorflow::Padding tf_pad;
984 if (!GetPaddingFromString(tf_conv_op.padding().str(), &tf_pad).ok())
985 return failure();
986
987 tensorflow::TensorFormat data_format_tf;
988 if (!FormatFromString(tf_conv_op.data_format().str(), &data_format_tf))
989 return failure();
990
991 if (tf_pad == tensorflow::Padding::EXPLICIT) {
992 outpad = getPaddingValuesFromExplicitPadAttr(
993 tf_conv_op.explicit_paddings(), data_format_tf, rewriter);
994 } else {
995 if (!getTransposeConv2dPaddingValues(tf_pad, data_format_tf,
996 0, // tensorflow::FORMAT_HWIO,
997 input_type, filter_type, output_type,
998 stride, dilation, rewriter, outpad))
999 return failure();
1000 }
1001 }
1002 {
1003 ElementsAttr output_shape_elems;
1004 // Match from input_sizes tensor first.
1005 if (matchPattern(tf_conv_op.input_sizes(),
1006 m_Constant(&output_shape_elems))) {
1007 llvm::SmallVector<int64_t, 4> shape_vec;
1008 for (int i = 0; i < output_shape_elems.getNumElements(); i++)
1009 shape_vec.push_back(
1010 output_shape_elems.getValue<IntegerAttr>(i).getInt());
1011 output_shape = rewriter.getI64ArrayAttr(shape_vec);
1012 } else {
1013 // Use output tensor's shape otherwise.
1014 output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
1015 }
1016 }
1017
1018 SmallVector<float, 8> zero_bias_vec(output_type.getShape()[3], 0.0f);
1019 Value zero_bias =
1020 get1DConstTensor<tosa::ConstOp, float>(rewriter, op, zero_bias_vec);
1021
1022 rewriter.replaceOpWithNewOp<tosa::TransposeConv2DOp>(
1023 op, output_type, tf_conv_op.out_backprop(),
1024 a1_filter_transpose_op.getResult(), zero_bias, outpad, stride, dilation,
1025 output_shape);
1026
1027 return success();
1028 }
1029
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1030 LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op,
1031 PatternRewriter& rewriter) const {
1032 auto tf_all_op = cast<TF::AllOp>(op);
1033
1034 RankedTensorType output_type =
1035 tf_all_op.getResult().getType().dyn_cast<RankedTensorType>();
1036 if (!output_type) return failure();
1037
1038 ElementsAttr axes_elems;
1039 if (!matchPattern(tf_all_op.reduction_indices(), m_Constant(&axes_elems)))
1040 return failure();
1041
1042 bool keep_dims = false;
1043 auto keep_dims_attr = tf_all_op.keep_dimsAttr();
1044 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1045
1046 llvm::Optional<Value> result = convertReduceAllOp(
1047 rewriter, op, output_type, tf_all_op.input(), axes_elems, keep_dims);
1048
1049 if (!result) return failure();
1050
1051 rewriter.replaceOp(op, {result.getValue()});
1052
1053 return success();
1054 }
1055
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1056 LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op,
1057 PatternRewriter& rewriter) const {
1058 auto tf_any_op = cast<TF::AnyOp>(op);
1059
1060 RankedTensorType output_type =
1061 tf_any_op.getResult().getType().dyn_cast<RankedTensorType>();
1062 if (!output_type) return failure();
1063
1064 ElementsAttr axes_elems;
1065 if (!matchPattern(tf_any_op.reduction_indices(), m_Constant(&axes_elems)))
1066 return failure();
1067
1068 bool keep_dims = false;
1069 auto keep_dims_attr = tf_any_op.keep_dimsAttr();
1070 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1071
1072 llvm::Optional<Value> result = convertReduceAnyOp(
1073 rewriter, op, output_type, tf_any_op.input(), axes_elems, keep_dims);
1074
1075 if (!result) return failure();
1076
1077 rewriter.replaceOp(op, {result.getValue()});
1078
1079 return success();
1080 }
1081
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1082 LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op,
1083 PatternRewriter& rewriter) const {
1084 auto tf_max_op = cast<TF::MaxOp>(op);
1085
1086 RankedTensorType output_type =
1087 tf_max_op.getResult().getType().dyn_cast<RankedTensorType>();
1088 if (!output_type) return failure();
1089
1090 ElementsAttr axes_elems;
1091 if (!matchPattern(tf_max_op.reduction_indices(), m_Constant(&axes_elems)))
1092 return failure();
1093
1094 bool keep_dims = false;
1095 auto keep_dims_attr = tf_max_op.keep_dimsAttr();
1096 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1097
1098 llvm::Optional<Value> result = convertReduceMaxOp(
1099 rewriter, op, output_type, tf_max_op.input(), axes_elems, keep_dims);
1100
1101 if (!result) return failure();
1102
1103 rewriter.replaceOp(op, {result.getValue()});
1104
1105 return success();
1106 }
1107
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1108 LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op,
1109 PatternRewriter& rewriter) const {
1110 auto tf_min_op = cast<TF::MinOp>(op);
1111
1112 RankedTensorType output_type =
1113 tf_min_op.getResult().getType().dyn_cast<RankedTensorType>();
1114 if (!output_type) return failure();
1115
1116 ElementsAttr axes_elems;
1117 if (!matchPattern(tf_min_op.reduction_indices(), m_Constant(&axes_elems)))
1118 return failure();
1119
1120 bool keep_dims = false;
1121 auto keep_dims_attr = tf_min_op.keep_dimsAttr();
1122 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1123
1124 llvm::Optional<Value> result = convertReduceMinOp(
1125 rewriter, op, output_type, tf_min_op.input(), axes_elems, keep_dims);
1126
1127 if (!result) return failure();
1128
1129 rewriter.replaceOp(op, {result.getValue()});
1130
1131 return success();
1132 }
1133
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1134 LogicalResult ConvertTFMeanOp::matchAndRewrite(
1135 Operation* op, PatternRewriter& rewriter) const {
1136 auto tf_mean_op = cast<TF::MeanOp>(op);
1137
1138 RankedTensorType output_type =
1139 tf_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
1140 if (!output_type) return failure();
1141
1142 ElementsAttr axes_elems;
1143 if (!matchPattern(tf_mean_op.reduction_indices(), m_Constant(&axes_elems)))
1144 return failure();
1145
1146 bool keep_dims = false;
1147 auto keep_dims_attr = tf_mean_op.keep_dimsAttr();
1148 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1149
1150 llvm::Optional<Value> result = convertReduceMeanOp(
1151 rewriter, op, output_type, tf_mean_op.input(), axes_elems, keep_dims);
1152
1153 if (!result) return failure();
1154
1155 rewriter.replaceOp(op, {result.getValue()});
1156
1157 return success();
1158 }
1159
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1160 LogicalResult ConvertTFProdOp::matchAndRewrite(
1161 Operation* op, PatternRewriter& rewriter) const {
1162 auto tf_prod_op = cast<TF::ProdOp>(op);
1163
1164 RankedTensorType output_type =
1165 tf_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
1166 if (!output_type) return failure();
1167
1168 ElementsAttr axes_elems;
1169 if (!matchPattern(tf_prod_op.reduction_indices(), m_Constant(&axes_elems)))
1170 return failure();
1171
1172 bool keep_dims = false;
1173 auto keep_dims_attr = tf_prod_op.keep_dimsAttr();
1174 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1175
1176 llvm::Optional<Value> result = convertReduceProdOp(
1177 rewriter, op, output_type, tf_prod_op.input(), axes_elems, keep_dims);
1178
1179 if (!result) return failure();
1180
1181 rewriter.replaceOp(op, {result.getValue()});
1182
1183 return success();
1184 }
1185
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1186 LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op,
1187 PatternRewriter& rewriter) const {
1188 auto tf_sum_op = cast<TF::SumOp>(op);
1189
1190 RankedTensorType output_type =
1191 tf_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
1192 if (!output_type) return failure();
1193
1194 ElementsAttr axes_elems;
1195 if (!matchPattern(tf_sum_op.reduction_indices(), m_Constant(&axes_elems)))
1196 return failure();
1197
1198 bool keep_dims = false;
1199 auto keep_dims_attr = tf_sum_op.keep_dimsAttr();
1200 if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
1201
1202 llvm::Optional<Value> result = convertReduceSumOp(
1203 rewriter, op, output_type, tf_sum_op.input(), axes_elems, keep_dims);
1204
1205 if (!result) return failure();
1206
1207 rewriter.replaceOp(op, {result.getValue()});
1208
1209 return success();
1210 }
1211
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1212 LogicalResult ConvertTFEluOp::matchAndRewrite(Operation* op,
1213 PatternRewriter& rewriter) const {
1214 auto tf_elu_op = cast<TF::EluOp>(op);
1215
1216 llvm::Optional<Value> result =
1217 convertEluOp(rewriter, op, tf_elu_op.getResult(), tf_elu_op.features());
1218
1219 if (!result) return failure();
1220
1221 rewriter.replaceOp(op, {result.getValue()});
1222
1223 return success();
1224 }
1225
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1226 LogicalResult ConvertTFSoftmaxOp::matchAndRewrite(
1227 Operation* op, PatternRewriter& rewriter) const {
1228 auto tf_softmax_op = cast<TF::SoftmaxOp>(op);
1229
1230 llvm::Optional<Value> result = convertSoftmaxOp(
1231 rewriter, op, tf_softmax_op.getResult(), tf_softmax_op.logits());
1232
1233 if (!result) return failure();
1234
1235 rewriter.replaceOp(op, {result.getValue()});
1236
1237 return success();
1238 }
1239
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1240 LogicalResult ConvertTFLogSoftmaxOp::matchAndRewrite(
1241 Operation* op, PatternRewriter& rewriter) const {
1242 auto tf_logsoftmax_op = cast<TF::LogSoftmaxOp>(op);
1243
1244 llvm::Optional<Value> result = convertLogSoftmaxOp(
1245 rewriter, op, tf_logsoftmax_op.getResult(), tf_logsoftmax_op.logits());
1246
1247 if (!result) return failure();
1248
1249 rewriter.replaceOp(op, {result.getValue()});
1250
1251 return success();
1252 }
1253
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1254 LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite(
1255 Operation* op, PatternRewriter& rewriter) const {
1256 auto tf_batchnorm_op = cast<TF::FusedBatchNormOp>(op);
1257
1258 RankedTensorType output_type =
1259 tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1260 // Not a ranked tensor output
1261 if (!output_type) return failure();
1262
1263 // Lowering:
1264 // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
1265 //
1266 // shape_0 = ones(input.rank)
1267 // shape_0[input.rank-1] = input.shape[input.rank-1]
1268 // shape_1 = ones(1)
1269 //
1270 // bmean = reshape(mean, shape_0)
1271 // bscale = reshape(scale, shape_0)
1272 // boffset= reshape(offset, shape_0)
1273 // beps = reshape(epsilon, shape_1)
1274 //
1275 // op1 = sub(input, bmean)
1276 // op2 = add(var, beps)
1277 // op3 = rsqrt(op2)
1278 // bvar = reshape(op3, shape_0)
1279 // op4 = mul(op1, bvar)
1280 // op5 = mul(op4, bscale)
1281 // op6 = add(op5, boffset)
1282
1283 RankedTensorType mean_type =
1284 tf_batchnorm_op.mean().getType().dyn_cast<RankedTensorType>();
1285 RankedTensorType variance_type =
1286 tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
1287 if (!variance_type || !mean_type) return failure();
1288
1289 Value mean_val, variance_val;
1290
1291 if (mean_type.getNumElements() == 0) {
1292 mean_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 0);
1293 } else {
1294 mean_val = tf_batchnorm_op.mean();
1295 }
1296
1297 if (variance_type.getNumElements() == 0) {
1298 variance_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 1.0);
1299 } else {
1300 variance_val = tf_batchnorm_op.variance();
1301 }
1302
1303 RankedTensorType epsilon_type =
1304 RankedTensorType::get({1}, variance_type.getElementType());
1305 auto epsilon_attr =
1306 DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
1307 auto epsilon_const =
1308 rewriter.create<tosa::ConstOp>(op->getLoc(), epsilon_type, epsilon_attr);
1309
1310 auto op1_sub_input_mean = rewriter.create<tosa::SubOp>(
1311 op->getLoc(), tf_batchnorm_op.getResult(0).getType(), tf_batchnorm_op.x(),
1312 mean_val);
1313
1314 auto op2_add_var_epsilon =
1315 rewriter.create<tosa::AddOp>(op->getLoc(), variance_val.getType(),
1316 variance_val, epsilon_const.getResult());
1317
1318 auto op3_rsqrt_op2 = rewriter.create<tosa::RsqrtOp>(
1319 op->getLoc(), variance_val.getType(), op2_add_var_epsilon.getResult());
1320
1321 auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
1322 op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1323 op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
1324
1325 auto op5_mul_op4_scale = rewriter.create<tosa::MulOp>(
1326 op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1327 op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
1328
1329 auto op6_add_op5_offset = rewriter.create<tosa::AddOp>(
1330 op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1331 op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
1332
1333 rewriter.replaceOp(op, {op6_add_op5_offset.getResult()});
1334 return success();
1335 }
1336
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1337 LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite(
1338 Operation* op, PatternRewriter& rewriter) const {
1339 auto tf_batchnorm_op = cast<TF::FusedBatchNormV3Op>(op);
1340
1341 RankedTensorType output_type =
1342 tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1343 // Not a ranked tensor output
1344 if (!output_type) return failure();
1345
1346 // Lowering:
1347 // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
1348 // op1 = sub(input, mean)
1349 // op2 = add(var, epsilon)
1350 // op3 = rsqrt(op2)
1351 // op4 = mul(op1, op3)
1352 // op5 = mul(op4, scale)
1353 // op6 = add(op5, offset)
1354
1355 auto op1_sub_input_mean = rewriter.create<tosa::SubOp>(
1356 op->getLoc(), tf_batchnorm_op.getResult(0).getType(), tf_batchnorm_op.x(),
1357 tf_batchnorm_op.mean());
1358
1359 RankedTensorType variance_type =
1360 tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
1361 if (!variance_type) return failure();
1362
1363 auto epsilon_type =
1364 RankedTensorType::get({1}, variance_type.getElementType());
1365 auto epsilon_attr =
1366 DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
1367 auto epsilon_const =
1368 rewriter.create<tosa::ConstOp>(op->getLoc(), epsilon_type, epsilon_attr);
1369
1370 auto op2_add_var_epsilon = rewriter.create<tosa::AddOp>(
1371 op->getLoc(), tf_batchnorm_op.variance().getType(),
1372 tf_batchnorm_op.variance(), epsilon_const);
1373
1374 auto op3_rsqrt_op2 = rewriter.create<tosa::RsqrtOp>(
1375 op->getLoc(), tf_batchnorm_op.variance().getType(),
1376 op2_add_var_epsilon.getResult());
1377
1378 auto op4_mul_op1_op3 = rewriter.create<tosa::MulOp>(
1379 op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1380 op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
1381
1382 auto op5_mul_op4_scale = rewriter.create<tosa::MulOp>(
1383 op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1384 op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
1385
1386 auto op6_add_op5_offset = rewriter.create<tosa::AddOp>(
1387 op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1388 op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
1389
1390 rewriter.replaceOp(op, {op6_add_op5_offset.getResult()});
1391 return success();
1392 }
1393
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1394 LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
1395 Operation* op, PatternRewriter& rewriter) const {
1396 auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
1397
1398 RankedTensorType output_type =
1399 tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
1400 // Not a ranked tensor output
1401 if (!output_type) return failure();
1402
1403 auto add_op = rewriter.create<tosa::AddOp>(
1404 op->getLoc(), output_type, tf_biasadd_op.value(), tf_biasadd_op.bias());
1405
1406 rewriter.replaceOp(op, {add_op.getResult()});
1407 return success();
1408 }
1409
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1410 LogicalResult ConvertTFSliceOp::matchAndRewrite(
1411 Operation* op, PatternRewriter& rewriter) const {
1412 auto tf_slice_op = cast<TF::SliceOp>(op);
1413
1414 RankedTensorType output_type =
1415 tf_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
1416 // Not a ranked tensor output
1417 if (!output_type) return failure();
1418
1419 ElementsAttr begin_elems, size_elems;
1420
1421 SmallVector<int64_t, 4> begin_vals, size_vals;
1422
1423 // Assuming begin is always compile-time constant
1424 if (!matchPattern(tf_slice_op.begin(), m_Constant(&begin_elems))) {
1425 return op->emitOpError("TF::Slice error: begin is not constant");
1426 }
1427
1428 for (int i = 0; i < begin_elems.getNumElements(); i++)
1429 begin_vals.push_back(begin_elems.getValue<IntegerAttr>(i).getInt());
1430
1431 // Try to match size as compile-time constant first,
1432 // if this fails, use the output tensor shape instead.
1433 if (matchPattern(tf_slice_op.size(), m_Constant(&size_elems))) {
1434 for (int i = 0; i < size_elems.getNumElements(); i++)
1435 size_vals.push_back(size_elems.getValue<IntegerAttr>(i).getInt());
1436 } else {
1437 size_vals.assign(output_type.getShape().begin(),
1438 output_type.getShape().end());
1439 }
1440
1441 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1442 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1443
1444 rewriter.replaceOpWithNewOp<tosa::SliceOp>(op, output_type,
1445 tf_slice_op.input(), begin, size);
1446 return success();
1447 }
1448
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1449 LogicalResult ConvertTFTileOp::matchAndRewrite(
1450 Operation* op, PatternRewriter& rewriter) const {
1451 auto tf_tile_op = cast<TF::TileOp>(op);
1452
1453 RankedTensorType output_type =
1454 tf_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
1455 // Not a ranked tensor output
1456 if (!output_type) return failure();
1457
1458 ElementsAttr multiples_elems;
1459 if (!matchPattern(tf_tile_op.multiples(), m_Constant(&multiples_elems)))
1460 return failure();
1461 SmallVector<int64_t, 4> multiples_vals;
1462 for (int i = 0; i < multiples_elems.getNumElements(); i++)
1463 multiples_vals.push_back(multiples_elems.getValue<IntegerAttr>(i).getInt());
1464
1465 ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
1466
1467 rewriter.replaceOpWithNewOp<tosa::TileOp>(op, output_type, tf_tile_op.input(),
1468 multiples_attr);
1469
1470 return success();
1471 }
1472
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1473 LogicalResult ConvertTFTransposeOp::matchAndRewrite(
1474 Operation* op, PatternRewriter& rewriter) const {
1475 auto tf_transpose_op = cast<TF::TransposeOp>(op);
1476
1477 RankedTensorType output_type =
1478 tf_transpose_op.getResult().getType().dyn_cast<RankedTensorType>();
1479 // Not a ranked tensor output
1480 if (!output_type) {
1481 return failure();
1482 }
1483
1484 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
1485 op, output_type, tf_transpose_op.x(), tf_transpose_op.perm());
1486
1487 return success();
1488 }
1489
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1490 LogicalResult ConvertTFPackOp::matchAndRewrite(
1491 Operation* op, PatternRewriter& rewriter) const {
1492 auto tf_pack_op = cast<TF::PackOp>(op);
1493
1494 SmallVector<Value, 8> inputs(tf_pack_op.values());
1495
1496 assert(inputs.size() >= 2);
1497
1498 IntegerAttr axis_attr;
1499 {
1500 auto tmpAttr = tf_pack_op.axisAttr();
1501 if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
1502 axis_attr = tmpAttr;
1503 }
1504 int32_t axis_i32 = axis_attr.getInt();
1505
1506 llvm::Optional<Value> result =
1507 convertPackOp(rewriter, op, tf_pack_op.getResult(), inputs, axis_i32);
1508
1509 if (!result) return failure();
1510
1511 rewriter.replaceOp(op, {result.getValue()});
1512
1513 return success();
1514 }
1515
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1516 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
1517 Operation* op, PatternRewriter& rewriter) const {
1518 auto tf_unpack_op = cast<TF::UnpackOp>(op);
1519
1520 IntegerAttr axis_attr;
1521 {
1522 auto tmpAttr = tf_unpack_op.axisAttr();
1523 if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
1524 axis_attr = tmpAttr;
1525 }
1526 int32_t axis_i32 = axis_attr.getInt();
1527
1528 llvm::Optional<ValueRange> results =
1529 convertUnpackOp(rewriter, op, tf_unpack_op.value(), axis_i32);
1530
1531 if (!results) return failure();
1532
1533 rewriter.replaceOp(op, results.getValue());
1534
1535 return success();
1536 }
1537
1538 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1539 LogicalResult ConvertTFSplitOp::matchAndRewrite(
1540 Operation* op, PatternRewriter& rewriter) const {
1541 auto tf_split_op = cast<TF::SplitOp>(op);
1542
1543 // Get the number of splits
1544 int32_t num_split = -1;
1545
1546 auto range = tf_split_op.getODSResults(0);
1547 num_split = std::distance(range.begin(), range.end());
1548
1549 // Get the axis
1550 int32_t axis = 0;
1551 ElementsAttr axisAttrElems;
1552 if (matchPattern(tf_split_op.split_dim(), m_Constant(&axisAttrElems))) {
1553 axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
1554 }
1555
1556 llvm::Optional<ValueRange> results =
1557 convertSplitOp(rewriter, op, tf_split_op.getResult(0),
1558 tf_split_op.value(), num_split, axis);
1559
1560 if (!results) return failure();
1561
1562 rewriter.replaceOp(op, results.getValue());
1563
1564 return success();
1565 }
1566
1567 // TFSplitV op splits based on a vector of sizes
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1568 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
1569 Operation* op, PatternRewriter& rewriter) const {
1570 auto tf_splitv_op = cast<TF::SplitVOp>(op);
1571
1572 // Get the size_splits array
1573 SmallVector<int32_t, 4> size_split;
1574 ElementsAttr size_split_elems;
1575 if (!matchPattern(tf_splitv_op.size_splits(),
1576 m_Constant(&size_split_elems))) {
1577 return failure();
1578 }
1579
1580 for (int i = 0; i < size_split_elems.getNumElements(); i++) {
1581 size_split.push_back(size_split_elems.getValue<IntegerAttr>(i).getInt());
1582 }
1583
1584 // Get the axis
1585 ElementsAttr axisAttrElems;
1586 if (!matchPattern(tf_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
1587 return op->emitOpError("Cannot read split_dim elems");
1588 }
1589
1590 int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
1591
1592 llvm::Optional<ValueRange> results =
1593 convertSplitVOp(rewriter, op, tf_splitv_op.getResult(0),
1594 tf_splitv_op.value(), size_split, axis);
1595
1596 if (!results) return failure();
1597
1598 rewriter.replaceOp(op, results.getValue());
1599
1600 return success();
1601 }
1602
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1603 LogicalResult ConvertTFLessOp::matchAndRewrite(
1604 Operation* op, PatternRewriter& rewriter) const {
1605 auto tf_less_op = cast<TF::LessOp>(op);
1606
1607 RankedTensorType output_type =
1608 tf_less_op.getResult().getType().dyn_cast<RankedTensorType>();
1609 // Not a ranked tensor output
1610 if (!output_type) return failure();
1611
1612 // less(x, y) is not(greater_equal(x, y))
1613 auto greater_equal_op = rewriter.create<tosa::GreaterEqualOp>(
1614 op->getLoc(), output_type, tf_less_op.x(), tf_less_op.y());
1615
1616 auto not_op = rewriter.create<tosa::LogicalNotOp>(
1617 op->getLoc(), output_type, greater_equal_op.getResult());
1618
1619 rewriter.replaceOp(op, {not_op.getResult()});
1620 return success();
1621 }
1622
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1623 LogicalResult ConvertTFLessEqualOp::matchAndRewrite(
1624 Operation* op, PatternRewriter& rewriter) const {
1625 auto tf_less_equal_op = cast<TF::LessEqualOp>(op);
1626
1627 RankedTensorType output_type =
1628 tf_less_equal_op.getResult().getType().dyn_cast<RankedTensorType>();
1629 // Not a ranked tensor output
1630 if (!output_type) return failure();
1631
1632 // less_equal(x, y) is not(greater(x, y))
1633 auto greater_op = rewriter.create<tosa::GreaterOp>(
1634 op->getLoc(), output_type, tf_less_equal_op.x(), tf_less_equal_op.y());
1635
1636 auto not_op = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), output_type,
1637 greater_op.getResult());
1638
1639 rewriter.replaceOp(op, {not_op.getResult()});
1640 return success();
1641 }
1642
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1643 LogicalResult ConvertTFPadOp::matchAndRewrite(Operation* op,
1644 PatternRewriter& rewriter) const {
1645 auto tf_pad_op = cast<TF::PadOp>(op);
1646
1647 RankedTensorType output_type =
1648 tf_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
1649 // Not a ranked tensor output
1650 if (!output_type) return failure();
1651
1652 auto pad_op = rewriter.create<tosa::PadOp>(
1653 op->getLoc(), output_type, tf_pad_op.input(), tf_pad_op.paddings());
1654
1655 rewriter.replaceOp(op, {pad_op.getResult()});
1656 return success();
1657 }
1658
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1659 LogicalResult ConvertTFResizeBilinearOp::matchAndRewrite(
1660 Operation* op, PatternRewriter& rewriter) const {
1661 auto tf_resize_op = cast<TF::ResizeBilinearOp>(op);
1662
1663 RankedTensorType output_type =
1664 tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
1665 // Not a ranked tensor output
1666 if (!output_type) return failure();
1667
1668 llvm::Optional<Value> result = convertResizeOp(
1669 rewriter, op, output_type, tf_resize_op.images(), StringRef("BILINEAR"));
1670
1671 if (!result) return failure();
1672
1673 rewriter.replaceOp(op, {result.getValue()});
1674
1675 return success();
1676 }
1677
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1678 LogicalResult ConvertTFResizeNearestNeighborOp::matchAndRewrite(
1679 Operation* op, PatternRewriter& rewriter) const {
1680 auto tf_resize_op = cast<TF::ResizeNearestNeighborOp>(op);
1681
1682 RankedTensorType output_type =
1683 tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
1684 // Not a ranked tensor output
1685 if (!output_type) return failure();
1686
1687 llvm::Optional<Value> result = convertResizeOp(
1688 rewriter, op, output_type, tf_resize_op.images(), StringRef("NEAREST"));
1689
1690 if (!result) return failure();
1691
1692 rewriter.replaceOp(op, {result.getValue()});
1693
1694 return success();
1695 }
1696
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1697 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
1698 Operation* op, PatternRewriter& rewriter) const {
1699 auto tf_matmul_op = cast<TF::MatMulOp>(op);
1700
1701 RankedTensorType a_type =
1702 tf_matmul_op.a().getType().dyn_cast<RankedTensorType>();
1703 RankedTensorType b_type =
1704 tf_matmul_op.b().getType().dyn_cast<RankedTensorType>();
1705 RankedTensorType output_type =
1706 tf_matmul_op.getResult().getType().dyn_cast<RankedTensorType>();
1707
1708 if (!(a_type && b_type && output_type)) {
1709 return op->emitOpError("MatMul: a/b/output not ranked tensors");
1710 }
1711
1712 // Can only handle rank=2 inputs
1713 if (a_type.getShape().size() != 2) return failure();
1714
1715 rewriter.replaceOpWithNewOp<tosa::MatMulOp>(op, output_type, tf_matmul_op.a(),
1716 tf_matmul_op.b());
1717
1718 return success();
1719 }
1720
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1721 LogicalResult ConvertTFGatherOp::matchAndRewrite(
1722 Operation* op, PatternRewriter& rewriter) const {
1723 auto tf_gather_op = cast<TF::GatherOp>(op);
1724
1725 RankedTensorType output_type =
1726 tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
1727 if (!output_type) return failure();
1728
1729 IntegerAttr axis_attr = rewriter.getI32IntegerAttr(0);
1730
1731 // TODO: batchdim_attr handling to be implemented with a revised
1732 // defintion of the TOSA operator.
1733 rewriter.replaceOpWithNewOp<tosa::GatherOp>(
1734 op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
1735 axis_attr);
1736
1737 return success();
1738 }
1739
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1740 LogicalResult ConvertTFGatherV2Op::matchAndRewrite(
1741 Operation* op, PatternRewriter& rewriter) const {
1742 auto tf_gather_op = cast<TF::GatherV2Op>(op);
1743
1744 RankedTensorType output_type =
1745 tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
1746 if (!output_type) return failure();
1747
1748 // Axis is a tensor in TF. Convert to I64Attr for TOSA
1749 ElementsAttr axis_elem;
1750 if (!matchPattern(tf_gather_op.axis(), m_Constant(&axis_elem)))
1751 return failure();
1752 assert(axis_elem.getType().getRank() == 0 && "expected 0D tensor");
1753
1754 IntegerAttr batchdim_attr;
1755 {
1756 auto tmpAttr = tf_gather_op.batch_dimsAttr();
1757 if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
1758 batchdim_attr = tmpAttr;
1759 }
1760
1761 // TODO: batchdim_attr handling to be implemented with a revised
1762 // defintion of the TOSA operator.
1763 rewriter.replaceOpWithNewOp<tosa::GatherOp>(
1764 op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
1765 rewriter.getI32IntegerAttr(axis_elem.getValue<IntegerAttr>({}).getInt()));
1766
1767 return success();
1768 }
1769
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1770 LogicalResult ConvertTFSelectV2Op::matchAndRewrite(
1771 Operation* op, PatternRewriter& rewriter) const {
1772 auto tf_sel_op = cast<TF::SelectV2Op>(op);
1773
1774 llvm::Optional<Value> result =
1775 convertSelectOp(rewriter, op, tf_sel_op.getResult(),
1776 tf_sel_op.condition(), tf_sel_op.t(), tf_sel_op.e());
1777
1778 if (!result) return failure();
1779
1780 rewriter.replaceOp(op, {result.getValue()});
1781
1782 return success();
1783 }
1784
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1785 LogicalResult ConvertTFSpaceToDepthOp::matchAndRewrite(
1786 Operation* op, PatternRewriter& rewriter) const {
1787 auto tf_s2d_op = cast<TF::SpaceToDepthOp>(op);
1788
1789 llvm::Optional<Value> result = convertSpaceToDepthOp(
1790 rewriter, op, tf_s2d_op.getResult(), tf_s2d_op.input(),
1791 tf_s2d_op.block_sizeAttr(), tf_s2d_op.data_formatAttr());
1792
1793 if (!result) return failure();
1794
1795 rewriter.replaceOp(op, {result.getValue()});
1796
1797 return success();
1798 }
1799
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1800 LogicalResult ConvertTFDepthToSpaceOp::matchAndRewrite(
1801 Operation* op, PatternRewriter& rewriter) const {
1802 auto tf_d2s_op = cast<TF::DepthToSpaceOp>(op);
1803
1804 llvm::Optional<Value> result = convertDepthToSpaceOp(
1805 rewriter, op, tf_d2s_op.getResult(), tf_d2s_op.input(),
1806 tf_d2s_op.block_sizeAttr(), tf_d2s_op.data_formatAttr());
1807
1808 if (!result) return failure();
1809
1810 rewriter.replaceOp(op, {result.getValue()});
1811
1812 return success();
1813 }
1814
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1815 LogicalResult ConvertTFSpaceToBatchNDOp::matchAndRewrite(
1816 Operation* op, PatternRewriter& rewriter) const {
1817 auto tf_s2b_op = cast<TF::SpaceToBatchNDOp>(op);
1818
1819 llvm::Optional<Value> result = convertSpaceToBatchNDOp(
1820 rewriter, op, tf_s2b_op.getResult(), tf_s2b_op.input(),
1821 tf_s2b_op.block_shape(), tf_s2b_op.paddings());
1822 if (!result) return failure();
1823
1824 rewriter.replaceOp(op, {result.getValue()});
1825
1826 return success();
1827 }
1828
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1829 LogicalResult ConvertTFBatchToSpaceNDOp::matchAndRewrite(
1830 Operation* op, PatternRewriter& rewriter) const {
1831 auto tf_b2s_op = cast<TF::BatchToSpaceNDOp>(op);
1832
1833 llvm::Optional<Value> result = convertBatchToSpaceNDOp(
1834 rewriter, op, tf_b2s_op.getResult(), tf_b2s_op.input(),
1835 tf_b2s_op.block_shape(), tf_b2s_op.crops());
1836
1837 if (!result) return failure();
1838
1839 rewriter.replaceOp(op, {result.getValue()});
1840
1841 return success();
1842 }
1843
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1844 LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
1845 Operation* op, PatternRewriter& rewriter) const {
1846 auto tf_ss_op = cast<TF::StridedSliceOp>(op);
1847
1848 llvm::Optional<Value> result = convertStridedSliceOp(
1849 rewriter, op, tf_ss_op.getResult(), tf_ss_op.input(), tf_ss_op.begin(),
1850 tf_ss_op.end(), tf_ss_op.strides(), tf_ss_op.begin_maskAttr().getInt(),
1851 tf_ss_op.end_maskAttr().getInt(), tf_ss_op.ellipsis_maskAttr().getInt(),
1852 tf_ss_op.new_axis_maskAttr().getInt(),
1853 tf_ss_op.shrink_axis_maskAttr().getInt());
1854
1855 if (!result) return failure();
1856
1857 rewriter.replaceOp(op, {result.getValue()});
1858
1859 return success();
1860 }
1861
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1862 LogicalResult ConvertTFZerosLikeOp::matchAndRewrite(
1863 Operation* op, PatternRewriter& rewriter) const {
1864 auto tf_zeroslike_op = cast<TF::ZerosLikeOp>(op);
1865
1866 llvm::Optional<Value> result = convertZerosLikeOp(
1867 rewriter, op, tf_zeroslike_op.getResult(), tf_zeroslike_op.x());
1868
1869 if (!result) return failure();
1870
1871 rewriter.replaceOp(op, {result.getValue()});
1872
1873 return success();
1874 }
1875
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1876 LogicalResult ConvertTFSigmoidOp::matchAndRewrite(
1877 Operation* op, PatternRewriter& rewriter) const {
1878 auto tf_sigmoid_op = cast<TF::SigmoidOp>(op);
1879 RankedTensorType output_type =
1880 tf_sigmoid_op.getResult().getType().dyn_cast<RankedTensorType>();
1881 if (!output_type) return failure();
1882
1883 rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(op, output_type,
1884 tf_sigmoid_op.x());
1885
1886 return success();
1887 }
1888
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1889 LogicalResult ConvertTFTanhOp::matchAndRewrite(
1890 Operation* op, PatternRewriter& rewriter) const {
1891 auto tf_tanh_op = cast<TF::TanhOp>(op);
1892 RankedTensorType output_type =
1893 tf_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
1894 if (!output_type) return failure();
1895
1896 rewriter.replaceOpWithNewOp<tosa::TanhOp>(op, output_type, tf_tanh_op.x());
1897
1898 return success();
1899 }
1900
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1901 LogicalResult ConvertTFLeakyReluOp::matchAndRewrite(
1902 Operation* op, PatternRewriter& rewriter) const {
1903 auto tf_leakyrelu_op = cast<TF::LeakyReluOp>(op);
1904 RankedTensorType output_type =
1905 tf_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
1906 if (!output_type) return failure();
1907
1908 // TODO: add lowering with MUL + SELECT
1909
1910 return failure();
1911 }
1912
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1913 LogicalResult ConvertTFNegOp::matchAndRewrite(Operation* op,
1914 PatternRewriter& rewriter) const {
1915 auto tf_neg_op = cast<TF::NegOp>(op);
1916 RankedTensorType output_type =
1917 tf_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
1918 if (!output_type) return failure();
1919
1920 rewriter.replaceOpWithNewOp<tosa::NegateOp>(op, output_type, tf_neg_op.x());
1921
1922 return success();
1923 }
1924
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1925 LogicalResult ConvertTFStopGradientOp::matchAndRewrite(
1926 Operation* op, PatternRewriter& rewriter) const {
1927 auto tf_stopgrad_op = cast<TF::StopGradientOp>(op);
1928 RankedTensorType output_type =
1929 tf_stopgrad_op.getResult().getType().dyn_cast<RankedTensorType>();
1930 if (!output_type) return failure();
1931
1932 rewriter.replaceOpWithNewOp<tosa::IdentityOp>(op, output_type,
1933 tf_stopgrad_op.input());
1934
1935 return success();
1936 }
1937
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1938 LogicalResult ConvertTFReverseV2Op::matchAndRewrite(
1939 Operation* op, PatternRewriter& rewriter) const {
1940 auto tf_reverse_op = cast<TF::ReverseV2Op>(op);
1941 RankedTensorType input_type =
1942 tf_reverse_op.tensor().getType().dyn_cast<RankedTensorType>();
1943 RankedTensorType output_type =
1944 tf_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
1945 if (!input_type || !output_type) return failure();
1946
1947 ElementsAttr axis_elems;
1948 if (!matchPattern(tf_reverse_op.axis(), m_Constant(&axis_elems)))
1949 return failure();
1950
1951 auto input_rank = input_type.getShape().size();
1952 Value val = tf_reverse_op.tensor();
1953 if (axis_elems.getNumElements() == 0) {
1954 auto identity_op =
1955 rewriter.create<tosa::IdentityOp>(op->getLoc(), output_type, val);
1956 val = identity_op.getResult();
1957 } else {
1958 for (int i = 0; i < axis_elems.getNumElements(); i++) {
1959 int64_t axis_val = axis_elems.getValue<IntegerAttr>(i).getInt();
1960 if (axis_val < 0) axis_val += input_rank;
1961 auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
1962 auto reverse_op = rewriter.create<tosa::ReverseOp>(
1963 op->getLoc(), output_type, val, axis_attr);
1964
1965 val = reverse_op.getResult();
1966 }
1967 }
1968
1969 rewriter.replaceOp(op, {val});
1970
1971 return success();
1972 }
1973
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1974 LogicalResult ConvertTFFakeQuantWithMinMaxArgsOp::matchAndRewrite(
1975 Operation* op, PatternRewriter& rewriter) const {
1976 auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxArgsOp>(op);
1977
1978 RankedTensorType output_type =
1979 tf_fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
1980 // Not a ranked tensor output
1981 if (!output_type) return failure();
1982
1983 llvm::Optional<Value> result =
1984 convertFakeQuantOp(rewriter, op, output_type, tf_fakequant_op.inputs(),
1985 tf_fakequant_op.minAttr().getValueAsDouble(),
1986 tf_fakequant_op.maxAttr().getValueAsDouble(),
1987 tf_fakequant_op.num_bitsAttr().getInt(),
1988 tf_fakequant_op.narrow_rangeAttr().getValue());
1989
1990 if (!result) return failure();
1991
1992 rewriter.replaceOp(op, {result.getValue()});
1993
1994 return success();
1995 }
1996
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1997 LogicalResult ConvertTFFakeQuantWithMinMaxVarsOp::matchAndRewrite(
1998 Operation* op, PatternRewriter& rewriter) const {
1999 auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
2000
2001 RankedTensorType output_type =
2002 tf_fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
2003 // Not a ranked tensor output
2004 if (!output_type) return failure();
2005
2006 // Only support min/max that can be matched at compile time
2007 ElementsAttr min_elems, max_elems;
2008 if (!matchPattern(tf_fakequant_op.min(), m_Constant(&min_elems)))
2009 return failure();
2010
2011 if (!matchPattern(tf_fakequant_op.max(), m_Constant(&max_elems)))
2012 return failure();
2013
2014 if (min_elems.getNumElements() != 1 && max_elems.getNumElements() != 1)
2015 return failure();
2016
2017 int64_t min_val = min_elems.getValue<IntegerAttr>(0).getInt();
2018 int64_t max_val = max_elems.getValue<IntegerAttr>(0).getInt();
2019
2020 llvm::Optional<Value> result = convertFakeQuantOp(
2021 rewriter, op, output_type, tf_fakequant_op.inputs(), min_val, max_val,
2022 tf_fakequant_op.num_bitsAttr().getInt(),
2023 tf_fakequant_op.narrow_rangeAttr().getValue());
2024
2025 if (!result) return failure();
2026
2027 rewriter.replaceOp(op, {result.getValue()});
2028
2029 return success();
2030 }
2031
runOnFunction()2032 void LegalizeTF::runOnFunction() {
2033 OwningRewritePatternList patterns;
2034 auto* ctx = &getContext();
2035 auto func = getFunction();
2036
2037 // Add the generated patterns to the list.
2038 populateWithGenerated(ctx, patterns);
2039 patterns.insert<ConvertTFMatMulOp>(ctx);
2040 patterns.insert<ConvertTFReluOp>(ctx);
2041 patterns.insert<ConvertTFRelu6Op>(ctx);
2042 patterns.insert<ConvertTFEqualOp>(ctx);
2043 patterns.insert<ConvertTFNotEqualOp>(ctx);
2044 patterns.insert<ConvertTFGreaterOp>(ctx);
2045 patterns.insert<ConvertTFGreaterEqualOp>(ctx);
2046 patterns.insert<ConvertTFAddOp>(ctx);
2047 patterns.insert<ConvertTFAddV2Op>(ctx);
2048 patterns.insert<ConvertTFAddNOp>(ctx);
2049 patterns.insert<ConvertTFSubOp>(ctx);
2050 patterns.insert<ConvertTFMulOp>(ctx);
2051 patterns.insert<ConvertTFSquareOp>(ctx);
2052 patterns.insert<ConvertTFSquaredDifferenceOp>(ctx);
2053 patterns.insert<ConvertTFRoundOp>(ctx);
2054 patterns.insert<ConvertTFFloorDivOp>(ctx);
2055 patterns.insert<ConvertTFFloorModOp>(ctx);
2056 patterns.insert<ConvertTFAssertOp>(ctx);
2057 patterns.insert<ConvertTFMaximumOp>(ctx);
2058 patterns.insert<ConvertTFMinimumOp>(ctx);
2059 patterns.insert<ConvertTFRealDivOp>(ctx);
2060 patterns.insert<ConvertTFArgMaxOp>(ctx);
2061 patterns.insert<ConvertTFAvgPoolOp>(ctx);
2062 patterns.insert<ConvertTFMaxPoolOp>(ctx);
2063 patterns.insert<ConvertTFConcatV2Op>(ctx);
2064 patterns.insert<ConvertTFReshapeOp>(ctx);
2065 patterns.insert<ConvertTFRankOp>(ctx);
2066 patterns.insert<ConvertTFShapeOp>(ctx);
2067 patterns.insert<ConvertTFExpandDimsOp>(ctx);
2068 patterns.insert<ConvertTFSqueezeOp>(ctx);
2069 patterns.insert<ConvertTFFillOp>(ctx);
2070 patterns.insert<ConvertTFConv2DOp>(ctx);
2071 patterns.insert<ConvertTFDepthwiseConv2dNativeOp>(ctx);
2072 patterns.insert<ConvertTFConv2DBackpropInputOp>(ctx);
2073 patterns.insert<ConvertTFEluOp>(ctx);
2074 patterns.insert<ConvertTFSoftmaxOp>(ctx);
2075 patterns.insert<ConvertTFLogSoftmaxOp>(ctx);
2076 patterns.insert<ConvertTFAllOp>(ctx);
2077 patterns.insert<ConvertTFAnyOp>(ctx);
2078 patterns.insert<ConvertTFMaxOp>(ctx);
2079 patterns.insert<ConvertTFMinOp>(ctx);
2080 patterns.insert<ConvertTFMeanOp>(ctx);
2081 patterns.insert<ConvertTFProdOp>(ctx);
2082 patterns.insert<ConvertTFSumOp>(ctx);
2083 patterns.insert<ConvertTFFusedBatchNormOp>(ctx);
2084 patterns.insert<ConvertTFFusedBatchNormV3Op>(ctx);
2085 patterns.insert<ConvertTFBiasAddOp>(ctx);
2086 patterns.insert<ConvertTFSplitOp>(ctx);
2087 patterns.insert<ConvertTFSplitVOp>(ctx);
2088 patterns.insert<ConvertTFPackOp>(ctx);
2089 patterns.insert<ConvertTFUnpackOp>(ctx);
2090 patterns.insert<ConvertTFTransposeOp>(ctx);
2091 patterns.insert<ConvertTFTileOp>(ctx);
2092 patterns.insert<ConvertTFSliceOp>(ctx);
2093 patterns.insert<ConvertTFStridedSliceOp>(ctx);
2094 patterns.insert<ConvertTFLessOp>(ctx);
2095 patterns.insert<ConvertTFLessEqualOp>(ctx);
2096 patterns.insert<ConvertTFPadOp>(ctx);
2097 patterns.insert<ConvertTFResizeBilinearOp>(ctx);
2098 patterns.insert<ConvertTFResizeNearestNeighborOp>(ctx);
2099 patterns.insert<ConvertTFGatherOp>(ctx);
2100 patterns.insert<ConvertTFGatherV2Op>(ctx);
2101 patterns.insert<ConvertTFSelectV2Op>(ctx);
2102 patterns.insert<ConvertTFSpaceToDepthOp>(ctx);
2103 patterns.insert<ConvertTFDepthToSpaceOp>(ctx);
2104 patterns.insert<ConvertTFSpaceToBatchNDOp>(ctx);
2105 patterns.insert<ConvertTFBatchToSpaceNDOp>(ctx);
2106 patterns.insert<ConvertTFZerosLikeOp>(ctx);
2107 patterns.insert<ConvertTFSigmoidOp>(ctx);
2108 patterns.insert<ConvertTFTanhOp>(ctx);
2109 patterns.insert<ConvertTFLeakyReluOp>(ctx);
2110 patterns.insert<ConvertTFNegOp>(ctx);
2111 patterns.insert<ConvertTFStopGradientOp>(ctx);
2112 patterns.insert<ConvertTFReverseV2Op>(ctx);
2113 patterns.insert<ConvertTFFakeQuantWithMinMaxArgsOp>(ctx);
2114 patterns.insert<ConvertTFFakeQuantWithMinMaxVarsOp>(ctx);
2115 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
2116 }
2117
2118 } // anonymous namespace
2119
2120 // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
createLegalizeTFPass()2121 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass() {
2122 return std::make_unique<LegalizeTF>();
2123 }
2124
2125 static PassRegistration<LegalizeTF> pass(
2126 PASS_NAME, "Legalize from TensorFlow to TOSA dialect");
2127
2128 } // namespace tosa
2129
2130 } // namespace mlir
2131