• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting Tensorlist operations in TensorFlow dialect into operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements.  The
19 // newly created operations are in the TensorFlow dialect if the operation can
20 // be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
21 // is used.
22 
23 #include <climits>
24 #include <cstdint>
25 
26 #include "absl/container/inlined_vector.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/None.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringSwitch.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/Debug.h"
35 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
36 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
37 #include "mlir/IR/Attributes.h"  // from @llvm-project
38 #include "mlir/IR/Block.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
42 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
43 #include "mlir/IR/Matchers.h"  // from @llvm-project
44 #include "mlir/IR/Operation.h"  // from @llvm-project
45 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
46 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
47 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
48 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
49 #include "mlir/IR/Types.h"  // from @llvm-project
50 #include "mlir/IR/Value.h"  // from @llvm-project
51 #include "mlir/Pass/Pass.h"  // from @llvm-project
52 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
53 #include "mlir/Support/LLVM.h"  // from @llvm-project
54 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
55 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
56 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
57 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
58 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
59 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
60 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
61 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
62 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
63 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
64 #include "tensorflow/core/framework/tensor.h"
65 #include "tensorflow/core/framework/types.pb.h"
66 #include "tensorflow/core/kernels/tensor_list.h"
67 
68 #define DEBUG_TYPE "tf-tfl-legalization"
69 
70 //===----------------------------------------------------------------------===//
71 // The actual LowerStaticTensorList Pass.
72 //
73 namespace mlir {
74 namespace {
75 
76 /// Lower TensorList ops in functions for subsequent legalization.
77 struct LowerStaticTensorListPass
78     : public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
79   LowerStaticTensorListPass() = default;
LowerStaticTensorListPassmlir::__anon3c12621d0111::LowerStaticTensorListPass80   LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
81 
82   void runOnOperation() override;
83 
84   Option<bool> allow_tensorlist_pass_through{
85       *this, "allow-tensorlist-pass-through",
86       llvm::cl::desc(
87           "When specified to true, if the tensorlist ops can't be properly "
88           "legalized by this pass, then the IR won't be changed so that "
89           "tensorlist ops can pass through (default false)"),
90       llvm::cl::init(false)};
91 };
92 
CreateI32SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int32_t val)93 Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
94                           ArrayRef<int64_t> shape, int32_t val) {
95   RankedTensorType type =
96       RankedTensorType::get(shape, rewriter->getIntegerType(32));
97   DenseElementsAttr attr =
98       DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
99   return rewriter->create<ConstantOp>(loc, type, attr);
100 }
101 
CreateI32SplatTensor(Location loc,PatternRewriter * rewriter,Value shape_tensor,int32_t val)102 Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
103                            Value shape_tensor, int32_t val) {
104   Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
105   return rewriter->create<TF::FillOp>(
106       loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
107       shape_tensor, scalar_val);
108 }
109 
110 // Returns a new type by prepending the specified dimension to the shape of
111 // the given type if it is a ranked type.
PrependLeadingDimIfRanked(int64_t dim,Type type,PatternRewriter * rewriter)112 Type PrependLeadingDimIfRanked(int64_t dim, Type type,
113                                PatternRewriter *rewriter) {
114   Type dtype = getElementTypeOrSelf(type);
115   if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
116     llvm::SmallVector<int64_t, 4> shape = {dim};
117     shape.append(ty.getShape().begin(), ty.getShape().end());
118     return RankedTensorType::get(shape, dtype);
119   }
120   return type;
121 }
122 
GetTensorTypeForTensorList(Type element_type,TF::VariantType handle_dtype,PatternRewriter * rewriter)123 Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
124                                 PatternRewriter *rewriter) {
125   // If the variant type in the output handle has item shape available, use it
126   // to derive the output shape by setting unknown leading dimension.
127   // Otherwise, result type will be of unranked type.
128   if (handle_dtype.getSubtypes().empty()) {
129     return UnrankedTensorType::get(element_type);
130   }
131   return PrependLeadingDimIfRanked(-1, handle_dtype.getSubtypes()[0], rewriter);
132 }
133 
134 // Creates a slice of the tensorlist `input_list`, starting from
135 // [start_index, 0, ...0], with size [size, -1, ...-1].
136 //
137 // Requires that `start_index` and `size` are scalar tensors and
138 // `item_position_shape` is a 1-D tensor with only one element equal to the rank
139 // of an item in the tensorlist.
CreateSliceOpForTensorList(Location loc,Value input_list,Value start_index,Value size,Value item_rank,Type result_type,PatternRewriter * rewriter)140 TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
141                                        Value start_index, Value size,
142                                        Value item_rank, Type result_type,
143                                        PatternRewriter *rewriter) {
144   // Create the start position of slice. This is done by concatenating
145   // `start_index` and `partial_start_position` together.
146   IntegerType shape_dtype = rewriter->getIntegerType(32);
147   RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
148   Value partial_start_position =
149       CreateI32SplatTensor(loc, rewriter, item_rank, 0);
150   Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
151   RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
152   auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
153       loc, vector_type, start_index, scalar_zero);
154   auto start_position = rewriter->create<TF::ConcatOp>(
155       loc, position_type, scalar_zero,
156       ArrayRef<Value>({expanded_start_index, partial_start_position}));
157 
158   // Create the slice size tensor. This is done by concatenating `size` and
159   // `partial_size`.
160   auto size_leading_dim =
161       rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
162   Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
163   auto slice_size = rewriter->create<TF::ConcatOp>(
164       loc, position_type, scalar_zero,
165       ArrayRef<Value>({size_leading_dim, partial_size}));
166 
167   return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
168                                        start_position, slice_size);
169 }
170 
171 // Converts tf.Const containing variant of type TensorList to a tensor of
172 // primitive element types. Each of the individual tensor in the list is
173 // converted to an ElementsAttr and then those are packed together using
174 // tf.Pack op.
175 struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
176   using OpConversionPattern::OpConversionPattern;
177 
matchAndRewritemlir::__anon3c12621d0111::ConvertConst178   LogicalResult matchAndRewrite(
179       TF::ConstOp op, ArrayRef<Value> operands,
180       ConversionPatternRewriter &rewriter) const override {
181     // Verify that the opaque elements attribute contains tensor of type variant
182     // and scalar shape. The variant type should hold a TensorList.
183     auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
184     if (!opaque_attr) return failure();
185     tensorflow::Tensor tensor;
186     if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
187       return failure();
188     if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
189     if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
190       return failure();
191 
192     const tensorflow::TensorList *list =
193         tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
194     if (!list) return failure();
195 
196     // Verify output type is variant and contains exactly one ranked subtypes.
197     auto variant_ty =
198         getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
199     if (!variant_ty) return failure();
200     ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
201     if (subtypes.size() != 1) return failure();
202     RankedTensorType list_element_ty =
203         subtypes.front().dyn_cast<RankedTensorType>();
204     if (!list_element_ty) return failure();
205 
206     // Extract tensor elements for the TensorList and construct result type
207     // based on the number of elements and element shape.
208     const std::vector<tensorflow::Tensor> &tensors = list->tensors();
209     llvm::SmallVector<int64_t, 4> result_shape = {
210         static_cast<int64_t>(tensors.size())};
211     result_shape.append(list_element_ty.getShape().begin(),
212                         list_element_ty.getShape().end());
213     auto result_ty =
214         RankedTensorType::get(result_shape, list_element_ty.getElementType());
215 
216     // If the list is empty, directly create the final result instead of
217     // creating the tf.Pack op. tf.Pack op requires at least one operand.
218     if (tensors.empty()) {
219       tensorflow::Tensor tensor(list->element_dtype,
220                                 tensorflow::TensorShape(result_shape));
221       auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
222       if (!attr_or.ok()) return failure();
223       rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
224       return success();
225     }
226 
227     // Extract individual tensor list element and combine them using the tf.Pack
228     // op.
229     Location loc = op.getLoc();
230     llvm::SmallVector<Value, 4> values;
231     values.reserve(tensors.size());
232     for (const tensorflow::Tensor &tensor : tensors) {
233       auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
234       if (!attr_or.ok()) return failure();
235 
236       auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
237       values.push_back(value);
238     }
239     rewriter.replaceOpWithNewOp<TF::PackOp>(
240         op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
241     return success();
242   }
243 };
244 
245 struct ConvertTensorListSetItem
246     : public OpConversionPattern<TF::TensorListSetItemOp> {
247   using OpConversionPattern::OpConversionPattern;
248 
249   // This function rewrites the original op into a series of slice and concat op
250   // to produce the same result. It first slices the first `$index` rows. Then
251   // expands the dimension of the `$item`, followed by another slice of the
252   // remaining rows starting from `$index` + 1. Lastly it concatenates the
253   // three parts together.
254   // On a high level, it's doing something like:
255   // def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
256   //      (Concat
257   //        concat_dim = 0,
258   //        (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
259   //        0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
260   //        $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListSetItem261   LogicalResult matchAndRewrite(
262       TF::TensorListSetItemOp op, ArrayRef<Value> operands,
263       ConversionPatternRewriter &rewriter) const override {
264     Location loc = op.getLoc();
265     Value input = operands[0];
266     Value index = operands[1];
267     Value item = operands[2];
268 
269     IntegerType shape_dtype = rewriter.getIntegerType(32);
270     auto item_rank = rewriter.create<TF::RankOp>(
271         loc, RankedTensorType::get({}, shape_dtype), item);
272     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
273 
274     // Calculate `index` + 1, which is used to generate the start position for
275     // the second slice op.
276     auto suffix_start =
277         rewriter.create<TF::AddOp>(loc, index.getType(), index,
278                                    CreateI32SplatConst(loc, &rewriter, {}, 1));
279 
280     auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
281         loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
282     // Create two slice ops.
283     Type element_type = input.getType().cast<TensorType>().getElementType();
284     UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
285     Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
286     TF::SliceOp slice1 =
287         CreateSliceOpForTensorList(loc, /*input_list=*/input,
288                                    /*start_index=*/scalar_zero,
289                                    /*size=*/index,
290                                    /*item_rank=*/item_position_shape,
291                                    /*result_type=*/unranked_tensor, &rewriter);
292     TF::SliceOp slice2 =
293         CreateSliceOpForTensorList(loc, /*input_list=*/input,
294                                    /*start_index=*/suffix_start,
295                                    /*size=*/scalar_minus_one,
296                                    /*item_rank=*/item_position_shape,
297                                    /*result_type=*/unranked_tensor, &rewriter);
298 
299     // Expand the dimension of item so that it will have the same rank with
300     // input.
301     auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
302         op.getLoc(), unranked_tensor, item, scalar_zero);
303 
304     // Concatenate three parts together to generate the final result.
305     rewriter.replaceOpWithNewOp<TF::ConcatOp>(
306         op, input.getType(), scalar_zero,
307         ArrayRef<Value>({slice1, expanded_item, slice2}));
308     return success();
309   }
310 };
311 
312 // Rewrites op of the template type initializing a TensorList with a list of ops
313 // to generate an equivalent raw tensor. Derived classes are required to
314 // override GetNumElements method.
315 template <typename OpT>
316 struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
317   using OpConversionPattern<OpT>::OpConversionPattern;
318 
319   // Create and return a 1-d tensor with exactly one element equal to the number
320   // of list elements to initialize the output tensor list with.
321   virtual Value GetNumElements(OpT op, ArrayRef<Value> operands,
322                                PatternRewriter *rewriter) const = 0;
323 
324   // Rewrites the original op into `tf.fill`. The result tensor shape is
325   // [num_element, element_shape]. All the values in the result tensor will be
326   // initialized to 0.
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListInitOp327   LogicalResult matchAndRewrite(
328       OpT op, ArrayRef<Value> operands,
329       ConversionPatternRewriter &rewriter) const override {
330     Type dtype = op.element_dtype();
331     if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
332           dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
333           dtype.isInteger(32) || dtype.isInteger(64))) {
334       return rewriter.notifyMatchFailure(
335           op,
336           "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
337           "integer or 16-bit/32-bit/64-bit float type during TF Lite "
338           "transformation pass");
339     }
340 
341     Value element_shape = operands[0];
342     Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
343     // If the `element_shape` is a scalar, we try to acquire its shape by
344     // looking at the first `TensorListSetItemOp` writing to this tensor list.
345     // Here we assume that the element_shape won't be changed before calling
346     // the first `TensorListSetItemOp`.
347     if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
348       if (shaped_type.getRank() == 0) {
349         bool element_shape_acquired = false;
350         auto uses = op.getResult().getUses();
351         for (auto &use : llvm::make_early_inc_range(uses)) {
352           if (TF::TensorListSetItemOp set_op =
353                   llvm::dyn_cast<TF::TensorListSetItemOp>(use.getOwner())) {
354             element_shape = rewriter.create<TF::ShapeOp>(
355                 op.getLoc(), RankedTensorType::get({-1}, shape_dtype),
356                 set_op.item());
357             element_shape_acquired = true;
358           } else if (TF::WhileOp while_op =
359                          llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
360             // Tensorlist is passed into a while loop, check inside the body
361             // function.
362             auto inside_uses = while_op.body_function()
363                                    .getArgument(use.getOperandNumber())
364                                    .getUses();
365             for (auto &inside_use : llvm::make_early_inc_range(inside_uses)) {
366               if (TF::TensorListSetItemOp set_op =
367                       llvm::dyn_cast<TF::TensorListSetItemOp>(
368                           inside_use.getOwner())) {
369                 if (auto shaped_type =
370                         set_op.item().getType().dyn_cast<ShapedType>()) {
371                   if (shaped_type.hasStaticShape()) {
372                     RankedTensorType type = RankedTensorType::get(
373                         {shaped_type.getRank()}, rewriter.getIntegerType(32));
374                     SmallVector<Attribute, 4> shape_attr;
375                     for (int64_t dim : shaped_type.getShape()) {
376                       shape_attr.push_back(rewriter.getI32IntegerAttr(dim));
377                     }
378                     DenseElementsAttr attr =
379                         DenseElementsAttr::get(type, shape_attr);
380                     element_shape =
381                         rewriter.create<ConstantOp>(op.getLoc(), type, attr);
382                     element_shape_acquired = true;
383                     break;
384                   }
385                 }
386               }
387             }
388           }
389           if (element_shape_acquired) break;
390         }
391         if (!element_shape_acquired) {
392           return rewriter.notifyMatchFailure(
393               op,
394               "requires element_shape to be 1D tensor during TF Lite "
395               "transformation pass");
396         }
397       }
398     }
399 
400     DenseIntElementsAttr dense_elem_attr;
401     if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
402       // Note: It's technically unsafe to rewrite
403       //     TensorListReserve(num_element, element_shape)
404       // to
405       //     Fill(Concat(num_element, element_shape), 0)
406       // because element_shape may contain -1 to represent unknown dimension.
407       //
408       // In real world use cases (e.g. Keras RNN), `element_shape` is usually
409       // a constant, and the first dimension of `element_shape` is usually
410       // batch dimension. Currently TFLiteConverter always rewrite unknown
411       // batch dimension to 1, therefore we also rewrite unknown dimension in
412       // `element_shape` to 1 here.
413       //
414       // This workaround enables converting Keras RNN without specifying batch
415       // dimension. This isn't guaranteed to work, but it doesn't break any
416       // non-broken cases either (since it's already broken if `element_shape`
417       // contains -1).
418       // TODO(b/142096690): Support dynamic element shape and remove the
419       // workaround.
420       SmallVector<int32_t, 4> new_element_shape_values;
421 
422       auto int_values = dense_elem_attr.getIntValues();
423       for (auto it = int_values.begin(); it != int_values.end(); ++it) {
424         auto dim_value = (*it).getSExtValue();
425         if (it == int_values.begin() && dim_value == -1) {
426           dim_value = 1;
427         }
428         new_element_shape_values.push_back(dim_value);
429       }
430 
431       auto attr = DenseIntElementsAttr::get(
432           element_shape.getType().cast<ShapedType>(), new_element_shape_values);
433       auto new_element_shape = rewriter.create<ConstantOp>(
434           op.getLoc(), element_shape.getType(), attr);
435       element_shape = new_element_shape;
436     }
437 
438     int64_t result_rank = -1;  // -1 means unknown result rank.
439     Type element_dtype = op.element_dtype();
440     Type result_type = UnrankedTensorType::get(element_dtype);
441     Value leading_dim = GetNumElements(op, operands, &rewriter);
442     if (auto element_type =
443             op.element_type().template dyn_cast<RankedTensorType>()) {
444       result_rank = element_type.getRank() + 1;
445       int64_t leading_dim_v = -1;
446       ElementsAttr element_attr;
447       if (matchPattern(leading_dim, m_Constant(&element_attr))) {
448         leading_dim_v = element_attr.getValue<IntegerAttr>(0).getInt();
449       }
450       SmallVector<int64_t, 4> result_shape = {leading_dim_v};
451       ArrayRef<int64_t> shape = element_type.getShape();
452       result_shape.append(shape.begin(), shape.end());
453       result_type = RankedTensorType::get(result_shape, element_dtype);
454     }
455 
456     // Create a 1-D RankedTensorType for result's shape. Number of elements in
457     // it is equal to the rank of the result, if known. Otherwise, the number of
458     // elements are unknown and represented with -1. In both cases, we can
459     // specify dimension using rank of the result.
460     Type shape_type = RankedTensorType::get({result_rank}, shape_dtype);
461 
462     Location loc = op.getLoc();
463     // Add number of elements as the prefix to the element shape to get shape of
464     // the output tensor.
465     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
466     auto list_shape = rewriter.create<TF::ConcatOp>(
467         loc, shape_type, scalar_zero,
468         ArrayRef<Value>({leading_dim, element_shape}));
469 
470     // Create a zero-initialized constant tensor that has the same type
471     // as specified by element_dtype.
472     RankedTensorType zero_type = RankedTensorType::get({}, element_dtype);
473     Attribute zero_attr = rewriter.getZeroAttr(zero_type);
474     auto zero = rewriter.create<ConstantOp>(loc, zero_type, zero_attr);
475 
476     rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
477     return success();
478   }
479 };
480 
481 struct ConvertTensorListReserve
482     : public ConvertTensorListInitOp<TF::TensorListReserveOp> {
ConvertTensorListReservemlir::__anon3c12621d0111::ConvertTensorListReserve483   explicit ConvertTensorListReserve(MLIRContext *context)
484       : ConvertTensorListInitOp(context) {}
485 
GetNumElementsmlir::__anon3c12621d0111::ConvertTensorListReserve486   Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
487                        PatternRewriter *rewriter) const override {
488     Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
489     Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
490     Value num_elements = operands[1];
491     IntegerAttr attr;
492     if (matchPattern(num_elements, m_Constant(&attr))) {
493       return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt());
494     }
495     return rewriter->create<TF::ExpandDimsOp>(
496         op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
497         scalar_zero);
498   }
499 };
500 
501 // Note that we ignore the second operand `max_num_elements` as we don't have
502 // any restrictions on the number of elements we can support. So this may
503 // have a different behavior compared to TensorFlow in case of errors.
504 struct ConvertEmptyTensorList
505     : public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
ConvertEmptyTensorListmlir::__anon3c12621d0111::ConvertEmptyTensorList506   explicit ConvertEmptyTensorList(MLIRContext *context)
507       : ConvertTensorListInitOp(context) {}
508 
GetNumElementsmlir::__anon3c12621d0111::ConvertEmptyTensorList509   Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
510                        PatternRewriter *rewriter) const override {
511     return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
512   }
513 };
514 
515 struct ConvertTensorListPushBack
516     : public OpConversionPattern<TF::TensorListPushBackOp> {
517   using OpConversionPattern::OpConversionPattern;
518 
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListPushBack519   LogicalResult matchAndRewrite(
520       TF::TensorListPushBackOp op, ArrayRef<Value> operands,
521       ConversionPatternRewriter &rewriter) const override {
522     Value input_handle = operands[0];
523     Value item = operands[1];
524 
525     // Expand the shape of the item so that it will have rank same as the input
526     // tensor and it is compatible for the Concat Op.
527     Type expanded_item_type =
528         PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
529     Location loc = op.getLoc();
530     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
531     auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
532         loc, expanded_item_type, item, scalar_zero);
533 
534     Type elem_type = getElementTypeOrSelf(item);
535     auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
536                             .cast<TF::VariantType>();
537     Type result_type =
538         GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
539 
540     // Concatenate tensor stored in the input handle with the expanded item to
541     // get a tensor equivalent to the TensorList generated by this op.
542     rewriter.replaceOpWithNewOp<TF::ConcatOp>(
543         op, result_type, scalar_zero,
544         ArrayRef<Value>({input_handle, expanded_item}));
545     return success();
546   }
547 };
548 
549 // Rewrites `TensorListResize` op into a functional `If` op and several basic
550 // TF ops to match the op semantics of Tensorflow. Basically, it does:
551 // 1) If the requested size is smaller or equal than the input tensorlist's
552 // size, rewrite it to a Slice op so that only the first 'size' rows are
553 // returned. 2) If the requested size is larger than the input tensorlist's
554 // size. We need to create an additional tensorlist with 'size - input_size'
555 // elements, and append it to the end of the input tensorlist.
556 // TODO(haoliang): We could simplify this transformation by rewriting to pure
557 // tensorlist ops and a few non-tensorlist ops (such as `SliceOp`). By operating
558 // only on variant types, we could save some ops involved in rewriting this op.
559 struct ConvertTensorListResize
560     : public OpConversionPattern<TF::TensorListResizeOp> {
561   using OpConversionPattern::OpConversionPattern;
562 
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListResize563   LogicalResult matchAndRewrite(
564       TF::TensorListResizeOp op, ArrayRef<Value> operands,
565       ConversionPatternRewriter &rewriter) const override {
566     Value input_handle = operands[0];
567     Value size = operands[1];
568 
569     Location loc = op.getLoc();
570     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
571 
572     // Compute the input tensorlist's length and store it in `input_size`.
573     IntegerType shape_dtype = rewriter.getIntegerType(32);
574     auto input_size = rewriter.create<TF::TensorListLengthOp>(
575         loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
576 
577     // Infer result type of this op based on TF's shape inference result.
578     Type elem_type = getElementTypeOrSelf(input_handle);
579     auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
580                             .cast<TF::VariantType>();
581     Type result_type =
582         GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
583 
584     // Compute the difference of `size` and `input_size`, and store it in
585     // `size_diff`, which is then consumed by `if_cond`.
586     auto size_diff = rewriter.create<TF::SubOp>(
587         loc, RankedTensorType::get({}, shape_dtype), size, input_size);
588     auto if_cond = rewriter.create<TF::GreaterOp>(
589         loc, RankedTensorType::get({}, rewriter.getI1Type()), size_diff,
590         scalar_zero);
591 
592     // Build the argument/result types for if branch function.
593     auto input_shape = rewriter.create<TF::ShapeOp>(
594         loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
595 
596     Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
597                                size_diff.getType(), size.getType()};
598     Type branch_result_type[] = {result_type};
599     auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
600                                        branch_result_type);
601 
602     // Create functions in a higher scope before restoring the insertion point.
603     // Additionally, create the SymbolTable before further modifying the module.
604     auto original_point = rewriter.saveInsertionPoint();
605     rewriter.setInsertionPointAfter(op->getParentOfType<FuncOp>());
606     SymbolTable manager(op->getParentOfType<ModuleOp>());
607 
608     // Constructs `then_branch`, which is executed when `if_cond` evaluates to
609     // true.
610     auto then_branch_op = rewriter.create<FuncOp>(loc, "cond_true", func_type);
611     CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
612                          &rewriter);
613 
614     // Constructs `else_branch`, which is executed when `if_cond` evaluates to
615     // false.
616     auto else_branch_op = rewriter.create<FuncOp>(loc, "cond_false", func_type);
617     CreateCondFalseBranch(loc, shape_dtype, result_type, else_branch_op,
618                           &rewriter);
619 
620     // Inserts the two blocks' names into the symbol table held by the module.
621     // Using SymbolTable will ensure that the inserted symbol names are
622     // unique.
623     manager.insert(then_branch_op);
624     manager.insert(else_branch_op);
625 
626     rewriter.restoreInsertionPoint(original_point);
627     rewriter.replaceOpWithNewOp<TF::IfOp>(
628         op, result_type, if_cond,
629         /*input=*/
630         ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
631         /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
632         /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
633         /*is_stateless=*/rewriter.getBoolAttr(true));
634     return success();
635   }
636 
637  private:
638   // When the input tensorlist's size is smaller than the requested size,
639   // then branch is executed.
640   // Create a new tensorlist of size 'size - input_size' and concat it
641   // with the input tensorlist.
CreateCondTrueBranchmlir::__anon3c12621d0111::ConvertTensorListResize642   void CreateCondTrueBranch(TF::TensorListResizeOp resize_op, Type shape_dtype,
643                             Type result_type, FuncOp branch_func,
644                             ConversionPatternRewriter *rewriter) const {
645     auto guard = OpBuilder::InsertionGuard(*rewriter);
646     Block *block =
647         rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
648                               branch_func.getType().getInputs());
649 
650     auto input_shape = block->getArgument(1);
651     auto size_diff = block->getArgument(2);
652     auto input = block->getArgument(0);
653 
654     Location loc = resize_op.getLoc();
655     // Get the element shape by slicing from index 1 in the input shape.
656     Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
657     Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
658     Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
659     auto elem_shape = rewriter->create<TF::SliceOp>(
660         loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
661         slice_size);
662     auto extended_part = rewriter->create<TF::TensorListReserveOp>(
663         loc, resize_op.output_handle().getType(), elem_shape, size_diff);
664     // `ConcatOp` expects non-variant-typed input. Insert a
665     // `TensorListStackOp` here to convert type from variant to non-variant.
666     // Note that we are using the same `result_type` for both the
667     // `TensorListStackOp` and `ConcatOp`, since the first dimension of the
668     // shape specified by `result_type` is -1.
669     auto stacked_extended_part = rewriter->create<TF::TensorListStackOp>(
670         loc, result_type, extended_part,
671         /*element_shape=*/CreateI32SplatConst(loc, rewriter, {}, -1),
672         /*num_elements=*/rewriter->getI32IntegerAttr(-1));
673     auto concat_op = rewriter->create<TF::ConcatOp>(
674         loc, result_type, scalar_zero,
675         ArrayRef<Value>({input, stacked_extended_part}));
676     rewriter->create<ReturnOp>(loc, ArrayRef<Value>({concat_op}));
677   }
678 
CreateCondFalseBranchmlir::__anon3c12621d0111::ConvertTensorListResize679   void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
680                              FuncOp branch_func,
681                              ConversionPatternRewriter *rewriter) const {
682     // When the input tensorlist's size is larger or equal than the requested
683     // size, the else branch is executed.
684     // Slice the first 'size' rows from the input tensorlist.
685     auto guard = OpBuilder::InsertionGuard(*rewriter);
686     Block *block =
687         rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
688                               branch_func.getType().getInputs());
689 
690     Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
691     Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
692     auto input = block->getArgument(0);
693     auto size = block->getArgument(3);
694 
695     // Subtract `input_rank` by 1 to get the item's rank, which is used as
696     // `partial_position_shape`.
697     auto input_rank = rewriter->create<TF::RankOp>(
698         loc, RankedTensorType::get({}, shape_dtype), input);
699     auto partial_position_shape = rewriter->create<TF::SubOp>(
700         loc, RankedTensorType::get({1}, shape_dtype), input_rank, vector_one);
701     auto slice_op =
702         CreateSliceOpForTensorList(loc, /*input_list=*/input,
703                                    /*start_index=*/scalar_zero, /*size=*/size,
704                                    /*item_rank=*/partial_position_shape,
705                                    /*result_type=*/result_type, rewriter);
706     rewriter->create<ReturnOp>(loc, ArrayRef<Value>({slice_op}));
707   }
708 };
709 
710 struct ConvertTensorListGetItem
711     : public OpConversionPattern<TF::TensorListGetItemOp> {
712   using OpConversionPattern::OpConversionPattern;
713 
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListGetItem714   LogicalResult matchAndRewrite(
715       TF::TensorListGetItemOp op, ArrayRef<Value> operands,
716       ConversionPatternRewriter &rewriter) const override {
717     Value input = operands[0];
718     Value index = operands[1];
719     rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
720                                               rewriter.getBoolAttr(true));
721     return success();
722   }
723 };
724 
725 struct ConvertTensorListLength
726     : public OpConversionPattern<TF::TensorListLengthOp> {
727   using OpConversionPattern::OpConversionPattern;
728 
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListLength729   LogicalResult matchAndRewrite(
730       TF::TensorListLengthOp op, ArrayRef<Value> operands,
731       ConversionPatternRewriter &rewriter) const override {
732     Location loc = op.getLoc();
733     Value input_handle = operands[0];
734 
735     BoolAttr true_attr = rewriter.getBoolAttr(true);
736     auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
737                                               /*use_32bit=*/true_attr);
738     rewriter.replaceOpWithNewOp<TF::GatherOp>(
739         op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
740         /*validate_indices=*/true_attr);
741     return success();
742   }
743 };
744 
745 struct ConvertTensorListStack
746     : public OpConversionPattern<TF::TensorListStackOp> {
747   using OpConversionPattern::OpConversionPattern;
748 
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListStack749   LogicalResult matchAndRewrite(
750       TF::TensorListStackOp op, ArrayRef<Value> operands,
751       ConversionPatternRewriter &rewriter) const override {
752     Location loc = op.getLoc();
753     Value input = operands[0];
754     Value element_shape = operands[1];
755 
756     // If the `element_shape` is a known constant (which is defined when calling
757     // `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
758     // trivial Reshape op (that doesn't actually change the input's shape) and
759     // also populate the shape info to the op result. The shape of the
760     // tensorlist is inferred from `num_elements` and `element_shape`.
761     auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
762     DenseIntElementsAttr dense_elem_attr;
763     if ((ranked_type && ranked_type.getRank() == 0) ||
764         !matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
765       // If no constant is spotted, just forward the operand.
766       rewriter.replaceOp(op, {input});
767       return success();
768     }
769 
770     RankedTensorType shape_type =
771         RankedTensorType::get({-1}, rewriter.getIntegerType(32));
772     auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
773     SmallVector<int64_t, 8> output_shape(/*Size=*/1, op.num_elements());
774     for (const auto &dim : dense_elem_attr.getIntValues())
775       output_shape.push_back(dim.getSExtValue());
776     RankedTensorType result_type =
777         RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
778     rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
779                                                new_shape);
780     return success();
781   }
782 };
783 
784 struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
785   using OpConversionPattern::OpConversionPattern;
786 
matchAndRewritemlir::__anon3c12621d0111::ConvertIdentity787   LogicalResult matchAndRewrite(
788       TF::IdentityOp op, ArrayRef<Value> operands,
789       ConversionPatternRewriter &rewriter) const override {
790     Value input = operands[0];
791     rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
792                                                 op.getAttrs());
793     return success();
794   }
795 };
796 
797 // Returns an unranked tensor type with an element of the same type as `value`
798 // if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
VariantToUnrankedTensorType(Type type,Value value)799 Type VariantToUnrankedTensorType(Type type, Value value) {
800   if (getElementTypeOrSelf(type).isa<TF::VariantType>())
801     return UnrankedTensorType::get(getElementTypeOrSelf(value.getType()));
802   return type;
803 }
804 
GetTensorListArgumentsFromWhileOp(TF::WhileOp op)805 llvm::SmallSet<int, 4> GetTensorListArgumentsFromWhileOp(TF::WhileOp op) {
806   llvm::SmallSet<int, 4> set;
807   for (FuncOp func : {op.cond_function(), op.body_function()}) {
808     if (!func) continue;
809 
810     for (auto arg_and_idx : llvm::enumerate(func.getArguments())) {
811       mlir::BlockArgument arg = arg_and_idx.value();
812       auto variant_ty =
813           getElementTypeOrSelf(arg.getType()).dyn_cast<TF::VariantType>();
814       if (!variant_ty) continue;
815 
816       for (auto &op_operand : arg.getUses()) {
817         auto op = op_operand.getOwner();
818         if (llvm::isa<TF::TensorListGetItemOp>(op) ||
819             llvm::isa<TF::TensorListLengthOp>(op) ||
820             llvm::isa<TF::TensorListPushBackOp>(op) ||
821             llvm::isa<TF::TensorListReserveOp>(op) ||
822             llvm::isa<TF::TensorListSetItemOp>(op) ||
823             llvm::isa<TF::TensorListStackOp>(op) ||
824             llvm::isa<TF::TensorListResizeOp>(op)) {
825           set.insert(arg_and_idx.index());
826           break;
827         }
828       }
829     }
830   }
831   return set;
832 }
833 
834 // Changes the function type of `cond_func` and `body_func` for the given While
835 // op.
UpdateFunctionTypes(TF::WhileOp op,llvm::SmallSet<int,4> tensor_list_args)836 LogicalResult UpdateFunctionTypes(TF::WhileOp op,
837                                   llvm::SmallSet<int, 4> tensor_list_args) {
838   int func_index = 0;
839   for (FuncOp func : {op.cond_function(), op.body_function()}) {
840     ++func_index;
841     if (!func) continue;
842 
843     FunctionType func_type = func.getType();
844     int num_inputs = func_type.getNumInputs();
845     int num_results = func_type.getNumResults();
846 
847     // For each argument type in function's arguments, change it to uranked
848     // tensor type if it's a variant type.
849     SmallVector<Type, 8> updated_argument_types;
850     updated_argument_types.reserve(num_inputs);
851     int i = 0;
852     for (auto it : llvm::zip(func_type.getInputs(), op.getOperands())) {
853       if (tensor_list_args.count(i)) {
854         updated_argument_types.push_back(
855             VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
856       } else {
857         updated_argument_types.push_back(std::get<0>(it));
858       }
859       ++i;
860     }
861 
862     // Change all DT_VARIANT result types in function results to unranked tensor
863     // type with element type derived from the corresponding input operand. This
864     // is correct because while body's inputs and results have the same type.
865     SmallVector<Type, 8> updated_result_types;
866     updated_result_types.reserve(num_results);
867     i = 0;
868     for (auto it : llvm::zip(func_type.getResults(), op.getOperands())) {
869       // Only update body's results.
870       if (func_index != 1 && tensor_list_args.count(i)) {
871         updated_result_types.push_back(
872             VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
873       } else {
874         updated_result_types.push_back(std::get<0>(it));
875       }
876       ++i;
877     }
878 
879     // Change `func`'s argument type to `unranked_argument_types`. If it
880     // return types contain a `DT_VARIANT`, change it to the unranked type
881     // derived from the corresponding argument.
882     func.setType(FunctionType::get(op.getContext(), updated_argument_types,
883                                    updated_result_types));
884 
885     // Change the argument type for the first block.
886     llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
887       arg.setType(updated_argument_types[arg.getArgNumber()]);
888     });
889   }
890   return success();
891 }
892 
893 struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
894   using OpConversionPattern::OpConversionPattern;
895 
matchAndRewritemlir::__anon3c12621d0111::ConvertWhile896   LogicalResult matchAndRewrite(
897       TF::WhileOp op, ArrayRef<Value> operands,
898       ConversionPatternRewriter &rewriter) const override {
899     // Find all Tensor List arugments.
900     auto tensor_list_args = GetTensorListArgumentsFromWhileOp(op);
901 
902     llvm::SmallVector<Type, 8> result_types;
903     result_types.reserve(op.getNumOperands());
904     // Change all DT_VARIANT result types to unranked tensor type.
905     int i = 0;
906     for (auto it : llvm::zip(op.getResultTypes(), operands)) {
907       if (tensor_list_args.count(i)) {
908         result_types.push_back(
909             VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
910       } else {
911         result_types.push_back(std::get<0>(it));
912       }
913       ++i;
914     }
915 
916     // Create a new while op with new operands and updated result types.
917     auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
918                                                   operands, op.getAttrs());
919     converted.removeAttr("T");
920     (void)UpdateFunctionTypes(converted, tensor_list_args);
921 
922     rewriter.replaceOp(op, converted.getResults());
923     return success();
924   }
925 };
926 
927 struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
928   using OpConversionPattern::OpConversionPattern;
929 
matchAndRewritemlir::__anon3c12621d0111::ConvertWhileRegion930   LogicalResult matchAndRewrite(
931       TF::WhileRegionOp op, ArrayRef<Value> operands,
932       ConversionPatternRewriter &rewriter) const override {
933     llvm::SmallVector<Type, 8> result_types;
934     result_types.reserve(op.getNumOperands());
935     // Change all DT_VARIANT result types to unranked tensor type.
936     for (auto it : llvm::zip(op.getResultTypes(), operands))
937       result_types.push_back(
938           VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
939 
940     // Create a new while op with new operands and updated result types.
941     auto converted = rewriter.create<TF::WhileRegionOp>(
942         op.getLoc(), result_types, operands, op.getAttrs());
943 
944     // Inline the regions from the old while into the new one, and apply
945     // signature conversion to inlined region.
946     for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
947       Region &old_region = *std::get<0>(it);
948       Region &new_region = *std::get<1>(it);
949 
950       Block &entry = old_region.front();
951       // Build signature conversion for the region.
952       TypeConverter::SignatureConversion signature_conversion(operands.size());
953       for (auto it : llvm::zip(entry.getArguments(), operands)) {
954         BlockArgument arg = std::get<0>(it);
955         signature_conversion.addInputs(
956             arg.getArgNumber(),
957             VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
958       }
959 
960       rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
961       rewriter.applySignatureConversion(&new_region, signature_conversion);
962     }
963 
964     rewriter.replaceOp(op, converted.getResults());
965     return success();
966   }
967 };
968 
969 #include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
970 
runOnOperation()971 void LowerStaticTensorListPass::runOnOperation() {
972   auto *context = &getContext();
973 
974   // TensorFlow operations that doesn't have operands and results of type
975   // variant are legal. Here, we don't distinguish between variants encoding
976   // TensorList or some other type as that information is not available here.
977   // Partial legalization is used below to still allow ops with variant types
978   // still.
979   auto is_legal = [](Operation *op) {
980     auto is_not_variant = [](Type ty) {
981       return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
982     };
983     return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
984            llvm::all_of(op->getResultTypes(), is_not_variant);
985   };
986 
987   ConversionTarget target(*context);
988   target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(
989       llvm::Optional<ConversionTarget::DynamicLegalityCallbackFn>(is_legal));
990   target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
991                       TF::TensorListGetItemOp, TF::TensorListLengthOp,
992                       TF::TensorListPushBackOp, TF::TensorListReserveOp,
993                       TF::TensorListSetItemOp, TF::TensorListStackOp,
994                       TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
995   // TODO(hinsu): Use TFLite constant op for constants.
996   target.addLegalOp<ConstantOp>();
997   target.addLegalOp<FuncOp>();
998   target.addLegalOp<ReturnOp>();
999   target.addLegalOp<TFL::CustomOp>();
1000   // Register fused LSTM/RNN ops as legal.
1001   target.addLegalOp<TFL::LSTMOp>();
1002   target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
1003   target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
1004   target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
1005 
1006   OwningRewritePatternList patterns;
1007   populateWithGenerated(context, patterns);
1008   patterns.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
1009                   ConvertTensorListGetItem, ConvertTensorListLength,
1010                   ConvertTensorListPushBack, ConvertTensorListReserve,
1011                   ConvertTensorListSetItem, ConvertTensorListStack,
1012                   ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
1013       context);
1014   if (failed(applyPartialConversion(getOperation(), target,
1015                                     std::move(patterns)))) {
1016     if (!allow_tensorlist_pass_through) {
1017       signalPassFailure();
1018     }
1019   }
1020 }
1021 
1022 }  // namespace
1023 
1024 /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
1025 /// pass.
1026 std::unique_ptr<OperationPass<ModuleOp>>
CreateLowerStaticTensorListPass()1027 TFL::CreateLowerStaticTensorListPass() {
1028   return std::make_unique<LowerStaticTensorListPass>();
1029 }
1030 
1031 static PassRegistration<LowerStaticTensorListPass> pass(
1032     "tfl-lower-static-tensor-list",
1033     "Lower TensorList ops within TensorFlow Lite dialect");
1034 
1035 }  // namespace mlir
1036