1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting Tensorlist operations in TensorFlow dialect into operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements. The
19 // newly created operations are in the TensorFlow dialect if the operation can
20 // be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
21 // is used.
22
23 #include <climits>
24 #include <cstdint>
25
26 #include "absl/container/inlined_vector.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/None.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringSwitch.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/Debug.h"
35 #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
36 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
37 #include "mlir/IR/Attributes.h" // from @llvm-project
38 #include "mlir/IR/Block.h" // from @llvm-project
39 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
42 #include "mlir/IR/MLIRContext.h" // from @llvm-project
43 #include "mlir/IR/Matchers.h" // from @llvm-project
44 #include "mlir/IR/Operation.h" // from @llvm-project
45 #include "mlir/IR/OperationSupport.h" // from @llvm-project
46 #include "mlir/IR/PatternMatch.h" // from @llvm-project
47 #include "mlir/IR/SymbolTable.h" // from @llvm-project
48 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
49 #include "mlir/IR/Types.h" // from @llvm-project
50 #include "mlir/IR/Value.h" // from @llvm-project
51 #include "mlir/Pass/Pass.h" // from @llvm-project
52 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
53 #include "mlir/Support/LLVM.h" // from @llvm-project
54 #include "mlir/Support/LogicalResult.h" // from @llvm-project
55 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
56 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
57 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
58 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
59 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
60 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
61 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
62 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
63 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
64 #include "tensorflow/core/framework/tensor.h"
65 #include "tensorflow/core/framework/types.pb.h"
66 #include "tensorflow/core/kernels/tensor_list.h"
67
68 #define DEBUG_TYPE "tf-tfl-legalization"
69
70 //===----------------------------------------------------------------------===//
71 // The actual LowerStaticTensorList Pass.
72 //
73 namespace mlir {
74 namespace {
75
76 /// Lower TensorList ops in functions for subsequent legalization.
77 struct LowerStaticTensorListPass
78 : public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
79 LowerStaticTensorListPass() = default;
LowerStaticTensorListPassmlir::__anon3c12621d0111::LowerStaticTensorListPass80 LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
81
82 void runOnOperation() override;
83
84 Option<bool> allow_tensorlist_pass_through{
85 *this, "allow-tensorlist-pass-through",
86 llvm::cl::desc(
87 "When specified to true, if the tensorlist ops can't be properly "
88 "legalized by this pass, then the IR won't be changed so that "
89 "tensorlist ops can pass through (default false)"),
90 llvm::cl::init(false)};
91 };
92
CreateI32SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int32_t val)93 Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
94 ArrayRef<int64_t> shape, int32_t val) {
95 RankedTensorType type =
96 RankedTensorType::get(shape, rewriter->getIntegerType(32));
97 DenseElementsAttr attr =
98 DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
99 return rewriter->create<ConstantOp>(loc, type, attr);
100 }
101
CreateI32SplatTensor(Location loc,PatternRewriter * rewriter,Value shape_tensor,int32_t val)102 Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
103 Value shape_tensor, int32_t val) {
104 Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
105 return rewriter->create<TF::FillOp>(
106 loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
107 shape_tensor, scalar_val);
108 }
109
110 // Returns a new type by prepending the specified dimension to the shape of
111 // the given type if it is a ranked type.
PrependLeadingDimIfRanked(int64_t dim,Type type,PatternRewriter * rewriter)112 Type PrependLeadingDimIfRanked(int64_t dim, Type type,
113 PatternRewriter *rewriter) {
114 Type dtype = getElementTypeOrSelf(type);
115 if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
116 llvm::SmallVector<int64_t, 4> shape = {dim};
117 shape.append(ty.getShape().begin(), ty.getShape().end());
118 return RankedTensorType::get(shape, dtype);
119 }
120 return type;
121 }
122
GetTensorTypeForTensorList(Type element_type,TF::VariantType handle_dtype,PatternRewriter * rewriter)123 Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
124 PatternRewriter *rewriter) {
125 // If the variant type in the output handle has item shape available, use it
126 // to derive the output shape by setting unknown leading dimension.
127 // Otherwise, result type will be of unranked type.
128 if (handle_dtype.getSubtypes().empty()) {
129 return UnrankedTensorType::get(element_type);
130 }
131 return PrependLeadingDimIfRanked(-1, handle_dtype.getSubtypes()[0], rewriter);
132 }
133
134 // Creates a slice of the tensorlist `input_list`, starting from
135 // [start_index, 0, ...0], with size [size, -1, ...-1].
136 //
137 // Requires that `start_index` and `size` are scalar tensors and
138 // `item_position_shape` is a 1-D tensor with only one element equal to the rank
139 // of an item in the tensorlist.
CreateSliceOpForTensorList(Location loc,Value input_list,Value start_index,Value size,Value item_rank,Type result_type,PatternRewriter * rewriter)140 TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
141 Value start_index, Value size,
142 Value item_rank, Type result_type,
143 PatternRewriter *rewriter) {
144 // Create the start position of slice. This is done by concatenating
145 // `start_index` and `partial_start_position` together.
146 IntegerType shape_dtype = rewriter->getIntegerType(32);
147 RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
148 Value partial_start_position =
149 CreateI32SplatTensor(loc, rewriter, item_rank, 0);
150 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
151 RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
152 auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
153 loc, vector_type, start_index, scalar_zero);
154 auto start_position = rewriter->create<TF::ConcatOp>(
155 loc, position_type, scalar_zero,
156 ArrayRef<Value>({expanded_start_index, partial_start_position}));
157
158 // Create the slice size tensor. This is done by concatenating `size` and
159 // `partial_size`.
160 auto size_leading_dim =
161 rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
162 Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
163 auto slice_size = rewriter->create<TF::ConcatOp>(
164 loc, position_type, scalar_zero,
165 ArrayRef<Value>({size_leading_dim, partial_size}));
166
167 return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
168 start_position, slice_size);
169 }
170
171 // Converts tf.Const containing variant of type TensorList to a tensor of
172 // primitive element types. Each of the individual tensor in the list is
173 // converted to an ElementsAttr and then those are packed together using
174 // tf.Pack op.
175 struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
176 using OpConversionPattern::OpConversionPattern;
177
matchAndRewritemlir::__anon3c12621d0111::ConvertConst178 LogicalResult matchAndRewrite(
179 TF::ConstOp op, ArrayRef<Value> operands,
180 ConversionPatternRewriter &rewriter) const override {
181 // Verify that the opaque elements attribute contains tensor of type variant
182 // and scalar shape. The variant type should hold a TensorList.
183 auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
184 if (!opaque_attr) return failure();
185 tensorflow::Tensor tensor;
186 if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
187 return failure();
188 if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
189 if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
190 return failure();
191
192 const tensorflow::TensorList *list =
193 tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
194 if (!list) return failure();
195
196 // Verify output type is variant and contains exactly one ranked subtypes.
197 auto variant_ty =
198 getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
199 if (!variant_ty) return failure();
200 ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
201 if (subtypes.size() != 1) return failure();
202 RankedTensorType list_element_ty =
203 subtypes.front().dyn_cast<RankedTensorType>();
204 if (!list_element_ty) return failure();
205
206 // Extract tensor elements for the TensorList and construct result type
207 // based on the number of elements and element shape.
208 const std::vector<tensorflow::Tensor> &tensors = list->tensors();
209 llvm::SmallVector<int64_t, 4> result_shape = {
210 static_cast<int64_t>(tensors.size())};
211 result_shape.append(list_element_ty.getShape().begin(),
212 list_element_ty.getShape().end());
213 auto result_ty =
214 RankedTensorType::get(result_shape, list_element_ty.getElementType());
215
216 // If the list is empty, directly create the final result instead of
217 // creating the tf.Pack op. tf.Pack op requires at least one operand.
218 if (tensors.empty()) {
219 tensorflow::Tensor tensor(list->element_dtype,
220 tensorflow::TensorShape(result_shape));
221 auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
222 if (!attr_or.ok()) return failure();
223 rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
224 return success();
225 }
226
227 // Extract individual tensor list element and combine them using the tf.Pack
228 // op.
229 Location loc = op.getLoc();
230 llvm::SmallVector<Value, 4> values;
231 values.reserve(tensors.size());
232 for (const tensorflow::Tensor &tensor : tensors) {
233 auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
234 if (!attr_or.ok()) return failure();
235
236 auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
237 values.push_back(value);
238 }
239 rewriter.replaceOpWithNewOp<TF::PackOp>(
240 op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
241 return success();
242 }
243 };
244
245 struct ConvertTensorListSetItem
246 : public OpConversionPattern<TF::TensorListSetItemOp> {
247 using OpConversionPattern::OpConversionPattern;
248
249 // This function rewrites the original op into a series of slice and concat op
250 // to produce the same result. It first slices the first `$index` rows. Then
251 // expands the dimension of the `$item`, followed by another slice of the
252 // remaining rows starting from `$index` + 1. Lastly it concatenates the
253 // three parts together.
254 // On a high level, it's doing something like:
255 // def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
256 // (Concat
257 // concat_dim = 0,
258 // (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
259 // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
260 // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListSetItem261 LogicalResult matchAndRewrite(
262 TF::TensorListSetItemOp op, ArrayRef<Value> operands,
263 ConversionPatternRewriter &rewriter) const override {
264 Location loc = op.getLoc();
265 Value input = operands[0];
266 Value index = operands[1];
267 Value item = operands[2];
268
269 IntegerType shape_dtype = rewriter.getIntegerType(32);
270 auto item_rank = rewriter.create<TF::RankOp>(
271 loc, RankedTensorType::get({}, shape_dtype), item);
272 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
273
274 // Calculate `index` + 1, which is used to generate the start position for
275 // the second slice op.
276 auto suffix_start =
277 rewriter.create<TF::AddOp>(loc, index.getType(), index,
278 CreateI32SplatConst(loc, &rewriter, {}, 1));
279
280 auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
281 loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
282 // Create two slice ops.
283 Type element_type = input.getType().cast<TensorType>().getElementType();
284 UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
285 Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
286 TF::SliceOp slice1 =
287 CreateSliceOpForTensorList(loc, /*input_list=*/input,
288 /*start_index=*/scalar_zero,
289 /*size=*/index,
290 /*item_rank=*/item_position_shape,
291 /*result_type=*/unranked_tensor, &rewriter);
292 TF::SliceOp slice2 =
293 CreateSliceOpForTensorList(loc, /*input_list=*/input,
294 /*start_index=*/suffix_start,
295 /*size=*/scalar_minus_one,
296 /*item_rank=*/item_position_shape,
297 /*result_type=*/unranked_tensor, &rewriter);
298
299 // Expand the dimension of item so that it will have the same rank with
300 // input.
301 auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
302 op.getLoc(), unranked_tensor, item, scalar_zero);
303
304 // Concatenate three parts together to generate the final result.
305 rewriter.replaceOpWithNewOp<TF::ConcatOp>(
306 op, input.getType(), scalar_zero,
307 ArrayRef<Value>({slice1, expanded_item, slice2}));
308 return success();
309 }
310 };
311
312 // Rewrites op of the template type initializing a TensorList with a list of ops
313 // to generate an equivalent raw tensor. Derived classes are required to
314 // override GetNumElements method.
315 template <typename OpT>
316 struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
317 using OpConversionPattern<OpT>::OpConversionPattern;
318
319 // Create and return a 1-d tensor with exactly one element equal to the number
320 // of list elements to initialize the output tensor list with.
321 virtual Value GetNumElements(OpT op, ArrayRef<Value> operands,
322 PatternRewriter *rewriter) const = 0;
323
324 // Rewrites the original op into `tf.fill`. The result tensor shape is
325 // [num_element, element_shape]. All the values in the result tensor will be
326 // initialized to 0.
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListInitOp327 LogicalResult matchAndRewrite(
328 OpT op, ArrayRef<Value> operands,
329 ConversionPatternRewriter &rewriter) const override {
330 Type dtype = op.element_dtype();
331 if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
332 dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
333 dtype.isInteger(32) || dtype.isInteger(64))) {
334 return rewriter.notifyMatchFailure(
335 op,
336 "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
337 "integer or 16-bit/32-bit/64-bit float type during TF Lite "
338 "transformation pass");
339 }
340
341 Value element_shape = operands[0];
342 Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
343 // If the `element_shape` is a scalar, we try to acquire its shape by
344 // looking at the first `TensorListSetItemOp` writing to this tensor list.
345 // Here we assume that the element_shape won't be changed before calling
346 // the first `TensorListSetItemOp`.
347 if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
348 if (shaped_type.getRank() == 0) {
349 bool element_shape_acquired = false;
350 auto uses = op.getResult().getUses();
351 for (auto &use : llvm::make_early_inc_range(uses)) {
352 if (TF::TensorListSetItemOp set_op =
353 llvm::dyn_cast<TF::TensorListSetItemOp>(use.getOwner())) {
354 element_shape = rewriter.create<TF::ShapeOp>(
355 op.getLoc(), RankedTensorType::get({-1}, shape_dtype),
356 set_op.item());
357 element_shape_acquired = true;
358 } else if (TF::WhileOp while_op =
359 llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
360 // Tensorlist is passed into a while loop, check inside the body
361 // function.
362 auto inside_uses = while_op.body_function()
363 .getArgument(use.getOperandNumber())
364 .getUses();
365 for (auto &inside_use : llvm::make_early_inc_range(inside_uses)) {
366 if (TF::TensorListSetItemOp set_op =
367 llvm::dyn_cast<TF::TensorListSetItemOp>(
368 inside_use.getOwner())) {
369 if (auto shaped_type =
370 set_op.item().getType().dyn_cast<ShapedType>()) {
371 if (shaped_type.hasStaticShape()) {
372 RankedTensorType type = RankedTensorType::get(
373 {shaped_type.getRank()}, rewriter.getIntegerType(32));
374 SmallVector<Attribute, 4> shape_attr;
375 for (int64_t dim : shaped_type.getShape()) {
376 shape_attr.push_back(rewriter.getI32IntegerAttr(dim));
377 }
378 DenseElementsAttr attr =
379 DenseElementsAttr::get(type, shape_attr);
380 element_shape =
381 rewriter.create<ConstantOp>(op.getLoc(), type, attr);
382 element_shape_acquired = true;
383 break;
384 }
385 }
386 }
387 }
388 }
389 if (element_shape_acquired) break;
390 }
391 if (!element_shape_acquired) {
392 return rewriter.notifyMatchFailure(
393 op,
394 "requires element_shape to be 1D tensor during TF Lite "
395 "transformation pass");
396 }
397 }
398 }
399
400 DenseIntElementsAttr dense_elem_attr;
401 if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
402 // Note: It's technically unsafe to rewrite
403 // TensorListReserve(num_element, element_shape)
404 // to
405 // Fill(Concat(num_element, element_shape), 0)
406 // because element_shape may contain -1 to represent unknown dimension.
407 //
408 // In real world use cases (e.g. Keras RNN), `element_shape` is usually
409 // a constant, and the first dimension of `element_shape` is usually
410 // batch dimension. Currently TFLiteConverter always rewrite unknown
411 // batch dimension to 1, therefore we also rewrite unknown dimension in
412 // `element_shape` to 1 here.
413 //
414 // This workaround enables converting Keras RNN without specifying batch
415 // dimension. This isn't guaranteed to work, but it doesn't break any
416 // non-broken cases either (since it's already broken if `element_shape`
417 // contains -1).
418 // TODO(b/142096690): Support dynamic element shape and remove the
419 // workaround.
420 SmallVector<int32_t, 4> new_element_shape_values;
421
422 auto int_values = dense_elem_attr.getIntValues();
423 for (auto it = int_values.begin(); it != int_values.end(); ++it) {
424 auto dim_value = (*it).getSExtValue();
425 if (it == int_values.begin() && dim_value == -1) {
426 dim_value = 1;
427 }
428 new_element_shape_values.push_back(dim_value);
429 }
430
431 auto attr = DenseIntElementsAttr::get(
432 element_shape.getType().cast<ShapedType>(), new_element_shape_values);
433 auto new_element_shape = rewriter.create<ConstantOp>(
434 op.getLoc(), element_shape.getType(), attr);
435 element_shape = new_element_shape;
436 }
437
438 int64_t result_rank = -1; // -1 means unknown result rank.
439 Type element_dtype = op.element_dtype();
440 Type result_type = UnrankedTensorType::get(element_dtype);
441 Value leading_dim = GetNumElements(op, operands, &rewriter);
442 if (auto element_type =
443 op.element_type().template dyn_cast<RankedTensorType>()) {
444 result_rank = element_type.getRank() + 1;
445 int64_t leading_dim_v = -1;
446 ElementsAttr element_attr;
447 if (matchPattern(leading_dim, m_Constant(&element_attr))) {
448 leading_dim_v = element_attr.getValue<IntegerAttr>(0).getInt();
449 }
450 SmallVector<int64_t, 4> result_shape = {leading_dim_v};
451 ArrayRef<int64_t> shape = element_type.getShape();
452 result_shape.append(shape.begin(), shape.end());
453 result_type = RankedTensorType::get(result_shape, element_dtype);
454 }
455
456 // Create a 1-D RankedTensorType for result's shape. Number of elements in
457 // it is equal to the rank of the result, if known. Otherwise, the number of
458 // elements are unknown and represented with -1. In both cases, we can
459 // specify dimension using rank of the result.
460 Type shape_type = RankedTensorType::get({result_rank}, shape_dtype);
461
462 Location loc = op.getLoc();
463 // Add number of elements as the prefix to the element shape to get shape of
464 // the output tensor.
465 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
466 auto list_shape = rewriter.create<TF::ConcatOp>(
467 loc, shape_type, scalar_zero,
468 ArrayRef<Value>({leading_dim, element_shape}));
469
470 // Create a zero-initialized constant tensor that has the same type
471 // as specified by element_dtype.
472 RankedTensorType zero_type = RankedTensorType::get({}, element_dtype);
473 Attribute zero_attr = rewriter.getZeroAttr(zero_type);
474 auto zero = rewriter.create<ConstantOp>(loc, zero_type, zero_attr);
475
476 rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
477 return success();
478 }
479 };
480
481 struct ConvertTensorListReserve
482 : public ConvertTensorListInitOp<TF::TensorListReserveOp> {
ConvertTensorListReservemlir::__anon3c12621d0111::ConvertTensorListReserve483 explicit ConvertTensorListReserve(MLIRContext *context)
484 : ConvertTensorListInitOp(context) {}
485
GetNumElementsmlir::__anon3c12621d0111::ConvertTensorListReserve486 Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
487 PatternRewriter *rewriter) const override {
488 Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
489 Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
490 Value num_elements = operands[1];
491 IntegerAttr attr;
492 if (matchPattern(num_elements, m_Constant(&attr))) {
493 return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt());
494 }
495 return rewriter->create<TF::ExpandDimsOp>(
496 op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
497 scalar_zero);
498 }
499 };
500
501 // Note that we ignore the second operand `max_num_elements` as we don't have
502 // any restrictions on the number of elements we can support. So this may
503 // have a different behavior compared to TensorFlow in case of errors.
504 struct ConvertEmptyTensorList
505 : public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
ConvertEmptyTensorListmlir::__anon3c12621d0111::ConvertEmptyTensorList506 explicit ConvertEmptyTensorList(MLIRContext *context)
507 : ConvertTensorListInitOp(context) {}
508
GetNumElementsmlir::__anon3c12621d0111::ConvertEmptyTensorList509 Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
510 PatternRewriter *rewriter) const override {
511 return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
512 }
513 };
514
515 struct ConvertTensorListPushBack
516 : public OpConversionPattern<TF::TensorListPushBackOp> {
517 using OpConversionPattern::OpConversionPattern;
518
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListPushBack519 LogicalResult matchAndRewrite(
520 TF::TensorListPushBackOp op, ArrayRef<Value> operands,
521 ConversionPatternRewriter &rewriter) const override {
522 Value input_handle = operands[0];
523 Value item = operands[1];
524
525 // Expand the shape of the item so that it will have rank same as the input
526 // tensor and it is compatible for the Concat Op.
527 Type expanded_item_type =
528 PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
529 Location loc = op.getLoc();
530 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
531 auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
532 loc, expanded_item_type, item, scalar_zero);
533
534 Type elem_type = getElementTypeOrSelf(item);
535 auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
536 .cast<TF::VariantType>();
537 Type result_type =
538 GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
539
540 // Concatenate tensor stored in the input handle with the expanded item to
541 // get a tensor equivalent to the TensorList generated by this op.
542 rewriter.replaceOpWithNewOp<TF::ConcatOp>(
543 op, result_type, scalar_zero,
544 ArrayRef<Value>({input_handle, expanded_item}));
545 return success();
546 }
547 };
548
549 // Rewrites `TensorListResize` op into a functional `If` op and several basic
550 // TF ops to match the op semantics of Tensorflow. Basically, it does:
551 // 1) If the requested size is smaller or equal than the input tensorlist's
552 // size, rewrite it to a Slice op so that only the first 'size' rows are
553 // returned. 2) If the requested size is larger than the input tensorlist's
554 // size. We need to create an additional tensorlist with 'size - input_size'
555 // elements, and append it to the end of the input tensorlist.
556 // TODO(haoliang): We could simplify this transformation by rewriting to pure
557 // tensorlist ops and a few non-tensorlist ops (such as `SliceOp`). By operating
558 // only on variant types, we could save some ops involved in rewriting this op.
559 struct ConvertTensorListResize
560 : public OpConversionPattern<TF::TensorListResizeOp> {
561 using OpConversionPattern::OpConversionPattern;
562
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListResize563 LogicalResult matchAndRewrite(
564 TF::TensorListResizeOp op, ArrayRef<Value> operands,
565 ConversionPatternRewriter &rewriter) const override {
566 Value input_handle = operands[0];
567 Value size = operands[1];
568
569 Location loc = op.getLoc();
570 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
571
572 // Compute the input tensorlist's length and store it in `input_size`.
573 IntegerType shape_dtype = rewriter.getIntegerType(32);
574 auto input_size = rewriter.create<TF::TensorListLengthOp>(
575 loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
576
577 // Infer result type of this op based on TF's shape inference result.
578 Type elem_type = getElementTypeOrSelf(input_handle);
579 auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
580 .cast<TF::VariantType>();
581 Type result_type =
582 GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
583
584 // Compute the difference of `size` and `input_size`, and store it in
585 // `size_diff`, which is then consumed by `if_cond`.
586 auto size_diff = rewriter.create<TF::SubOp>(
587 loc, RankedTensorType::get({}, shape_dtype), size, input_size);
588 auto if_cond = rewriter.create<TF::GreaterOp>(
589 loc, RankedTensorType::get({}, rewriter.getI1Type()), size_diff,
590 scalar_zero);
591
592 // Build the argument/result types for if branch function.
593 auto input_shape = rewriter.create<TF::ShapeOp>(
594 loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
595
596 Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
597 size_diff.getType(), size.getType()};
598 Type branch_result_type[] = {result_type};
599 auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
600 branch_result_type);
601
602 // Create functions in a higher scope before restoring the insertion point.
603 // Additionally, create the SymbolTable before further modifying the module.
604 auto original_point = rewriter.saveInsertionPoint();
605 rewriter.setInsertionPointAfter(op->getParentOfType<FuncOp>());
606 SymbolTable manager(op->getParentOfType<ModuleOp>());
607
608 // Constructs `then_branch`, which is executed when `if_cond` evaluates to
609 // true.
610 auto then_branch_op = rewriter.create<FuncOp>(loc, "cond_true", func_type);
611 CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
612 &rewriter);
613
614 // Constructs `else_branch`, which is executed when `if_cond` evaluates to
615 // false.
616 auto else_branch_op = rewriter.create<FuncOp>(loc, "cond_false", func_type);
617 CreateCondFalseBranch(loc, shape_dtype, result_type, else_branch_op,
618 &rewriter);
619
620 // Inserts the two blocks' names into the symbol table held by the module.
621 // Using SymbolTable will ensure that the inserted symbol names are
622 // unique.
623 manager.insert(then_branch_op);
624 manager.insert(else_branch_op);
625
626 rewriter.restoreInsertionPoint(original_point);
627 rewriter.replaceOpWithNewOp<TF::IfOp>(
628 op, result_type, if_cond,
629 /*input=*/
630 ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
631 /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
632 /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
633 /*is_stateless=*/rewriter.getBoolAttr(true));
634 return success();
635 }
636
637 private:
638 // When the input tensorlist's size is smaller than the requested size,
639 // then branch is executed.
640 // Create a new tensorlist of size 'size - input_size' and concat it
641 // with the input tensorlist.
CreateCondTrueBranchmlir::__anon3c12621d0111::ConvertTensorListResize642 void CreateCondTrueBranch(TF::TensorListResizeOp resize_op, Type shape_dtype,
643 Type result_type, FuncOp branch_func,
644 ConversionPatternRewriter *rewriter) const {
645 auto guard = OpBuilder::InsertionGuard(*rewriter);
646 Block *block =
647 rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
648 branch_func.getType().getInputs());
649
650 auto input_shape = block->getArgument(1);
651 auto size_diff = block->getArgument(2);
652 auto input = block->getArgument(0);
653
654 Location loc = resize_op.getLoc();
655 // Get the element shape by slicing from index 1 in the input shape.
656 Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
657 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
658 Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
659 auto elem_shape = rewriter->create<TF::SliceOp>(
660 loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
661 slice_size);
662 auto extended_part = rewriter->create<TF::TensorListReserveOp>(
663 loc, resize_op.output_handle().getType(), elem_shape, size_diff);
664 // `ConcatOp` expects non-variant-typed input. Insert a
665 // `TensorListStackOp` here to convert type from variant to non-variant.
666 // Note that we are using the same `result_type` for both the
667 // `TensorListStackOp` and `ConcatOp`, since the first dimension of the
668 // shape specified by `result_type` is -1.
669 auto stacked_extended_part = rewriter->create<TF::TensorListStackOp>(
670 loc, result_type, extended_part,
671 /*element_shape=*/CreateI32SplatConst(loc, rewriter, {}, -1),
672 /*num_elements=*/rewriter->getI32IntegerAttr(-1));
673 auto concat_op = rewriter->create<TF::ConcatOp>(
674 loc, result_type, scalar_zero,
675 ArrayRef<Value>({input, stacked_extended_part}));
676 rewriter->create<ReturnOp>(loc, ArrayRef<Value>({concat_op}));
677 }
678
CreateCondFalseBranchmlir::__anon3c12621d0111::ConvertTensorListResize679 void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
680 FuncOp branch_func,
681 ConversionPatternRewriter *rewriter) const {
682 // When the input tensorlist's size is larger or equal than the requested
683 // size, the else branch is executed.
684 // Slice the first 'size' rows from the input tensorlist.
685 auto guard = OpBuilder::InsertionGuard(*rewriter);
686 Block *block =
687 rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
688 branch_func.getType().getInputs());
689
690 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
691 Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
692 auto input = block->getArgument(0);
693 auto size = block->getArgument(3);
694
695 // Subtract `input_rank` by 1 to get the item's rank, which is used as
696 // `partial_position_shape`.
697 auto input_rank = rewriter->create<TF::RankOp>(
698 loc, RankedTensorType::get({}, shape_dtype), input);
699 auto partial_position_shape = rewriter->create<TF::SubOp>(
700 loc, RankedTensorType::get({1}, shape_dtype), input_rank, vector_one);
701 auto slice_op =
702 CreateSliceOpForTensorList(loc, /*input_list=*/input,
703 /*start_index=*/scalar_zero, /*size=*/size,
704 /*item_rank=*/partial_position_shape,
705 /*result_type=*/result_type, rewriter);
706 rewriter->create<ReturnOp>(loc, ArrayRef<Value>({slice_op}));
707 }
708 };
709
710 struct ConvertTensorListGetItem
711 : public OpConversionPattern<TF::TensorListGetItemOp> {
712 using OpConversionPattern::OpConversionPattern;
713
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListGetItem714 LogicalResult matchAndRewrite(
715 TF::TensorListGetItemOp op, ArrayRef<Value> operands,
716 ConversionPatternRewriter &rewriter) const override {
717 Value input = operands[0];
718 Value index = operands[1];
719 rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
720 rewriter.getBoolAttr(true));
721 return success();
722 }
723 };
724
725 struct ConvertTensorListLength
726 : public OpConversionPattern<TF::TensorListLengthOp> {
727 using OpConversionPattern::OpConversionPattern;
728
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListLength729 LogicalResult matchAndRewrite(
730 TF::TensorListLengthOp op, ArrayRef<Value> operands,
731 ConversionPatternRewriter &rewriter) const override {
732 Location loc = op.getLoc();
733 Value input_handle = operands[0];
734
735 BoolAttr true_attr = rewriter.getBoolAttr(true);
736 auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
737 /*use_32bit=*/true_attr);
738 rewriter.replaceOpWithNewOp<TF::GatherOp>(
739 op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
740 /*validate_indices=*/true_attr);
741 return success();
742 }
743 };
744
745 struct ConvertTensorListStack
746 : public OpConversionPattern<TF::TensorListStackOp> {
747 using OpConversionPattern::OpConversionPattern;
748
matchAndRewritemlir::__anon3c12621d0111::ConvertTensorListStack749 LogicalResult matchAndRewrite(
750 TF::TensorListStackOp op, ArrayRef<Value> operands,
751 ConversionPatternRewriter &rewriter) const override {
752 Location loc = op.getLoc();
753 Value input = operands[0];
754 Value element_shape = operands[1];
755
756 // If the `element_shape` is a known constant (which is defined when calling
757 // `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
758 // trivial Reshape op (that doesn't actually change the input's shape) and
759 // also populate the shape info to the op result. The shape of the
760 // tensorlist is inferred from `num_elements` and `element_shape`.
761 auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
762 DenseIntElementsAttr dense_elem_attr;
763 if ((ranked_type && ranked_type.getRank() == 0) ||
764 !matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
765 // If no constant is spotted, just forward the operand.
766 rewriter.replaceOp(op, {input});
767 return success();
768 }
769
770 RankedTensorType shape_type =
771 RankedTensorType::get({-1}, rewriter.getIntegerType(32));
772 auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
773 SmallVector<int64_t, 8> output_shape(/*Size=*/1, op.num_elements());
774 for (const auto &dim : dense_elem_attr.getIntValues())
775 output_shape.push_back(dim.getSExtValue());
776 RankedTensorType result_type =
777 RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
778 rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
779 new_shape);
780 return success();
781 }
782 };
783
784 struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
785 using OpConversionPattern::OpConversionPattern;
786
matchAndRewritemlir::__anon3c12621d0111::ConvertIdentity787 LogicalResult matchAndRewrite(
788 TF::IdentityOp op, ArrayRef<Value> operands,
789 ConversionPatternRewriter &rewriter) const override {
790 Value input = operands[0];
791 rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
792 op.getAttrs());
793 return success();
794 }
795 };
796
797 // Returns an unranked tensor type with an element of the same type as `value`
798 // if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
VariantToUnrankedTensorType(Type type,Value value)799 Type VariantToUnrankedTensorType(Type type, Value value) {
800 if (getElementTypeOrSelf(type).isa<TF::VariantType>())
801 return UnrankedTensorType::get(getElementTypeOrSelf(value.getType()));
802 return type;
803 }
804
GetTensorListArgumentsFromWhileOp(TF::WhileOp op)805 llvm::SmallSet<int, 4> GetTensorListArgumentsFromWhileOp(TF::WhileOp op) {
806 llvm::SmallSet<int, 4> set;
807 for (FuncOp func : {op.cond_function(), op.body_function()}) {
808 if (!func) continue;
809
810 for (auto arg_and_idx : llvm::enumerate(func.getArguments())) {
811 mlir::BlockArgument arg = arg_and_idx.value();
812 auto variant_ty =
813 getElementTypeOrSelf(arg.getType()).dyn_cast<TF::VariantType>();
814 if (!variant_ty) continue;
815
816 for (auto &op_operand : arg.getUses()) {
817 auto op = op_operand.getOwner();
818 if (llvm::isa<TF::TensorListGetItemOp>(op) ||
819 llvm::isa<TF::TensorListLengthOp>(op) ||
820 llvm::isa<TF::TensorListPushBackOp>(op) ||
821 llvm::isa<TF::TensorListReserveOp>(op) ||
822 llvm::isa<TF::TensorListSetItemOp>(op) ||
823 llvm::isa<TF::TensorListStackOp>(op) ||
824 llvm::isa<TF::TensorListResizeOp>(op)) {
825 set.insert(arg_and_idx.index());
826 break;
827 }
828 }
829 }
830 }
831 return set;
832 }
833
834 // Changes the function type of `cond_func` and `body_func` for the given While
835 // op.
UpdateFunctionTypes(TF::WhileOp op,llvm::SmallSet<int,4> tensor_list_args)836 LogicalResult UpdateFunctionTypes(TF::WhileOp op,
837 llvm::SmallSet<int, 4> tensor_list_args) {
838 int func_index = 0;
839 for (FuncOp func : {op.cond_function(), op.body_function()}) {
840 ++func_index;
841 if (!func) continue;
842
843 FunctionType func_type = func.getType();
844 int num_inputs = func_type.getNumInputs();
845 int num_results = func_type.getNumResults();
846
847 // For each argument type in function's arguments, change it to uranked
848 // tensor type if it's a variant type.
849 SmallVector<Type, 8> updated_argument_types;
850 updated_argument_types.reserve(num_inputs);
851 int i = 0;
852 for (auto it : llvm::zip(func_type.getInputs(), op.getOperands())) {
853 if (tensor_list_args.count(i)) {
854 updated_argument_types.push_back(
855 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
856 } else {
857 updated_argument_types.push_back(std::get<0>(it));
858 }
859 ++i;
860 }
861
862 // Change all DT_VARIANT result types in function results to unranked tensor
863 // type with element type derived from the corresponding input operand. This
864 // is correct because while body's inputs and results have the same type.
865 SmallVector<Type, 8> updated_result_types;
866 updated_result_types.reserve(num_results);
867 i = 0;
868 for (auto it : llvm::zip(func_type.getResults(), op.getOperands())) {
869 // Only update body's results.
870 if (func_index != 1 && tensor_list_args.count(i)) {
871 updated_result_types.push_back(
872 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
873 } else {
874 updated_result_types.push_back(std::get<0>(it));
875 }
876 ++i;
877 }
878
879 // Change `func`'s argument type to `unranked_argument_types`. If it
880 // return types contain a `DT_VARIANT`, change it to the unranked type
881 // derived from the corresponding argument.
882 func.setType(FunctionType::get(op.getContext(), updated_argument_types,
883 updated_result_types));
884
885 // Change the argument type for the first block.
886 llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
887 arg.setType(updated_argument_types[arg.getArgNumber()]);
888 });
889 }
890 return success();
891 }
892
893 struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
894 using OpConversionPattern::OpConversionPattern;
895
matchAndRewritemlir::__anon3c12621d0111::ConvertWhile896 LogicalResult matchAndRewrite(
897 TF::WhileOp op, ArrayRef<Value> operands,
898 ConversionPatternRewriter &rewriter) const override {
899 // Find all Tensor List arugments.
900 auto tensor_list_args = GetTensorListArgumentsFromWhileOp(op);
901
902 llvm::SmallVector<Type, 8> result_types;
903 result_types.reserve(op.getNumOperands());
904 // Change all DT_VARIANT result types to unranked tensor type.
905 int i = 0;
906 for (auto it : llvm::zip(op.getResultTypes(), operands)) {
907 if (tensor_list_args.count(i)) {
908 result_types.push_back(
909 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
910 } else {
911 result_types.push_back(std::get<0>(it));
912 }
913 ++i;
914 }
915
916 // Create a new while op with new operands and updated result types.
917 auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
918 operands, op.getAttrs());
919 converted.removeAttr("T");
920 (void)UpdateFunctionTypes(converted, tensor_list_args);
921
922 rewriter.replaceOp(op, converted.getResults());
923 return success();
924 }
925 };
926
927 struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
928 using OpConversionPattern::OpConversionPattern;
929
matchAndRewritemlir::__anon3c12621d0111::ConvertWhileRegion930 LogicalResult matchAndRewrite(
931 TF::WhileRegionOp op, ArrayRef<Value> operands,
932 ConversionPatternRewriter &rewriter) const override {
933 llvm::SmallVector<Type, 8> result_types;
934 result_types.reserve(op.getNumOperands());
935 // Change all DT_VARIANT result types to unranked tensor type.
936 for (auto it : llvm::zip(op.getResultTypes(), operands))
937 result_types.push_back(
938 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
939
940 // Create a new while op with new operands and updated result types.
941 auto converted = rewriter.create<TF::WhileRegionOp>(
942 op.getLoc(), result_types, operands, op.getAttrs());
943
944 // Inline the regions from the old while into the new one, and apply
945 // signature conversion to inlined region.
946 for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
947 Region &old_region = *std::get<0>(it);
948 Region &new_region = *std::get<1>(it);
949
950 Block &entry = old_region.front();
951 // Build signature conversion for the region.
952 TypeConverter::SignatureConversion signature_conversion(operands.size());
953 for (auto it : llvm::zip(entry.getArguments(), operands)) {
954 BlockArgument arg = std::get<0>(it);
955 signature_conversion.addInputs(
956 arg.getArgNumber(),
957 VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
958 }
959
960 rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
961 rewriter.applySignatureConversion(&new_region, signature_conversion);
962 }
963
964 rewriter.replaceOp(op, converted.getResults());
965 return success();
966 }
967 };
968
969 #include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
970
runOnOperation()971 void LowerStaticTensorListPass::runOnOperation() {
972 auto *context = &getContext();
973
974 // TensorFlow operations that doesn't have operands and results of type
975 // variant are legal. Here, we don't distinguish between variants encoding
976 // TensorList or some other type as that information is not available here.
977 // Partial legalization is used below to still allow ops with variant types
978 // still.
979 auto is_legal = [](Operation *op) {
980 auto is_not_variant = [](Type ty) {
981 return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
982 };
983 return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
984 llvm::all_of(op->getResultTypes(), is_not_variant);
985 };
986
987 ConversionTarget target(*context);
988 target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(
989 llvm::Optional<ConversionTarget::DynamicLegalityCallbackFn>(is_legal));
990 target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
991 TF::TensorListGetItemOp, TF::TensorListLengthOp,
992 TF::TensorListPushBackOp, TF::TensorListReserveOp,
993 TF::TensorListSetItemOp, TF::TensorListStackOp,
994 TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
995 // TODO(hinsu): Use TFLite constant op for constants.
996 target.addLegalOp<ConstantOp>();
997 target.addLegalOp<FuncOp>();
998 target.addLegalOp<ReturnOp>();
999 target.addLegalOp<TFL::CustomOp>();
1000 // Register fused LSTM/RNN ops as legal.
1001 target.addLegalOp<TFL::LSTMOp>();
1002 target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
1003 target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
1004 target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
1005
1006 OwningRewritePatternList patterns;
1007 populateWithGenerated(context, patterns);
1008 patterns.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
1009 ConvertTensorListGetItem, ConvertTensorListLength,
1010 ConvertTensorListPushBack, ConvertTensorListReserve,
1011 ConvertTensorListSetItem, ConvertTensorListStack,
1012 ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
1013 context);
1014 if (failed(applyPartialConversion(getOperation(), target,
1015 std::move(patterns)))) {
1016 if (!allow_tensorlist_pass_through) {
1017 signalPassFailure();
1018 }
1019 }
1020 }
1021
1022 } // namespace
1023
1024 /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
1025 /// pass.
1026 std::unique_ptr<OperationPass<ModuleOp>>
CreateLowerStaticTensorListPass()1027 TFL::CreateLowerStaticTensorListPass() {
1028 return std::make_unique<LowerStaticTensorListPass>();
1029 }
1030
1031 static PassRegistration<LowerStaticTensorListPass> pass(
1032 "tfl-lower-static-tensor-list",
1033 "Lower TensorList ops within TensorFlow Lite dialect");
1034
1035 } // namespace mlir
1036