1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/BitVector.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/None.h"
33 #include "llvm/ADT/Optional.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/Sequence.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/ADT/StringSwitch.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/ADT/iterator_range.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
45 #include "mlir/Dialect/Traits.h" // from @llvm-project
46 #include "mlir/IR/Attributes.h" // from @llvm-project
47 #include "mlir/IR/Builders.h" // from @llvm-project
48 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
49 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
50 #include "mlir/IR/Diagnostics.h" // from @llvm-project
51 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
52 #include "mlir/IR/Identifier.h" // from @llvm-project
53 #include "mlir/IR/Location.h" // from @llvm-project
54 #include "mlir/IR/MLIRContext.h" // from @llvm-project
55 #include "mlir/IR/Matchers.h" // from @llvm-project
56 #include "mlir/IR/OpDefinition.h" // from @llvm-project
57 #include "mlir/IR/OpImplementation.h" // from @llvm-project
58 #include "mlir/IR/PatternMatch.h" // from @llvm-project
59 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
60 #include "mlir/IR/Types.h" // from @llvm-project
61 #include "mlir/IR/Value.h" // from @llvm-project
62 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
63 #include "mlir/Parser.h" // from @llvm-project
64 #include "mlir/Support/LLVM.h" // from @llvm-project
65 #include "mlir/Support/LogicalResult.h" // from @llvm-project
66 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
72 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
73 #include "tensorflow/core/platform/logging.h"
74 #include "tensorflow/core/util/tensor_format.h"
75
76 namespace mlir {
77 namespace TF {
78
79 namespace {
80 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
81 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
82 } // namespace
83
84 //===----------------------------------------------------------------------===//
85 // NotEqualOp
86 //===----------------------------------------------------------------------===//
87
Verify(NotEqualOp op)88 static LogicalResult Verify(NotEqualOp op) {
89 // If we allow inputs to have incompatible type, then nothing to do.
90 if (!op.incompatible_shape_error()) return success();
91
92 // Otherwise, check inputs are broadcastable.
93 return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
94 op.getOperation());
95 }
96
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)97 void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
98 Value y, BoolAttr incompatible_shape_error) {
99 auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
100 incompatible_shape_error);
101 return build(builder, result, result_type, x, y, incompatible_shape_error);
102 }
103
104 //===----------------------------------------------------------------------===//
105 // OneHotOp
106 //===----------------------------------------------------------------------===//
107
Verify(OneHotOp op)108 static LogicalResult Verify(OneHotOp op) {
109 int64_t axis = op.axis();
110
111 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
112 if (indices_ty &&
113 !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
114 return op.emitOpError()
115 << "expected axis (" << axis << ") to be -1 or between [0, "
116 << indices_ty.getShape().size() << "]";
117 }
118
119 if (axis < -1) {
120 return op.emitOpError() << "expected axis (" << axis
121 << ") to be -1 or between [0, rank(indices()))";
122 }
123
124 if (!IsOfRankOrUnranked(op.depth(), 0)) {
125 return op.emitOpError() << "requires depth to be a scalar";
126 }
127 if (!IsOfRankOrUnranked(op.on_value(), 0)) {
128 return op.emitOpError() << "requires on_value to be a scalar";
129 }
130 if (!IsOfRankOrUnranked(op.off_value(), 0)) {
131 return op.emitOpError() << "requires off_value to be a scalar";
132 }
133
134 DenseIntElementsAttr depth_attr;
135 if (matchPattern(op.depth(), m_Constant(&depth_attr))) {
136 if (depth_attr.getType().getRank() != 0)
137 return op.emitOpError() << "requires depth to be a scalar";
138 int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
139 if (depth < 0) {
140 return op.emitOpError() << "depth must be non-negative, got: " << depth;
141 }
142 }
143
144 return success();
145 }
146
InferOneHotOpType(Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)147 static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
148 Value off_value, IntegerAttr axis) {
149 int64_t axis_val = axis.getInt();
150 Type element_ty = on_value.getType().cast<TensorType>().getElementType();
151 auto unranked_ty = UnrankedTensorType::get(element_ty);
152 if (axis_val < -1) return unranked_ty;
153
154 auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
155 if (!indices_ty) return unranked_ty;
156
157 auto shape = llvm::to_vector<2>(indices_ty.getShape());
158 if (axis_val == -1) axis_val = shape.size();
159
160 int64_t depth_val = ShapedType::kDynamicSize;
161 DenseIntElementsAttr depth_attr;
162 if (matchPattern(depth, m_Constant(&depth_attr)) &&
163 depth_attr.getNumElements() == 1)
164 depth_val = (*depth_attr.begin()).getSExtValue();
165 shape.insert(shape.begin() + axis_val, depth_val);
166 return RankedTensorType::get(shape, element_ty);
167 }
168
build(OpBuilder & builder,OperationState & result,Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)169 void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices,
170 Value depth, Value on_value, Value off_value,
171 IntegerAttr axis) {
172 build(builder, result,
173 InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
174 depth, on_value, off_value, axis);
175 }
176
177 //===----------------------------------------------------------------------===//
178 // PackOp
179 //===----------------------------------------------------------------------===//
180
Verify(PackOp op)181 static LogicalResult Verify(PackOp op) {
182 // TODO(hinsu): Convert variadic length attributes to derived attributes.
183 Operation::operand_range values = op.values();
184
185 if (failed(VerifyTypesCompatibility(values,
186 /*mask_one_dim=*/false,
187 op.getOperation()))) {
188 return failure();
189 }
190
191 int64_t inputs_rank = -1;
192 for (Value value : values) {
193 if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
194 // Exit early as input types are verified to be compatible so all ranked
195 // tensors have the same rank.
196 inputs_rank = ty.getRank();
197 break;
198 }
199 }
200 if (inputs_rank == -1) return success();
201
202 // The values can be packed along any of the dimensions between 0 and
203 // inputs rank, inclusive. Also, as the negative axis values wrap around so
204 // the axis value range is [-(R+1), R+1).
205 int64_t range_begin = -inputs_rank - 1; // Inclusive
206 int64_t range_end = inputs_rank + 1; // Exclusive
207 int64_t axis = op.axis();
208 if (axis < range_begin || axis >= range_end) {
209 return op.emitError() << "attribute 'axis' should be within range ["
210 << range_begin << ", " << range_end
211 << "); actual value: " << axis;
212 }
213
214 return success();
215 }
216
fold(ArrayRef<Attribute> operands)217 OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
218 // Fold pack operation if it computes the input tensor shape:
219 //
220 // %shape = tf.Shape(%arg) // [? x ...]
221 // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
222 // %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
223 //
224 // Where `...` are some statically known dimensions. In this case %pack can be
225 // replaced with a %shape. This is a common pattern in models with a dynamic
226 // batch size.
227
228 // Pack operation should pack at least two values.
229 if (values().size() < 2) return {};
230
231 // Dimensions packed along axis = 0 (pack scalars into vector).
232 if (axis() != 0) return {};
233
234 // First packed value is defined by a strided slice operation.
235 auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
236 if (!slice_op) return {};
237
238 // Input to the slice op is defined by shape operation.
239 auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
240 if (!shape_op) return {};
241
242 // Input tensor, which shape is reconstructed by the pack operation.
243 Value tensor = shape_op.input();
244
245 // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
246 // scalar value from input vector).
247 if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 ||
248 slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 ||
249 slice_op.shrink_axis_mask() != 1)
250 return {};
251
252 // Returns a value if the `value` is defined by a ConstOp with a single
253 // integer element in it and has an expected rank.
254 auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
255 auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
256 if (!const_op) return None;
257
258 auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
259 if (!value_attr || value_attr.getNumElements() != 1) return None;
260
261 auto value_ty = value_attr.getType();
262 if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
263
264 auto splat = value_attr.getSplatValue<IntegerAttr>();
265 return splat.getValue().getSExtValue();
266 };
267
268 // All other packed values are scalar constants.
269 SmallVector<int64_t, 4> packed_dims;
270 packed_dims.reserve(values().size() - 1);
271 for (Value operand : llvm::drop_begin(values(), 1)) {
272 if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
273 packed_dims.push_back(*dim);
274 } else {
275 return {};
276 }
277 }
278
279 // Slice exactly the first shape dimension:
280 // begin = [0] end = [1], strides = [1]
281 auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
282 auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
283 auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
284 if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
285 *begin != 0 || *end != 1 || *strides != 1)
286 return {};
287
288 // First tensor dimension is dynamic.
289 auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
290 if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
291 !arg_ty.isDynamicDim(0))
292 return {};
293
294 // Argument tensor rank is equal to the number of packed dimensions.
295 if (arg_ty.getRank() != values().size()) return {};
296
297 // All other dimensions are statically known and equal to packed dims.
298 auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
299 if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
300 return {};
301
302 // Replace %pack with %shape.
303 return slice_op.input();
304 }
305
306 // Convert Pack to Reshape when there is only one operand to be packed.
307 // For example,
308 //
309 // %0 = tf.Pack(%input) {axis = 0} // %input : tensor<2x3xf32>
310 //
311 // can be canonicalized to
312 //
313 // %shape = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>}
314 // %0 = tf.Reshape(%input, %shape)
315 struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
316 using OpRewritePattern<PackOp>::OpRewritePattern;
317
matchAndRewritemlir::TF::ConvertPackToReshape318 LogicalResult matchAndRewrite(PackOp pack_op,
319 PatternRewriter &rewriter) const override {
320 // Check if there is only one operand to be packed.
321 if (pack_op.N() != 1) {
322 return failure();
323 }
324
325 // Check if input and output are static.
326 auto input_ty = pack_op.getOperand(0).getType().cast<ShapedType>();
327 auto output_ty = pack_op.output().getType().cast<ShapedType>();
328 if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) {
329 return failure();
330 }
331
332 // Create constant shape for reshape.
333 auto type =
334 RankedTensorType::get(output_ty.getRank(), rewriter.getIntegerType(64));
335 auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
336 auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
337
338 // TODO(b/173622615): Remove after fixed.
339 ReplaceTfOpWithNewOp<ReshapeOp>(rewriter, pack_op, output_ty,
340 pack_op.getOperand(0), shape);
341 return success();
342 }
343 };
344
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)345 void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
346 MLIRContext *context) {
347 results.insert<ConvertPackToReshape>(context);
348 }
349
350 //===----------------------------------------------------------------------===//
351 // PadOp
352 //===----------------------------------------------------------------------===//
353
FoldOperandsPermutation(ArrayRef<int64_t> permutation)354 LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
355 // Paddings must be defined by a constant operation.
356 auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
357 if (!paddings_op) return failure();
358
359 auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
360 if (!paddings_value ||
361 paddings_value.getNumElements() != permutation.size() * 2)
362 return failure();
363
364 SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
365 for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) {
366 size_t outer_idx = index_pair.index() / 2;
367 size_t inner_idx = index_pair.index() % 2;
368
369 shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
370 index_pair.value().getSExtValue();
371 }
372
373 // Add constant operation with a new paddings.
374 OpBuilder builder(getOperation());
375 auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
376 builder.getIntegerType(32));
377 auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
378 auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
379
380 // Use new paddings.
381 setOperand(1, shuffled_paddings_op);
382
383 // Change the result type.
384 getResult().setType(ShuffleRankedTensorType(getResult().getType(),
385 ReversePermutation(permutation)));
386
387 return success();
388 }
389
390 //===----------------------------------------------------------------------===//
391 // ParseExampleV2Op
392 //===----------------------------------------------------------------------===//
393
Verify(ParseExampleV2Op op)394 static LogicalResult Verify(ParseExampleV2Op op) {
395 // NOTE(mrry): This validates properties of an op that would previously be
396 // validated by the TensorFlow OpDef type checker. In addition to these
397 // checks, the shape inference function for ParseExampleV2 validates the
398 // consistency of the argument and result types.
399
400 // Validate dense variadic input and output lengths.
401 // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
402 // do not need to validate dense_defaults.
403 auto dense_types_count =
404 std::distance(op.Tdense().begin(), op.Tdense().end());
405 auto dense_values_count =
406 std::distance(op.dense_values().begin(), op.dense_values().end());
407 if (dense_values_count != dense_types_count) {
408 return op.emitError() << "output 'dense_values' should have same length "
409 << "as attribute 'Tdense'";
410 }
411
412 // Validate sparse variadic output lengths.
413 // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
414 // do not need to validate sparse_values.
415 auto sparse_types_count =
416 std::distance(op.sparse_types().begin(), op.sparse_types().end());
417 if (op.num_sparse() != sparse_types_count) {
418 return op.emitError() << "attribute 'num_sparse' should be the same as "
419 << "the length of attribute 'sparse_types'";
420 }
421 if (op.sparse_indices().size() != sparse_types_count) {
422 return op.emitError() << "output 'sparse_indices' should have same length "
423 << "as attribute 'sparse_types'";
424 }
425 if (op.sparse_shapes().size() != sparse_types_count) {
426 return op.emitError() << "output 'sparse_shapes' should have same length "
427 << "as attribute 'sparse_types'";
428 }
429
430 // Validate ragged variadic output lengths.
431 auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
432 op.ragged_value_types().end());
433 auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
434 op.ragged_split_types().end());
435 if (ragged_value_types_count != ragged_split_types_count) {
436 return op.emitError() << "attribute 'ragged_value_types' should have same "
437 << "length as attribute 'ragged_split_types'";
438 }
439
440 return success();
441 }
442
443 //===----------------------------------------------------------------------===//
444 // PartitionedCallOp
445 //===----------------------------------------------------------------------===//
446
447 template <class OpClass>
VerifyPartitionedCall(OpClass op)448 static LogicalResult VerifyPartitionedCall(OpClass op) {
449 auto module = op->template getParentOfType<ModuleOp>();
450 SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
451
452 auto function =
453 dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
454
455 if (!function) {
456 return op.emitError("'f' attribute refers to an undefined function: ")
457 << func;
458 }
459
460 FunctionType function_ty = function.getType();
461 int func_arg_count = function_ty.getNumInputs();
462 int arg_count = op.args().size();
463
464 if (arg_count != func_arg_count) {
465 return op.emitError() << "argument count mismatch: 'args' has " << arg_count
466 << " arguments, but '" << func << "' expects "
467 << func_arg_count;
468 }
469
470 return success();
471 }
472
473 //===----------------------------------------------------------------------===//
474 // PowOp
475 //===----------------------------------------------------------------------===//
476
fold(ArrayRef<Attribute> operands)477 OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
478 auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
479 if (constant_y && constant_y.isSplat()) {
480 APFloat y_value = constant_y.getSplatValue<APFloat>();
481 auto output_type = getType().cast<ShapedType>();
482 if (y_value.isZero() && output_type.hasStaticShape()) {
483 return DenseElementsAttr::get(
484 output_type,
485 FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
486 }
487 if (y_value.isExactlyValue(1.0)) {
488 return x();
489 }
490 }
491 return {};
492 }
493
494 //===----------------------------------------------------------------------===//
495 // QuantizeAndDequantizeV2Op
496 //===----------------------------------------------------------------------===//
497
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)498 void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns(
499 OwningRewritePatternList &results, MLIRContext *context) {
500 results.insert<QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4>(context);
501 }
502
503 //===----------------------------------------------------------------------===//
504 // QrOp
505 //===----------------------------------------------------------------------===//
506
507 // Verifies that,
508 //
509 // * Input type, if ranked, must have at least 2 dimensions and at most
510 // INT32_MAX dimensions.
511 //
Verify(QrOp op)512 static LogicalResult Verify(QrOp op) {
513 auto ttype = op.input().getType().cast<TensorType>();
514 if (!ttype.hasRank()) return success();
515 if (!HasRankAtLeast(op.input(), 2))
516 return op.emitOpError(
517 "requires ranked input tensor to be of rank 2 or more");
518 if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
519 return op.emitOpError(
520 "requires ranked input tensor to be of rank INT32_MAX or less");
521
522 return success();
523 }
524
525 //===----------------------------------------------------------------------===//
526 // ReadVariableOp
527 //===----------------------------------------------------------------------===//
528
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)529 void ReadVariableOp::getCanonicalizationPatterns(
530 OwningRewritePatternList &results, MLIRContext *context) {
531 results.insert<ReadVariableOfCast>(context);
532 }
533
534 //===----------------------------------------------------------------------===//
535 // RandomUniformOp
536 //===----------------------------------------------------------------------===//
537
Verify(RandomUniformOp op)538 static LogicalResult Verify(RandomUniformOp op) {
539 if (!IsOfRankOrUnranked(op.shape(), 1))
540 return op.emitOpError("shape must be 1D tensor");
541 return success();
542 }
543
544 //===----------------------------------------------------------------------===//
545 // RangeOp
546 //===----------------------------------------------------------------------===//
547
548 namespace {
549
550 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
551 // Template parameter `FloatOrInt` must be standard C integer or floating-point
552 // types.
553 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)554 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
555 // Refer to the implementation in
556 // tensorflow/lite/kernels/range.cc.
557 FloatOrInt diff = limit - start;
558 if (std::is_integral<FloatOrInt>::value) {
559 return ((std::abs(diff) + std::abs(delta) - 1) / std::abs(delta));
560 }
561 return std::ceil(std::abs(diff / delta));
562 }
563
564 // Builds a constant range tensor of `result_elem_type` elements.
565 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
566 // mlir::FloatAttr.
567 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)568 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
569 FloatOrIntAtrr start_attr,
570 FloatOrIntAtrr delta_attr) {
571 using ValueType = typename FloatOrIntAtrr::ValueType; // APInt or APFloat
572 ValueType start = start_attr.getValue();
573 ValueType delta = delta_attr.getValue();
574
575 SmallVector<ValueType, 16> new_values;
576 new_values.reserve(num_elements);
577 ValueType new_value = start;
578 for (int i = 0; i < num_elements; ++i) {
579 new_values.push_back(new_value);
580 new_value = new_value + delta;
581 }
582 // Result is always a 1-D tensor.
583 auto new_result_type =
584 RankedTensorType::get({num_elements}, result_elem_type);
585 return DenseElementsAttr::get(new_result_type, new_values);
586 }
587 } // namespace
588
build(OpBuilder & builder,OperationState & result,Value start,Value limit,Value delta)589 void RangeOp::build(OpBuilder &builder, OperationState &result, Value start,
590 Value limit, Value delta) {
591 assert(start.getType() == limit.getType());
592 assert(start.getType() == delta.getType());
593 DenseIntElementsAttr start_val;
594 DenseIntElementsAttr limit_val;
595 DenseIntElementsAttr delta_val;
596 if (matchPattern(start, m_Constant(&start_val)) &&
597 matchPattern(limit, m_Constant(&limit_val)) &&
598 matchPattern(delta, m_Constant(&delta_val))) {
599 auto size = llvm::APIntOps::RoundingSDiv(
600 *limit_val.begin() - *start_val.begin(), *delta_val.begin(),
601 llvm::APInt::Rounding::DOWN);
602 return RangeOp::build(
603 builder, result,
604 RankedTensorType::get(
605 size.getSExtValue(),
606 start.getType().cast<TensorType>().getElementType()),
607 start, limit, delta);
608 }
609 return RangeOp::build(
610 builder, result,
611 RankedTensorType::get(
612 {-1}, start.getType().cast<TensorType>().getElementType()),
613 start, limit, delta);
614 }
615
fold(ArrayRef<Attribute> operands)616 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
617 assert(operands.size() == 3);
618 auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
619 auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
620 auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
621 if (!(start_tensor && limit_tensor && delta_tensor)) return nullptr;
622
623 // Operands should all be scalars
624 assert(start_tensor.getType().getRank() == 0 &&
625 limit_tensor.getType().getRank() == 0 &&
626 delta_tensor.getType().getRank() == 0);
627 Type elem_type = getType().cast<ShapedType>().getElementType();
628 if (elem_type.isSignlessInteger() || elem_type.isUnsignedInteger()) {
629 auto start_attr = start_tensor.getValue<IntegerAttr>({});
630 auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
631 auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
632 int num_elements;
633 if (elem_type.isUnsignedInteger()) {
634 uint64_t start = start_attr.getUInt();
635 uint64_t limit = limit_attr.getUInt();
636 uint64_t delta = delta_attr.getUInt();
637 assert(start <= (uint64_t)INT_MAX);
638 assert(limit <= (uint64_t)INT_MAX);
639 assert(delta <= (uint64_t)INT_MAX);
640 num_elements =
641 GetLengthOfRange(static_cast<int>(start), static_cast<int>(limit),
642 static_cast<int>(delta));
643 } else {
644 num_elements = GetLengthOfRange(start_attr.getInt(), limit_attr.getInt(),
645 delta_attr.getInt());
646 }
647 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
648 delta_attr);
649 } else if (elem_type.isa<FloatType>()) {
650 auto start_attr = start_tensor.getValue<FloatAttr>({});
651 auto limit_attr = limit_tensor.getValue<FloatAttr>({});
652 auto delta_attr = delta_tensor.getValue<FloatAttr>({});
653 const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
654 limit_attr.getValueAsDouble(),
655 delta_attr.getValueAsDouble());
656 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
657 delta_attr);
658 }
659 return nullptr;
660 }
661
662 //===----------------------------------------------------------------------===//
663 // RankOp
664 //===----------------------------------------------------------------------===//
665
build(OpBuilder & builder,OperationState & result,Value input)666 void RankOp::build(OpBuilder &builder, OperationState &result, Value input) {
667 return RankOp::build(builder, result,
668 RankedTensorType::get({}, builder.getIntegerType(32)),
669 input);
670 }
671
672 // This will create a constant value for RankOp of a ranked tensor.
fold(ArrayRef<Attribute> operands)673 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
674 auto type = input().getType();
675 auto ranked_type = type.dyn_cast<RankedTensorType>();
676 if (!ranked_type) return {};
677
678 // DenseIntElementsAttr::get requires the output type be ranked with static
679 // shape.
680 auto output_type = getType().dyn_cast<RankedTensorType>();
681 if (!output_type || !output_type.hasStaticShape()) return {};
682
683 int32_t rank = ranked_type.getRank();
684 return DenseIntElementsAttr::get(output_type, rank);
685 }
686
687 //===----------------------------------------------------------------------===//
688 // RealDivOp
689 //===----------------------------------------------------------------------===//
690
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)691 void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
692 MLIRContext *context) {
693 results.insert<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
694 }
695
fold(ArrayRef<Attribute> operands)696 OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
697 return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
698 }
699
700 //===----------------------------------------------------------------------===//
701 // ReshapeOp
702 //===----------------------------------------------------------------------===//
703
704 namespace {
705 using ReshapeErrorHandler =
706 llvm::function_ref<LogicalResult(const llvm::Twine &)>;
707
GetReshapeOutputType(Value tensor,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)708 LogicalResult GetReshapeOutputType(Value tensor, Value shape,
709 ReshapeErrorHandler error_handler,
710 TensorType &output_ty) {
711 auto tensor_ty = tensor.getType().cast<TensorType>();
712 auto element_ty = tensor_ty.getElementType();
713 output_ty = UnrankedTensorType::get(element_ty);
714
715 auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
716 if (!shape_ty) return success();
717 if (shape_ty.getRank() != 1)
718 return error_handler(llvm::formatv(
719 "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
720
721 DenseIntElementsAttr shape_attr;
722 if (!matchPattern(shape, m_Constant(&shape_attr))) {
723 // If only shape of `shape` is known, return ranked but dynamic output
724 // shape.
725 if (shape_ty.hasStaticShape()) {
726 llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
727 ShapedType::kDynamicSize);
728 output_ty = RankedTensorType::get(dynamic_shape, element_ty);
729 }
730 return success();
731 }
732
733 // Detect if reshape output shape is folded.
734 bool shape_ty_zero_dim = false;
735 int unknown_index = -1;
736 // The product of constant shape argument excluding unknown dimension.
737 int64_t shape_ty_size = 1;
738 llvm::SmallVector<int64_t, 8> output_ty_shape;
739 output_ty_shape.reserve(shape_attr.getNumElements());
740 for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) {
741 const int64_t size = dim.value().getSExtValue();
742 if (size == ShapedType::kDynamicSize) {
743 if (unknown_index != -1)
744 return error_handler(llvm::formatv(
745 "requires 'shape' to have at most one dynamic dimension, but got "
746 "multiple dynamic dimensions at indices {0} and {1}",
747 unknown_index, dim.index()));
748
749 unknown_index = dim.index();
750 } else if (size == 0) {
751 shape_ty_zero_dim = true;
752 } else if (size > 0) {
753 shape_ty_size *= size;
754 } else {
755 return error_handler(
756 llvm::formatv("requires 'shape' to have dimensions greater than -1, "
757 "but got {0} at index {1}",
758 size, dim.index()));
759 }
760 output_ty_shape.push_back(size);
761 }
762
763 if (!tensor_ty.hasStaticShape()) {
764 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
765 return success();
766 }
767
768 // Compute the value of the unknown dimension.
769 if (unknown_index != -1) {
770 // Compute number of elements in tensor shape.
771 int64_t tensor_ty_size = 1;
772 bool tensor_ty_zero_dim = false;
773 for (const auto &dim : tensor_ty.getShape()) {
774 if (dim > 0 || !shape_ty_zero_dim) {
775 tensor_ty_size *= dim;
776 } else {
777 tensor_ty_zero_dim = true;
778 }
779 }
780
781 const int64_t missing_dim = tensor_ty_size / shape_ty_size;
782 if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size)
783 return error_handler(
784 llvm::formatv("requires 'tensor' number of elements be a multiple of "
785 "{0}, but got {1}",
786 shape_ty_size, tensor_ty_size));
787
788 // Set the unknown dimension such that total number of elements remain
789 // constant.
790 output_ty_shape[unknown_index] = missing_dim;
791 }
792
793 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
794
795 return success();
796 }
797 } // namespace
798
Verify(ReshapeOp op)799 static LogicalResult Verify(ReshapeOp op) {
800 auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
801 return op.emitOpError() << message;
802 };
803 TensorType expected_ty;
804 if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler,
805 expected_ty)))
806 return failure();
807
808 auto output_ty = op.getType().dyn_cast<RankedTensorType>();
809 if (!output_ty) return success();
810 auto tensor_ty = op.tensor().getType().cast<TensorType>();
811 if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) {
812 const int64_t output_ty_size = output_ty.getNumElements();
813 const int64_t tensor_ty_size = tensor_ty.getNumElements();
814 if (tensor_ty_size != output_ty_size)
815 return op.emitOpError() << "requires 'output' number of elements to "
816 "match 'tensor' number of elements, but got "
817 << output_ty_size << " and " << tensor_ty_size;
818 }
819
820 if (!AreCastCompatible({output_ty, expected_ty}))
821 return op.emitOpError()
822 << "requires 'output' type " << output_ty
823 << " to be cast compatible with expected type " << expected_ty;
824
825 return success();
826 }
827
828 // Currently there are use cases that rely on partial evaluation of the `shape`
829 // operand, so InferTypeOpInterface is not used (along with generated builder of
830 // the same signature).
build(OpBuilder & builder,OperationState & result,Value tensor,Value shape)831 void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
832 Value shape) {
833 auto error_handler = [&result](const llvm::Twine &message) {
834 return mlir::emitError(result.location) << message;
835 };
836 TensorType output_ty;
837 if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty)))
838 return;
839
840 return ReshapeOp::build(builder, result, output_ty, tensor, shape);
841 }
842
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)843 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
844 MLIRContext *context) {
845 results.insert<RedundantReshape, ReshapeToSelfShape>(context);
846 }
847
fold(ArrayRef<Attribute> operands)848 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
849 Value tensor = this->tensor();
850
851 // Fold reshape if operand and result types are the same and all dimensions
852 // are statically known (no-op reshape).
853 auto result_ty = getType().dyn_cast<ShapedType>();
854 if (result_ty && result_ty.hasStaticShape() &&
855 result_ty == tensor.getType()) {
856 return tensor;
857 }
858
859 return {};
860 }
861
862 //===----------------------------------------------------------------------===//
863 // SelectOp
864 //===----------------------------------------------------------------------===//
865
866 // Verifies a few extra requirements on SelectOp:
867 // (1) `then` and `else` must have same shape
868 // (2) At least one of the following must be true:
869 // (a) `cond` has the same rank as `then` and `else`
870 // (b) `cond` is a scalar
871 // (c) `cond` is a vector AND `then` and `else` are non-scalar with their
872 // first dimension equal to `cond`.
Verify(SelectOp op)873 static LogicalResult Verify(SelectOp op) {
874 auto then_tensor = op.t().getType().cast<TensorType>();
875 auto else_tensor = op.e().getType().cast<TensorType>();
876 // Check (1).
877 if (!AreCastCompatible({then_tensor, else_tensor}))
878 return op.emitOpError() << "requires t and e have compatible shapes";
879
880 // Get data rank (if exists).
881 int data_rank;
882 // If data is unranked or data_rank is 0, this will remain -2. Otherwise
883 // refers to first dimension of then and/or else.
884 int data_first_dim = -2;
885 bool then_has_rank = then_tensor.hasRank();
886 bool else_has_rank = else_tensor.hasRank();
887 if (then_has_rank && else_has_rank) {
888 data_rank = then_tensor.getRank();
889 if (then_tensor.getRank() > 0)
890 data_first_dim = then_tensor.getShape().front();
891 if (else_tensor.getRank() > 0)
892 data_first_dim = std::max(
893 static_cast<int>(else_tensor.getShape().front()), data_first_dim);
894 } else if (then_has_rank) {
895 data_rank = then_tensor.getRank();
896 if (then_tensor.getRank() > 0)
897 data_first_dim = then_tensor.getShape().front();
898 } else if (else_has_rank) {
899 data_rank = else_tensor.getRank();
900 if (else_tensor.getRank() > 0)
901 data_first_dim = else_tensor.getShape().front();
902 } else {
903 // Neither has a rank.
904 return success();
905 }
906
907 auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
908 if (!cond_tensor) return success();
909 auto cond_rank = cond_tensor.getRank();
910 // Check (2a) and (2b).
911 if (cond_rank == 0 || cond_rank == data_rank) return success();
912 // Check (2c).
913 if (cond_rank == 1) {
914 auto cond_shape = cond_tensor.getShape().front();
915 if (data_rank == 0) {
916 return op.emitOpError()
917 << "requires that t and e are nonscalar when pred is a vector";
918 }
919 // We know `data` tensor has a rank of at least 1.
920 if (data_first_dim != -1 && cond_shape != -1 &&
921 data_first_dim != cond_shape) {
922 return op.emitOpError() << "requires that, when pred is a vector, the "
923 "shape matches the first dimension of t and e";
924 }
925 return success();
926 }
927 // None of (2a,b,c) were true; fail.
928 return op.emitOpError() << "requires that pred is a scalar OR has the same "
929 "rank as t and e OR is a vector";
930 }
931
932 //===----------------------------------------------------------------------===//
933 // SelectV2Op
934 //===----------------------------------------------------------------------===//
935
InferSelectV2OpType(Value condition,Value e,Value t)936 static Type InferSelectV2OpType(Value condition, Value e, Value t) {
937 Type element_ty = e.getType().cast<TensorType>().getElementType();
938 auto unranked_ty = UnrankedTensorType::get(element_ty);
939
940 Type broadcasted_ty =
941 OpTrait::util::getBroadcastedType(e.getType(), t.getType());
942 if (!broadcasted_ty) return unranked_ty;
943
944 auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
945 auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
946 if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
947
948 // Explicitly get broadcasted output type as element types of condition may
949 // not be same as the broadcated type's element type.
950 SmallVector<int64_t, 4> result_shape;
951 if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(),
952 broadcasted_ranked_ty.getShape(),
953 result_shape))
954 return unranked_ty;
955 return RankedTensorType::get(result_shape, element_ty);
956 }
957
build(OpBuilder & builder,OperationState & result,Value condition,Value e,Value t)958 void SelectV2Op::build(OpBuilder &builder, OperationState &result,
959 Value condition, Value e, Value t) {
960 build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
961 }
962
963 //===----------------------------------------------------------------------===//
964 // ShapeOp
965 //===----------------------------------------------------------------------===//
966
967 namespace {
968 // Validates Shape/ShapeN/VariableShape operand and associated result types.
VerifyShapeOperandAndResult(Operation * op,Type operand_type,Type result_type,int variadic_idx=-1)969 LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
970 Type result_type,
971 int variadic_idx = -1) {
972 std::string variadic_idx_str =
973 variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str();
974
975 auto result_ranked_type = result_type.dyn_cast<RankedTensorType>();
976 if (!result_ranked_type) return success();
977 if (result_ranked_type.getShape().size() != 1)
978 return op->emitOpError("requires 1D type for result") << variadic_idx_str;
979
980 auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
981 if (operand_ranked_type) {
982 // The operand is a ranked tensor.
983 if (result_ranked_type.hasStaticShape() &&
984 !operand_ranked_type.getShape().empty() &&
985 result_ranked_type.getDimSize(0) !=
986 operand_ranked_type.getShape().size())
987 return op->emitOpError("requires dimension size of result")
988 << variadic_idx_str << " to match rank of operand"
989 << variadic_idx_str;
990 } else if (result_ranked_type.hasStaticShape()) {
991 // The operand is an unranked tensor, print a warning if the result
992 // is static.
993 // Note: We do not handle this situation as an error, this would be too
994 // restrictive due to incompleteness of shape inference at this point.
995 op->emitWarning("has static shape result")
996 << variadic_idx_str << " for unranked operand" << variadic_idx_str;
997 }
998
999 Type element_type = result_ranked_type.getElementType();
1000 if (!element_type.isSignlessInteger(32) &&
1001 !element_type.isSignlessInteger(64))
1002 return op->emitOpError("requires int32 or int64 return type for result")
1003 << variadic_idx_str;
1004
1005 return success();
1006 }
1007 } // anonymous namespace
1008
Verify(ShapeOp op)1009 static LogicalResult Verify(ShapeOp op) {
1010 return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
1011 }
1012
1013 // Converts shape of the given type to attribute if it is of ranked tensor type.
1014 // Returned attribute has integer elements of the given width.
ConvertShapeToAttr(Type input_ty,int out_width)1015 static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
1016 auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
1017 if (!ranked_ty || !ranked_ty.hasStaticShape()) return {};
1018
1019 auto shape = ranked_ty.getShape();
1020 int rank = shape.size();
1021
1022 SmallVector<APInt, 4> dimensions;
1023 dimensions.reserve(rank);
1024 for (int i = 0; i < rank; ++i)
1025 dimensions.push_back(APInt(out_width, shape[i]));
1026
1027 auto result_type = RankedTensorType::get(
1028 {rank}, IntegerType::get(input_ty.getContext(), out_width));
1029 return DenseElementsAttr::get(result_type, dimensions);
1030 }
1031
fold(ArrayRef<Attribute> operands)1032 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
1033 int width =
1034 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
1035 return ConvertShapeToAttr(getOperand().getType(), width);
1036 }
1037
build(OpBuilder & builder,OperationState & result,Value input,BoolAttr use32Bit)1038 void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input,
1039 BoolAttr use32Bit) {
1040 auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
1041 int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
1042 auto out_type = use32Bit.getValue() ? builder.getIntegerType(32)
1043 : builder.getIntegerType(64);
1044 return ShapeOp::build(builder, result,
1045 RankedTensorType::get({rank}, out_type), input);
1046 }
1047
1048 //===----------------------------------------------------------------------===//
1049 // ShapeNOp
1050 //===----------------------------------------------------------------------===//
1051
Verify(ShapeNOp op)1052 static LogicalResult Verify(ShapeNOp op) {
1053 const size_t num_tensors = op.N();
1054
1055 if (op.getNumOperands() != num_tensors)
1056 return op.emitOpError() << "requires " << num_tensors << " operand(s), got "
1057 << op.getNumOperands() << " operand(s)";
1058
1059 if (op.getNumResults() != num_tensors)
1060 return op.emitOpError() << "requires " << num_tensors << " result(s), got "
1061 << op.getNumResults() << " result(s)";
1062
1063 for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
1064 auto verification = VerifyShapeOperandAndResult(
1065 op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
1066 if (failed(verification)) return verification;
1067 }
1068
1069 return success();
1070 }
1071
1072 namespace {
1073 // Canonicalization pattern for ShapeNOp that don't have all
1074 // static input shapes. Replacing output values corresponding to static input
1075 // types may enable optimizations in users of the values.
1076 class ShapeNPartialStaticInputShape : public OpRewritePattern<ShapeNOp> {
1077 using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const1078 LogicalResult matchAndRewrite(ShapeNOp op,
1079 PatternRewriter &rewriter) const override {
1080 if (op.getNumOperands() == 0) {
1081 rewriter.eraseOp(op);
1082 return success();
1083 }
1084
1085 int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
1086
1087 SmallVector<Value, 4> results(op.getNumOperands());
1088 SmallVector<int64_t, 4> dynamic_indices;
1089 SmallVector<Value, 4> dynamic_inputs;
1090 SmallVector<Type, 4> result_types;
1091 for (auto e : llvm::enumerate(op.getOperands())) {
1092 if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) {
1093 results[e.index()] = rewriter.create<TF::ConstOp>(op.getLoc(), result);
1094 } else {
1095 dynamic_indices.push_back(e.index());
1096 dynamic_inputs.push_back(e.value());
1097 result_types.push_back(op.getType(e.index()));
1098 }
1099 }
1100
1101 if (dynamic_inputs.size() == op.getNumOperands()) {
1102 // Cannot canonicalize ShapeN if all inputs are dynamic.
1103 return failure();
1104 }
1105
1106 // Create a ShapeNOp for all dynamic inputs.
1107 if (!dynamic_inputs.empty()) {
1108 auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
1109 op.getLoc(), result_types, dynamic_inputs);
1110 for (auto index_result :
1111 llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) {
1112 results[std::get<0>(index_result)] = std::get<1>(index_result);
1113 }
1114 }
1115
1116 rewriter.replaceOp(op, results);
1117 return success();
1118 }
1119 };
1120
1121 // Canonicalize ShapeNOp to ShapeOp if there is only one operand.
1122 class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
1123 using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const1124 LogicalResult matchAndRewrite(ShapeNOp op,
1125 PatternRewriter &rewriter) const override {
1126 if (op.getNumOperands() != 1) {
1127 return failure();
1128 }
1129 auto shape = rewriter.create<TF::ShapeOp>(op.getLoc(), op.getType(0),
1130 op.getOperand(0));
1131 rewriter.replaceOp(op, {shape});
1132 return success();
1133 }
1134 };
1135 } // namespace
1136
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1137 void ShapeNOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1138 MLIRContext *context) {
1139 results.insert<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
1140 }
1141
1142 //===----------------------------------------------------------------------===//
1143 // SizeOp
1144 //===----------------------------------------------------------------------===//
1145
1146 // Verifies that,
1147 //
1148 // * Input type, if is a ranked tensor, has at most INT32_MAX dimensions.
1149 //
Verify(SizeOp op)1150 static LogicalResult Verify(SizeOp op) {
1151 if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
1152 return op.emitOpError(
1153 "requires ranked input tensor to be of rank INT32_MAX or less");
1154
1155 // Output type needs to be scalar.
1156 if (!IsOfRankOrUnranked(op.output(), /*rank=*/0))
1157 return op.emitOpError("requires scalar output");
1158
1159 return success();
1160 }
1161
fold(ArrayRef<Attribute> operands)1162 OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
1163 ShapedType output_type = getType().cast<ShapedType>();
1164 if (!output_type.hasRank()) return {};
1165 ShapedType input_type = getOperand().getType().cast<ShapedType>();
1166 if (!input_type.hasStaticShape()) return {};
1167 int size = input_type.getNumElements();
1168 return DenseElementsAttr::get(
1169 output_type,
1170 IntegerAttr::get(output_type.getElementType(), /*value=*/size));
1171 }
1172
1173 //===----------------------------------------------------------------------===//
1174 // SliceOp
1175 //===----------------------------------------------------------------------===//
1176
1177 // Verifies that:
1178 //
1179 // - operands begin and size are 1D with the same number of elements.
1180 // - if the input is a ranked tensor, the rank of the input equals the number
1181 // of elements in operands begin and size.
1182 // - if begin are constants, that
1183 // 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
1184 // and
1185 // size[i] == output_ty.getShape()[i]
1186 // - if begins aren't constant but the input is a ranked tensor, that
1187 // size[i] <= input_ty.getShape()[i]
1188 // - output rank is the same as input rank
1189 //
Verify(SliceOp op)1190 static LogicalResult Verify(SliceOp op) {
1191 RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
1192 if (begin_ty && begin_ty.getRank() != 1) {
1193 return op.emitOpError() << "requires begin operand to be 1D tensor";
1194 }
1195
1196 RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
1197 if (size_ty && size_ty.getRank() != 1) {
1198 return op.emitOpError() << "requires size operand to be 1D tensor";
1199 }
1200
1201 if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
1202 !size_ty.hasStaticShape())
1203 return success();
1204
1205 if (begin_ty.getNumElements() != size_ty.getNumElements()) {
1206 return op.emitOpError() << "requires begin and size operands to have the"
1207 " same number of elements";
1208 }
1209
1210 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1211 if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
1212 return op.emitOpError() << "requires number of elements in begin and size "
1213 "are equal to input rank";
1214 }
1215
1216 auto output_ty = op.output().getType().dyn_cast<RankedTensorType>();
1217 if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) {
1218 return op.emitOpError()
1219 << "requires output to have the same rank as input, but got input "
1220 "rank "
1221 << input_ty.getRank() << " and output rank " << output_ty.getRank();
1222 }
1223
1224 DenseIntElementsAttr begin_indices;
1225 if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
1226 DenseIntElementsAttr slice_sizes;
1227 bool constant_slice_sizes =
1228 matchPattern(op.size(), m_Constant(&slice_sizes));
1229 int dim = 0;
1230 // TODO(jpienaar): Reformulate the shape verification below to not use magic
1231 // constants.
1232 for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
1233 int64_t begin_index = raw_begin_index.getSExtValue();
1234 int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
1235 int64_t slice_size = constant_slice_sizes
1236 ? slice_sizes.getValue<APInt>(dim).getSExtValue()
1237 : 0;
1238 int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1;
1239
1240 if (slice_size == -1 && input_size != -1) {
1241 slice_size = input_size - begin_index;
1242 }
1243 if (output_size != -1 && constant_slice_sizes &&
1244 output_size != slice_size) {
1245 return op.emitOpError()
1246 << "requires output size to have the same size of slice, got "
1247 "slice size "
1248 << slice_size << " and output size " << output_size;
1249 }
1250 if (begin_index < 0 ||
1251 (input_size != -1 && begin_index + slice_size > input_size)) {
1252 return op.emitOpError()
1253 << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
1254 }
1255 ++dim;
1256 }
1257 } else if (input_ty) {
1258 // If the inputs are ranked, we can do a few more sanity checks.
1259 DenseIntElementsAttr slice_sizes;
1260 if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
1261 auto input_shape = input_ty.getShape();
1262 for (int64_t i = 0; i < input_ty.getRank(); ++i) {
1263 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
1264 int64_t input_size = input_shape[i];
1265 if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
1266 return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
1267 "is unknown at compile time";
1268 }
1269 }
1270 }
1271 }
1272
1273 return success();
1274 }
1275
1276 //===----------------------------------------------------------------------===//
1277 // SoftmaxOp
1278 //===----------------------------------------------------------------------===//
1279
Verify(SoftmaxOp op)1280 static LogicalResult Verify(SoftmaxOp op) {
1281 if (!HasRankAtLeast(op.logits(), 1)) {
1282 return op.emitOpError("requires operand to have rank at least 1");
1283 }
1284 return success();
1285 }
1286
1287 //===----------------------------------------------------------------------===//
1288 // SoftmaxCrossEntropyWithLogitsOp
1289 //===----------------------------------------------------------------------===//
1290
1291 // Verifies that,
1292 //
1293 // * Input types are broadcast compatible and the broadcasted type has rank two.
1294 //
Verify(SoftmaxCrossEntropyWithLogitsOp op)1295 static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
1296 auto broadcasted_ty = OpTrait::util::getBroadcastedType(
1297 op.features().getType(), op.labels().getType())
1298 .dyn_cast_or_null<ShapedType>();
1299 if (!broadcasted_ty ||
1300 (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
1301 return op.emitOpError(
1302 "requires features and labels to be broadcast compatible to rank two");
1303
1304 return success();
1305 }
1306
1307 //===----------------------------------------------------------------------===//
1308 // SpaceToBatchNDOp
1309 //===----------------------------------------------------------------------===//
1310
SpaceToBatchNDBlockRank(const TensorType block_shape_type,const TensorType paddings_type)1311 int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type,
1312 const TensorType paddings_type) {
1313 if (block_shape_type.hasStaticShape()) {
1314 return block_shape_type.getShape()[0];
1315 } else if (paddings_type.hasStaticShape()) {
1316 return paddings_type.getShape()[0];
1317 } else {
1318 return -1;
1319 }
1320 }
1321
Verify(SpaceToBatchNDOp op)1322 static LogicalResult Verify(SpaceToBatchNDOp op) {
1323 const auto input_type = op.input().getType().cast<TensorType>();
1324 const auto block_shape_type = op.block_shape().getType().cast<TensorType>();
1325 const auto paddings_type = op.paddings().getType().cast<TensorType>();
1326
1327 // Check that block_shape has rank 1.
1328 if (!IsOfRankOrUnranked(op.block_shape(), 1)) {
1329 return op.emitOpError() << "requires rank of block_shape = 1; got "
1330 << block_shape_type.getRank();
1331 }
1332
1333 // Check that paddings has rank 2.
1334 if (!IsOfRankOrUnranked(op.paddings(), 2)) {
1335 return op.emitOpError()
1336 << "requires rank of paddings = 2; got " << paddings_type.getRank();
1337 }
1338
1339 // Check that paddings.shape[1]=2.
1340 if (paddings_type.hasStaticShape() && paddings_type.getShape()[1] != 2) {
1341 return op.emitOpError() << "requires paddings.shape[1] to be 2; got "
1342 << paddings_type.getShape()[1];
1343 }
1344
1345 // Check that block_shape and paddings have consistent ranks.
1346 if (block_shape_type.hasStaticShape() && paddings_type.hasStaticShape() &&
1347 block_shape_type.getShape()[0] != paddings_type.getShape()[0]) {
1348 return op.emitOpError()
1349 << "requires block_shape.shape[0] must equal paddings.shape[0]";
1350 }
1351
1352 const int64_t block_rank =
1353 SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
1354
1355 // Further checks require block_rank to be known.
1356 if (block_rank == -1) {
1357 return success();
1358 }
1359
1360 // check that rank of input_type >= block_rank + 1
1361 if (input_type.hasRank() && input_type.getRank() < 1 + block_rank) {
1362 return op.emitOpError() << "requires rank of input >= 1 + rank of block";
1363 }
1364
1365 ElementsAttr block_shape_attr = nullptr;
1366 ElementsAttr paddings_attr = nullptr;
1367
1368 // Check that block_shape[*] >= 1.
1369 if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) {
1370 uint64_t i = 0;
1371 for (auto block_len : block_shape_attr.getValues<APInt>()) {
1372 if (block_len.getSExtValue() < 1) {
1373 return op.emitOpError()
1374 << "requires all values of block_shape to be >= 1; "
1375 "failed for dimension "
1376 << i;
1377 }
1378 ++i;
1379 }
1380 }
1381
1382 // Check that paddings[*] >= 0.
1383 if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) {
1384 for (uint64_t i = 0; i < block_rank; ++i) {
1385 const int64_t pad_start =
1386 paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt();
1387 const int64_t pad_end =
1388 paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
1389 if (pad_start < 0 || pad_end < 0) {
1390 return op.emitOpError()
1391 << "requires all values of paddings to be >= 0; "
1392 "failed for dimension "
1393 << i;
1394 }
1395 }
1396 }
1397
1398 // Check that block_shape divides the padded input.
1399 if (input_type.hasStaticShape() && block_shape_attr && paddings_attr) {
1400 for (uint64_t i = 0; i < block_rank; ++i) {
1401 const int64_t input_len = input_type.getShape()[1 + i];
1402 const int64_t pad_start =
1403 paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt();
1404 const int64_t pad_end =
1405 paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
1406 const int64_t block_len =
1407 block_shape_attr.getValue({i}).cast<IntegerAttr>().getInt();
1408 if ((input_len + pad_start + pad_end) % block_len != 0) {
1409 return op.emitOpError()
1410 << "requires block_shape[i] divides "
1411 "input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; "
1412 "failed for i="
1413 << i;
1414 }
1415 }
1416 }
1417
1418 return success();
1419 }
1420
1421 //===----------------------------------------------------------------------===//
1422 // SparseSoftmaxCrossEntropyWithLogitsOp
1423 //===----------------------------------------------------------------------===//
1424
Verify(SparseSoftmaxCrossEntropyWithLogitsOp op)1425 static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) {
1426 if (!IsOfRankOrUnranked(op.features(), 2)) {
1427 return op.emitOpError("requires features operand of rank two");
1428 }
1429 if (!IsOfRankOrUnranked(op.labels(), 1)) {
1430 return op.emitOpError("requires labels operand of rank one");
1431 }
1432 auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
1433 auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
1434 if (features_ty && labels_ty) {
1435 int64_t features_batches = features_ty.getDimSize(0);
1436 int64_t labels_batches = labels_ty.getDimSize(0);
1437 if (!ShapedType::isDynamic(features_batches) &&
1438 !ShapedType::isDynamic(labels_batches) &&
1439 features_batches != labels_batches)
1440 return op.emitOpError(
1441 "requires features and labels with matching first dimension");
1442 }
1443 return success();
1444 }
1445
1446 //===----------------------------------------------------------------------===//
1447 // SplitOp
1448 //===----------------------------------------------------------------------===//
1449
1450 // Verifies the input and split dimension operands for tf.Split/tf.SplitV.
1451 // Writes the split dimension's index (adjusted with input rank) via `dim_index`
1452 // if it's a constant.
1453 template <class Op>
VerifySplitInputAndSplitDim(Op op,Optional<int64_t> * dim_index)1454 LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
1455 *dim_index = llvm::None;
1456
1457 Value split_dim = op.split_dim();
1458 if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
1459 if (split_dim_type.getRank() != 0)
1460 return op.emitOpError(
1461 "split dimension should be an integer scalar tensor");
1462
1463 // We can perform further verification if the input tensor to be split has
1464 // known rank and the split dimension tensor is a constant.
1465
1466 auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
1467 if (!input_type) return success();
1468
1469 int64_t input_rank = input_type.getRank();
1470 if (input_rank == 0)
1471 return op.emitOpError("cannot split scalar input tensor");
1472
1473 DenseIntElementsAttr split_dim_attr;
1474 if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
1475
1476 int64_t index = (*split_dim_attr.begin()).getSExtValue();
1477
1478 if (index + input_rank < 0 || index >= input_rank) {
1479 return op.emitOpError("split dimension must be in range [-")
1480 << input_rank << ", " << input_rank << ")";
1481 }
1482
1483 if (index < 0) index += input_rank;
1484 *dim_index = index;
1485
1486 return success();
1487 }
1488
Verify(SplitOp op)1489 static LogicalResult Verify(SplitOp op) {
1490 Optional<int64_t> dim_index;
1491 if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1492 if (!dim_index) return success();
1493
1494 int64_t input_dim_size =
1495 op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1496 if (input_dim_size == ShapedType::kDynamicSize) return success();
1497
1498 if (input_dim_size % op.getNumResults() != 0)
1499 return op.emitOpError("dimension #")
1500 << *dim_index << " not divisible by the number of result tensors";
1501
1502 return success();
1503 }
1504
1505 //===----------------------------------------------------------------------===//
1506 // SplitVOp
1507 //===----------------------------------------------------------------------===//
1508
Verify(SplitVOp op)1509 static LogicalResult Verify(SplitVOp op) {
1510 auto split_sizes_type =
1511 op.size_splits().getType().dyn_cast<RankedTensorType>();
1512 if (!split_sizes_type) return success();
1513
1514 if (split_sizes_type.getRank() != 1 ||
1515 (split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize &&
1516 split_sizes_type.getDimSize(0) != op.getNumResults()))
1517 return op.emitOpError("split sizes should be a 1D tensor of ")
1518 << op.getNumResults() << " elements";
1519
1520 Optional<int64_t> dim_index = 0;
1521 if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1522 if (!dim_index) return success();
1523
1524 int64_t input_dim_size =
1525 op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1526 if (input_dim_size == ShapedType::kDynamicSize) return success();
1527
1528 // If split sizes come from a constant, they must sum to the dimension size
1529 // along split_dim, and we can have no more than one dynamic dimension.
1530 DenseIntElementsAttr split_sizes_attr;
1531 if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
1532 return success();
1533
1534 int64_t total_dim_size = 0; // Total dimension size assigned to splits
1535 llvm::Optional<int> dynamic_dim_index;
1536
1537 SmallVector<int64_t, 4> split_sizes;
1538 split_sizes.reserve(
1539 split_sizes_attr.getType().cast<ShapedType>().getNumElements());
1540
1541 for (auto dim : llvm::enumerate(split_sizes_attr)) {
1542 int64_t dim_val = dim.value().getSExtValue();
1543 split_sizes.push_back(dim_val);
1544 if (dim_val == ShapedType::kDynamicSize) {
1545 // We cannot have more than one dynamic dimension.
1546 if (dynamic_dim_index)
1547 return op.emitOpError(
1548 "cannot have more than one dynamic dimension in split sizes");
1549 dynamic_dim_index = dim.index();
1550 } else {
1551 total_dim_size += dim_val;
1552 }
1553 }
1554
1555 if (!dynamic_dim_index && total_dim_size != input_dim_size)
1556 return op.emitOpError(
1557 "split sizes must sum up to the dimension size along split "
1558 "dimension, found ")
1559 << total_dim_size << " vs " << input_dim_size;
1560
1561 if (dynamic_dim_index && total_dim_size > input_dim_size)
1562 return op.emitOpError(
1563 "split sizes must sum up to be less than or equal to the "
1564 "dimension size along split dimension, found ")
1565 << total_dim_size << " vs " << input_dim_size;
1566
1567 return success();
1568 }
1569
1570 //===----------------------------------------------------------------------===//
1571 // SquareOp
1572 //===----------------------------------------------------------------------===//
1573
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1574 void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1575 MLIRContext *context) {
1576 results.insert<SquareOfSub>(context);
1577 }
1578
1579 //===----------------------------------------------------------------------===//
1580 // SubOp
1581 //===----------------------------------------------------------------------===//
1582
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1583 void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1584 MLIRContext *context) {
1585 results.insert<SubOfNeg>(context);
1586 }
1587
fold(ArrayRef<Attribute> operands)1588 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1589 return IdentityArithmeticOpFolder<SubOp>(*this, operands);
1590 }
1591
1592 //===----------------------------------------------------------------------===//
1593 // SumOp
1594 //===----------------------------------------------------------------------===//
1595
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)1596 void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
1597 Value reduction_indices, BoolAttr keep_dims) {
1598 Type out_ty =
1599 InferReductionOpType(input, reduction_indices, keep_dims, &builder);
1600 build(builder, result, out_ty, input, reduction_indices, keep_dims);
1601 }
1602
1603 // TODO: Templatize this fold for all reduction ops.
fold(ArrayRef<Attribute> operands)1604 OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
1605 auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
1606 if (!input_ty) return {};
1607 auto result_ty = getType().template dyn_cast<RankedTensorType>();
1608 if (!result_ty) return {};
1609
1610 // Bypass this op if the result has the same shape and type. This can happen
1611 // if the input tensor has size 0 or size 1.
1612 if (!keep_dims() && input_ty == result_ty) {
1613 return input();
1614 }
1615 return {};
1616 }
1617
1618 //===----------------------------------------------------------------------===//
1619 // StridedSliceOp
1620 //===----------------------------------------------------------------------===//
1621
1622 // TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
1623 // tf.SliceOp if both of the following are true:
1624 // - All strides have a known value equal to 1
1625 // - No masks are set (or masks can be applied by transforming the inputs to
1626 // Slice)
1627
1628 // Verifies that,
1629 //
1630 // - begin, end and strides operands are 1D and they have the same number of
1631 // elements. Here, the number of elements should be less than 32 to support
1632 // 32-bit mask attributes.
1633 // - None of the strides values are zero.
1634 // - Ellipsis mask can have at most one bit set.
1635
1636 template <class OpTy>
VerifyStridedSliceBase(OpTy op)1637 static LogicalResult VerifyStridedSliceBase(OpTy op) {
1638 // Expected size for operands begin, end and strides vector operands.
1639 int64_t expected_size = -1;
1640
1641 for (Value val : {op.begin(), op.end(), op.strides()}) {
1642 auto operand_ty = val.getType().dyn_cast<ShapedType>();
1643 if (!operand_ty || !operand_ty.hasStaticShape()) {
1644 // TensorFlow constant ops may have non-static shape because the shape is
1645 // not propagated during constant folding. If the defining op for this
1646 // operand is a constant op, use the constant op's attribute to get the
1647 // actual shape.
1648 DenseIntElementsAttr attr;
1649 if (!matchPattern(val, m_Constant(&attr))) continue;
1650 operand_ty = attr.getType();
1651 }
1652
1653 if (operand_ty.getRank() != 1)
1654 return op.emitOpError()
1655 << "requires begin, end and strides to be 1D tensors";
1656
1657 int64_t length = operand_ty.getDimSize(0);
1658 if (length == -1) continue;
1659
1660 if (expected_size == -1) {
1661 // This op uses 32-bit masks.
1662 if (length >= 32)
1663 return op.emitOpError(
1664 "requires begin, end and strides operands with less than 32 "
1665 "elements");
1666
1667 expected_size = length;
1668 } else if (length != expected_size) {
1669 return op.emitOpError() << "requires begin, end and strides to have the "
1670 "same number of elements";
1671 }
1672 }
1673
1674 // If strides are constants, verify that none of the element is zero.
1675 DenseIntElementsAttr strides;
1676 if (matchPattern(op.strides(), m_Constant(&strides))) {
1677 if (llvm::is_contained(strides.getValues<APInt>(), 0))
1678 return op.emitOpError("requires non-zero strides");
1679 }
1680
1681 // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
1682 // exists only no more than one ellipsis.
1683 uint32_t ellipsis_mask = op.ellipsis_mask();
1684 if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
1685 return op.emitOpError("cannot have multiple ellipses");
1686
1687 return success();
1688 }
1689
1690 // Clamps the given `val`: returns `low` if `val` is less than `low`; returns
1691 // `high` if `high` is less than `val`; otherwise returns `val`.
1692 template <class T>
Clamp(const T & val,const T & low,const T & high)1693 constexpr const T &Clamp(const T &val, const T &low, const T &high) {
1694 assert(!(high < low));
1695 return (val < low) ? low : (high < val) ? high : val;
1696 }
1697
1698 // Checks if the `index` bit of `val` is set.
1699 template <class T>
IsSet(const T & val,unsigned index)1700 constexpr bool IsSet(const T &val, unsigned index) {
1701 return (val & (1 << index)) != 0;
1702 }
1703
1704 // Sets the `index` bit of `val`.
1705 template <class T>
Set(T & val,unsigned index)1706 constexpr void Set(T &val, unsigned index) {
1707 val |= (1 << index);
1708 }
1709
1710 // Unset the `index` bit of `val`.
1711 template <class T>
Unset(T & val,unsigned index)1712 constexpr void Unset(T &val, unsigned index) {
1713 val &= ~(1 << index);
1714 }
1715
1716 // Copy the `src_index` bit of `src` to `dst_index` bit of `dst`.
1717 template <class T>
CopyBit(const T & src,unsigned src_index,T & dst,unsigned dst_index)1718 constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
1719 unsigned dst_index) {
1720 if (IsSet(src, src_index))
1721 Set(dst, dst_index);
1722 else
1723 Unset(dst, dst_index);
1724 }
1725
1726 // The sparse spec of strided slice does not correspond to the number of
1727 // dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
1728 // 4, 8) would have dims = 2.
1729 struct SparseSliceSpec {
1730 int64_t dims;
1731 int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
1732 const ArrayRef<int64_t> &begin;
1733 const ArrayRef<int64_t> &end;
1734 const ArrayRef<int64_t> &strides;
1735 };
1736
1737 // The dense spec of strided slice is the canonicalized version of sparse spec.
1738 // The number of dimensions of dense spec correspond to the number of dimensions
1739 // in operand tensor.
1740 struct DenseSliceSpec {
1741 int64_t dims;
1742 int32_t begin_mask, end_mask, shrink_axis_mask;
1743 SmallVectorImpl<int64_t> &begin;
1744 SmallVectorImpl<int64_t> &end;
1745 SmallVectorImpl<int64_t> &strides;
1746 };
1747
1748 // Make a sparse spec into a dense index spec.
1749 // The sparse spec does not correspond to the number of dimensions
1750 // Make a dense spec that corresponds to the number of dimensions
1751 //
1752 // For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then
1753 // we need to produce the missing begin_mask, end_mask for the first two
1754 // dimensions i.e. foo[:, :, 3:, 2].
BuildDenseSliceSpec(const SparseSliceSpec & sparse,DenseSliceSpec * dense)1755 static void BuildDenseSliceSpec(const SparseSliceSpec &sparse,
1756 DenseSliceSpec *dense) {
1757 // Build expanded dense begin, end, strides, begin_mask, end_mask, and
1758 // shrink_axis_mask.
1759 dense->begin.resize(dense->dims);
1760 dense->end.resize(dense->dims);
1761 dense->strides.resize(dense->dims);
1762 dense->begin_mask = 0;
1763 dense->end_mask = 0;
1764 dense->shrink_axis_mask = 0;
1765
1766 // Count number of new_axis after ellipsis. This helps in calculating the
1767 // number of dimensions ellipsis represents in the sparse spec.
1768 bool ellipsis_seen = false;
1769 int num_new_axis_after_ellipsis = 0;
1770 for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1771 if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
1772 num_new_axis_after_ellipsis++;
1773 if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
1774 }
1775
1776 int dense_index = 0;
1777 for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1778 if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
1779 if (IsSet(sparse.ellipsis_mask, sparse_index)) {
1780 auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1781 1 + num_new_axis_after_ellipsis,
1782 dense->dims);
1783 // Expand ellipsis into the appropriate dense indices. From current index
1784 // until next_index, all dimensions would have begin and end masks set and
1785 // stride 1, i.e., get all elements in those dimensions.
1786 for (; dense_index < next_index; ++dense_index) {
1787 dense->begin[dense_index] = dense->end[dense_index] = 0;
1788 dense->strides[dense_index] = 1;
1789 Set(dense->begin_mask, dense_index);
1790 Set(dense->end_mask, dense_index);
1791 }
1792 continue;
1793 }
1794 assert(dense_index < dense->dims);
1795 // Copy over the sparse indices to dense indices if ellipsis_mask and
1796 // new_axis_mask are not set.
1797 dense->begin[dense_index] = sparse.begin[sparse_index];
1798 dense->end[dense_index] = sparse.end[sparse_index];
1799 dense->strides[dense_index] = sparse.strides[sparse_index];
1800 CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index);
1801 CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index);
1802 CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask,
1803 dense_index);
1804 dense_index++;
1805 }
1806 }
1807
1808 // For the given `input_shape`, calculates the sliced shape using the given
1809 // `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and
1810 // `shrink_axis_mask` masks. Updates the result back to `input_shape`. If
1811 // `shrink_axis_mask` is not zero, this function will not drop the corresponding
1812 // dimensions in `input_shape`; it will turn them into 1s. At the same time,
1813 // canonicalizes `begin`, `end`, and `strides. The calculation follows
1814 // tf.StridedSlice op semantics.
CalculateSlicedShapeFromDenseIndices(MutableArrayRef<int64_t> input_shape,int32_t begin_mask,int32_t end_mask,int32_t shrink_axis_mask,MutableArrayRef<int64_t> begin,MutableArrayRef<int64_t> end,MutableArrayRef<int64_t> stride)1815 static void CalculateSlicedShapeFromDenseIndices(
1816 MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
1817 int32_t shrink_axis_mask, MutableArrayRef<int64_t> begin,
1818 MutableArrayRef<int64_t> end, MutableArrayRef<int64_t> stride) {
1819 assert(input_shape.size() <= 32); // Only 32-bit masks are supported.
1820
1821 // Make sure ranges' ranks are consistent with the input.
1822 assert(input_shape.size() == begin.size());
1823 assert(input_shape.size() == end.size());
1824 assert(input_shape.size() == stride.size());
1825
1826 for (int i = 0, e = input_shape.size(); i < e; ++i) {
1827 if (ShapedType::isDynamic(input_shape[i])) continue;
1828
1829 int64_t dim_i = input_shape[i];
1830 int64_t begin_i = begin[i];
1831 int64_t end_i = end[i];
1832 int64_t stride_i = stride[i];
1833
1834 // [0]: mask for begin, [1]: mask for end
1835 int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
1836 // [0]: bound for begin, [1]: bound for end
1837 int64_t bounds[] = {stride_i > 0 ? 0 : -1,
1838 stride_i > 0 ? dim_i : dim_i - 1};
1839
1840 // Canonicalizes the given range `point` (begin/end) according to the
1841 // current dimension. `c` means case: 0 for begin, 1 for end.
1842 auto canonicalize = [&](int64_t point, int c) {
1843 if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
1844
1845 // Add dim as offset to negative range point.
1846 point = point < 0 ? dim_i + point : point;
1847 return Clamp(point, bounds[0], bounds[1]);
1848 };
1849
1850 begin_i = canonicalize(begin_i, 0);
1851 end_i = canonicalize(end_i, 1);
1852
1853 int64_t interval_len = end_i - begin_i;
1854 int64_t size_i = 0;
1855 // If internal length is zero or has different sign from stride, it's a
1856 // degenerated case: we are slicing nothing. Otherwise, calculate the sliced
1857 // size.
1858 if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
1859 size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
1860
1861 begin[i] = begin_i;
1862 if (IsSet(shrink_axis_mask, i)) {
1863 // Shrink this dimension. It means we only take the element at begin_i.
1864 input_shape[i] = 1;
1865 end[i] = begin_i + 1;
1866 stride[i] = 1;
1867 } else {
1868 input_shape[i] = size_i;
1869 end[i] = end_i;
1870 stride[i] = stride_i;
1871 }
1872 }
1873 }
1874
1875 // For the given `input_shape`, calculates the sliced shape using the given
1876 // `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`,
1877 // `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks.
1878 // Updates the result back to `input_shape`.
CalculateSlicedShapeFromSparseIndices(MutableArrayRef<int64_t> input_shape,ArrayRef<int64_t> sparse_begin,ArrayRef<int64_t> sparse_end,ArrayRef<int64_t> sparse_strides,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,SmallVectorImpl<int64_t> * begin,SmallVectorImpl<int64_t> * end,SmallVectorImpl<int64_t> * stride)1879 static void CalculateSlicedShapeFromSparseIndices(
1880 MutableArrayRef<int64_t> input_shape, ArrayRef<int64_t> sparse_begin,
1881 ArrayRef<int64_t> sparse_end, ArrayRef<int64_t> sparse_strides,
1882 int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
1883 int32_t new_axis_mask, int32_t shrink_axis_mask,
1884 SmallVectorImpl<int64_t> *begin, SmallVectorImpl<int64_t> *end,
1885 SmallVectorImpl<int64_t> *stride) {
1886 int64_t num_sparse_indices = sparse_begin.size();
1887 SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask,
1888 ellipsis_mask, new_axis_mask, shrink_axis_mask,
1889 sparse_begin, sparse_end, sparse_strides};
1890
1891 // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
1892 // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
1893 // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
1894 if (sparse.ellipsis_mask == 0) {
1895 Set(sparse.ellipsis_mask, sparse.dims);
1896 sparse.dims++;
1897 }
1898
1899 int64_t dims = input_shape.size();
1900 DenseSliceSpec dense = {dims,
1901 /*begin_mask = */ 0,
1902 /*end_mask = */ 0,
1903 /*shrink_axis_mask = */ 0,
1904 *begin,
1905 *end,
1906 *stride};
1907
1908 BuildDenseSliceSpec(sparse, &dense);
1909 CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask,
1910 dense.end_mask, dense.shrink_axis_mask,
1911 *begin, *end, *stride);
1912 }
1913
GetSlicedBoundRanges(SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)1914 bool StridedSliceOp::GetSlicedBoundRanges(
1915 SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
1916 SmallVectorImpl<int64_t> *slice_stride) {
1917 // TODO(hinsu): Support lowering for ops with dynamic begin and end values
1918 // when it is possible to derive indices based on mask attributes.
1919 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
1920 if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
1921 !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
1922 !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
1923 return false;
1924
1925 auto input_ty = this->input().getType().dyn_cast<RankedTensorType>();
1926 if (!input_ty || !input_ty.hasStaticShape()) return false;
1927 auto input_shape = llvm::to_vector<4>(input_ty.getShape());
1928
1929 SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
1930
1931 for (const APInt &index : sparse_begin_attr)
1932 sparse_begin.push_back(index.getSExtValue());
1933 for (const APInt &index : sparse_end_attr)
1934 sparse_end.push_back(index.getSExtValue());
1935 for (const APInt &stride : sparse_strides_attr)
1936 sparse_strides.push_back(stride.getSExtValue());
1937
1938 CalculateSlicedShapeFromSparseIndices(
1939 input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
1940 end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
1941 slice_begin, slice_end, slice_stride);
1942 return true;
1943 }
1944
fold(ArrayRef<Attribute> operands)1945 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
1946 // Fold StridedSlice operation if it extracts statically known dimensions.
1947 //
1948 // For example,
1949 //
1950 // %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
1951 // %height = tf.StridedSlice(%shape, 1, 2, 1)
1952 //
1953 // In this case %height can be replaced with a constant 2.
1954 //
1955 // Or,
1956 //
1957 // %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
1958 // %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
1959 //
1960 // In this case %spatial_shape can be replaced with a constant [2, 3].
1961
1962 // Input to strided slice op is defined by shape operation.
1963 auto shape_op = input().getDefiningOp<ShapeOp>();
1964 if (!shape_op) {
1965 return {};
1966 }
1967
1968 // `begin`, `end` and `strides` should be constant in order to infer static
1969 // dimension.
1970 DenseIntElementsAttr begin_attr, end_attr, strides_attr;
1971 if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
1972 !matchPattern(end(), m_Constant(&end_attr)) ||
1973 !matchPattern(strides(), m_Constant(&strides_attr)) ||
1974 begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
1975 strides_attr.getNumElements() != 1) {
1976 return {};
1977 }
1978
1979 // Do not fold when `new_axis_mask` is set. It's likely to break the shape
1980 // of output. Typically, `new_axis_mask` is not set in this canonicalization
1981 // pattern.
1982 if (new_axis_mask() != 0) return {};
1983
1984 auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
1985 // Only ranked tensor can be folded.
1986 if (!tensor_ty) return {};
1987
1988 int64_t rank = tensor_ty.getRank();
1989 int64_t begin_int = begin_attr.getValue<APInt>(0).getSExtValue();
1990 int64_t end_int = end_attr.getValue<APInt>(0).getSExtValue();
1991 int64_t strides_int = strides_attr.getValue<APInt>(0).getSExtValue();
1992
1993 // Canonicalize `begin` and `end` in case of negative index.
1994 if (begin_int < 0) begin_int += rank;
1995 if (end_int < 0) end_int += rank;
1996
1997 // Create `begin` and `end` from `*_mask`. Note that we don't care about
1998 // `new_axis_mask` as it can be inferred from `output_ty`.
1999 if (shrink_axis_mask() == 1) {
2000 // When `shrink_axis_mask` is set, output is always a scalar so only
2001 // one element is sliced.
2002 end_int = begin_int + 1;
2003 }
2004 if (begin_mask() == 1) {
2005 begin_int = (strides_int > 0) ? 0 : rank - 1;
2006 }
2007 if (end_mask() == 1) {
2008 end_int = (strides_int > 0) ? rank : -1;
2009 }
2010 if (ellipsis_mask() == 1) {
2011 begin_int = 0;
2012 end_int = rank;
2013 }
2014
2015 // It's possible that `begin` and `end` are out of bound. See
2016 // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
2017 if (strides_int > 0) {
2018 begin_int = std::min(begin_int, rank);
2019 end_int = std::min(end_int, rank);
2020 } else {
2021 begin_int = std::min(begin_int, rank - 1);
2022 end_int = std::min(end_int, rank - 1);
2023 }
2024
2025 SmallVector<int64_t, 2> sub_shape;
2026 // Only handle cases that have something to slice to avoid infinite for-loop.
2027 if ((end_int > begin_int && strides_int > 0) ||
2028 (end_int < begin_int && strides_int < 0)) {
2029 // Extract sub-shape only if all of those dimensions are static.
2030 for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
2031 i += strides_int) {
2032 if (tensor_ty.isDynamicDim(i)) {
2033 return {};
2034 }
2035 sub_shape.push_back(tensor_ty.getDimSize(i));
2036 }
2037 }
2038
2039 // For unranked or dynamic output, we infer the output type to either a
2040 // scalar or a vector based on `shrink_axis_mask` because we have rejected
2041 // the case of `new_axis_mask` != 0.
2042 auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
2043 auto output_ty = output().getType().dyn_cast<RankedTensorType>();
2044 if (!output_ty || !output_ty.hasStaticShape()) {
2045 if (shrink_axis_mask() == 1) {
2046 output_ty = RankedTensorType::get({}, output_elt_ty);
2047 } else {
2048 output_ty = RankedTensorType::get(
2049 {static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
2050 }
2051 }
2052
2053 // Down-cast to 32 bit int if needed.
2054 if (output_elt_ty.isInteger(32)) {
2055 SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
2056 std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
2057 [](int64_t d) { return static_cast<int32_t>(d); });
2058 return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
2059 }
2060 return DenseIntElementsAttr::get(output_ty, sub_shape);
2061 }
2062
2063 //===----------------------------------------------------------------------===//
2064 // StridedSliceGradOp
2065 //===----------------------------------------------------------------------===//
2066
Verify(StridedSliceGradOp op)2067 static LogicalResult Verify(StridedSliceGradOp op) {
2068 auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
2069 if (shape_type && shape_type.getRank() != 1)
2070 return op.emitOpError("'shape' operand must be 1D tensor, but got ")
2071 << shape_type.getRank() << "D tensor";
2072
2073 if (failed(VerifyStridedSliceBase(op))) return failure();
2074
2075 // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
2076 // the sliced type from StridedSlice.
2077
2078 return success();
2079 }
2080
GetSlicedShapeAndBoundRanges(SmallVectorImpl<int64_t> * input_shape,SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)2081 bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
2082 SmallVectorImpl<int64_t> *input_shape,
2083 SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
2084 SmallVectorImpl<int64_t> *slice_stride) {
2085 DenseIntElementsAttr shape_attr;
2086 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
2087 if (!matchPattern(shape(), m_Constant(&shape_attr)) ||
2088 !matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
2089 !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
2090 !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
2091 return false;
2092
2093 int rank = std::distance(shape_attr.begin(), shape_attr.end());
2094
2095 input_shape->clear();
2096 input_shape->reserve(rank);
2097 for (const APInt &dim : shape_attr)
2098 input_shape->push_back(dim.getSExtValue());
2099
2100 SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
2101
2102 for (const APInt &index : sparse_begin_attr)
2103 sparse_begin.push_back(index.getSExtValue());
2104 for (const APInt &index : sparse_end_attr)
2105 sparse_end.push_back(index.getSExtValue());
2106 for (const APInt &stride : sparse_strides_attr)
2107 sparse_strides.push_back(stride.getSExtValue());
2108
2109 CalculateSlicedShapeFromSparseIndices(
2110 *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
2111 end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
2112 slice_begin, slice_end, slice_stride);
2113 return true;
2114 }
2115
2116 //===----------------------------------------------------------------------===//
2117 // SummaryWriterOp
2118 //===----------------------------------------------------------------------===//
2119
2120 llvm::SmallVector<ResourceHandleValueAndId, 4>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2121 SummaryWriterOp::GetResourceHandleValueAndIdList(
2122 llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2123 int64_t &next_id) {
2124 llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2125 return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2126 writer(), resource_handle_id_map,
2127 next_id)};
2128 }
2129
2130 //===----------------------------------------------------------------------===//
2131 // TPUExecuteOp
2132 //===----------------------------------------------------------------------===//
2133
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2134 void TPUExecuteOp::getEffects(
2135 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2136 &effects) {
2137 effects.reserve(args().size() + 1);
2138
2139 // There may be some TPU Embedding ops in the computation, so this effect is
2140 // added conservatively.
2141 effects.emplace_back(MemoryEffects::Write::get(),
2142 ResourceEffects::TPUEmbedding::get());
2143
2144 for (Value value : args()) {
2145 if (value.getType()
2146 .cast<TensorType>()
2147 .getElementType()
2148 .isa<ResourceType>()) {
2149 // Conservatively mark resource handles as read and write, as without
2150 // analyzing TPUCompile, there is not sufficient information to determine
2151 // effects on resources. For the MLIR bridge, this op will never be
2152 // populated with resource handles and tf.TPUExecuteAndUpdateVariables is
2153 // used instead.
2154 effects.emplace_back(MemoryEffects::Read::get(), value,
2155 ResourceEffects::Variable::get());
2156 effects.emplace_back(MemoryEffects::Write::get(), value,
2157 ResourceEffects::Variable::get());
2158 }
2159 }
2160 }
2161
2162 //===----------------------------------------------------------------------===//
2163 // TPUExecuteAndUpdateVariablesOp
2164 //===----------------------------------------------------------------------===//
2165
Verify(TPUExecuteAndUpdateVariablesOp op)2166 static LogicalResult Verify(TPUExecuteAndUpdateVariablesOp op) {
2167 int num_resource_args = 0;
2168 for (Type arg_type : op.args().getTypes())
2169 if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
2170 ++num_resource_args;
2171
2172 auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
2173 int min) -> LogicalResult {
2174 if (indices.size() != num_resource_args)
2175 return op.emitOpError()
2176 << "requires '" << name
2177 << "' to be the same size as number of resource handles in 'args' "
2178 "("
2179 << num_resource_args << "), but got " << indices.size();
2180
2181 for (auto entry : llvm::enumerate(indices.getValue())) {
2182 auto int_attr = entry.value().cast<IntegerAttr>();
2183 if (int_attr.getInt() < min)
2184 return op.emitOpError()
2185 << "requires '" << name << "' to contain values of at least "
2186 << min << ", but got " << int_attr.getInt() << " at index "
2187 << entry.index();
2188 }
2189
2190 return success();
2191 };
2192
2193 return failure(
2194 failed(check_attr(op.device_var_reads_indices(),
2195 /*name=*/"device_var_reads_indices", /*min=*/0)) ||
2196 failed(check_attr(op.device_var_updates_indices(),
2197 /*name=*/"device_var_updates_indices", /*min=*/-1)));
2198 }
2199
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2200 void TPUExecuteAndUpdateVariablesOp::getEffects(
2201 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2202 &effects) {
2203 effects.reserve(device_var_reads_indices().size() + 1);
2204
2205 // There may be some TPU Embedding ops in the computation, so this effect is
2206 // added conservatively.
2207 effects.emplace_back(MemoryEffects::Write::get(),
2208 ResourceEffects::TPUEmbedding::get());
2209 auto resource_handles = llvm::make_filter_range(args(), [](Value value) {
2210 return value.getType()
2211 .cast<TensorType>()
2212 .getElementType()
2213 .isa<ResourceType>();
2214 });
2215
2216 for (auto &entry : llvm::enumerate(resource_handles)) {
2217 Value value = entry.value();
2218 effects.emplace_back(MemoryEffects::Read::get(), value,
2219 ResourceEffects::Variable::get());
2220 if (device_var_updates_indices()
2221 .getValue()[entry.index()]
2222 .cast<IntegerAttr>()
2223 .getInt() >= 0)
2224 effects.emplace_back(MemoryEffects::Write::get(), value,
2225 ResourceEffects::Variable::get());
2226 }
2227 }
2228
2229 //===----------------------------------------------------------------------===//
2230 // TensorListReserveOp
2231 //===----------------------------------------------------------------------===//
2232
Verify(TensorListReserveOp op)2233 static LogicalResult Verify(TensorListReserveOp op) {
2234 if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2235 !IsOfRankOrUnranked(op.element_shape(), 1)) {
2236 return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2237 }
2238
2239 if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
2240 return op.emitOpError("requires num_elements operand to be 0D tensor");
2241 }
2242 return success();
2243 }
2244
2245 //===----------------------------------------------------------------------===//
2246 // TensorListElementShapeOp
2247 //===----------------------------------------------------------------------===//
2248
fold(ArrayRef<Attribute> operands)2249 OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
2250 int width =
2251 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2252 auto variant_type =
2253 getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
2254 if (variant_type.getSubtypes().empty()) return {};
2255 return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
2256 }
2257
2258 //===----------------------------------------------------------------------===//
2259 // TensorListStackOp
2260 //===----------------------------------------------------------------------===//
2261
Verify(TensorListStackOp op)2262 static LogicalResult Verify(TensorListStackOp op) {
2263 if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2264 !IsOfRankOrUnranked(op.element_shape(), 1)) {
2265 return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2266 }
2267 return success();
2268 }
2269
2270 //===----------------------------------------------------------------------===//
2271 // TensorScatterUpdateOp
2272 //===----------------------------------------------------------------------===//
2273
Verify(TensorScatterUpdateOp op)2274 static LogicalResult Verify(TensorScatterUpdateOp op) {
2275 if (!HasRankAtLeast(op.tensor(), 1))
2276 return op.emitOpError(
2277 "requires tensor operand to have at least 1 dimension");
2278 if (!HasRankAtLeast(op.indices(), 1))
2279 return op.emitOpError(
2280 "requires indices operand to have at least 1 dimension");
2281 if (!HasRankAtLeast(op.updates(), 1))
2282 return op.emitOpError(
2283 "requires updates operand to have at least 1 dimension");
2284
2285 auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
2286 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2287 if (!tensor_ty || !indices_ty) return success();
2288
2289 int64_t num_index_dims = indices_ty.getShape().back();
2290 if (ShapedType::isDynamic(num_index_dims)) return success();
2291
2292 if (num_index_dims > tensor_ty.getRank())
2293 return op.emitOpError(
2294 "requires tensor operand with rank greater than or equal to the "
2295 "indices operand's last dimensions");
2296 return success();
2297 }
2298
2299 //===----------------------------------------------------------------------===//
2300 // TileOp
2301 //===----------------------------------------------------------------------===//
2302
2303 // Verifies that,
2304 //
2305 // - input has at least rank 1
2306 // - multiples is rank 1
2307 // - multiples.size() == input.rank()
2308 // - input.rank() == output.rank()
2309 // - Elements in multiples are non-negative
2310 // - input.shape[i] * multiples[i] == output.shape[i]
2311 // for i in [0, input.rank() - 1]
2312
Verify(TileOp op)2313 static LogicalResult Verify(TileOp op) {
2314 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
2315 auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
2316 auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
2317
2318 if (multiples_type && multiples_type.getRank() != 1) {
2319 return op.emitOpError() << "expected multiples to be rank 1, got rank = "
2320 << multiples_type.getRank();
2321 }
2322
2323 if (input_type && multiples_type && multiples_type.hasStaticShape() &&
2324 (input_type.getRank() != multiples_type.getNumElements() ||
2325 (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
2326 return op.emitOpError()
2327 << "expected size of multiples equal to rank of input"
2328 << ", got multiples of size " << multiples_type.getNumElements()
2329 << ", and input of rank " << input_type.getRank();
2330 }
2331
2332 if (input_type && output_type) {
2333 if (input_type.getRank() != output_type.getRank()) {
2334 return op.emitOpError()
2335 << "expected rank of input to equal to rank of output"
2336 << ", got input of rank " << input_type.getRank()
2337 << ", and output of rank " << output_type.getRank();
2338 }
2339
2340 DenseIntElementsAttr multiples_attr;
2341 if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
2342 for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
2343 const int64_t input_dim = input_type.getDimSize(i);
2344 const int64_t output_dim = output_type.getDimSize(i);
2345 const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
2346
2347 if (m < 0) {
2348 return op.emitOpError()
2349 << "expected multiples to be non-negative, got "
2350 << "multiples[" << i << "] = " << m;
2351 }
2352
2353 if (!ShapedType::isDynamic(input_dim) &&
2354 !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
2355 return op.emitOpError()
2356 << "requires input.shape[" << i << "] (" << input_dim << ")"
2357 << " * " << m << " to be equal to "
2358 << "output.shape[" << i << "] (" << output_dim << ")";
2359 }
2360 }
2361 }
2362 }
2363
2364 return success();
2365 }
2366
fold(ArrayRef<Attribute> operands)2367 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
2368 DenseIntElementsAttr multiples_attr;
2369 if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
2370 // Return input directly when multiples are all ones,
2371 // regardless what input is.
2372 if (multiples_attr.isSplat() &&
2373 multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
2374 return input();
2375 }
2376 }
2377 return {};
2378 }
2379
2380 //===----------------------------------------------------------------------===//
2381 // TopKV2Op
2382 //===----------------------------------------------------------------------===//
2383
Verify(TopKV2Op op)2384 static LogicalResult Verify(TopKV2Op op) {
2385 if (!HasRankAtLeast(op.input(), 1))
2386 return op.emitOpError(
2387 "requires input operand to have at least 1 dimension");
2388
2389 if (!IsOfRankOrUnranked(op.k(), 0))
2390 return op.emitOpError("requires k operand to be 0D tensor");
2391
2392 return success();
2393 }
2394
2395 //===----------------------------------------------------------------------===//
2396 // ToBoolOp
2397 //===----------------------------------------------------------------------===//
2398
2399 namespace {
2400 // If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
2401 // into an identity or an equality comparison.
2402 class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
2403 using OpRewritePattern<ToBoolOp>::OpRewritePattern;
matchAndRewrite(ToBoolOp op,PatternRewriter & rewriter) const2404 LogicalResult matchAndRewrite(ToBoolOp op,
2405 PatternRewriter &rewriter) const override {
2406 auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
2407 // If the input is an unranked tensor, cannpt rewrite.
2408 if (!type) return failure();
2409
2410 // Expected return type of the ToBool operation. The return type of ToBool
2411 // operation is always 0D tensor of bool type.
2412 auto result_type = op.getResult().getType().cast<RankedTensorType>();
2413
2414 // If input is already a tensor<i1>, it can be folded into an identity.
2415 if (type == result_type) {
2416 rewriter.replaceOp(op, op.getOperand());
2417 return success();
2418 }
2419
2420 if (type.getRank() == 0) {
2421 // If the input is a scalar tensor, the ToBool can be expanded to
2422 // element != 0 (for numerical values) or element == empty (for string).
2423 Type element_type = type.getElementType();
2424 Attribute zero_attr;
2425 if (element_type.isIntOrFloat())
2426 zero_attr = rewriter.getZeroAttr(type);
2427 else if (element_type.isa<TF::StringType>())
2428 zero_attr = DenseStringElementsAttr::get(type, {""});
2429
2430 if (!zero_attr) return failure();
2431
2432 auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
2433 rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
2434 op, result_type, op.getOperand(), zero_const, false);
2435 } else {
2436 // If the input is a non-scalar ranked tensor, ToBool can be expanded
2437 // to numElements != 0. numElements will be 0 iff one of the dimensions is
2438 // zero.
2439 bool any_zero =
2440 llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
2441 rewriter.replaceOpWithNewOp<TF::ConstOp>(
2442 op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
2443 }
2444 return success();
2445 }
2446 };
2447 } // namespace
2448
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2449 void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2450 MLIRContext *context) {
2451 results.insert<ToBoolOfRankedTensor>(context);
2452 }
2453
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2454 LogicalResult ToBoolOp::inferReturnTypes(
2455 MLIRContext *context, Optional<Location> location, ValueRange operands,
2456 DictionaryAttr attributes, RegionRange regions,
2457 SmallVectorImpl<Type> &inferredReturnTypes) {
2458 inferredReturnTypes.push_back(
2459 RankedTensorType::get({}, IntegerType::get(context, 1)));
2460 return success();
2461 }
2462
2463 //===----------------------------------------------------------------------===//
2464 // TransposeOp
2465 //===----------------------------------------------------------------------===//
2466
Verify(TransposeOp op)2467 static LogicalResult Verify(TransposeOp op) {
2468 auto perm_type = op.perm().getType().dyn_cast<RankedTensorType>();
2469 auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
2470 auto y_type = op.y().getType().dyn_cast<RankedTensorType>();
2471
2472 if (perm_type && perm_type.getRank() != 1) {
2473 return op.emitOpError()
2474 << "expected perm to be a 1-D Tensor, got perm of rank "
2475 << perm_type.getRank();
2476 }
2477
2478 if (x_type && y_type && x_type.getRank() != y_type.getRank()) {
2479 return op.emitOpError() << "x should be of the same rank with y, got "
2480 << "x of rank " << x_type.getRank()
2481 << ", and y of rank " << y_type.getRank();
2482 }
2483
2484 if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) {
2485 return success();
2486 }
2487
2488 if (x_type.getRank() != perm_type.getNumElements()) {
2489 return op.emitOpError() << "expected perm to be a 1-D Tensor of size "
2490 << "equal to the rank of x, got perm of size "
2491 << perm_type.getNumElements() << ", and x of rank "
2492 << x_type.getRank();
2493 }
2494
2495 DenseIntElementsAttr attr_perm;
2496 if (matchPattern(op.perm(), m_Constant(&attr_perm))) {
2497 // y.shape[i] should be equal to x.shape[perm[i]]
2498 // for i = [0, 1, ..., rank(x) - 1]
2499 for (auto e : llvm::enumerate(attr_perm)) {
2500 const int64_t y_idx = e.index();
2501 const int64_t y_dim = y_type.getDimSize(y_idx);
2502 const int64_t x_idx = e.value().getSExtValue();
2503 const int64_t x_dim = x_type.getDimSize(x_idx);
2504 if (y_dim != ShapedType::kDynamicSize &&
2505 x_dim != ShapedType::kDynamicSize && y_dim != x_dim) {
2506 return op.emitOpError()
2507 << "requires y.shape[" << y_idx << "] (" << y_dim << ") "
2508 << "to be equal to x.shape[perm[" << x_idx << "]] "
2509 << "(" << x_dim << ")";
2510 }
2511 }
2512 }
2513
2514 return success();
2515 }
2516
2517 // TODO(jpienaar): perm could be optional too.
build(OpBuilder & builder,OperationState & result,Value x,Value perm)2518 void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
2519 Value perm) {
2520 auto x_type = x.getType().cast<TensorType>();
2521 // If value is unranked, then so is results.
2522 if (!x_type.hasRank())
2523 return TransposeOp::build(builder, result,
2524 UnrankedTensorType::get(x_type.getElementType()),
2525 x, perm);
2526
2527 // TODO(jpienaar): Handle unknown perm case.
2528
2529 // TODO(jpienaar): Extract utility function.
2530 auto etype = x_type.cast<ShapedType>().getElementType();
2531 DenseIntElementsAttr attr_shape;
2532 if (matchPattern(perm, m_Constant(&attr_shape))) {
2533 llvm::SmallVector<int64_t, 4> const_shape;
2534 if (attr_shape.isSplat()) {
2535 const_shape.assign(
2536 attr_shape.getNumElements(),
2537 x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
2538 } else {
2539 const_shape.reserve(attr_shape.getNumElements());
2540 for (const auto &dim : attr_shape)
2541 const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
2542 }
2543 return TransposeOp::build(
2544 builder, result, RankedTensorType::get(const_shape, etype), x, perm);
2545 }
2546 return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x,
2547 perm);
2548 }
2549
2550 namespace {
2551
FoldIdentityTranspose(TransposeOp op)2552 OpFoldResult FoldIdentityTranspose(TransposeOp op) {
2553 DenseIntElementsAttr perm;
2554 if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
2555 const auto elements = perm.getValues<APInt>();
2556
2557 for (auto it : llvm::enumerate(elements)) {
2558 if (it.index() != it.value()) return {};
2559 }
2560
2561 // TODO(jpienaar): Remove if/when we handle this more generally.
2562 if (op.getType() != op.x().getType()) {
2563 // If the types don't match then only fold if all the operands are in the TF
2564 // dialect.
2565 for (auto user : op.getOperation()->getUsers())
2566 if (user->getDialect() != op->getDialect()) return {};
2567 }
2568
2569 return op.x();
2570 }
2571
FoldCancellableTranspose(TransposeOp op)2572 OpFoldResult FoldCancellableTranspose(TransposeOp op) {
2573 // Operand is a TransposeOp.
2574 auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
2575 if (!transpose) return {};
2576
2577 // Permutations defined by constant operations.
2578 DenseIntElementsAttr perm0;
2579 DenseIntElementsAttr perm1;
2580 if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
2581 !matchPattern(transpose.perm(), m_Constant(&perm1)))
2582 return {};
2583
2584 // With permutation indices that cancel each other
2585 if (!AreCancellablePermutations(perm0, perm1)) return {};
2586
2587 return transpose.x();
2588 }
2589
2590 } // namespace
2591
fold(ArrayRef<Attribute> operands)2592 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2593 if (auto folded = FoldIdentityTranspose(*this)) return folded;
2594 if (auto folded = FoldCancellableTranspose(*this)) return folded;
2595 return {};
2596 }
2597
2598 //===----------------------------------------------------------------------===//
2599 // TruncateDivOp
2600 //===----------------------------------------------------------------------===//
2601
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2602 void TruncateDivOp::getCanonicalizationPatterns(
2603 OwningRewritePatternList &results, MLIRContext *context) {
2604 results.insert<TruncateDivWithSqrtDivisor>(context);
2605 }
2606
2607 //===----------------------------------------------------------------------===//
2608 // NonMaxSuppressionV3Op
2609 //===----------------------------------------------------------------------===//
2610
2611 namespace {
2612
2613 // Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op.
2614 class NMSV3ToNMSV4Op : public OpRewritePattern<NonMaxSuppressionV3Op> {
2615 using OpRewritePattern<NonMaxSuppressionV3Op>::OpRewritePattern;
matchAndRewrite(NonMaxSuppressionV3Op nms_op,PatternRewriter & rewriter) const2616 LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op,
2617 PatternRewriter &rewriter) const override {
2618 if (nms_op.getNumOperands() != 5) {
2619 return failure();
2620 }
2621 SmallVector<Type, 2> new_result_types;
2622 new_result_types.push_back(nms_op.getType());
2623 auto input_ty = nms_op.getType().template cast<ShapedType>();
2624 // corresponds to the second result type of nmsv4
2625 RankedTensorType valid_output_type =
2626 RankedTensorType::get({}, input_ty.getElementType());
2627 new_result_types.push_back(valid_output_type);
2628
2629 auto nmsv4 = rewriter.create<TF::NonMaxSuppressionV4Op>(
2630 nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(),
2631 nms_op.max_output_size(), nms_op.iou_threshold(),
2632 nms_op.score_threshold());
2633 // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the
2634 // two are different (v4 expects two output values vs v3 requires only one.
2635 nms_op.replaceAllUsesWith(nmsv4.getResult(0));
2636 return success();
2637 }
2638 };
2639 } // namespace.
2640
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2641 void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
2642 OwningRewritePatternList &results, MLIRContext *context) {
2643 results.insert<NMSV3ToNMSV4Op>(context);
2644 }
2645
2646 //===----------------------------------------------------------------------===//
2647 // FusedBatchNormOp
2648 //===----------------------------------------------------------------------===//
2649
2650 namespace {
2651
2652 class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
2653 using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,PatternRewriter & rewriter) const2654 LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
2655 PatternRewriter &rewriter) const override {
2656 auto new_result_types =
2657 llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
2658 // reserve_space_3
2659 new_result_types.push_back(
2660 UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
2661
2662 OperationState new_state(tf_fused_batch_norm_op.getLoc(),
2663 TF::FusedBatchNormV3Op::getOperationName(),
2664 tf_fused_batch_norm_op.getOperands(),
2665 new_result_types,
2666 tf_fused_batch_norm_op->getAttrs());
2667 Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
2668
2669 rewriter.replaceOp(tf_fused_batch_norm_op,
2670 tf_fused_batch_norm_op_v3->getResults().drop_back());
2671 return success();
2672 }
2673 };
2674 } // namespace.
2675
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2676 void FusedBatchNormOp::getCanonicalizationPatterns(
2677 OwningRewritePatternList &results, MLIRContext *context) {
2678 results.insert<ConvertFusedBatchNorm>(context);
2679 }
2680
2681 //===----------------------------------------------------------------------===//
2682 // UnpackOp
2683 //===----------------------------------------------------------------------===//
2684
Verify(UnpackOp op)2685 static LogicalResult Verify(UnpackOp op) {
2686 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
2687 if (!value_type) return success();
2688
2689 int64_t value_rank = value_type.getRank();
2690 int64_t axis = op.axis();
2691 if (axis < -value_rank || axis >= value_rank)
2692 return op.emitOpError("axis attribute must be in the range of [-")
2693 << value_rank << ", " << value_rank << ')';
2694
2695 axis = GetDimForAxis(axis, value_rank);
2696 int64_t dim_size = value_type.getDimSize(axis);
2697 if (ShapedType::isDynamic(dim_size)) return success();
2698
2699 if (dim_size != op.getNumResults())
2700 return op.emitOpError("result count must be equal to ") << dim_size;
2701
2702 return success();
2703 }
2704
2705 namespace {
2706
2707 // Hoist coefficient-wise unary operation out of the Unpack op:
2708 //
2709 // %unpacked:N = "tf.Unpack"(%0)
2710 // %neg0 = "tf.Neg"(%unpacked#0)
2711 // %neg1 = "tf.Neg"(%unpacked#1)
2712 // ...
2713 // %negN-1 = "tf.Neg"(%unpacked:N-1)
2714 //
2715 // Rewrite it to:
2716 //
2717 // %neg = "tf.Neg"(%0)
2718 // %unpacked:N = "tf.Unpack"(%neg)
2719 class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
2720 public:
HoistCwiseUnaryOutOfUnpack(MLIRContext * context)2721 explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
2722 : OpRewritePattern<UnpackOp>(context) {}
2723 LogicalResult matchAndRewrite(UnpackOp op,
2724 PatternRewriter &rewriter) const override;
2725 };
2726
matchAndRewrite(UnpackOp op,PatternRewriter & rewriter) const2727 LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
2728 UnpackOp op, PatternRewriter &rewriter) const {
2729 auto loc = op.getLoc();
2730
2731 // First unpack user must be coeff-wise unary operation.
2732 Operation *first_user = *op->getUsers().begin();
2733 if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
2734
2735 // All unpack users must be defined by the op of same kind.
2736 bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
2737 return user->getName() == first_user->getName();
2738 });
2739 if (!users_same_op) return failure();
2740
2741 // Pass unpack operand to unary operation.
2742 OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
2743 op.getOperand(), op.getOperand().getType(),
2744 ArrayRef<NamedAttribute>());
2745 Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
2746
2747 // Unpack results after applying unary operation.
2748 auto unpack_unary_op = rewriter.create<UnpackOp>(
2749 loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
2750
2751 // Bypass all users of the original unpack operation and use `unpack_unary_op`
2752 // results instead.
2753 for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
2754 OpResult old_result = std::get<0>(pair); // result of original Unpack
2755 OpResult new_result = std::get<1>(pair); // result of transformed Unpack
2756 for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
2757 rewriter.replaceOp(user, ValueRange(new_result));
2758 }
2759
2760 // Erase original unpack operation.
2761 rewriter.eraseOp(op.getOperation());
2762
2763 return success();
2764 }
2765
2766 } // namespace
2767
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2768 void UnpackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2769 MLIRContext *context) {
2770 results.insert<HoistCwiseUnaryOutOfUnpack>(context);
2771 }
2772
2773 //===----------------------------------------------------------------------===//
2774 // Unsorted segment reduction ops
2775 //===----------------------------------------------------------------------===//
2776
2777 template <class Op>
VerifyUnsortedSegmentReduction(Op op)2778 static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
2779 if (!HasRankAtMost(op.num_segments(), 0))
2780 return op.emitOpError("number of segments should be a 0-D tensor");
2781
2782 auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
2783 auto segment_ids_type =
2784 op.segment_ids().getType().template dyn_cast<RankedTensorType>();
2785 if (data_type && segment_ids_type) {
2786 if (data_type.getRank() < segment_ids_type.getRank())
2787 return op.emitOpError(
2788 "requires segment ids rank to be less than or equal to data's rank");
2789
2790 int index = 0;
2791 for (auto shape_pair :
2792 llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
2793 int64_t segment_id_dim = std::get<0>(shape_pair);
2794 int64_t data_dim = std::get<1>(shape_pair);
2795 if (!ShapedType::isDynamic(segment_id_dim) &&
2796 !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
2797 return op.emitOpError(
2798 "requires segment ids shape to be a prefix of data shape, "
2799 "but dimension #")
2800 << index << " differs: " << segment_id_dim << " vs. "
2801 << data_dim;
2802 ++index;
2803 }
2804 }
2805
2806 DenseIntElementsAttr num_segments_attr;
2807 if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
2808 int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
2809 if (num_segments < 0)
2810 return op.emitOpError("num of segments cannot be negative");
2811 }
2812
2813 return success();
2814 }
2815
2816 //===----------------------------------------------------------------------===//
2817 // VarHandleOp
2818 //===----------------------------------------------------------------------===//
2819
2820 llvm::SmallVector<ResourceHandleValueAndId, 4>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2821 VarHandleOp::GetResourceHandleValueAndIdList(
2822 llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2823 int64_t &next_id) {
2824 llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2825 return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2826 resource(), resource_handle_id_map,
2827 next_id)};
2828 }
2829
2830 //===----------------------------------------------------------------------===//
2831 // VarIsInitializedOp
2832 //===----------------------------------------------------------------------===//
2833
2834 namespace {
2835
2836 /// Erase VarIsInitializedOp operations with no uses. This op has side effect on
2837 /// resources (read-only), but can still be deleted if it has zero uses.
2838 struct EraseDeadVarIsInitializedOp
2839 : public OpRewritePattern<VarIsInitializedOp> {
2840 using OpRewritePattern<VarIsInitializedOp>::OpRewritePattern;
2841
matchAndRewritemlir::TF::__anon4e3f163f1411::EraseDeadVarIsInitializedOp2842 LogicalResult matchAndRewrite(VarIsInitializedOp op,
2843 PatternRewriter &rewriter) const override {
2844 if (!op.use_empty()) return failure();
2845 rewriter.eraseOp(op);
2846 return success();
2847 }
2848 };
2849 } // end anonymous namespace.
2850
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2851 void VarIsInitializedOp::getCanonicalizationPatterns(
2852 OwningRewritePatternList &patterns, MLIRContext *context) {
2853 patterns.insert<EraseDeadVarIsInitializedOp>(context);
2854 }
2855
2856 //===----------------------------------------------------------------------===//
2857 // VariableOp
2858 //===----------------------------------------------------------------------===//
2859
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2860 void VariableOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2861 MLIRContext *context) {
2862 results.insert<VariableToVariableV2>(context);
2863 }
2864
2865 //===----------------------------------------------------------------------===//
2866 // VariableShapeOp
2867 //===----------------------------------------------------------------------===//
2868
Verify(VariableShapeOp op)2869 static LogicalResult Verify(VariableShapeOp op) {
2870 auto input_type = op.input().getType().cast<TensorType>();
2871 if (input_type.hasStaticShape() && input_type.getNumElements() != 1)
2872 return op.emitOpError("requires input to have one resource");
2873
2874 auto resource_type = input_type.getElementType().cast<TF::ResourceType>();
2875 auto subtypes = resource_type.getSubtypes();
2876 switch (subtypes.size()) {
2877 case 1:
2878 return VerifyShapeOperandAndResult(
2879 op, resource_type.getSubtypes().front(), op.getType());
2880 case 0:
2881 return VerifyShapeOperandAndResult(op, Type(), op.getType());
2882 default:
2883 return op.emitOpError(
2884 "requires resource input type to have at most 1 subtype");
2885 }
2886 }
2887
fold(ArrayRef<Attribute> operands)2888 OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
2889 int width =
2890 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2891 auto resource_type =
2892 getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
2893 if (resource_type.getSubtypes().empty()) return {};
2894 return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
2895 }
2896
2897 //===----------------------------------------------------------------------===//
2898 // WhileOp
2899 //===----------------------------------------------------------------------===//
2900
VerifyWhileTypes(Operation * op,TypeRange cond_input,TypeRange body_input,TypeRange body_result,bool shape_invariant)2901 static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
2902 TypeRange body_input,
2903 TypeRange body_result,
2904 bool shape_invariant) {
2905 const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
2906 const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
2907 constexpr int kNumRegionTypeLists = 3;
2908 const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
2909 {body_result, "body result"},
2910 {cond_input, "condition input"},
2911 {body_input, "body input"},
2912 }};
2913
2914 // A pair of type lists should be cast compatible with each other if one is
2915 // converted to the another for a function call or assignment or there is a
2916 // common source of inputs for both. Therefore, the While op requires the
2917 // following pairs of type lists to be cast compatible for the tensor_cast
2918 // operation:
2919 //
2920 // * Operands and cond inputs to call the cond function before the
2921 // first iteration.
2922 // * Operands and body inputs to call the body function for the first
2923 // iteration if the cond functions returns True or equivalent result.
2924 // * Operands and results to assign cond function arguments to op results if
2925 // the cond function returns False or equivalent result. If the op is shape
2926 // invariant, this does not hold as shapes can differ.
2927 // * All three pairs using cond inputs, body inputs and results as operand is
2928 // a common source for all three.
2929 // * Body result and cond inputs to call the cond function for the subsequent
2930 // iterations. Similarly, Body result should be compatible with body inputs
2931 // and op results.
2932 //
2933 // Note that the operands and body results need not be compatible as they are
2934 // never converted from one to the another nor there is a common source
2935 // tensors. Compatibility requirement is not transitive.
2936
2937 if (!shape_invariant &&
2938 failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
2939 return failure();
2940
2941 // Skip the first pair as the While op operands and body function results does
2942 // not need to be compatible with each other.
2943 for (int i = 1; i < kNumRegionTypeLists; ++i)
2944 if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
2945 return failure();
2946
2947 for (int i = 0; i < kNumRegionTypeLists; ++i)
2948 if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
2949 return failure();
2950
2951 for (int i = 0; i < kNumRegionTypeLists; ++i)
2952 for (int j = i + 1; j < kNumRegionTypeLists; ++j)
2953 if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
2954 region_types[j])))
2955 return failure();
2956
2957 return success();
2958 }
2959
verifySymbolUses(SymbolTableCollection & symbol_table)2960 LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
2961 // TODO(jpienaar): Remove.
2962 if (failed(WhileOpAdaptor(*this).verify(getLoc()))) return failure();
2963
2964 auto cond_fn = symbol_table.lookupNearestSymbolFrom<FuncOp>(*this, cond());
2965 auto body_fn = symbol_table.lookupNearestSymbolFrom<FuncOp>(*this, body());
2966 if (!cond_fn) {
2967 return emitOpError("cond refers to an undefined function : ") << cond();
2968 }
2969 if (!body_fn) {
2970 return emitOpError("body refers to an undefined function : ") << body();
2971 }
2972
2973 auto cond_fn_type = cond_fn.getType();
2974 auto body_fn_type = body_fn.getType();
2975
2976 // Verify that the cond function has exactly one result.
2977 if (cond_fn_type.getNumResults() != 1)
2978 return emitOpError("requires cond function to have exactly one result");
2979
2980 return VerifyWhileTypes(*this, /*cond_input=*/cond_fn_type.getInputs(),
2981 /*body_input=*/body_fn_type.getInputs(),
2982 /*body_result=*/body_fn_type.getResults(),
2983 shape_invariant());
2984 }
2985
2986 //===----------------------------------------------------------------------===//
2987 // WhileRegionOp
2988 //===----------------------------------------------------------------------===//
Verify(WhileRegionOp op)2989 static LogicalResult Verify(WhileRegionOp op) {
2990 // Verify that the condition generates a single tensor<i1> result.
2991 Operation *cond_yield = op.cond().front().getTerminator();
2992 if (cond_yield->getNumOperands() != 1)
2993 return op.emitOpError()
2994 << "condition should have a single tensor<i1> result";
2995
2996 auto cond_type =
2997 cond_yield->getOperand(0).getType().dyn_cast<RankedTensorType>();
2998 if (!cond_type || !cond_type.getShape().equals({}) ||
2999 !cond_type.getElementType().isInteger(/*width=*/1))
3000 return op.emitOpError()
3001 << "condition should have a single tensor<i1> result";
3002
3003 Operation *body_yield = op.body().front().getTerminator();
3004 if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
3005 /*body_input=*/op.body().getArgumentTypes(),
3006 /*body_result=*/body_yield->getOperandTypes(),
3007 op.shape_invariant())))
3008 return failure();
3009 return success();
3010 }
3011
3012 //===----------------------------------------------------------------------===//
3013 // WhileRegionOp LoopLikeOpInterface
3014 //===----------------------------------------------------------------------===//
3015
getLoopBody()3016 Region &WhileRegionOp::getLoopBody() { return body(); }
3017
isDefinedOutsideOfLoop(Value value)3018 bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) {
3019 // If the Op defining the value exists and the defining op is outside the
3020 // scope of this WhileRegion, then we can infer that its defined outside.
3021 // The defining Op is outside the scope of this WhileRegion if this
3022 // WhileRegionOp is not an ancestor of the defining op in the parent chain.
3023 Operation *def_op = value.getDefiningOp();
3024 return def_op && !getOperation()->isAncestor(def_op);
3025 }
3026
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)3027 LogicalResult WhileRegionOp::moveOutOfLoop(
3028 llvm::ArrayRef<mlir::Operation *> ops) {
3029 // Move the hoisted value to just before the while.
3030 Operation *while_op = this->getOperation();
3031 for (auto op : ops) op->moveBefore(while_op);
3032 return success();
3033 }
3034
3035 //===----------------------------------------------------------------------===//
3036 // WhileRegionOp canonicalization
3037 //===----------------------------------------------------------------------===//
3038 namespace {
3039 // Eliminate values that pass through the WhileRegionOp body.
3040 struct WhileRegionEliminatePassThrough
3041 : public OpRewritePattern<WhileRegionOp> {
3042 using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
3043
matchAndRewritemlir::TF::__anon4e3f163f1511::WhileRegionEliminatePassThrough3044 LogicalResult matchAndRewrite(WhileRegionOp while_op,
3045 PatternRewriter &rewriter) const override {
3046 // Remove any extern values that are explicitly captured and returned. Also
3047 // replace values that simply passthrough the body with extern values. The
3048 // block arguments of body and while match and so the corresponding cond
3049 // argument can be easily found.
3050 int old_num_operands = while_op.getNumOperands();
3051 int new_num_operands = old_num_operands;
3052 auto &body_block = while_op.body().front();
3053 auto &cond_block = while_op.cond().front();
3054 auto &yield = *body_block.getTerminator();
3055
3056 // Bit mask indicating which operands will be removed.
3057 llvm::BitVector removed_operand(old_num_operands);
3058
3059 for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3060 auto body_arg = body_block.getArgument(op_idx);
3061 auto yield_operand = LookThroughIdentity(yield.getOperand(op_idx));
3062 auto while_operand = while_op.getOperand(op_idx);
3063 if (body_arg == yield_operand || while_operand == yield_operand) {
3064 // Replace the use of the passthrough value with the while operand
3065 // in the body and condition regions, as well as the while output (if
3066 // type match)
3067 // TODO(jurahul): Use PatternRewriter API for IR modification.
3068 if (body_arg.getType() == while_operand.getType())
3069 body_arg.replaceAllUsesWith(while_operand);
3070
3071 auto cond_arg = cond_block.getArgument(op_idx);
3072 if (cond_arg.getType() == while_operand.getType())
3073 cond_arg.replaceAllUsesWith(while_operand);
3074
3075 auto result = while_op.getResult(op_idx);
3076 if (result.getType() == while_operand.getType())
3077 result.replaceAllUsesWith(while_operand);
3078 }
3079
3080 // Now check if the operand is unused in both regions as well as the
3081 // result. If so, mark it for removal.
3082 if (body_block.getArgument(op_idx).use_empty() &&
3083 cond_block.getArgument(op_idx).use_empty() &&
3084 while_op.getResult(op_idx).use_empty()) {
3085 removed_operand.set(op_idx);
3086 new_num_operands--;
3087 }
3088 }
3089
3090 if (new_num_operands == old_num_operands) return failure();
3091
3092 // Compress the operands, region arguments, and outputs.
3093 SmallVector<Value, 4> new_while_operands;
3094 SmallVector<Type, 4> new_result_types;
3095 new_while_operands.reserve(new_num_operands);
3096 new_result_types.reserve(new_num_operands);
3097
3098 // Build new operands and result type.
3099 for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3100 if (removed_operand.test(op_idx)) continue;
3101 new_while_operands.push_back(while_op.getOperand(op_idx));
3102 new_result_types.push_back(while_op.getResult(op_idx).getType());
3103 }
3104
3105 // Create the new while operation.
3106 auto new_while_op = rewriter.create<WhileRegionOp>(
3107 while_op.getLoc(), new_result_types, new_while_operands,
3108 while_op->getAttrs());
3109
3110 // Move region bodies to the new while.
3111 rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
3112 new_while_op.cond().end());
3113 rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
3114 new_while_op.body().end());
3115
3116 auto &new_cond_block = new_while_op.cond().front();
3117 auto &new_body_block = new_while_op.body().front();
3118 auto &new_yield = *new_body_block.getTerminator();
3119
3120 // Patch up the region bodies and yield.
3121 new_cond_block.eraseArguments(removed_operand);
3122 new_body_block.eraseArguments(removed_operand);
3123 new_yield.eraseOperands(removed_operand);
3124
3125 // Build a vector of new results. Also patch up the region bodies and
3126 // yield.
3127 SmallVector<Value, 4> new_results(old_num_operands);
3128 int next_idx = 0;
3129 for (int op_idx : llvm::seq<int>(0, old_num_operands))
3130 if (!removed_operand.test(op_idx))
3131 new_results[op_idx] = new_while_op.getResult(next_idx++);
3132
3133 rewriter.replaceOp(while_op, new_results);
3134 return success();
3135 }
3136 };
3137
3138 } // anonymous namespace
3139
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3140 void WhileRegionOp::getCanonicalizationPatterns(
3141 OwningRewritePatternList &results, MLIRContext *context) {
3142 results.insert<WhileRegionEliminatePassThrough>(context);
3143 }
3144
3145 //===----------------------------------------------------------------------===//
3146 // XdivyOp
3147 //===----------------------------------------------------------------------===//
3148
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3149 void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3150 MLIRContext *context) {
3151 results.insert<XdivyWithSqrtDivisor>(context);
3152 }
3153
3154 //===----------------------------------------------------------------------===//
3155 // XlaBroadcastHelperOp
3156 //===----------------------------------------------------------------------===//
3157
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3158 LogicalResult XlaBroadcastHelperOp::inferReturnTypes(
3159 MLIRContext *context, Optional<Location> location, ValueRange operands,
3160 DictionaryAttr attributes, RegionRange regions,
3161 SmallVectorImpl<Type> &inferredReturnTypes) {
3162 auto loc = location ? *location : mlir::UnknownLoc::get(context);
3163 XlaBroadcastHelperOpAdaptor op(operands, attributes);
3164 if (failed(op.verify(loc))) {
3165 return failure();
3166 }
3167
3168 Value lhs = op.lhs();
3169 Value rhs = op.rhs();
3170 auto set_unranked_results = [&]() {
3171 auto unranked_lhs = UnrankedTensorType::get(getElementTypeOrSelf(lhs));
3172 inferredReturnTypes.push_back(unranked_lhs);
3173 auto unranked_rhs = UnrankedTensorType::get(getElementTypeOrSelf(rhs));
3174 inferredReturnTypes.push_back(unranked_rhs);
3175 return success();
3176 };
3177
3178 RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
3179 RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
3180 if (!lhs_ty || !rhs_ty) return set_unranked_results();
3181
3182 int64_t lhs_rank = lhs_ty.getRank();
3183 int64_t rhs_rank = rhs_ty.getRank();
3184
3185 DenseIntElementsAttr dims;
3186 if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
3187 return set_unranked_results();
3188 }
3189
3190 if (dims.size() == 0) {
3191 if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
3192 return emitOptionalError(
3193 location,
3194 "if broadcast_dims is empty, both arguments must have equal rank or "
3195 "at least one argument must be a scalar");
3196 }
3197 inferredReturnTypes.push_back(lhs_ty);
3198 inferredReturnTypes.push_back(rhs_ty);
3199 return success();
3200 }
3201
3202 const bool broadcast_lhs = lhs_rank < rhs_rank;
3203 RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
3204 RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
3205
3206 if (dims.size() != min_rank_ty.getRank()) {
3207 return emitOptionalError(
3208 location,
3209 "broadcast_dims must have size equal to the smaller argument rank");
3210 }
3211
3212 int64_t output_rank = max_rank_ty.getRank();
3213 llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
3214 llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
3215 for (auto item : llvm::enumerate(dims)) {
3216 int64_t index = item.index();
3217 int64_t dim = item.value().getSExtValue();
3218 if (dim < 0 || dim > output_rank) {
3219 return emitOptionalError(location, "out of range broadcast dim");
3220 }
3221 if (is_broadcasted[dim]) {
3222 return emitOptionalError(location, "broadcast_dims has duplicates");
3223 }
3224 broadcast_shape[dim] = min_rank_ty.getDimSize(index);
3225 is_broadcasted[dim] = true;
3226 }
3227
3228 if (broadcast_lhs) {
3229 inferredReturnTypes.push_back(
3230 RankedTensorType::get(broadcast_shape, lhs_ty.getElementType()));
3231 inferredReturnTypes.push_back(rhs_ty);
3232 } else {
3233 inferredReturnTypes.push_back(lhs_ty);
3234 inferredReturnTypes.push_back(
3235 RankedTensorType::get(broadcast_shape, rhs_ty.getElementType()));
3236 }
3237 return success();
3238 }
3239
3240 //===----------------------------------------------------------------------===//
3241 // XlaSetDynamicDimensionSizeOp
3242 //===----------------------------------------------------------------------===//
3243
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3244 LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypes(
3245 MLIRContext *context, Optional<Location> location, ValueRange operands,
3246 DictionaryAttr attributes, RegionRange regions,
3247 SmallVectorImpl<Type> &inferredReturnTypes) {
3248 auto loc = location ? *location : mlir::UnknownLoc::get(context);
3249 XlaSetDynamicDimensionSizeOpAdaptor op(operands, attributes);
3250 if (failed(op.verify(loc))) return failure();
3251
3252 TensorType operand_ty = op.input().getType().cast<TensorType>();
3253 Type element_ty = operand_ty.getElementType();
3254
3255 TensorType result_ty;
3256 if (operand_ty.hasRank()) {
3257 auto shape = llvm::to_vector<4>(operand_ty.getShape());
3258
3259 DenseIntElementsAttr dim_index_attr;
3260 if (matchPattern(op.dim_index(), m_Constant(&dim_index_attr))) {
3261 int64_t dim_index = dim_index_attr.getValue<APInt>({}).getSExtValue();
3262
3263 int64_t rank = operand_ty.getRank();
3264 if (dim_index < 0 || dim_index >= rank) {
3265 return emitOptionalError(location, "dim_index (", dim_index,
3266 ") is out of range [0, ", rank, ")");
3267 }
3268 shape[dim_index] = RankedTensorType::kDynamicSize;
3269 } else {
3270 shape.assign(shape.size(), RankedTensorType::kDynamicSize);
3271 }
3272 result_ty = RankedTensorType::get(shape, element_ty);
3273 } else {
3274 result_ty = UnrankedTensorType::get(element_ty);
3275 }
3276
3277 inferredReturnTypes.push_back(result_ty);
3278 return success();
3279 }
3280
3281 } // namespace TF
3282 } // namespace mlir
3283
3284 //===----------------------------------------------------------------------===//
3285 // TableGen'd op method definitions
3286 //===----------------------------------------------------------------------===//
3287
3288 #define GET_OP_CLASSES
3289 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc"
3290