• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting Tensorlist operations in TensorFlow dialect into operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements.  The
19 // newly created operations are in the TensorFlow dialect if the operation can
20 // be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
21 // is used.
22 
23 #include <climits>
24 #include <cstdint>
25 #include <utility>
26 
27 #include "absl/container/inlined_vector.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/None.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallSet.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringSwitch.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
39 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
40 #include "mlir/IR/Attributes.h"  // from @llvm-project
41 #include "mlir/IR/Block.h"  // from @llvm-project
42 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
43 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
45 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
46 #include "mlir/IR/Matchers.h"  // from @llvm-project
47 #include "mlir/IR/Operation.h"  // from @llvm-project
48 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
49 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
50 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
51 #include "mlir/IR/TypeRange.h"  // from @llvm-project
52 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
53 #include "mlir/IR/Types.h"  // from @llvm-project
54 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
55 #include "mlir/IR/Value.h"  // from @llvm-project
56 #include "mlir/Pass/Pass.h"  // from @llvm-project
57 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
58 #include "mlir/Support/LLVM.h"  // from @llvm-project
59 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
60 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
61 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
62 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
63 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
64 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
69 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
70 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
71 #include "tensorflow/core/framework/tensor.h"
72 #include "tensorflow/core/framework/types.pb.h"
73 #include "tensorflow/core/kernels/tensor_list.h"
74 
75 #define DEBUG_TYPE "tf-tfl-legalization"
76 
77 //===----------------------------------------------------------------------===//
78 // The actual LowerStaticTensorList Pass.
79 //
80 namespace mlir {
81 namespace {
82 
83 /// Lower TensorList ops in functions for subsequent legalization.
84 struct LowerStaticTensorListPass
85     : public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
86   LowerStaticTensorListPass() = default;
LowerStaticTensorListPassmlir::__anonce850f620111::LowerStaticTensorListPass87   LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
LowerStaticTensorListPassmlir::__anonce850f620111::LowerStaticTensorListPass88   explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through) {
89     this->allow_tensorlist_pass_through = allow_tensorlist_pass_through;
90   }
91 
getArgumentmlir::__anonce850f620111::LowerStaticTensorListPass92   StringRef getArgument() const final {
93     // This is the argument used to refer to the pass in
94     // the textual format (on the commandline for example).
95     return "tfl-lower-static-tensor-list";
96   }
getDescriptionmlir::__anonce850f620111::LowerStaticTensorListPass97   StringRef getDescription() const final {
98     // This is a brief description of the pass.
99     return "Lower TensorList ops within TensorFlow Lite dialect";
100   }
101 
102   void runOnOperation() override;
103 
104   Option<bool> allow_tensorlist_pass_through{
105       *this, "allow-tensorlist-pass-through",
106       llvm::cl::desc(
107           "When specified to true, if the tensorlist ops can't be properly "
108           "legalized by this pass, then the IR won't be changed so that "
109           "tensorlist ops can pass through (default false)"),
110       llvm::cl::init(false)};
111 };
112 
CreateI32SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int32_t val)113 Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
114                           ArrayRef<int64_t> shape, int32_t val) {
115   RankedTensorType type =
116       RankedTensorType::get(shape, rewriter->getIntegerType(32));
117   DenseElementsAttr attr =
118       DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
119   return rewriter->create<ConstantOp>(loc, type, attr);
120 }
121 
CreateI64SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int64_t val)122 Value CreateI64SplatConst(Location loc, PatternRewriter *rewriter,
123                           ArrayRef<int64_t> shape, int64_t val) {
124   RankedTensorType type =
125       RankedTensorType::get(shape, rewriter->getIntegerType(64));
126   DenseElementsAttr attr =
127       DenseElementsAttr::get(type, rewriter->getI64IntegerAttr(val));
128   return rewriter->create<ConstantOp>(loc, type, attr);
129 }
130 
CreateI32SplatTensor(Location loc,PatternRewriter * rewriter,Value shape_tensor,int32_t val)131 Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
132                            Value shape_tensor, int32_t val) {
133   Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
134   return rewriter->create<TF::FillOp>(
135       loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
136       shape_tensor, scalar_val);
137 }
138 
139 // Returns a new type by prepending the specified dimension to the shape of
140 // the given type if it is a ranked type.
PrependLeadingDimIfRanked(int64_t dim,Type type,PatternRewriter * rewriter)141 Type PrependLeadingDimIfRanked(int64_t dim, Type type,
142                                PatternRewriter *rewriter) {
143   Type dtype = getElementTypeOrSelf(type);
144   if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
145     llvm::SmallVector<int64_t, 4> shape = {dim};
146     shape.append(ty.getShape().begin(), ty.getShape().end());
147     return RankedTensorType::get(shape, dtype);
148   }
149   return type;
150 }
151 
GetTensorTypeForTensorList(Type element_type,TF::VariantType handle_dtype,PatternRewriter * rewriter)152 Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
153                                 PatternRewriter *rewriter) {
154   // If the variant type in the output handle has item shape available, use it
155   // to derive the output shape by setting unknown leading dimension.
156   // Otherwise, result type will be of unranked type.
157   if (handle_dtype.getSubtypes().empty()) {
158     return UnrankedTensorType::get(element_type);
159   }
160   return PrependLeadingDimIfRanked(-1, handle_dtype.getSubtypes()[0], rewriter);
161 }
162 
163 // Gets the index of tensorlist arguments which size might get changed by the
164 // function.
GetResizedTensorListIndexes(FuncOp func,const llvm::SmallSet<int,4> & tensor_list_args)165 llvm::SmallSet<int, 4> GetResizedTensorListIndexes(
166     FuncOp func, const llvm::SmallSet<int, 4> &tensor_list_args) {
167   // `indexes` stores the argument index of tensorlists which size may get
168   // updated in the function.
169   llvm::SmallSet<int, 4> indexes;
170   for (BlockArgument &arg : func.getArguments()) {
171     if (tensor_list_args.contains(arg.getArgNumber())) {
172       for (const mlir::OpOperand &use : arg.getUses()) {
173         mlir::Operation *op = use.getOwner();
174         // Currently we only check if the tensorlist argument is consumed by
175         // `TensorListPushBack` or `TensorListResize`, since those are the only
176         // length-mutating ops supported in this pass.
177         if (llvm::isa<TF::TensorListPushBackOp>(op) ||
178             llvm::isa<TF::TensorListResizeOp>(op)) {
179           indexes.insert(arg.getArgNumber());
180         }
181       }
182     }
183   }
184   return indexes;
185 }
186 
187 // Creates a slice of the tensorlist `input_list`, starting from
188 // [start_index, 0, ...0], with size [size, -1, ...-1].
189 //
190 // Requires that `start_index` and `size` are scalar tensors and
191 // `item_position_shape` is a 1-D tensor with only one element equal to the rank
192 // of an item in the tensorlist.
CreateSliceOpForTensorList(Location loc,Value input_list,Value start_index,Value size,Value item_rank,Type result_type,PatternRewriter * rewriter)193 TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
194                                        Value start_index, Value size,
195                                        Value item_rank, Type result_type,
196                                        PatternRewriter *rewriter) {
197   // Create the start position of slice. This is done by concatenating
198   // `start_index` and `partial_start_position` together.
199   IntegerType shape_dtype = rewriter->getIntegerType(32);
200   RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
201   Value partial_start_position =
202       CreateI32SplatTensor(loc, rewriter, item_rank, 0);
203   Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
204   RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
205   auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
206       loc, vector_type, start_index, scalar_zero);
207   auto start_position = rewriter->create<TF::ConcatOp>(
208       loc, position_type, scalar_zero,
209       ArrayRef<Value>({expanded_start_index, partial_start_position}));
210 
211   // Create the slice size tensor. This is done by concatenating `size` and
212   // `partial_size`.
213   auto size_leading_dim =
214       rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
215   Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
216   auto slice_size = rewriter->create<TF::ConcatOp>(
217       loc, position_type, scalar_zero,
218       ArrayRef<Value>({size_leading_dim, partial_size}));
219 
220   return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
221                                        start_position, slice_size);
222 }
223 
224 template <typename OpT>
225 class TensorListOpConverterBase : public OpConversionPattern<OpT> {
226  public:
TensorListOpConverterBase(MLIRContext * context,bool allow_tensorlist_pass_through)227   explicit TensorListOpConverterBase<OpT>(MLIRContext *context,
228                                           bool allow_tensorlist_pass_through)
229       : OpConversionPattern<OpT>::OpConversionPattern(context),
230         allow_tensorlist_pass_through_(allow_tensorlist_pass_through) {}
231 
232  protected:
233   // This flag will control the behavior of error emitting during rewrite:
234   // 1) If it's true, then patterns will only emit errors during debug or
235   // tracing mode. 2) If it's false, then patterns will emit standard errors
236   // when there is a rewrite failure.
237   bool allow_tensorlist_pass_through_;
238 };
239 
240 // Converts tf.Const containing variant of type TensorList to a tensor of
241 // primitive element types. Each of the individual tensor in the list is
242 // converted to an ElementsAttr and then those are packed together using
243 // tf.Pack op.
244 struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
245   using OpConversionPattern::OpConversionPattern;
246 
matchAndRewritemlir::__anonce850f620111::ConvertConst247   LogicalResult matchAndRewrite(
248       TF::ConstOp op, ArrayRef<Value> operands,
249       ConversionPatternRewriter &rewriter) const override {
250     // Verify that the opaque elements attribute contains tensor of type variant
251     // and scalar shape. The variant type should hold a TensorList.
252     auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
253     if (!opaque_attr) return failure();
254     tensorflow::Tensor tensor;
255     if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
256       return failure();
257     if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
258     if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
259       return failure();
260 
261     const tensorflow::TensorList *list =
262         tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
263     if (!list) return failure();
264 
265     // Verify output type is variant and contains exactly one ranked subtypes.
266     auto variant_ty =
267         getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
268     if (!variant_ty) return failure();
269     ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
270     if (subtypes.size() != 1) return failure();
271     RankedTensorType list_element_ty =
272         subtypes.front().dyn_cast<RankedTensorType>();
273     if (!list_element_ty) return failure();
274 
275     // Extract tensor elements for the TensorList and construct result type
276     // based on the number of elements and element shape.
277     const std::vector<tensorflow::Tensor> &tensors = list->tensors();
278     llvm::SmallVector<int64_t, 4> result_shape = {
279         static_cast<int64_t>(tensors.size())};
280     result_shape.append(list_element_ty.getShape().begin(),
281                         list_element_ty.getShape().end());
282     auto result_ty =
283         RankedTensorType::get(result_shape, list_element_ty.getElementType());
284 
285     // If the list is empty, directly create the final result instead of
286     // creating the tf.Pack op. tf.Pack op requires at least one operand.
287     if (tensors.empty()) {
288       tensorflow::Tensor tensor(list->element_dtype,
289                                 tensorflow::TensorShape(result_shape));
290       auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
291       if (!attr_or.ok()) return failure();
292       rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
293       return success();
294     }
295 
296     // Extract individual tensor list element and combine them using the tf.Pack
297     // op.
298     Location loc = op.getLoc();
299     llvm::SmallVector<Value, 4> values;
300     values.reserve(tensors.size());
301     for (const tensorflow::Tensor &tensor : tensors) {
302       auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
303       if (!attr_or.ok()) return failure();
304 
305       auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
306       values.push_back(value);
307     }
308     rewriter.replaceOpWithNewOp<TF::PackOp>(
309         op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
310     return success();
311   }
312 };
313 
314 struct ConvertTensorListSetItem
315     : public OpConversionPattern<TF::TensorListSetItemOp> {
316   using OpConversionPattern::OpConversionPattern;
317 
318   // This function rewrites the original op into a series of slice and concat op
319   // to produce the same result. It first slices the first `$index` rows. Then
320   // expands the dimension of the `$item`, followed by another slice of the
321   // remaining rows starting from `$index` + 1. Lastly it concatenates the
322   // three parts together.
323   // On a high level, it's doing something like:
324   // def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
325   //      (Concat
326   //        concat_dim = 0,
327   //        (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
328   //        0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
329   //        $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
matchAndRewritemlir::__anonce850f620111::ConvertTensorListSetItem330   LogicalResult matchAndRewrite(
331       TF::TensorListSetItemOp op, ArrayRef<Value> operands,
332       ConversionPatternRewriter &rewriter) const override {
333     Location loc = op.getLoc();
334     Value input = operands[0];
335     Value index = operands[1];
336     Value item = operands[2];
337 
338     IntegerType shape_dtype = rewriter.getIntegerType(32);
339     auto item_rank = rewriter.create<TF::RankOp>(
340         loc, RankedTensorType::get({}, shape_dtype), item);
341     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
342 
343     // Calculate `index` + 1, which is used to generate the start position for
344     // the second slice op.
345     auto suffix_start =
346         rewriter.create<TF::AddOp>(loc, index.getType(), index,
347                                    CreateI32SplatConst(loc, &rewriter, {}, 1));
348 
349     auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
350         loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
351     // Create two slice ops.
352     Type element_type = input.getType().cast<TensorType>().getElementType();
353     UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
354     Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
355     TF::SliceOp slice1 =
356         CreateSliceOpForTensorList(loc, /*input_list=*/input,
357                                    /*start_index=*/scalar_zero,
358                                    /*size=*/index,
359                                    /*item_rank=*/item_position_shape,
360                                    /*result_type=*/unranked_tensor, &rewriter);
361     TF::SliceOp slice2 =
362         CreateSliceOpForTensorList(loc, /*input_list=*/input,
363                                    /*start_index=*/suffix_start,
364                                    /*size=*/scalar_minus_one,
365                                    /*item_rank=*/item_position_shape,
366                                    /*result_type=*/unranked_tensor, &rewriter);
367 
368     // Expand the dimension of item so that it will have the same rank with
369     // input.
370     auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
371         op.getLoc(), unranked_tensor, item, scalar_zero);
372 
373     // Concatenate three parts together to generate the final result.
374     rewriter.replaceOpWithNewOp<TF::ConcatOp>(
375         op, input.getType(), scalar_zero,
376         ArrayRef<Value>({slice1, expanded_item, slice2}));
377     return success();
378   }
379 };
380 
381 // Rewrites op of the template type initializing a TensorList with a list of ops
382 // to generate an equivalent raw tensor. Derived classes are required to
383 // override GetNumElements method.
384 template <typename OpT>
385 struct ConvertTensorListInitOp : public TensorListOpConverterBase<OpT> {
386   using TensorListOpConverterBase<OpT>::TensorListOpConverterBase;
387   using TensorListOpConverterBase<OpT>::allow_tensorlist_pass_through_;
388 
389   // Create and return a 1-d tensor with exactly one element equal to the number
390   // of list elements to initialize the output tensor list with.
391   virtual Value GetNumElements(OpT op, ArrayRef<Value> operands,
392                                PatternRewriter *rewriter) const = 0;
393 
394   // Rewrites the original op into `tf.fill`. The result tensor shape is
395   // [num_element, element_shape]. All the values in the result tensor will be
396   // initialized to 0.
matchAndRewritemlir::__anonce850f620111::ConvertTensorListInitOp397   LogicalResult matchAndRewrite(
398       OpT op, ArrayRef<Value> operands,
399       ConversionPatternRewriter &rewriter) const override {
400     Type dtype = op.element_dtype();
401     if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
402           dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
403           dtype.isInteger(32) || dtype.isInteger(64))) {
404       const char *error_info =
405           "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
406           "integer or 16-bit/32-bit/64-bit float type during TF Lite "
407           "transformation pass";
408       return allow_tensorlist_pass_through_
409                  ? rewriter.notifyMatchFailure(op, error_info)
410                  : op.emitOpError(error_info);
411     }
412 
413     Value element_shape = operands[0];
414     Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
415     // If the `element_shape` is a scalar, we try to acquire its shape by
416     // looking at the first `TensorListSetItemOp` writing to this tensor list.
417     // Here we assume that the element_shape won't be changed before calling
418     // the first `TensorListSetItemOp`.
419     if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
420       if (shaped_type.hasRank() && shaped_type.getRank() == 0) {
421         bool element_shape_acquired = false;
422         auto uses = op.getResult().getUses();
423         for (auto &use : llvm::make_early_inc_range(uses)) {
424           if (TF::TensorListSetItemOp set_op =
425                   llvm::dyn_cast<TF::TensorListSetItemOp>(use.getOwner())) {
426             element_shape = rewriter.create<TF::ShapeOp>(
427                 op.getLoc(), RankedTensorType::get({-1}, shape_dtype),
428                 set_op.item());
429             element_shape_acquired = true;
430           } else if (TF::WhileOp while_op =
431                          llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
432             // Tensorlist is passed into a while loop, check inside the body
433             // function.
434             auto inside_uses = while_op.body_function()
435                                    .getArgument(use.getOperandNumber())
436                                    .getUses();
437             for (auto &inside_use : llvm::make_early_inc_range(inside_uses)) {
438               if (TF::TensorListSetItemOp set_op =
439                       llvm::dyn_cast<TF::TensorListSetItemOp>(
440                           inside_use.getOwner())) {
441                 if (auto shaped_type =
442                         set_op.item().getType().dyn_cast<ShapedType>()) {
443                   if (shaped_type.hasStaticShape()) {
444                     RankedTensorType type = RankedTensorType::get(
445                         {shaped_type.getRank()}, rewriter.getIntegerType(32));
446                     SmallVector<Attribute, 4> shape_attr;
447                     for (int64_t dim : shaped_type.getShape()) {
448                       shape_attr.push_back(rewriter.getI32IntegerAttr(dim));
449                     }
450                     DenseElementsAttr attr =
451                         DenseElementsAttr::get(type, shape_attr);
452                     element_shape =
453                         rewriter.create<ConstantOp>(op.getLoc(), type, attr);
454                     element_shape_acquired = true;
455                     break;
456                   }
457                 }
458               }
459             }
460           }
461           if (element_shape_acquired) break;
462         }
463         if (!element_shape_acquired) {
464           const char *error_info =
465               "requires element_shape to be 1D tensor during TF Lite "
466               "transformation pass";
467           return allow_tensorlist_pass_through_
468                      ? rewriter.notifyMatchFailure(op, error_info)
469                      : op.emitOpError(error_info);
470         }
471       }
472     }
473 
474     DenseIntElementsAttr dense_elem_attr;
475     if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
476       // Note: It's technically unsafe to rewrite
477       //     TensorListReserve(num_element, element_shape)
478       // to
479       //     Fill(Concat(num_element, element_shape), 0)
480       // because element_shape may contain -1 to represent unknown dimension.
481       //
482       // In real world use cases (e.g. Keras RNN), `element_shape` is usually
483       // a constant, and the first dimension of `element_shape` is usually
484       // batch dimension. Currently TFLiteConverter always rewrite unknown
485       // batch dimension to 1, therefore we also rewrite unknown dimension in
486       // `element_shape` to 1 here.
487       //
488       // This workaround enables converting Keras RNN without specifying batch
489       // dimension. This isn't guaranteed to work, but it doesn't break any
490       // non-broken cases either (since it's already broken if `element_shape`
491       // contains -1).
492       // TODO(b/142096690): Support dynamic element shape and remove the
493       // workaround.
494       SmallVector<int32_t, 4> new_element_shape_values;
495 
496       auto int_values = dense_elem_attr.getIntValues();
497       for (auto it = int_values.begin(); it != int_values.end(); ++it) {
498         auto dim_value = (*it).getSExtValue();
499         if (it == int_values.begin() && dim_value == -1) {
500           dim_value = 1;
501         }
502         new_element_shape_values.push_back(dim_value);
503       }
504 
505       auto attr = DenseIntElementsAttr::get(
506           element_shape.getType().cast<ShapedType>(), new_element_shape_values);
507       auto new_element_shape = rewriter.create<ConstantOp>(
508           op.getLoc(), element_shape.getType(), attr);
509       element_shape = new_element_shape;
510     }
511 
512     int64_t result_rank = -1;  // -1 means unknown result rank.
513     Type element_dtype = op.element_dtype();
514     Type result_type = UnrankedTensorType::get(element_dtype);
515     Value leading_dim = GetNumElements(op, operands, &rewriter);
516     if (auto element_type =
517             op.element_type().template dyn_cast<RankedTensorType>()) {
518       result_rank = element_type.getRank() + 1;
519       int64_t leading_dim_v = -1;
520       ElementsAttr element_attr;
521       if (matchPattern(leading_dim, m_Constant(&element_attr))) {
522         leading_dim_v = element_attr.getValue<IntegerAttr>(0).getInt();
523       }
524       SmallVector<int64_t, 4> result_shape = {leading_dim_v};
525       ArrayRef<int64_t> shape = element_type.getShape();
526       result_shape.append(shape.begin(), shape.end());
527       result_type = RankedTensorType::get(result_shape, element_dtype);
528     }
529 
530     // Create a 1-D RankedTensorType for result's shape. Number of elements in
531     // it is equal to the rank of the result, if known. Otherwise, the number of
532     // elements are unknown and represented with -1. In both cases, we can
533     // specify dimension using rank of the result.
534     Type shape_type = RankedTensorType::get({result_rank}, shape_dtype);
535 
536     Location loc = op.getLoc();
537     // Add number of elements as the prefix to the element shape to get shape of
538     // the output tensor.
539     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
540     auto list_shape = rewriter.create<TF::ConcatOp>(
541         loc, shape_type, scalar_zero,
542         ArrayRef<Value>({leading_dim, element_shape}));
543 
544     // Create a zero-initialized constant tensor that has the same type
545     // as specified by element_dtype.
546     RankedTensorType zero_type = RankedTensorType::get({}, element_dtype);
547     Attribute zero_attr = rewriter.getZeroAttr(zero_type);
548     auto zero = rewriter.create<ConstantOp>(loc, zero_type, zero_attr);
549 
550     rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
551     return success();
552   }
553 };
554 
555 struct ConvertTensorListReserve
556     : public ConvertTensorListInitOp<TF::TensorListReserveOp> {
ConvertTensorListReservemlir::__anonce850f620111::ConvertTensorListReserve557   explicit ConvertTensorListReserve(MLIRContext *context,
558                                     bool allow_tensorlist_pass_through)
559       : ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
560 
GetNumElementsmlir::__anonce850f620111::ConvertTensorListReserve561   Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
562                        PatternRewriter *rewriter) const override {
563     Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
564     Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
565     Value num_elements = operands[1];
566     IntegerAttr attr;
567     if (matchPattern(num_elements, m_Constant(&attr))) {
568       return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt());
569     }
570     if (auto const_op = num_elements.getDefiningOp<TF::ConstOp>()) {
571       return CreateI32SplatConst(
572           op->getLoc(), rewriter, {1},
573           (*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
574               .getSExtValue());
575     }
576     return rewriter->create<TF::ExpandDimsOp>(
577         op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
578         scalar_zero);
579   }
580 };
581 
582 // Note that we ignore the second operand `max_num_elements` as we don't have
583 // any restrictions on the number of elements we can support. So this may
584 // have a different behavior compared to TensorFlow in case of errors.
585 struct ConvertEmptyTensorList
586     : public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
ConvertEmptyTensorListmlir::__anonce850f620111::ConvertEmptyTensorList587   explicit ConvertEmptyTensorList(MLIRContext *context,
588                                   bool allow_tensorlist_pass_through)
589       : ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
590 
GetNumElementsmlir::__anonce850f620111::ConvertEmptyTensorList591   Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
592                        PatternRewriter *rewriter) const override {
593     return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
594   }
595 };
596 
597 struct ConvertTensorListPushBack
598     : public OpConversionPattern<TF::TensorListPushBackOp> {
599   using OpConversionPattern::OpConversionPattern;
600 
matchAndRewritemlir::__anonce850f620111::ConvertTensorListPushBack601   LogicalResult matchAndRewrite(
602       TF::TensorListPushBackOp op, ArrayRef<Value> operands,
603       ConversionPatternRewriter &rewriter) const override {
604     Value input_handle = operands[0];
605     Value item = operands[1];
606 
607     // Expand the shape of the item so that it will have rank same as the input
608     // tensor and it is compatible for the Concat Op.
609     Type expanded_item_type =
610         PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
611     Location loc = op.getLoc();
612     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
613     auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
614         loc, expanded_item_type, item, scalar_zero);
615 
616     Type elem_type = getElementTypeOrSelf(item);
617     auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
618                             .cast<TF::VariantType>();
619     Type result_type =
620         GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
621 
622     // Concatenate tensor stored in the input handle with the expanded item to
623     // get a tensor equivalent to the TensorList generated by this op.
624     rewriter.replaceOpWithNewOp<TF::ConcatOp>(
625         op, result_type, scalar_zero,
626         ArrayRef<Value>({input_handle, expanded_item}));
627     return success();
628   }
629 };
630 
631 // Rewrites `TensorListResize` op into a functional `If` op and several basic
632 // TF ops to match the op semantics of Tensorflow. Basically, it does:
633 // 1) If the requested size is smaller or equal than the input tensorlist's
634 // size, rewrite it to a Slice op so that only the first 'size' rows are
635 // returned. 2) If the requested size is larger than the input tensorlist's
636 // size. We need to create an additional tensorlist with 'size - input_size'
637 // elements, and append it to the end of the input tensorlist.
638 struct ConvertTensorListResize
639     : public OpConversionPattern<TF::TensorListResizeOp> {
640   using OpConversionPattern::OpConversionPattern;
641 
matchAndRewritemlir::__anonce850f620111::ConvertTensorListResize642   LogicalResult matchAndRewrite(
643       TF::TensorListResizeOp op, ArrayRef<Value> operands,
644       ConversionPatternRewriter &rewriter) const override {
645     Value input_handle = operands[0];
646     Value size = operands[1];
647 
648     Location loc = op.getLoc();
649     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
650 
651     // Compute the input tensorlist's length and store it in `input_size`.
652     IntegerType shape_dtype = rewriter.getIntegerType(32);
653     auto input_size = rewriter.create<TF::TensorListLengthOp>(
654         loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
655 
656     // Infer result type of this op based on TF's shape inference result.
657     Type elem_type = getElementTypeOrSelf(input_handle);
658     auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
659                             .cast<TF::VariantType>();
660     Type result_type =
661         GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
662 
663     // Compute the difference of `size` and `input_size`, and store it in
664     // `size_diff`, which is then consumed by `if_cond`.
665     auto size_diff = rewriter.create<TF::SubOp>(
666         loc, RankedTensorType::get({}, shape_dtype), size, input_size);
667     auto if_cond = rewriter.create<TF::GreaterOp>(
668         loc, RankedTensorType::get({}, rewriter.getI1Type()), size_diff,
669         scalar_zero);
670 
671     // Build the argument/result types for if branch function.
672     auto input_shape = rewriter.create<TF::ShapeOp>(
673         loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
674 
675     Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
676                                size_diff.getType(), size.getType()};
677     Type branch_result_type[] = {result_type};
678     auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
679                                        branch_result_type);
680 
681     // Create functions in a higher scope before restoring the insertion point.
682     // Additionally, create the SymbolTable before further modifying the module.
683     auto original_point = rewriter.saveInsertionPoint();
684     rewriter.setInsertionPointAfter(op->getParentOfType<FuncOp>());
685     SymbolTable manager(op->getParentOfType<ModuleOp>());
686 
687     // Constructs `then_branch`, which is executed when `if_cond` evaluates to
688     // true.
689     auto then_branch_op = rewriter.create<FuncOp>(loc, "cond_true", func_type);
690     CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
691                          &rewriter);
692 
693     // Constructs `else_branch`, which is executed when `if_cond` evaluates to
694     // false.
695     auto else_branch_op = rewriter.create<FuncOp>(loc, "cond_false", func_type);
696     CreateCondFalseBranch(loc, shape_dtype, result_type, else_branch_op,
697                           &rewriter);
698 
699     // Inserts the two blocks' names into the symbol table held by the module.
700     // Using SymbolTable will ensure that the inserted symbol names are
701     // unique.
702     manager.insert(then_branch_op);
703     manager.insert(else_branch_op);
704 
705     rewriter.restoreInsertionPoint(original_point);
706     rewriter.replaceOpWithNewOp<TF::IfOp>(
707         op, result_type, if_cond,
708         /*input=*/
709         ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
710         /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
711         /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
712         /*is_stateless=*/rewriter.getBoolAttr(true));
713     return success();
714   }
715 
716  private:
717   // When the input tensorlist's size is smaller than the requested size,
718   // then branch is executed.
719   // Create a new tensorlist of size 'size - input_size' and concat it
720   // with the input tensorlist.
CreateCondTrueBranchmlir::__anonce850f620111::ConvertTensorListResize721   void CreateCondTrueBranch(TF::TensorListResizeOp resize_op, Type shape_dtype,
722                             Type result_type, FuncOp branch_func,
723                             ConversionPatternRewriter *rewriter) const {
724     auto guard = OpBuilder::InsertionGuard(*rewriter);
725     Block *block =
726         rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
727                               branch_func.getType().getInputs());
728 
729     auto input_shape = block->getArgument(1);
730     auto size_diff = block->getArgument(2);
731     auto input = block->getArgument(0);
732 
733     Location loc = resize_op.getLoc();
734     // Get the element shape by slicing from index 1 in the input shape.
735     Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
736     Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
737     Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
738     auto elem_shape = rewriter->create<TF::SliceOp>(
739         loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
740         slice_size);
741     auto extended_part = rewriter->create<TF::TensorListReserveOp>(
742         loc, resize_op.output_handle().getType(), elem_shape, size_diff);
743     // `ConcatOp` expects non-variant-typed input. Insert a
744     // `TensorListStackOp` here to convert type from variant to non-variant.
745     // Note that we are using the same `result_type` for both the
746     // `TensorListStackOp` and `ConcatOp`, since the first dimension of the
747     // shape specified by `result_type` is -1.
748     auto stacked_extended_part = rewriter->create<TF::TensorListStackOp>(
749         loc, result_type, extended_part,
750         /*element_shape=*/CreateI32SplatConst(loc, rewriter, {}, -1),
751         /*num_elements=*/rewriter->getI32IntegerAttr(-1));
752     auto concat_op = rewriter->create<TF::ConcatOp>(
753         loc, result_type, scalar_zero,
754         ArrayRef<Value>({input, stacked_extended_part}));
755     rewriter->create<ReturnOp>(loc, ArrayRef<Value>({concat_op}));
756   }
757 
CreateCondFalseBranchmlir::__anonce850f620111::ConvertTensorListResize758   void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
759                              FuncOp branch_func,
760                              ConversionPatternRewriter *rewriter) const {
761     // When the input tensorlist's size is larger or equal than the requested
762     // size, the else branch is executed.
763     // Slice the first 'size' rows from the input tensorlist.
764     auto guard = OpBuilder::InsertionGuard(*rewriter);
765     Block *block =
766         rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
767                               branch_func.getType().getInputs());
768 
769     Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
770     Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
771     auto input = block->getArgument(0);
772     auto size = block->getArgument(3);
773 
774     // Subtract `input_rank` by 1 to get the item's rank, which is used as
775     // `partial_position_shape`.
776     auto input_rank = rewriter->create<TF::RankOp>(
777         loc, RankedTensorType::get({}, shape_dtype), input);
778     auto partial_position_shape = rewriter->create<TF::SubOp>(
779         loc, RankedTensorType::get({1}, shape_dtype), input_rank, vector_one);
780     auto slice_op =
781         CreateSliceOpForTensorList(loc, /*input_list=*/input,
782                                    /*start_index=*/scalar_zero, /*size=*/size,
783                                    /*item_rank=*/partial_position_shape,
784                                    /*result_type=*/result_type, rewriter);
785     rewriter->create<ReturnOp>(loc, ArrayRef<Value>({slice_op}));
786   }
787 };
788 
789 struct ConvertTensorListGetItem
790     : public OpConversionPattern<TF::TensorListGetItemOp> {
791   using OpConversionPattern::OpConversionPattern;
792 
matchAndRewritemlir::__anonce850f620111::ConvertTensorListGetItem793   LogicalResult matchAndRewrite(
794       TF::TensorListGetItemOp op, ArrayRef<Value> operands,
795       ConversionPatternRewriter &rewriter) const override {
796     Value input = operands[0];
797     Value index = operands[1];
798     rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
799                                               rewriter.getBoolAttr(true));
800     return success();
801   }
802 };
803 
804 struct ConvertTensorListLength
805     : public OpConversionPattern<TF::TensorListLengthOp> {
806   using OpConversionPattern::OpConversionPattern;
807 
matchAndRewritemlir::__anonce850f620111::ConvertTensorListLength808   LogicalResult matchAndRewrite(
809       TF::TensorListLengthOp op, ArrayRef<Value> operands,
810       ConversionPatternRewriter &rewriter) const override {
811     Location loc = op.getLoc();
812     Value input_handle = operands[0];
813 
814     BoolAttr true_attr = rewriter.getBoolAttr(true);
815     auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
816                                               /*use_32bit=*/true_attr);
817     rewriter.replaceOpWithNewOp<TF::GatherOp>(
818         op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
819         /*validate_indices=*/true_attr);
820     return success();
821   }
822 };
823 
824 struct ConvertTensorListStack
825     : public OpConversionPattern<TF::TensorListStackOp> {
826   using OpConversionPattern::OpConversionPattern;
827 
matchAndRewritemlir::__anonce850f620111::ConvertTensorListStack828   LogicalResult matchAndRewrite(
829       TF::TensorListStackOp op, ArrayRef<Value> operands,
830       ConversionPatternRewriter &rewriter) const override {
831     Location loc = op.getLoc();
832     Value input = operands[0];
833     Value element_shape = operands[1];
834 
835     // If the `element_shape` is a known constant (which is defined when calling
836     // `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
837     // trivial Reshape op (that doesn't actually change the input's shape) and
838     // also populate the shape info to the op result. The shape of the
839     // tensorlist is inferred from `num_elements` and `element_shape`.
840     auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
841     DenseIntElementsAttr dense_elem_attr;
842     if ((ranked_type && ranked_type.getRank() == 0) ||
843         !matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
844       // If no constant is spotted, just forward the operand.
845       rewriter.replaceOp(op, {input});
846       return success();
847     }
848 
849     RankedTensorType shape_type =
850         RankedTensorType::get({-1}, rewriter.getIntegerType(32));
851     auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
852     SmallVector<int64_t, 8> output_shape(/*Size=*/1, op.num_elements());
853     for (const auto &dim : dense_elem_attr.getIntValues())
854       output_shape.push_back(dim.getSExtValue());
855     RankedTensorType result_type =
856         RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
857     rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
858                                                new_shape);
859     return success();
860   }
861 };
862 
863 // Converts `TensorListConcatV2` into Unpack and Concat. First we unpack
864 // the input tensorlist along the first dimension, which results in N (where N
865 // is the first dim's size) tensors (each with shape [element_shape]). Then
866 // we concatenate all those tensors along the first dimension.
867 // The pattern will be rejected if either `element_shape` is not constant, or
868 // the first dimension of `input` is not known.
869 struct ConvertTensorListConcatV2
870     : public TensorListOpConverterBase<TF::TensorListConcatV2Op> {
871   using TensorListOpConverterBase<
872       TF::TensorListConcatV2Op>::TensorListOpConverterBase;
873   using TensorListOpConverterBase<
874       TF::TensorListConcatV2Op>::allow_tensorlist_pass_through_;
875 
matchAndRewritemlir::__anonce850f620111::ConvertTensorListConcatV2876   LogicalResult matchAndRewrite(
877       TF::TensorListConcatV2Op op, ArrayRef<Value> operands,
878       ConversionPatternRewriter &rewriter) const override {
879     Location loc = op.getLoc();
880     Value input = operands[0];
881     Value element_shape = operands[1];
882 
883     // Only match when `element_shape` is a constant.
884     DenseIntElementsAttr dense_elem_attr;
885     if (!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
886       const char *error_info = "requires element_shape to be a constant";
887       return allow_tensorlist_pass_through_
888                  ? rewriter.notifyMatchFailure(op, error_info)
889                  : op.emitOpError(error_info);
890     }
891     llvm::SmallVector<int64_t, 4> output_shape;
892     for (const auto &dim : dense_elem_attr.getIntValues()) {
893       output_shape.push_back(dim.getSExtValue());
894     }
895 
896     // First unpack the input tensor along the first dimension.
897     Type input_element_type = getElementTypeOrSelf(input);
898     int64_t num_unpacked = 0;
899     if (auto type = input.getType().dyn_cast<RankedTensorType>()) {
900       if (type.getDimSize(0) > 0) {
901         num_unpacked = type.getDimSize(0);
902       } else {
903         const char *error_info =
904             "requires the first dimension of input tensor to have > 0 "
905             "dimension";
906         return allow_tensorlist_pass_through_
907                    ? rewriter.notifyMatchFailure(op, error_info)
908                    : op.emitOpError(error_info);
909       }
910     }
911     llvm::SmallVector<Type, 1> unpack_output_type;
912     unpack_output_type.insert(
913         unpack_output_type.begin(), num_unpacked,
914         RankedTensorType::get(output_shape, input_element_type));
915     auto unpack = rewriter.create<TF::UnpackOp>(loc, unpack_output_type, input,
916                                                 /*axis=*/0);
917 
918     // Concatenate the unpacked tensors along the first dimension.
919     // Since we're concatenating along first dimension, change its dim size to
920     // -1.
921     output_shape[0] = -1;
922     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
923     auto concat = rewriter.create<TF::ConcatOp>(
924         loc, RankedTensorType::get(output_shape, input_element_type),
925         scalar_zero, unpack->getResults());
926     // `lengths` is only useful for computing gradient. For now we just return
927     // a placeholder tensor.
928     rewriter.replaceOp(
929         op, {concat.getResult(), CreateI64SplatConst(loc, &rewriter, {0}, 0)});
930     return success();
931   }
932 };
933 
934 struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
935   using OpConversionPattern::OpConversionPattern;
936 
matchAndRewritemlir::__anonce850f620111::ConvertIdentity937   LogicalResult matchAndRewrite(
938       TF::IdentityOp op, ArrayRef<Value> operands,
939       ConversionPatternRewriter &rewriter) const override {
940     Value input = operands[0];
941     rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
942                                                 op->getAttrs());
943     return success();
944   }
945 };
946 
947 // Returns an unranked tensor type with an element of the same type as `value`
948 // if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
VariantToUnrankedTensorType(Type type,Value value)949 Type VariantToUnrankedTensorType(Type type, Value value) {
950   TF::VariantType variant_ty =
951       getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
952   if (!variant_ty) {
953     return type;
954   }
955   if (!variant_ty.getSubtypes().empty()) {
956     // Short-circut if the variant type has subtype info.
957     return UnrankedTensorType::get(
958         variant_ty.getSubtypes()[0].getElementType());
959   }
960   Type value_type = value.getType();
961   Type element_type;
962   variant_ty = value_type.dyn_cast<TF::VariantType>();
963   if (variant_ty && !variant_ty.getSubtypes().empty()) {
964     element_type = variant_ty.getSubtypes()[0].getElementType();
965   } else {
966     element_type = getElementTypeOrSelf(value_type);
967   }
968   return UnrankedTensorType::get(element_type);
969 }
970 
971 // Returns true if we can deduce the type is tensorlist.
IsTensorListType(Type type,llvm::Optional<Value> value)972 bool IsTensorListType(Type type, llvm::Optional<Value> value) {
973   TF::VariantType variant_ty =
974       getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
975   if (!variant_ty) {
976     return false;
977   }
978   // Check there is only one subtype contained in the variant type. Note that
979   // when `subtypes.size() == 1` does not always mean the type is actually
980   // a tensorlist. We probably need some form of data flow analysis.
981   if (variant_ty.getSubtypes().size() == 1) {
982     return true;
983   }
984   // If subtype info is not available, check if the value is used by any of
985   // the following TensorList operations.
986   if (!value.hasValue()) {
987     return false;
988   }
989   for (const mlir::OpOperand &use : value.getValue().getUses()) {
990     mlir::Operation *op = use.getOwner();
991     if (llvm::isa<TF::TensorListGetItemOp>(op) ||
992         llvm::isa<TF::TensorListLengthOp>(op) ||
993         llvm::isa<TF::TensorListPushBackOp>(op) ||
994         llvm::isa<TF::TensorListReserveOp>(op) ||
995         llvm::isa<TF::TensorListSetItemOp>(op) ||
996         llvm::isa<TF::TensorListStackOp>(op) ||
997         llvm::isa<TF::TensorListResizeOp>(op)) {
998       return true;
999     }
1000   }
1001   return false;
1002 }
1003 
1004 // Returns a set of integers that correspond to the tensorlist arguments in
1005 // the function.
GetTensorListArgumentsIndex(FuncOp func)1006 llvm::SmallSet<int, 4> GetTensorListArgumentsIndex(FuncOp func) {
1007   llvm::SmallSet<int, 4> set;
1008   for (const auto &arg_and_idx : llvm::enumerate(func.getArguments())) {
1009     if (IsTensorListType(arg_and_idx.value().getType(), arg_and_idx.value())) {
1010       set.insert(arg_and_idx.index());
1011     }
1012   }
1013   return set;
1014 }
1015 
1016 // Returns a set of integers that correspond to the tensorlist results in the
1017 // function.
GetTensorListResultsIndex(FuncOp func)1018 llvm::SmallSet<int, 4> GetTensorListResultsIndex(FuncOp func) {
1019   llvm::SmallSet<int, 4> set;
1020 
1021   for (const auto &result_and_idx :
1022        llvm::enumerate(func.getType().getResults())) {
1023     if (IsTensorListType(result_and_idx.value(), llvm::None)) {
1024       set.insert(result_and_idx.index());
1025     }
1026   }
1027   return set;
1028 }
1029 
1030 // Updates the tensorlist types based on the input index. If the tensorlist's
1031 // size isn't changed(which is indicated by `resized_tensor_list_index`), then
1032 // we will use the original operand's type, otherwise update it with the
1033 // unranked tensor type.
1034 template <typename R>
UpdateTensorListTypes(const llvm::SmallSet<int,4> & tensor_list_index,const llvm::SmallSet<int,4> & resized_tensor_list_index,ArrayRef<Type> types,R && range,ArrayRef<Value> operands,llvm::SmallVectorImpl<Type> * updated_types)1035 void UpdateTensorListTypes(
1036     const llvm::SmallSet<int, 4> &tensor_list_index,
1037     const llvm::SmallSet<int, 4> &resized_tensor_list_index,
1038     ArrayRef<Type> types, R &&range, ArrayRef<Value> operands,
1039     llvm::SmallVectorImpl<Type> *updated_types) {
1040   int i = 0;
1041   for (const auto it : llvm::zip(types, range, operands)) {
1042     if (tensor_list_index.count(i)) {
1043       // Only change the tensorlist's type to unranked tensor if it has been
1044       // resized.
1045       if (resized_tensor_list_index.count(i)) {
1046         updated_types->push_back(
1047             VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1048       } else {
1049         updated_types->push_back(std::get<2>(it).getType());
1050       }
1051     } else {
1052       updated_types->push_back(std::get<0>(it));
1053     }
1054     ++i;
1055   }
1056 }
1057 
1058 // Updates the tensorlist types to unranked tensor types based on the input
1059 // index.
1060 template <typename R>
ChangeVariantToUnrankedTensorType(const llvm::SmallSet<int,4> & tensor_list_index,ArrayRef<Type> types,R && range,llvm::SmallVectorImpl<Type> * updated_types)1061 void ChangeVariantToUnrankedTensorType(
1062     const llvm::SmallSet<int, 4> &tensor_list_index, ArrayRef<Type> types,
1063     R &&range, llvm::SmallVectorImpl<Type> *updated_types) {
1064   int i = 0;
1065   for (const auto it : llvm::zip(types, range)) {
1066     if (tensor_list_index.count(i)) {
1067       updated_types->push_back(
1068           VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1069     } else {
1070       updated_types->push_back(std::get<0>(it));
1071     }
1072     ++i;
1073   }
1074 }
1075 
1076 // Updates the specified function's type and region signature.
UpdateFunctionAndRegionType(ConversionPatternRewriter & rewriter,FuncOp func,llvm::ArrayRef<Type> updated_argument_types,llvm::ArrayRef<Type> updated_result_types)1077 void UpdateFunctionAndRegionType(ConversionPatternRewriter &rewriter,
1078                                  FuncOp func,
1079                                  llvm::ArrayRef<Type> updated_argument_types,
1080                                  llvm::ArrayRef<Type> updated_result_types) {
1081   // Change `func`'s argument type to `unranked_argument_types`. If its
1082   // return types contain a `DT_VARIANT`, change it to the unranked type
1083   // derived from the corresponding argument.
1084   rewriter.updateRootInPlace(func, [&] {
1085     func.setType(FunctionType::get(func.getContext(), updated_argument_types,
1086                                    updated_result_types));
1087   });
1088   Region &entry = func.getRegion();
1089   TypeConverter::SignatureConversion signature_conversion(
1090       entry.getNumArguments());
1091   for (const BlockArgument &arg : entry.getArguments()) {
1092     signature_conversion.addInputs(arg.getArgNumber(),
1093                                    updated_argument_types[arg.getArgNumber()]);
1094   }
1095   rewriter.applySignatureConversion(&entry, signature_conversion);
1096 }
1097 
1098 // Changes the function type of `cond_func` and `body_func` for the given While
1099 // op.
UpdateFunctionTypesForWhileOp(ConversionPatternRewriter & rewriter,TF::WhileOp op,ArrayRef<Value> operands,const llvm::SmallSet<int,4> & tensor_list_args,const llvm::SmallSet<int,4> & resized_tensor_lists)1100 LogicalResult UpdateFunctionTypesForWhileOp(
1101     ConversionPatternRewriter &rewriter, TF::WhileOp op,
1102     ArrayRef<Value> operands, const llvm::SmallSet<int, 4> &tensor_list_args,
1103     const llvm::SmallSet<int, 4> &resized_tensor_lists) {
1104   int func_index = 0;
1105   for (FuncOp func : {op.cond_function(), op.body_function()}) {
1106     ++func_index;
1107     if (!func) continue;
1108 
1109     FunctionType func_type = func.getType();
1110     int num_inputs = func_type.getNumInputs();
1111     int num_results = func_type.getNumResults();
1112 
1113     // For each argument type in function's arguments, change it to uranked
1114     // tensor type if it's a variant type.
1115     SmallVector<Type, 8> updated_argument_types;
1116     updated_argument_types.reserve(num_inputs);
1117     UpdateTensorListTypes<mlir::OperandRange>(
1118         tensor_list_args, resized_tensor_lists, func_type.getInputs(),
1119         op.getOperands(), operands, &updated_argument_types);
1120 
1121     // Change all DT_VARIANT result types in function results to unranked tensor
1122     // type with element type derived from the corresponding input operand. This
1123     // is correct because while body's inputs and results have the same type.
1124     SmallVector<Type, 8> updated_result_types;
1125     updated_result_types.reserve(num_results);
1126     if (func_index == 1) {
1127       // We only update the result types for the body function.
1128       for (Type ty : func_type.getResults()) {
1129         updated_result_types.push_back(ty);
1130       }
1131     } else {
1132       UpdateTensorListTypes<mlir::OperandRange>(
1133           tensor_list_args, resized_tensor_lists, func_type.getResults(),
1134           op.getOperands(), operands, &updated_result_types);
1135     }
1136 
1137     UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
1138                                 updated_result_types);
1139   }
1140   return success();
1141 }
1142 
1143 // Changes the function type of `then_function` and `else_function` for the
1144 // given If op.
UpdateFunctionTypesForIfOp(ConversionPatternRewriter & rewriter,TF::IfOp op,llvm::ArrayRef<Value> operands,const llvm::SmallSet<int,4> & tensor_list_args,const llvm::SmallSet<int,4> & resized_tensor_lists,llvm::ArrayRef<Type> updated_result_types)1145 LogicalResult UpdateFunctionTypesForIfOp(
1146     ConversionPatternRewriter &rewriter, TF::IfOp op,
1147     llvm::ArrayRef<Value> operands,
1148     const llvm::SmallSet<int, 4> &tensor_list_args,
1149     const llvm::SmallSet<int, 4> &resized_tensor_lists,
1150     llvm::ArrayRef<Type> updated_result_types) {
1151   for (FuncOp func : {op.else_function(), op.then_function()}) {
1152     if (!func) continue;
1153 
1154     FunctionType func_type = func.getType();
1155     int num_inputs = func_type.getNumInputs();
1156 
1157     // Update the argument types of the function. If it's a tensorlist and
1158     // is not resized inside the function, we will use the corresponding
1159     // operand's type, otherwise change its type to unranked tensor type.
1160     SmallVector<Type, 8> updated_argument_types;
1161     updated_argument_types.reserve(num_inputs);
1162     UpdateTensorListTypes<mlir::OperandRange>(
1163         tensor_list_args, resized_tensor_lists, func_type.getInputs(),
1164         op.getOperands().drop_front(), operands.drop_front(),
1165         &updated_argument_types);
1166 
1167     UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
1168                                 updated_result_types);
1169   }
1170   return success();
1171 }
1172 
1173 // Returns a `llvm::DenseMap` which maps from the index of tensorlist in the
1174 // result, to the index of the same tensorlist in the arguments. For `If` op's
1175 // branch functions, the results and arguments are not usually matched 1-1. This
1176 // will let us konw which tensorlist result maps to which tensorlist in the
1177 // arguments. Once we know this info it will help us decide the types of the
1178 // result tensorlist based on the operand's of the `If` op.
MapTensorListResultToArgument(FuncOp func)1179 llvm::DenseMap<int, int> MapTensorListResultToArgument(FuncOp func) {
1180   // `map_fn` will trace upwards along the use-def chain of the ssa value. It
1181   // starts from the last ssa value (returned by the function), and check its
1182   // parent op iteratively. If the root ssa value appears in the function's
1183   // argument list, it will return the index of the corresponding argument,
1184   // otherwise it will return -1.
1185   auto map_fn = [](Value value) -> int {
1186     Value parent = value;
1187     while (true) {
1188       if (auto identity = parent.getDefiningOp<TF::IdentityOp>()) {
1189         parent = identity.input();
1190       } else if (auto set_item =
1191                      parent.getDefiningOp<TF::TensorListSetItemOp>()) {
1192         parent = set_item.input_handle();
1193       } else {
1194         break;
1195       }
1196     }
1197     if (auto block_arg = parent.dyn_cast<mlir::BlockArgument>()) {
1198       return block_arg.getArgNumber();
1199     }
1200     // Returns -1 if we don't find which this result maps to.
1201     return -1;
1202   };
1203 
1204   llvm::SmallVector<Value, 4> returns;
1205   for (auto res : func.getBody().back().getTerminator()->getOperands()) {
1206     returns.push_back(res);
1207   }
1208   llvm::DenseMap<int, int> result;
1209   for (const auto &result_and_idx : llvm::enumerate(returns)) {
1210     if (IsTensorListType(result_and_idx.value().getType(),
1211                          result_and_idx.value())) {
1212       int arg_idx = map_fn(result_and_idx.value());
1213       if (arg_idx != -1) {
1214         result.insert({result_and_idx.index(), arg_idx});
1215       }
1216     }
1217   }
1218   return result;
1219 }
1220 
1221 // Updates the tensorlist result types for the `If` Op. If the tensorlist result
1222 // maps to a specific argument (indicated by `tensor_list_map`), and also that
1223 // tensorlist argument's shape isn't changed (indicated by
1224 // `resized_tensor_list_index`), we will update this tensorlist's result type to
1225 // the corresponding operand's type. In all other cases we change the
1226 // tensorlist's type to unranked tensor type.
1227 template <typename R>
UpdateTensorListResultTypesForIf(const llvm::SmallSet<int,4> & tensor_list_index,const llvm::SmallSet<int,4> & resized_tensor_list_index,const llvm::DenseMap<int,int> & tensor_list_map,ArrayRef<Type> types,R && range,ArrayRef<Value> operands,llvm::SmallVectorImpl<Type> * updated_types)1228 void UpdateTensorListResultTypesForIf(
1229     const llvm::SmallSet<int, 4> &tensor_list_index,
1230     const llvm::SmallSet<int, 4> &resized_tensor_list_index,
1231     const llvm::DenseMap<int, int> &tensor_list_map, ArrayRef<Type> types,
1232     R &&range, ArrayRef<Value> operands,
1233     llvm::SmallVectorImpl<Type> *updated_types) {
1234   int i = 0;
1235   for (const auto it : llvm::zip(types, range)) {
1236     if (!tensor_list_index.count(i)) {
1237       updated_types->push_back(std::get<0>(it));
1238       ++i;
1239       continue;
1240     }
1241     auto iter = tensor_list_map.find(i);
1242     if (iter != tensor_list_map.end()) {
1243       int arg_idx = iter->second;
1244       if (!resized_tensor_list_index.count(arg_idx)) {
1245         // If the mapped tensorlist argument's size isn't changed, we will
1246         // use the corresponding `operand` type.
1247         updated_types->push_back(operands[arg_idx].getType());
1248         ++i;
1249         continue;
1250       }
1251     }
1252     updated_types->push_back(
1253         VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1254     ++i;
1255   }
1256 }
1257 
1258 struct ConvertIf : public OpConversionPattern<TF::IfOp> {
1259   using OpConversionPattern::OpConversionPattern;
1260 
matchAndRewritemlir::__anonce850f620111::ConvertIf1261   LogicalResult matchAndRewrite(
1262       TF::IfOp op, ArrayRef<Value> operands,
1263       ConversionPatternRewriter &rewriter) const override {
1264     // Find all Tensor List arugments.
1265     auto tensor_list_args = GetTensorListArgumentsIndex(op.else_function());
1266     auto tensor_list_results = GetTensorListResultsIndex(op.else_function());
1267     auto tensor_list_map = MapTensorListResultToArgument(op.else_function());
1268     llvm::SmallSet<int, 4> resized_tensor_lists =
1269         GetResizedTensorListIndexes(op.else_function(), tensor_list_args);
1270 
1271     llvm::SmallVector<Type, 8> result_types;
1272     result_types.reserve(op.getNumResults());
1273     llvm::SmallVector<Type, 4> op_result_types;
1274     for (Type ty : op.getResultTypes()) {
1275       op_result_types.push_back(ty);
1276     }
1277 
1278     UpdateTensorListResultTypesForIf<mlir::ResultRange>(
1279         tensor_list_results, resized_tensor_lists, tensor_list_map,
1280         op_result_types, op->getResults(), operands.drop_front(),
1281         &result_types);
1282 
1283     // Create a new if op with new operands and updated result types.
1284     auto converted = rewriter.create<TF::IfOp>(op.getLoc(), result_types,
1285                                                operands, op->getAttrs());
1286     converted->removeAttr("T");
1287     (void)UpdateFunctionTypesForIfOp(rewriter, converted, operands,
1288                                      tensor_list_args, resized_tensor_lists,
1289                                      result_types);
1290     rewriter.replaceOp(op, converted.getResults());
1291     return success();
1292   }
1293 };
1294 
1295 struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
1296   using OpConversionPattern::OpConversionPattern;
1297 
matchAndRewritemlir::__anonce850f620111::ConvertWhile1298   LogicalResult matchAndRewrite(
1299       TF::WhileOp op, ArrayRef<Value> operands,
1300       ConversionPatternRewriter &rewriter) const override {
1301     // Find all Tensor List arugments.
1302     auto tensor_list_args = GetTensorListArgumentsIndex(op.body_function());
1303 
1304     llvm::SmallVector<Type, 8> result_types;
1305     result_types.reserve(op.getNumOperands());
1306     // Change all DT_VARIANT result types to unranked tensor type.
1307     llvm::SmallVector<Type, 4> op_result_types;
1308     for (Type ty : op.getResultTypes()) {
1309       op_result_types.push_back(ty);
1310     }
1311 
1312     llvm::SmallSet<int, 4> resized_tensor_lists =
1313         GetResizedTensorListIndexes(op.body_function(), tensor_list_args);
1314     UpdateTensorListTypes<mlir::OperandRange>(
1315         tensor_list_args, resized_tensor_lists, op_result_types,
1316         op.getOperands(), operands, &result_types);
1317 
1318     // Create a new while op with new operands and updated result types.
1319     auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
1320                                                   operands, op->getAttrs());
1321     converted->removeAttr("T");
1322     (void)UpdateFunctionTypesForWhileOp(rewriter, converted, operands,
1323                                         tensor_list_args, resized_tensor_lists);
1324 
1325     rewriter.replaceOp(op, converted.getResults());
1326     return success();
1327   }
1328 };
1329 
1330 struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
1331   using OpConversionPattern::OpConversionPattern;
1332 
matchAndRewritemlir::__anonce850f620111::ConvertWhileRegion1333   LogicalResult matchAndRewrite(
1334       TF::WhileRegionOp op, ArrayRef<Value> operands,
1335       ConversionPatternRewriter &rewriter) const override {
1336     llvm::SmallVector<Type, 8> result_types;
1337     result_types.reserve(op.getNumOperands());
1338     // Change all DT_VARIANT result types to unranked tensor type.
1339     for (auto it : llvm::zip(op.getResultTypes(), operands))
1340       result_types.push_back(
1341           VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1342 
1343     // Create a new while op with new operands and updated result types.
1344     auto converted = rewriter.create<TF::WhileRegionOp>(
1345         op.getLoc(), result_types, operands, op->getAttrs());
1346 
1347     // Inline the regions from the old while into the new one, and apply
1348     // signature conversion to inlined region.
1349     for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
1350       Region &old_region = *std::get<0>(it);
1351       Region &new_region = *std::get<1>(it);
1352 
1353       Block &entry = old_region.front();
1354       // Build signature conversion for the region.
1355       TypeConverter::SignatureConversion signature_conversion(operands.size());
1356       for (auto it : llvm::zip(entry.getArguments(), operands)) {
1357         BlockArgument arg = std::get<0>(it);
1358         signature_conversion.addInputs(
1359             arg.getArgNumber(),
1360             VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
1361       }
1362 
1363       rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
1364       rewriter.applySignatureConversion(&new_region, signature_conversion);
1365     }
1366 
1367     rewriter.replaceOp(op, converted.getResults());
1368     return success();
1369   }
1370 };
1371 
1372 #include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
1373 
runOnOperation()1374 void LowerStaticTensorListPass::runOnOperation() {
1375   auto *context = &getContext();
1376 
1377   // TensorFlow operations that doesn't have operands and results of type
1378   // variant are legal. Here, we don't distinguish between variants encoding
1379   // TensorList or some other type as that information is not available here.
1380   // Partial legalization is used below to still allow ops with variant types
1381   // still.
1382   auto is_legal = [](Operation *op) {
1383     auto is_not_variant = [](Type ty) {
1384       return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
1385     };
1386     return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
1387            llvm::all_of(op->getResultTypes(), is_not_variant);
1388   };
1389 
1390   ConversionTarget target(*context);
1391   target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(is_legal);
1392   target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
1393                       TF::TensorListGetItemOp, TF::TensorListLengthOp,
1394                       TF::TensorListPushBackOp, TF::TensorListReserveOp,
1395                       TF::TensorListSetItemOp, TF::TensorListStackOp,
1396                       TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
1397   // TODO(hinsu): Use TFLite constant op for constants.
1398   target.addLegalOp<ConstantOp>();
1399   target.addLegalOp<FuncOp>();
1400   target.addLegalOp<ReturnOp>();
1401   target.addLegalOp<TFL::CustomOp>();
1402   // Register fused LSTM/RNN ops as legal.
1403   target.addLegalOp<TFL::LSTMOp>();
1404   target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
1405   target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
1406   target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
1407 
1408   OwningRewritePatternList patterns(&getContext());
1409   populateWithGenerated(patterns);
1410   patterns.insert<ConvertConst, ConvertIdentity, ConvertTensorListGetItem,
1411                   ConvertTensorListLength, ConvertTensorListPushBack,
1412                   ConvertTensorListSetItem, ConvertTensorListStack,
1413                   ConvertTensorListResize, ConvertWhile, ConvertWhileRegion,
1414                   ConvertIf>(context);
1415   patterns.insert<ConvertEmptyTensorList, ConvertTensorListReserve,
1416                   ConvertTensorListConcatV2>(context,
1417                                              allow_tensorlist_pass_through);
1418   ModuleOp module = getOperation();
1419   if (!allow_tensorlist_pass_through) {
1420     if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
1421       module.emitError(
1422           "Lowering tensor list ops is failed. Please consider using Select TF "
1423           "ops and disabling `_experimental_lower_tensor_list_ops` flag in the "
1424           "TFLite converter object. For example, "
1425           "converter.target_spec.supported_ops = "
1426           "[tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]\\n "
1427           "converter._experimental_lower_tensor_list_ops = False");
1428       signalPassFailure();
1429     }
1430   } else {
1431     // If `allow_tensorlist_pass_through` is set to true, if legalization fails
1432     // we should not leak the diagnostic info outside this pass. Hence we use
1433     // a `StatusScopedDiagnosticHandler` here to capture diagnostics generated
1434     // within this pass.
1435     StatusScopedDiagnosticHandler handler(context);
1436     if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
1437       auto _ = handler.ConsumeStatus();
1438     }
1439   }
1440 }
1441 
1442 }  // namespace
1443 
1444 /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
1445 /// pass.
CreateLowerStaticTensorListPass(bool allow_tensorlist_pass_through)1446 std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass(
1447     bool allow_tensorlist_pass_through) {
1448   return std::make_unique<LowerStaticTensorListPass>(
1449       allow_tensorlist_pass_through);
1450 }
1451 
1452 static PassRegistration<LowerStaticTensorListPass> pass;
1453 
1454 }  // namespace mlir
1455