1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass converts operations in TensorFlow dialect into
17 // operations that are legal in the TensorFlow Lite dialect. Operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements are part
19 // of this pass and other operations that may create extra ops should be part of
20 // the PrepareTF pass which should be run before this pass. That way any
21 // constant folding opportunities from the extra ops can be exploited by the
22 // constant folding support for the TensorFlow ops.
23
24 #include <climits>
25 #include <complex>
26 #include <cstdint>
27 #include <utility>
28
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/Hashing.h"
32 #include "llvm/ADT/StringSwitch.h"
33 #include "llvm/Support/Threading.h"
34 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
35 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
36 #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
37 #include "mlir/IR/Attributes.h" // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
39 #include "mlir/IR/Diagnostics.h" // from @llvm-project
40 #include "mlir/IR/MLIRContext.h" // from @llvm-project
41 #include "mlir/IR/Operation.h" // from @llvm-project
42 #include "mlir/IR/PatternMatch.h" // from @llvm-project
43 #include "mlir/Pass/Pass.h" // from @llvm-project
44 #include "mlir/Support/LLVM.h" // from @llvm-project
45 #include "mlir/Support/LogicalResult.h" // from @llvm-project
46 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
47 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
48 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
49 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
50 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
51 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
52 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
54 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
56 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
57 #include "tensorflow/compiler/xla/status.h"
58 #include "tensorflow/compiler/xla/statusor.h"
59 #include "tensorflow/core/framework/tensor.pb.h"
60 #include "tensorflow/core/framework/tensor_shape.pb.h"
61 #include "tensorflow/core/framework/types.pb.h"
62 #include "tensorflow/core/lib/random/philox_random.h"
63 #include "tensorflow/core/lib/random/random_distributions.h"
64 #include "tensorflow/core/protobuf/error_codes.pb.h"
65
66 namespace mlir {
67 namespace TFL {
68
69 //===----------------------------------------------------------------------===//
70 // The actual LegalizeTF Pass.
71 namespace {
72
73 constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
74 constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
75 constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
76
77 // Legalize operations in functions.
78 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const79 void getDependentDialects(DialectRegistry& registry) const override {
80 registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>();
81 }
82
83 public:
84 LegalizeTF() = default;
LegalizeTF(const LegalizeTF &)85 LegalizeTF(const LegalizeTF&) {}
LegalizeTF(bool run_tfl_runtime_verification)86 explicit LegalizeTF(bool run_tfl_runtime_verification) {
87 run_tfl_runtime_verification_ = run_tfl_runtime_verification;
88 }
89
getArgument() const90 StringRef getArgument() const final {
91 // This is the argument used to refer to the pass in
92 // the textual format (on the commandline for example).
93 return "tfl-legalize-tf";
94 }
getDescription() const95 StringRef getDescription() const final {
96 // This is a brief description of the pass.
97 return "Legalize from TensorFlow to TensorFlow Lite dialect";
98 }
99
100 /// Performs the lowering to TFLite dialect.
101 void runOnFunction() override;
102
103 private:
104 Option<bool> run_tfl_runtime_verification_{
105 *this, "run-tfl-runtime-verification",
106 llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
107 };
108
109 // Returns true if all tensor value in `values` has static shape and same shape.
HasSameStaticShapes(Operation * op)110 bool HasSameStaticShapes(Operation* op) {
111 auto values = op->getOperands();
112 int index = 0;
113 ArrayRef<int64_t> shape;
114 for (Value value : values) {
115 auto shaped_type = value.getType().dyn_cast<ShapedType>();
116 if (!shaped_type || !shaped_type.hasStaticShape()) {
117 return false;
118 }
119 if (index == 0) {
120 shape = shaped_type.getShape();
121 } else {
122 if (shape != shaped_type.getShape()) {
123 return false;
124 }
125 }
126 ++index;
127 }
128 return true;
129 }
130
131 // Util that casts 'val' to Int32 by adding a cast Op.
CreateCastToInt32(Value val,Location loc,PatternRewriter & rewriter)132 Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
133 IntegerType new_ele_type = rewriter.getIntegerType(32);
134 if (auto shaped_type = val.getType().dyn_cast<RankedTensorType>()) {
135 ShapedType new_type =
136 RankedTensorType::get(shaped_type.getShape(), new_ele_type);
137 return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
138 rewriter.getBoolAttr(false));
139 }
140 return rewriter.createOrFold<TF::CastOp>(
141 loc, UnrankedTensorType::get(new_ele_type), val,
142 rewriter.getBoolAttr(false));
143 }
144
145 // Get shape of an operand or result, support both dynamic and static shape.
GetShape(Value input,Location loc,PatternRewriter & rewriter)146 Value GetShape(Value input, Location loc, PatternRewriter& rewriter) {
147 auto shaped_type = input.getType().cast<ShapedType>();
148 if (shaped_type.hasStaticShape()) {
149 auto static_shape = shaped_type.getShape();
150 auto static_shape_type =
151 RankedTensorType::get(static_shape.size(), rewriter.getIntegerType(64));
152 auto static_shape_attr =
153 mlir::DenseIntElementsAttr::get(static_shape_type, static_shape);
154 return rewriter.create<TF::ConstOp>(loc, static_shape_attr).output();
155 }
156
157 // If the shape is not static, create a new ShapeOp.
158 BoolAttr false_attr = rewriter.getBoolAttr(false);
159 return rewriter
160 .create<TF::ShapeOp>(loc, input,
161 /*use_32bit=*/false_attr)
162 .output();
163 }
164
165 #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
166
167 #define DECL_CONVERT_OP(tf_op) \
168 struct ConvertTF##tf_op##Op : public RewritePattern { \
169 explicit ConvertTF##tf_op##Op(MLIRContext* context) \
170 : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
171 LogicalResult matchAndRewrite(Operation* op, \
172 PatternRewriter& rewriter) const override; \
173 }
174
175 // TODO(antiagainst): Define this pattern in a table-driven manner once variadic
176 // operands are properly supported in declarative rewrite rule specification.
177
178 DECL_CONVERT_OP(Assert);
179 DECL_CONVERT_OP(ConcatV2);
180 DECL_CONVERT_OP(MatMul);
181 DECL_CONVERT_OP(MatrixDiagV2);
182 DECL_CONVERT_OP(MatrixDiagV3);
183 DECL_CONVERT_OP(Pack);
184 DECL_CONVERT_OP(Split);
185 DECL_CONVERT_OP(SplitV);
186 DECL_CONVERT_OP(Unpack);
187 DECL_CONVERT_OP(RandomUniform);
188 DECL_CONVERT_OP(Conv3D);
189 DECL_CONVERT_OP(Conv3DBackpropInputV2);
190
191 #undef DECL_CONVERT_OP
192
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const193 LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
194 Operation* op, PatternRewriter& rewriter) const {
195 auto random_uniform_op = cast<TF::RandomUniformOp>(op);
196 if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
197 return failure();
198 }
199 if (!random_uniform_op.dtype().isF32()) {
200 return failure();
201 }
202 typedef tensorflow::random::UniformDistribution<
203 tensorflow::random::PhiloxRandom, float>
204 Distribution;
205
206 tensorflow::random::PhiloxRandom generator(random_uniform_op.seed(),
207 random_uniform_op.seed2());
208 Distribution dist;
209 size_t num_elements = 0;
210 if (auto output_type =
211 random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
212 if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
213 if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
214 return failure();
215 }
216 num_elements = output_type.getNumElements();
217 size_t offset = 0;
218 size_t num_samples = Distribution::kResultElementCount;
219 llvm::SmallVector<float, 32> data;
220 data.resize(num_elements);
221 while (offset < num_elements) {
222 const typename Distribution::ResultType samples = dist(&generator);
223 std::copy(&samples[0],
224 &samples[0] + std::min(num_samples, data.size() - offset),
225 &data[0] + offset);
226 offset += num_samples;
227 }
228 auto output_data = DenseFPElementsAttr::get(output_type, data);
229 rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
230 return success();
231 }
232 }
233 return failure();
234 }
235
236 // Converts any IntegerAttr to an IntegerAttr of an i32 type.
237 // The value won't change in the new attribute, but if the value is out of
238 // the bound of i32, the function returns a failure.
ConvertToI32Attr(IntegerAttr attr,IntegerAttr * attr_i32)239 LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
240 if (attr.getType().isInteger(/*width=*/32)) {
241 *attr_i32 = attr;
242 return success();
243 }
244
245 int64_t value = attr.getInt();
246 if (value > std::numeric_limits<int>::max() ||
247 value < std::numeric_limits<int>::min()) {
248 return failure();
249 }
250
251 *attr_i32 = IntegerAttr::get(
252 IntegerType::get(attr.getContext(), /*width=*/32), value);
253 return success();
254 }
255
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const256 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
257 Operation* op, PatternRewriter& rewriter) const {
258 auto tf_concat_op = cast<TF::ConcatV2Op>(op);
259
260 auto values = tf_concat_op.values();
261 auto output_type = tf_concat_op.output().getType();
262 // Extract axis attribute from constant axis tensor
263 ElementsAttr axis;
264 if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
265 IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
266
267 // "axis" operand could be a i64 tensor. Resolve it here.
268 IntegerAttr axis_i32;
269 if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
270
271 StringAttr fused_activation_function =
272 StringAttr::get(rewriter.getContext(), "NONE");
273 rewriter.replaceOpWithNewOp<ConcatenationOp>(
274 op, output_type, values, axis_i32, fused_activation_function);
275 return success();
276 }
277
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const278 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
279 Operation* op, PatternRewriter& rewriter) const {
280 auto tf_matmul_op = cast<TF::MatMulOp>(op);
281 auto lhs = op->getOperand(0);
282 auto rhs = op->getOperand(1);
283 auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
284 RankedTensorType type =
285 input.getType().dyn_cast_or_null<RankedTensorType>();
286 if (!type || type.getRank() != 2) return {failure(), nullptr};
287
288 auto permute_attr = DenseIntElementsAttr::get(
289 RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
290 auto permute = rewriter.create<ConstantOp>(
291 op->getLoc(), permute_attr.getType(), permute_attr);
292 llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
293 type.getShape()[0]};
294 auto output = rewriter.create<TFL::TransposeOp>(
295 op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
296 input, permute);
297 return {success(), output};
298 };
299
300 // TODO(jpienaar): Remove once handled via dailect conversion.
301 if (tf_matmul_op.transpose_a()) {
302 LogicalResult result = success();
303 std::tie(result, lhs) = transpose(lhs);
304 if (failed(result)) return failure();
305 }
306 if (!tf_matmul_op.transpose_b()) {
307 LogicalResult result = success();
308 std::tie(result, rhs) = transpose(rhs);
309 if (failed(result)) return failure();
310 }
311
312 Type output_type = tf_matmul_op.getResult().getType();
313 auto no_input = rewriter.create<ConstantOp>(
314 op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
315 auto fc_op = rewriter.create<FullyConnectedOp>(
316 op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
317 rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
318 rewriter.getBoolAttr(false));
319 rewriter.replaceOp(op, {fc_op.getResult(0)});
320 return success();
321 }
322
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const323 LogicalResult ConvertTFPackOp::matchAndRewrite(
324 Operation* op, PatternRewriter& rewriter) const {
325 auto tf_pack_op = cast<TF::PackOp>(op);
326
327 SmallVector<Value, 4> values(tf_pack_op.values());
328 auto output_type = tf_pack_op.output().getType();
329 auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
330 // Axis can be negative.
331 auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis());
332
333 rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
334 axis);
335 return success();
336 }
337
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const338 LogicalResult ConvertTFSplitOp::matchAndRewrite(
339 Operation* op, PatternRewriter& rewriter) const {
340 auto tf_split_op = cast<TF::SplitOp>(op);
341
342 // Number of splits cannot be negative.
343 auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
344
345 rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, tf_split_op.output().getTypes(),
346 tf_split_op.split_dim(),
347 tf_split_op.value(), num_split);
348 return success();
349 }
350
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const351 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
352 Operation* op, PatternRewriter& rewriter) const {
353 auto tf_splitv_op = cast<TF::SplitVOp>(op);
354
355 // Number of splits cannot be negative.
356 auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
357
358 rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
359 op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(),
360 tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split);
361 return success();
362 }
363
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const364 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
365 Operation* op, PatternRewriter& rewriter) const {
366 auto tf_unpack_op = cast<TF::UnpackOp>(op);
367
368 auto input = tf_unpack_op.value();
369 auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
370 // Axis can be negative.
371 auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis());
372
373 rewriter.replaceOpWithNewOp<UnpackOp>(op, tf_unpack_op.output().getTypes(),
374 input, num, axis);
375 return success();
376 }
377
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const378 LogicalResult ConvertTFConv3DOp::matchAndRewrite(
379 Operation* op, PatternRewriter& rewriter) const {
380 if (!TFDataFormatIsNDHWC(op)) return failure();
381
382 auto tf_op = cast<TF::Conv3DOp>(op);
383
384 IntegerAttr stride_depth, stride_height, stride_width;
385 if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height,
386 &stride_width))
387 return failure();
388
389 IntegerAttr dilation_depth_factor, dilation_height_factor,
390 dilation_width_factor;
391 if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor,
392 &dilation_height_factor, &dilation_width_factor)) {
393 // If the 'dilations' attribute is missing, we use the default value (1)
394 // for all dilation depth, height and width factor.
395 dilation_depth_factor = rewriter.getI32IntegerAttr(1);
396 dilation_height_factor = rewriter.getI32IntegerAttr(1);
397 dilation_width_factor = rewriter.getI32IntegerAttr(1);
398 }
399
400 StringAttr padding;
401 if (!TFPaddingIsSameOrValid(op, &padding)) return failure();
402
403 // TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D
404 // with other ops can fill the bias.
405 Value none = rewriter.create<mlir::ConstantOp>(
406 op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
407
408 rewriter.replaceOpWithNewOp<TFL::Conv3DOp>(
409 op, tf_op.getType(), tf_op.input(), tf_op.filter(),
410 /*bias=*/none, dilation_depth_factor, dilation_height_factor,
411 dilation_width_factor,
412 /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding,
413 stride_depth, stride_height, stride_width);
414
415 return success();
416 }
417
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const418 LogicalResult ConvertTFConv3DBackpropInputV2Op::matchAndRewrite(
419 Operation* op, PatternRewriter& rewriter) const {
420 if (!TFDataFormatIsNDHWC(op)) return failure();
421
422 auto tf_op = cast<TF::Conv3DBackpropInputV2Op>(op);
423
424 IntegerAttr stride_depth, stride_height, stride_width;
425 if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height,
426 &stride_width))
427 return failure();
428
429 IntegerAttr dilation_depth_factor, dilation_height_factor,
430 dilation_width_factor;
431 if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor,
432 &dilation_height_factor, &dilation_width_factor)) {
433 // If the 'dilations' attribute is missing, we use the default value (1)
434 // for all dilation depth, height and width factor.
435 dilation_depth_factor = rewriter.getI32IntegerAttr(1);
436 dilation_height_factor = rewriter.getI32IntegerAttr(1);
437 dilation_width_factor = rewriter.getI32IntegerAttr(1);
438 }
439
440 StringAttr padding;
441 if (!TFPaddingIsSameOrValid(op, &padding)) return failure();
442
443 // TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D
444 // with other ops can fill the bias.
445 Value none = rewriter.create<mlir::ConstantOp>(
446 op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
447
448 Value output_shape =
449 CreateCastToInt32(tf_op.input_sizes(), op->getLoc(), rewriter);
450
451 rewriter.replaceOpWithNewOp<TFL::Conv3DTransposeOp>(
452 op, tf_op.getType(), output_shape, tf_op.filter(), tf_op.out_backprop(),
453 /*bias=*/none, dilation_depth_factor, dilation_height_factor,
454 dilation_width_factor,
455 /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding,
456 stride_depth, stride_height, stride_width);
457
458 return success();
459 }
460
461 // MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute
462 // only has effects when processing multiple diagonals. Since TFLite converts
463 // MatrixDiagV{2,3} to MatrixDiag, which only takes single-diagonal inputs, we
464 // can safely ignore this V3 attribute.
465 // We can't pass `rewriter` by reference because clang-tidy will want it to be
466 // constant (`const PatternRewriter& rewriter`). If we do that, we won't be able
467 // to call `rewriter::replaceOpWihNewOp`, which is not a const member function.
468 template <typename MatrixDiagV2OrV3Op>
ConvertTFMatrixDiagV2orV3(Operation * op,PatternRewriter * rewriter)469 bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
470 auto tf_matrix_diag_v2_or_v3_op = cast<MatrixDiagV2OrV3Op>(op);
471
472 if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
473
474 auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
475 auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
476
477 // Extract k constant tensor and check value = 0.
478 ElementsAttr k;
479 if (!matchPattern(tf_matrix_diag_v2_or_v3_op.k(), m_Constant(&k)))
480 return false;
481 if (ExtractSingleElementAsInteger(k).getInt() != 0) return false;
482
483 // Extract num_rows constant tensor and check value = -1.
484 ElementsAttr num_rows;
485 if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_rows(),
486 m_Constant(&num_rows)))
487 return false;
488 if (ExtractSingleElementAsInteger(num_rows).getInt() != -1) return false;
489
490 // Extract num_cols constant tensor and check value = -1.
491 ElementsAttr num_cols;
492 if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_cols(),
493 m_Constant(&num_cols)))
494 return false;
495 if (ExtractSingleElementAsInteger(num_cols).getInt() != -1) return false;
496
497 // Verify padding_value is an integer tensor with all 0s.
498 ElementsAttr padding_value;
499 if (!matchPattern(tf_matrix_diag_v2_or_v3_op.padding_value(),
500 m_Constant(&padding_value)))
501 return false;
502 for (const auto& value : padding_value.getValues<APInt>()) {
503 if (value != 0) return false;
504 }
505
506 rewriter->replaceOpWithNewOp<MatrixDiagOp>(op, output_type, input);
507 return true;
508 }
509
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const510 LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite(
511 Operation* op, PatternRewriter& rewriter) const {
512 if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV2Op>(op, &rewriter))
513 return success();
514 return failure();
515 }
516
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const517 LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
518 Operation* op, PatternRewriter& rewriter) const {
519 if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV3Op>(op, &rewriter))
520 return success();
521 return failure();
522 }
523
524 // TF Lite doesn't support Assert, we just drop the assert from the graph.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const525 LogicalResult ConvertTFAssertOp::matchAndRewrite(
526 Operation* op, PatternRewriter& rewriter) const {
527 rewriter.eraseOp(op);
528 return success();
529 }
530
531 // Legalize unidirectional sequence lstm.
532 struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
LegalizeUnidirectionalSequenceLstmmlir::TFL::__anondf4937430111::LegalizeUnidirectionalSequenceLstm533 explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
534 : RewritePattern(kUnidirectionalSequenceLstm, 1, context) {}
535
matchAndRewritemlir::TFL::__anondf4937430111::LegalizeUnidirectionalSequenceLstm536 LogicalResult matchAndRewrite(Operation* op,
537 PatternRewriter& rewriter) const override {
538 auto tflite_indices_attr =
539 op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
540 if (!tflite_indices_attr) return failure();
541
542 SmallVector<int64_t, 20> tflite_indices;
543 for (auto index_attr : tflite_indices_attr.getValue()) {
544 IntegerAttr index = index_attr.cast<IntegerAttr>();
545 tflite_indices.push_back(index.getInt());
546 }
547
548 // Optional input placeholder.
549 Value none = rewriter.create<mlir::ConstantOp>(
550 op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
551
552 // Populate inputs.
553 // UnidirectionalSequenceLstm is expected to have 24 inputs.
554 SmallVector<Value, 24> inputs;
555 int count = 0;
556 int total_ophint_converted_inputs = tflite_indices.size();
557 for (int i = 0; i < 24; ++i) {
558 if (count < total_ophint_converted_inputs && tflite_indices[count] == i) {
559 // specified input.
560 inputs.push_back(op->getOperand(i));
561 count++;
562 } else {
563 // Non specified input.
564 inputs.push_back(none);
565 }
566 }
567
568 // Populate outputs.
569 // UnidirectionalSequenceLstm should only have 1 output, and that is the
570 // original ophint converted node's 3rd output.
571 SmallVector<Type, 4> result_types;
572 result_types.push_back(op->getOpResult(2).getType());
573
574 // Populate attributes.
575 SmallVector<NamedAttribute, 4> attributes;
576 // Activation will always be tanh.
577 attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
578 rewriter.getStringAttr("TANH")));
579 // cell_clip.
580 attributes.push_back(
581 rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(0.0)));
582 // proj_clip.
583 attributes.push_back(
584 rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0)));
585 // will always be time_majored.
586 attributes.push_back(
587 rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
588
589 Value lstm_result = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
590 op->getLoc(), result_types, inputs, attributes);
591
592 // Rewire the output.
593 rewriter.replaceOp(op, {nullptr, nullptr, lstm_result});
594 return success();
595 }
596 };
597
598 // Legalize unidirectional seqeucen rnn.
599 struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
LegalizeUnidirectionalSequenceRnnmlir::TFL::__anondf4937430111::LegalizeUnidirectionalSequenceRnn600 explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context)
601 : RewritePattern(kUnidirectionalSequenceRnn, 1, context) {}
602
matchAndRewritemlir::TFL::__anondf4937430111::LegalizeUnidirectionalSequenceRnn603 LogicalResult matchAndRewrite(Operation* op,
604 PatternRewriter& rewriter) const override {
605 auto tflite_indices_attr =
606 op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
607 if (!tflite_indices_attr) return failure();
608
609 if (op->getNumOperands() != 5) {
610 op->emitError()
611 << "We're expecting 5 inputs for UnidirectionalSequenceRNN, only "
612 << op->getNumOperands() << " provided";
613 return failure();
614 }
615
616 if (op->getNumResults() != 2) {
617 op->emitError()
618 << "We're expecting 2 inputs for UnidirectionalSequenceRNN, only "
619 << op->getNumResults() << " found";
620 return failure();
621 }
622
623 // Populate inputs.
624 // UnidirectionalSequenceRnn is expected to have 5 inputs, and none of them
625 // are optional inputs.
626 SmallVector<Value, 5> inputs;
627 for (int i = 0; i < 5; ++i) {
628 inputs.push_back(op->getOperand(i));
629 }
630
631 // Populate outputs.
632 // UnidirectionalSequenceRnn should only have 1 output, and that is the
633 // original ophint converted node's 2nd output.
634 SmallVector<Type, 4> result_types;
635 result_types.push_back(op->getOpResult(1).getType());
636
637 // Populate attributes.
638 SmallVector<NamedAttribute, 2> attributes;
639 // Activation will always be tanh.
640 attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
641 rewriter.getStringAttr("TANH")));
642
643 // will always be time_majored.
644 attributes.push_back(
645 rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
646
647 Value rnn_result = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
648 op->getLoc(), result_types, inputs, attributes);
649
650 // Rewire the output.
651 rewriter.replaceOp(op, {nullptr, rnn_result});
652
653 return success();
654 }
655 };
656
657 // Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
658 // to make binary broadcast-able op conversion always successful and does not
659 // require flex delegate.
660 template <typename SourceOp>
661 class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
662 public:
663 using OpRewritePattern<SourceOp>::OpRewritePattern;
664
rewriteOpWithDynamicInput(Operation * op,PatternRewriter & rewriter) const665 LogicalResult rewriteOpWithDynamicInput(Operation* op,
666 PatternRewriter& rewriter) const {
667 auto lhs = op->getOperand(0);
668 auto rhs = op->getOperand(1);
669 auto out = op->getResult(0);
670
671 // Calculates symbolic broadcast shape that is only used in types.
672 SmallVector<int64_t, 4> symbolic_broadcast_shape;
673 if (!OpTrait::util::getBroadcastedShape(
674 lhs.getType().cast<ShapedType>().getShape(),
675 rhs.getType().cast<ShapedType>().getShape(),
676 symbolic_broadcast_shape)) {
677 return failure();
678 }
679
680 // Calculates the broadcast shape using BroadcastArgs op.
681 Value lhs_shape = GetShape(lhs, op->getLoc(), rewriter);
682 Value rhs_shape = GetShape(rhs, op->getLoc(), rewriter);
683 auto broadcast_shape =
684 rewriter
685 .create<TF::BroadcastArgsOp>(
686 op->getLoc(),
687 RankedTensorType::get(symbolic_broadcast_shape.size(),
688 rewriter.getIntegerType(64)),
689 lhs_shape, rhs_shape)
690 .r0();
691
692 // Broadcasts inputs using BroadcastTo op.
693 auto broadcast_type = RankedTensorType::get(
694 symbolic_broadcast_shape, getElementTypeOrSelf(lhs.getType()));
695 auto broadcasted_lhs =
696 rewriter
697 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
698 broadcast_shape)
699 .output();
700 auto broadcasted_rhs =
701 rewriter
702 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
703 broadcast_shape)
704 .output();
705
706 // Recreate an op with the above BroadcastTo op results.
707 RankedTensorType result_type = RankedTensorType::get(
708 symbolic_broadcast_shape, getElementTypeOrSelf(out.getType()));
709 rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, broadcasted_lhs,
710 broadcasted_rhs);
711 return success();
712 }
713
matchAndRewrite(SourceOp src_op,PatternRewriter & rewriter) const714 LogicalResult matchAndRewrite(SourceOp src_op,
715 PatternRewriter& rewriter) const override {
716 Operation* op = static_cast<Operation*>(src_op);
717 auto lhs = op->getOperand(0);
718 auto rhs = op->getOperand(1);
719
720 if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
721 !rhs.getType().cast<ShapedType>().hasStaticShape()) {
722 return rewriteOpWithDynamicInput(op, rewriter);
723 }
724
725 auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
726 auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
727
728 if (lhs_shape == rhs_shape) {
729 return failure();
730 }
731
732 // Calculate the broadcasted shape.
733 SmallVector<int64_t, 4> result_shape;
734 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
735 result_shape)) {
736 return failure();
737 }
738
739 RankedTensorType result_type = RankedTensorType::get(
740 result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
741
742 // Create a const op, that stores the above broadcasted shape.
743 auto new_shape_attr = mlir::DenseIntElementsAttr::get(
744 RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
745 result_shape);
746 auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
747
748 // Apply BroadcastTo ops to each input.
749 auto broadcast_type = RankedTensorType::get(
750 result_shape, getElementTypeOrSelf(lhs.getType()));
751
752 if (result_type.getShape() != lhs_shape) {
753 lhs = rewriter
754 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
755 new_shape)
756 .output();
757 }
758 if (result_type.getShape() != rhs_shape) {
759 rhs = rewriter
760 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
761 new_shape)
762 .output();
763 }
764
765 // Recreate an op with the above Broadcast op results.
766 rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
767 return success();
768 }
769 };
770
771 // This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
772 // they should have broadcastable shapes.
773 template <>
774 class ApplyExplicitBroadcasting<TF::SelectV2Op>
775 : public OpRewritePattern<TF::SelectV2Op> {
776 public:
777 using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
778
rewriteOpWithDynamicInput(Operation * op,PatternRewriter & rewriter) const779 LogicalResult rewriteOpWithDynamicInput(Operation* op,
780 PatternRewriter& rewriter) const {
781 auto cond = op->getOperand(0);
782 auto lhs = op->getOperand(1);
783 auto rhs = op->getOperand(2);
784 auto out = op->getResult(0);
785
786 // Calculates symbolic broadcast shape that is only used in types.
787 SmallVector<int64_t, 4> symbolic_broadcast_lhs_rhs_shape;
788 if (!OpTrait::util::getBroadcastedShape(
789 lhs.getType().cast<ShapedType>().getShape(),
790 rhs.getType().cast<ShapedType>().getShape(),
791 symbolic_broadcast_lhs_rhs_shape)) {
792 return failure();
793 }
794 SmallVector<int64_t, 4> symbolic_broadcast_shape;
795 if (!OpTrait::util::getBroadcastedShape(
796 cond.getType().cast<ShapedType>().getShape(),
797 symbolic_broadcast_lhs_rhs_shape, symbolic_broadcast_shape)) {
798 return failure();
799 }
800
801 // Calculates the broadcast shape using BroadcastArgs op.
802 Value cond_shape = GetShape(cond, op->getLoc(), rewriter);
803 Value lhs_shape = GetShape(lhs, op->getLoc(), rewriter);
804 Value rhs_shape = GetShape(rhs, op->getLoc(), rewriter);
805 auto broadcast_shape_value =
806 rewriter
807 .create<TF::BroadcastArgsOp>(op->getLoc(), lhs_shape.getType(),
808 lhs_shape, rhs_shape)
809 .r0();
810 broadcast_shape_value =
811 rewriter
812 .create<TF::BroadcastArgsOp>(op->getLoc(), lhs_shape.getType(),
813 broadcast_shape_value, cond_shape)
814 .r0();
815
816 // Broadcasting inputs using BroadcastTo op.
817 auto broadcast_type = RankedTensorType::get(
818 symbolic_broadcast_shape, getElementTypeOrSelf(out.getType()));
819 auto broadcasted_cond =
820 rewriter
821 .create<TF::BroadcastToOp>(
822 op->getLoc(),
823 RankedTensorType::get(symbolic_broadcast_shape,
824 rewriter.getIntegerType(1)),
825 cond, broadcast_shape_value)
826 .output();
827 auto broadcasted_lhs =
828 rewriter
829 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
830 broadcast_shape_value)
831 .output();
832 auto broadcasted_rhs =
833 rewriter
834 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
835 broadcast_shape_value)
836 .output();
837
838 // Recreate an op with the above BroadcastTo op results.
839 rewriter.replaceOpWithNewOp<TF::SelectV2Op>(
840 op, broadcast_type, broadcasted_cond, broadcasted_lhs, broadcasted_rhs);
841 return success();
842 }
843
matchAndRewrite(TF::SelectV2Op src_op,PatternRewriter & rewriter) const844 LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
845 PatternRewriter& rewriter) const override {
846 Operation* op = static_cast<Operation*>(src_op);
847 auto cond = op->getOperand(0);
848 auto lhs = op->getOperand(1);
849 auto rhs = op->getOperand(2);
850
851 // Should have static shapes to calculate the broadcasted shape.
852 if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
853 !rhs.getType().cast<ShapedType>().hasStaticShape() ||
854 !cond.getType().cast<ShapedType>().hasStaticShape()) {
855 return rewriteOpWithDynamicInput(op, rewriter);
856 }
857
858 auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
859 auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
860 auto cond_shape = cond.getType().cast<ShapedType>().getShape();
861
862 if (lhs_shape == rhs_shape && cond_shape == lhs_shape) {
863 return failure();
864 }
865
866 // Calculate the broadcasted shape.
867 SmallVector<int64_t, 4> broadcasted_shape;
868 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
869 broadcasted_shape)) {
870 return failure();
871 }
872
873 SmallVector<int64_t, 4> result_shape;
874 if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape,
875 result_shape)) {
876 return failure();
877 }
878
879 // Create a const op, that stores the above broadcasted shape.
880 auto shape_type =
881 RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
882 auto new_shape_attr =
883 mlir::DenseIntElementsAttr::get(shape_type, result_shape);
884 auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
885
886 // Apply BroadcastTo ops to each input.
887 auto cond_result_type =
888 RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
889 auto result_type = RankedTensorType::get(
890 result_shape, getElementTypeOrSelf(lhs.getType()));
891
892 if (result_shape != cond_shape) {
893 cond = rewriter
894 .create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
895 cond, new_shape)
896 .output();
897 }
898 if (result_shape != lhs_shape) {
899 lhs = rewriter
900 .create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
901 new_shape)
902 .output();
903 }
904 if (result_shape != rhs_shape) {
905 rhs = rewriter
906 .create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
907 new_shape)
908 .output();
909 }
910
911 // Recreate an op with the above Broadcast op results.
912 rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
913 rhs);
914 return success();
915 }
916 };
917
addPatterns(MLIRContext * context,OwningRewritePatternList & patterns)918 void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
919 // Add TF->TF lowering patterns.
920 TF::PopulateLoweringTFPatterns(context, &patterns);
921
922 // Add the generated patterns to the list.
923 populateWithGenerated(patterns);
924 patterns
925 .insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
926 ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
927 ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
928 ConvertTFRandomUniformOp, ConvertTFConv3DOp,
929 ConvertTFConv3DBackpropInputV2Op>(context);
930
931 // Ophint python converter converted tf node pattern.
932 patterns.insert<LegalizeUnidirectionalSequenceLstm,
933 LegalizeUnidirectionalSequenceRnn>(context);
934 }
935
applyPatterns(FuncOp func,ConversionTarget & target,FrozenRewritePatternSet & frozenPatterns)936 void applyPatterns(FuncOp func, ConversionTarget& target,
937 FrozenRewritePatternSet& frozenPatterns) {
938 // Keep trying to convert.
939 // TODO(karimnosseir): This is similar to what apply greedy patterns does.
940 // Look if there is a function that tries until it converge.
941 // Currently unit-test doesn't do multiple tries, so we need this.
942 const int max_iterations = 15;
943 for (int i = 0; i < max_iterations; ++i) {
944 if (failed(applyPartialConversion(func, target, frozenPatterns))) {
945 return;
946 }
947 }
948 }
949
runOnFunction()950 void LegalizeTF::runOnFunction() {
951 auto* context = &getContext();
952 auto func = getFunction();
953
954 ConversionTarget target(*context);
955 // It is legal to have TF ops in the graph still which can be
956 // used later or in the case of SELECT were we allow TF ops in the final
957 // graph.
958 target.addLegalOp<mlir::ConstantOp>();
959 target.addLegalOp<ConstOp>();
960 if (run_tfl_runtime_verification_) {
961 target.addDynamicallyLegalDialect<TensorFlowLiteDialect>([](Operation* op) {
962 auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
963 if (!tfl_op) return false;
964 return succeeded(tfl_op.VerifyTflRuntimeConstraints(op));
965 });
966 } else {
967 target.addLegalDialect<TensorFlowLiteDialect>();
968 }
969
970 // Ignore transient errors by registering an no-op handler.
971 // Applying legalization patterns will emit unwanted, transient errors when
972 // the replaced TFLite ops do not meet the sanity checks. In order to ignore
973 // the transient errors, the following lines override a diagnostic handler
974 // with an no-op handler only while this pass runs.
975 uint64_t current_thread_id = llvm::get_threadid();
976 ScopedDiagnosticHandler scoped_diag_handler(
977 context, [¤t_thread_id](Diagnostic&) -> LogicalResult {
978 // Consume only errors that are coming from the same thread in order not
979 // to ignore errors from other passes that are running. Things running
980 // in the pass manager can be multi-threaded.
981 return success(current_thread_id == llvm::get_threadid());
982 });
983
984 OwningRewritePatternList stage1Patterns(&getContext());
985
986 addPatterns(context, stage1Patterns);
987
988 FrozenRewritePatternSet stage1FrozenPatterns(std::move(stage1Patterns));
989 applyPatterns(func, target, stage1FrozenPatterns);
990
991 // Explict BroadcastTo addition for left-over broadcast-able ops.
992 // The following pattern matchings should be done after the other legalization
993 // rules in order not to add unnecessary BroadcastTo ops.
994 OwningRewritePatternList stage2Patterns(&getContext());
995
996 addPatterns(context, stage2Patterns);
997
998 stage2Patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
999 ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
1000 ApplyExplicitBroadcasting<TF::NotEqualOp>,
1001 ApplyExplicitBroadcasting<TF::GreaterOp>,
1002 ApplyExplicitBroadcasting<TF::LessOp>,
1003 ApplyExplicitBroadcasting<TF::EqualOp>,
1004 ApplyExplicitBroadcasting<TF::AddOp>,
1005 ApplyExplicitBroadcasting<TF::AddV2Op>,
1006 ApplyExplicitBroadcasting<TF::MulOp>,
1007 ApplyExplicitBroadcasting<TF::DivOp>,
1008 ApplyExplicitBroadcasting<TF::RealDivOp>,
1009 ApplyExplicitBroadcasting<TF::SubOp>,
1010 ApplyExplicitBroadcasting<TF::FloorDivOp>,
1011 ApplyExplicitBroadcasting<TF::FloorModOp>,
1012 ApplyExplicitBroadcasting<TF::PowOp>,
1013 ApplyExplicitBroadcasting<TF::MaximumOp>,
1014 ApplyExplicitBroadcasting<TF::MinimumOp>,
1015 ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
1016 ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
1017
1018 FrozenRewritePatternSet stage2FrozenPatterns(std::move(stage2Patterns));
1019 applyPatterns(func, target, stage2FrozenPatterns);
1020 }
1021
1022 } // namespace
1023
1024 // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
CreateLegalizeTFPass(bool run_tfl_runtime_verification)1025 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
1026 bool run_tfl_runtime_verification) {
1027 return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
1028 }
1029
1030 static PassRegistration<LegalizeTF> pass;
1031
1032 } // namespace TFL
1033 } // namespace mlir
1034