1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting Tensorlist operations in TensorFlow dialect into operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements. The
19 // newly created operations are in the TensorFlow dialect if the operation can
20 // be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
21 // is used.
22
23 #include <climits>
24 #include <cstdint>
25 #include <utility>
26
27 #include "absl/container/inlined_vector.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/None.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallSet.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringSwitch.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
39 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
40 #include "mlir/IR/Attributes.h" // from @llvm-project
41 #include "mlir/IR/Block.h" // from @llvm-project
42 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
43 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
44 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
45 #include "mlir/IR/MLIRContext.h" // from @llvm-project
46 #include "mlir/IR/Matchers.h" // from @llvm-project
47 #include "mlir/IR/Operation.h" // from @llvm-project
48 #include "mlir/IR/OperationSupport.h" // from @llvm-project
49 #include "mlir/IR/PatternMatch.h" // from @llvm-project
50 #include "mlir/IR/SymbolTable.h" // from @llvm-project
51 #include "mlir/IR/TypeRange.h" // from @llvm-project
52 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
53 #include "mlir/IR/Types.h" // from @llvm-project
54 #include "mlir/IR/UseDefLists.h" // from @llvm-project
55 #include "mlir/IR/Value.h" // from @llvm-project
56 #include "mlir/Pass/Pass.h" // from @llvm-project
57 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
58 #include "mlir/Support/LLVM.h" // from @llvm-project
59 #include "mlir/Support/LogicalResult.h" // from @llvm-project
60 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
61 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
62 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
63 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
64 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
69 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
70 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
71 #include "tensorflow/core/framework/tensor.h"
72 #include "tensorflow/core/framework/types.pb.h"
73 #include "tensorflow/core/kernels/tensor_list.h"
74
75 #define DEBUG_TYPE "tf-tfl-legalization"
76
77 //===----------------------------------------------------------------------===//
78 // The actual LowerStaticTensorList Pass.
79 //
80 namespace mlir {
81 namespace {
82
83 /// Lower TensorList ops in functions for subsequent legalization.
84 struct LowerStaticTensorListPass
85 : public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
86 LowerStaticTensorListPass() = default;
LowerStaticTensorListPassmlir::__anonce850f620111::LowerStaticTensorListPass87 LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
LowerStaticTensorListPassmlir::__anonce850f620111::LowerStaticTensorListPass88 explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through) {
89 this->allow_tensorlist_pass_through = allow_tensorlist_pass_through;
90 }
91
getArgumentmlir::__anonce850f620111::LowerStaticTensorListPass92 StringRef getArgument() const final {
93 // This is the argument used to refer to the pass in
94 // the textual format (on the commandline for example).
95 return "tfl-lower-static-tensor-list";
96 }
getDescriptionmlir::__anonce850f620111::LowerStaticTensorListPass97 StringRef getDescription() const final {
98 // This is a brief description of the pass.
99 return "Lower TensorList ops within TensorFlow Lite dialect";
100 }
101
102 void runOnOperation() override;
103
104 Option<bool> allow_tensorlist_pass_through{
105 *this, "allow-tensorlist-pass-through",
106 llvm::cl::desc(
107 "When specified to true, if the tensorlist ops can't be properly "
108 "legalized by this pass, then the IR won't be changed so that "
109 "tensorlist ops can pass through (default false)"),
110 llvm::cl::init(false)};
111 };
112
CreateI32SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int32_t val)113 Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
114 ArrayRef<int64_t> shape, int32_t val) {
115 RankedTensorType type =
116 RankedTensorType::get(shape, rewriter->getIntegerType(32));
117 DenseElementsAttr attr =
118 DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
119 return rewriter->create<ConstantOp>(loc, type, attr);
120 }
121
CreateI64SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int64_t val)122 Value CreateI64SplatConst(Location loc, PatternRewriter *rewriter,
123 ArrayRef<int64_t> shape, int64_t val) {
124 RankedTensorType type =
125 RankedTensorType::get(shape, rewriter->getIntegerType(64));
126 DenseElementsAttr attr =
127 DenseElementsAttr::get(type, rewriter->getI64IntegerAttr(val));
128 return rewriter->create<ConstantOp>(loc, type, attr);
129 }
130
CreateI32SplatTensor(Location loc,PatternRewriter * rewriter,Value shape_tensor,int32_t val)131 Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
132 Value shape_tensor, int32_t val) {
133 Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
134 return rewriter->create<TF::FillOp>(
135 loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
136 shape_tensor, scalar_val);
137 }
138
139 // Returns a new type by prepending the specified dimension to the shape of
140 // the given type if it is a ranked type.
PrependLeadingDimIfRanked(int64_t dim,Type type,PatternRewriter * rewriter)141 Type PrependLeadingDimIfRanked(int64_t dim, Type type,
142 PatternRewriter *rewriter) {
143 Type dtype = getElementTypeOrSelf(type);
144 if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
145 llvm::SmallVector<int64_t, 4> shape = {dim};
146 shape.append(ty.getShape().begin(), ty.getShape().end());
147 return RankedTensorType::get(shape, dtype);
148 }
149 return type;
150 }
151
GetTensorTypeForTensorList(Type element_type,TF::VariantType handle_dtype,PatternRewriter * rewriter)152 Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
153 PatternRewriter *rewriter) {
154 // If the variant type in the output handle has item shape available, use it
155 // to derive the output shape by setting unknown leading dimension.
156 // Otherwise, result type will be of unranked type.
157 if (handle_dtype.getSubtypes().empty()) {
158 return UnrankedTensorType::get(element_type);
159 }
160 return PrependLeadingDimIfRanked(-1, handle_dtype.getSubtypes()[0], rewriter);
161 }
162
163 // Gets the index of tensorlist arguments which size might get changed by the
164 // function.
GetResizedTensorListIndexes(FuncOp func,const llvm::SmallSet<int,4> & tensor_list_args)165 llvm::SmallSet<int, 4> GetResizedTensorListIndexes(
166 FuncOp func, const llvm::SmallSet<int, 4> &tensor_list_args) {
167 // `indexes` stores the argument index of tensorlists which size may get
168 // updated in the function.
169 llvm::SmallSet<int, 4> indexes;
170 for (BlockArgument &arg : func.getArguments()) {
171 if (tensor_list_args.contains(arg.getArgNumber())) {
172 for (const mlir::OpOperand &use : arg.getUses()) {
173 mlir::Operation *op = use.getOwner();
174 // Currently we only check if the tensorlist argument is consumed by
175 // `TensorListPushBack` or `TensorListResize`, since those are the only
176 // length-mutating ops supported in this pass.
177 if (llvm::isa<TF::TensorListPushBackOp>(op) ||
178 llvm::isa<TF::TensorListResizeOp>(op)) {
179 indexes.insert(arg.getArgNumber());
180 }
181 }
182 }
183 }
184 return indexes;
185 }
186
187 // Creates a slice of the tensorlist `input_list`, starting from
188 // [start_index, 0, ...0], with size [size, -1, ...-1].
189 //
190 // Requires that `start_index` and `size` are scalar tensors and
191 // `item_position_shape` is a 1-D tensor with only one element equal to the rank
192 // of an item in the tensorlist.
CreateSliceOpForTensorList(Location loc,Value input_list,Value start_index,Value size,Value item_rank,Type result_type,PatternRewriter * rewriter)193 TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
194 Value start_index, Value size,
195 Value item_rank, Type result_type,
196 PatternRewriter *rewriter) {
197 // Create the start position of slice. This is done by concatenating
198 // `start_index` and `partial_start_position` together.
199 IntegerType shape_dtype = rewriter->getIntegerType(32);
200 RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
201 Value partial_start_position =
202 CreateI32SplatTensor(loc, rewriter, item_rank, 0);
203 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
204 RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
205 auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
206 loc, vector_type, start_index, scalar_zero);
207 auto start_position = rewriter->create<TF::ConcatOp>(
208 loc, position_type, scalar_zero,
209 ArrayRef<Value>({expanded_start_index, partial_start_position}));
210
211 // Create the slice size tensor. This is done by concatenating `size` and
212 // `partial_size`.
213 auto size_leading_dim =
214 rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
215 Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
216 auto slice_size = rewriter->create<TF::ConcatOp>(
217 loc, position_type, scalar_zero,
218 ArrayRef<Value>({size_leading_dim, partial_size}));
219
220 return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
221 start_position, slice_size);
222 }
223
224 template <typename OpT>
225 class TensorListOpConverterBase : public OpConversionPattern<OpT> {
226 public:
TensorListOpConverterBase(MLIRContext * context,bool allow_tensorlist_pass_through)227 explicit TensorListOpConverterBase<OpT>(MLIRContext *context,
228 bool allow_tensorlist_pass_through)
229 : OpConversionPattern<OpT>::OpConversionPattern(context),
230 allow_tensorlist_pass_through_(allow_tensorlist_pass_through) {}
231
232 protected:
233 // This flag will control the behavior of error emitting during rewrite:
234 // 1) If it's true, then patterns will only emit errors during debug or
235 // tracing mode. 2) If it's false, then patterns will emit standard errors
236 // when there is a rewrite failure.
237 bool allow_tensorlist_pass_through_;
238 };
239
240 // Converts tf.Const containing variant of type TensorList to a tensor of
241 // primitive element types. Each of the individual tensor in the list is
242 // converted to an ElementsAttr and then those are packed together using
243 // tf.Pack op.
244 struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
245 using OpConversionPattern::OpConversionPattern;
246
matchAndRewritemlir::__anonce850f620111::ConvertConst247 LogicalResult matchAndRewrite(
248 TF::ConstOp op, ArrayRef<Value> operands,
249 ConversionPatternRewriter &rewriter) const override {
250 // Verify that the opaque elements attribute contains tensor of type variant
251 // and scalar shape. The variant type should hold a TensorList.
252 auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
253 if (!opaque_attr) return failure();
254 tensorflow::Tensor tensor;
255 if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
256 return failure();
257 if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
258 if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
259 return failure();
260
261 const tensorflow::TensorList *list =
262 tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
263 if (!list) return failure();
264
265 // Verify output type is variant and contains exactly one ranked subtypes.
266 auto variant_ty =
267 getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
268 if (!variant_ty) return failure();
269 ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
270 if (subtypes.size() != 1) return failure();
271 RankedTensorType list_element_ty =
272 subtypes.front().dyn_cast<RankedTensorType>();
273 if (!list_element_ty) return failure();
274
275 // Extract tensor elements for the TensorList and construct result type
276 // based on the number of elements and element shape.
277 const std::vector<tensorflow::Tensor> &tensors = list->tensors();
278 llvm::SmallVector<int64_t, 4> result_shape = {
279 static_cast<int64_t>(tensors.size())};
280 result_shape.append(list_element_ty.getShape().begin(),
281 list_element_ty.getShape().end());
282 auto result_ty =
283 RankedTensorType::get(result_shape, list_element_ty.getElementType());
284
285 // If the list is empty, directly create the final result instead of
286 // creating the tf.Pack op. tf.Pack op requires at least one operand.
287 if (tensors.empty()) {
288 tensorflow::Tensor tensor(list->element_dtype,
289 tensorflow::TensorShape(result_shape));
290 auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
291 if (!attr_or.ok()) return failure();
292 rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
293 return success();
294 }
295
296 // Extract individual tensor list element and combine them using the tf.Pack
297 // op.
298 Location loc = op.getLoc();
299 llvm::SmallVector<Value, 4> values;
300 values.reserve(tensors.size());
301 for (const tensorflow::Tensor &tensor : tensors) {
302 auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
303 if (!attr_or.ok()) return failure();
304
305 auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
306 values.push_back(value);
307 }
308 rewriter.replaceOpWithNewOp<TF::PackOp>(
309 op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
310 return success();
311 }
312 };
313
314 struct ConvertTensorListSetItem
315 : public OpConversionPattern<TF::TensorListSetItemOp> {
316 using OpConversionPattern::OpConversionPattern;
317
318 // This function rewrites the original op into a series of slice and concat op
319 // to produce the same result. It first slices the first `$index` rows. Then
320 // expands the dimension of the `$item`, followed by another slice of the
321 // remaining rows starting from `$index` + 1. Lastly it concatenates the
322 // three parts together.
323 // On a high level, it's doing something like:
324 // def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
325 // (Concat
326 // concat_dim = 0,
327 // (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
328 // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
329 // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
matchAndRewritemlir::__anonce850f620111::ConvertTensorListSetItem330 LogicalResult matchAndRewrite(
331 TF::TensorListSetItemOp op, ArrayRef<Value> operands,
332 ConversionPatternRewriter &rewriter) const override {
333 Location loc = op.getLoc();
334 Value input = operands[0];
335 Value index = operands[1];
336 Value item = operands[2];
337
338 IntegerType shape_dtype = rewriter.getIntegerType(32);
339 auto item_rank = rewriter.create<TF::RankOp>(
340 loc, RankedTensorType::get({}, shape_dtype), item);
341 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
342
343 // Calculate `index` + 1, which is used to generate the start position for
344 // the second slice op.
345 auto suffix_start =
346 rewriter.create<TF::AddOp>(loc, index.getType(), index,
347 CreateI32SplatConst(loc, &rewriter, {}, 1));
348
349 auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
350 loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
351 // Create two slice ops.
352 Type element_type = input.getType().cast<TensorType>().getElementType();
353 UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
354 Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
355 TF::SliceOp slice1 =
356 CreateSliceOpForTensorList(loc, /*input_list=*/input,
357 /*start_index=*/scalar_zero,
358 /*size=*/index,
359 /*item_rank=*/item_position_shape,
360 /*result_type=*/unranked_tensor, &rewriter);
361 TF::SliceOp slice2 =
362 CreateSliceOpForTensorList(loc, /*input_list=*/input,
363 /*start_index=*/suffix_start,
364 /*size=*/scalar_minus_one,
365 /*item_rank=*/item_position_shape,
366 /*result_type=*/unranked_tensor, &rewriter);
367
368 // Expand the dimension of item so that it will have the same rank with
369 // input.
370 auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
371 op.getLoc(), unranked_tensor, item, scalar_zero);
372
373 // Concatenate three parts together to generate the final result.
374 rewriter.replaceOpWithNewOp<TF::ConcatOp>(
375 op, input.getType(), scalar_zero,
376 ArrayRef<Value>({slice1, expanded_item, slice2}));
377 return success();
378 }
379 };
380
381 // Rewrites op of the template type initializing a TensorList with a list of ops
382 // to generate an equivalent raw tensor. Derived classes are required to
383 // override GetNumElements method.
384 template <typename OpT>
385 struct ConvertTensorListInitOp : public TensorListOpConverterBase<OpT> {
386 using TensorListOpConverterBase<OpT>::TensorListOpConverterBase;
387 using TensorListOpConverterBase<OpT>::allow_tensorlist_pass_through_;
388
389 // Create and return a 1-d tensor with exactly one element equal to the number
390 // of list elements to initialize the output tensor list with.
391 virtual Value GetNumElements(OpT op, ArrayRef<Value> operands,
392 PatternRewriter *rewriter) const = 0;
393
394 // Rewrites the original op into `tf.fill`. The result tensor shape is
395 // [num_element, element_shape]. All the values in the result tensor will be
396 // initialized to 0.
matchAndRewritemlir::__anonce850f620111::ConvertTensorListInitOp397 LogicalResult matchAndRewrite(
398 OpT op, ArrayRef<Value> operands,
399 ConversionPatternRewriter &rewriter) const override {
400 Type dtype = op.element_dtype();
401 if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
402 dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
403 dtype.isInteger(32) || dtype.isInteger(64))) {
404 const char *error_info =
405 "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
406 "integer or 16-bit/32-bit/64-bit float type during TF Lite "
407 "transformation pass";
408 return allow_tensorlist_pass_through_
409 ? rewriter.notifyMatchFailure(op, error_info)
410 : op.emitOpError(error_info);
411 }
412
413 Value element_shape = operands[0];
414 Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
415 // If the `element_shape` is a scalar, we try to acquire its shape by
416 // looking at the first `TensorListSetItemOp` writing to this tensor list.
417 // Here we assume that the element_shape won't be changed before calling
418 // the first `TensorListSetItemOp`.
419 if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
420 if (shaped_type.hasRank() && shaped_type.getRank() == 0) {
421 bool element_shape_acquired = false;
422 auto uses = op.getResult().getUses();
423 for (auto &use : llvm::make_early_inc_range(uses)) {
424 if (TF::TensorListSetItemOp set_op =
425 llvm::dyn_cast<TF::TensorListSetItemOp>(use.getOwner())) {
426 element_shape = rewriter.create<TF::ShapeOp>(
427 op.getLoc(), RankedTensorType::get({-1}, shape_dtype),
428 set_op.item());
429 element_shape_acquired = true;
430 } else if (TF::WhileOp while_op =
431 llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
432 // Tensorlist is passed into a while loop, check inside the body
433 // function.
434 auto inside_uses = while_op.body_function()
435 .getArgument(use.getOperandNumber())
436 .getUses();
437 for (auto &inside_use : llvm::make_early_inc_range(inside_uses)) {
438 if (TF::TensorListSetItemOp set_op =
439 llvm::dyn_cast<TF::TensorListSetItemOp>(
440 inside_use.getOwner())) {
441 if (auto shaped_type =
442 set_op.item().getType().dyn_cast<ShapedType>()) {
443 if (shaped_type.hasStaticShape()) {
444 RankedTensorType type = RankedTensorType::get(
445 {shaped_type.getRank()}, rewriter.getIntegerType(32));
446 SmallVector<Attribute, 4> shape_attr;
447 for (int64_t dim : shaped_type.getShape()) {
448 shape_attr.push_back(rewriter.getI32IntegerAttr(dim));
449 }
450 DenseElementsAttr attr =
451 DenseElementsAttr::get(type, shape_attr);
452 element_shape =
453 rewriter.create<ConstantOp>(op.getLoc(), type, attr);
454 element_shape_acquired = true;
455 break;
456 }
457 }
458 }
459 }
460 }
461 if (element_shape_acquired) break;
462 }
463 if (!element_shape_acquired) {
464 const char *error_info =
465 "requires element_shape to be 1D tensor during TF Lite "
466 "transformation pass";
467 return allow_tensorlist_pass_through_
468 ? rewriter.notifyMatchFailure(op, error_info)
469 : op.emitOpError(error_info);
470 }
471 }
472 }
473
474 DenseIntElementsAttr dense_elem_attr;
475 if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
476 // Note: It's technically unsafe to rewrite
477 // TensorListReserve(num_element, element_shape)
478 // to
479 // Fill(Concat(num_element, element_shape), 0)
480 // because element_shape may contain -1 to represent unknown dimension.
481 //
482 // In real world use cases (e.g. Keras RNN), `element_shape` is usually
483 // a constant, and the first dimension of `element_shape` is usually
484 // batch dimension. Currently TFLiteConverter always rewrite unknown
485 // batch dimension to 1, therefore we also rewrite unknown dimension in
486 // `element_shape` to 1 here.
487 //
488 // This workaround enables converting Keras RNN without specifying batch
489 // dimension. This isn't guaranteed to work, but it doesn't break any
490 // non-broken cases either (since it's already broken if `element_shape`
491 // contains -1).
492 // TODO(b/142096690): Support dynamic element shape and remove the
493 // workaround.
494 SmallVector<int32_t, 4> new_element_shape_values;
495
496 auto int_values = dense_elem_attr.getIntValues();
497 for (auto it = int_values.begin(); it != int_values.end(); ++it) {
498 auto dim_value = (*it).getSExtValue();
499 if (it == int_values.begin() && dim_value == -1) {
500 dim_value = 1;
501 }
502 new_element_shape_values.push_back(dim_value);
503 }
504
505 auto attr = DenseIntElementsAttr::get(
506 element_shape.getType().cast<ShapedType>(), new_element_shape_values);
507 auto new_element_shape = rewriter.create<ConstantOp>(
508 op.getLoc(), element_shape.getType(), attr);
509 element_shape = new_element_shape;
510 }
511
512 int64_t result_rank = -1; // -1 means unknown result rank.
513 Type element_dtype = op.element_dtype();
514 Type result_type = UnrankedTensorType::get(element_dtype);
515 Value leading_dim = GetNumElements(op, operands, &rewriter);
516 if (auto element_type =
517 op.element_type().template dyn_cast<RankedTensorType>()) {
518 result_rank = element_type.getRank() + 1;
519 int64_t leading_dim_v = -1;
520 ElementsAttr element_attr;
521 if (matchPattern(leading_dim, m_Constant(&element_attr))) {
522 leading_dim_v = element_attr.getValue<IntegerAttr>(0).getInt();
523 }
524 SmallVector<int64_t, 4> result_shape = {leading_dim_v};
525 ArrayRef<int64_t> shape = element_type.getShape();
526 result_shape.append(shape.begin(), shape.end());
527 result_type = RankedTensorType::get(result_shape, element_dtype);
528 }
529
530 // Create a 1-D RankedTensorType for result's shape. Number of elements in
531 // it is equal to the rank of the result, if known. Otherwise, the number of
532 // elements are unknown and represented with -1. In both cases, we can
533 // specify dimension using rank of the result.
534 Type shape_type = RankedTensorType::get({result_rank}, shape_dtype);
535
536 Location loc = op.getLoc();
537 // Add number of elements as the prefix to the element shape to get shape of
538 // the output tensor.
539 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
540 auto list_shape = rewriter.create<TF::ConcatOp>(
541 loc, shape_type, scalar_zero,
542 ArrayRef<Value>({leading_dim, element_shape}));
543
544 // Create a zero-initialized constant tensor that has the same type
545 // as specified by element_dtype.
546 RankedTensorType zero_type = RankedTensorType::get({}, element_dtype);
547 Attribute zero_attr = rewriter.getZeroAttr(zero_type);
548 auto zero = rewriter.create<ConstantOp>(loc, zero_type, zero_attr);
549
550 rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
551 return success();
552 }
553 };
554
555 struct ConvertTensorListReserve
556 : public ConvertTensorListInitOp<TF::TensorListReserveOp> {
ConvertTensorListReservemlir::__anonce850f620111::ConvertTensorListReserve557 explicit ConvertTensorListReserve(MLIRContext *context,
558 bool allow_tensorlist_pass_through)
559 : ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
560
GetNumElementsmlir::__anonce850f620111::ConvertTensorListReserve561 Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
562 PatternRewriter *rewriter) const override {
563 Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
564 Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
565 Value num_elements = operands[1];
566 IntegerAttr attr;
567 if (matchPattern(num_elements, m_Constant(&attr))) {
568 return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt());
569 }
570 if (auto const_op = num_elements.getDefiningOp<TF::ConstOp>()) {
571 return CreateI32SplatConst(
572 op->getLoc(), rewriter, {1},
573 (*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
574 .getSExtValue());
575 }
576 return rewriter->create<TF::ExpandDimsOp>(
577 op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
578 scalar_zero);
579 }
580 };
581
582 // Note that we ignore the second operand `max_num_elements` as we don't have
583 // any restrictions on the number of elements we can support. So this may
584 // have a different behavior compared to TensorFlow in case of errors.
585 struct ConvertEmptyTensorList
586 : public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
ConvertEmptyTensorListmlir::__anonce850f620111::ConvertEmptyTensorList587 explicit ConvertEmptyTensorList(MLIRContext *context,
588 bool allow_tensorlist_pass_through)
589 : ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
590
GetNumElementsmlir::__anonce850f620111::ConvertEmptyTensorList591 Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
592 PatternRewriter *rewriter) const override {
593 return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
594 }
595 };
596
597 struct ConvertTensorListPushBack
598 : public OpConversionPattern<TF::TensorListPushBackOp> {
599 using OpConversionPattern::OpConversionPattern;
600
matchAndRewritemlir::__anonce850f620111::ConvertTensorListPushBack601 LogicalResult matchAndRewrite(
602 TF::TensorListPushBackOp op, ArrayRef<Value> operands,
603 ConversionPatternRewriter &rewriter) const override {
604 Value input_handle = operands[0];
605 Value item = operands[1];
606
607 // Expand the shape of the item so that it will have rank same as the input
608 // tensor and it is compatible for the Concat Op.
609 Type expanded_item_type =
610 PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
611 Location loc = op.getLoc();
612 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
613 auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
614 loc, expanded_item_type, item, scalar_zero);
615
616 Type elem_type = getElementTypeOrSelf(item);
617 auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
618 .cast<TF::VariantType>();
619 Type result_type =
620 GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
621
622 // Concatenate tensor stored in the input handle with the expanded item to
623 // get a tensor equivalent to the TensorList generated by this op.
624 rewriter.replaceOpWithNewOp<TF::ConcatOp>(
625 op, result_type, scalar_zero,
626 ArrayRef<Value>({input_handle, expanded_item}));
627 return success();
628 }
629 };
630
631 // Rewrites `TensorListResize` op into a functional `If` op and several basic
632 // TF ops to match the op semantics of Tensorflow. Basically, it does:
633 // 1) If the requested size is smaller or equal than the input tensorlist's
634 // size, rewrite it to a Slice op so that only the first 'size' rows are
635 // returned. 2) If the requested size is larger than the input tensorlist's
636 // size. We need to create an additional tensorlist with 'size - input_size'
637 // elements, and append it to the end of the input tensorlist.
638 struct ConvertTensorListResize
639 : public OpConversionPattern<TF::TensorListResizeOp> {
640 using OpConversionPattern::OpConversionPattern;
641
matchAndRewritemlir::__anonce850f620111::ConvertTensorListResize642 LogicalResult matchAndRewrite(
643 TF::TensorListResizeOp op, ArrayRef<Value> operands,
644 ConversionPatternRewriter &rewriter) const override {
645 Value input_handle = operands[0];
646 Value size = operands[1];
647
648 Location loc = op.getLoc();
649 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
650
651 // Compute the input tensorlist's length and store it in `input_size`.
652 IntegerType shape_dtype = rewriter.getIntegerType(32);
653 auto input_size = rewriter.create<TF::TensorListLengthOp>(
654 loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
655
656 // Infer result type of this op based on TF's shape inference result.
657 Type elem_type = getElementTypeOrSelf(input_handle);
658 auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
659 .cast<TF::VariantType>();
660 Type result_type =
661 GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
662
663 // Compute the difference of `size` and `input_size`, and store it in
664 // `size_diff`, which is then consumed by `if_cond`.
665 auto size_diff = rewriter.create<TF::SubOp>(
666 loc, RankedTensorType::get({}, shape_dtype), size, input_size);
667 auto if_cond = rewriter.create<TF::GreaterOp>(
668 loc, RankedTensorType::get({}, rewriter.getI1Type()), size_diff,
669 scalar_zero);
670
671 // Build the argument/result types for if branch function.
672 auto input_shape = rewriter.create<TF::ShapeOp>(
673 loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
674
675 Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
676 size_diff.getType(), size.getType()};
677 Type branch_result_type[] = {result_type};
678 auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
679 branch_result_type);
680
681 // Create functions in a higher scope before restoring the insertion point.
682 // Additionally, create the SymbolTable before further modifying the module.
683 auto original_point = rewriter.saveInsertionPoint();
684 rewriter.setInsertionPointAfter(op->getParentOfType<FuncOp>());
685 SymbolTable manager(op->getParentOfType<ModuleOp>());
686
687 // Constructs `then_branch`, which is executed when `if_cond` evaluates to
688 // true.
689 auto then_branch_op = rewriter.create<FuncOp>(loc, "cond_true", func_type);
690 CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
691 &rewriter);
692
693 // Constructs `else_branch`, which is executed when `if_cond` evaluates to
694 // false.
695 auto else_branch_op = rewriter.create<FuncOp>(loc, "cond_false", func_type);
696 CreateCondFalseBranch(loc, shape_dtype, result_type, else_branch_op,
697 &rewriter);
698
699 // Inserts the two blocks' names into the symbol table held by the module.
700 // Using SymbolTable will ensure that the inserted symbol names are
701 // unique.
702 manager.insert(then_branch_op);
703 manager.insert(else_branch_op);
704
705 rewriter.restoreInsertionPoint(original_point);
706 rewriter.replaceOpWithNewOp<TF::IfOp>(
707 op, result_type, if_cond,
708 /*input=*/
709 ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
710 /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
711 /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
712 /*is_stateless=*/rewriter.getBoolAttr(true));
713 return success();
714 }
715
716 private:
717 // When the input tensorlist's size is smaller than the requested size,
718 // then branch is executed.
719 // Create a new tensorlist of size 'size - input_size' and concat it
720 // with the input tensorlist.
CreateCondTrueBranchmlir::__anonce850f620111::ConvertTensorListResize721 void CreateCondTrueBranch(TF::TensorListResizeOp resize_op, Type shape_dtype,
722 Type result_type, FuncOp branch_func,
723 ConversionPatternRewriter *rewriter) const {
724 auto guard = OpBuilder::InsertionGuard(*rewriter);
725 Block *block =
726 rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
727 branch_func.getType().getInputs());
728
729 auto input_shape = block->getArgument(1);
730 auto size_diff = block->getArgument(2);
731 auto input = block->getArgument(0);
732
733 Location loc = resize_op.getLoc();
734 // Get the element shape by slicing from index 1 in the input shape.
735 Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
736 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
737 Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
738 auto elem_shape = rewriter->create<TF::SliceOp>(
739 loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
740 slice_size);
741 auto extended_part = rewriter->create<TF::TensorListReserveOp>(
742 loc, resize_op.output_handle().getType(), elem_shape, size_diff);
743 // `ConcatOp` expects non-variant-typed input. Insert a
744 // `TensorListStackOp` here to convert type from variant to non-variant.
745 // Note that we are using the same `result_type` for both the
746 // `TensorListStackOp` and `ConcatOp`, since the first dimension of the
747 // shape specified by `result_type` is -1.
748 auto stacked_extended_part = rewriter->create<TF::TensorListStackOp>(
749 loc, result_type, extended_part,
750 /*element_shape=*/CreateI32SplatConst(loc, rewriter, {}, -1),
751 /*num_elements=*/rewriter->getI32IntegerAttr(-1));
752 auto concat_op = rewriter->create<TF::ConcatOp>(
753 loc, result_type, scalar_zero,
754 ArrayRef<Value>({input, stacked_extended_part}));
755 rewriter->create<ReturnOp>(loc, ArrayRef<Value>({concat_op}));
756 }
757
CreateCondFalseBranchmlir::__anonce850f620111::ConvertTensorListResize758 void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
759 FuncOp branch_func,
760 ConversionPatternRewriter *rewriter) const {
761 // When the input tensorlist's size is larger or equal than the requested
762 // size, the else branch is executed.
763 // Slice the first 'size' rows from the input tensorlist.
764 auto guard = OpBuilder::InsertionGuard(*rewriter);
765 Block *block =
766 rewriter->createBlock(&branch_func.getBody(), branch_func.begin(),
767 branch_func.getType().getInputs());
768
769 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
770 Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
771 auto input = block->getArgument(0);
772 auto size = block->getArgument(3);
773
774 // Subtract `input_rank` by 1 to get the item's rank, which is used as
775 // `partial_position_shape`.
776 auto input_rank = rewriter->create<TF::RankOp>(
777 loc, RankedTensorType::get({}, shape_dtype), input);
778 auto partial_position_shape = rewriter->create<TF::SubOp>(
779 loc, RankedTensorType::get({1}, shape_dtype), input_rank, vector_one);
780 auto slice_op =
781 CreateSliceOpForTensorList(loc, /*input_list=*/input,
782 /*start_index=*/scalar_zero, /*size=*/size,
783 /*item_rank=*/partial_position_shape,
784 /*result_type=*/result_type, rewriter);
785 rewriter->create<ReturnOp>(loc, ArrayRef<Value>({slice_op}));
786 }
787 };
788
789 struct ConvertTensorListGetItem
790 : public OpConversionPattern<TF::TensorListGetItemOp> {
791 using OpConversionPattern::OpConversionPattern;
792
matchAndRewritemlir::__anonce850f620111::ConvertTensorListGetItem793 LogicalResult matchAndRewrite(
794 TF::TensorListGetItemOp op, ArrayRef<Value> operands,
795 ConversionPatternRewriter &rewriter) const override {
796 Value input = operands[0];
797 Value index = operands[1];
798 rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
799 rewriter.getBoolAttr(true));
800 return success();
801 }
802 };
803
804 struct ConvertTensorListLength
805 : public OpConversionPattern<TF::TensorListLengthOp> {
806 using OpConversionPattern::OpConversionPattern;
807
matchAndRewritemlir::__anonce850f620111::ConvertTensorListLength808 LogicalResult matchAndRewrite(
809 TF::TensorListLengthOp op, ArrayRef<Value> operands,
810 ConversionPatternRewriter &rewriter) const override {
811 Location loc = op.getLoc();
812 Value input_handle = operands[0];
813
814 BoolAttr true_attr = rewriter.getBoolAttr(true);
815 auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
816 /*use_32bit=*/true_attr);
817 rewriter.replaceOpWithNewOp<TF::GatherOp>(
818 op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
819 /*validate_indices=*/true_attr);
820 return success();
821 }
822 };
823
824 struct ConvertTensorListStack
825 : public OpConversionPattern<TF::TensorListStackOp> {
826 using OpConversionPattern::OpConversionPattern;
827
matchAndRewritemlir::__anonce850f620111::ConvertTensorListStack828 LogicalResult matchAndRewrite(
829 TF::TensorListStackOp op, ArrayRef<Value> operands,
830 ConversionPatternRewriter &rewriter) const override {
831 Location loc = op.getLoc();
832 Value input = operands[0];
833 Value element_shape = operands[1];
834
835 // If the `element_shape` is a known constant (which is defined when calling
836 // `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
837 // trivial Reshape op (that doesn't actually change the input's shape) and
838 // also populate the shape info to the op result. The shape of the
839 // tensorlist is inferred from `num_elements` and `element_shape`.
840 auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
841 DenseIntElementsAttr dense_elem_attr;
842 if ((ranked_type && ranked_type.getRank() == 0) ||
843 !matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
844 // If no constant is spotted, just forward the operand.
845 rewriter.replaceOp(op, {input});
846 return success();
847 }
848
849 RankedTensorType shape_type =
850 RankedTensorType::get({-1}, rewriter.getIntegerType(32));
851 auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
852 SmallVector<int64_t, 8> output_shape(/*Size=*/1, op.num_elements());
853 for (const auto &dim : dense_elem_attr.getIntValues())
854 output_shape.push_back(dim.getSExtValue());
855 RankedTensorType result_type =
856 RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
857 rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
858 new_shape);
859 return success();
860 }
861 };
862
863 // Converts `TensorListConcatV2` into Unpack and Concat. First we unpack
864 // the input tensorlist along the first dimension, which results in N (where N
865 // is the first dim's size) tensors (each with shape [element_shape]). Then
866 // we concatenate all those tensors along the first dimension.
867 // The pattern will be rejected if either `element_shape` is not constant, or
868 // the first dimension of `input` is not known.
869 struct ConvertTensorListConcatV2
870 : public TensorListOpConverterBase<TF::TensorListConcatV2Op> {
871 using TensorListOpConverterBase<
872 TF::TensorListConcatV2Op>::TensorListOpConverterBase;
873 using TensorListOpConverterBase<
874 TF::TensorListConcatV2Op>::allow_tensorlist_pass_through_;
875
matchAndRewritemlir::__anonce850f620111::ConvertTensorListConcatV2876 LogicalResult matchAndRewrite(
877 TF::TensorListConcatV2Op op, ArrayRef<Value> operands,
878 ConversionPatternRewriter &rewriter) const override {
879 Location loc = op.getLoc();
880 Value input = operands[0];
881 Value element_shape = operands[1];
882
883 // Only match when `element_shape` is a constant.
884 DenseIntElementsAttr dense_elem_attr;
885 if (!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
886 const char *error_info = "requires element_shape to be a constant";
887 return allow_tensorlist_pass_through_
888 ? rewriter.notifyMatchFailure(op, error_info)
889 : op.emitOpError(error_info);
890 }
891 llvm::SmallVector<int64_t, 4> output_shape;
892 for (const auto &dim : dense_elem_attr.getIntValues()) {
893 output_shape.push_back(dim.getSExtValue());
894 }
895
896 // First unpack the input tensor along the first dimension.
897 Type input_element_type = getElementTypeOrSelf(input);
898 int64_t num_unpacked = 0;
899 if (auto type = input.getType().dyn_cast<RankedTensorType>()) {
900 if (type.getDimSize(0) > 0) {
901 num_unpacked = type.getDimSize(0);
902 } else {
903 const char *error_info =
904 "requires the first dimension of input tensor to have > 0 "
905 "dimension";
906 return allow_tensorlist_pass_through_
907 ? rewriter.notifyMatchFailure(op, error_info)
908 : op.emitOpError(error_info);
909 }
910 }
911 llvm::SmallVector<Type, 1> unpack_output_type;
912 unpack_output_type.insert(
913 unpack_output_type.begin(), num_unpacked,
914 RankedTensorType::get(output_shape, input_element_type));
915 auto unpack = rewriter.create<TF::UnpackOp>(loc, unpack_output_type, input,
916 /*axis=*/0);
917
918 // Concatenate the unpacked tensors along the first dimension.
919 // Since we're concatenating along first dimension, change its dim size to
920 // -1.
921 output_shape[0] = -1;
922 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
923 auto concat = rewriter.create<TF::ConcatOp>(
924 loc, RankedTensorType::get(output_shape, input_element_type),
925 scalar_zero, unpack->getResults());
926 // `lengths` is only useful for computing gradient. For now we just return
927 // a placeholder tensor.
928 rewriter.replaceOp(
929 op, {concat.getResult(), CreateI64SplatConst(loc, &rewriter, {0}, 0)});
930 return success();
931 }
932 };
933
934 struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
935 using OpConversionPattern::OpConversionPattern;
936
matchAndRewritemlir::__anonce850f620111::ConvertIdentity937 LogicalResult matchAndRewrite(
938 TF::IdentityOp op, ArrayRef<Value> operands,
939 ConversionPatternRewriter &rewriter) const override {
940 Value input = operands[0];
941 rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
942 op->getAttrs());
943 return success();
944 }
945 };
946
947 // Returns an unranked tensor type with an element of the same type as `value`
948 // if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
VariantToUnrankedTensorType(Type type,Value value)949 Type VariantToUnrankedTensorType(Type type, Value value) {
950 TF::VariantType variant_ty =
951 getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
952 if (!variant_ty) {
953 return type;
954 }
955 if (!variant_ty.getSubtypes().empty()) {
956 // Short-circut if the variant type has subtype info.
957 return UnrankedTensorType::get(
958 variant_ty.getSubtypes()[0].getElementType());
959 }
960 Type value_type = value.getType();
961 Type element_type;
962 variant_ty = value_type.dyn_cast<TF::VariantType>();
963 if (variant_ty && !variant_ty.getSubtypes().empty()) {
964 element_type = variant_ty.getSubtypes()[0].getElementType();
965 } else {
966 element_type = getElementTypeOrSelf(value_type);
967 }
968 return UnrankedTensorType::get(element_type);
969 }
970
971 // Returns true if we can deduce the type is tensorlist.
IsTensorListType(Type type,llvm::Optional<Value> value)972 bool IsTensorListType(Type type, llvm::Optional<Value> value) {
973 TF::VariantType variant_ty =
974 getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
975 if (!variant_ty) {
976 return false;
977 }
978 // Check there is only one subtype contained in the variant type. Note that
979 // when `subtypes.size() == 1` does not always mean the type is actually
980 // a tensorlist. We probably need some form of data flow analysis.
981 if (variant_ty.getSubtypes().size() == 1) {
982 return true;
983 }
984 // If subtype info is not available, check if the value is used by any of
985 // the following TensorList operations.
986 if (!value.hasValue()) {
987 return false;
988 }
989 for (const mlir::OpOperand &use : value.getValue().getUses()) {
990 mlir::Operation *op = use.getOwner();
991 if (llvm::isa<TF::TensorListGetItemOp>(op) ||
992 llvm::isa<TF::TensorListLengthOp>(op) ||
993 llvm::isa<TF::TensorListPushBackOp>(op) ||
994 llvm::isa<TF::TensorListReserveOp>(op) ||
995 llvm::isa<TF::TensorListSetItemOp>(op) ||
996 llvm::isa<TF::TensorListStackOp>(op) ||
997 llvm::isa<TF::TensorListResizeOp>(op)) {
998 return true;
999 }
1000 }
1001 return false;
1002 }
1003
1004 // Returns a set of integers that correspond to the tensorlist arguments in
1005 // the function.
GetTensorListArgumentsIndex(FuncOp func)1006 llvm::SmallSet<int, 4> GetTensorListArgumentsIndex(FuncOp func) {
1007 llvm::SmallSet<int, 4> set;
1008 for (const auto &arg_and_idx : llvm::enumerate(func.getArguments())) {
1009 if (IsTensorListType(arg_and_idx.value().getType(), arg_and_idx.value())) {
1010 set.insert(arg_and_idx.index());
1011 }
1012 }
1013 return set;
1014 }
1015
1016 // Returns a set of integers that correspond to the tensorlist results in the
1017 // function.
GetTensorListResultsIndex(FuncOp func)1018 llvm::SmallSet<int, 4> GetTensorListResultsIndex(FuncOp func) {
1019 llvm::SmallSet<int, 4> set;
1020
1021 for (const auto &result_and_idx :
1022 llvm::enumerate(func.getType().getResults())) {
1023 if (IsTensorListType(result_and_idx.value(), llvm::None)) {
1024 set.insert(result_and_idx.index());
1025 }
1026 }
1027 return set;
1028 }
1029
1030 // Updates the tensorlist types based on the input index. If the tensorlist's
1031 // size isn't changed(which is indicated by `resized_tensor_list_index`), then
1032 // we will use the original operand's type, otherwise update it with the
1033 // unranked tensor type.
1034 template <typename R>
UpdateTensorListTypes(const llvm::SmallSet<int,4> & tensor_list_index,const llvm::SmallSet<int,4> & resized_tensor_list_index,ArrayRef<Type> types,R && range,ArrayRef<Value> operands,llvm::SmallVectorImpl<Type> * updated_types)1035 void UpdateTensorListTypes(
1036 const llvm::SmallSet<int, 4> &tensor_list_index,
1037 const llvm::SmallSet<int, 4> &resized_tensor_list_index,
1038 ArrayRef<Type> types, R &&range, ArrayRef<Value> operands,
1039 llvm::SmallVectorImpl<Type> *updated_types) {
1040 int i = 0;
1041 for (const auto it : llvm::zip(types, range, operands)) {
1042 if (tensor_list_index.count(i)) {
1043 // Only change the tensorlist's type to unranked tensor if it has been
1044 // resized.
1045 if (resized_tensor_list_index.count(i)) {
1046 updated_types->push_back(
1047 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1048 } else {
1049 updated_types->push_back(std::get<2>(it).getType());
1050 }
1051 } else {
1052 updated_types->push_back(std::get<0>(it));
1053 }
1054 ++i;
1055 }
1056 }
1057
1058 // Updates the tensorlist types to unranked tensor types based on the input
1059 // index.
1060 template <typename R>
ChangeVariantToUnrankedTensorType(const llvm::SmallSet<int,4> & tensor_list_index,ArrayRef<Type> types,R && range,llvm::SmallVectorImpl<Type> * updated_types)1061 void ChangeVariantToUnrankedTensorType(
1062 const llvm::SmallSet<int, 4> &tensor_list_index, ArrayRef<Type> types,
1063 R &&range, llvm::SmallVectorImpl<Type> *updated_types) {
1064 int i = 0;
1065 for (const auto it : llvm::zip(types, range)) {
1066 if (tensor_list_index.count(i)) {
1067 updated_types->push_back(
1068 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1069 } else {
1070 updated_types->push_back(std::get<0>(it));
1071 }
1072 ++i;
1073 }
1074 }
1075
1076 // Updates the specified function's type and region signature.
UpdateFunctionAndRegionType(ConversionPatternRewriter & rewriter,FuncOp func,llvm::ArrayRef<Type> updated_argument_types,llvm::ArrayRef<Type> updated_result_types)1077 void UpdateFunctionAndRegionType(ConversionPatternRewriter &rewriter,
1078 FuncOp func,
1079 llvm::ArrayRef<Type> updated_argument_types,
1080 llvm::ArrayRef<Type> updated_result_types) {
1081 // Change `func`'s argument type to `unranked_argument_types`. If its
1082 // return types contain a `DT_VARIANT`, change it to the unranked type
1083 // derived from the corresponding argument.
1084 rewriter.updateRootInPlace(func, [&] {
1085 func.setType(FunctionType::get(func.getContext(), updated_argument_types,
1086 updated_result_types));
1087 });
1088 Region &entry = func.getRegion();
1089 TypeConverter::SignatureConversion signature_conversion(
1090 entry.getNumArguments());
1091 for (const BlockArgument &arg : entry.getArguments()) {
1092 signature_conversion.addInputs(arg.getArgNumber(),
1093 updated_argument_types[arg.getArgNumber()]);
1094 }
1095 rewriter.applySignatureConversion(&entry, signature_conversion);
1096 }
1097
1098 // Changes the function type of `cond_func` and `body_func` for the given While
1099 // op.
UpdateFunctionTypesForWhileOp(ConversionPatternRewriter & rewriter,TF::WhileOp op,ArrayRef<Value> operands,const llvm::SmallSet<int,4> & tensor_list_args,const llvm::SmallSet<int,4> & resized_tensor_lists)1100 LogicalResult UpdateFunctionTypesForWhileOp(
1101 ConversionPatternRewriter &rewriter, TF::WhileOp op,
1102 ArrayRef<Value> operands, const llvm::SmallSet<int, 4> &tensor_list_args,
1103 const llvm::SmallSet<int, 4> &resized_tensor_lists) {
1104 int func_index = 0;
1105 for (FuncOp func : {op.cond_function(), op.body_function()}) {
1106 ++func_index;
1107 if (!func) continue;
1108
1109 FunctionType func_type = func.getType();
1110 int num_inputs = func_type.getNumInputs();
1111 int num_results = func_type.getNumResults();
1112
1113 // For each argument type in function's arguments, change it to uranked
1114 // tensor type if it's a variant type.
1115 SmallVector<Type, 8> updated_argument_types;
1116 updated_argument_types.reserve(num_inputs);
1117 UpdateTensorListTypes<mlir::OperandRange>(
1118 tensor_list_args, resized_tensor_lists, func_type.getInputs(),
1119 op.getOperands(), operands, &updated_argument_types);
1120
1121 // Change all DT_VARIANT result types in function results to unranked tensor
1122 // type with element type derived from the corresponding input operand. This
1123 // is correct because while body's inputs and results have the same type.
1124 SmallVector<Type, 8> updated_result_types;
1125 updated_result_types.reserve(num_results);
1126 if (func_index == 1) {
1127 // We only update the result types for the body function.
1128 for (Type ty : func_type.getResults()) {
1129 updated_result_types.push_back(ty);
1130 }
1131 } else {
1132 UpdateTensorListTypes<mlir::OperandRange>(
1133 tensor_list_args, resized_tensor_lists, func_type.getResults(),
1134 op.getOperands(), operands, &updated_result_types);
1135 }
1136
1137 UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
1138 updated_result_types);
1139 }
1140 return success();
1141 }
1142
1143 // Changes the function type of `then_function` and `else_function` for the
1144 // given If op.
UpdateFunctionTypesForIfOp(ConversionPatternRewriter & rewriter,TF::IfOp op,llvm::ArrayRef<Value> operands,const llvm::SmallSet<int,4> & tensor_list_args,const llvm::SmallSet<int,4> & resized_tensor_lists,llvm::ArrayRef<Type> updated_result_types)1145 LogicalResult UpdateFunctionTypesForIfOp(
1146 ConversionPatternRewriter &rewriter, TF::IfOp op,
1147 llvm::ArrayRef<Value> operands,
1148 const llvm::SmallSet<int, 4> &tensor_list_args,
1149 const llvm::SmallSet<int, 4> &resized_tensor_lists,
1150 llvm::ArrayRef<Type> updated_result_types) {
1151 for (FuncOp func : {op.else_function(), op.then_function()}) {
1152 if (!func) continue;
1153
1154 FunctionType func_type = func.getType();
1155 int num_inputs = func_type.getNumInputs();
1156
1157 // Update the argument types of the function. If it's a tensorlist and
1158 // is not resized inside the function, we will use the corresponding
1159 // operand's type, otherwise change its type to unranked tensor type.
1160 SmallVector<Type, 8> updated_argument_types;
1161 updated_argument_types.reserve(num_inputs);
1162 UpdateTensorListTypes<mlir::OperandRange>(
1163 tensor_list_args, resized_tensor_lists, func_type.getInputs(),
1164 op.getOperands().drop_front(), operands.drop_front(),
1165 &updated_argument_types);
1166
1167 UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
1168 updated_result_types);
1169 }
1170 return success();
1171 }
1172
1173 // Returns a `llvm::DenseMap` which maps from the index of tensorlist in the
1174 // result, to the index of the same tensorlist in the arguments. For `If` op's
1175 // branch functions, the results and arguments are not usually matched 1-1. This
1176 // will let us konw which tensorlist result maps to which tensorlist in the
1177 // arguments. Once we know this info it will help us decide the types of the
1178 // result tensorlist based on the operand's of the `If` op.
MapTensorListResultToArgument(FuncOp func)1179 llvm::DenseMap<int, int> MapTensorListResultToArgument(FuncOp func) {
1180 // `map_fn` will trace upwards along the use-def chain of the ssa value. It
1181 // starts from the last ssa value (returned by the function), and check its
1182 // parent op iteratively. If the root ssa value appears in the function's
1183 // argument list, it will return the index of the corresponding argument,
1184 // otherwise it will return -1.
1185 auto map_fn = [](Value value) -> int {
1186 Value parent = value;
1187 while (true) {
1188 if (auto identity = parent.getDefiningOp<TF::IdentityOp>()) {
1189 parent = identity.input();
1190 } else if (auto set_item =
1191 parent.getDefiningOp<TF::TensorListSetItemOp>()) {
1192 parent = set_item.input_handle();
1193 } else {
1194 break;
1195 }
1196 }
1197 if (auto block_arg = parent.dyn_cast<mlir::BlockArgument>()) {
1198 return block_arg.getArgNumber();
1199 }
1200 // Returns -1 if we don't find which this result maps to.
1201 return -1;
1202 };
1203
1204 llvm::SmallVector<Value, 4> returns;
1205 for (auto res : func.getBody().back().getTerminator()->getOperands()) {
1206 returns.push_back(res);
1207 }
1208 llvm::DenseMap<int, int> result;
1209 for (const auto &result_and_idx : llvm::enumerate(returns)) {
1210 if (IsTensorListType(result_and_idx.value().getType(),
1211 result_and_idx.value())) {
1212 int arg_idx = map_fn(result_and_idx.value());
1213 if (arg_idx != -1) {
1214 result.insert({result_and_idx.index(), arg_idx});
1215 }
1216 }
1217 }
1218 return result;
1219 }
1220
1221 // Updates the tensorlist result types for the `If` Op. If the tensorlist result
1222 // maps to a specific argument (indicated by `tensor_list_map`), and also that
1223 // tensorlist argument's shape isn't changed (indicated by
1224 // `resized_tensor_list_index`), we will update this tensorlist's result type to
1225 // the corresponding operand's type. In all other cases we change the
1226 // tensorlist's type to unranked tensor type.
1227 template <typename R>
UpdateTensorListResultTypesForIf(const llvm::SmallSet<int,4> & tensor_list_index,const llvm::SmallSet<int,4> & resized_tensor_list_index,const llvm::DenseMap<int,int> & tensor_list_map,ArrayRef<Type> types,R && range,ArrayRef<Value> operands,llvm::SmallVectorImpl<Type> * updated_types)1228 void UpdateTensorListResultTypesForIf(
1229 const llvm::SmallSet<int, 4> &tensor_list_index,
1230 const llvm::SmallSet<int, 4> &resized_tensor_list_index,
1231 const llvm::DenseMap<int, int> &tensor_list_map, ArrayRef<Type> types,
1232 R &&range, ArrayRef<Value> operands,
1233 llvm::SmallVectorImpl<Type> *updated_types) {
1234 int i = 0;
1235 for (const auto it : llvm::zip(types, range)) {
1236 if (!tensor_list_index.count(i)) {
1237 updated_types->push_back(std::get<0>(it));
1238 ++i;
1239 continue;
1240 }
1241 auto iter = tensor_list_map.find(i);
1242 if (iter != tensor_list_map.end()) {
1243 int arg_idx = iter->second;
1244 if (!resized_tensor_list_index.count(arg_idx)) {
1245 // If the mapped tensorlist argument's size isn't changed, we will
1246 // use the corresponding `operand` type.
1247 updated_types->push_back(operands[arg_idx].getType());
1248 ++i;
1249 continue;
1250 }
1251 }
1252 updated_types->push_back(
1253 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1254 ++i;
1255 }
1256 }
1257
1258 struct ConvertIf : public OpConversionPattern<TF::IfOp> {
1259 using OpConversionPattern::OpConversionPattern;
1260
matchAndRewritemlir::__anonce850f620111::ConvertIf1261 LogicalResult matchAndRewrite(
1262 TF::IfOp op, ArrayRef<Value> operands,
1263 ConversionPatternRewriter &rewriter) const override {
1264 // Find all Tensor List arugments.
1265 auto tensor_list_args = GetTensorListArgumentsIndex(op.else_function());
1266 auto tensor_list_results = GetTensorListResultsIndex(op.else_function());
1267 auto tensor_list_map = MapTensorListResultToArgument(op.else_function());
1268 llvm::SmallSet<int, 4> resized_tensor_lists =
1269 GetResizedTensorListIndexes(op.else_function(), tensor_list_args);
1270
1271 llvm::SmallVector<Type, 8> result_types;
1272 result_types.reserve(op.getNumResults());
1273 llvm::SmallVector<Type, 4> op_result_types;
1274 for (Type ty : op.getResultTypes()) {
1275 op_result_types.push_back(ty);
1276 }
1277
1278 UpdateTensorListResultTypesForIf<mlir::ResultRange>(
1279 tensor_list_results, resized_tensor_lists, tensor_list_map,
1280 op_result_types, op->getResults(), operands.drop_front(),
1281 &result_types);
1282
1283 // Create a new if op with new operands and updated result types.
1284 auto converted = rewriter.create<TF::IfOp>(op.getLoc(), result_types,
1285 operands, op->getAttrs());
1286 converted->removeAttr("T");
1287 (void)UpdateFunctionTypesForIfOp(rewriter, converted, operands,
1288 tensor_list_args, resized_tensor_lists,
1289 result_types);
1290 rewriter.replaceOp(op, converted.getResults());
1291 return success();
1292 }
1293 };
1294
1295 struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
1296 using OpConversionPattern::OpConversionPattern;
1297
matchAndRewritemlir::__anonce850f620111::ConvertWhile1298 LogicalResult matchAndRewrite(
1299 TF::WhileOp op, ArrayRef<Value> operands,
1300 ConversionPatternRewriter &rewriter) const override {
1301 // Find all Tensor List arugments.
1302 auto tensor_list_args = GetTensorListArgumentsIndex(op.body_function());
1303
1304 llvm::SmallVector<Type, 8> result_types;
1305 result_types.reserve(op.getNumOperands());
1306 // Change all DT_VARIANT result types to unranked tensor type.
1307 llvm::SmallVector<Type, 4> op_result_types;
1308 for (Type ty : op.getResultTypes()) {
1309 op_result_types.push_back(ty);
1310 }
1311
1312 llvm::SmallSet<int, 4> resized_tensor_lists =
1313 GetResizedTensorListIndexes(op.body_function(), tensor_list_args);
1314 UpdateTensorListTypes<mlir::OperandRange>(
1315 tensor_list_args, resized_tensor_lists, op_result_types,
1316 op.getOperands(), operands, &result_types);
1317
1318 // Create a new while op with new operands and updated result types.
1319 auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
1320 operands, op->getAttrs());
1321 converted->removeAttr("T");
1322 (void)UpdateFunctionTypesForWhileOp(rewriter, converted, operands,
1323 tensor_list_args, resized_tensor_lists);
1324
1325 rewriter.replaceOp(op, converted.getResults());
1326 return success();
1327 }
1328 };
1329
1330 struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
1331 using OpConversionPattern::OpConversionPattern;
1332
matchAndRewritemlir::__anonce850f620111::ConvertWhileRegion1333 LogicalResult matchAndRewrite(
1334 TF::WhileRegionOp op, ArrayRef<Value> operands,
1335 ConversionPatternRewriter &rewriter) const override {
1336 llvm::SmallVector<Type, 8> result_types;
1337 result_types.reserve(op.getNumOperands());
1338 // Change all DT_VARIANT result types to unranked tensor type.
1339 for (auto it : llvm::zip(op.getResultTypes(), operands))
1340 result_types.push_back(
1341 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1342
1343 // Create a new while op with new operands and updated result types.
1344 auto converted = rewriter.create<TF::WhileRegionOp>(
1345 op.getLoc(), result_types, operands, op->getAttrs());
1346
1347 // Inline the regions from the old while into the new one, and apply
1348 // signature conversion to inlined region.
1349 for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
1350 Region &old_region = *std::get<0>(it);
1351 Region &new_region = *std::get<1>(it);
1352
1353 Block &entry = old_region.front();
1354 // Build signature conversion for the region.
1355 TypeConverter::SignatureConversion signature_conversion(operands.size());
1356 for (auto it : llvm::zip(entry.getArguments(), operands)) {
1357 BlockArgument arg = std::get<0>(it);
1358 signature_conversion.addInputs(
1359 arg.getArgNumber(),
1360 VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
1361 }
1362
1363 rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
1364 rewriter.applySignatureConversion(&new_region, signature_conversion);
1365 }
1366
1367 rewriter.replaceOp(op, converted.getResults());
1368 return success();
1369 }
1370 };
1371
1372 #include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
1373
runOnOperation()1374 void LowerStaticTensorListPass::runOnOperation() {
1375 auto *context = &getContext();
1376
1377 // TensorFlow operations that doesn't have operands and results of type
1378 // variant are legal. Here, we don't distinguish between variants encoding
1379 // TensorList or some other type as that information is not available here.
1380 // Partial legalization is used below to still allow ops with variant types
1381 // still.
1382 auto is_legal = [](Operation *op) {
1383 auto is_not_variant = [](Type ty) {
1384 return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
1385 };
1386 return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
1387 llvm::all_of(op->getResultTypes(), is_not_variant);
1388 };
1389
1390 ConversionTarget target(*context);
1391 target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(is_legal);
1392 target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
1393 TF::TensorListGetItemOp, TF::TensorListLengthOp,
1394 TF::TensorListPushBackOp, TF::TensorListReserveOp,
1395 TF::TensorListSetItemOp, TF::TensorListStackOp,
1396 TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
1397 // TODO(hinsu): Use TFLite constant op for constants.
1398 target.addLegalOp<ConstantOp>();
1399 target.addLegalOp<FuncOp>();
1400 target.addLegalOp<ReturnOp>();
1401 target.addLegalOp<TFL::CustomOp>();
1402 // Register fused LSTM/RNN ops as legal.
1403 target.addLegalOp<TFL::LSTMOp>();
1404 target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
1405 target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
1406 target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
1407
1408 OwningRewritePatternList patterns(&getContext());
1409 populateWithGenerated(patterns);
1410 patterns.insert<ConvertConst, ConvertIdentity, ConvertTensorListGetItem,
1411 ConvertTensorListLength, ConvertTensorListPushBack,
1412 ConvertTensorListSetItem, ConvertTensorListStack,
1413 ConvertTensorListResize, ConvertWhile, ConvertWhileRegion,
1414 ConvertIf>(context);
1415 patterns.insert<ConvertEmptyTensorList, ConvertTensorListReserve,
1416 ConvertTensorListConcatV2>(context,
1417 allow_tensorlist_pass_through);
1418 ModuleOp module = getOperation();
1419 if (!allow_tensorlist_pass_through) {
1420 if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
1421 module.emitError(
1422 "Lowering tensor list ops is failed. Please consider using Select TF "
1423 "ops and disabling `_experimental_lower_tensor_list_ops` flag in the "
1424 "TFLite converter object. For example, "
1425 "converter.target_spec.supported_ops = "
1426 "[tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]\\n "
1427 "converter._experimental_lower_tensor_list_ops = False");
1428 signalPassFailure();
1429 }
1430 } else {
1431 // If `allow_tensorlist_pass_through` is set to true, if legalization fails
1432 // we should not leak the diagnostic info outside this pass. Hence we use
1433 // a `StatusScopedDiagnosticHandler` here to capture diagnostics generated
1434 // within this pass.
1435 StatusScopedDiagnosticHandler handler(context);
1436 if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
1437 auto _ = handler.ConsumeStatus();
1438 }
1439 }
1440 }
1441
1442 } // namespace
1443
1444 /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
1445 /// pass.
CreateLowerStaticTensorListPass(bool allow_tensorlist_pass_through)1446 std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass(
1447 bool allow_tensorlist_pass_through) {
1448 return std::make_unique<LowerStaticTensorListPass>(
1449 allow_tensorlist_pass_through);
1450 }
1451
1452 static PassRegistration<LowerStaticTensorListPass> pass;
1453
1454 } // namespace mlir
1455