• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass converts operations in TensorFlow dialect into
17 // operations that are legal in the TensorFlow Lite dialect.  Operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements are part
19 // of this pass and other operations that may create extra ops should be part of
20 // the PrepareTF pass which should be run before this pass.  That way any
21 // constant folding opportunities from the extra ops can be exploited by the
22 // constant folding support for the TensorFlow ops.
23 
24 #include <climits>
25 #include <complex>
26 #include <cstdint>
27 
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/StringSwitch.h"
31 #include "llvm/Support/Threading.h"
32 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
33 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
34 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
35 #include "mlir/IR/Attributes.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
37 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
38 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
39 #include "mlir/IR/Operation.h"  // from @llvm-project
40 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
41 #include "mlir/Pass/Pass.h"  // from @llvm-project
42 #include "mlir/Support/LLVM.h"  // from @llvm-project
43 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
44 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
46 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
47 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
49 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
50 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
53 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
55 #include "tensorflow/compiler/xla/status.h"
56 #include "tensorflow/compiler/xla/statusor.h"
57 #include "tensorflow/core/framework/tensor.pb.h"
58 #include "tensorflow/core/framework/tensor_shape.pb.h"
59 #include "tensorflow/core/framework/types.pb.h"
60 #include "tensorflow/core/lib/random/philox_random.h"
61 #include "tensorflow/core/lib/random/random_distributions.h"
62 #include "tensorflow/core/protobuf/error_codes.pb.h"
63 
64 namespace mlir {
65 namespace TFL {
66 
67 //===----------------------------------------------------------------------===//
68 // The actual LegalizeTF Pass.
69 namespace {
70 
71 using xla::StatusOr;
72 
73 constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
74 constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
75 constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
76 
77 // Legalize operations in functions.
78 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const79   void getDependentDialects(DialectRegistry& registry) const override {
80     registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>();
81   }
82 
83  public:
84   LegalizeTF() = default;
LegalizeTF(const LegalizeTF &)85   LegalizeTF(const LegalizeTF&) {}
LegalizeTF(bool run_tfl_runtime_verification)86   explicit LegalizeTF(bool run_tfl_runtime_verification) {
87     run_tfl_runtime_verification_ = run_tfl_runtime_verification;
88   }
89 
90   /// Performs the lowering to TFLite dialect.
91   void runOnFunction() override;
92 
93  private:
94   Option<bool> run_tfl_runtime_verification_{
95       *this, "run-tfl-runtime-verification",
96       llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
97 };
98 
99 // Returns true if all tensor value in `values` has static shape and same shape.
HasSameStaticShapes(Operation * op)100 bool HasSameStaticShapes(Operation* op) {
101   auto values = op->getOperands();
102   int index = 0;
103   ArrayRef<int64_t> shape;
104   for (Value value : values) {
105     auto shaped_type = value.getType().dyn_cast<ShapedType>();
106     if (!shaped_type || !shaped_type.hasStaticShape()) {
107       return false;
108     }
109     if (index == 0) {
110       shape = shaped_type.getShape();
111     } else {
112       if (shape != shaped_type.getShape()) {
113         return false;
114       }
115     }
116     ++index;
117   }
118   return true;
119 }
120 
121 // Util that casts 'val' to Int32 by adding a cast Op.
CreateCastToInt32(Value val,Location loc,PatternRewriter & rewriter)122 Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
123   auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
124   IntegerType new_ele_type = rewriter.getIntegerType(32);
125   ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
126   return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
127                                            rewriter.getBoolAttr(false));
128 }
129 
130 #include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
131 
132 #define DECL_CONVERT_OP(tf_op)                                               \
133   struct ConvertTF##tf_op##Op : public RewritePattern {                      \
134     explicit ConvertTF##tf_op##Op(MLIRContext* context)                      \
135         : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {}   \
136     LogicalResult matchAndRewrite(Operation* op,                             \
137                                   PatternRewriter& rewriter) const override; \
138   }
139 
140 // TODO(antiagainst): Define this pattern in a table-driven manner once variadic
141 // operands are properly supported in declarative rewrite rule specification.
142 
143 DECL_CONVERT_OP(Assert);
144 DECL_CONVERT_OP(ConcatV2);
145 DECL_CONVERT_OP(MatMul);
146 DECL_CONVERT_OP(MatrixDiagV2);
147 DECL_CONVERT_OP(MatrixDiagV3);
148 DECL_CONVERT_OP(Pack);
149 DECL_CONVERT_OP(Split);
150 DECL_CONVERT_OP(SplitV);
151 DECL_CONVERT_OP(Unpack);
152 DECL_CONVERT_OP(RandomUniform);
153 DECL_CONVERT_OP(Conv3D);
154 
155 #undef DECL_CONVERT_OP
156 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const157 LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
158     Operation* op, PatternRewriter& rewriter) const {
159   auto random_uniform_op = cast<TF::RandomUniformOp>(op);
160   if (random_uniform_op.seed() == 0 && random_uniform_op.seed2() == 0) {
161     return failure();
162   }
163   if (!random_uniform_op.dtype().isF32()) {
164     return failure();
165   }
166   typedef tensorflow::random::UniformDistribution<
167       tensorflow::random::PhiloxRandom, float>
168       Distribution;
169 
170   tensorflow::random::PhiloxRandom generator(random_uniform_op.seed(),
171                                              random_uniform_op.seed2());
172   Distribution dist;
173   size_t num_elements = 0;
174   if (auto output_type =
175           random_uniform_op.output().getType().dyn_cast_or_null<ShapedType>()) {
176     if (auto ranked_output = output_type.dyn_cast_or_null<RankedTensorType>()) {
177       if (!ranked_output.hasRank() || ranked_output.getNumDynamicDims() != 0) {
178         return failure();
179       }
180       num_elements = output_type.getNumElements();
181       size_t offset = 0;
182       size_t num_samples = Distribution::kResultElementCount;
183       llvm::SmallVector<float, 32> data;
184       data.resize(num_elements);
185       while (offset < num_elements) {
186         const typename Distribution::ResultType samples = dist(&generator);
187         std::copy(&samples[0],
188                   &samples[0] + std::min(num_samples, data.size() - offset),
189                   &data[0] + offset);
190         offset += num_samples;
191       }
192       auto output_data = DenseFPElementsAttr::get(output_type, data);
193       rewriter.replaceOpWithNewOp<ConstantOp>(op, output_type, output_data);
194       return success();
195     }
196   }
197   return failure();
198 }
199 
200 // Converts any IntegerAttr to an IntegerAttr of an i32 type.
201 // The value won't change in the new attribute, but if the value is out of
202 // the bound of i32, the function returns a failure.
ConvertToI32Attr(IntegerAttr attr,IntegerAttr * attr_i32)203 LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
204   if (attr.getType().isInteger(/*width=*/32)) {
205     *attr_i32 = attr;
206     return success();
207   }
208 
209   int64_t value = attr.getInt();
210   if (value > std::numeric_limits<int>::max() ||
211       value < std::numeric_limits<int>::min()) {
212     return failure();
213   }
214 
215   *attr_i32 = IntegerAttr::get(
216       IntegerType::get(attr.getContext(), /*width=*/32), value);
217   return success();
218 }
219 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const220 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
221     Operation* op, PatternRewriter& rewriter) const {
222   auto tf_concat_op = cast<TF::ConcatV2Op>(op);
223 
224   auto values = tf_concat_op.values();
225   auto output_type = tf_concat_op.output().getType();
226   // Extract axis attribute from constant axis tensor
227   ElementsAttr axis;
228   if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
229   IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
230 
231   // "axis" operand could be a i64 tensor. Resolve it here.
232   IntegerAttr axis_i32;
233   if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
234 
235   StringAttr fused_activation_function =
236       StringAttr::get(rewriter.getContext(), "NONE");
237   rewriter.replaceOpWithNewOp<ConcatenationOp>(
238       op, output_type, values, axis_i32, fused_activation_function);
239   return success();
240 }
241 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const242 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
243     Operation* op, PatternRewriter& rewriter) const {
244   auto tf_matmul_op = cast<TF::MatMulOp>(op);
245   auto lhs = op->getOperand(0);
246   auto rhs = op->getOperand(1);
247   auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
248     RankedTensorType type =
249         input.getType().dyn_cast_or_null<RankedTensorType>();
250     if (!type || type.getRank() != 2) return {failure(), nullptr};
251 
252     auto permute_attr = DenseIntElementsAttr::get(
253         RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
254     auto permute = rewriter.create<ConstantOp>(
255         op->getLoc(), permute_attr.getType(), permute_attr);
256     llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
257                                             type.getShape()[0]};
258     auto output = rewriter.create<TFL::TransposeOp>(
259         op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
260         input, permute);
261     return {success(), output};
262   };
263 
264   // TODO(jpienaar): Remove once handled via dailect conversion.
265   if (tf_matmul_op.transpose_a()) {
266     LogicalResult result = success();
267     std::tie(result, lhs) = transpose(lhs);
268     if (failed(result)) return failure();
269   }
270   if (!tf_matmul_op.transpose_b()) {
271     LogicalResult result = success();
272     std::tie(result, rhs) = transpose(rhs);
273     if (failed(result)) return failure();
274   }
275 
276   Type output_type = tf_matmul_op.getResult().getType();
277   auto no_input = rewriter.create<ConstantOp>(
278       op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
279   auto fc_op = rewriter.create<FullyConnectedOp>(
280       op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
281       rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
282       rewriter.getBoolAttr(false));
283   rewriter.replaceOp(op, {fc_op.getResult(0)});
284   return success();
285 }
286 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const287 LogicalResult ConvertTFPackOp::matchAndRewrite(
288     Operation* op, PatternRewriter& rewriter) const {
289   auto tf_pack_op = cast<TF::PackOp>(op);
290 
291   SmallVector<Value, 4> values(tf_pack_op.values());
292   auto output_type = tf_pack_op.output().getType();
293   auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
294   // Axis can be negative.
295   auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis());
296 
297   rewriter.replaceOpWithNewOp<PackOp>(op, output_type, values, values_count,
298                                       axis);
299   return success();
300 }
301 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const302 LogicalResult ConvertTFSplitOp::matchAndRewrite(
303     Operation* op, PatternRewriter& rewriter) const {
304   auto tf_split_op = cast<TF::SplitOp>(op);
305 
306   // Number of splits cannot be negative.
307   auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
308 
309   rewriter.replaceOpWithNewOp<TFL::SplitOp>(op, tf_split_op.output().getTypes(),
310                                             tf_split_op.split_dim(),
311                                             tf_split_op.value(), num_split);
312   return success();
313 }
314 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const315 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
316     Operation* op, PatternRewriter& rewriter) const {
317   auto tf_splitv_op = cast<TF::SplitVOp>(op);
318 
319   // Number of splits cannot be negative.
320   auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
321 
322   rewriter.replaceOpWithNewOp<TFL::SplitVOp>(
323       op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(),
324       tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split);
325   return success();
326 }
327 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const328 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
329     Operation* op, PatternRewriter& rewriter) const {
330   auto tf_unpack_op = cast<TF::UnpackOp>(op);
331 
332   auto input = tf_unpack_op.value();
333   auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
334   // Axis can be negative.
335   auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis());
336 
337   rewriter.replaceOpWithNewOp<UnpackOp>(op, tf_unpack_op.output().getTypes(),
338                                         input, num, axis);
339   return success();
340 }
341 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const342 LogicalResult ConvertTFConv3DOp::matchAndRewrite(
343     Operation* op, PatternRewriter& rewriter) const {
344   if (!TFDataFormatIsNDHWC(op)) return failure();
345 
346   auto tf_op = cast<TF::Conv3DOp>(op);
347 
348   IntegerAttr stride_depth, stride_height, stride_width;
349   if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height,
350                         &stride_width))
351     return failure();
352 
353   IntegerAttr dilation_depth_factor, dilation_height_factor,
354       dilation_width_factor;
355   if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor,
356                         &dilation_height_factor, &dilation_width_factor)) {
357     // If the 'dilations' attribute is missing, we use the default value (1)
358     // for all dilation depth, height and width factor.
359     dilation_depth_factor = rewriter.getI32IntegerAttr(1);
360     dilation_height_factor = rewriter.getI32IntegerAttr(1);
361     dilation_width_factor = rewriter.getI32IntegerAttr(1);
362   }
363 
364   StringAttr padding;
365   if (!TFPaddingIsSameOrValid(op, &padding)) return failure();
366 
367   // TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D
368   // with other ops can fill the bias.
369   Value none = rewriter.create<mlir::ConstantOp>(
370       op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
371 
372   rewriter.replaceOpWithNewOp<TFL::Conv3DOp>(
373       op, tf_op.getType(), tf_op.input(), tf_op.filter(),
374       /*bias=*/none, dilation_depth_factor, dilation_height_factor,
375       dilation_width_factor,
376       /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding,
377       stride_depth, stride_height, stride_width);
378 
379   return success();
380 }
381 
382 // MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute
383 // only has effects when processing multiple diagonals. Since TFLite converts
384 // MatrixDiagV{2,3} to MatrixDiag, which only takes single-diagonal inputs, we
385 // can safely ignore this V3 attribute.
386 // We can't pass `rewriter` by reference because clang-tidy will want it to be
387 // constant (`const PatternRewriter& rewriter`). If we do that, we won't be able
388 // to call `rewriter::replaceOpWihNewOp`, which is not a const member function.
389 template <typename MatrixDiagV2OrV3Op>
ConvertTFMatrixDiagV2orV3(Operation * op,PatternRewriter * rewriter)390 bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
391   auto tf_matrix_diag_v2_or_v3_op = cast<MatrixDiagV2OrV3Op>(op);
392 
393   if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
394 
395   auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
396   auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
397 
398   // Extract k constant tensor and check value = 0.
399   ElementsAttr k;
400   if (!matchPattern(tf_matrix_diag_v2_or_v3_op.k(), m_Constant(&k)))
401     return false;
402   if (ExtractSingleElementAsInteger(k).getInt() != 0) return false;
403 
404   // Extract num_rows constant tensor and check value = -1.
405   ElementsAttr num_rows;
406   if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_rows(),
407                     m_Constant(&num_rows)))
408     return false;
409   if (ExtractSingleElementAsInteger(num_rows).getInt() != -1) return false;
410 
411   // Extract num_cols constant tensor and check value = -1.
412   ElementsAttr num_cols;
413   if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_cols(),
414                     m_Constant(&num_cols)))
415     return false;
416   if (ExtractSingleElementAsInteger(num_cols).getInt() != -1) return false;
417 
418   // Verify padding_value is an integer tensor with all 0s.
419   ElementsAttr padding_value;
420   if (!matchPattern(tf_matrix_diag_v2_or_v3_op.padding_value(),
421                     m_Constant(&padding_value)))
422     return false;
423   for (const auto& value : padding_value.getValues<APInt>()) {
424     if (value != 0) return false;
425   }
426 
427   rewriter->replaceOpWithNewOp<MatrixDiagOp>(op, output_type, input);
428   return true;
429 }
430 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const431 LogicalResult ConvertTFMatrixDiagV2Op::matchAndRewrite(
432     Operation* op, PatternRewriter& rewriter) const {
433   if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV2Op>(op, &rewriter))
434     return success();
435   return failure();
436 }
437 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const438 LogicalResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
439     Operation* op, PatternRewriter& rewriter) const {
440   if (ConvertTFMatrixDiagV2orV3<TF::MatrixDiagV3Op>(op, &rewriter))
441     return success();
442   return failure();
443 }
444 
445 // TF Lite doesn't support Assert, we just drop the assert from the graph.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const446 LogicalResult ConvertTFAssertOp::matchAndRewrite(
447     Operation* op, PatternRewriter& rewriter) const {
448   rewriter.eraseOp(op);
449   return success();
450 }
451 
452 // Legalize unidirectional sequence lstm.
453 struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
LegalizeUnidirectionalSequenceLstmmlir::TFL::__anon82e37e1e0111::LegalizeUnidirectionalSequenceLstm454   explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
455       : RewritePattern(kUnidirectionalSequenceLstm, 1, context) {}
456 
matchAndRewritemlir::TFL::__anon82e37e1e0111::LegalizeUnidirectionalSequenceLstm457   LogicalResult matchAndRewrite(Operation* op,
458                                 PatternRewriter& rewriter) const override {
459     auto tflite_indices_attr =
460         op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
461     if (!tflite_indices_attr) return failure();
462 
463     SmallVector<int64_t, 20> tflite_indices;
464     for (auto index_attr : tflite_indices_attr.getValue()) {
465       IntegerAttr index = index_attr.cast<IntegerAttr>();
466       tflite_indices.push_back(index.getInt());
467     }
468 
469     // Optional input placeholder.
470     Value none = rewriter.create<mlir::ConstantOp>(
471         op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
472 
473     // Populate inputs.
474     // UnidirectionalSequenceLstm is expected to have 24 inputs.
475     SmallVector<Value, 24> inputs;
476     int count = 0;
477     int total_ophint_converted_inputs = tflite_indices.size();
478     for (int i = 0; i < 24; ++i) {
479       if (count < total_ophint_converted_inputs && tflite_indices[count] == i) {
480         // specified input.
481         inputs.push_back(op->getOperand(i));
482         count++;
483       } else {
484         // Non specified input.
485         inputs.push_back(none);
486       }
487     }
488 
489     // Populate outputs.
490     // UnidirectionalSequenceLstm should only have 1 output, and that is the
491     // original ophint converted node's 3rd output.
492     SmallVector<Type, 4> result_types;
493     result_types.push_back(op->getOpResult(2).getType());
494 
495     // Populate attributes.
496     SmallVector<NamedAttribute, 4> attributes;
497     // Activation will always be tanh.
498     attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
499                                                rewriter.getStringAttr("TANH")));
500     // cell_clip.
501     attributes.push_back(
502         rewriter.getNamedAttr("cell_clip", rewriter.getF32FloatAttr(0.0)));
503     // proj_clip.
504     attributes.push_back(
505         rewriter.getNamedAttr("proj_clip", rewriter.getF32FloatAttr(0.0)));
506     // will always be time_majored.
507     attributes.push_back(
508         rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
509 
510     Value lstm_result = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
511         op->getLoc(), result_types, inputs, attributes);
512 
513     // Rewire the output.
514     rewriter.replaceOp(op, {nullptr, nullptr, lstm_result});
515     return success();
516   }
517 };
518 
519 // Legalize unidirectional seqeucen rnn.
520 struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
LegalizeUnidirectionalSequenceRnnmlir::TFL::__anon82e37e1e0111::LegalizeUnidirectionalSequenceRnn521   explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context)
522       : RewritePattern(kUnidirectionalSequenceRnn, 1, context) {}
523 
matchAndRewritemlir::TFL::__anon82e37e1e0111::LegalizeUnidirectionalSequenceRnn524   LogicalResult matchAndRewrite(Operation* op,
525                                 PatternRewriter& rewriter) const override {
526     auto tflite_indices_attr =
527         op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
528     if (!tflite_indices_attr) return failure();
529 
530     if (op->getNumOperands() != 5) {
531       op->emitError()
532           << "We're expecting 5 inputs for UnidirectionalSequenceRNN, only "
533           << op->getNumOperands() << " provided";
534       return failure();
535     }
536 
537     if (op->getNumResults() != 2) {
538       op->emitError()
539           << "We're expecting 2 inputs for UnidirectionalSequenceRNN, only "
540           << op->getNumResults() << " found";
541       return failure();
542     }
543 
544     // Populate inputs.
545     // UnidirectionalSequenceRnn is expected to have 5 inputs, and none of them
546     // are optional inputs.
547     SmallVector<Value, 5> inputs;
548     for (int i = 0; i < 5; ++i) {
549       inputs.push_back(op->getOperand(i));
550     }
551 
552     // Populate outputs.
553     // UnidirectionalSequenceRnn should only have 1 output, and that is the
554     // original ophint converted node's 2nd output.
555     SmallVector<Type, 4> result_types;
556     result_types.push_back(op->getOpResult(1).getType());
557 
558     // Populate attributes.
559     SmallVector<NamedAttribute, 2> attributes;
560     // Activation will always be tanh.
561     attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
562                                                rewriter.getStringAttr("TANH")));
563 
564     // will always be time_majored.
565     attributes.push_back(
566         rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
567 
568     Value rnn_result = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
569         op->getLoc(), result_types, inputs, attributes);
570 
571     // Rewire the output.
572     rewriter.replaceOp(op, {nullptr, rnn_result});
573 
574     return success();
575   }
576 };
577 
578 // Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
579 // to make binary broadcast-able op conversion always successful and does not
580 // require flex delegate.
581 template <typename SourceOp>
582 class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
583  public:
584   using OpRewritePattern<SourceOp>::OpRewritePattern;
585 
matchAndRewrite(SourceOp src_op,PatternRewriter & rewriter) const586   LogicalResult matchAndRewrite(SourceOp src_op,
587                                 PatternRewriter& rewriter) const override {
588     Operation* op = static_cast<Operation*>(src_op);
589     auto lhs = op->getOperand(0);
590     auto rhs = op->getOperand(1);
591 
592     // Should have static shapes to calculate the broadcasted shape.
593     if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
594         !rhs.getType().cast<ShapedType>().hasStaticShape()) {
595       return failure();
596     }
597 
598     auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
599     auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
600 
601     if (lhs_shape == rhs_shape) {
602       return failure();
603     }
604 
605     // Calculate the broadcasted shape.
606     SmallVector<int64_t, 4> result_shape;
607     if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
608                                             result_shape)) {
609       return failure();
610     }
611 
612     RankedTensorType result_type = RankedTensorType::get(
613         result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
614 
615     // Create a const op, that stores the above broadcasted shape.
616     auto new_shape_attr = mlir::DenseIntElementsAttr::get(
617         RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
618         result_shape);
619     auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
620 
621     // Apply BroadcastTo ops to each input.
622     auto broadcast_type = RankedTensorType::get(
623         result_shape, getElementTypeOrSelf(lhs.getType()));
624 
625     if (result_type.getShape() != lhs_shape) {
626       lhs = rewriter
627                 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
628                                            new_shape)
629                 .output();
630     }
631     if (result_type.getShape() != rhs_shape) {
632       rhs = rewriter
633                 .create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
634                                            new_shape)
635                 .output();
636     }
637 
638     // Recreate an op with the above Broadcast op results.
639     rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
640     return success();
641   }
642 };
643 
644 // This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
645 // they should have broadcastable shapes.
646 template <>
647 class ApplyExplicitBroadcasting<TF::SelectV2Op>
648     : public OpRewritePattern<TF::SelectV2Op> {
649  public:
650   using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
651 
matchAndRewrite(TF::SelectV2Op src_op,PatternRewriter & rewriter) const652   LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
653                                 PatternRewriter& rewriter) const override {
654     Operation* op = static_cast<Operation*>(src_op);
655     auto cond = op->getOperand(0);
656     auto lhs = op->getOperand(1);
657     auto rhs = op->getOperand(2);
658 
659     // Should have static shapes to calculate the broadcasted shape.
660     if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
661         !rhs.getType().cast<ShapedType>().hasStaticShape() ||
662         !cond.getType().cast<ShapedType>().hasStaticShape()) {
663       return failure();
664     }
665 
666     auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
667     auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
668     auto cond_shape = cond.getType().cast<ShapedType>().getShape();
669 
670     if (lhs_shape == rhs_shape && cond_shape == lhs_shape) {
671       return failure();
672     }
673 
674     // Calculate the broadcasted shape.
675     SmallVector<int64_t, 4> broadcasted_shape;
676     if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
677                                             broadcasted_shape)) {
678       return failure();
679     }
680 
681     SmallVector<int64_t, 4> result_shape;
682     if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape,
683                                             result_shape)) {
684       return failure();
685     }
686 
687     // Create a const op, that stores the above broadcasted shape.
688     auto shape_type =
689         RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
690     auto new_shape_attr =
691         mlir::DenseIntElementsAttr::get(shape_type, result_shape);
692     auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
693 
694     // Apply BroadcastTo ops to each input.
695     auto cond_result_type =
696         RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
697     auto result_type = RankedTensorType::get(
698         result_shape, getElementTypeOrSelf(lhs.getType()));
699 
700     if (result_shape != cond_shape) {
701       cond = rewriter
702                  .create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
703                                             cond, new_shape)
704                  .output();
705     }
706     if (result_shape != lhs_shape) {
707       lhs = rewriter
708                 .create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
709                                            new_shape)
710                 .output();
711     }
712     if (result_shape != rhs_shape) {
713       rhs = rewriter
714                 .create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
715                                            new_shape)
716                 .output();
717     }
718 
719     // Recreate an op with the above Broadcast op results.
720     rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
721                                                 rhs);
722     return success();
723   }
724 };
725 
addPatterns(MLIRContext * context,OwningRewritePatternList & patterns)726 void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
727   // Add TF->TF lowering patterns.
728   TF::PopulateLoweringTFPatterns(context, &patterns);
729 
730   // Add the generated patterns to the list.
731   populateWithGenerated(context, patterns);
732   patterns
733       .insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
734               ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
735               ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
736               ConvertTFRandomUniformOp, ConvertTFConv3DOp>(context);
737 
738   // Ophint python converter converted tf node pattern.
739   patterns.insert<LegalizeUnidirectionalSequenceLstm,
740                   LegalizeUnidirectionalSequenceRnn>(context);
741 }
742 
applyPatterns(FuncOp func,ConversionTarget & target,FrozenRewritePatternList & frozenPatterns)743 void applyPatterns(FuncOp func, ConversionTarget& target,
744                    FrozenRewritePatternList& frozenPatterns) {
745   // Keep trying to convert.
746   // TODO(karimnosseir): This is similar to what apply greedy patterns does.
747   // Look if there is a function that tries until it converge.
748   // Currently unit-test doesn't do multiple tries, so we need this.
749   const int max_iterations = 15;
750   for (int i = 0; i < max_iterations; ++i) {
751     if (failed(applyPartialConversion(func, target, frozenPatterns))) {
752       return;
753     }
754   }
755 }
756 
runOnFunction()757 void LegalizeTF::runOnFunction() {
758   auto* context = &getContext();
759   auto func = getFunction();
760 
761   ConversionTarget target(*context);
762   // It is legal to have TF ops in the graph still which can be
763   // used later or in the case of SELECT were we allow TF ops in the final
764   // graph.
765   target.addLegalOp<mlir::ConstantOp>();
766   target.addLegalOp<ConstOp>();
767   if (run_tfl_runtime_verification_) {
768     target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
769         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
770             [](Operation* op) {
771               auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
772               if (!tfl_op) return false;
773               return succeeded(tfl_op.VerifyTflRuntimeConstraints(op));
774             }));
775   } else {
776     target.addLegalDialect<TensorFlowLiteDialect>();
777   }
778 
779   // Ignore transient errors by registering an no-op handler.
780   // Applying legalization patterns will emit unwanted, transient errors when
781   // the replaced TFLite ops do not meet the sanity checks. In order to ignore
782   // the transient errors, the following lines override a diagnostic handler
783   // with an no-op handler only while this pass runs.
784   uint64_t current_thread_id = llvm::get_threadid();
785   ScopedDiagnosticHandler scoped_diag_handler(
786       context, [&current_thread_id](Diagnostic&) -> LogicalResult {
787         // Consume only errors that are coming from the same thread in order not
788         // to ignore errors from other passes that are running. Things running
789         // in the pass manager can be multi-threaded.
790         return success(current_thread_id == llvm::get_threadid());
791       });
792 
793   OwningRewritePatternList stage1Patterns;
794 
795   addPatterns(context, stage1Patterns);
796 
797   FrozenRewritePatternList stage1FrozenPatterns(std::move(stage1Patterns));
798   applyPatterns(func, target, stage1FrozenPatterns);
799 
800   // Explict BroadcastTo addition for left-over broadcast-able ops.
801   // The following pattern matchings should be done after the other legalization
802   // rules in order not to add unnecessary BroadcastTo ops.
803   OwningRewritePatternList stage2Patterns;
804 
805   addPatterns(context, stage2Patterns);
806 
807   stage2Patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
808                         ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
809                         ApplyExplicitBroadcasting<TF::NotEqualOp>,
810                         ApplyExplicitBroadcasting<TF::GreaterOp>,
811                         ApplyExplicitBroadcasting<TF::LessOp>,
812                         ApplyExplicitBroadcasting<TF::EqualOp>,
813                         ApplyExplicitBroadcasting<TF::AddOp>,
814                         ApplyExplicitBroadcasting<TF::AddV2Op>,
815                         ApplyExplicitBroadcasting<TF::MulOp>,
816                         ApplyExplicitBroadcasting<TF::DivOp>,
817                         ApplyExplicitBroadcasting<TF::RealDivOp>,
818                         ApplyExplicitBroadcasting<TF::SubOp>,
819                         ApplyExplicitBroadcasting<TF::FloorDivOp>,
820                         ApplyExplicitBroadcasting<TF::FloorModOp>,
821                         ApplyExplicitBroadcasting<TF::PowOp>,
822                         ApplyExplicitBroadcasting<TF::MaximumOp>,
823                         ApplyExplicitBroadcasting<TF::MinimumOp>,
824                         ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
825                         ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
826 
827   FrozenRewritePatternList stage2FrozenPatterns(std::move(stage2Patterns));
828   applyPatterns(func, target, stage2FrozenPatterns);
829 }
830 
831 }  // namespace
832 
833 // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
CreateLegalizeTFPass(bool run_tfl_runtime_verification)834 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
835     bool run_tfl_runtime_verification) {
836   return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
837 }
838 
839 static PassRegistration<LegalizeTF> pass(
840     "tfl-legalize-tf", "Legalize from TensorFlow to TensorFlow Lite dialect");
841 
842 }  // namespace TFL
843 }  // namespace mlir
844