• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass converts operations in TensorFlow dialect into
17 // operations that are legal in the TensorFlow Lite dialect.  Operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements are part
19 // of this pass and other operations that may create extra ops should be part of
20 // the PrepareTF pass which should be run before this pass.  That way any
21 // constant folding opportunities from the extra ops can be exploited by the
22 // constant folding support for the TensorFlow ops.
23 
24 #include <climits>
25 #include <complex>
26 #include <cstdint>
27 #include <utility>
28 
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/Hashing.h"
32 #include "llvm/ADT/StringSwitch.h"
33 #include "llvm/Support/Threading.h"
34 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
35 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
36 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
37 #include "mlir/IR/Attributes.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
39 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
40 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
41 #include "mlir/IR/Operation.h"  // from @llvm-project
42 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
43 #include "mlir/Pass/Pass.h"  // from @llvm-project
44 #include "mlir/Support/LLVM.h"  // from @llvm-project
45 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
46 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
47 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
48 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
49 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
50 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
51 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
52 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
54 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
56 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
57 #include "tensorflow/compiler/xla/status.h"
58 #include "tensorflow/compiler/xla/statusor.h"
59 #include "tensorflow/core/framework/tensor.pb.h"
60 #include "tensorflow/core/framework/tensor_shape.pb.h"
61 #include "tensorflow/core/framework/types.pb.h"
62 #include "tensorflow/core/lib/random/philox_random.h"
63 #include "tensorflow/core/lib/random/random_distributions.h"
64 #include "tensorflow/core/protobuf/error_codes.pb.h"
65 
66 namespace mlir {
67 namespace TFL {
68 
69 //===----------------------------------------------------------------------===//
70 // The actual LegalizeTF Pass.
71 namespace {
72 
73 constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
74 constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
75 constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
76 
77 // Legalize operations in functions.
78 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const79   void getDependentDialects(DialectRegistry& registry) const override {
80     registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>();
81   }
82 
83  public:
84   LegalizeTF() = default;
LegalizeTF(const LegalizeTF &)85   LegalizeTF(const LegalizeTF&) {}
LegalizeTF(bool run_tfl_runtime_verification)86   explicit LegalizeTF(bool run_tfl_runtime_verification) {
87     run_tfl_runtime_verification_ = run_tfl_runtime_verification;
88   }
89 
getArgument() const90   StringRef getArgument() const final {
91     // This is the argument used to refer to the pass in
92     // the textual format (on the commandline for example).
93     return "tfl-legalize-tf";
94   }
getDescription() const95   StringRef getDescription() const final {
96     // This is a brief description of the pass.
97     return "Legalize from TensorFlow to TensorFlow Lite dialect";
98   }
99 
100   /// Performs the lowering to TFLite dialect.
101   void runOnFunction() override;
102 
103  private:
104   Option<bool> run_tfl_runtime_verification_{
105       *this, "run-tfl-runtime-verification",
106       llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
107 };
108 
109 // Returns true if all tensor value in `values` has static shape and same shape.
HasSameStaticShapes(Operation * op)110 bool HasSameStaticShapes(Operation* op) {
111   auto values = op->getOperands();
112   int index = 0;
113   ArrayRef<int64_t> shape;
114   for (Value value : values) {
115     auto shaped_type = value.getType().dyn_cast<ShapedType>();
116     if (!shaped_type || !shaped_type.hasStaticShape()) {
117       return false;
118     }
119     if (index == 0) {
120       shape = shaped_type.getShape();
121     } else {
122       if (shape != shaped_type.getShape()) {
123         return false;
124       }
125     }
126     ++index;
127   }
128   return true;
129 }
130 
131 // Util that casts 'val' to Int32 by adding a cast Op.
CreateCastToInt32(Value val,Location loc,PatternRewriter & rewriter)132 Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
133   IntegerType new_ele_type = rewriter.getIntegerType(32);
134   if (auto shaped_type = val.getType().dyn_cast<RankedTensorType>()) {
135     ShapedType new_type =
136         RankedTensorType::get(shaped_type.getShape(), new_ele_type);
137     return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
138                                              rewriter.getBoolAttr(false));
139   }
140   return rewriter.createOrFold<TF::CastOp>(
141       loc, UnrankedTensorType::get(new_ele_type), val,
142       rewriter.getBoolAttr(false));
143 }
144 
145 // Get shape of an operand or result, support both dynamic and static shape.
GetShape(Value input,Location loc,PatternRewriter & rewriter)146 Value GetShape(Value input, Location loc, PatternRewriter& rewriter) {
147   auto shaped_type = input.getType().cast<ShapedType>();
148   if (shaped_type.hasStaticShape()) {
149     auto static_shape = shaped_type.getShape();
150     auto static_shape_type =
151         RankedTensorType::get(static_shape.size(), rewriter.getIntegerType(64));
152     auto static_shape_attr =
153         mlir::DenseIntElementsAttr::get(static_shape_type, static_shape);
154     return rewriter.create<TF::ConstOp>(loc, static_shape_attr).output();
155   }
156 
157   // If the shape is not static, create a new ShapeOp.
158   BoolAttr false_attr = rewriter.getBoolAttr(false);
159   return rewriter
160       .create<TF::ShapeOp>(loc, input,
161                            /*use_32bit=*/false_attr)
162       .output();
163 }
164 
165 #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
166 
167 #define DECL_CONVERT_OP(tf_op)                                               \
168   struct ConvertTF##tf_op##Op : public RewritePattern {                      \
169     explicit ConvertTF##tf_op##Op(MLIRContext* context)                      \
170         : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {}   \
171     LogicalResult matchAndRewrite(Operation* op,                             \
172                                   PatternRewriter& rewriter) const override; \
173   }
174 
175 // TODO(antiagainst): Define this pattern in a table-driven manner once variadic
176 // operands are properly supported in declarative rewrite rule specification.
177 
178 DECL_CONVERT_OP(Assert);
179 DECL_CONVERT_OP(ConcatV2);
180 DECL_CONVERT_OP(MatMul);
181 DECL_CONVERT_OP(MatrixDiagV2);
182 DECL_CONVERT_OP(MatrixDiagV3);
183 DECL_CONVERT_OP(Pack);
184 DECL_CONVERT_OP(Split);
185 DECL_CONVERT_OP(SplitV);
186 DECL_CONVERT_OP(Unpack);
187 DECL_CONVERT_OP(RandomUniform);
188 DECL_CONVERT_OP(Conv3D);
189 DECL_CONVERT_OP(Conv3DBackpropInputV2);
190 
191 #undef DECL_CONVERT_OP
192 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const193 LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
194     Operation* op, PatternRewriter& rewriter) const {
195   auto random_uniform_op = cast<TF::RandomUniformOp>(op);
196   if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
197     return failure();
198   }
199   if (!random_uniform_op.dtype().isF32()) {
200     return failure();
201   }
202   typedef tensorflow::random::UniformDistribution<
203       tensorflow::random::PhiloxRandom, float>
204       Distribution;
205 
206   tensorflow::random::PhiloxRandom generator(random_uniform_op.seed(),
207                                              random_uniform_op.seed2());
208   Distribution dist;
209   size_t num_elements = 0;
210   if (auto output_type =
211           random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
212     if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
213       if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
214         return failure();
215       }
216       num_elements = output_type.getNumElements();
217       size_t offset = 0;
218       size_t num_samples = Distribution::kResultElementCount;
219       llvm::SmallVector<float, 32> data;
220       data.resize(num_elements);
221       while (offset < num_elements) {
222         const typename Distribution::ResultType samples = dist(&generator);
223         std::copy(&samples[0],
224                   &samples[0] + std::min(num_samples, data.size() - offset),
225                   &data[0] + offset);
226         offset += num_samples;
227       }
228       auto output_data = DenseFPElementsAttr::get(output_type, data);
229       rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
230       return success();
231     }
232   }
233   return failure();
234 }
235 
236 // Converts any IntegerAttr to an IntegerAttr of an i32 type.
237 // The value won't change in the new attribute, but if the value is out of
238 // the bound of i32, the function returns a failure.
ConvertToI32Attr(IntegerAttr attr,IntegerAttr * attr_i32)239 LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
240   if (attr.getType().isInteger(/*width=*/32)) {
241     *attr_i32 = attr;
242     return success();
243   }
244 
245   int64_t value = attr.getInt();
246   if (value > std::numeric_limits<int>::max() ||
247       value < std::numeric_limits<int>::min()) {
248     return failure();
249   }
250 
251   *attr_i32 = IntegerAttr::get(
252       IntegerType::get(attr.getContext(), /*width=*/32), value);
253   return success();
254 }
255 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const256 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
257     Operation* op, PatternRewriter& rewriter) const {
258   auto tf_concat_op = cast<TF::ConcatV2Op>(op);
259 
260   auto values = tf_concat_op.values();
261   auto output_type = tf_concat_op.output().getType();
262   // Extract axis attribute from constant axis tensor
263   ElementsAttr axis;
264   if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
265   IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
266 
267   // "axis" operand could be a i64 tensor. Resolve it here.
268   IntegerAttr axis_i32;
269   if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
270 
271   StringAttr fused_activation_function =
272       StringAttr::get(rewriter.getContext(), "NONE");
273   rewriter.replaceOpWithNewOp<ConcatenationOp>(
274       op, output_type, values, axis_i32, fused_activation_function);
275   return success();
276 }
277 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const278 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
279     Operation* op, PatternRewriter& rewriter) const {
280   auto tf_matmul_op = cast<TF::MatMulOp>(op);
281   auto lhs = op->getOperand(0);
282   auto rhs = op->getOperand(1);
283   auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
284     RankedTensorType type =
285         input.getType().dyn_cast_or_null<RankedTensorType>();
286     if (!type || type.getRank() != 2) return {failure(), nullptr};
287 
288     auto permute_attr = DenseIntElementsAttr::get(
289         RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
290     auto permute = rewriter.create<ConstantOp>(
291         op->getLoc(), permute_attr.getType(), permute_attr);
292     llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
293                                             type.getShape()[0]};
294     auto output = rewriter.create<TFL::TransposeOp>(
295         op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
296         input, permute);
297     return {success(), output};
298   };
299 
300   // TODO(jpienaar): Remove once handled via dailect conversion.
301   if (tf_matmul_op.transpose_a()) {
302     LogicalResult result = success();
303     std::tie(result, lhs) = transpose(lhs);
304     if (failed(result)) return failure();
305   }
306   if (!tf_matmul_op.transpose_b()) {
307     LogicalResult result = success();
308     std::tie(result, rhs) = transpose(rhs);
309     if (failed(result)) return failure();
310   }
311 
312   Type output_type = tf_matmul_op.getResult().getType();
313   auto no_input = rewriter.create<ConstantOp>(
314       op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
315   auto fc_op = rewriter.create<FullyConnectedOp>(
316       op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
317       rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
318       rewriter.getBoolAttr(false));
319   rewriter.replaceOp(op, {fc_op.getResult(0)});
320   return success();
321 }
322 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const323 LogicalResult ConvertTFPackOp::matchAndRewrite(
324     Operation* op, PatternRewriter& rewriter) const {
325   auto tf_pack_op = cast<TF::PackOp>(op);
326 
327   SmallVector<Value, 4> values(tf_pack_op.values());
328   auto output_type = tf_pack_op.output().getType();
329   auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
330   // Axis can be negative.
331   auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis());
332 
333   rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
334                                       axis);
335   return success();
336 }
337 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const338 LogicalResult ConvertTFSplitOp::matchAndRewrite(
339     Operation* op, PatternRewriter& rewriter) const {
340   auto tf_split_op = cast<TF::SplitOp>(op);
341 
342   // Number of splits cannot be negative.
343   auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
344 
345   rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, tf_split_op.output().getTypes(),
346                                             tf_split_op.split_dim(),
347                                             tf_split_op.value(), num_split);
348   return success();
349 }
350 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const351 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
352     Operation* op, PatternRewriter& rewriter) const {
353   auto tf_splitv_op = cast<TF::SplitVOp>(op);
354 
355   // Number of splits cannot be negative.
356   auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
357 
358   rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
359       op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(),
360       tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split);
361   return success();
362 }
363 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const364 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
365     Operation* op, PatternRewriter& rewriter) const {
366   auto tf_unpack_op = cast<TF::UnpackOp>(op);
367 
368   auto input = tf_unpack_op.value();
369   auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
370   // Axis can be negative.
371   auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis());
372 
373   rewriter.replaceOpWithNewOp<UnpackOp>(op, tf_unpack_op.output().getTypes(),
374                                         input, num, axis);
375   return success();
376 }
377 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const378 LogicalResult ConvertTFConv3DOp::matchAndRewrite(
379     Operation* op, PatternRewriter& rewriter) const {
380   if (!TFDataFormatIsNDHWC(op)) return failure();
381 
382   auto tf_op = cast<TF::Conv3DOp>(op);
383 
384   IntegerAttr stride_depth, stride_height, stride_width;
385   if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height,
386                         &stride_width))
387     return failure();
388 
389   IntegerAttr dilation_depth_factor, dilation_height_factor,
390       dilation_width_factor;
391   if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor,
392                         &dilation_height_factor, &dilation_width_factor)) {
393     // If the 'dilations' attribute is missing, we use the default value (1)
394     // for all dilation depth, height and width factor.
395     dilation_depth_factor = rewriter.getI32IntegerAttr(1);
396     dilation_height_factor = rewriter.getI32IntegerAttr(1);
397     dilation_width_factor = rewriter.getI32IntegerAttr(1);
398   }
399 
400   StringAttr padding;
401   if (!TFPaddingIsSameOrValid(op, &padding)) return failure();
402 
403   // TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D
404   // with other ops can fill the bias.
405   Value none = rewriter.create<mlir::ConstantOp>(
406       op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
407 
408   rewriter.replaceOpWithNewOp<TFL::Conv3DOp>(
409       op, tf_op.getType(), tf_op.input(), tf_op.filter(),
410       /*bias=*/none, dilation_depth_factor, dilation_height_factor,
411       dilation_width_factor,
412       /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding,
413       stride_depth, stride_height, stride_width);
414 
415   return success();
416 }
417 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const418 LogicalResult ConvertTFConv3DBackpropInputV2Op::matchAndRewrite(
419     Operation* op, PatternRewriter& rewriter) const {
420   if (!TFDataFormatIsNDHWC(op)) return failure();
421 
422   auto tf_op = cast<TF::Conv3DBackpropInputV2Op>(op);
423 
424   IntegerAttr stride_depth, stride_height, stride_width;
425   if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height,
426                         &stride_width))
427     return failure();
428 
429   IntegerAttr dilation_depth_factor, dilation_height_factor,
430       dilation_width_factor;
431   if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor,
432                         &dilation_height_factor, &dilation_width_factor)) {
433     // If the 'dilations' attribute is missing, we use the default value (1)
434     // for all dilation depth, height and width factor.
435     dilation_depth_factor = rewriter.getI32IntegerAttr(1);
436     dilation_height_factor = rewriter.getI32IntegerAttr(1);
437     dilation_width_factor = rewriter.getI32IntegerAttr(1);
438   }
439 
440   StringAttr padding;
441   if (!TFPaddingIsSameOrValid(op, &padding)) return failure();
442 
443   // TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D
444   // with other ops can fill the bias.
445   Value none = rewriter.create<mlir::ConstantOp>(
446       op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
447 
448   Value output_shape =
449       CreateCastToInt32(tf_op.input_sizes(), op->getLoc(), rewriter);
450 
451   rewriter.replaceOpWithNewOp<TFL::Conv3DTransposeOp>(
452       op, tf_op.getType(), output_shape, tf_op.filter(), tf_op.out_backprop(),
453       /*bias=*/none, dilation_depth_factor, dilation_height_factor,
454       dilation_width_factor,
455       /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding,
456       stride_depth, stride_height, stride_width);
457 
458   return success();
459 }
460 
461 // MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute
462 // only has effects when processing multiple diagonals. Since TFLite converts
463 // MatrixDiagV{2,3} to MatrixDiag, which only takes single-diagonal inputs, we
464 // can safely ignore this V3 attribute.
465 // We can't pass `rewriter` by reference because clang-tidy will want it to be
466 // constant (`const PatternRewriter& rewriter`). If we do that, we won't be able
467 // to call `rewriter::replaceOpWihNewOp`, which is not a const member function.
468 template <typename MatrixDiagV2OrV3Op>
ConvertTFMatrixDiagV2orV3(Operation * op,PatternRewriter * rewriter)469 bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
470   auto tf_matrix_diag_v2_or_v3_op = cast<MatrixDiagV2OrV3Op>(op);
471 
472   if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
473 
474   auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
475   auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
476 
477   // Extract k constant tensor and check value = 0.
478   ElementsAttr k;
479   if (!matchPattern(tf_matrix_diag_v2_or_v3_op.k(), m_Constant(&k)))
480     return false;
481   if (ExtractSingleElementAsInteger(k).getInt() != 0) return false;
482 
483   // Extract num_rows constant tensor and check value = -1.
484   ElementsAttr num_rows;
485   if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_rows(),
486                     m_Constant(&num_rows)))
487     return false;
488   if (ExtractSingleElementAsInteger(num_rows).getInt() != -1) return false;
489 
490   // Extract num_cols constant tensor and check value = -1.
491   ElementsAttr num_cols;
492   if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_cols(),
493                     m_Constant(&num_cols)))
494     return false;
495   if (ExtractSingleElementAsInteger(num_cols).getInt() != -1) return false;
496 
497   // Verify padding_value is an integer tensor with all 0s.
498   ElementsAttr padding_value;
499   if (!matchPattern(tf_matrix_diag_v2_or_v3_op.padding_value(),
500                     m_Constant(&padding_value)))
501     return false;
502   for (const auto& value : padding_value.getValues<APInt>()) {
503     if (value != 0) return false;
504   }
505 
506   rewriter->replaceOpWithNewOp<MatrixDiagOp>(op, output_type, input);
507   return true;
508 }
509 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const510 LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite(
511     Operation* op, PatternRewriter& rewriter) const {
512   if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV2Op>(op, &rewriter))
513     return success();
514   return failure();
515 }
516 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const517 LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
518     Operation* op, PatternRewriter& rewriter) const {
519   if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV3Op>(op, &rewriter))
520     return success();
521   return failure();
522 }
523 
524 // TF Lite doesn't support Assert, we just drop the assert from the graph.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const525 LogicalResult ConvertTFAssertOp::matchAndRewrite(
526     Operation* op, PatternRewriter& rewriter) const {
527   rewriter.eraseOp(op);
528   return success();
529 }
530 
531 // Legalize unidirectional sequence lstm.
532 struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
LegalizeUnidirectionalSequenceLstmmlir::TFL::__anondf4937430111::LegalizeUnidirectionalSequenceLstm533   explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
534       : RewritePattern(kUnidirectionalSequenceLstm, 1, context) {}
535 
matchAndRewritemlir::TFL::__anondf4937430111::LegalizeUnidirectionalSequenceLstm536   LogicalResult matchAndRewrite(Operation* op,
537                                 PatternRewriter& rewriter) const override {
538     auto tflite_indices_attr =
539         op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
540     if (!tflite_indices_attr) return failure();
541 
542     SmallVector<int64_t, 20> tflite_indices;
543     for (auto index_attr : tflite_indices_attr.getValue()) {
544       IntegerAttr index = index_attr.cast<IntegerAttr>();
545       tflite_indices.push_back(index.getInt());
546     }
547 
548     // Optional input placeholder.
549     Value none = rewriter.create<mlir::ConstantOp>(
550         op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
551 
552     // Populate inputs.
553     // UnidirectionalSequenceLstm is expected to have 24 inputs.
554     SmallVector<Value, 24> inputs;
555     int count = 0;
556     int total_ophint_converted_inputs = tflite_indices.size();
557     for (int i = 0; i < 24; ++i) {
558       if (count < total_ophint_converted_inputs && tflite_indices[count] == i) {
559         // specified input.
560         inputs.push_back(op->getOperand(i));
561         count++;
562       } else {
563         // Non specified input.
564         inputs.push_back(none);
565       }
566     }
567 
568     // Populate outputs.
569     // UnidirectionalSequenceLstm should only have 1 output, and that is the
570     // original ophint converted node's 3rd output.
571     SmallVector<Type, 4> result_types;
572     result_types.push_back(op->getOpResult(2).getType());
573 
574     // Populate attributes.
575     SmallVector<NamedAttribute, 4> attributes;
576     // Activation will always be tanh.
577     attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
578                                                rewriter.getStringAttr("TANH")));
579     // cell_clip.
580     attributes.push_back(
581         rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(0.0)));
582     // proj_clip.
583     attributes.push_back(
584         rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0)));
585     // will always be time_majored.
586     attributes.push_back(
587         rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
588 
589     Value lstm_result = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
590         op->getLoc(), result_types, inputs, attributes);
591 
592     // Rewire the output.
593     rewriter.replaceOp(op, {nullptr, nullptr, lstm_result});
594     return success();
595   }
596 };
597 
598 // Legalize unidirectional seqeucen rnn.
599 struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
LegalizeUnidirectionalSequenceRnnmlir::TFL::__anondf4937430111::LegalizeUnidirectionalSequenceRnn600   explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context)
601       : RewritePattern(kUnidirectionalSequenceRnn, 1, context) {}
602 
matchAndRewritemlir::TFL::__anondf4937430111::LegalizeUnidirectionalSequenceRnn603   LogicalResult matchAndRewrite(Operation* op,
604                                 PatternRewriter& rewriter) const override {
605     auto tflite_indices_attr =
606         op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
607     if (!tflite_indices_attr) return failure();
608 
609     if (op->getNumOperands() != 5) {
610       op->emitError()
611           << "We're expecting 5 inputs for UnidirectionalSequenceRNN, only "
612           << op->getNumOperands() << " provided";
613       return failure();
614     }
615 
616     if (op->getNumResults() != 2) {
617       op->emitError()
618           << "We're expecting 2 inputs for UnidirectionalSequenceRNN, only "
619           << op->getNumResults() << " found";
620       return failure();
621     }
622 
623     // Populate inputs.
624     // UnidirectionalSequenceRnn is expected to have 5 inputs, and none of them
625     // are optional inputs.
626     SmallVector<Value, 5> inputs;
627     for (int i = 0; i < 5; ++i) {
628       inputs.push_back(op->getOperand(i));
629     }
630 
631     // Populate outputs.
632     // UnidirectionalSequenceRnn should only have 1 output, and that is the
633     // original ophint converted node's 2nd output.
634     SmallVector<Type, 4> result_types;
635     result_types.push_back(op->getOpResult(1).getType());
636 
637     // Populate attributes.
638     SmallVector<NamedAttribute, 2> attributes;
639     // Activation will always be tanh.
640     attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
641                                                rewriter.getStringAttr("TANH")));
642 
643     // will always be time_majored.
644     attributes.push_back(
645         rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
646 
647     Value rnn_result = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
648         op->getLoc(), result_types, inputs, attributes);
649 
650     // Rewire the output.
651     rewriter.replaceOp(op, {nullptr, rnn_result});
652 
653     return success();
654   }
655 };
656 
657 // Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
658 // to make binary broadcast-able op conversion always successful and does not
659 // require flex delegate.
660 template <typename SourceOp>
661 class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
662  public:
663   using OpRewritePattern<SourceOp>::OpRewritePattern;
664 
rewriteOpWithDynamicInput(Operation * op,PatternRewriter & rewriter) const665   LogicalResult rewriteOpWithDynamicInput(Operation* op,
666                                           PatternRewriter& rewriter) const {
667     auto lhs = op->getOperand(0);
668     auto rhs = op->getOperand(1);
669     auto out = op->getResult(0);
670 
671     // Calculates symbolic broadcast shape that is only used in types.
672     SmallVector<int64_t, 4> symbolic_broadcast_shape;
673     if (!OpTrait::util::getBroadcastedShape(
674             lhs.getType().cast<ShapedType>().getShape(),
675             rhs.getType().cast<ShapedType>().getShape(),
676             symbolic_broadcast_shape)) {
677       return failure();
678     }
679 
680     // Calculates the broadcast shape using BroadcastArgs op.
681     Value lhs_shape = GetShape(lhs, op->getLoc(), rewriter);
682     Value rhs_shape = GetShape(rhs, op->getLoc(), rewriter);
683     auto broadcast_shape =
684         rewriter
685             .create<TF::BroadcastArgsOp>(
686                 op->getLoc(),
687                 RankedTensorType::get(symbolic_broadcast_shape.size(),
688                                       rewriter.getIntegerType(64)),
689                 lhs_shape, rhs_shape)
690             .r0();
691 
692     // Broadcasts inputs using BroadcastTo op.
693     auto broadcast_type = RankedTensorType::get(
694         symbolic_broadcast_shape, getElementTypeOrSelf(lhs.getType()));
695     auto broadcasted_lhs =
696         rewriter
697             .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
698                                        broadcast_shape)
699             .output();
700     auto broadcasted_rhs =
701         rewriter
702             .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
703                                        broadcast_shape)
704             .output();
705 
706     // Recreate an op with the above BroadcastTo op results.
707     RankedTensorType result_type = RankedTensorType::get(
708         symbolic_broadcast_shape, getElementTypeOrSelf(out.getType()));
709     rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, broadcasted_lhs,
710                                           broadcasted_rhs);
711     return success();
712   }
713 
matchAndRewrite(SourceOp src_op,PatternRewriter & rewriter) const714   LogicalResult matchAndRewrite(SourceOp src_op,
715                                 PatternRewriter& rewriter) const override {
716     Operation* op = static_cast<Operation*>(src_op);
717     auto lhs = op->getOperand(0);
718     auto rhs = op->getOperand(1);
719 
720     if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
721         !rhs.getType().cast<ShapedType>().hasStaticShape()) {
722       return rewriteOpWithDynamicInput(op, rewriter);
723     }
724 
725     auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
726     auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
727 
728     if (lhs_shape == rhs_shape) {
729       return failure();
730     }
731 
732     // Calculate the broadcasted shape.
733     SmallVector<int64_t, 4> result_shape;
734     if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
735                                             result_shape)) {
736       return failure();
737     }
738 
739     RankedTensorType result_type = RankedTensorType::get(
740         result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
741 
742     // Create a const op, that stores the above broadcasted shape.
743     auto new_shape_attr = mlir::DenseIntElementsAttr::get(
744         RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
745         result_shape);
746     auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
747 
748     // Apply BroadcastTo ops to each input.
749     auto broadcast_type = RankedTensorType::get(
750         result_shape, getElementTypeOrSelf(lhs.getType()));
751 
752     if (result_type.getShape() != lhs_shape) {
753       lhs = rewriter
754                 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
755                                            new_shape)
756                 .output();
757     }
758     if (result_type.getShape() != rhs_shape) {
759       rhs = rewriter
760                 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
761                                            new_shape)
762                 .output();
763     }
764 
765     // Recreate an op with the above Broadcast op results.
766     rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
767     return success();
768   }
769 };
770 
771 // This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
772 // they should have broadcastable shapes.
773 template <>
774 class ApplyExplicitBroadcasting<TF::SelectV2Op>
775     : public OpRewritePattern<TF::SelectV2Op> {
776  public:
777   using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
778 
rewriteOpWithDynamicInput(Operation * op,PatternRewriter & rewriter) const779   LogicalResult rewriteOpWithDynamicInput(Operation* op,
780                                           PatternRewriter& rewriter) const {
781     auto cond = op->getOperand(0);
782     auto lhs = op->getOperand(1);
783     auto rhs = op->getOperand(2);
784     auto out = op->getResult(0);
785 
786     // Calculates symbolic broadcast shape that is only used in types.
787     SmallVector<int64_t, 4> symbolic_broadcast_lhs_rhs_shape;
788     if (!OpTrait::util::getBroadcastedShape(
789             lhs.getType().cast<ShapedType>().getShape(),
790             rhs.getType().cast<ShapedType>().getShape(),
791             symbolic_broadcast_lhs_rhs_shape)) {
792       return failure();
793     }
794     SmallVector<int64_t, 4> symbolic_broadcast_shape;
795     if (!OpTrait::util::getBroadcastedShape(
796             cond.getType().cast<ShapedType>().getShape(),
797             symbolic_broadcast_lhs_rhs_shape, symbolic_broadcast_shape)) {
798       return failure();
799     }
800 
801     // Calculates the broadcast shape using BroadcastArgs op.
802     Value cond_shape = GetShape(cond, op->getLoc(), rewriter);
803     Value lhs_shape = GetShape(lhs, op->getLoc(), rewriter);
804     Value rhs_shape = GetShape(rhs, op->getLoc(), rewriter);
805     auto broadcast_shape_value =
806         rewriter
807             .create<TF::BroadcastArgsOp>(op->getLoc(), lhs_shape.getType(),
808                                          lhs_shape, rhs_shape)
809             .r0();
810     broadcast_shape_value =
811         rewriter
812             .create<TF::BroadcastArgsOp>(op->getLoc(), lhs_shape.getType(),
813                                          broadcast_shape_value, cond_shape)
814             .r0();
815 
816     // Broadcasting inputs using BroadcastTo op.
817     auto broadcast_type = RankedTensorType::get(
818         symbolic_broadcast_shape, getElementTypeOrSelf(out.getType()));
819     auto broadcasted_cond =
820         rewriter
821             .create<TF::BroadcastToOp>(
822                 op->getLoc(),
823                 RankedTensorType::get(symbolic_broadcast_shape,
824                                       rewriter.getIntegerType(1)),
825                 cond, broadcast_shape_value)
826             .output();
827     auto broadcasted_lhs =
828         rewriter
829             .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
830                                        broadcast_shape_value)
831             .output();
832     auto broadcasted_rhs =
833         rewriter
834             .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
835                                        broadcast_shape_value)
836             .output();
837 
838     // Recreate an op with the above BroadcastTo op results.
839     rewriter.replaceOpWithNewOp<TF::SelectV2Op>(
840         op, broadcast_type, broadcasted_cond, broadcasted_lhs, broadcasted_rhs);
841     return success();
842   }
843 
matchAndRewrite(TF::SelectV2Op src_op,PatternRewriter & rewriter) const844   LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
845                                 PatternRewriter& rewriter) const override {
846     Operation* op = static_cast<Operation*>(src_op);
847     auto cond = op->getOperand(0);
848     auto lhs = op->getOperand(1);
849     auto rhs = op->getOperand(2);
850 
851     // Should have static shapes to calculate the broadcasted shape.
852     if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
853         !rhs.getType().cast<ShapedType>().hasStaticShape() ||
854         !cond.getType().cast<ShapedType>().hasStaticShape()) {
855       return rewriteOpWithDynamicInput(op, rewriter);
856     }
857 
858     auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
859     auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
860     auto cond_shape = cond.getType().cast<ShapedType>().getShape();
861 
862     if (lhs_shape == rhs_shape && cond_shape == lhs_shape) {
863       return failure();
864     }
865 
866     // Calculate the broadcasted shape.
867     SmallVector<int64_t, 4> broadcasted_shape;
868     if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
869                                             broadcasted_shape)) {
870       return failure();
871     }
872 
873     SmallVector<int64_t, 4> result_shape;
874     if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape,
875                                             result_shape)) {
876       return failure();
877     }
878 
879     // Create a const op, that stores the above broadcasted shape.
880     auto shape_type =
881         RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
882     auto new_shape_attr =
883         mlir::DenseIntElementsAttr::get(shape_type, result_shape);
884     auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
885 
886     // Apply BroadcastTo ops to each input.
887     auto cond_result_type =
888         RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
889     auto result_type = RankedTensorType::get(
890         result_shape, getElementTypeOrSelf(lhs.getType()));
891 
892     if (result_shape != cond_shape) {
893       cond = rewriter
894                  .create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
895                                             cond, new_shape)
896                  .output();
897     }
898     if (result_shape != lhs_shape) {
899       lhs = rewriter
900                 .create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
901                                            new_shape)
902                 .output();
903     }
904     if (result_shape != rhs_shape) {
905       rhs = rewriter
906                 .create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
907                                            new_shape)
908                 .output();
909     }
910 
911     // Recreate an op with the above Broadcast op results.
912     rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
913                                                 rhs);
914     return success();
915   }
916 };
917 
addPatterns(MLIRContext * context,OwningRewritePatternList & patterns)918 void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
919   // Add TF->TF lowering patterns.
920   TF::PopulateLoweringTFPatterns(context, &patterns);
921 
922   // Add the generated patterns to the list.
923   populateWithGenerated(patterns);
924   patterns
925       .insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
926               ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
927               ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
928               ConvertTFRandomUniformOp, ConvertTFConv3DOp,
929               ConvertTFConv3DBackpropInputV2Op>(context);
930 
931   // Ophint python converter converted tf node pattern.
932   patterns.insert<LegalizeUnidirectionalSequenceLstm,
933                   LegalizeUnidirectionalSequenceRnn>(context);
934 }
935 
applyPatterns(FuncOp func,ConversionTarget & target,FrozenRewritePatternSet & frozenPatterns)936 void applyPatterns(FuncOp func, ConversionTarget& target,
937                    FrozenRewritePatternSet& frozenPatterns) {
938   // Keep trying to convert.
939   // TODO(karimnosseir): This is similar to what apply greedy patterns does.
940   // Look if there is a function that tries until it converge.
941   // Currently unit-test doesn't do multiple tries, so we need this.
942   const int max_iterations = 15;
943   for (int i = 0; i < max_iterations; ++i) {
944     if (failed(applyPartialConversion(func, target, frozenPatterns))) {
945       return;
946     }
947   }
948 }
949 
runOnFunction()950 void LegalizeTF::runOnFunction() {
951   auto* context = &getContext();
952   auto func = getFunction();
953 
954   ConversionTarget target(*context);
955   // It is legal to have TF ops in the graph still which can be
956   // used later or in the case of SELECT were we allow TF ops in the final
957   // graph.
958   target.addLegalOp<mlir::ConstantOp>();
959   target.addLegalOp<ConstOp>();
960   if (run_tfl_runtime_verification_) {
961     target.addDynamicallyLegalDialect<TensorFlowLiteDialect>([](Operation* op) {
962       auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
963       if (!tfl_op) return false;
964       return succeeded(tfl_op.VerifyTflRuntimeConstraints(op));
965     });
966   } else {
967     target.addLegalDialect<TensorFlowLiteDialect>();
968   }
969 
970   // Ignore transient errors by registering an no-op handler.
971   // Applying legalization patterns will emit unwanted, transient errors when
972   // the replaced TFLite ops do not meet the sanity checks. In order to ignore
973   // the transient errors, the following lines override a diagnostic handler
974   // with an no-op handler only while this pass runs.
975   uint64_t current_thread_id = llvm::get_threadid();
976   ScopedDiagnosticHandler scoped_diag_handler(
977       context, [&current_thread_id](Diagnostic&) -> LogicalResult {
978         // Consume only errors that are coming from the same thread in order not
979         // to ignore errors from other passes that are running. Things running
980         // in the pass manager can be multi-threaded.
981         return success(current_thread_id == llvm::get_threadid());
982       });
983 
984   OwningRewritePatternList stage1Patterns(&getContext());
985 
986   addPatterns(context, stage1Patterns);
987 
988   FrozenRewritePatternSet stage1FrozenPatterns(std::move(stage1Patterns));
989   applyPatterns(func, target, stage1FrozenPatterns);
990 
991   // Explict BroadcastTo addition for left-over broadcast-able ops.
992   // The following pattern matchings should be done after the other legalization
993   // rules in order not to add unnecessary BroadcastTo ops.
994   OwningRewritePatternList stage2Patterns(&getContext());
995 
996   addPatterns(context, stage2Patterns);
997 
998   stage2Patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
999                         ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
1000                         ApplyExplicitBroadcasting<TF::NotEqualOp>,
1001                         ApplyExplicitBroadcasting<TF::GreaterOp>,
1002                         ApplyExplicitBroadcasting<TF::LessOp>,
1003                         ApplyExplicitBroadcasting<TF::EqualOp>,
1004                         ApplyExplicitBroadcasting<TF::AddOp>,
1005                         ApplyExplicitBroadcasting<TF::AddV2Op>,
1006                         ApplyExplicitBroadcasting<TF::MulOp>,
1007                         ApplyExplicitBroadcasting<TF::DivOp>,
1008                         ApplyExplicitBroadcasting<TF::RealDivOp>,
1009                         ApplyExplicitBroadcasting<TF::SubOp>,
1010                         ApplyExplicitBroadcasting<TF::FloorDivOp>,
1011                         ApplyExplicitBroadcasting<TF::FloorModOp>,
1012                         ApplyExplicitBroadcasting<TF::PowOp>,
1013                         ApplyExplicitBroadcasting<TF::MaximumOp>,
1014                         ApplyExplicitBroadcasting<TF::MinimumOp>,
1015                         ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
1016                         ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
1017 
1018   FrozenRewritePatternSet stage2FrozenPatterns(std::move(stage2Patterns));
1019   applyPatterns(func, target, stage2FrozenPatterns);
1020 }
1021 
1022 }  // namespace
1023 
1024 // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
CreateLegalizeTFPass(bool run_tfl_runtime_verification)1025 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
1026     bool run_tfl_runtime_verification) {
1027   return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
1028 }
1029 
1030 static PassRegistration<LegalizeTF> pass;
1031 
1032 }  // namespace TFL
1033 }  // namespace mlir
1034