1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/BitVector.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/None.h"
33 #include "llvm/ADT/Optional.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/Sequence.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/ADT/StringSwitch.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/ADT/iterator_range.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
45 #include "mlir/Dialect/Traits.h" // from @llvm-project
46 #include "mlir/IR/Attributes.h" // from @llvm-project
47 #include "mlir/IR/Builders.h" // from @llvm-project
48 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
49 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
50 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
51 #include "mlir/IR/Diagnostics.h" // from @llvm-project
52 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
53 #include "mlir/IR/Location.h" // from @llvm-project
54 #include "mlir/IR/MLIRContext.h" // from @llvm-project
55 #include "mlir/IR/Matchers.h" // from @llvm-project
56 #include "mlir/IR/OpDefinition.h" // from @llvm-project
57 #include "mlir/IR/OpImplementation.h" // from @llvm-project
58 #include "mlir/IR/PatternMatch.h" // from @llvm-project
59 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
60 #include "mlir/IR/Types.h" // from @llvm-project
61 #include "mlir/IR/Value.h" // from @llvm-project
62 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
63 #include "mlir/Parser/Parser.h" // from @llvm-project
64 #include "mlir/Support/LLVM.h" // from @llvm-project
65 #include "mlir/Support/LogicalResult.h" // from @llvm-project
66 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h"
72 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
74 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
77 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
78 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
79 #include "tensorflow/core/platform/logging.h"
80 #include "tensorflow/core/util/tensor_format.h"
81
82 namespace mlir {
83 namespace TF {
84
85 namespace {
86 // Returns the equivalent Value skipping through identity nodes.
LookThroughIdentity(Value result)87 Value LookThroughIdentity(Value result) {
88 while (isa_and_nonnull<IdentityOp, IdentityNOp>(result.getDefiningOp())) {
89 auto op_result = result.cast<OpResult>();
90 result = op_result.getOwner()->getOperand(op_result.getResultNumber());
91 }
92 return result;
93 }
94
95 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
96 } // namespace
97
98 //===----------------------------------------------------------------------===//
99 // NotEqualOp
100 //===----------------------------------------------------------------------===//
101
verify()102 LogicalResult NotEqualOp::verify() {
103 NotEqualOp op = *this;
104 // If we allow inputs to have incompatible type, then nothing to do.
105 if (!op.incompatible_shape_error()) return success();
106
107 // Otherwise, check inputs are broadcastable.
108 return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
109 op.getOperation());
110 }
111
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)112 void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
113 Value y, BoolAttr incompatible_shape_error) {
114 auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
115 incompatible_shape_error);
116 return build(builder, result, result_type, x, y, incompatible_shape_error);
117 }
118
119 //===----------------------------------------------------------------------===//
120 // OneHotOp
121 //===----------------------------------------------------------------------===//
122
verify()123 LogicalResult OneHotOp::verify() {
124 OneHotOp op = *this;
125 int64_t axis = op.axis();
126
127 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
128 if (indices_ty &&
129 !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
130 return op.emitOpError()
131 << "expected axis (" << axis << ") to be -1 or between [0, "
132 << indices_ty.getShape().size() << "]";
133 }
134
135 if (axis < -1) {
136 return op.emitOpError() << "expected axis (" << axis
137 << ") to be -1 or between [0, rank(indices()))";
138 }
139
140 if (!IsOfRankOrUnranked(op.depth(), 0)) {
141 return op.emitOpError() << "requires depth to be a scalar";
142 }
143 if (!IsOfRankOrUnranked(op.on_value(), 0)) {
144 return op.emitOpError() << "requires on_value to be a scalar";
145 }
146 if (!IsOfRankOrUnranked(op.off_value(), 0)) {
147 return op.emitOpError() << "requires off_value to be a scalar";
148 }
149
150 DenseIntElementsAttr depth_attr;
151 if (matchPattern(op.depth(), m_Constant(&depth_attr))) {
152 if (depth_attr.getType().getRank() != 0)
153 return op.emitOpError() << "requires depth to be a scalar";
154 int64_t depth = depth_attr.getValues<APInt>()[0].getSExtValue();
155 if (depth < 0) {
156 return op.emitOpError() << "depth must be non-negative, got: " << depth;
157 }
158 }
159
160 return success();
161 }
162
InferOneHotOpType(Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)163 static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
164 Value off_value, IntegerAttr axis) {
165 int64_t axis_val = axis.getInt();
166 Type element_ty = on_value.getType().cast<TensorType>().getElementType();
167 auto unranked_ty = UnrankedTensorType::get(element_ty);
168 if (axis_val < -1) return unranked_ty;
169
170 auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
171 if (!indices_ty) return unranked_ty;
172
173 auto shape = llvm::to_vector<2>(indices_ty.getShape());
174 if (axis_val == -1) axis_val = shape.size();
175
176 int64_t depth_val = ShapedType::kDynamicSize;
177 DenseIntElementsAttr depth_attr;
178 if (matchPattern(depth, m_Constant(&depth_attr)) &&
179 depth_attr.getNumElements() == 1)
180 depth_val = (*depth_attr.begin()).getSExtValue();
181 shape.insert(shape.begin() + axis_val, depth_val);
182 return RankedTensorType::get(shape, element_ty);
183 }
184
build(OpBuilder & builder,OperationState & result,Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)185 void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices,
186 Value depth, Value on_value, Value off_value,
187 IntegerAttr axis) {
188 build(builder, result,
189 InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
190 depth, on_value, off_value, axis);
191 }
192
193 //===----------------------------------------------------------------------===//
194 // PackOp
195 //===----------------------------------------------------------------------===//
196
verify()197 LogicalResult PackOp::verify() {
198 PackOp op = *this;
199 // TODO(hinsu): Convert variadic length attributes to derived attributes.
200 Operation::operand_range values = op.values();
201
202 if (failed(VerifyTypesCompatibility(values,
203 /*mask_one_dim=*/false,
204 op.getOperation()))) {
205 return failure();
206 }
207
208 int64_t inputs_rank = -1;
209 for (Value value : values) {
210 if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
211 // Exit early as input types are verified to be compatible so all ranked
212 // tensors have the same rank.
213 inputs_rank = ty.getRank();
214 break;
215 }
216 }
217 if (inputs_rank == -1) return success();
218
219 // The values can be packed along any of the dimensions between 0 and
220 // inputs rank, inclusive. Also, as the negative axis values wrap around so
221 // the axis value range is [-(R+1), R+1).
222 int64_t range_begin = -inputs_rank - 1; // Inclusive
223 int64_t range_end = inputs_rank + 1; // Exclusive
224 int64_t axis = op.axis();
225 if (axis < range_begin || axis >= range_end) {
226 return op.emitError() << "attribute 'axis' should be within range ["
227 << range_begin << ", " << range_end
228 << "); actual value: " << axis;
229 }
230
231 return success();
232 }
233
fold(ArrayRef<Attribute> operands)234 OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
235 // Fold pack operation if it computes the input tensor shape:
236 //
237 // %shape = tf.Shape(%arg) // [? x ...]
238 // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
239 // %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
240 //
241 // Where `...` are some statically known dimensions. In this case %pack can be
242 // replaced with a %shape. This is a common pattern in models with a dynamic
243 // batch size.
244
245 // Pack operation should pack at least two values.
246 if (values().size() < 2) return {};
247
248 // Dimensions packed along axis = 0 (pack scalars into vector).
249 if (axis() != 0) return {};
250
251 // First packed value is defined by a strided slice operation.
252 auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
253 if (!slice_op) return {};
254
255 // Input to the slice op is defined by shape operation.
256 auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
257 if (!shape_op) return {};
258
259 // Input tensor, which shape is reconstructed by the pack operation.
260 Value tensor = shape_op.input();
261
262 // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
263 // scalar value from input vector).
264 if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 ||
265 slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 ||
266 slice_op.shrink_axis_mask() != 1)
267 return {};
268
269 // Returns a value if the `value` is defined by a ConstOp with a single
270 // integer element in it and has an expected rank.
271 auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
272 auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
273 if (!const_op) return None;
274
275 auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
276 if (!value_attr || value_attr.getNumElements() != 1) return None;
277
278 auto value_ty = value_attr.getType();
279 if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
280
281 auto splat = value_attr.getSplatValue<IntegerAttr>();
282 return splat.getValue().getSExtValue();
283 };
284
285 // All other packed values are scalar constants.
286 SmallVector<int64_t, 4> packed_dims;
287 packed_dims.reserve(values().size() - 1);
288 for (Value operand : llvm::drop_begin(values(), 1)) {
289 if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
290 packed_dims.push_back(*dim);
291 } else {
292 return {};
293 }
294 }
295
296 // Slice exactly the first shape dimension:
297 // begin = [0] end = [1], strides = [1]
298 auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
299 auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
300 auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
301 if (!begin.has_value() || !end.has_value() || !strides.has_value() ||
302 *begin != 0 || *end != 1 || *strides != 1)
303 return {};
304
305 // First tensor dimension is dynamic.
306 auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
307 if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
308 !arg_ty.isDynamicDim(0))
309 return {};
310
311 // Argument tensor rank is equal to the number of packed dimensions.
312 if (arg_ty.getRank() != values().size()) return {};
313
314 // All other dimensions are statically known and equal to packed dims.
315 auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
316 if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
317 return {};
318
319 // Replace %pack with %shape.
320 return slice_op.input();
321 }
322
323 // Convert Pack to Reshape when there is only one operand to be packed.
324 // For example,
325 //
326 // %0 = tf.Pack(%input) {axis = 0} // %input : tensor<2x3xf32>
327 //
328 // can be canonicalized to
329 //
330 // %shape = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>}
331 // %0 = tf.Reshape(%input, %shape)
332 struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
333 using OpRewritePattern<PackOp>::OpRewritePattern;
334
matchAndRewritemlir::TF::ConvertPackToReshape335 LogicalResult matchAndRewrite(PackOp pack_op,
336 PatternRewriter &rewriter) const override {
337 // Check if there is only one operand to be packed.
338 if (pack_op.N() != 1) {
339 return failure();
340 }
341
342 // Check if input and output are static.
343 auto input_ty = pack_op.getOperand(0).getType().cast<ShapedType>();
344 auto output_ty = pack_op.output().getType().cast<ShapedType>();
345 if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) {
346 return failure();
347 }
348
349 // Create constant shape for reshape.
350 auto type =
351 RankedTensorType::get(output_ty.getRank(), rewriter.getIntegerType(64));
352 auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
353 auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
354
355 // TODO(b/173622615): Remove after fixed.
356 ReplaceTfOpWithNewOp<ReshapeOp>(rewriter, pack_op, output_ty,
357 pack_op.getOperand(0), shape);
358 return success();
359 }
360 };
361
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)362 void PackOp::getCanonicalizationPatterns(RewritePatternSet &results,
363 MLIRContext *context) {
364 results.add<ConvertPackToReshape>(context);
365 }
366
367 //===----------------------------------------------------------------------===//
368 // PadOp
369 //===----------------------------------------------------------------------===//
370
FoldOperandsPermutation(ArrayRef<int64_t> permutation)371 LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
372 // Paddings must be defined by a constant operation.
373 auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
374 if (!paddings_op) return failure();
375
376 auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
377 if (!paddings_value ||
378 paddings_value.getNumElements() != permutation.size() * 2)
379 return failure();
380
381 SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
382 for (auto index_pair : llvm::enumerate(paddings_value.getValues<APInt>())) {
383 size_t outer_idx = index_pair.index() / 2;
384 size_t inner_idx = index_pair.index() % 2;
385
386 shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
387 index_pair.value().getSExtValue();
388 }
389
390 // Add constant operation with a new paddings.
391 OpBuilder builder(getOperation());
392 auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
393 builder.getIntegerType(32));
394 auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
395 auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
396
397 // Use new paddings.
398 setOperand(1, shuffled_paddings_op);
399
400 // Change the result type.
401 getResult().setType(ShuffleRankedTensorType(getResult().getType(),
402 ReversePermutation(permutation)));
403
404 return success();
405 }
406
407 //===----------------------------------------------------------------------===//
408 // ParseExampleV2Op
409 //===----------------------------------------------------------------------===//
410
verify()411 LogicalResult ParseExampleV2Op::verify() {
412 ParseExampleV2Op op = *this;
413 // NOTE(mrry): This validates properties of an op that would previously be
414 // validated by the TensorFlow OpDef type checker. In addition to these
415 // checks, the shape inference function for ParseExampleV2 validates the
416 // consistency of the argument and result types.
417
418 // Validate dense variadic input and output lengths.
419 // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
420 // do not need to validate dense_defaults.
421 auto dense_types_count =
422 std::distance(op.Tdense().begin(), op.Tdense().end());
423 auto dense_values_count =
424 std::distance(op.dense_values().begin(), op.dense_values().end());
425 if (dense_values_count != dense_types_count) {
426 return op.emitError() << "output 'dense_values' should have same length "
427 << "as attribute 'Tdense'";
428 }
429
430 // Validate sparse variadic output lengths.
431 // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
432 // do not need to validate sparse_values.
433 auto sparse_types_count =
434 std::distance(op.sparse_types().begin(), op.sparse_types().end());
435 if (op.num_sparse() != sparse_types_count) {
436 return op.emitError() << "attribute 'num_sparse' should be the same as "
437 << "the length of attribute 'sparse_types'";
438 }
439 if (op.sparse_indices().size() != sparse_types_count) {
440 return op.emitError() << "output 'sparse_indices' should have same length "
441 << "as attribute 'sparse_types'";
442 }
443 if (op.sparse_shapes().size() != sparse_types_count) {
444 return op.emitError() << "output 'sparse_shapes' should have same length "
445 << "as attribute 'sparse_types'";
446 }
447
448 // Validate ragged variadic output lengths.
449 auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
450 op.ragged_value_types().end());
451 auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
452 op.ragged_split_types().end());
453 if (ragged_value_types_count != ragged_split_types_count) {
454 return op.emitError() << "attribute 'ragged_value_types' should have same "
455 << "length as attribute 'ragged_split_types'";
456 }
457
458 return success();
459 }
460
461 //===----------------------------------------------------------------------===//
462 // PartitionedCallOp
463 //===----------------------------------------------------------------------===//
464
465 template <typename CallOpClass>
VerifyPartitionedCall(CallOpClass op,SymbolTableCollection & symbolTable)466 static LogicalResult VerifyPartitionedCall(CallOpClass op,
467 SymbolTableCollection &symbolTable) {
468 SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
469 auto function = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(op, func);
470 if (!function) {
471 return op.emitError("'f' attribute refers to an undefined function: ")
472 << func;
473 }
474
475 FunctionType function_ty = function.getFunctionType();
476 int func_arg_count = function_ty.getNumInputs();
477 int arg_count = op.args().size();
478
479 if (arg_count != func_arg_count) {
480 return op.emitError() << "argument count mismatch: 'args' has " << arg_count
481 << " arguments, but '" << func << "' expects "
482 << func_arg_count;
483 }
484
485 return success();
486 }
487
verifySymbolUses(SymbolTableCollection & symbolTable)488 LogicalResult PartitionedCallOp::verifySymbolUses(
489 SymbolTableCollection &symbolTable) {
490 return VerifyPartitionedCall(*this, symbolTable);
491 }
verifySymbolUses(SymbolTableCollection & symbolTable)492 LogicalResult StatefulPartitionedCallOp::verifySymbolUses(
493 SymbolTableCollection &symbolTable) {
494 return VerifyPartitionedCall(*this, symbolTable);
495 }
verifySymbolUses(SymbolTableCollection & symbolTable)496 LogicalResult TPUPartitionedCallOp::verifySymbolUses(
497 SymbolTableCollection &symbolTable) {
498 return VerifyPartitionedCall(*this, symbolTable);
499 }
500
501 //===----------------------------------------------------------------------===//
502 // PowOp
503 //===----------------------------------------------------------------------===//
504
fold(ArrayRef<Attribute> operands)505 OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
506 auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
507 if (constant_y && constant_y.isSplat()) {
508 APFloat y_value = constant_y.getSplatValue<APFloat>();
509 auto output_type = getType().cast<ShapedType>();
510 if (y_value.isZero() && output_type.hasStaticShape()) {
511 return DenseElementsAttr::get(
512 output_type,
513 FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
514 }
515 if (y_value.isExactlyValue(1.0)) {
516 return x();
517 }
518 }
519 return {};
520 }
521
522 //===----------------------------------------------------------------------===//
523 // QuantizeAndDequantizeV2Op
524 //===----------------------------------------------------------------------===//
525
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)526 void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns(
527 RewritePatternSet &results, MLIRContext *context) {
528 results.add<QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4>(context);
529 }
530
531 //===----------------------------------------------------------------------===//
532 // QrOp
533 //===----------------------------------------------------------------------===//
534
535 // Verifies that,
536 //
537 // * Input type, if ranked, must have at least 2 dimensions and at most
538 // INT32_MAX dimensions.
539 //
verify()540 LogicalResult QrOp::verify() {
541 QrOp op = *this;
542 auto ttype = op.input().getType().cast<TensorType>();
543 if (!ttype.hasRank()) return success();
544 if (!HasRankAtLeast(op.input(), 2))
545 return op.emitOpError(
546 "requires ranked input tensor to be of rank 2 or more");
547 if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
548 return op.emitOpError(
549 "requires ranked input tensor to be of rank INT32_MAX or less");
550
551 return success();
552 }
553
554 //===----------------------------------------------------------------------===//
555 // ReadVariableOp
556 //===----------------------------------------------------------------------===//
557
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)558 void ReadVariableOp::getCanonicalizationPatterns(RewritePatternSet &results,
559 MLIRContext *context) {
560 results.add<ReadVariableOfCast>(context);
561 }
562
563 //===----------------------------------------------------------------------===//
564 // RandomUniformOp
565 //===----------------------------------------------------------------------===//
566
verify()567 LogicalResult RandomUniformOp::verify() {
568 RandomUniformOp op = *this;
569 if (!IsOfRankOrUnranked(op.shape(), 1))
570 return op.emitOpError("shape must be 1D tensor");
571 return success();
572 }
573
574 //===----------------------------------------------------------------------===//
575 // RangeOp
576 //===----------------------------------------------------------------------===//
577
578 namespace {
579
580 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
581 // Template parameter `FloatOrInt` must be standard C integer or floating-point
582 // types.
583 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)584 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
585 // Refer to the implementation in
586 // tensorflow/lite/kernels/range.cc.
587 FloatOrInt diff = limit - start;
588 if (std::is_integral<FloatOrInt>::value) {
589 return ((std::abs(diff) + std::abs(delta) - 1) / std::abs(delta));
590 }
591 return std::ceil(std::abs(diff / delta));
592 }
593
594 // Builds a constant range tensor of `result_elem_type` elements.
595 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
596 // mlir::FloatAttr.
597 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)598 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
599 FloatOrIntAtrr start_attr,
600 FloatOrIntAtrr delta_attr) {
601 using ValueType = typename FloatOrIntAtrr::ValueType; // APInt or APFloat
602 ValueType start = start_attr.getValue();
603 ValueType delta = delta_attr.getValue();
604
605 SmallVector<ValueType, 16> new_values;
606 new_values.reserve(num_elements);
607 ValueType new_value = start;
608 for (int i = 0; i < num_elements; ++i) {
609 new_values.push_back(new_value);
610 new_value = new_value + delta;
611 }
612 // Result is always a 1-D tensor.
613 auto new_result_type =
614 RankedTensorType::get({num_elements}, result_elem_type);
615 return DenseElementsAttr::get(new_result_type, new_values);
616 }
617 } // namespace
618
build(OpBuilder & builder,OperationState & result,Value start,Value limit,Value delta)619 void RangeOp::build(OpBuilder &builder, OperationState &result, Value start,
620 Value limit, Value delta) {
621 assert(start.getType() == limit.getType());
622 assert(start.getType() == delta.getType());
623 DenseIntElementsAttr start_val;
624 DenseIntElementsAttr limit_val;
625 DenseIntElementsAttr delta_val;
626 if (matchPattern(start, m_Constant(&start_val)) &&
627 matchPattern(limit, m_Constant(&limit_val)) &&
628 matchPattern(delta, m_Constant(&delta_val))) {
629 auto size = llvm::APIntOps::RoundingSDiv(
630 *limit_val.begin() - *start_val.begin(), *delta_val.begin(),
631 llvm::APInt::Rounding::DOWN);
632 return RangeOp::build(
633 builder, result,
634 RankedTensorType::get(
635 size.getSExtValue(),
636 start.getType().cast<TensorType>().getElementType()),
637 start, limit, delta);
638 }
639 return RangeOp::build(
640 builder, result,
641 RankedTensorType::get(
642 {-1}, start.getType().cast<TensorType>().getElementType()),
643 start, limit, delta);
644 }
645
fold(ArrayRef<Attribute> operands)646 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
647 assert(operands.size() == 3);
648 auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
649 auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
650 auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
651 if (!(start_tensor && limit_tensor && delta_tensor)) return nullptr;
652
653 // Operands should all be scalars
654 assert(start_tensor.getType().getRank() == 0 &&
655 limit_tensor.getType().getRank() == 0 &&
656 delta_tensor.getType().getRank() == 0);
657 Type elem_type = getType().cast<ShapedType>().getElementType();
658 if (elem_type.isSignlessInteger() || elem_type.isUnsignedInteger()) {
659 auto start_attr = start_tensor.getValues<IntegerAttr>()[0];
660 auto limit_attr = limit_tensor.getValues<IntegerAttr>()[0];
661 auto delta_attr = delta_tensor.getValues<IntegerAttr>()[0];
662 int num_elements;
663 if (elem_type.isUnsignedInteger()) {
664 uint64_t start = start_attr.getUInt();
665 uint64_t limit = limit_attr.getUInt();
666 uint64_t delta = delta_attr.getUInt();
667 assert(start <= (uint64_t)INT_MAX);
668 assert(limit <= (uint64_t)INT_MAX);
669 assert(delta <= (uint64_t)INT_MAX);
670 num_elements =
671 GetLengthOfRange(static_cast<int>(start), static_cast<int>(limit),
672 static_cast<int>(delta));
673 } else {
674 num_elements = GetLengthOfRange(start_attr.getInt(), limit_attr.getInt(),
675 delta_attr.getInt());
676 }
677 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
678 delta_attr);
679 } else if (elem_type.isa<FloatType>()) {
680 auto start_attr = start_tensor.getValues<FloatAttr>()[0];
681 auto limit_attr = limit_tensor.getValues<FloatAttr>()[0];
682 auto delta_attr = delta_tensor.getValues<FloatAttr>()[0];
683 const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
684 limit_attr.getValueAsDouble(),
685 delta_attr.getValueAsDouble());
686 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
687 delta_attr);
688 }
689 return nullptr;
690 }
691
692 //===----------------------------------------------------------------------===//
693 // RankOp
694 //===----------------------------------------------------------------------===//
695
build(OpBuilder & builder,OperationState & result,Value input)696 void RankOp::build(OpBuilder &builder, OperationState &result, Value input) {
697 return RankOp::build(builder, result,
698 RankedTensorType::get({}, builder.getIntegerType(32)),
699 input);
700 }
701
702 // This will create a constant value for RankOp of a ranked tensor.
fold(ArrayRef<Attribute> operands)703 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
704 auto type = input().getType();
705 auto ranked_type = type.dyn_cast<RankedTensorType>();
706 if (!ranked_type) return {};
707
708 // DenseIntElementsAttr::get requires the output type be ranked with static
709 // shape.
710 auto output_type = getType().dyn_cast<RankedTensorType>();
711 if (!output_type || !output_type.hasStaticShape()) return {};
712
713 int32_t rank = ranked_type.getRank();
714 return DenseIntElementsAttr::get(output_type, rank);
715 }
716
717 //===----------------------------------------------------------------------===//
718 // RealDivOp
719 //===----------------------------------------------------------------------===//
720
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)721 void RealDivOp::getCanonicalizationPatterns(RewritePatternSet &results,
722 MLIRContext *context) {
723 results.add<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
724 }
725
fold(ArrayRef<Attribute> operands)726 OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
727 return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
728 }
729
730 //===----------------------------------------------------------------------===//
731 // ReluOp
732 //===----------------------------------------------------------------------===//
733
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)734 void ReluOp::getCanonicalizationPatterns(RewritePatternSet &results,
735 MLIRContext *context) {
736 results.add<ReluOfMinimum6ToRelu6>(context);
737 }
738
739 //===----------------------------------------------------------------------===//
740 // ReshapeOp
741 //===----------------------------------------------------------------------===//
742
743 namespace {
744 using ReshapeErrorHandler =
745 llvm::function_ref<LogicalResult(const llvm::Twine &)>;
746
GetReshapeOutputType(Value tensor,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)747 LogicalResult GetReshapeOutputType(Value tensor, Value shape,
748 ReshapeErrorHandler error_handler,
749 TensorType &output_ty) {
750 auto tensor_ty = tensor.getType().cast<TensorType>();
751 auto element_ty = tensor_ty.getElementType();
752 output_ty = UnrankedTensorType::get(element_ty);
753
754 auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
755 if (!shape_ty) return success();
756 if (shape_ty.getRank() != 1)
757 return error_handler(llvm::formatv(
758 "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
759
760 DenseIntElementsAttr shape_attr;
761 if (!matchPattern(shape, m_Constant(&shape_attr))) {
762 // If only shape of `shape` is known, return ranked but dynamic output
763 // shape.
764 if (shape_ty.hasStaticShape()) {
765 llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
766 ShapedType::kDynamicSize);
767 output_ty = RankedTensorType::get(dynamic_shape, element_ty);
768 }
769 return success();
770 }
771
772 // Detect if reshape output shape is folded.
773 bool shape_ty_zero_dim = false;
774 int unknown_index = -1;
775 // The product of constant shape argument excluding unknown dimension.
776 int64_t shape_ty_size = 1;
777 llvm::SmallVector<int64_t, 8> output_ty_shape;
778 output_ty_shape.reserve(shape_attr.getNumElements());
779 for (const auto &dim : llvm::enumerate(shape_attr.getValues<APInt>())) {
780 const int64_t size = dim.value().getSExtValue();
781 if (ShapedType::isDynamic(size)) {
782 if (unknown_index != -1)
783 return error_handler(llvm::formatv(
784 "requires 'shape' to have at most one dynamic dimension, but got "
785 "multiple dynamic dimensions at indices {0} and {1}",
786 unknown_index, dim.index()));
787
788 unknown_index = dim.index();
789 } else if (size == 0) {
790 shape_ty_zero_dim = true;
791 } else if (size > 0) {
792 shape_ty_size *= size;
793 } else {
794 return error_handler(
795 llvm::formatv("requires 'shape' to have dimensions greater than -1, "
796 "but got {0} at index {1}",
797 size, dim.index()));
798 }
799 output_ty_shape.push_back(size);
800 }
801
802 if (!tensor_ty.hasStaticShape()) {
803 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
804 return success();
805 }
806
807 // Compute the value of the unknown dimension.
808 if (unknown_index != -1) {
809 // Compute number of elements in tensor shape.
810 int64_t tensor_ty_size = 1;
811 bool tensor_ty_zero_dim = false;
812 for (const auto &dim : tensor_ty.getShape()) {
813 if (dim > 0 || !shape_ty_zero_dim) {
814 tensor_ty_size *= dim;
815 } else {
816 tensor_ty_zero_dim = true;
817 }
818 }
819
820 const int64_t missing_dim = tensor_ty_size / shape_ty_size;
821 if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size)
822 return error_handler(
823 llvm::formatv("requires 'tensor' number of elements be a multiple of "
824 "{0}, but got {1}",
825 shape_ty_size, tensor_ty_size));
826
827 // Set the unknown dimension such that total number of elements remain
828 // constant.
829 output_ty_shape[unknown_index] = missing_dim;
830 }
831
832 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
833
834 return success();
835 }
836 } // namespace
837
verify()838 LogicalResult ReshapeOp::verify() {
839 ReshapeOp op = *this;
840 auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
841 return op.emitOpError() << message;
842 };
843 TensorType expected_ty;
844 if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler,
845 expected_ty)))
846 return failure();
847
848 auto output_ty = op.getType().dyn_cast<RankedTensorType>();
849 if (!output_ty) return success();
850 auto tensor_ty = op.tensor().getType().cast<TensorType>();
851 if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) {
852 const int64_t output_ty_size = output_ty.getNumElements();
853 const int64_t tensor_ty_size = tensor_ty.getNumElements();
854 if (tensor_ty_size != output_ty_size)
855 return op.emitOpError() << "requires 'output' number of elements to "
856 "match 'tensor' number of elements, but got "
857 << output_ty_size << " and " << tensor_ty_size;
858 }
859
860 if (!AreCastCompatible({output_ty, expected_ty}))
861 return op.emitOpError()
862 << "requires 'output' type " << output_ty
863 << " to be cast compatible with expected type " << expected_ty;
864
865 return success();
866 }
867
868 // Currently there are use cases that rely on partial evaluation of the `shape`
869 // operand, so InferTypeOpInterface is not used (along with generated builder of
870 // the same signature).
build(OpBuilder & builder,OperationState & result,Value tensor,Value shape)871 void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
872 Value shape) {
873 auto error_handler = [&result](const llvm::Twine &message) {
874 return mlir::emitError(result.location) << message;
875 };
876 TensorType output_ty;
877 if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty)))
878 return;
879
880 return ReshapeOp::build(builder, result, output_ty, tensor, shape);
881 }
882
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)883 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
884 MLIRContext *context) {
885 results.add<RedundantReshape, ReshapeToSelfShape>(context);
886 }
887
fold(ArrayRef<Attribute> operands)888 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
889 Value tensor = this->tensor();
890
891 // Fold reshape if operand and result types are the same and all dimensions
892 // are statically known (no-op reshape).
893 auto result_ty = getType().dyn_cast<ShapedType>();
894 if (result_ty && result_ty.hasStaticShape() &&
895 result_ty == tensor.getType()) {
896 return tensor;
897 }
898
899 return {};
900 }
901
902 //===----------------------------------------------------------------------===//
903 // SelectOp
904 //===----------------------------------------------------------------------===//
905
906 // Verifies a few extra requirements on SelectOp:
907 // (1) `then` and `else` must have same shape
908 // (2) At least one of the following must be true:
909 // (a) `cond` has the same rank as `then` and `else`
910 // (b) `cond` is a scalar
911 // (c) `cond` is a vector AND `then` and `else` are non-scalar with their
912 // first dimension equal to `cond`.
verify()913 LogicalResult SelectOp::verify() {
914 SelectOp op = *this;
915 auto then_tensor = op.t().getType().cast<TensorType>();
916 auto else_tensor = op.e().getType().cast<TensorType>();
917 // Check (1).
918 if (!AreCastCompatible({then_tensor, else_tensor}))
919 return op.emitOpError() << "requires t and e have compatible shapes";
920
921 // Get data rank (if exists).
922 int data_rank;
923 // If data is unranked or data_rank is 0, this will remain -2. Otherwise
924 // refers to first dimension of then and/or else.
925 int data_first_dim = -2;
926 bool then_has_rank = then_tensor.hasRank();
927 bool else_has_rank = else_tensor.hasRank();
928 if (then_has_rank && else_has_rank) {
929 data_rank = then_tensor.getRank();
930 if (then_tensor.getRank() > 0)
931 data_first_dim = then_tensor.getShape().front();
932 if (else_tensor.getRank() > 0)
933 data_first_dim = std::max(
934 static_cast<int>(else_tensor.getShape().front()), data_first_dim);
935 } else if (then_has_rank) {
936 data_rank = then_tensor.getRank();
937 if (then_tensor.getRank() > 0)
938 data_first_dim = then_tensor.getShape().front();
939 } else if (else_has_rank) {
940 data_rank = else_tensor.getRank();
941 if (else_tensor.getRank() > 0)
942 data_first_dim = else_tensor.getShape().front();
943 } else {
944 // Neither has a rank.
945 return success();
946 }
947
948 auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
949 if (!cond_tensor) return success();
950 auto cond_rank = cond_tensor.getRank();
951 // Check (2a) and (2b).
952 if (cond_rank == 0 || cond_rank == data_rank) return success();
953 // Check (2c).
954 if (cond_rank == 1) {
955 auto cond_shape = cond_tensor.getShape().front();
956 if (data_rank == 0) {
957 return op.emitOpError()
958 << "requires that t and e are nonscalar when pred is a vector";
959 }
960 // We know `data` tensor has a rank of at least 1.
961 if (data_first_dim != -1 && cond_shape != -1 &&
962 data_first_dim != cond_shape) {
963 return op.emitOpError() << "requires that, when pred is a vector, the "
964 "shape matches the first dimension of t and e";
965 }
966 return success();
967 }
968 // None of (2a,b,c) were true; fail.
969 return op.emitOpError() << "requires that pred is a scalar OR has the same "
970 "rank as t and e OR is a vector";
971 }
972
973 //===----------------------------------------------------------------------===//
974 // SelectV2Op
975 //===----------------------------------------------------------------------===//
976
InferSelectV2OpType(Value condition,Value e,Value t)977 static Type InferSelectV2OpType(Value condition, Value e, Value t) {
978 Type element_ty = e.getType().cast<TensorType>().getElementType();
979 auto unranked_ty = UnrankedTensorType::get(element_ty);
980
981 Type broadcasted_ty =
982 OpTrait::util::getBroadcastedType(e.getType(), t.getType());
983 if (!broadcasted_ty) return unranked_ty;
984
985 auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
986 auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
987 if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
988
989 // Explicitly get broadcasted output type as element types of condition may
990 // not be same as the broadcated type's element type.
991 SmallVector<int64_t, 4> result_shape;
992 if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(),
993 broadcasted_ranked_ty.getShape(),
994 result_shape))
995 return unranked_ty;
996 return RankedTensorType::get(result_shape, element_ty);
997 }
998
build(OpBuilder & builder,OperationState & result,Value condition,Value e,Value t)999 void SelectV2Op::build(OpBuilder &builder, OperationState &result,
1000 Value condition, Value e, Value t) {
1001 build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
1002 }
1003
1004 //===----------------------------------------------------------------------===//
1005 // ShapeOp
1006 //===----------------------------------------------------------------------===//
1007
1008 namespace {
1009 // Validates Shape/ShapeN/VariableShape operand and associated result types.
VerifyShapeOperandAndResult(Operation * op,Type operand_type,Type result_type,int variadic_idx=-1)1010 LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
1011 Type result_type,
1012 int variadic_idx = -1) {
1013 std::string variadic_idx_str =
1014 variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str();
1015
1016 auto result_ranked_type = result_type.dyn_cast<RankedTensorType>();
1017 if (!result_ranked_type) return success();
1018 if (result_ranked_type.getShape().size() != 1)
1019 return op->emitOpError("requires 1D type for result") << variadic_idx_str;
1020
1021 auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
1022 if (operand_ranked_type) {
1023 // The operand is a ranked tensor.
1024 if (result_ranked_type.hasStaticShape() &&
1025 !operand_ranked_type.getShape().empty() &&
1026 result_ranked_type.getDimSize(0) !=
1027 operand_ranked_type.getShape().size())
1028 return op->emitOpError("requires dimension size of result")
1029 << variadic_idx_str << " to match rank of operand"
1030 << variadic_idx_str;
1031 } else if (result_ranked_type.hasStaticShape()) {
1032 // The operand is an unranked tensor, print a warning if the result
1033 // is static.
1034 // Note: We do not handle this situation as an error, this would be too
1035 // restrictive due to incompleteness of shape inference at this point.
1036 mlir::InFlightDiagnostic diag =
1037 mlir::emitWarning(op->getLoc(), "has static shape result");
1038 if (op->getContext()->shouldPrintOpOnDiagnostic()) {
1039 diag.attachNote(op->getLoc())
1040 .append("see current operation: ")
1041 .appendOp(*op, OpPrintingFlags().printGenericOpForm());
1042 }
1043 diag << variadic_idx_str << " for unranked operand" << variadic_idx_str;
1044 }
1045
1046 Type element_type = result_ranked_type.getElementType();
1047 if (!element_type.isSignlessInteger(32) &&
1048 !element_type.isSignlessInteger(64))
1049 return op->emitOpError("requires int32 or int64 return type for result")
1050 << variadic_idx_str;
1051
1052 return success();
1053 }
1054 } // anonymous namespace
1055
verify()1056 LogicalResult ShapeOp::verify() {
1057 ShapeOp op = *this;
1058 return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
1059 }
1060
1061 // Converts shape of the given type to attribute if it is of ranked tensor type.
1062 // Returned attribute has integer elements of the given width.
ConvertShapeToAttr(Type input_ty,int out_width)1063 static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
1064 auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
1065 if (!ranked_ty || !ranked_ty.hasStaticShape()) return {};
1066
1067 auto shape = ranked_ty.getShape();
1068 int rank = shape.size();
1069
1070 SmallVector<APInt, 4> dimensions;
1071 dimensions.reserve(rank);
1072 for (int i = 0; i < rank; ++i)
1073 dimensions.push_back(APInt(out_width, shape[i]));
1074
1075 auto result_type = RankedTensorType::get(
1076 {rank}, IntegerType::get(input_ty.getContext(), out_width));
1077 return DenseElementsAttr::get(result_type, dimensions);
1078 }
1079
fold(ArrayRef<Attribute> operands)1080 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
1081 int width =
1082 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
1083 return ConvertShapeToAttr(getOperand().getType(), width);
1084 }
1085
build(OpBuilder & builder,OperationState & result,Value input,BoolAttr use32Bit)1086 void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input,
1087 BoolAttr use32Bit) {
1088 auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
1089 int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
1090 auto out_type = use32Bit.getValue() ? builder.getIntegerType(32)
1091 : builder.getIntegerType(64);
1092 return ShapeOp::build(builder, result,
1093 RankedTensorType::get({rank}, out_type), input);
1094 }
1095
1096 //===----------------------------------------------------------------------===//
1097 // ShapeNOp
1098 //===----------------------------------------------------------------------===//
1099
verify()1100 LogicalResult ShapeNOp::verify() {
1101 ShapeNOp op = *this;
1102 const size_t num_tensors = op.N();
1103
1104 if (op.getNumOperands() != num_tensors)
1105 return op.emitOpError() << "requires " << num_tensors << " operand(s), got "
1106 << op.getNumOperands() << " operand(s)";
1107
1108 if (op.getNumResults() != num_tensors)
1109 return op.emitOpError() << "requires " << num_tensors << " result(s), got "
1110 << op.getNumResults() << " result(s)";
1111
1112 for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
1113 auto verification = VerifyShapeOperandAndResult(
1114 op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
1115 if (failed(verification)) return verification;
1116 }
1117
1118 return success();
1119 }
1120
1121 namespace {
1122 // Canonicalization pattern for ShapeNOp that don't have all
1123 // static input shapes. Replacing output values corresponding to static input
1124 // types may enable optimizations in users of the values.
1125 class ShapeNPartialStaticInputShape : public OpRewritePattern<ShapeNOp> {
1126 using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const1127 LogicalResult matchAndRewrite(ShapeNOp op,
1128 PatternRewriter &rewriter) const override {
1129 if (op.getNumOperands() == 0) {
1130 rewriter.eraseOp(op);
1131 return success();
1132 }
1133
1134 int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
1135
1136 SmallVector<Value, 4> results(op.getNumOperands());
1137 SmallVector<int64_t, 4> dynamic_indices;
1138 SmallVector<Value, 4> dynamic_inputs;
1139 SmallVector<Type, 4> result_types;
1140 for (auto e : llvm::enumerate(op.getOperands())) {
1141 if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) {
1142 results[e.index()] = rewriter.create<TF::ConstOp>(op.getLoc(), result);
1143 } else {
1144 dynamic_indices.push_back(e.index());
1145 dynamic_inputs.push_back(e.value());
1146 result_types.push_back(op.getType(e.index()));
1147 }
1148 }
1149
1150 if (dynamic_inputs.size() == op.getNumOperands()) {
1151 // Cannot canonicalize ShapeN if all inputs are dynamic.
1152 return failure();
1153 }
1154
1155 // Create a ShapeNOp for all dynamic inputs.
1156 if (!dynamic_inputs.empty()) {
1157 auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
1158 op.getLoc(), result_types, dynamic_inputs);
1159 for (auto index_result :
1160 llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) {
1161 results[std::get<0>(index_result)] = std::get<1>(index_result);
1162 }
1163 }
1164
1165 rewriter.replaceOp(op, results);
1166 return success();
1167 }
1168 };
1169
1170 // Canonicalize ShapeNOp to ShapeOp if there is only one operand.
1171 class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
1172 using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const1173 LogicalResult matchAndRewrite(ShapeNOp op,
1174 PatternRewriter &rewriter) const override {
1175 if (op.getNumOperands() != 1) {
1176 return failure();
1177 }
1178 auto shape = rewriter.create<TF::ShapeOp>(op.getLoc(), op.getType(0),
1179 op.getOperand(0));
1180 rewriter.replaceOp(op, {shape});
1181 return success();
1182 }
1183 };
1184 } // namespace
1185
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1186 void ShapeNOp::getCanonicalizationPatterns(RewritePatternSet &results,
1187 MLIRContext *context) {
1188 results.add<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
1189 }
1190
1191 //===----------------------------------------------------------------------===//
1192 // SizeOp
1193 //===----------------------------------------------------------------------===//
1194
1195 // Verifies that,
1196 //
1197 // * Input type, if is a ranked tensor, has at most INT32_MAX dimensions.
1198 //
verify()1199 LogicalResult SizeOp::verify() {
1200 SizeOp op = *this;
1201 if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
1202 return op.emitOpError(
1203 "requires ranked input tensor to be of rank INT32_MAX or less");
1204
1205 // Output type needs to be scalar.
1206 if (!IsOfRankOrUnranked(op.output(), /*rank=*/0))
1207 return op.emitOpError("requires scalar output");
1208
1209 return success();
1210 }
1211
fold(ArrayRef<Attribute> operands)1212 OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
1213 ShapedType output_type = getType().cast<ShapedType>();
1214 if (!output_type.hasRank()) return {};
1215 ShapedType input_type = getOperand().getType().cast<ShapedType>();
1216 if (!input_type.hasStaticShape()) return {};
1217 int size = input_type.getNumElements();
1218 return DenseElementsAttr::get(
1219 output_type,
1220 IntegerAttr::get(output_type.getElementType(), /*value=*/size));
1221 }
1222
1223 //===----------------------------------------------------------------------===//
1224 // SliceOp
1225 //===----------------------------------------------------------------------===//
1226
1227 // Verifies that:
1228 //
1229 // - operands begin and size are 1D with the same number of elements.
1230 // - if the input is a ranked tensor, the rank of the input equals the number
1231 // of elements in operands begin and size.
1232 // - if begin are constants, that
1233 // 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
1234 // and
1235 // size[i] == output_ty.getShape()[i]
1236 // - if begins aren't constant but the input is a ranked tensor, that
1237 // size[i] <= input_ty.getShape()[i]
1238 // - output rank is the same as input rank
1239 //
verify()1240 LogicalResult SliceOp::verify() {
1241 SliceOp op = *this;
1242 RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
1243 if (begin_ty && begin_ty.getRank() != 1) {
1244 return op.emitOpError() << "requires begin operand to be 1D tensor";
1245 }
1246
1247 RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
1248 if (size_ty && size_ty.getRank() != 1) {
1249 return op.emitOpError() << "requires size operand to be 1D tensor";
1250 }
1251
1252 if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
1253 !size_ty.hasStaticShape())
1254 return success();
1255
1256 if (begin_ty.getNumElements() != size_ty.getNumElements()) {
1257 return op.emitOpError() << "requires begin and size operands to have the"
1258 " same number of elements";
1259 }
1260
1261 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1262 if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
1263 return op.emitOpError() << "requires number of elements in begin and size "
1264 "are equal to input rank";
1265 }
1266
1267 auto output_ty = op.output().getType().dyn_cast<RankedTensorType>();
1268 if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) {
1269 return op.emitOpError()
1270 << "requires output to have the same rank as input, but got input "
1271 "rank "
1272 << input_ty.getRank() << " and output rank " << output_ty.getRank();
1273 }
1274
1275 DenseIntElementsAttr begin_indices;
1276 if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
1277 DenseIntElementsAttr slice_sizes;
1278 bool constant_slice_sizes =
1279 matchPattern(op.size(), m_Constant(&slice_sizes));
1280 int dim = 0;
1281 // TODO(jpienaar): Reformulate the shape verification below to not use magic
1282 // constants.
1283 for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
1284 int64_t begin_index = raw_begin_index.getSExtValue();
1285 int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
1286 int64_t slice_size =
1287 constant_slice_sizes
1288 ? slice_sizes.getValues<APInt>()[dim].getSExtValue()
1289 : 0;
1290 int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1;
1291
1292 if (slice_size == -1 && input_size != -1) {
1293 slice_size = input_size - begin_index;
1294 }
1295 if (output_size != -1 && constant_slice_sizes &&
1296 output_size != slice_size) {
1297 return op.emitOpError()
1298 << "requires output size to have the same size of slice, got "
1299 "slice size "
1300 << slice_size << " and output size " << output_size;
1301 }
1302 if (begin_index < 0 ||
1303 (input_size != -1 && begin_index + slice_size > input_size)) {
1304 return op.emitOpError()
1305 << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
1306 }
1307 ++dim;
1308 }
1309 } else if (input_ty) {
1310 // If the inputs are ranked, we can do a few more sanity checks.
1311 DenseIntElementsAttr slice_sizes;
1312 if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
1313 auto input_shape = input_ty.getShape();
1314 for (int64_t i = 0; i < input_ty.getRank(); ++i) {
1315 int64_t slice_size = slice_sizes.getValues<APInt>()[i].getSExtValue();
1316 int64_t input_size = input_shape[i];
1317 if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
1318 return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
1319 "is unknown at compile time";
1320 }
1321 }
1322 }
1323 }
1324
1325 return success();
1326 }
1327
1328 //===----------------------------------------------------------------------===//
1329 // SoftmaxOp
1330 //===----------------------------------------------------------------------===//
1331
verify()1332 LogicalResult SoftmaxOp::verify() {
1333 SoftmaxOp op = *this;
1334 if (!HasRankAtLeast(op.logits(), 1)) {
1335 return op.emitOpError("requires operand to have rank at least 1");
1336 }
1337 return success();
1338 }
1339
1340 //===----------------------------------------------------------------------===//
1341 // SoftmaxCrossEntropyWithLogitsOp
1342 //===----------------------------------------------------------------------===//
1343
1344 // Verifies that,
1345 //
1346 // * Input types are broadcast compatible and the broadcasted type has rank two.
1347 //
verify()1348 LogicalResult SoftmaxCrossEntropyWithLogitsOp::verify() {
1349 SoftmaxCrossEntropyWithLogitsOp op = *this;
1350 auto broadcasted_ty = OpTrait::util::getBroadcastedType(
1351 op.features().getType(), op.labels().getType())
1352 .dyn_cast_or_null<ShapedType>();
1353 if (!broadcasted_ty ||
1354 (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
1355 return op.emitOpError(
1356 "requires features and labels to be broadcast compatible to rank two");
1357
1358 return success();
1359 }
1360
1361 //===----------------------------------------------------------------------===//
1362 // SpaceToBatchNDOp
1363 //===----------------------------------------------------------------------===//
1364
SpaceToBatchNDBlockRank(const TensorType block_shape_type,const TensorType paddings_type)1365 int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type,
1366 const TensorType paddings_type) {
1367 if (block_shape_type.hasStaticShape()) {
1368 return block_shape_type.getShape()[0];
1369 } else if (paddings_type.hasStaticShape()) {
1370 return paddings_type.getShape()[0];
1371 } else {
1372 return -1;
1373 }
1374 }
1375
verify()1376 LogicalResult SpaceToBatchNDOp::verify() {
1377 SpaceToBatchNDOp op = *this;
1378 const auto input_type = op.input().getType().cast<TensorType>();
1379 const auto block_shape_type = op.block_shape().getType().cast<TensorType>();
1380 const auto paddings_type = op.paddings().getType().cast<TensorType>();
1381
1382 // Check that block_shape has rank 1.
1383 if (!IsOfRankOrUnranked(op.block_shape(), 1)) {
1384 return op.emitOpError() << "requires rank of block_shape = 1; got "
1385 << block_shape_type.getRank();
1386 }
1387
1388 // Check that paddings has rank 2.
1389 if (!IsOfRankOrUnranked(op.paddings(), 2)) {
1390 return op.emitOpError()
1391 << "requires rank of paddings = 2; got " << paddings_type.getRank();
1392 }
1393
1394 // Check that paddings.shape[1]=2.
1395 if (paddings_type.hasStaticShape() && paddings_type.getShape()[1] != 2) {
1396 return op.emitOpError() << "requires paddings.shape[1] to be 2; got "
1397 << paddings_type.getShape()[1];
1398 }
1399
1400 // Check that block_shape and paddings have consistent ranks.
1401 if (block_shape_type.hasStaticShape() && paddings_type.hasStaticShape() &&
1402 block_shape_type.getShape()[0] != paddings_type.getShape()[0]) {
1403 return op.emitOpError()
1404 << "requires block_shape.shape[0] must equal paddings.shape[0]";
1405 }
1406
1407 const int64_t block_rank =
1408 SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
1409
1410 // Further checks require block_rank to be known.
1411 if (block_rank == -1) {
1412 return success();
1413 }
1414
1415 // check that rank of input_type >= block_rank + 1
1416 if (input_type.hasRank() && input_type.getRank() < 1 + block_rank) {
1417 return op.emitOpError() << "requires rank of input >= 1 + rank of block";
1418 }
1419
1420 ElementsAttr block_shape_attr = nullptr;
1421 ElementsAttr paddings_attr = nullptr;
1422
1423 // Check that block_shape[*] >= 1.
1424 if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) {
1425 uint64_t i = 0;
1426 for (auto block_len : block_shape_attr.getValues<APInt>()) {
1427 if (block_len.getSExtValue() < 1) {
1428 return op.emitOpError()
1429 << "requires all values of block_shape to be >= 1; "
1430 "failed for dimension "
1431 << i;
1432 }
1433 ++i;
1434 }
1435 }
1436
1437 // Check that paddings[*] >= 0.
1438 if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) {
1439 for (uint64_t i = 0; i < block_rank; ++i) {
1440 const int64_t pad_start =
1441 paddings_attr.getValues<APInt>()[{i, 0}].getSExtValue();
1442 const int64_t pad_end =
1443 paddings_attr.getValues<APInt>()[{i, 1}].getSExtValue();
1444 if (pad_start < 0 || pad_end < 0) {
1445 return op.emitOpError()
1446 << "requires all values of paddings to be >= 0; "
1447 "failed for dimension "
1448 << i;
1449 }
1450 }
1451 }
1452
1453 // Check that block_shape divides the padded input.
1454 if (input_type.hasStaticShape() && block_shape_attr && paddings_attr) {
1455 for (uint64_t i = 0; i < block_rank; ++i) {
1456 const int64_t input_len = input_type.getShape()[1 + i];
1457 const int64_t pad_start =
1458 paddings_attr.getValues<APInt>()[{i, 0}].getSExtValue();
1459 const int64_t pad_end =
1460 paddings_attr.getValues<APInt>()[{i, 1}].getSExtValue();
1461 const int64_t block_len =
1462 block_shape_attr.getValues<APInt>()[i].getSExtValue();
1463 if ((input_len + pad_start + pad_end) % block_len != 0) {
1464 return op.emitOpError()
1465 << "requires block_shape[i] divides "
1466 "input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; "
1467 "failed for i="
1468 << i;
1469 }
1470 }
1471 }
1472
1473 return success();
1474 }
1475
1476 //===----------------------------------------------------------------------===//
1477 // SparseSoftmaxCrossEntropyWithLogitsOp
1478 //===----------------------------------------------------------------------===//
1479
verify()1480 LogicalResult SparseSoftmaxCrossEntropyWithLogitsOp::verify() {
1481 SparseSoftmaxCrossEntropyWithLogitsOp op = *this;
1482 if (!IsOfRankOrUnranked(op.features(), 2)) {
1483 return op.emitOpError("requires features operand of rank two");
1484 }
1485 if (!IsOfRankOrUnranked(op.labels(), 1)) {
1486 return op.emitOpError("requires labels operand of rank one");
1487 }
1488 auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
1489 auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
1490 if (features_ty && labels_ty) {
1491 int64_t features_batches = features_ty.getDimSize(0);
1492 int64_t labels_batches = labels_ty.getDimSize(0);
1493 if (!ShapedType::isDynamic(features_batches) &&
1494 !ShapedType::isDynamic(labels_batches) &&
1495 features_batches != labels_batches)
1496 return op.emitOpError(
1497 "requires features and labels with matching first dimension");
1498 }
1499 return success();
1500 }
1501
1502 //===----------------------------------------------------------------------===//
1503 // SplitOp
1504 //===----------------------------------------------------------------------===//
1505
1506 // Verifies the input and split dimension operands for tf.Split/tf.SplitV.
1507 // Writes the split dimension's index (adjusted with input rank) via `dim_index`
1508 // if it's a constant.
1509 template <class Op>
VerifySplitInputAndSplitDim(Op op,Optional<int64_t> * dim_index)1510 LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
1511 *dim_index = llvm::None;
1512
1513 Value split_dim = op.split_dim();
1514 if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
1515 if (split_dim_type.getRank() != 0)
1516 return op.emitOpError(
1517 "split dimension should be an integer scalar tensor");
1518
1519 // We can perform further verification if the input tensor to be split has
1520 // known rank and the split dimension tensor is a constant.
1521
1522 auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
1523 if (!input_type) return success();
1524
1525 int64_t input_rank = input_type.getRank();
1526 if (input_rank == 0)
1527 return op.emitOpError("cannot split scalar input tensor");
1528
1529 DenseIntElementsAttr split_dim_attr;
1530 if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
1531
1532 int64_t index = (*split_dim_attr.begin()).getSExtValue();
1533
1534 if (index + input_rank < 0 || index >= input_rank) {
1535 return op.emitOpError("split dimension must be in range [-")
1536 << input_rank << ", " << input_rank << ")";
1537 }
1538
1539 if (index < 0) index += input_rank;
1540 *dim_index = index;
1541
1542 return success();
1543 }
1544
verify()1545 LogicalResult SplitOp::verify() {
1546 SplitOp op = *this;
1547 Optional<int64_t> dim_index;
1548 if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1549 if (!dim_index) return success();
1550
1551 int64_t input_dim_size =
1552 op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1553 if (ShapedType::isDynamic(input_dim_size)) return success();
1554
1555 if (op.getNumResults() == 0) return failure();
1556
1557 if (input_dim_size % op.getNumResults() != 0)
1558 return op.emitOpError("dimension #")
1559 << *dim_index << " not divisible by the number of result tensors";
1560
1561 return success();
1562 }
1563
1564 //===----------------------------------------------------------------------===//
1565 // SplitVOp
1566 //===----------------------------------------------------------------------===//
1567
verify()1568 LogicalResult SplitVOp::verify() {
1569 SplitVOp op = *this;
1570 auto split_sizes_type =
1571 op.size_splits().getType().dyn_cast<RankedTensorType>();
1572 if (!split_sizes_type) return success();
1573
1574 if (split_sizes_type.getRank() != 1 ||
1575 (!ShapedType::isDynamic(split_sizes_type.getDimSize(0)) &&
1576 split_sizes_type.getDimSize(0) != op.getNumResults()))
1577 return op.emitOpError("split sizes should be a 1D tensor of ")
1578 << op.getNumResults() << " elements";
1579
1580 Optional<int64_t> dim_index = 0;
1581 if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1582 if (!dim_index) return success();
1583
1584 int64_t input_dim_size =
1585 op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1586 if (ShapedType::isDynamic(input_dim_size)) return success();
1587
1588 // If split sizes come from a constant, they must sum to the dimension size
1589 // along split_dim, and we can have no more than one dynamic dimension.
1590 DenseIntElementsAttr split_sizes_attr;
1591 if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
1592 return success();
1593
1594 int64_t total_dim_size = 0; // Total dimension size assigned to splits
1595 llvm::Optional<int> dynamic_dim_index;
1596
1597 SmallVector<int64_t, 4> split_sizes;
1598 split_sizes.reserve(
1599 split_sizes_attr.getType().cast<ShapedType>().getNumElements());
1600
1601 for (auto dim : llvm::enumerate(split_sizes_attr)) {
1602 int64_t dim_val = dim.value().getSExtValue();
1603 split_sizes.push_back(dim_val);
1604 if (ShapedType::isDynamic(dim_val)) {
1605 // We cannot have more than one dynamic dimension.
1606 if (dynamic_dim_index)
1607 return op.emitOpError(
1608 "cannot have more than one dynamic dimension in split sizes");
1609 dynamic_dim_index = dim.index();
1610 } else {
1611 total_dim_size += dim_val;
1612 }
1613 }
1614
1615 if (!dynamic_dim_index && total_dim_size != input_dim_size)
1616 return op.emitOpError(
1617 "split sizes must sum up to the dimension size along split "
1618 "dimension, found ")
1619 << total_dim_size << " vs " << input_dim_size;
1620
1621 if (dynamic_dim_index && total_dim_size > input_dim_size)
1622 return op.emitOpError(
1623 "split sizes must sum up to be less than or equal to the "
1624 "dimension size along split dimension, found ")
1625 << total_dim_size << " vs " << input_dim_size;
1626
1627 return success();
1628 }
1629
1630 //===----------------------------------------------------------------------===//
1631 // SquareOp
1632 //===----------------------------------------------------------------------===//
1633
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1634 void SquareOp::getCanonicalizationPatterns(RewritePatternSet &results,
1635 MLIRContext *context) {
1636 results.add<SquareOfSub>(context);
1637 }
1638
1639 //===----------------------------------------------------------------------===//
1640 // SqueezeOp
1641 //===----------------------------------------------------------------------===//
1642
verify()1643 LogicalResult SqueezeOp::verify() {
1644 SqueezeOp op = *this;
1645 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1646
1647 if (!input_type) return success(); // Can't verify squeeze dims.
1648
1649 int64_t input_rank = input_type.getRank();
1650 for (const auto &squeeze_dim_apint :
1651 op.squeeze_dims().getAsValueRange<IntegerAttr>()) {
1652 int64_t squeeze_dim = squeeze_dim_apint.getSExtValue();
1653 if (squeeze_dim < -input_rank || squeeze_dim >= input_rank) {
1654 return op.emitOpError()
1655 << "squeeze dimension " << squeeze_dim << " not in ["
1656 << -input_rank << ", " << input_rank << ")";
1657 }
1658 }
1659
1660 return success();
1661 }
1662
1663 //===----------------------------------------------------------------------===//
1664 // SubOp
1665 //===----------------------------------------------------------------------===//
1666
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1667 void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
1668 MLIRContext *context) {
1669 results.add<SubOfNeg>(context);
1670 }
1671
fold(ArrayRef<Attribute> operands)1672 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1673 return IdentityArithmeticOpFolder<SubOp>(*this, operands);
1674 }
1675
1676 //===----------------------------------------------------------------------===//
1677 // SumOp
1678 //===----------------------------------------------------------------------===//
1679
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)1680 void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
1681 Value reduction_indices, BoolAttr keep_dims) {
1682 Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims);
1683 build(builder, result, out_ty, input, reduction_indices, keep_dims);
1684 }
1685
1686 // TODO: Templatize this fold for all reduction ops.
fold(ArrayRef<Attribute> operands)1687 OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
1688 auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
1689 if (!input_ty) return {};
1690 auto result_ty = getType().template dyn_cast<RankedTensorType>();
1691 if (!result_ty) return {};
1692
1693 // Bypass this op if the result has the same shape and type. This can happen
1694 // if the input tensor has size 0 or size 1.
1695 if (!keep_dims() && input_ty == result_ty) {
1696 return input();
1697 }
1698 return {};
1699 }
1700
1701 //===----------------------------------------------------------------------===//
1702 // StridedSliceOp
1703 //===----------------------------------------------------------------------===//
1704
1705 // TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
1706 // tf.SliceOp if both of the following are true:
1707 // - All strides have a known value equal to 1
1708 // - No masks are set (or masks can be applied by transforming the inputs to
1709 // Slice)
1710
1711 // Verifies that,
1712 //
1713 // - begin, end and strides operands are 1D and they have the same number of
1714 // elements. Here, the number of elements should be less than 32 to support
1715 // 32-bit mask attributes.
1716 // - None of the strides values are zero.
1717 // - Ellipsis mask can have at most one bit set.
1718
1719 template <class OpTy>
VerifyStridedSliceBase(OpTy op)1720 static LogicalResult VerifyStridedSliceBase(OpTy op) {
1721 // Expected size for operands begin, end and strides vector operands.
1722 int64_t expected_size = -1;
1723
1724 for (Value val : {op.begin(), op.end(), op.strides()}) {
1725 auto operand_ty = val.getType().dyn_cast<ShapedType>();
1726 if (!operand_ty || !operand_ty.hasStaticShape()) {
1727 // TensorFlow constant ops may have non-static shape because the shape is
1728 // not propagated during constant folding. If the defining op for this
1729 // operand is a constant op, use the constant op's attribute to get the
1730 // actual shape.
1731 DenseIntElementsAttr attr;
1732 if (!matchPattern(val, m_Constant(&attr))) continue;
1733 operand_ty = attr.getType();
1734 }
1735
1736 if (operand_ty.getRank() != 1)
1737 return op.emitOpError()
1738 << "requires begin, end and strides to be 1D tensors";
1739
1740 int64_t length = operand_ty.getDimSize(0);
1741 if (length == -1) continue;
1742
1743 if (expected_size == -1) {
1744 // This op uses 32-bit masks.
1745 if (length >= 32)
1746 return op.emitOpError(
1747 "requires begin, end and strides operands with less than 32 "
1748 "elements");
1749
1750 expected_size = length;
1751 } else if (length != expected_size) {
1752 return op.emitOpError() << "requires begin, end and strides to have the "
1753 "same number of elements";
1754 }
1755 }
1756
1757 // If strides are constants, verify that none of the element is zero.
1758 DenseIntElementsAttr strides;
1759 if (matchPattern(op.strides(), m_Constant(&strides))) {
1760 if (llvm::is_contained(strides.getValues<APInt>(), 0))
1761 return op.emitOpError("requires non-zero strides");
1762 }
1763
1764 // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
1765 // exists only no more than one ellipsis.
1766 uint32_t ellipsis_mask = op.ellipsis_mask();
1767 if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
1768 return op.emitOpError("cannot have multiple ellipses");
1769
1770 return success();
1771 }
1772
verify()1773 LogicalResult StridedSliceOp::verify() { return VerifyStridedSliceBase(*this); }
1774
1775 // Clamps the given `val`: returns `low` if `val` is less than `low`; returns
1776 // `high` if `high` is less than `val`; otherwise returns `val`.
1777 template <class T>
Clamp(const T & val,const T & low,const T & high)1778 constexpr const T &Clamp(const T &val, const T &low, const T &high) {
1779 assert(!(high < low));
1780 return (val < low) ? low : (high < val) ? high : val;
1781 }
1782
1783 // Checks if the `index` bit of `val` is set.
1784 template <class T>
IsSet(const T & val,unsigned index)1785 constexpr bool IsSet(const T &val, unsigned index) {
1786 return (val & (1 << index)) != 0;
1787 }
1788
1789 // Sets the `index` bit of `val`.
1790 template <class T>
Set(T & val,unsigned index)1791 constexpr void Set(T &val, unsigned index) {
1792 val |= (1 << index);
1793 }
1794
1795 // Unset the `index` bit of `val`.
1796 template <class T>
Unset(T & val,unsigned index)1797 constexpr void Unset(T &val, unsigned index) {
1798 val &= ~(1 << index);
1799 }
1800
1801 // Copy the `src_index` bit of `src` to `dst_index` bit of `dst`.
1802 template <class T>
CopyBit(const T & src,unsigned src_index,T & dst,unsigned dst_index)1803 constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
1804 unsigned dst_index) {
1805 if (IsSet(src, src_index))
1806 Set(dst, dst_index);
1807 else
1808 Unset(dst, dst_index);
1809 }
1810
1811 // The sparse spec of strided slice does not correspond to the number of
1812 // dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
1813 // 4, 8) would have dims = 2.
1814 struct SparseSliceSpec {
1815 int64_t dims;
1816 int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
1817 const ArrayRef<int64_t> &begin;
1818 const ArrayRef<int64_t> &end;
1819 const ArrayRef<int64_t> &strides;
1820 };
1821
1822 // The dense spec of strided slice is the canonicalized version of sparse spec.
1823 // The number of dimensions of dense spec correspond to the number of dimensions
1824 // in operand tensor.
1825 struct DenseSliceSpec {
1826 int64_t dims;
1827 int32_t begin_mask, end_mask, shrink_axis_mask;
1828 SmallVectorImpl<int64_t> &begin;
1829 SmallVectorImpl<int64_t> &end;
1830 SmallVectorImpl<int64_t> &strides;
1831 };
1832
1833 // Make a sparse spec into a dense index spec.
1834 // The sparse spec does not correspond to the number of dimensions
1835 // Make a dense spec that corresponds to the number of dimensions
1836 //
1837 // For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then
1838 // we need to produce the missing begin_mask, end_mask for the first two
1839 // dimensions i.e. foo[:, :, 3:, 2].
BuildDenseSliceSpec(const SparseSliceSpec & sparse,DenseSliceSpec * dense)1840 static void BuildDenseSliceSpec(const SparseSliceSpec &sparse,
1841 DenseSliceSpec *dense) {
1842 // Build expanded dense begin, end, strides, begin_mask, end_mask, and
1843 // shrink_axis_mask.
1844 dense->begin.resize(dense->dims);
1845 dense->end.resize(dense->dims);
1846 dense->strides.resize(dense->dims);
1847 dense->begin_mask = 0;
1848 dense->end_mask = 0;
1849 dense->shrink_axis_mask = 0;
1850
1851 // Count number of new_axis after ellipsis. This helps in calculating the
1852 // number of dimensions ellipsis represents in the sparse spec.
1853 bool ellipsis_seen = false;
1854 int num_new_axis_after_ellipsis = 0;
1855 for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1856 if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
1857 num_new_axis_after_ellipsis++;
1858 if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
1859 }
1860
1861 int dense_index = 0;
1862 for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1863 if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
1864 if (IsSet(sparse.ellipsis_mask, sparse_index)) {
1865 auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1866 1 + num_new_axis_after_ellipsis,
1867 dense->dims);
1868 // Expand ellipsis into the appropriate dense indices. From current index
1869 // until next_index, all dimensions would have begin and end masks set and
1870 // stride 1, i.e., get all elements in those dimensions.
1871 for (; dense_index < next_index; ++dense_index) {
1872 dense->begin[dense_index] = dense->end[dense_index] = 0;
1873 dense->strides[dense_index] = 1;
1874 Set(dense->begin_mask, dense_index);
1875 Set(dense->end_mask, dense_index);
1876 }
1877 continue;
1878 }
1879 assert(dense_index < dense->dims);
1880 // Copy over the sparse indices to dense indices if ellipsis_mask and
1881 // new_axis_mask are not set.
1882 dense->begin[dense_index] = sparse.begin[sparse_index];
1883 dense->end[dense_index] = sparse.end[sparse_index];
1884 dense->strides[dense_index] = sparse.strides[sparse_index];
1885 CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index);
1886 CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index);
1887 CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask,
1888 dense_index);
1889 dense_index++;
1890 }
1891 }
1892
1893 // For the given `input_shape`, calculates the sliced shape using the given
1894 // `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and
1895 // `shrink_axis_mask` masks. Updates the result back to `input_shape`. If
1896 // `shrink_axis_mask` is not zero, this function will not drop the corresponding
1897 // dimensions in `input_shape`; it will turn them into 1s. At the same time,
1898 // canonicalizes `begin`, `end`, and `strides. The calculation follows
1899 // tf.StridedSlice op semantics.
CalculateSlicedShapeFromDenseIndices(MutableArrayRef<int64_t> input_shape,int32_t begin_mask,int32_t end_mask,int32_t shrink_axis_mask,MutableArrayRef<int64_t> begin,MutableArrayRef<int64_t> end,MutableArrayRef<int64_t> stride)1900 static void CalculateSlicedShapeFromDenseIndices(
1901 MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
1902 int32_t shrink_axis_mask, MutableArrayRef<int64_t> begin,
1903 MutableArrayRef<int64_t> end, MutableArrayRef<int64_t> stride) {
1904 assert(input_shape.size() <= 32); // Only 32-bit masks are supported.
1905
1906 // Make sure ranges' ranks are consistent with the input.
1907 assert(input_shape.size() == begin.size());
1908 assert(input_shape.size() == end.size());
1909 assert(input_shape.size() == stride.size());
1910
1911 for (int i = 0, e = input_shape.size(); i < e; ++i) {
1912 if (ShapedType::isDynamic(input_shape[i])) continue;
1913
1914 int64_t dim_i = input_shape[i];
1915 int64_t begin_i = begin[i];
1916 int64_t end_i = end[i];
1917 int64_t stride_i = stride[i];
1918
1919 // [0]: mask for begin, [1]: mask for end
1920 int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
1921 // [0]: bound for begin, [1]: bound for end
1922 int64_t bounds[] = {stride_i > 0 ? 0 : -1,
1923 stride_i > 0 ? dim_i : dim_i - 1};
1924
1925 // Canonicalizes the given range `point` (begin/end) according to the
1926 // current dimension. `c` means case: 0 for begin, 1 for end.
1927 auto canonicalize = [&](int64_t point, int c) {
1928 if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
1929
1930 // Add dim as offset to negative range point.
1931 point = point < 0 ? dim_i + point : point;
1932 return Clamp(point, bounds[0], bounds[1]);
1933 };
1934
1935 begin_i = canonicalize(begin_i, 0);
1936 end_i = canonicalize(end_i, 1);
1937
1938 int64_t interval_len = end_i - begin_i;
1939 int64_t size_i = 0;
1940 // If internal length is zero or has different sign from stride, it's a
1941 // degenerated case: we are slicing nothing. Otherwise, calculate the sliced
1942 // size.
1943 if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
1944 size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
1945
1946 begin[i] = begin_i;
1947 if (IsSet(shrink_axis_mask, i)) {
1948 // Shrink this dimension. It means we only take the element at begin_i.
1949 input_shape[i] = 1;
1950 end[i] = begin_i + 1;
1951 stride[i] = 1;
1952 } else {
1953 input_shape[i] = size_i;
1954 end[i] = end_i;
1955 stride[i] = stride_i;
1956 }
1957 }
1958 }
1959
1960 // For the given `input_shape`, calculates the sliced shape using the given
1961 // `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`,
1962 // `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks.
1963 // Updates the result back to `input_shape`.
CalculateSlicedShapeFromSparseIndices(MutableArrayRef<int64_t> input_shape,ArrayRef<int64_t> sparse_begin,ArrayRef<int64_t> sparse_end,ArrayRef<int64_t> sparse_strides,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,SmallVectorImpl<int64_t> * begin,SmallVectorImpl<int64_t> * end,SmallVectorImpl<int64_t> * stride)1964 static void CalculateSlicedShapeFromSparseIndices(
1965 MutableArrayRef<int64_t> input_shape, ArrayRef<int64_t> sparse_begin,
1966 ArrayRef<int64_t> sparse_end, ArrayRef<int64_t> sparse_strides,
1967 int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
1968 int32_t new_axis_mask, int32_t shrink_axis_mask,
1969 SmallVectorImpl<int64_t> *begin, SmallVectorImpl<int64_t> *end,
1970 SmallVectorImpl<int64_t> *stride) {
1971 int64_t num_sparse_indices = sparse_begin.size();
1972 SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask,
1973 ellipsis_mask, new_axis_mask, shrink_axis_mask,
1974 sparse_begin, sparse_end, sparse_strides};
1975
1976 // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
1977 // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
1978 // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
1979 if (sparse.ellipsis_mask == 0) {
1980 Set(sparse.ellipsis_mask, sparse.dims);
1981 sparse.dims++;
1982 }
1983
1984 int64_t dims = input_shape.size();
1985 DenseSliceSpec dense = {dims,
1986 /*begin_mask = */ 0,
1987 /*end_mask = */ 0,
1988 /*shrink_axis_mask = */ 0,
1989 *begin,
1990 *end,
1991 *stride};
1992
1993 BuildDenseSliceSpec(sparse, &dense);
1994 CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask,
1995 dense.end_mask, dense.shrink_axis_mask,
1996 *begin, *end, *stride);
1997 }
1998
GetSlicedBoundRanges(SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)1999 bool StridedSliceOp::GetSlicedBoundRanges(
2000 SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
2001 SmallVectorImpl<int64_t> *slice_stride) {
2002 // TODO(hinsu): Support lowering for ops with dynamic begin and end values
2003 // when it is possible to derive indices based on mask attributes.
2004 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
2005 if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
2006 !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
2007 !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
2008 return false;
2009
2010 auto input_ty = this->input().getType().dyn_cast<RankedTensorType>();
2011 if (!input_ty || !input_ty.hasStaticShape()) return false;
2012 auto input_shape = llvm::to_vector<4>(input_ty.getShape());
2013
2014 SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
2015
2016 for (const APInt &index : sparse_begin_attr)
2017 sparse_begin.push_back(index.getSExtValue());
2018 for (const APInt &index : sparse_end_attr)
2019 sparse_end.push_back(index.getSExtValue());
2020 for (const APInt &stride : sparse_strides_attr)
2021 sparse_strides.push_back(stride.getSExtValue());
2022
2023 CalculateSlicedShapeFromSparseIndices(
2024 input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
2025 end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
2026 slice_begin, slice_end, slice_stride);
2027 return true;
2028 }
2029
fold(ArrayRef<Attribute> operands)2030 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
2031 // Fold StridedSlice operation if it extracts statically known dimensions.
2032 //
2033 // For example,
2034 //
2035 // %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
2036 // %height = tf.StridedSlice(%shape, 1, 2, 1)
2037 //
2038 // In this case %height can be replaced with a constant 2.
2039 //
2040 // Or,
2041 //
2042 // %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
2043 // %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
2044 //
2045 // In this case %spatial_shape can be replaced with a constant [2, 3].
2046
2047 // Input to strided slice op is defined by shape operation.
2048 auto shape_op = input().getDefiningOp<ShapeOp>();
2049 if (!shape_op) {
2050 return {};
2051 }
2052
2053 // `begin`, `end` and `strides` should be constant in order to infer static
2054 // dimension.
2055 DenseIntElementsAttr begin_attr, end_attr, strides_attr;
2056 if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
2057 !matchPattern(end(), m_Constant(&end_attr)) ||
2058 !matchPattern(strides(), m_Constant(&strides_attr)) ||
2059 begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
2060 strides_attr.getNumElements() != 1) {
2061 return {};
2062 }
2063
2064 // Do not fold when `new_axis_mask` is set. It's likely to break the shape
2065 // of output. Typically, `new_axis_mask` is not set in this canonicalization
2066 // pattern.
2067 if (new_axis_mask() != 0) return {};
2068
2069 auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
2070 // Only ranked tensor can be folded.
2071 if (!tensor_ty) return {};
2072
2073 int64_t rank = tensor_ty.getRank();
2074 int64_t begin_int = begin_attr.getValues<APInt>()[0].getSExtValue();
2075 int64_t end_int = end_attr.getValues<APInt>()[0].getSExtValue();
2076 int64_t strides_int = strides_attr.getValues<APInt>()[0].getSExtValue();
2077
2078 // Canonicalize `begin` and `end` in case of negative index.
2079 if (begin_int < 0) begin_int += rank;
2080 if (end_int < 0) end_int += rank;
2081
2082 // Create `begin` and `end` from `*_mask`. Note that we don't care about
2083 // `new_axis_mask` as it can be inferred from `output_ty`.
2084 if (shrink_axis_mask() == 1) {
2085 // When `shrink_axis_mask` is set, output is always a scalar so only
2086 // one element is sliced.
2087 end_int = begin_int + 1;
2088 }
2089 if (begin_mask() == 1) {
2090 begin_int = (strides_int > 0) ? 0 : rank - 1;
2091 }
2092 if (end_mask() == 1) {
2093 end_int = (strides_int > 0) ? rank : -1;
2094 }
2095 if (ellipsis_mask() == 1) {
2096 begin_int = 0;
2097 end_int = rank;
2098 }
2099
2100 // It's possible that `begin` and `end` are out of bound. See
2101 // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
2102 if (strides_int > 0) {
2103 begin_int = std::min(begin_int, rank);
2104 end_int = std::min(end_int, rank);
2105 } else {
2106 begin_int = std::min(begin_int, rank - 1);
2107 end_int = std::min(end_int, rank - 1);
2108 }
2109
2110 SmallVector<int64_t, 2> sub_shape;
2111 // Only handle cases that have something to slice to avoid infinite for-loop.
2112 if ((end_int > begin_int && strides_int > 0) ||
2113 (end_int < begin_int && strides_int < 0)) {
2114 // Extract sub-shape only if all of those dimensions are static.
2115 for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
2116 i += strides_int) {
2117 if (tensor_ty.isDynamicDim(i)) {
2118 return {};
2119 }
2120 sub_shape.push_back(tensor_ty.getDimSize(i));
2121 }
2122 }
2123
2124 // For unranked or dynamic output, we infer the output type to either a
2125 // scalar or a vector based on `shrink_axis_mask` because we have rejected
2126 // the case of `new_axis_mask` != 0.
2127 auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
2128 auto output_ty = output().getType().dyn_cast<RankedTensorType>();
2129 if (!output_ty || !output_ty.hasStaticShape()) {
2130 if (shrink_axis_mask() == 1) {
2131 output_ty = RankedTensorType::get({}, output_elt_ty);
2132 } else {
2133 output_ty = RankedTensorType::get(
2134 {static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
2135 }
2136 }
2137
2138 // Down-cast to 32 bit int if needed.
2139 if (output_elt_ty.isInteger(32)) {
2140 SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
2141 std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
2142 [](int64_t d) { return static_cast<int32_t>(d); });
2143 return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
2144 }
2145 return DenseIntElementsAttr::get(output_ty, sub_shape);
2146 }
2147
2148 //===----------------------------------------------------------------------===//
2149 // StridedSliceGradOp
2150 //===----------------------------------------------------------------------===//
2151
verify()2152 LogicalResult StridedSliceGradOp::verify() {
2153 StridedSliceGradOp op = *this;
2154 auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
2155 if (shape_type && shape_type.getRank() != 1)
2156 return op.emitOpError("'shape' operand must be 1D tensor, but got ")
2157 << shape_type.getRank() << "D tensor";
2158
2159 if (failed(VerifyStridedSliceBase(op))) return failure();
2160
2161 // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
2162 // the sliced type from StridedSlice.
2163
2164 return success();
2165 }
2166
GetSlicedShapeAndBoundRanges(SmallVectorImpl<int64_t> * input_shape,SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)2167 bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
2168 SmallVectorImpl<int64_t> *input_shape,
2169 SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
2170 SmallVectorImpl<int64_t> *slice_stride) {
2171 DenseIntElementsAttr shape_attr;
2172 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
2173 if (!matchPattern(shape(), m_Constant(&shape_attr)) ||
2174 !matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
2175 !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
2176 !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
2177 return false;
2178
2179 int rank = std::distance(shape_attr.begin(), shape_attr.end());
2180
2181 input_shape->clear();
2182 input_shape->reserve(rank);
2183 for (const APInt &dim : shape_attr)
2184 input_shape->push_back(dim.getSExtValue());
2185
2186 SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
2187
2188 for (const APInt &index : sparse_begin_attr)
2189 sparse_begin.push_back(index.getSExtValue());
2190 for (const APInt &index : sparse_end_attr)
2191 sparse_end.push_back(index.getSExtValue());
2192 for (const APInt &stride : sparse_strides_attr)
2193 sparse_strides.push_back(stride.getSExtValue());
2194
2195 CalculateSlicedShapeFromSparseIndices(
2196 *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
2197 end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
2198 slice_begin, slice_end, slice_stride);
2199 return true;
2200 }
2201
2202 //===----------------------------------------------------------------------===//
2203 // SummaryWriterOp
2204 //===----------------------------------------------------------------------===//
2205
2206 llvm::SmallVector<ResourceHandleValueAndId, 4>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2207 SummaryWriterOp::GetResourceHandleValueAndIdList(
2208 llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2209 int64_t &next_id) {
2210 llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2211 return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2212 writer(), resource_handle_id_map,
2213 next_id)};
2214 }
2215
2216 //===----------------------------------------------------------------------===//
2217 // TPUExecuteOp
2218 //===----------------------------------------------------------------------===//
2219
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2220 void TPUExecuteOp::getEffects(
2221 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2222 &effects) {
2223 effects.reserve(args().size() + 1);
2224 effects.emplace_back(MemoryEffects::Write::get(),
2225 ResourceEffects::TPUExecute::get());
2226
2227 for (Value value : args()) {
2228 if (value.getType()
2229 .cast<TensorType>()
2230 .getElementType()
2231 .isa<ResourceType>()) {
2232 // Conservatively mark resource handles as read and write, as without
2233 // analyzing TPUCompile, there is not sufficient information to determine
2234 // effects on resources. For the MLIR bridge, this op will never be
2235 // populated with resource handles and tf.TPUExecuteAndUpdateVariables is
2236 // used instead.
2237 effects.emplace_back(MemoryEffects::Read::get(), value,
2238 ResourceEffects::Variable::get());
2239 effects.emplace_back(MemoryEffects::Write::get(), value,
2240 ResourceEffects::Variable::get());
2241 }
2242 }
2243 }
2244
2245 //===----------------------------------------------------------------------===//
2246 // TPUExecuteAndUpdateVariablesOp
2247 //===----------------------------------------------------------------------===//
2248
verify()2249 LogicalResult TPUExecuteAndUpdateVariablesOp::verify() {
2250 TPUExecuteAndUpdateVariablesOp op = *this;
2251 int num_resource_args = 0;
2252 for (Type arg_type : op.args().getTypes())
2253 if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
2254 ++num_resource_args;
2255
2256 auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
2257 int min) -> LogicalResult {
2258 if (indices.size() != num_resource_args)
2259 return op.emitOpError()
2260 << "requires '" << name
2261 << "' to be the same size as number of resource handles in 'args' "
2262 "("
2263 << num_resource_args << "), but got " << indices.size();
2264
2265 for (auto entry : llvm::enumerate(indices.getValue())) {
2266 auto int_attr = entry.value().cast<IntegerAttr>();
2267 if (int_attr.getInt() < min)
2268 return op.emitOpError()
2269 << "requires '" << name << "' to contain values of at least "
2270 << min << ", but got " << int_attr.getInt() << " at index "
2271 << entry.index();
2272 }
2273
2274 return success();
2275 };
2276
2277 return failure(
2278 failed(check_attr(op.device_var_reads_indices(),
2279 /*name=*/"device_var_reads_indices", /*min=*/0)) ||
2280 failed(check_attr(op.device_var_updates_indices(),
2281 /*name=*/"device_var_updates_indices", /*min=*/-1)));
2282 }
2283
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2284 void TPUExecuteAndUpdateVariablesOp::getEffects(
2285 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2286 &effects) {
2287 effects.reserve(device_var_reads_indices().size() + 1);
2288 effects.emplace_back(MemoryEffects::Write::get(),
2289 ResourceEffects::TPUExecute::get());
2290 auto resource_handles = llvm::make_filter_range(args(), [](Value value) {
2291 return value.getType()
2292 .cast<TensorType>()
2293 .getElementType()
2294 .isa<ResourceType>();
2295 });
2296
2297 for (auto &entry : llvm::enumerate(resource_handles)) {
2298 Value value = entry.value();
2299 effects.emplace_back(MemoryEffects::Read::get(), value,
2300 ResourceEffects::Variable::get());
2301 if (device_var_updates_indices()
2302 .getValue()[entry.index()]
2303 .cast<IntegerAttr>()
2304 .getInt() >= 0)
2305 effects.emplace_back(MemoryEffects::Write::get(), value,
2306 ResourceEffects::Variable::get());
2307 }
2308 }
2309
2310 //===----------------------------------------------------------------------===//
2311 // TensorListGetItemOp
2312 //===----------------------------------------------------------------------===//
2313
2314 namespace {
2315 // If the input of TensorListGetItemOp is TensorListFromTensorOp and the
2316 // TensorListFromTensorOp is only used by TensorListGetItemOp (not modified by
2317 // other TensorList ops), we can convert it to a GatherOp.
2318 class ConvertTensorListGetItemOpOfTensorListFromTensorOpToGather
2319 : public OpRewritePattern<TensorListGetItemOp> {
2320 using OpRewritePattern<TensorListGetItemOp>::OpRewritePattern;
matchAndRewrite(TensorListGetItemOp op,PatternRewriter & rewriter) const2321 LogicalResult matchAndRewrite(TensorListGetItemOp op,
2322 PatternRewriter &rewriter) const override {
2323 // Checks that the input is created by TensorListFromTensorOp and the input
2324 // is only used by TensorListGetItemOp.
2325 auto tensor_list_from_tensor_op = dyn_cast_or_null<TensorListFromTensorOp>(
2326 op.input_handle().getDefiningOp());
2327 if (!tensor_list_from_tensor_op ||
2328 llvm::any_of(
2329 tensor_list_from_tensor_op->getUsers(),
2330 [](Operation *user) { return !isa<TensorListGetItemOp>(user); })) {
2331 return failure();
2332 }
2333
2334 rewriter.replaceOpWithNewOp<GatherOp>(
2335 op, op.getType(), tensor_list_from_tensor_op.tensor(), op.index());
2336 return success();
2337 }
2338 };
2339 } // namespace
2340
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2341 void TensorListGetItemOp::getCanonicalizationPatterns(
2342 RewritePatternSet &results, MLIRContext *context) {
2343 results.add<ConvertTensorListGetItemOpOfTensorListFromTensorOpToGather>(
2344 context);
2345 }
2346
2347 //===----------------------------------------------------------------------===//
2348 // TensorListReserveOp
2349 //===----------------------------------------------------------------------===//
2350
verify()2351 LogicalResult TensorListReserveOp::verify() {
2352 TensorListReserveOp op = *this;
2353 // This is required to populate derived attributes during export in a
2354 // meaningful way. Else during export to GraphDef element_type() query
2355 // will result in out of bounds access/assert.
2356 if (handle_dtype().getSubtypes().size() != 1) {
2357 return emitOpError(
2358 "must have exactly one subtype in the result variant type");
2359 }
2360 if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2361 !IsOfRankOrUnranked(op.element_shape(), 1)) {
2362 return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2363 }
2364
2365 if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
2366 return op.emitOpError("requires num_elements operand to be 0D tensor");
2367 }
2368 return success();
2369 }
2370
2371 //===----------------------------------------------------------------------===//
2372 // TensorListElementShapeOp
2373 //===----------------------------------------------------------------------===//
2374
fold(ArrayRef<Attribute> operands)2375 OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
2376 int width =
2377 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2378 auto variant_type =
2379 getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
2380 if (variant_type.getSubtypes().empty()) return {};
2381 return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
2382 }
2383
2384 //===----------------------------------------------------------------------===//
2385 // TensorListStackOp
2386 //===----------------------------------------------------------------------===//
2387
verify()2388 LogicalResult TensorListStackOp::verify() {
2389 TensorListStackOp op = *this;
2390 if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2391 !IsOfRankOrUnranked(op.element_shape(), 1)) {
2392 return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2393 }
2394 return success();
2395 }
2396
2397 //===----------------------------------------------------------------------===//
2398 // TensorScatterUpdateOp
2399 //===----------------------------------------------------------------------===//
2400
verify()2401 LogicalResult TensorScatterUpdateOp::verify() {
2402 TensorScatterUpdateOp op = *this;
2403 if (!HasRankAtLeast(op.tensor(), 1))
2404 return op.emitOpError(
2405 "requires tensor operand to have at least 1 dimension");
2406 if (!HasRankAtLeast(op.indices(), 1))
2407 return op.emitOpError(
2408 "requires indices operand to have at least 1 dimension");
2409
2410 auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
2411 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2412 if (!tensor_ty || !indices_ty) return success();
2413
2414 int64_t num_index_dims = indices_ty.getShape().back();
2415 if (ShapedType::isDynamic(num_index_dims)) return success();
2416
2417 if (num_index_dims > tensor_ty.getRank())
2418 return op.emitOpError(
2419 "requires tensor operand with rank greater than or equal to the "
2420 "indices operand's last dimensions");
2421 return success();
2422 }
2423
2424 //===----------------------------------------------------------------------===//
2425 // TileOp
2426 //===----------------------------------------------------------------------===//
2427
2428 // Verifies that,
2429 //
2430 // - input has at least rank 1
2431 // - multiples is rank 1
2432 // - multiples.size() == input.rank()
2433 // - input.rank() == output.rank()
2434 // - Elements in multiples are non-negative
2435 // - input.shape[i] * multiples[i] == output.shape[i]
2436 // for i in [0, input.rank() - 1]
2437
verify()2438 LogicalResult TileOp::verify() {
2439 TileOp op = *this;
2440 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
2441 auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
2442 auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
2443
2444 if (multiples_type && multiples_type.getRank() != 1) {
2445 return op.emitOpError() << "expected multiples to be rank 1, got rank = "
2446 << multiples_type.getRank();
2447 }
2448
2449 if (input_type && multiples_type && multiples_type.hasStaticShape() &&
2450 (input_type.getRank() != multiples_type.getNumElements() ||
2451 (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
2452 return op.emitOpError()
2453 << "expected size of multiples equal to rank of input"
2454 << ", got multiples of size " << multiples_type.getNumElements()
2455 << ", and input of rank " << input_type.getRank();
2456 }
2457
2458 if (input_type && output_type) {
2459 if (input_type.getRank() != output_type.getRank()) {
2460 return op.emitOpError()
2461 << "expected rank of input to equal to rank of output"
2462 << ", got input of rank " << input_type.getRank()
2463 << ", and output of rank " << output_type.getRank();
2464 }
2465
2466 DenseIntElementsAttr multiples_attr;
2467 if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
2468 for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
2469 const int64_t input_dim = input_type.getDimSize(i);
2470 const int64_t output_dim = output_type.getDimSize(i);
2471 const int64_t m = multiples_attr.getValues<APInt>()[i].getSExtValue();
2472
2473 if (m < 0) {
2474 return op.emitOpError()
2475 << "expected multiples to be non-negative, got "
2476 << "multiples[" << i << "] = " << m;
2477 }
2478
2479 if (!ShapedType::isDynamic(input_dim) &&
2480 !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
2481 return op.emitOpError()
2482 << "requires input.shape[" << i << "] (" << input_dim << ")"
2483 << " * " << m << " to be equal to "
2484 << "output.shape[" << i << "] (" << output_dim << ")";
2485 }
2486 }
2487 }
2488 }
2489
2490 return success();
2491 }
2492
fold(ArrayRef<Attribute> operands)2493 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
2494 DenseIntElementsAttr multiples_attr;
2495 if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
2496 // Return input directly when multiples are all ones,
2497 // regardless what input is.
2498 if (multiples_attr.isSplat() &&
2499 multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
2500 return input();
2501 }
2502 }
2503 return {};
2504 }
2505
2506 //===----------------------------------------------------------------------===//
2507 // TopKV2Op
2508 //===----------------------------------------------------------------------===//
2509
verify()2510 LogicalResult TopKV2Op::verify() {
2511 TopKV2Op op = *this;
2512 if (!HasRankAtLeast(op.input(), 1))
2513 return op.emitOpError(
2514 "requires input operand to have at least 1 dimension");
2515
2516 if (!IsOfRankOrUnranked(op.k(), 0))
2517 return op.emitOpError("requires k operand to be 0D tensor");
2518
2519 return success();
2520 }
2521
2522 //===----------------------------------------------------------------------===//
2523 // ToBoolOp
2524 //===----------------------------------------------------------------------===//
2525
2526 namespace {
2527 // If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
2528 // into an identity or an equality comparison.
2529 class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
2530 using OpRewritePattern<ToBoolOp>::OpRewritePattern;
matchAndRewrite(ToBoolOp op,PatternRewriter & rewriter) const2531 LogicalResult matchAndRewrite(ToBoolOp op,
2532 PatternRewriter &rewriter) const override {
2533 auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
2534 // If the input is an unranked tensor, cannpt rewrite.
2535 if (!type) return failure();
2536
2537 // Expected return type of the ToBool operation. The return type of ToBool
2538 // operation is always 0D tensor of bool type.
2539 auto result_type = op.getResult().getType().cast<RankedTensorType>();
2540
2541 // If input is already a tensor<i1>, it can be folded into an identity.
2542 if (type == result_type) {
2543 rewriter.replaceOp(op, op.getOperand());
2544 return success();
2545 }
2546
2547 if (type.getRank() == 0) {
2548 // If the input is a scalar tensor, the ToBool can be expanded to
2549 // element != 0 (for numerical values) or element == empty (for string).
2550 Type element_type = type.getElementType();
2551 Attribute zero_attr;
2552 if (element_type.isIntOrFloat())
2553 zero_attr = rewriter.getZeroAttr(type);
2554 else if (element_type.isa<TF::StringType>())
2555 zero_attr = DenseStringElementsAttr::get(type, {""});
2556
2557 if (!zero_attr) return failure();
2558
2559 auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
2560 rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
2561 op, result_type, op.getOperand(), zero_const, false);
2562 } else {
2563 // If the input is a non-scalar ranked tensor, ToBool can be expanded
2564 // to numElements != 0. numElements will be 0 iff one of the dimensions is
2565 // zero.
2566 bool any_zero =
2567 llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
2568 rewriter.replaceOpWithNewOp<TF::ConstOp>(
2569 op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
2570 }
2571 return success();
2572 }
2573 };
2574 } // namespace
2575
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2576 void ToBoolOp::getCanonicalizationPatterns(RewritePatternSet &results,
2577 MLIRContext *context) {
2578 results.add<ToBoolOfRankedTensor>(context);
2579 }
2580
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2581 LogicalResult ToBoolOp::inferReturnTypes(
2582 MLIRContext *context, Optional<Location> location, ValueRange operands,
2583 DictionaryAttr attributes, RegionRange regions,
2584 SmallVectorImpl<Type> &inferredReturnTypes) {
2585 inferredReturnTypes.push_back(
2586 RankedTensorType::get({}, IntegerType::get(context, 1)));
2587 return success();
2588 }
2589
2590 //===----------------------------------------------------------------------===//
2591 // TransposeOp
2592 //===----------------------------------------------------------------------===//
2593
verify()2594 LogicalResult TransposeOp::verify() {
2595 TransposeOp op = *this;
2596 auto perm_type = op.perm().getType().dyn_cast<RankedTensorType>();
2597 auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
2598 auto y_type = op.y().getType().dyn_cast<RankedTensorType>();
2599
2600 if (perm_type && perm_type.getRank() != 1) {
2601 return op.emitOpError()
2602 << "expected perm to be a 1-D Tensor, got perm of rank "
2603 << perm_type.getRank();
2604 }
2605
2606 if (x_type && y_type && x_type.getRank() != y_type.getRank()) {
2607 return op.emitOpError() << "x should be of the same rank with y, got "
2608 << "x of rank " << x_type.getRank()
2609 << ", and y of rank " << y_type.getRank();
2610 }
2611
2612 if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) {
2613 return success();
2614 }
2615
2616 if (x_type.getRank() != perm_type.getNumElements()) {
2617 return op.emitOpError() << "expected perm to be a 1-D Tensor of size "
2618 << "equal to the rank of x, got perm of size "
2619 << perm_type.getNumElements() << ", and x of rank "
2620 << x_type.getRank();
2621 }
2622
2623 DenseIntElementsAttr attr_perm;
2624 if (matchPattern(op.perm(), m_Constant(&attr_perm))) {
2625 // y.shape[i] should be equal to x.shape[perm[i]]
2626 // for i = [0, 1, ..., rank(x) - 1]
2627 for (auto e : llvm::enumerate(attr_perm)) {
2628 const int64_t y_idx = e.index();
2629 const int64_t y_dim = y_type.getDimSize(y_idx);
2630 const int64_t x_idx = e.value().getSExtValue();
2631 const int64_t x_dim = x_type.getDimSize(x_idx);
2632 if (!ShapedType::isDynamic(y_dim) && !ShapedType::isDynamic(x_dim) &&
2633 y_dim != x_dim) {
2634 return op.emitOpError()
2635 << "requires y.shape[" << y_idx << "] (" << y_dim << ") "
2636 << "to be equal to x.shape[perm[" << x_idx << "]] "
2637 << "(" << x_dim << ")";
2638 }
2639 }
2640 }
2641
2642 return success();
2643 }
2644
2645 // TODO(jpienaar): perm could be optional too.
build(OpBuilder & builder,OperationState & result,Value x,Value perm)2646 void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
2647 Value perm) {
2648 auto x_type = x.getType().cast<TensorType>();
2649 // If value is unranked, then so is results.
2650 if (!x_type.hasRank())
2651 return TransposeOp::build(builder, result,
2652 UnrankedTensorType::get(x_type.getElementType()),
2653 x, perm);
2654
2655 // TODO(jpienaar): Handle unknown perm case.
2656
2657 // TODO(jpienaar): Extract utility function.
2658 auto etype = x_type.cast<ShapedType>().getElementType();
2659 DenseIntElementsAttr attr_shape;
2660 if (matchPattern(perm, m_Constant(&attr_shape))) {
2661 llvm::SmallVector<int64_t, 4> const_shape;
2662 if (attr_shape.isSplat()) {
2663 const_shape.assign(
2664 attr_shape.getNumElements(),
2665 x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
2666 } else {
2667 const_shape.reserve(attr_shape.getNumElements());
2668 for (const auto &dim : attr_shape)
2669 const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
2670 }
2671 return TransposeOp::build(
2672 builder, result, RankedTensorType::get(const_shape, etype), x, perm);
2673 }
2674 return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x,
2675 perm);
2676 }
2677
2678 namespace {
2679
FoldIdentityTranspose(TransposeOp op)2680 OpFoldResult FoldIdentityTranspose(TransposeOp op) {
2681 DenseIntElementsAttr perm;
2682 if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
2683 const auto elements = perm.getValues<APInt>();
2684
2685 for (auto it : llvm::enumerate(elements)) {
2686 if (it.index() != it.value()) return {};
2687 }
2688
2689 // TODO(jpienaar): Remove if/when we handle this more generally.
2690 if (op.getType() != op.x().getType()) {
2691 // If the types don't match then only fold if all the operands are in the TF
2692 // dialect.
2693 for (auto user : op.getOperation()->getUsers())
2694 if (user->getDialect() != op->getDialect()) return {};
2695 }
2696
2697 return op.x();
2698 }
2699
FoldCancellableTranspose(TransposeOp op)2700 OpFoldResult FoldCancellableTranspose(TransposeOp op) {
2701 // Operand is a TransposeOp.
2702 auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
2703 if (!transpose) return {};
2704
2705 // Permutations defined by constant operations.
2706 DenseIntElementsAttr perm0;
2707 DenseIntElementsAttr perm1;
2708 if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
2709 !matchPattern(transpose.perm(), m_Constant(&perm1)))
2710 return {};
2711
2712 // With permutation indices that cancel each other
2713 if (!AreCancellablePermutations(perm0, perm1)) return {};
2714
2715 return transpose.x();
2716 }
2717
2718 } // namespace
2719
fold(ArrayRef<Attribute> operands)2720 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2721 if (auto folded = FoldIdentityTranspose(*this)) return folded;
2722 if (auto folded = FoldCancellableTranspose(*this)) return folded;
2723 return {};
2724 }
2725
2726 //===----------------------------------------------------------------------===//
2727 // TruncateDivOp
2728 //===----------------------------------------------------------------------===//
2729
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2730 void TruncateDivOp::getCanonicalizationPatterns(RewritePatternSet &results,
2731 MLIRContext *context) {
2732 results.add<TruncateDivWithSqrtDivisor>(context);
2733 }
2734
2735 //===----------------------------------------------------------------------===//
2736 // NonMaxSuppressionV3Op
2737 //===----------------------------------------------------------------------===//
2738
2739 namespace {
2740
2741 // Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op.
2742 class NMSV3ToNMSV4Op : public OpRewritePattern<NonMaxSuppressionV3Op> {
2743 using OpRewritePattern<NonMaxSuppressionV3Op>::OpRewritePattern;
matchAndRewrite(NonMaxSuppressionV3Op nms_op,PatternRewriter & rewriter) const2744 LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op,
2745 PatternRewriter &rewriter) const override {
2746 if (nms_op.getNumOperands() != 5) {
2747 return failure();
2748 }
2749 SmallVector<Type, 2> new_result_types;
2750 new_result_types.push_back(nms_op.getType());
2751 auto input_ty = nms_op.getType().template cast<ShapedType>();
2752 // corresponds to the second result type of nmsv4
2753 RankedTensorType valid_output_type =
2754 RankedTensorType::get({}, input_ty.getElementType());
2755 new_result_types.push_back(valid_output_type);
2756
2757 auto nmsv4 = rewriter.create<TF::NonMaxSuppressionV4Op>(
2758 nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(),
2759 nms_op.max_output_size(), nms_op.iou_threshold(),
2760 nms_op.score_threshold());
2761 // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the
2762 // two are different (v4 expects two output values vs v3 requires only one.
2763 nms_op.replaceAllUsesWith(nmsv4.getResult(0));
2764 return success();
2765 }
2766 };
2767 } // namespace.
2768
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2769 void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
2770 RewritePatternSet &results, MLIRContext *context) {
2771 results.add<NMSV3ToNMSV4Op>(context);
2772 }
2773
2774 //===----------------------------------------------------------------------===//
2775 // FusedBatchNormOp
2776 //===----------------------------------------------------------------------===//
2777
2778 namespace {
2779
2780 class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
2781 using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,PatternRewriter & rewriter) const2782 LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
2783 PatternRewriter &rewriter) const override {
2784 auto new_result_types =
2785 llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
2786 // reserve_space_3
2787 new_result_types.push_back(
2788 UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
2789
2790 OperationState new_state(tf_fused_batch_norm_op.getLoc(),
2791 TF::FusedBatchNormV3Op::getOperationName(),
2792 tf_fused_batch_norm_op.getOperands(),
2793 new_result_types,
2794 tf_fused_batch_norm_op->getAttrs());
2795 Operation *tf_fused_batch_norm_op_v3 = rewriter.create(new_state);
2796
2797 rewriter.replaceOp(tf_fused_batch_norm_op,
2798 tf_fused_batch_norm_op_v3->getResults().drop_back());
2799 return success();
2800 }
2801 };
2802 } // namespace.
2803
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2804 void FusedBatchNormOp::getCanonicalizationPatterns(RewritePatternSet &results,
2805 MLIRContext *context) {
2806 results.add<ConvertFusedBatchNorm>(context);
2807 }
2808
2809 //===----------------------------------------------------------------------===//
2810 // UnpackOp
2811 //===----------------------------------------------------------------------===//
2812
verify()2813 LogicalResult UnpackOp::verify() {
2814 UnpackOp op = *this;
2815 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
2816 if (!value_type) return success();
2817
2818 int64_t value_rank = value_type.getRank();
2819 int64_t axis = op.axis();
2820 if (axis < -value_rank || axis >= value_rank)
2821 return op.emitOpError("axis attribute must be in the range of [-")
2822 << value_rank << ", " << value_rank << ')';
2823
2824 axis = GetDimForAxis(axis, value_rank);
2825 int64_t dim_size = value_type.getDimSize(axis);
2826 if (ShapedType::isDynamic(dim_size)) return success();
2827
2828 if (dim_size != op.getNumResults())
2829 return op.emitOpError("result count must be equal to ") << dim_size;
2830
2831 return success();
2832 }
2833
2834 namespace {
2835
2836 // Hoist coefficient-wise unary operation out of the Unpack op:
2837 //
2838 // %unpacked:N = "tf.Unpack"(%0)
2839 // %neg0 = "tf.Neg"(%unpacked#0)
2840 // %neg1 = "tf.Neg"(%unpacked#1)
2841 // ...
2842 // %negN-1 = "tf.Neg"(%unpacked:N-1)
2843 //
2844 // Rewrite it to:
2845 //
2846 // %neg = "tf.Neg"(%0)
2847 // %unpacked:N = "tf.Unpack"(%neg)
2848 class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
2849 public:
HoistCwiseUnaryOutOfUnpack(MLIRContext * context)2850 explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
2851 : OpRewritePattern<UnpackOp>(context) {}
2852 LogicalResult matchAndRewrite(UnpackOp op,
2853 PatternRewriter &rewriter) const override;
2854 };
2855
matchAndRewrite(UnpackOp op,PatternRewriter & rewriter) const2856 LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
2857 UnpackOp op, PatternRewriter &rewriter) const {
2858 auto loc = op.getLoc();
2859
2860 // First unpack user must be coeff-wise unary operation.
2861 Operation *first_user = *op->getUsers().begin();
2862 if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
2863
2864 // All unpack users must be defined by the op of same kind.
2865 bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
2866 return user->getName() == first_user->getName();
2867 });
2868 if (!users_same_op) return failure();
2869
2870 // Pass unpack operand to unary operation.
2871 OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
2872 op.getOperand(), op.getOperand().getType(),
2873 ArrayRef<NamedAttribute>());
2874 Operation *new_unary_op = rewriter.create(new_unary_op_state);
2875
2876 // Unpack results after applying unary operation.
2877 auto unpack_unary_op = rewriter.create<UnpackOp>(
2878 loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
2879
2880 // Bypass all users of the original unpack operation and use `unpack_unary_op`
2881 // results instead.
2882 for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
2883 OpResult old_result = std::get<0>(pair); // result of original Unpack
2884 OpResult new_result = std::get<1>(pair); // result of transformed Unpack
2885 for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
2886 rewriter.replaceOp(user, ValueRange(new_result));
2887 }
2888
2889 // Erase original unpack operation.
2890 rewriter.eraseOp(op.getOperation());
2891
2892 return success();
2893 }
2894
2895 } // namespace
2896
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2897 void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
2898 MLIRContext *context) {
2899 results.add<HoistCwiseUnaryOutOfUnpack>(context);
2900 }
2901
2902 //===----------------------------------------------------------------------===//
2903 // Unsorted segment reduction ops
2904 //===----------------------------------------------------------------------===//
2905
2906 template <class Op>
VerifyUnsortedSegmentReduction(Op op)2907 static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
2908 if (!HasRankAtMost(op.num_segments(), 0))
2909 return op.emitOpError("number of segments should be a 0-D tensor");
2910
2911 auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
2912 auto segment_ids_type =
2913 op.segment_ids().getType().template dyn_cast<RankedTensorType>();
2914 if (data_type && segment_ids_type) {
2915 if (data_type.getRank() < segment_ids_type.getRank())
2916 return op.emitOpError(
2917 "requires segment ids rank to be less than or equal to data's rank");
2918
2919 int index = 0;
2920 for (auto shape_pair :
2921 llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
2922 int64_t segment_id_dim = std::get<0>(shape_pair);
2923 int64_t data_dim = std::get<1>(shape_pair);
2924 if (!ShapedType::isDynamic(segment_id_dim) &&
2925 !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
2926 return op.emitOpError(
2927 "requires segment ids shape to be a prefix of data shape, "
2928 "but dimension #")
2929 << index << " differs: " << segment_id_dim << " vs. "
2930 << data_dim;
2931 ++index;
2932 }
2933 }
2934
2935 DenseIntElementsAttr num_segments_attr;
2936 if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
2937 int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
2938 if (num_segments < 0)
2939 return op.emitOpError("num of segments cannot be negative");
2940 }
2941
2942 return success();
2943 }
2944
verify()2945 LogicalResult UnsortedSegmentMaxOp::verify() {
2946 return VerifyUnsortedSegmentReduction(*this);
2947 }
verify()2948 LogicalResult UnsortedSegmentMinOp::verify() {
2949 return VerifyUnsortedSegmentReduction(*this);
2950 }
verify()2951 LogicalResult UnsortedSegmentProdOp::verify() {
2952 return VerifyUnsortedSegmentReduction(*this);
2953 }
verify()2954 LogicalResult UnsortedSegmentSumOp::verify() {
2955 return VerifyUnsortedSegmentReduction(*this);
2956 }
2957
2958 //===----------------------------------------------------------------------===//
2959 // VarHandleOp
2960 //===----------------------------------------------------------------------===//
2961
verify()2962 LogicalResult VarHandleOp::verify() {
2963 // VarHandleOp requires the resource handle supply a single subtype from
2964 // which to derive the dtype and shape attributes.
2965 if (resource_type().getSubtypes().size() != 1) {
2966 return emitOpError(
2967 "must have exactly one subtype in the result resource type");
2968 }
2969
2970 return success();
2971 }
2972
2973 llvm::SmallVector<ResourceHandleValueAndId, 4>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2974 VarHandleOp::GetResourceHandleValueAndIdList(
2975 llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2976 int64_t &next_id) {
2977 llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2978 return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2979 resource(), resource_handle_id_map,
2980 next_id)};
2981 }
2982
2983 //===----------------------------------------------------------------------===//
2984 // VarIsInitializedOp
2985 //===----------------------------------------------------------------------===//
2986
2987 namespace {
2988
2989 /// Erase VarIsInitializedOp operations with no uses. This op has side effect on
2990 /// resources (read-only), but can still be deleted if it has zero uses.
2991 struct EraseDeadVarIsInitializedOp
2992 : public OpRewritePattern<VarIsInitializedOp> {
2993 using OpRewritePattern<VarIsInitializedOp>::OpRewritePattern;
2994
matchAndRewritemlir::TF::__anonf66a1e981611::EraseDeadVarIsInitializedOp2995 LogicalResult matchAndRewrite(VarIsInitializedOp op,
2996 PatternRewriter &rewriter) const override {
2997 if (!op.use_empty()) return failure();
2998 rewriter.eraseOp(op);
2999 return success();
3000 }
3001 };
3002 } // end anonymous namespace.
3003
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)3004 void VarIsInitializedOp::getCanonicalizationPatterns(
3005 RewritePatternSet &patterns, MLIRContext *context) {
3006 patterns.add<EraseDeadVarIsInitializedOp>(context);
3007 }
3008
3009 //===----------------------------------------------------------------------===//
3010 // VariableOp
3011 //===----------------------------------------------------------------------===//
3012
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3013 void VariableOp::getCanonicalizationPatterns(RewritePatternSet &results,
3014 MLIRContext *context) {
3015 results.add<VariableToVariableV2>(context);
3016 }
3017
3018 //===----------------------------------------------------------------------===//
3019 // VariableShapeOp
3020 //===----------------------------------------------------------------------===//
3021
verify()3022 LogicalResult VariableShapeOp::verify() {
3023 VariableShapeOp op = *this;
3024 auto input_type = op.input().getType().cast<TensorType>();
3025 if (input_type.hasStaticShape() && input_type.getNumElements() != 1)
3026 return op.emitOpError("requires input to have one resource");
3027
3028 auto resource_type = input_type.getElementType().cast<TF::ResourceType>();
3029 auto subtypes = resource_type.getSubtypes();
3030 switch (subtypes.size()) {
3031 case 1:
3032 return VerifyShapeOperandAndResult(
3033 op, resource_type.getSubtypes().front(), op.getType());
3034 case 0:
3035 return VerifyShapeOperandAndResult(op, Type(), op.getType());
3036 default:
3037 return op.emitOpError(
3038 "requires resource input type to have at most 1 subtype");
3039 }
3040 }
3041
fold(ArrayRef<Attribute> operands)3042 OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
3043 int width =
3044 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
3045 auto resource_type =
3046 getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
3047 if (resource_type.getSubtypes().empty()) return {};
3048 return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
3049 }
3050
3051 //===----------------------------------------------------------------------===//
3052 // WhileOp
3053 //===----------------------------------------------------------------------===//
3054
VerifyWhileTypes(Operation * op,TypeRange cond_input,TypeRange body_input,TypeRange body_result,bool shape_invariant)3055 static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
3056 TypeRange body_input,
3057 TypeRange body_result,
3058 bool shape_invariant) {
3059 const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
3060 const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
3061 constexpr int kNumRegionTypeLists = 3;
3062 const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
3063 {body_result, "body result"},
3064 {cond_input, "condition input"},
3065 {body_input, "body input"},
3066 }};
3067
3068 // A pair of type lists should be cast compatible with each other if one is
3069 // converted to the another for a function call or assignment or there is a
3070 // common source of inputs for both. Therefore, the While op requires the
3071 // following pairs of type lists to be cast compatible for the tensor_cast
3072 // operation:
3073 //
3074 // * Operands and cond inputs to call the cond function before the
3075 // first iteration.
3076 // * Operands and body inputs to call the body function for the first
3077 // iteration if the cond functions returns True or equivalent result.
3078 // * Operands and results to assign cond function arguments to op results if
3079 // the cond function returns False or equivalent result. If the op is shape
3080 // invariant, this does not hold as shapes can differ.
3081 // * All three pairs using cond inputs, body inputs and results as operand is
3082 // a common source for all three.
3083 // * Body result and cond inputs to call the cond function for the subsequent
3084 // iterations. Similarly, Body result should be compatible with body inputs
3085 // and op results.
3086 //
3087 // Note that the operands and body results need not be compatible as they are
3088 // never converted from one to the another nor there is a common source
3089 // tensors. Compatibility requirement is not transitive.
3090
3091 if (!shape_invariant &&
3092 failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
3093 return failure();
3094
3095 // Skip the first pair as the While op operands and body function results does
3096 // not need to be compatible with each other.
3097 for (int i = 1; i < kNumRegionTypeLists; ++i)
3098 if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
3099 return failure();
3100
3101 for (int i = 0; i < kNumRegionTypeLists; ++i)
3102 if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
3103 return failure();
3104
3105 for (int i = 0; i < kNumRegionTypeLists; ++i)
3106 for (int j = i + 1; j < kNumRegionTypeLists; ++j)
3107 if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
3108 region_types[j])))
3109 return failure();
3110
3111 return success();
3112 }
3113
verifySymbolUses(SymbolTableCollection & symbol_table)3114 LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
3115 auto cond_fn =
3116 symbol_table.lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr());
3117 auto body_fn =
3118 symbol_table.lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr());
3119 if (!cond_fn) {
3120 return emitOpError("cond refers to an undefined function : ") << cond();
3121 }
3122 if (!body_fn) {
3123 return emitOpError("body refers to an undefined function : ") << body();
3124 }
3125
3126 auto cond_fn_type = cond_fn.getFunctionType();
3127 auto body_fn_type = body_fn.getFunctionType();
3128
3129 // Verify that the cond function has exactly one result.
3130 if (cond_fn_type.getNumResults() != 1)
3131 return emitOpError("requires cond function to have exactly one result");
3132
3133 return VerifyWhileTypes(*this, /*cond_input=*/cond_fn_type.getInputs(),
3134 /*body_input=*/body_fn_type.getInputs(),
3135 /*body_result=*/body_fn_type.getResults(),
3136 shape_invariant());
3137 }
3138
3139 //===----------------------------------------------------------------------===//
3140 // WhileRegionOp
3141 //===----------------------------------------------------------------------===//
verify()3142 LogicalResult WhileRegionOp::verify() {
3143 WhileRegionOp op = *this;
3144 // Verify that the condition generates a single tensor<i1> result.
3145 Operation *cond_yield = op.cond().front().getTerminator();
3146 if (cond_yield->getNumOperands() != 1)
3147 return op.emitOpError()
3148 << "condition should have a single tensor<i1> result";
3149
3150 auto cond_type =
3151 cond_yield->getOperand(0).getType().dyn_cast<RankedTensorType>();
3152 if (!cond_type || !cond_type.getShape().equals({}) ||
3153 !cond_type.getElementType().isInteger(/*width=*/1))
3154 return op.emitOpError()
3155 << "condition should have a single tensor<i1> result";
3156
3157 Operation *body_yield = op.body().front().getTerminator();
3158 if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
3159 /*body_input=*/op.body().getArgumentTypes(),
3160 /*body_result=*/body_yield->getOperandTypes(),
3161 op.shape_invariant())))
3162 return failure();
3163 return success();
3164 }
3165
3166 //===----------------------------------------------------------------------===//
3167 // WhileRegionOp LoopLikeOpInterface
3168 //===----------------------------------------------------------------------===//
3169
getLoopBody()3170 Region &WhileRegionOp::getLoopBody() { return body(); }
3171
3172 //===----------------------------------------------------------------------===//
3173 // WhileRegionOp canonicalization
3174 //===----------------------------------------------------------------------===//
3175 namespace {
3176 // Eliminate values that pass through the WhileRegionOp body.
3177 struct WhileRegionEliminatePassThrough
3178 : public OpRewritePattern<WhileRegionOp> {
3179 using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
3180
matchAndRewritemlir::TF::__anonf66a1e981711::WhileRegionEliminatePassThrough3181 LogicalResult matchAndRewrite(WhileRegionOp while_op,
3182 PatternRewriter &rewriter) const override {
3183 // Remove any extern values that are explicitly captured and returned. Also
3184 // replace values that simply passthrough the body with extern values. The
3185 // block arguments of body and while match and so the corresponding cond
3186 // argument can be easily found.
3187 int old_num_operands = while_op.getNumOperands();
3188 int new_num_operands = old_num_operands;
3189 auto &body_block = while_op.body().front();
3190 auto &cond_block = while_op.cond().front();
3191 auto &yield = *body_block.getTerminator();
3192
3193 // Bit mask indicating which operands will be removed.
3194 llvm::BitVector removed_operand(old_num_operands);
3195
3196 for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3197 auto body_arg = body_block.getArgument(op_idx);
3198 auto yield_operand = LookThroughIdentity(yield.getOperand(op_idx));
3199 auto while_operand = while_op.getOperand(op_idx);
3200 if (body_arg == yield_operand || while_operand == yield_operand) {
3201 // Replace the use of the passthrough value with the while operand
3202 // in the body and condition regions, as well as the while output (if
3203 // type match)
3204 // TODO(jurahul): Use PatternRewriter API for IR modification.
3205 if (body_arg.getType() == while_operand.getType())
3206 body_arg.replaceAllUsesWith(while_operand);
3207
3208 auto cond_arg = cond_block.getArgument(op_idx);
3209 if (cond_arg.getType() == while_operand.getType())
3210 cond_arg.replaceAllUsesWith(while_operand);
3211
3212 auto result = while_op.getResult(op_idx);
3213 if (result.getType() == while_operand.getType())
3214 result.replaceAllUsesWith(while_operand);
3215 }
3216
3217 // Now check if the operand is unused in both regions as well as the
3218 // result. If so, mark it for removal.
3219 if (body_block.getArgument(op_idx).use_empty() &&
3220 cond_block.getArgument(op_idx).use_empty() &&
3221 while_op.getResult(op_idx).use_empty()) {
3222 removed_operand.set(op_idx);
3223 new_num_operands--;
3224 }
3225 }
3226
3227 if (new_num_operands == old_num_operands) return failure();
3228
3229 // Compress the operands, region arguments, and outputs.
3230 SmallVector<Value, 4> new_while_operands;
3231 SmallVector<Type, 4> new_result_types;
3232 new_while_operands.reserve(new_num_operands);
3233 new_result_types.reserve(new_num_operands);
3234
3235 // Build new operands and result type.
3236 for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3237 if (removed_operand.test(op_idx)) continue;
3238 new_while_operands.push_back(while_op.getOperand(op_idx));
3239 new_result_types.push_back(while_op.getResult(op_idx).getType());
3240 }
3241
3242 // Create the new while operation.
3243 auto new_while_op = rewriter.create<WhileRegionOp>(
3244 while_op.getLoc(), new_result_types, new_while_operands,
3245 while_op->getAttrs());
3246
3247 // Move region bodies to the new while.
3248 rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
3249 new_while_op.cond().end());
3250 rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
3251 new_while_op.body().end());
3252
3253 auto &new_cond_block = new_while_op.cond().front();
3254 auto &new_body_block = new_while_op.body().front();
3255 auto &new_yield = *new_body_block.getTerminator();
3256
3257 // Patch up the region bodies and yield.
3258 new_cond_block.eraseArguments(removed_operand);
3259 new_body_block.eraseArguments(removed_operand);
3260 new_yield.eraseOperands(removed_operand);
3261
3262 // Build a vector of new results. Also patch up the region bodies and
3263 // yield.
3264 SmallVector<Value, 4> new_results(old_num_operands);
3265 int next_idx = 0;
3266 for (int op_idx : llvm::seq<int>(0, old_num_operands))
3267 if (!removed_operand.test(op_idx))
3268 new_results[op_idx] = new_while_op.getResult(next_idx++);
3269
3270 rewriter.replaceOp(while_op, new_results);
3271 return success();
3272 }
3273 };
3274
3275 } // anonymous namespace
3276
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3277 void WhileRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
3278 MLIRContext *context) {
3279 results.add<WhileRegionEliminatePassThrough>(context);
3280 }
3281
3282 //===----------------------------------------------------------------------===//
3283 // XdivyOp
3284 //===----------------------------------------------------------------------===//
3285
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3286 void XdivyOp::getCanonicalizationPatterns(RewritePatternSet &results,
3287 MLIRContext *context) {
3288 results.add<XdivyWithSqrtDivisor>(context);
3289 }
3290
3291 //===----------------------------------------------------------------------===//
3292 // XlaBroadcastHelperOp
3293 //===----------------------------------------------------------------------===//
3294
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3295 LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents(
3296 MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
3297 DictionaryAttr attributes, RegionRange regions,
3298 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3299 XlaBroadcastHelperOpAdaptor op(operands.getValues(), attributes);
3300 Value lhs = op.lhs();
3301 Value rhs = op.rhs();
3302 auto set_unranked_results = [&]() {
3303 inferredReturnShapes.emplace_back(getElementTypeOrSelf(lhs));
3304 inferredReturnShapes.emplace_back(getElementTypeOrSelf(rhs));
3305 return success();
3306 };
3307
3308 RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
3309 RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
3310 if (!lhs_ty || !rhs_ty) return set_unranked_results();
3311
3312 int64_t lhs_rank = lhs_ty.getRank();
3313 int64_t rhs_rank = rhs_ty.getRank();
3314
3315 DenseIntElementsAttr dims;
3316 if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
3317 return set_unranked_results();
3318 }
3319
3320 if (dims.size() == 0) {
3321 if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
3322 return emitOptionalError(
3323 location,
3324 "if broadcast_dims is empty, both arguments must have equal rank or "
3325 "at least one argument must be a scalar");
3326 }
3327 inferredReturnShapes.emplace_back(lhs_ty.cast<ShapedType>());
3328 inferredReturnShapes.emplace_back(rhs_ty.cast<ShapedType>());
3329 return success();
3330 }
3331
3332 const bool broadcast_lhs = lhs_rank < rhs_rank;
3333 RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
3334 RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
3335
3336 if (dims.size() != min_rank_ty.getRank()) {
3337 return emitOptionalError(
3338 location,
3339 "broadcast_dims must have size equal to the smaller argument rank");
3340 }
3341
3342 int64_t output_rank = max_rank_ty.getRank();
3343 llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
3344 llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
3345 for (auto item : llvm::enumerate(dims)) {
3346 int64_t index = item.index();
3347 int64_t dim = item.value().getSExtValue();
3348 if (dim < 0 || dim > output_rank) {
3349 return emitOptionalError(location, "out of range broadcast dim");
3350 }
3351 if (is_broadcasted[dim]) {
3352 return emitOptionalError(location, "broadcast_dims has duplicates");
3353 }
3354 broadcast_shape[dim] = min_rank_ty.getDimSize(index);
3355 is_broadcasted[dim] = true;
3356 }
3357
3358 if (broadcast_lhs) {
3359 inferredReturnShapes.emplace_back(broadcast_shape, lhs_ty.getElementType());
3360 inferredReturnShapes.emplace_back(rhs_ty.cast<ShapedType>());
3361 } else {
3362 inferredReturnShapes.emplace_back(lhs_ty.cast<ShapedType>());
3363 inferredReturnShapes.emplace_back(broadcast_shape, rhs_ty.getElementType());
3364 }
3365 return success();
3366 }
3367
3368 //===----------------------------------------------------------------------===//
3369 // XlaConvOp
3370 //===----------------------------------------------------------------------===//
3371
3372 class XlaConvToV2 : public OpRewritePattern<TF::XlaConvOp> {
3373 public:
3374 using OpRewritePattern::OpRewritePattern;
3375
matchAndRewrite(TF::XlaConvOp op,PatternRewriter & rewriter) const3376 LogicalResult matchAndRewrite(TF::XlaConvOp op,
3377 PatternRewriter &rewriter) const override {
3378 SmallVector<Type> result_types{op.getResult().getType()};
3379 rewriter.replaceOpWithNewOp<TF::XlaConvV2Op>(
3380 op, op.getResult().getType(), op.lhs(), op.rhs(), op.window_strides(),
3381 op.padding(), op.lhs_dilation(), op.rhs_dilation(),
3382 op.feature_group_count(), op.dimension_numbers(), op.precision_config(),
3383 1);
3384 return ::mlir::success();
3385 };
3386 };
3387
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3388 void XlaConvOp::getCanonicalizationPatterns(RewritePatternSet &results,
3389 MLIRContext *context) {
3390 results.insert<XlaConvToV2>(context);
3391 }
3392
3393 //===----------------------------------------------------------------------===//
3394 // XlaConvV2Op
3395 //===----------------------------------------------------------------------===//
3396
verify()3397 LogicalResult XlaConvV2Op::verify() {
3398 XlaConvV2Op op = *this;
3399 DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr,
3400 rhs_dilation_attr, feature_group_count_attr;
3401 if (!(matchPattern(op.window_strides(), m_Constant(&window_strides_attr)) &&
3402 matchPattern(op.padding(), m_Constant(&padding_attr)) &&
3403 matchPattern(op.lhs_dilation(), m_Constant(&lhs_dilation_attr)) &&
3404 matchPattern(op.rhs_dilation(), m_Constant(&rhs_dilation_attr)) &&
3405 matchPattern(op.feature_group_count(),
3406 m_Constant(&feature_group_count_attr))))
3407 return success();
3408
3409 if (window_strides_attr.getType().getRank() != 1)
3410 return op.emitOpError() << "expects window_stride to be a vector";
3411
3412 const ShapedType &padding_ty = padding_attr.getType();
3413 if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2)
3414 return op.emitOpError()
3415 << "expects padding to be a matrix with minor dimension 2";
3416
3417 if (lhs_dilation_attr.getType().getRank() != 1)
3418 return op.emitOpError() << "expects lhs_dilation to be a vecotr";
3419
3420 if (rhs_dilation_attr.getType().getRank() != 1)
3421 return op.emitOpError() << "expects rhs_dilation to be a vecotr";
3422
3423 if (feature_group_count_attr.getType().getRank())
3424 return op.emitOpError() << "expects feature_group_count to be a scalar";
3425
3426 return success();
3427 }
3428
3429 //===----------------------------------------------------------------------===//
3430 // XlaSetDynamicDimensionSizeOp
3431 //===----------------------------------------------------------------------===//
3432
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3433 LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents(
3434 MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
3435 DictionaryAttr attributes, RegionRange regions,
3436 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3437 XlaSetDynamicDimensionSizeOpAdaptor op(operands.getValues(), attributes);
3438
3439 TensorType operand_ty = op.input().getType().cast<TensorType>();
3440 Type element_ty = operand_ty.getElementType();
3441
3442 TensorType result_ty;
3443 if (operand_ty.hasRank()) {
3444 auto shape = llvm::to_vector<4>(operand_ty.getShape());
3445
3446 DenseIntElementsAttr dim_index_attr;
3447 if (matchPattern(op.dim_index(), m_Constant(&dim_index_attr))) {
3448 int64_t dim_index = dim_index_attr.getValues<APInt>()[0].getSExtValue();
3449
3450 int64_t rank = operand_ty.getRank();
3451 if (dim_index < 0 || dim_index >= rank) {
3452 return emitOptionalError(location, "dim_index (", dim_index,
3453 ") is out of range [0, ", rank, ")");
3454 }
3455 shape[dim_index] = ShapedType::kDynamicSize;
3456 } else {
3457 shape.assign(shape.size(), ShapedType::kDynamicSize);
3458 }
3459 result_ty = RankedTensorType::get(shape, element_ty);
3460 } else {
3461 result_ty = UnrankedTensorType::get(element_ty);
3462 }
3463
3464 inferredReturnShapes.emplace_back(result_ty.cast<ShapedType>());
3465 return success();
3466 }
3467
3468 //===----------------------------------------------------------------------===//
3469 // XlaReduceOp
3470 //===----------------------------------------------------------------------===//
3471
3472 class XlaReduceToXlaVariadicReduceV2
3473 : public OpRewritePattern<TF::XlaReduceOp> {
3474 public:
3475 using OpRewritePattern::OpRewritePattern;
3476
matchAndRewrite(TF::XlaReduceOp op,PatternRewriter & rewriter) const3477 LogicalResult matchAndRewrite(TF::XlaReduceOp op,
3478 PatternRewriter &rewriter) const override {
3479 SmallVector<Value> inputs{op.input()};
3480 SmallVector<Value> init_values{op.init_value()};
3481 SmallVector<Type> result_types{op.getResult().getType()};
3482 rewriter.replaceOpWithNewOp<TF::XlaVariadicReduceV2Op>(
3483 op, result_types, inputs, init_values, op.dimensions_to_reduce(),
3484 op.reducer());
3485 return ::mlir::success();
3486 };
3487 };
3488
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3489 void XlaReduceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3490 MLIRContext *context) {
3491 results.add<XlaReduceToXlaVariadicReduceV2>(context);
3492 }
3493
3494 //===----------------------------------------------------------------------===//
3495 // XlaReduceWindowOp
3496 //===----------------------------------------------------------------------===//
3497
verify()3498 LogicalResult XlaReduceWindowOp::verify() {
3499 XlaReduceWindowOp op = *this;
3500 const auto &input_ty = op.input().getType().cast<ShapedType>();
3501
3502 auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult {
3503 ElementsAttr attr;
3504 if (matchPattern(val, m_Constant(&attr))) {
3505 if (attr.getType().getRank() != 1) {
3506 return op.emitOpError() << "expects the rank of " << attr_name
3507 << "to be 1, got " << attr.getType().getRank();
3508 }
3509 if (input_ty.hasRank()) {
3510 int64_t input_rank = input_ty.getRank();
3511 int64_t size = attr.size();
3512 if (input_rank != size) {
3513 return op.emitOpError() << "expects the size of " << attr_name
3514 << " to be equal to the input "
3515 "rank ("
3516 << size << " vs. " << input_rank << ")";
3517 }
3518 }
3519 }
3520 return success();
3521 };
3522
3523 if (check(op.window_dimensions(), "window_dimensions").failed())
3524 return failure();
3525
3526 if (check(op.window_strides(), "window_strides").failed()) return failure();
3527
3528 if (check(op.base_dilations(), "base_dilations").failed()) return failure();
3529
3530 if (check(op.window_dilations(), "window_dilations").failed())
3531 return failure();
3532
3533 ElementsAttr padding;
3534 if (matchPattern(op.padding(), m_Constant(&padding))) {
3535 const ShapedType &padding_ty = padding.getType();
3536 if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) {
3537 return op.emitOpError()
3538 << "expects padding to be a matrix with minor dimension 2, got "
3539 << padding.getType().getShape();
3540 }
3541 }
3542
3543 auto module = op->getParentOfType<mlir::ModuleOp>();
3544 auto func = dyn_cast_or_null<mlir::func::FuncOp>(
3545 SymbolTable::lookupSymbolIn(module, op.computation()));
3546 if (!func) {
3547 return op.emitOpError() << "has no reduction function specified";
3548 }
3549
3550 auto func_type = func.getFunctionType();
3551
3552 if (func_type.getNumInputs() != 2) {
3553 return op.emitOpError()
3554 << "expects reduction function to take 2 parameters, but "
3555 "has "
3556 << func_type.getNumInputs() << " parameter(s)";
3557 }
3558
3559 return success();
3560 }
3561
3562 //===----------------------------------------------------------------------===//
3563 // XlaSelectAndScatterOp
3564 //===----------------------------------------------------------------------===//
3565
verify()3566 LogicalResult XlaSelectAndScatterOp::verify() {
3567 XlaSelectAndScatterOp op = *this;
3568 auto input_ty = op.operand().getType().cast<ShapedType>();
3569
3570 auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult {
3571 ElementsAttr attr;
3572 if (input_ty.hasRank() && matchPattern(val, m_Constant(&attr))) {
3573 int64_t input_rank = input_ty.getRank();
3574 int64_t size = attr.size();
3575 if (input_rank != size) {
3576 return op.emitOpError() << "expects the size of " << attr_name
3577 << "to be equal to the input "
3578 "rank ("
3579 << size << " vs. " << input_rank << ")";
3580 }
3581 }
3582 return success();
3583 };
3584
3585 if (check(op.window_dimensions(), "window_dimensions").failed())
3586 return failure();
3587
3588 if (check(op.window_strides(), "window_strides").failed()) return failure();
3589
3590 ElementsAttr padding;
3591 if (matchPattern(op.padding(), m_Constant(&padding))) {
3592 const ShapedType &padding_ty = padding.getType();
3593 if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) {
3594 return op.emitOpError()
3595 << "expects padding to be a matrix with minor dimension 2, got "
3596 << padding.getType().getShape();
3597 }
3598 }
3599
3600 auto module = op->getParentOfType<mlir::ModuleOp>();
3601 auto select_func = dyn_cast_or_null<mlir::func::FuncOp>(
3602 SymbolTable::lookupSymbolIn(module, op.select()));
3603 if (!select_func) {
3604 return op.emitOpError() << "has no select function specified";
3605 }
3606 auto select_func_type = select_func.getFunctionType();
3607 if (select_func_type.getNumInputs() != 2) {
3608 return op.emitOpError()
3609 << "expects select function to take 2 parameters, but has "
3610 << select_func_type.getNumInputs() << " parameter(s)";
3611 }
3612 if (select_func_type.getNumResults() != 1 ||
3613 !getElementTypeOrSelf(select_func_type.getResult(0)).isInteger(1)) {
3614 return op.emitOpError() << "expects select function to return a single "
3615 "boolean result but got "
3616 << select_func_type.getResult(0);
3617 }
3618 auto scatter_func = dyn_cast_or_null<mlir::func::FuncOp>(
3619 SymbolTable::lookupSymbolIn(module, op.scatter()));
3620 if (!scatter_func) {
3621 return op.emitOpError() << "has no scatter function specified";
3622 }
3623 auto scatter_func_type = scatter_func.getFunctionType();
3624 if (scatter_func_type.getNumInputs() != 2) {
3625 return op.emitOpError()
3626 << "expects scatter function to take 2 parameters, but has "
3627 << scatter_func_type.getNumInputs() << " parameter(s)";
3628 }
3629
3630 return success();
3631 }
3632
3633 //===----------------------------------------------------------------------===//
3634 // XlaVariadicReduceOp
3635 //===----------------------------------------------------------------------===//
3636
verify()3637 LogicalResult XlaVariadicReduceOp::verify() {
3638 XlaVariadicReduceOp op = *this;
3639 // We rely on V2 for the majority of the checks.
3640 const auto &input_ty = op.input().getType();
3641 if (input_ty.empty()) return op.emitOpError() << "No input";
3642 const auto &dtype = input_ty[0].cast<TensorType>().getElementType();
3643 for (const auto &ty : input_ty) {
3644 if (ty.cast<TensorType>().getElementType() != dtype)
3645 return op.emitOpError()
3646 << "This version is limited to operands of the same dtype";
3647 }
3648 return success();
3649 }
3650
3651 class XlaVariadicReduceToV2 : public OpRewritePattern<TF::XlaVariadicReduceOp> {
3652 public:
3653 using OpRewritePattern::OpRewritePattern;
3654
matchAndRewrite(TF::XlaVariadicReduceOp op,PatternRewriter & rewriter) const3655 LogicalResult matchAndRewrite(TF::XlaVariadicReduceOp op,
3656 PatternRewriter &rewriter) const override {
3657 mlir::TF::XlaVariadicReduceV2Op xla_variadic_reduce_v2_op =
3658 rewriter.create<::mlir::TF::XlaVariadicReduceV2Op>(
3659 op.getLoc(), op.getResults().getTypes(), op.input(),
3660 op.init_value(), op.dimensions_to_reduce(), op.reducer());
3661
3662 rewriter.replaceOp(op, xla_variadic_reduce_v2_op.getResults());
3663 return ::mlir::success();
3664 };
3665 };
3666
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3667 void XlaVariadicReduceOp::getCanonicalizationPatterns(
3668 RewritePatternSet &results, MLIRContext *context) {
3669 results.add<XlaVariadicReduceToV2>(context);
3670 }
3671
3672 //===----------------------------------------------------------------------===//
3673 // XlaVariadicReduceV2Op
3674 //===----------------------------------------------------------------------===//
3675
verify()3676 LogicalResult XlaVariadicReduceV2Op::verify() {
3677 XlaVariadicReduceV2Op op = *this;
3678 const auto &inputs_ty = op.inputs().getType();
3679 int n_inputs = inputs_ty.size();
3680 if (n_inputs < 1) return op.emitOpError() << "No inputs";
3681
3682 const auto &init_values_ty = op.init_values().getType();
3683 int n_init_values = init_values_ty.size();
3684 if (n_init_values != n_inputs) {
3685 return op.emitOpError() << "Number of inputs (" << n_inputs
3686 << ") is different than number of init_values ("
3687 << n_init_values << ")";
3688 }
3689
3690 auto input_ty_0 = inputs_ty[0].cast<ShapedType>();
3691 if (input_ty_0.hasStaticShape()) {
3692 for (int i = 1; i < n_inputs; ++i) {
3693 auto input_ty_i = inputs_ty[i].cast<ShapedType>();
3694 if (input_ty_i.hasStaticShape() &&
3695 input_ty_i.getShape() != input_ty_0.getShape()) {
3696 return op.emitOpError()
3697 << "inputs[" << i << "] has shape [" << input_ty_i.getShape()
3698 << "] different than the shape of inputs[0]: "
3699 << input_ty_0.getShape();
3700 }
3701 }
3702
3703 if (op.dimensions_to_reduce().size() > input_ty_0.getRank()) {
3704 return op.emitOpError()
3705 << "Invalid dimensions_to_reduce argument to XlaVariadicReduceV2";
3706 }
3707 }
3708
3709 for (int i = 0; i < n_inputs; ++i) {
3710 auto init_value_ty_i = init_values_ty[i].cast<ShapedType>();
3711 if (init_value_ty_i.hasRank() && init_value_ty_i.getRank() != 0) {
3712 return op.emitOpError()
3713 << "init_values[" << i << "] must be a scalar but got ["
3714 << init_value_ty_i.getShape() << "]";
3715 }
3716 }
3717
3718 auto module = op->getParentOfType<mlir::ModuleOp>();
3719 auto function = dyn_cast_or_null<mlir::func::FuncOp>(
3720 SymbolTable::lookupSymbolIn(module, op.reducer()));
3721 if (!function) return op.emitOpError() << "No reducer";
3722 if (!function.getBody().hasOneBlock())
3723 return op.emitOpError() << "reducer has more than one block";
3724
3725 return success();
3726 }
3727
3728 //===----------------------------------------------------------------------===//
3729 // XlaVariadicSortOp
3730 //===----------------------------------------------------------------------===//
3731
verify()3732 LogicalResult XlaVariadicSortOp::verify() {
3733 XlaVariadicSortOp op = *this;
3734 const auto &inputs_ty = op.inputs().getType();
3735 int n_inputs = inputs_ty.size();
3736 auto input_ty_0 = inputs_ty[0].cast<ShapedType>();
3737 if (input_ty_0.hasStaticShape()) {
3738 for (int i = 1; i < n_inputs; ++i) {
3739 auto input_ty_i = inputs_ty[i].cast<ShapedType>();
3740 if (input_ty_i.hasStaticShape() &&
3741 input_ty_i.getShape() != input_ty_0.getShape()) {
3742 return op.emitOpError()
3743 << "input[" << i << "] has shape [" << input_ty_i.getShape()
3744 << "] different than the shape of input[0]: "
3745 << input_ty_0.getShape();
3746 }
3747 }
3748 }
3749
3750 ElementsAttr dimension;
3751 if (matchPattern(op.dimension(), m_Constant(&dimension))) {
3752 if (dimension.getType().getRank() != 0 ||
3753 dimension.getType().getNumElements() != 1)
3754 return op.emitOpError() << "dimension must be a scalar";
3755 }
3756
3757 auto module = op->getParentOfType<mlir::ModuleOp>();
3758 auto function = dyn_cast_or_null<mlir::func::FuncOp>(
3759 SymbolTable::lookupSymbolIn(module, op.comparator()));
3760 if (!function) return op.emitOpError() << "No comparator";
3761 if (!function.getBody().hasOneBlock())
3762 return op.emitOpError() << "comparator has more than one block";
3763
3764 return success();
3765 }
3766
3767 //===----------------------------------------------------------------------===//
3768 // SetStaticDimensionBoundsOp
3769 //===----------------------------------------------------------------------===//
3770 //
3771
verify()3772 LogicalResult SetStaticDimensionBoundsOp::verify() {
3773 SetStaticDimensionBoundsOp op = *this;
3774 mlir::ShapedType input_type = op.input().getType().cast<mlir::ShapedType>();
3775 mlir::ShapedType static_shape_type =
3776 op.static_shape().getType().cast<mlir::ShapedType>();
3777 int input_type_rank = input_type.hasRank() ? input_type.getRank() : -1;
3778 if (input_type_rank > 2) {
3779 return op.emitOpError() << "was used with an input tensor with rank > 2, "
3780 "only tensors of rank 1,2 are supported";
3781 }
3782
3783 if (static_shape_type.hasRank() && static_shape_type.getRank() != 1) {
3784 return op.emitOpError("static shape must be of rank 1 (vector)");
3785 }
3786 if (input_type_rank != -1 && static_shape_type.hasStaticShape()) {
3787 if (static_shape_type.getShape()[0] != input_type_rank) {
3788 return op.emitOpError(
3789 "static shape must have num_elements == rank of input "
3790 "tensor");
3791 }
3792 }
3793
3794 return success();
3795 }
3796
3797 } // namespace TF
3798 } // namespace mlir
3799
3800 //===----------------------------------------------------------------------===//
3801 // TableGen'd op method definitions
3802 //===----------------------------------------------------------------------===//
3803
3804 #define GET_OP_CLASSES
3805 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc"
3806