1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/Sequence.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringRef.h"
38 #include "llvm/ADT/StringSwitch.h"
39 #include "llvm/ADT/Twine.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
44 #include "mlir/Dialect/Traits.h" // from @llvm-project
45 #include "mlir/IR/Attributes.h" // from @llvm-project
46 #include "mlir/IR/Builders.h" // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
49 #include "mlir/IR/Diagnostics.h" // from @llvm-project
50 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
51 #include "mlir/IR/Identifier.h" // from @llvm-project
52 #include "mlir/IR/Location.h" // from @llvm-project
53 #include "mlir/IR/MLIRContext.h" // from @llvm-project
54 #include "mlir/IR/Matchers.h" // from @llvm-project
55 #include "mlir/IR/OpDefinition.h" // from @llvm-project
56 #include "mlir/IR/OpImplementation.h" // from @llvm-project
57 #include "mlir/IR/PatternMatch.h" // from @llvm-project
58 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
59 #include "mlir/IR/Types.h" // from @llvm-project
60 #include "mlir/IR/Value.h" // from @llvm-project
61 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
62 #include "mlir/Parser.h" // from @llvm-project
63 #include "mlir/Support/LLVM.h" // from @llvm-project
64 #include "mlir/Support/LogicalResult.h" // from @llvm-project
65 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
71 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
72 #include "tensorflow/core/platform/logging.h"
73 #include "tensorflow/core/util/tensor_format.h"
74
75 namespace mlir {
76 namespace TF {
77
78 namespace {
79 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
80 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
81 } // namespace
82
83 //===----------------------------------------------------------------------===//
84 // NotEqualOp
85 //===----------------------------------------------------------------------===//
86
Verify(NotEqualOp op)87 static LogicalResult Verify(NotEqualOp op) {
88 // If we allow inputs to have incompatible type, then nothing to do.
89 if (!op.incompatible_shape_error()) return success();
90
91 // Otherwise, check inputs are broadcastable.
92 return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
93 op.getOperation());
94 }
95
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)96 void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
97 Value y, BoolAttr incompatible_shape_error) {
98 auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
99 incompatible_shape_error);
100 return build(builder, result, result_type, x, y, incompatible_shape_error);
101 }
102
103 //===----------------------------------------------------------------------===//
104 // OneHotOp
105 //===----------------------------------------------------------------------===//
106
Verify(OneHotOp op)107 static LogicalResult Verify(OneHotOp op) {
108 int64_t axis = op.axis();
109
110 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
111 if (indices_ty &&
112 !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
113 return op.emitOpError()
114 << "expected axis (" << axis << ") to be -1 or between [0, "
115 << indices_ty.getShape().size() << "]";
116 }
117
118 if (axis < -1) {
119 return op.emitOpError() << "expected axis (" << axis
120 << ") to be -1 or between [0, rank(indices()))";
121 }
122
123 if (!IsOfRankOrUnranked(op.depth(), 0)) {
124 return op.emitOpError() << "requires depth to be a scalar";
125 }
126 if (!IsOfRankOrUnranked(op.on_value(), 0)) {
127 return op.emitOpError() << "requires on_value to be a scalar";
128 }
129 if (!IsOfRankOrUnranked(op.off_value(), 0)) {
130 return op.emitOpError() << "requires off_value to be a scalar";
131 }
132
133 DenseIntElementsAttr depth_attr;
134 if (matchPattern(op.depth(), m_Constant(&depth_attr))) {
135 if (depth_attr.getType().getRank() != 0)
136 return op.emitOpError() << "requires depth to be a scalar";
137 int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
138 if (depth < 0) {
139 return op.emitOpError() << "depth must be non-negative, got: " << depth;
140 }
141 }
142
143 return success();
144 }
145
InferOneHotOpType(Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)146 static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
147 Value off_value, IntegerAttr axis) {
148 int64_t axis_val = axis.getInt();
149 Type element_ty = on_value.getType().cast<TensorType>().getElementType();
150 auto unranked_ty = UnrankedTensorType::get(element_ty);
151 if (axis_val < -1) return unranked_ty;
152
153 auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
154 if (!indices_ty) return unranked_ty;
155
156 auto shape = llvm::to_vector<2>(indices_ty.getShape());
157 if (axis_val == -1) axis_val = shape.size();
158
159 int64_t depth_val = ShapedType::kDynamicSize;
160 DenseIntElementsAttr depth_attr;
161 if (matchPattern(depth, m_Constant(&depth_attr)) &&
162 depth_attr.getNumElements() == 1)
163 depth_val = (*depth_attr.begin()).getSExtValue();
164 shape.insert(shape.begin() + axis_val, depth_val);
165 return RankedTensorType::get(shape, element_ty);
166 }
167
build(OpBuilder & builder,OperationState & result,Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)168 void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices,
169 Value depth, Value on_value, Value off_value,
170 IntegerAttr axis) {
171 build(builder, result,
172 InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
173 depth, on_value, off_value, axis);
174 }
175
176 //===----------------------------------------------------------------------===//
177 // PackOp
178 //===----------------------------------------------------------------------===//
179
Verify(PackOp op)180 static LogicalResult Verify(PackOp op) {
181 // TODO(hinsu): Convert variadic length attributes to derived attributes.
182 Operation::operand_range values = op.values();
183
184 if (failed(VerifyTypesCompatibility(values,
185 /*mask_one_dim=*/false,
186 op.getOperation()))) {
187 return failure();
188 }
189
190 int64_t inputs_rank = -1;
191 for (Value value : values) {
192 if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
193 // Exit early as input types are verified to be compatible so all ranked
194 // tensors have the same rank.
195 inputs_rank = ty.getRank();
196 break;
197 }
198 }
199 if (inputs_rank == -1) return success();
200
201 // The values can be packed along any of the dimensions between 0 and
202 // inputs rank, inclusive. Also, as the negative axis values wrap around so
203 // the axis value range is [-(R+1), R+1).
204 int64_t range_begin = -inputs_rank - 1; // Inclusive
205 int64_t range_end = inputs_rank + 1; // Exclusive
206 int64_t axis = op.axis();
207 if (axis < range_begin || axis >= range_end) {
208 return op.emitError() << "attribute 'axis' should be within range ["
209 << range_begin << ", " << range_end
210 << "); actual value: " << axis;
211 }
212
213 return success();
214 }
215
fold(ArrayRef<Attribute> operands)216 OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
217 // Fold pack operation if it computes the input tensor shape:
218 //
219 // %shape = tf.Shape(%arg) // [? x ...]
220 // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
221 // %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
222 //
223 // Where `...` are some statically known dimensions. In this case %pack can be
224 // replaced with a %shape. This is a common pattern in models with a dynamic
225 // batch size.
226
227 // Pack operation should pack at least two values.
228 if (values().size() < 2) return {};
229
230 // Dimensions packed along axis = 0 (pack scalars into vector).
231 if (axis() != 0) return {};
232
233 // First packed value is defined by a strided slice operation.
234 auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
235 if (!slice_op) return {};
236
237 // Input to the slice op is defined by shape operation.
238 auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
239 if (!shape_op) return {};
240
241 // Input tensor, which shape is reconstructed by the pack operation.
242 Value tensor = shape_op.input();
243
244 // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
245 // scalar value from input vector).
246 if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 ||
247 slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 ||
248 slice_op.shrink_axis_mask() != 1)
249 return {};
250
251 // Returns a value if the `value` is defined by a ConstOp with a single
252 // integer element in it and has an expected rank.
253 auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
254 auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
255 if (!const_op) return None;
256
257 auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
258 if (!value_attr || value_attr.getNumElements() != 1) return None;
259
260 auto value_ty = value_attr.getType();
261 if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
262
263 auto splat = value_attr.getSplatValue<IntegerAttr>();
264 return splat.getValue().getSExtValue();
265 };
266
267 // All other packed values are scalar constants.
268 SmallVector<int64_t, 4> packed_dims;
269 packed_dims.reserve(values().size() - 1);
270 for (Value operand : llvm::drop_begin(values(), 1)) {
271 if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
272 packed_dims.push_back(*dim);
273 } else {
274 return {};
275 }
276 }
277
278 // Slice exactly the first shape dimension:
279 // begin = [0] end = [1], strides = [1]
280 auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
281 auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
282 auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
283 if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
284 *begin != 0 || *end != 1 || *strides != 1)
285 return {};
286
287 // First tensor dimension is dynamic.
288 auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
289 if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
290 !arg_ty.isDynamicDim(0))
291 return {};
292
293 // Argument tensor rank is equal to the number of packed dimensions.
294 if (arg_ty.getRank() != values().size()) return {};
295
296 // All other dimensions are statically known and equal to packed dims.
297 auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
298 if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
299 return {};
300
301 // Replace %pack with %shape.
302 return slice_op.input();
303 }
304
305 //===----------------------------------------------------------------------===//
306 // PadOp
307 //===----------------------------------------------------------------------===//
308
FoldOperandsPermutation(ArrayRef<int64_t> permutation)309 LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
310 // Paddings must be defined by a constant operation.
311 auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
312 if (!paddings_op) return failure();
313
314 auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
315 if (!paddings_value ||
316 paddings_value.getNumElements() != permutation.size() * 2)
317 return failure();
318
319 SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
320 for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) {
321 size_t outer_idx = index_pair.index() / 2;
322 size_t inner_idx = index_pair.index() % 2;
323
324 shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
325 index_pair.value().getSExtValue();
326 }
327
328 // Add constant operation with a new paddings.
329 OpBuilder builder(getOperation());
330 auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
331 builder.getIntegerType(32));
332 auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
333 auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
334
335 // Use new paddings.
336 setOperand(1, shuffled_paddings_op);
337
338 // Change the result type.
339 getResult().setType(ShuffleRankedTensorType(getResult().getType(),
340 ReversePermutation(permutation)));
341
342 return success();
343 }
344
345 //===----------------------------------------------------------------------===//
346 // ParseExampleV2Op
347 //===----------------------------------------------------------------------===//
348
Verify(ParseExampleV2Op op)349 static LogicalResult Verify(ParseExampleV2Op op) {
350 // NOTE(mrry): This validates properties of an op that would previously be
351 // validated by the TensorFlow OpDef type checker. In addition to these
352 // checks, the shape inference function for ParseExampleV2 validates the
353 // consistency of the argument and result types.
354
355 // Validate dense variadic input and output lengths.
356 // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
357 // do not need to validate dense_defaults.
358 auto dense_types_count =
359 std::distance(op.Tdense().begin(), op.Tdense().end());
360 auto dense_values_count =
361 std::distance(op.dense_values().begin(), op.dense_values().end());
362 if (dense_values_count != dense_types_count) {
363 return op.emitError() << "output 'dense_values' should have same length "
364 << "as attribute 'Tdense'";
365 }
366
367 // Validate sparse variadic output lengths.
368 // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
369 // do not need to validate sparse_values.
370 auto sparse_types_count =
371 std::distance(op.sparse_types().begin(), op.sparse_types().end());
372 if (op.num_sparse() != sparse_types_count) {
373 return op.emitError() << "attribute 'num_sparse' should be the same as "
374 << "the length of attribute 'sparse_types'";
375 }
376 if (op.sparse_indices().size() != sparse_types_count) {
377 return op.emitError() << "output 'sparse_indices' should have same length "
378 << "as attribute 'sparse_types'";
379 }
380 if (op.sparse_shapes().size() != sparse_types_count) {
381 return op.emitError() << "output 'sparse_shapes' should have same length "
382 << "as attribute 'sparse_types'";
383 }
384
385 // Validate ragged variadic output lengths.
386 auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
387 op.ragged_value_types().end());
388 auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
389 op.ragged_split_types().end());
390 if (ragged_value_types_count != ragged_split_types_count) {
391 return op.emitError() << "attribute 'ragged_value_types' should have same "
392 << "length as attribute 'ragged_split_types'";
393 }
394
395 return success();
396 }
397
398 //===----------------------------------------------------------------------===//
399 // PartitionedCallOp
400 //===----------------------------------------------------------------------===//
401
402 template <class OpClass>
VerifyPartitionedCall(OpClass op)403 static LogicalResult VerifyPartitionedCall(OpClass op) {
404 auto module = op->template getParentOfType<ModuleOp>();
405 SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
406
407 auto function =
408 dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
409
410 if (!function) {
411 return op.emitError("'f' attribute refers to an undefined function: ")
412 << func;
413 }
414
415 FunctionType function_ty = function.getType();
416 int func_arg_count = function_ty.getNumInputs();
417 int arg_count = op.args().size();
418
419 if (arg_count != func_arg_count) {
420 return op.emitError() << "argument count mismatch: 'args' has " << arg_count
421 << " arguments, but '" << func << "' expects "
422 << func_arg_count;
423 }
424
425 return success();
426 }
427
428 //===----------------------------------------------------------------------===//
429 // PowOp
430 //===----------------------------------------------------------------------===//
431
fold(ArrayRef<Attribute> operands)432 OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
433 auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
434 if (constant_y && constant_y.isSplat()) {
435 APFloat y_value = constant_y.getSplatValue<APFloat>();
436 auto output_type = getType().cast<ShapedType>();
437 if (y_value.isZero() && output_type.hasStaticShape()) {
438 return DenseElementsAttr::get(
439 output_type,
440 FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
441 }
442 if (y_value.isExactlyValue(1.0)) {
443 return x();
444 }
445 }
446 return {};
447 }
448
449 //===----------------------------------------------------------------------===//
450 // QrOp
451 //===----------------------------------------------------------------------===//
452
453 // Verifies that,
454 //
455 // * Input type, if ranked, must have at least 2 dimensions and at most
456 // INT32_MAX dimensions.
457 //
Verify(QrOp op)458 static LogicalResult Verify(QrOp op) {
459 auto ttype = op.input().getType().cast<TensorType>();
460 if (!ttype.hasRank()) return success();
461 if (!HasRankAtLeast(op.input(), 2))
462 return op.emitOpError(
463 "requires ranked input tensor to be of rank 2 or more");
464 if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
465 return op.emitOpError(
466 "requires ranked input tensor to be of rank INT32_MAX or less");
467
468 return success();
469 }
470
471 //===----------------------------------------------------------------------===//
472 // ReadVariableOp
473 //===----------------------------------------------------------------------===//
474
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)475 void ReadVariableOp::getCanonicalizationPatterns(
476 OwningRewritePatternList &results, MLIRContext *context) {
477 results.insert<ReadVariableOfCast>(context);
478 }
479
480 //===----------------------------------------------------------------------===//
481 // RandomUniformOp
482 //===----------------------------------------------------------------------===//
483
Verify(RandomUniformOp op)484 static LogicalResult Verify(RandomUniformOp op) {
485 if (!IsOfRankOrUnranked(op.shape(), 1))
486 return op.emitOpError("shape must be 1D tensor");
487 return success();
488 }
489
490 //===----------------------------------------------------------------------===//
491 // RangeOp
492 //===----------------------------------------------------------------------===//
493
build(OpBuilder & builder,OperationState & result,Value start,Value limit,Value delta)494 void RangeOp::build(OpBuilder &builder, OperationState &result, Value start,
495 Value limit, Value delta) {
496 assert(start.getType() == limit.getType());
497 assert(start.getType() == delta.getType());
498 DenseIntElementsAttr start_val;
499 DenseIntElementsAttr limit_val;
500 DenseIntElementsAttr delta_val;
501 if (matchPattern(start, m_Constant(&start_val)) &&
502 matchPattern(limit, m_Constant(&limit_val)) &&
503 matchPattern(delta, m_Constant(&delta_val))) {
504 auto size = llvm::APIntOps::RoundingSDiv(
505 *limit_val.begin() - *start_val.begin(), *delta_val.begin(),
506 llvm::APInt::Rounding::DOWN);
507 return RangeOp::build(
508 builder, result,
509 RankedTensorType::get(
510 size.getSExtValue(),
511 start.getType().cast<TensorType>().getElementType()),
512 start, limit, delta);
513 }
514 return RangeOp::build(
515 builder, result,
516 RankedTensorType::get(
517 {-1}, start.getType().cast<TensorType>().getElementType()),
518 start, limit, delta);
519 }
520 //===----------------------------------------------------------------------===//
521 // RankOp
522 //===----------------------------------------------------------------------===//
523
build(OpBuilder & builder,OperationState & result,Value input)524 void RankOp::build(OpBuilder &builder, OperationState &result, Value input) {
525 return RankOp::build(builder, result,
526 RankedTensorType::get({}, builder.getIntegerType(32)),
527 input);
528 }
529
530 // This will create a constant value for RankOp of a ranked tensor.
fold(ArrayRef<Attribute> operands)531 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
532 auto type = input().getType();
533 auto ranked_type = type.dyn_cast<RankedTensorType>();
534 if (!ranked_type) return {};
535
536 // DenseIntElementsAttr::get requires the output type be ranked with static
537 // shape.
538 auto output_type = getType().dyn_cast<RankedTensorType>();
539 if (!output_type || !output_type.hasStaticShape()) return {};
540
541 int32_t rank = ranked_type.getRank();
542 return DenseIntElementsAttr::get(output_type, rank);
543 }
544
545 //===----------------------------------------------------------------------===//
546 // RealDivOp
547 //===----------------------------------------------------------------------===//
548
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)549 void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
550 MLIRContext *context) {
551 results.insert<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
552 }
553
fold(ArrayRef<Attribute> operands)554 OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
555 return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
556 }
557
558 //===----------------------------------------------------------------------===//
559 // ReluOp
560 //===----------------------------------------------------------------------===//
561
GetContractionFusion()562 Optional<ContractionFusion> ReluOp::GetContractionFusion() {
563 // Only f32 is supported for fusion.
564 if (!T().isF32()) return None;
565
566 return ContractionFusion("Relu", /*additional_arguments=*/{});
567 }
568
569 //===----------------------------------------------------------------------===//
570 // ReshapeOp
571 //===----------------------------------------------------------------------===//
572
573 namespace {
574 using ReshapeErrorHandler =
575 llvm::function_ref<LogicalResult(const llvm::Twine &)>;
576
GetReshapeOutputType(Value tensor,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)577 LogicalResult GetReshapeOutputType(Value tensor, Value shape,
578 ReshapeErrorHandler error_handler,
579 TensorType &output_ty) {
580 auto tensor_ty = tensor.getType().cast<TensorType>();
581 auto element_ty = tensor_ty.getElementType();
582 output_ty = UnrankedTensorType::get(element_ty);
583
584 auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
585 if (!shape_ty) return success();
586 if (shape_ty.getRank() != 1)
587 return error_handler(llvm::formatv(
588 "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
589
590 DenseIntElementsAttr shape_attr;
591 if (!matchPattern(shape, m_Constant(&shape_attr))) {
592 // If only shape of `shape` is known, return ranked but dynamic output
593 // shape.
594 if (shape_ty.hasStaticShape()) {
595 llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
596 ShapedType::kDynamicSize);
597 output_ty = RankedTensorType::get(dynamic_shape, element_ty);
598 }
599 return success();
600 }
601
602 // Detect if reshape output shape is folded.
603 bool shape_ty_zero_dim = false;
604 int unknown_index = -1;
605 // The product of constant shape argument excluding unknown dimension.
606 int64_t shape_ty_size = 1;
607 llvm::SmallVector<int64_t, 8> output_ty_shape;
608 output_ty_shape.reserve(shape_attr.getNumElements());
609 for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) {
610 const int64_t size = dim.value().getSExtValue();
611 if (size == ShapedType::kDynamicSize) {
612 if (unknown_index != -1)
613 return error_handler(llvm::formatv(
614 "requires 'shape' to have at most one dynamic dimension, but got "
615 "multiple dynamic dimensions at indices {0} and {1}",
616 unknown_index, dim.index()));
617
618 unknown_index = dim.index();
619 } else if (size == 0) {
620 shape_ty_zero_dim = true;
621 } else if (size > 0) {
622 shape_ty_size *= size;
623 } else {
624 return error_handler(
625 llvm::formatv("requires 'shape' to have dimensions greater than -1, "
626 "but got {0} at index {1}",
627 size, dim.index()));
628 }
629 output_ty_shape.push_back(size);
630 }
631
632 if (!tensor_ty.hasStaticShape()) {
633 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
634 return success();
635 }
636
637 // Compute the value of the unknown dimension.
638 if (unknown_index != -1) {
639 // Compute number of elements in tensor shape.
640 int64_t tensor_ty_size = 1;
641 bool tensor_ty_zero_dim = false;
642 for (const auto &dim : tensor_ty.getShape()) {
643 if (dim > 0 || !shape_ty_zero_dim) {
644 tensor_ty_size *= dim;
645 } else {
646 tensor_ty_zero_dim = true;
647 }
648 }
649
650 const int64_t missing_dim = tensor_ty_size / shape_ty_size;
651 if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size)
652 return error_handler(
653 llvm::formatv("requires 'tensor' number of elements be a multiple of "
654 "{0}, but got {1}",
655 shape_ty_size, tensor_ty_size));
656
657 // Set the unknown dimension such that total number of elements remain
658 // constant.
659 output_ty_shape[unknown_index] = missing_dim;
660 }
661
662 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
663
664 return success();
665 }
666 } // namespace
667
Verify(ReshapeOp op)668 static LogicalResult Verify(ReshapeOp op) {
669 auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
670 return op.emitOpError() << message;
671 };
672 TensorType expected_ty;
673 if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler,
674 expected_ty)))
675 return failure();
676
677 auto output_ty = op.getType().dyn_cast<RankedTensorType>();
678 if (!output_ty) return success();
679 auto tensor_ty = op.tensor().getType().cast<TensorType>();
680 if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) {
681 const int64_t output_ty_size = output_ty.getNumElements();
682 const int64_t tensor_ty_size = tensor_ty.getNumElements();
683 if (tensor_ty_size != output_ty_size)
684 return op.emitOpError() << "requires 'output' number of elements to "
685 "match 'tensor' number of elements, but got "
686 << output_ty_size << " and " << tensor_ty_size;
687 }
688
689 if (!AreCastCompatible({output_ty, expected_ty}))
690 return op.emitOpError()
691 << "requires 'output' type " << output_ty
692 << " to be cast compatible with expected type " << expected_ty;
693
694 return success();
695 }
696
697 // Currently there are use cases that rely on partial evaluation of the `shape`
698 // operand, so InferTypeOpInterface is not used (along with generated builder of
699 // the same signature).
build(OpBuilder & builder,OperationState & result,Value tensor,Value shape)700 void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
701 Value shape) {
702 auto error_handler = [&result](const llvm::Twine &message) {
703 return mlir::emitError(result.location) << message;
704 };
705 TensorType output_ty;
706 if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty)))
707 return;
708
709 return ReshapeOp::build(builder, result, output_ty, tensor, shape);
710 }
711
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)712 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
713 MLIRContext *context) {
714 results.insert<RedundantReshape, ReshapeToSelfShape>(context);
715 }
716
fold(ArrayRef<Attribute> operands)717 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
718 Value tensor = this->tensor();
719
720 // Fold reshape if operand and result types are the same and all dimensions
721 // are statically known (no-op reshape).
722 auto result_ty = getType().dyn_cast<ShapedType>();
723 if (result_ty && result_ty.hasStaticShape() &&
724 result_ty == tensor.getType()) {
725 return tensor;
726 }
727
728 return {};
729 }
730
731 //===----------------------------------------------------------------------===//
732 // SelectOp
733 //===----------------------------------------------------------------------===//
734
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)735 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
736 MLIRContext *context) {
737 results.insert<SelectToSelectV2>(context);
738 }
739
740 // Verifies a few extra requirements on SelectOp:
741 // (1) `then` and `else` must have same shape
742 // (2) At least one of the following must be true:
743 // (a) `cond` has the same rank as `then` and `else`
744 // (b) `cond` is a scalar
745 // (c) `cond` is a vector AND `then` and `else` are non-scalar with their
746 // first dimension equal to `cond`.
Verify(SelectOp op)747 static LogicalResult Verify(SelectOp op) {
748 auto then_tensor = op.t().getType().cast<TensorType>();
749 auto else_tensor = op.e().getType().cast<TensorType>();
750 // Check (1).
751 if (!AreCastCompatible({then_tensor, else_tensor}))
752 return op.emitOpError() << "requires t and e have compatible shapes";
753
754 // Get data rank (if exists).
755 int data_rank;
756 // If data is unranked or data_rank is 0, this will remain -2. Otherwise
757 // refers to first dimension of then and/or else.
758 int data_first_dim = -2;
759 bool then_has_rank = then_tensor.hasRank();
760 bool else_has_rank = else_tensor.hasRank();
761 if (then_has_rank && else_has_rank) {
762 data_rank = then_tensor.getRank();
763 if (then_tensor.getRank() > 0)
764 data_first_dim = then_tensor.getShape().front();
765 if (else_tensor.getRank() > 0)
766 data_first_dim = std::max(
767 static_cast<int>(else_tensor.getShape().front()), data_first_dim);
768 } else if (then_has_rank) {
769 data_rank = then_tensor.getRank();
770 if (then_tensor.getRank() > 0)
771 data_first_dim = then_tensor.getShape().front();
772 } else if (else_has_rank) {
773 data_rank = else_tensor.getRank();
774 if (else_tensor.getRank() > 0)
775 data_first_dim = else_tensor.getShape().front();
776 } else {
777 // Neither has a rank.
778 return success();
779 }
780
781 auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
782 if (!cond_tensor) return success();
783 auto cond_rank = cond_tensor.getRank();
784 // Check (2a) and (2b).
785 if (cond_rank == 0 || cond_rank == data_rank) return success();
786 // Check (2c).
787 if (cond_rank == 1) {
788 auto cond_shape = cond_tensor.getShape().front();
789 if (data_rank == 0) {
790 return op.emitOpError()
791 << "requires that t and e are nonscalar when pred is a vector";
792 }
793 // We know `data` tensor has a rank of at least 1.
794 if (data_first_dim != -1 && cond_shape != -1 &&
795 data_first_dim != cond_shape) {
796 return op.emitOpError() << "requires that, when pred is a vector, the "
797 "shape matches the first dimension of t and e";
798 }
799 return success();
800 }
801 // None of (2a,b,c) were true; fail.
802 return op.emitOpError() << "requires that pred is a scalar OR has the same "
803 "rank as t and e OR is a vector";
804 }
805
806 //===----------------------------------------------------------------------===//
807 // SelectV2Op
808 //===----------------------------------------------------------------------===//
809
InferSelectV2OpType(Value condition,Value e,Value t)810 static Type InferSelectV2OpType(Value condition, Value e, Value t) {
811 Type element_ty = e.getType().cast<TensorType>().getElementType();
812 auto unranked_ty = UnrankedTensorType::get(element_ty);
813
814 Type broadcasted_ty =
815 OpTrait::util::getBroadcastedType(e.getType(), t.getType());
816 if (!broadcasted_ty) return unranked_ty;
817
818 auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
819 auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
820 if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
821
822 // Explicitly get broadcasted output type as element types of condition may
823 // not be same as the broadcated type's element type.
824 SmallVector<int64_t, 4> result_shape;
825 if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(),
826 broadcasted_ranked_ty.getShape(),
827 result_shape))
828 return unranked_ty;
829 return RankedTensorType::get(result_shape, element_ty);
830 }
831
build(OpBuilder & builder,OperationState & result,Value condition,Value e,Value t)832 void SelectV2Op::build(OpBuilder &builder, OperationState &result,
833 Value condition, Value e, Value t) {
834 build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
835 }
836
837 //===----------------------------------------------------------------------===//
838 // ShapeOp
839 //===----------------------------------------------------------------------===//
840
841 namespace {
842 // Validates Shape/ShapeN/VariableShape operand and associated result types.
VerifyShapeOperandAndResult(Operation * op,Type operand_type,Type result_type,int variadic_idx=-1)843 LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
844 Type result_type,
845 int variadic_idx = -1) {
846 std::string variadic_idx_str =
847 variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str();
848
849 auto result_ranked_type = result_type.dyn_cast<RankedTensorType>();
850 if (!result_ranked_type) return success();
851 if (result_ranked_type.getShape().size() != 1)
852 return op->emitOpError("requires 1D type for result") << variadic_idx_str;
853
854 auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
855 if (operand_ranked_type) {
856 // The operand is a ranked tensor.
857 if (result_ranked_type.hasStaticShape() &&
858 !operand_ranked_type.getShape().empty() &&
859 result_ranked_type.getDimSize(0) !=
860 operand_ranked_type.getShape().size())
861 return op->emitOpError("requires dimension size of result")
862 << variadic_idx_str << " to match rank of operand"
863 << variadic_idx_str;
864 } else if (result_ranked_type.hasStaticShape()) {
865 // The operand is an unranked tensor, print a warning if the result
866 // is static.
867 // Note: We do not handle this situation as an error, this would be too
868 // restrictive due to incompleteness of shape inference at this point.
869 op->emitWarning("has static shape result")
870 << variadic_idx_str << " for unranked operand" << variadic_idx_str;
871 }
872
873 Type element_type = result_ranked_type.getElementType();
874 if (!element_type.isSignlessInteger(32) &&
875 !element_type.isSignlessInteger(64))
876 return op->emitOpError("requires int32 or int64 return type for result")
877 << variadic_idx_str;
878
879 return success();
880 }
881 } // anonymous namespace
882
Verify(ShapeOp op)883 static LogicalResult Verify(ShapeOp op) {
884 return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
885 }
886
887 // Converts shape of the given type to attribute if it is of ranked tensor type.
888 // Returned attribute has integer elements of the given width.
ConvertShapeToAttr(Type input_ty,int out_width)889 static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
890 auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
891 if (!ranked_ty || !ranked_ty.hasStaticShape()) return {};
892
893 auto shape = ranked_ty.getShape();
894 int rank = shape.size();
895
896 SmallVector<APInt, 4> dimensions;
897 dimensions.reserve(rank);
898 for (int i = 0; i < rank; ++i)
899 dimensions.push_back(APInt(out_width, shape[i]));
900
901 auto result_type = RankedTensorType::get(
902 {rank}, IntegerType::get(input_ty.getContext(), out_width));
903 return DenseElementsAttr::get(result_type, dimensions);
904 }
905
fold(ArrayRef<Attribute> operands)906 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
907 int width =
908 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
909 return ConvertShapeToAttr(getOperand().getType(), width);
910 }
911
build(OpBuilder & builder,OperationState & result,Value input,BoolAttr use32Bit)912 void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input,
913 BoolAttr use32Bit) {
914 auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
915 int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
916 auto out_type = use32Bit.getValue() ? builder.getIntegerType(32)
917 : builder.getIntegerType(64);
918 return ShapeOp::build(builder, result,
919 RankedTensorType::get({rank}, out_type), input);
920 }
921
922 //===----------------------------------------------------------------------===//
923 // ShapeNOp
924 //===----------------------------------------------------------------------===//
925
Verify(ShapeNOp op)926 static LogicalResult Verify(ShapeNOp op) {
927 const size_t num_tensors = op.N();
928
929 if (op.getNumOperands() != num_tensors)
930 return op.emitOpError() << "requires " << num_tensors << " operand(s), got "
931 << op.getNumOperands() << " operand(s)";
932
933 if (op.getNumResults() != num_tensors)
934 return op.emitOpError() << "requires " << num_tensors << " result(s), got "
935 << op.getNumResults() << " result(s)";
936
937 for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
938 auto verification = VerifyShapeOperandAndResult(
939 op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
940 if (failed(verification)) return verification;
941 }
942
943 return success();
944 }
945
946 namespace {
947 // Canonicalization pattern for ShapeNOp that don't have all
948 // static input shapes. Replacing output values corresponding to static input
949 // types may enable optimizations in users of the values.
950 class ShapeNPartialStaticInputShape : public OpRewritePattern<ShapeNOp> {
951 using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const952 LogicalResult matchAndRewrite(ShapeNOp op,
953 PatternRewriter &rewriter) const override {
954 if (op.getNumOperands() == 0) {
955 rewriter.eraseOp(op);
956 return success();
957 }
958
959 int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
960
961 SmallVector<Value, 4> results(op.getNumOperands());
962 SmallVector<int64_t, 4> dynamic_indices;
963 SmallVector<Value, 4> dynamic_inputs;
964 SmallVector<Type, 4> result_types;
965 for (auto e : llvm::enumerate(op.getOperands())) {
966 if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) {
967 results[e.index()] = rewriter.create<TF::ConstOp>(op.getLoc(), result);
968 } else {
969 dynamic_indices.push_back(e.index());
970 dynamic_inputs.push_back(e.value());
971 result_types.push_back(op.getType(e.index()));
972 }
973 }
974
975 if (dynamic_inputs.size() == op.getNumOperands()) {
976 // Cannot canonicalize ShapeN if all inputs are dynamic.
977 return failure();
978 }
979
980 // Create a ShapeNOp for all dynamic inputs.
981 if (!dynamic_inputs.empty()) {
982 auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
983 op.getLoc(), result_types, dynamic_inputs);
984 for (auto index_result :
985 llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) {
986 results[std::get<0>(index_result)] = std::get<1>(index_result);
987 }
988 }
989
990 rewriter.replaceOp(op, results);
991 return success();
992 }
993 };
994
995 // Canonicalize ShapeNOp to ShapeOp if there is only one operand.
996 class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
997 using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const998 LogicalResult matchAndRewrite(ShapeNOp op,
999 PatternRewriter &rewriter) const override {
1000 if (op.getNumOperands() != 1) {
1001 return failure();
1002 }
1003 auto shape = rewriter.create<TF::ShapeOp>(op.getLoc(), op.getType(0),
1004 op.getOperand(0));
1005 rewriter.replaceOp(op, {shape});
1006 return success();
1007 }
1008 };
1009 } // namespace
1010
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1011 void ShapeNOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1012 MLIRContext *context) {
1013 results.insert<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
1014 }
1015
1016 //===----------------------------------------------------------------------===//
1017 // SizeOp
1018 //===----------------------------------------------------------------------===//
1019
1020 // Verifies that,
1021 //
1022 // * Input type, if is a ranked tensor, has at most INT32_MAX dimensions.
1023 //
Verify(SizeOp op)1024 static LogicalResult Verify(SizeOp op) {
1025 if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
1026 return op.emitOpError(
1027 "requires ranked input tensor to be of rank INT32_MAX or less");
1028
1029 // Output type needs to be scalar.
1030 if (!IsOfRankOrUnranked(op.output(), /*rank=*/0))
1031 return op.emitOpError("requires scalar output");
1032
1033 return success();
1034 }
1035
fold(ArrayRef<Attribute> operands)1036 OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
1037 ShapedType output_type = getType().cast<ShapedType>();
1038 if (!output_type.hasRank()) return {};
1039 ShapedType input_type = getOperand().getType().cast<ShapedType>();
1040 if (!input_type.hasStaticShape()) return {};
1041 int size = input_type.getNumElements();
1042 return DenseElementsAttr::get(
1043 output_type,
1044 IntegerAttr::get(output_type.getElementType(), /*value=*/size));
1045 }
1046
1047 //===----------------------------------------------------------------------===//
1048 // SliceOp
1049 //===----------------------------------------------------------------------===//
1050
1051 // Verifies that:
1052 //
1053 // - operands begin and size are 1D with the same number of elements.
1054 // - if the input is a ranked tensor, the rank of the input equals the number
1055 // of elements in operands begin and size.
1056 // - if begin are constants, that
1057 // 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
1058 // and
1059 // size[i] == output_ty.getShape()[i]
1060 // - if begins aren't constant but the input is a ranked tensor, that
1061 // size[i] <= input_ty.getShape()[i]
1062 // - output rank is the same as input rank
1063 //
Verify(SliceOp op)1064 static LogicalResult Verify(SliceOp op) {
1065 RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
1066 if (begin_ty && begin_ty.getRank() != 1) {
1067 return op.emitOpError() << "requires begin operand to be 1D tensor";
1068 }
1069
1070 RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
1071 if (size_ty && size_ty.getRank() != 1) {
1072 return op.emitOpError() << "requires size operand to be 1D tensor";
1073 }
1074
1075 if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
1076 !size_ty.hasStaticShape())
1077 return success();
1078
1079 if (begin_ty.getNumElements() != size_ty.getNumElements()) {
1080 return op.emitOpError() << "requires begin and size operands to have the"
1081 " same number of elements";
1082 }
1083
1084 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1085 if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
1086 return op.emitOpError() << "requires number of elements in begin and size"
1087 "are equal to input rank";
1088 }
1089
1090 auto output_ty = op.output().getType().dyn_cast<RankedTensorType>();
1091 if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) {
1092 return op.emitOpError()
1093 << "requires output to have the same rank as input, but got input "
1094 "rank "
1095 << input_ty.getRank() << " and output rank " << output_ty.getRank();
1096 }
1097
1098 DenseIntElementsAttr begin_indices;
1099 if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
1100 DenseIntElementsAttr slice_sizes;
1101 bool constant_slice_sizes =
1102 matchPattern(op.size(), m_Constant(&slice_sizes));
1103 int dim = 0;
1104 // TODO(jpienaar): Reformulate the shape verification below to not use magic
1105 // constants.
1106 for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
1107 int64_t begin_index = raw_begin_index.getSExtValue();
1108 int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
1109 int64_t slice_size = constant_slice_sizes
1110 ? slice_sizes.getValue<APInt>(dim).getSExtValue()
1111 : 0;
1112 int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1;
1113
1114 if (slice_size == -1 && input_size != -1) {
1115 slice_size = input_size - begin_index;
1116 }
1117 if (output_size != -1 && constant_slice_sizes &&
1118 output_size != slice_size) {
1119 return op.emitOpError()
1120 << "requires output size to have the same size of slice, got "
1121 "slice size "
1122 << slice_size << " and output size " << output_size;
1123 }
1124 if (begin_index < 0 ||
1125 (input_size != -1 && begin_index + slice_size > input_size)) {
1126 return op.emitOpError()
1127 << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
1128 }
1129 ++dim;
1130 }
1131 } else if (input_ty) {
1132 // If the inputs are ranked, we can do a few more sanity checks.
1133 DenseIntElementsAttr slice_sizes;
1134 if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
1135 auto input_shape = input_ty.getShape();
1136 for (int64_t i = 0; i < input_ty.getRank(); ++i) {
1137 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
1138 int64_t input_size = input_shape[i];
1139 if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
1140 return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
1141 "is unknown at compile time";
1142 }
1143 }
1144 }
1145 }
1146
1147 return success();
1148 }
1149
1150 //===----------------------------------------------------------------------===//
1151 // SoftmaxOp
1152 //===----------------------------------------------------------------------===//
1153
Verify(SoftmaxOp op)1154 static LogicalResult Verify(SoftmaxOp op) {
1155 if (!HasRankAtLeast(op.logits(), 1)) {
1156 return op.emitOpError("requires operand to have rank at least 1");
1157 }
1158 return success();
1159 }
1160
1161 //===----------------------------------------------------------------------===//
1162 // SoftmaxCrossEntropyWithLogitsOp
1163 //===----------------------------------------------------------------------===//
1164
1165 // Verifies that,
1166 //
1167 // * Input types are broadcast compatible and the broadcasted type has rank two.
1168 //
Verify(SoftmaxCrossEntropyWithLogitsOp op)1169 static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
1170 auto broadcasted_ty = OpTrait::util::getBroadcastedType(
1171 op.features().getType(), op.labels().getType())
1172 .dyn_cast_or_null<ShapedType>();
1173 if (!broadcasted_ty ||
1174 (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
1175 return op.emitOpError(
1176 "requires features and labels to be broadcast compatible to rank two");
1177
1178 return success();
1179 }
1180
1181 //===----------------------------------------------------------------------===//
1182 // SpaceToBatchNDOp
1183 //===----------------------------------------------------------------------===//
1184
SpaceToBatchNDBlockRank(const TensorType block_shape_type,const TensorType paddings_type)1185 int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type,
1186 const TensorType paddings_type) {
1187 if (block_shape_type.hasStaticShape()) {
1188 return block_shape_type.getShape()[0];
1189 } else if (paddings_type.hasStaticShape()) {
1190 return paddings_type.getShape()[0];
1191 } else {
1192 return -1;
1193 }
1194 }
1195
Verify(SpaceToBatchNDOp op)1196 static LogicalResult Verify(SpaceToBatchNDOp op) {
1197 const auto input_type = op.input().getType().cast<TensorType>();
1198 const auto block_shape_type = op.block_shape().getType().cast<TensorType>();
1199 const auto paddings_type = op.paddings().getType().cast<TensorType>();
1200
1201 // Check that block_shape has rank 1.
1202 if (!IsOfRankOrUnranked(op.block_shape(), 1)) {
1203 return op.emitOpError() << "requires rank of block_shape = 1; got "
1204 << block_shape_type.getRank();
1205 }
1206
1207 // Check that paddings has rank 2.
1208 if (!IsOfRankOrUnranked(op.paddings(), 2)) {
1209 return op.emitOpError()
1210 << "requires rank of paddings = 2; got " << paddings_type.getRank();
1211 }
1212
1213 // Check that paddings.shape[1]=2.
1214 if (paddings_type.hasStaticShape() && paddings_type.getShape()[1] != 2) {
1215 return op.emitOpError() << "requires paddings.shape[1] to be 2; got "
1216 << paddings_type.getShape()[1];
1217 }
1218
1219 // Check that block_shape and paddings have consistent ranks.
1220 if (block_shape_type.hasStaticShape() && paddings_type.hasStaticShape() &&
1221 block_shape_type.getShape()[0] != paddings_type.getShape()[0]) {
1222 return op.emitOpError()
1223 << "requires block_shape.shape[0] must equal paddings.shape[0]";
1224 }
1225
1226 const int64_t block_rank =
1227 SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
1228
1229 // Further checks require block_rank to be known.
1230 if (block_rank == -1) {
1231 return success();
1232 }
1233
1234 // check that rank of input_type >= block_rank + 1
1235 if (input_type.hasRank() && input_type.getRank() < 1 + block_rank) {
1236 return op.emitOpError() << "requires rank of input >= 1 + rank of block";
1237 }
1238
1239 ElementsAttr block_shape_attr = nullptr;
1240 ElementsAttr paddings_attr = nullptr;
1241
1242 // Check that block_shape[*] >= 1.
1243 if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) {
1244 uint64_t i = 0;
1245 for (auto block_len : block_shape_attr.getValues<APInt>()) {
1246 if (block_len.getSExtValue() < 1) {
1247 return op.emitOpError()
1248 << "requires all values of block_shape to be >= 1; "
1249 "failed for dimension "
1250 << i;
1251 }
1252 ++i;
1253 }
1254 }
1255
1256 // Check that paddings[*] >= 0.
1257 if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) {
1258 for (uint64_t i = 0; i < block_rank; ++i) {
1259 const int64_t pad_start =
1260 paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt();
1261 const int64_t pad_end =
1262 paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
1263 if (pad_start < 0 || pad_end < 0) {
1264 return op.emitOpError()
1265 << "requires all values of paddings to be >= 0; "
1266 "failed for dimension "
1267 << i;
1268 }
1269 }
1270 }
1271
1272 // Check that block_shape divides the padded input.
1273 if (input_type.hasStaticShape() && block_shape_attr && paddings_attr) {
1274 for (uint64_t i = 0; i < block_rank; ++i) {
1275 const int64_t input_len = input_type.getShape()[1 + i];
1276 const int64_t pad_start =
1277 paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt();
1278 const int64_t pad_end =
1279 paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
1280 const int64_t block_len =
1281 block_shape_attr.getValue({i}).cast<IntegerAttr>().getInt();
1282 if ((input_len + pad_start + pad_end) % block_len != 0) {
1283 return op.emitOpError()
1284 << "requires block_shape[i] divides "
1285 "input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; "
1286 "failed for i="
1287 << i;
1288 }
1289 }
1290 }
1291
1292 return success();
1293 }
1294
1295 // Infers returned rank if possible. Further, infers returned dimension sizes
1296 // when possible. For all dimensions sizes to be inferred, the arguments
1297 // block_shape and paddings must be constant.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1298 LogicalResult SpaceToBatchNDOp::inferReturnTypes(
1299 MLIRContext *context, Optional<Location> location, ValueRange operands,
1300 DictionaryAttr attributes, RegionRange regions,
1301 SmallVectorImpl<Type> &inferredReturnTypes) {
1302 const Value input = operands[0];
1303 const Value block_shape_val = operands[1];
1304 const Value paddings_val = operands[2];
1305 const auto input_type = input.getType().cast<TensorType>();
1306 const auto block_shape_type = block_shape_val.getType().cast<TensorType>();
1307 const auto paddings_type = paddings_val.getType().cast<TensorType>();
1308
1309 // The return is unranked when the input is unranked.
1310 if (!input_type.hasRank()) {
1311 inferredReturnTypes.assign(
1312 {UnrankedTensorType::get(input_type.getElementType())});
1313 return success();
1314 }
1315
1316 const int64_t input_rank = input_type.getRank();
1317 const ArrayRef<int64_t> input_shape = input_type.getShape();
1318 const int64_t block_rank =
1319 SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
1320 SmallVector<int64_t, 4> return_shape(input_rank, ShapedType::kDynamicSize);
1321
1322 // The return has all dimension sizes unknown when block_rank is unknown.
1323 if (block_rank == ShapedType::kDynamicSize) {
1324 inferredReturnTypes.assign(
1325 {RankedTensorType::get(return_shape, input_type.getElementType())});
1326 return success();
1327 }
1328
1329 // The return preserves the remaining dimensions after blocked dimensions.
1330 for (uint64_t i = 1 + block_rank; i < input_rank; ++i) {
1331 return_shape[i] = input_shape[i];
1332 }
1333
1334 // The rest of the dimension sizes can be calculated when block_shape and
1335 // paddings arguments are constant.
1336 DenseIntElementsAttr block_shape_attr;
1337 DenseIntElementsAttr paddings_attr;
1338 if (GetValueAsConstant(block_shape_val, block_shape_attr) &&
1339 GetValueAsConstant(paddings_val, paddings_attr)) {
1340 int64_t return_batch = input_shape[0];
1341 for (uint64_t i = 0; i < block_rank; ++i) {
1342 // Propagate dynamic dimension.
1343 if (input_shape[i + 1] == ShapedType::kDynamicSize) {
1344 return_batch = ShapedType::kDynamicSize;
1345 }
1346 if (return_batch == ShapedType::kDynamicSize) {
1347 return_shape[1 + i] = ShapedType::kDynamicSize;
1348 continue;
1349 }
1350 int64_t paddings_sum =
1351 paddings_attr.getValue<APInt>({i, 0}).getSExtValue() +
1352 paddings_attr.getValue<APInt>({i, 1}).getSExtValue();
1353 int64_t block_shape_i =
1354 block_shape_attr.getValue<APInt>({i}).getSExtValue();
1355 return_batch *= block_shape_i;
1356 return_shape[1 + i] = (paddings_sum + input_shape[i + 1]) / block_shape_i;
1357 }
1358 return_shape[0] = return_batch;
1359 }
1360
1361 inferredReturnTypes.assign(
1362 {RankedTensorType::get(return_shape, input_type.getElementType())});
1363 return success();
1364 }
1365
1366 //===----------------------------------------------------------------------===//
1367 // SparseSoftmaxCrossEntropyWithLogitsOp
1368 //===----------------------------------------------------------------------===//
1369
Verify(SparseSoftmaxCrossEntropyWithLogitsOp op)1370 static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) {
1371 if (!IsOfRankOrUnranked(op.features(), 2)) {
1372 return op.emitOpError("requires features operand of rank two");
1373 }
1374 if (!IsOfRankOrUnranked(op.labels(), 1)) {
1375 return op.emitOpError("requires labels operand of rank one");
1376 }
1377 auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
1378 auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
1379 if (features_ty && labels_ty) {
1380 int64_t features_batches = features_ty.getDimSize(0);
1381 int64_t labels_batches = labels_ty.getDimSize(0);
1382 if (!ShapedType::isDynamic(features_batches) &&
1383 !ShapedType::isDynamic(labels_batches) &&
1384 features_batches != labels_batches)
1385 return op.emitOpError(
1386 "requires features and labels with matching first dimension");
1387 }
1388 return success();
1389 }
1390
1391 //===----------------------------------------------------------------------===//
1392 // SplitOp
1393 //===----------------------------------------------------------------------===//
1394
1395 // Verifies the input and split dimension operands for tf.Split/tf.SplitV.
1396 // Writes the split dimension's index (adjusted with input rank) via `dim_index`
1397 // if it's a constant.
1398 template <class Op>
VerifySplitInputAndSplitDim(Op op,Optional<int64_t> * dim_index)1399 LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
1400 *dim_index = llvm::None;
1401
1402 Value split_dim = op.split_dim();
1403 if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
1404 if (split_dim_type.getRank() != 0)
1405 return op.emitOpError(
1406 "split dimension should be an integer scalar tensor");
1407
1408 // We can perform further verification if the input tensor to be split has
1409 // known rank and the split dimension tensor is a constant.
1410
1411 auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
1412 if (!input_type) return success();
1413
1414 int64_t input_rank = input_type.getRank();
1415 if (input_rank == 0)
1416 return op.emitOpError("cannot split scalar input tensor");
1417
1418 DenseIntElementsAttr split_dim_attr;
1419 if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
1420
1421 int64_t index = (*split_dim_attr.begin()).getSExtValue();
1422
1423 if (index + input_rank < 0 || index >= input_rank) {
1424 return op.emitOpError("split dimension must be in range [-")
1425 << input_rank << ", " << input_rank << ")";
1426 }
1427
1428 if (index < 0) index += input_rank;
1429 *dim_index = index;
1430
1431 return success();
1432 }
1433
Verify(SplitOp op)1434 static LogicalResult Verify(SplitOp op) {
1435 Optional<int64_t> dim_index;
1436 if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1437 if (!dim_index) return success();
1438
1439 int64_t input_dim_size =
1440 op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1441 if (input_dim_size == ShapedType::kDynamicSize) return success();
1442
1443 if (input_dim_size % op.getNumResults() != 0)
1444 return op.emitOpError("dimension #")
1445 << *dim_index << " not divisible by the number of result tensors";
1446
1447 return success();
1448 }
1449
1450 //===----------------------------------------------------------------------===//
1451 // SplitVOp
1452 //===----------------------------------------------------------------------===//
1453
Verify(SplitVOp op)1454 static LogicalResult Verify(SplitVOp op) {
1455 auto split_sizes_type =
1456 op.size_splits().getType().dyn_cast<RankedTensorType>();
1457 if (!split_sizes_type) return success();
1458
1459 if (split_sizes_type.getRank() != 1 ||
1460 (split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize &&
1461 split_sizes_type.getDimSize(0) != op.getNumResults()))
1462 return op.emitOpError("split sizes should be a 1D tensor of ")
1463 << op.getNumResults() << " elements";
1464
1465 Optional<int64_t> dim_index = 0;
1466 if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1467 if (!dim_index) return success();
1468
1469 int64_t input_dim_size =
1470 op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1471 if (input_dim_size == ShapedType::kDynamicSize) return success();
1472
1473 // If split sizes come from a constant, they must sum to the dimension size
1474 // along split_dim, and we can have no more than one dynamic dimension.
1475 DenseIntElementsAttr split_sizes_attr;
1476 if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
1477 return success();
1478
1479 int64_t total_dim_size = 0; // Total dimension size assigned to splits
1480 llvm::Optional<int> dynamic_dim_index;
1481
1482 SmallVector<int64_t, 4> split_sizes;
1483 split_sizes.reserve(
1484 split_sizes_attr.getType().cast<ShapedType>().getNumElements());
1485
1486 for (auto dim : llvm::enumerate(split_sizes_attr)) {
1487 int64_t dim_val = dim.value().getSExtValue();
1488 split_sizes.push_back(dim_val);
1489 if (dim_val == ShapedType::kDynamicSize) {
1490 // We cannot have more than one dynamic dimension.
1491 if (dynamic_dim_index)
1492 return op.emitOpError(
1493 "cannot have more than one dynamic dimension in split sizes");
1494 dynamic_dim_index = dim.index();
1495 } else {
1496 total_dim_size += dim_val;
1497 }
1498 }
1499
1500 if (!dynamic_dim_index && total_dim_size != input_dim_size)
1501 return op.emitOpError(
1502 "split sizes must sum up to the dimension size along split "
1503 "dimension, found ")
1504 << total_dim_size << " vs " << input_dim_size;
1505
1506 if (dynamic_dim_index && total_dim_size > input_dim_size)
1507 return op.emitOpError(
1508 "split sizes must sum up to be less than or equal to the "
1509 "dimension size along split dimension, found ")
1510 << total_dim_size << " vs " << input_dim_size;
1511
1512 return success();
1513 }
1514
1515 //===----------------------------------------------------------------------===//
1516 // SquareOp
1517 //===----------------------------------------------------------------------===//
1518
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1519 void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1520 MLIRContext *context) {
1521 results.insert<SquareOfSub>(context);
1522 }
1523
1524 //===----------------------------------------------------------------------===//
1525 // SubOp
1526 //===----------------------------------------------------------------------===//
1527
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1528 void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1529 MLIRContext *context) {
1530 results.insert<SubOfNeg>(context);
1531 }
1532
fold(ArrayRef<Attribute> operands)1533 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1534 return IdentityArithmeticOpFolder<SubOp>(*this, operands);
1535 }
1536
1537 //===----------------------------------------------------------------------===//
1538 // SumOp
1539 //===----------------------------------------------------------------------===//
1540
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)1541 void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
1542 Value reduction_indices, BoolAttr keep_dims) {
1543 Type out_ty =
1544 InferReductionOpType(input, reduction_indices, keep_dims, &builder);
1545 build(builder, result, out_ty, input, reduction_indices, keep_dims);
1546 }
1547
1548 // TODO: Templatize this fold for all reduction ops.
fold(ArrayRef<Attribute> operands)1549 OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
1550 auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
1551 if (!input_ty) return {};
1552 auto result_ty = getType().template dyn_cast<RankedTensorType>();
1553 if (!result_ty) return {};
1554
1555 // Bypass this op if the result has the same shape and type. This can happen
1556 // if the input tensor has size 0 or size 1.
1557 if (!keep_dims() && input_ty == result_ty) {
1558 return input();
1559 }
1560 return {};
1561 }
1562
1563 //===----------------------------------------------------------------------===//
1564 // StridedSliceOp
1565 //===----------------------------------------------------------------------===//
1566
1567 // TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
1568 // tf.SliceOp if both of the following are true:
1569 // - All strides have a known value equal to 1
1570 // - No masks are set (or masks can be applied by transforming the inputs to
1571 // Slice)
1572
1573 // Verifies that,
1574 //
1575 // - begin, end and strides operands are 1D and they have the same number of
1576 // elements. Here, the number of elements should be less than 32 to support
1577 // 32-bit mask attributes.
1578 // - None of the strides values are zero.
1579 // - Ellipsis mask can have at most one bit set.
1580
1581 template <class OpTy>
VerifyStridedSliceBase(OpTy op)1582 static LogicalResult VerifyStridedSliceBase(OpTy op) {
1583 // Expected size for operands begin, end and strides vector operands.
1584 int64_t expected_size = -1;
1585
1586 for (Value val : {op.begin(), op.end(), op.strides()}) {
1587 auto operand_ty = val.getType().dyn_cast<ShapedType>();
1588 if (!operand_ty || !operand_ty.hasStaticShape()) {
1589 // TensorFlow constant ops may have non-static shape because the shape is
1590 // not propagated during constant folding. If the defining op for this
1591 // operand is a constant op, use the constant op's attribute to get the
1592 // actual shape.
1593 DenseIntElementsAttr attr;
1594 if (!matchPattern(val, m_Constant(&attr))) continue;
1595 operand_ty = attr.getType();
1596 }
1597
1598 if (operand_ty.getRank() != 1)
1599 return op.emitOpError()
1600 << "requires begin, end and strides to be 1D tensors";
1601
1602 int64_t length = operand_ty.getDimSize(0);
1603 if (length == -1) continue;
1604
1605 if (expected_size == -1) {
1606 // This op uses 32-bit masks.
1607 if (length >= 32)
1608 return op.emitOpError(
1609 "requires begin, end and strides operands with less than 32 "
1610 "elements");
1611
1612 expected_size = length;
1613 } else if (length != expected_size) {
1614 return op.emitOpError() << "requires begin, end and strides to have the "
1615 "same number of elements";
1616 }
1617 }
1618
1619 // If strides are constants, verify that none of the element is zero.
1620 DenseIntElementsAttr strides;
1621 if (matchPattern(op.strides(), m_Constant(&strides))) {
1622 if (llvm::is_contained(strides.getValues<APInt>(), 0))
1623 return op.emitOpError("requires non-zero strides");
1624 }
1625
1626 // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
1627 // exists only no more than one ellipsis.
1628 uint32_t ellipsis_mask = op.ellipsis_mask();
1629 if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
1630 return op.emitOpError("cannot have multiple ellipses");
1631
1632 return success();
1633 }
1634
1635 // Clamps the given `val`: returns `low` if `val` is less than `low`; returns
1636 // `high` if `high` is less than `val`; otherwise returns `val`.
1637 template <class T>
Clamp(const T & val,const T & low,const T & high)1638 constexpr const T &Clamp(const T &val, const T &low, const T &high) {
1639 assert(!(high < low));
1640 return (val < low) ? low : (high < val) ? high : val;
1641 }
1642
1643 // Checks if the `index` bit of `val` is set.
1644 template <class T>
IsSet(const T & val,unsigned index)1645 constexpr bool IsSet(const T &val, unsigned index) {
1646 return (val & (1 << index)) != 0;
1647 }
1648
1649 // Sets the `index` bit of `val`.
1650 template <class T>
Set(T & val,unsigned index)1651 constexpr void Set(T &val, unsigned index) {
1652 val |= (1 << index);
1653 }
1654
1655 // Unset the `index` bit of `val`.
1656 template <class T>
Unset(T & val,unsigned index)1657 constexpr void Unset(T &val, unsigned index) {
1658 val &= ~(1 << index);
1659 }
1660
1661 // Copy the `src_index` bit of `src` to `dst_index` bit of `dst`.
1662 template <class T>
CopyBit(const T & src,unsigned src_index,T & dst,unsigned dst_index)1663 constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
1664 unsigned dst_index) {
1665 if (IsSet(src, src_index))
1666 Set(dst, dst_index);
1667 else
1668 Unset(dst, dst_index);
1669 }
1670
1671 // The sparse spec of strided slice does not correspond to the number of
1672 // dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
1673 // 4, 8) would have dims = 2.
1674 struct SparseSliceSpec {
1675 int64_t dims;
1676 int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
1677 const ArrayRef<int64_t> &begin;
1678 const ArrayRef<int64_t> &end;
1679 const ArrayRef<int64_t> &strides;
1680 };
1681
1682 // The dense spec of strided slice is the canonicalized version of sparse spec.
1683 // The number of dimensions of dense spec correspond to the number of dimensions
1684 // in operand tensor.
1685 struct DenseSliceSpec {
1686 int64_t dims;
1687 int32_t begin_mask, end_mask, shrink_axis_mask;
1688 SmallVectorImpl<int64_t> &begin;
1689 SmallVectorImpl<int64_t> &end;
1690 SmallVectorImpl<int64_t> &strides;
1691 };
1692
1693 // Make a sparse spec into a dense index spec.
1694 // The sparse spec does not correspond to the number of dimensions
1695 // Make a dense spec that corresponds to the number of dimensions
1696 //
1697 // For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then
1698 // we need to produce the missing begin_mask, end_mask for the first two
1699 // dimensions i.e. foo[:, :, 3:, 2].
BuildDenseSliceSpec(const SparseSliceSpec & sparse,DenseSliceSpec * dense)1700 static void BuildDenseSliceSpec(const SparseSliceSpec &sparse,
1701 DenseSliceSpec *dense) {
1702 // Build expanded dense begin, end, strides, begin_mask, end_mask, and
1703 // shrink_axis_mask.
1704 dense->begin.resize(dense->dims);
1705 dense->end.resize(dense->dims);
1706 dense->strides.resize(dense->dims);
1707 dense->begin_mask = 0;
1708 dense->end_mask = 0;
1709 dense->shrink_axis_mask = 0;
1710
1711 // Count number of new_axis after ellipsis. This helps in calculating the
1712 // number of dimensions ellipsis represents in the sparse spec.
1713 bool ellipsis_seen = false;
1714 int num_new_axis_after_ellipsis = 0;
1715 for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1716 if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
1717 num_new_axis_after_ellipsis++;
1718 if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
1719 }
1720
1721 int dense_index = 0;
1722 for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1723 if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
1724 if (IsSet(sparse.ellipsis_mask, sparse_index)) {
1725 auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1726 1 + num_new_axis_after_ellipsis,
1727 dense->dims);
1728 // Expand ellipsis into the appropriate dense indices. From current index
1729 // until next_index, all dimensions would have begin and end masks set and
1730 // stride 1, i.e., get all elements in those dimensions.
1731 for (; dense_index < next_index; ++dense_index) {
1732 dense->begin[dense_index] = dense->end[dense_index] = 0;
1733 dense->strides[dense_index] = 1;
1734 Set(dense->begin_mask, dense_index);
1735 Set(dense->end_mask, dense_index);
1736 }
1737 continue;
1738 }
1739 assert(dense_index < dense->dims);
1740 // Copy over the sparse indices to dense indices if ellipsis_mask and
1741 // new_axis_mask are not set.
1742 dense->begin[dense_index] = sparse.begin[sparse_index];
1743 dense->end[dense_index] = sparse.end[sparse_index];
1744 dense->strides[dense_index] = sparse.strides[sparse_index];
1745 CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index);
1746 CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index);
1747 CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask,
1748 dense_index);
1749 dense_index++;
1750 }
1751 }
1752
1753 // For the given `input_shape`, calculates the sliced shape using the given
1754 // `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and
1755 // `shrink_axis_mask` masks. Updates the result back to `input_shape`. If
1756 // `shrink_axis_mask` is not zero, this function will not drop the corresponding
1757 // dimensions in `input_shape`; it will turn them into 1s. At the same time,
1758 // canonicalizes `begin`, `end`, and `strides. The calculation follows
1759 // tf.StridedSlice op semantics.
CalculateSlicedShapeFromDenseIndices(MutableArrayRef<int64_t> input_shape,int32_t begin_mask,int32_t end_mask,int32_t shrink_axis_mask,MutableArrayRef<int64_t> begin,MutableArrayRef<int64_t> end,MutableArrayRef<int64_t> stride)1760 static void CalculateSlicedShapeFromDenseIndices(
1761 MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
1762 int32_t shrink_axis_mask, MutableArrayRef<int64_t> begin,
1763 MutableArrayRef<int64_t> end, MutableArrayRef<int64_t> stride) {
1764 assert(input_shape.size() <= 32); // Only 32-bit masks are supported.
1765
1766 // Make sure ranges' ranks are consistent with the input.
1767 assert(input_shape.size() == begin.size());
1768 assert(input_shape.size() == end.size());
1769 assert(input_shape.size() == stride.size());
1770
1771 for (int i = 0, e = input_shape.size(); i < e; ++i) {
1772 if (ShapedType::isDynamic(input_shape[i])) continue;
1773
1774 int64_t dim_i = input_shape[i];
1775 int64_t begin_i = begin[i];
1776 int64_t end_i = end[i];
1777 int64_t stride_i = stride[i];
1778
1779 // [0]: mask for begin, [1]: mask for end
1780 int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
1781 // [0]: bound for begin, [1]: bound for end
1782 int64_t bounds[] = {stride_i > 0 ? 0 : -1,
1783 stride_i > 0 ? dim_i : dim_i - 1};
1784
1785 // Canonicalizes the given range `point` (begin/end) according to the
1786 // current dimension. `c` means case: 0 for begin, 1 for end.
1787 auto canonicalize = [&](int64_t point, int c) {
1788 if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
1789
1790 // Add dim as offset to negative range point.
1791 point = point < 0 ? dim_i + point : point;
1792 return Clamp(point, bounds[0], bounds[1]);
1793 };
1794
1795 begin_i = canonicalize(begin_i, 0);
1796 end_i = canonicalize(end_i, 1);
1797
1798 int64_t interval_len = end_i - begin_i;
1799 int64_t size_i = 0;
1800 // If internal length is zero or has different sign from stride, it's a
1801 // degenerated case: we are slicing nothing. Otherwise, calculate the sliced
1802 // size.
1803 if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
1804 size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
1805
1806 begin[i] = begin_i;
1807 if (IsSet(shrink_axis_mask, i)) {
1808 // Shrink this dimension. It means we only take the element at begin_i.
1809 input_shape[i] = 1;
1810 end[i] = begin_i + 1;
1811 stride[i] = 1;
1812 } else {
1813 input_shape[i] = size_i;
1814 end[i] = end_i;
1815 stride[i] = stride_i;
1816 }
1817 }
1818 }
1819
1820 // For the given `input_shape`, calculates the sliced shape using the given
1821 // `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`,
1822 // `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks.
1823 // Updates the result back to `input_shape`.
CalculateSlicedShapeFromSparseIndices(MutableArrayRef<int64_t> input_shape,ArrayRef<int64_t> sparse_begin,ArrayRef<int64_t> sparse_end,ArrayRef<int64_t> sparse_strides,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,SmallVectorImpl<int64_t> * begin,SmallVectorImpl<int64_t> * end,SmallVectorImpl<int64_t> * stride)1824 static void CalculateSlicedShapeFromSparseIndices(
1825 MutableArrayRef<int64_t> input_shape, ArrayRef<int64_t> sparse_begin,
1826 ArrayRef<int64_t> sparse_end, ArrayRef<int64_t> sparse_strides,
1827 int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
1828 int32_t new_axis_mask, int32_t shrink_axis_mask,
1829 SmallVectorImpl<int64_t> *begin, SmallVectorImpl<int64_t> *end,
1830 SmallVectorImpl<int64_t> *stride) {
1831 int64_t num_sparse_indices = sparse_begin.size();
1832 SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask,
1833 ellipsis_mask, new_axis_mask, shrink_axis_mask,
1834 sparse_begin, sparse_end, sparse_strides};
1835
1836 // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
1837 // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
1838 // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
1839 if (sparse.ellipsis_mask == 0) {
1840 Set(sparse.ellipsis_mask, sparse.dims);
1841 sparse.dims++;
1842 }
1843
1844 int64_t dims = input_shape.size();
1845 DenseSliceSpec dense = {dims,
1846 /*begin_mask = */ 0,
1847 /*end_mask = */ 0,
1848 /*shrink_axis_mask = */ 0,
1849 *begin,
1850 *end,
1851 *stride};
1852
1853 BuildDenseSliceSpec(sparse, &dense);
1854 CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask,
1855 dense.end_mask, dense.shrink_axis_mask,
1856 *begin, *end, *stride);
1857 }
1858
GetSlicedBoundRanges(SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)1859 bool StridedSliceOp::GetSlicedBoundRanges(
1860 SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
1861 SmallVectorImpl<int64_t> *slice_stride) {
1862 // TODO(hinsu): Support lowering for ops with dynamic begin and end values
1863 // when it is possible to derive indices based on mask attributes.
1864 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
1865 if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
1866 !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
1867 !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
1868 return false;
1869
1870 auto input_ty = this->input().getType().dyn_cast<RankedTensorType>();
1871 if (!input_ty || !input_ty.hasStaticShape()) return false;
1872 auto input_shape = llvm::to_vector<4>(input_ty.getShape());
1873
1874 SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
1875
1876 for (const APInt &index : sparse_begin_attr)
1877 sparse_begin.push_back(index.getSExtValue());
1878 for (const APInt &index : sparse_end_attr)
1879 sparse_end.push_back(index.getSExtValue());
1880 for (const APInt &stride : sparse_strides_attr)
1881 sparse_strides.push_back(stride.getSExtValue());
1882
1883 CalculateSlicedShapeFromSparseIndices(
1884 input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
1885 end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
1886 slice_begin, slice_end, slice_stride);
1887 return true;
1888 }
1889
fold(ArrayRef<Attribute> operands)1890 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
1891 // Fold StridedSlice operation if it extracts statically known dimensions.
1892 //
1893 // For example,
1894 //
1895 // %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
1896 // %height = tf.StridedSlice(%shape, 1, 2, 1)
1897 //
1898 // In this case %height can be replaced with a constant 2.
1899 //
1900 // Or,
1901 //
1902 // %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
1903 // %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
1904 //
1905 // In this case %spatial_shape can be replaced with a constant [2, 3].
1906
1907 // Input to strided slice op is defined by shape operation.
1908 auto shape_op = input().getDefiningOp<ShapeOp>();
1909 if (!shape_op) {
1910 return {};
1911 }
1912
1913 // `begin`, `end` and `strides` should be constant in order to infer static
1914 // dimension.
1915 DenseIntElementsAttr begin_attr, end_attr, strides_attr;
1916 if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
1917 !matchPattern(end(), m_Constant(&end_attr)) ||
1918 !matchPattern(strides(), m_Constant(&strides_attr)) ||
1919 begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
1920 strides_attr.getNumElements() != 1) {
1921 return {};
1922 }
1923
1924 // Do not fold when `new_axis_mask` is set. It's likely to break the shape
1925 // of output. Typically, `new_axis_mask` is not set in this canonicalization
1926 // pattern.
1927 if (new_axis_mask() != 0) return {};
1928
1929 auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
1930 // Only ranked tensor can be folded.
1931 if (!tensor_ty) return {};
1932
1933 int64_t rank = tensor_ty.getRank();
1934 int64_t begin_int = begin_attr.getValue<APInt>(0).getSExtValue();
1935 int64_t end_int = end_attr.getValue<APInt>(0).getSExtValue();
1936 int64_t strides_int = strides_attr.getValue<APInt>(0).getSExtValue();
1937
1938 // Canonicalize `begin` and `end` in case of negative index.
1939 if (begin_int < 0) begin_int += rank;
1940 if (end_int < 0) end_int += rank;
1941
1942 // Create `begin` and `end` from `*_mask`. Note that we don't care about
1943 // `new_axis_mask` as it can be inferred from `output_ty`.
1944 if (shrink_axis_mask() == 1) {
1945 // When `shrink_axis_mask` is set, output is always a scalar so only
1946 // one element is sliced.
1947 end_int = begin_int + 1;
1948 }
1949 if (begin_mask() == 1) {
1950 begin_int = (strides_int > 0) ? 0 : rank - 1;
1951 }
1952 if (end_mask() == 1) {
1953 end_int = (strides_int > 0) ? rank : -1;
1954 }
1955 if (ellipsis_mask() == 1) {
1956 begin_int = 0;
1957 end_int = rank;
1958 }
1959
1960 // It's possible that `begin` and `end` are out of bound. See
1961 // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
1962 if (strides_int > 0) {
1963 begin_int = std::min(begin_int, rank);
1964 end_int = std::min(end_int, rank);
1965 } else {
1966 begin_int = std::min(begin_int, rank - 1);
1967 end_int = std::min(end_int, rank - 1);
1968 }
1969
1970 SmallVector<int64_t, 2> sub_shape;
1971 // Only handle cases that have something to slice to avoid infinite for-loop.
1972 if ((end_int > begin_int && strides_int > 0) ||
1973 (end_int < begin_int && strides_int < 0)) {
1974 // Extract sub-shape only if all of those dimensions are static.
1975 for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
1976 i += strides_int) {
1977 if (tensor_ty.isDynamicDim(i)) {
1978 return {};
1979 }
1980 sub_shape.push_back(tensor_ty.getDimSize(i));
1981 }
1982 }
1983
1984 // For unranked or dynamic output, we infer the output type to either a
1985 // scalar or a vector based on `shrink_axis_mask` because we have rejected
1986 // the case of `new_axis_mask` != 0.
1987 auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
1988 auto output_ty = output().getType().dyn_cast<RankedTensorType>();
1989 if (!output_ty || !output_ty.hasStaticShape()) {
1990 if (shrink_axis_mask() == 1) {
1991 output_ty = RankedTensorType::get({}, output_elt_ty);
1992 } else {
1993 output_ty = RankedTensorType::get(
1994 {static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
1995 }
1996 }
1997
1998 // Down-cast to 32 bit int if needed.
1999 if (output_elt_ty.isInteger(32)) {
2000 SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
2001 std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
2002 [](int64_t d) { return static_cast<int32_t>(d); });
2003 return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
2004 }
2005 return DenseIntElementsAttr::get(output_ty, sub_shape);
2006 }
2007
2008 //===----------------------------------------------------------------------===//
2009 // StridedSliceGradOp
2010 //===----------------------------------------------------------------------===//
2011
Verify(StridedSliceGradOp op)2012 static LogicalResult Verify(StridedSliceGradOp op) {
2013 auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
2014 if (shape_type && shape_type.getRank() != 1)
2015 return op.emitOpError("'shape' operand must be 1D tensor, but got ")
2016 << shape_type.getRank() << "D tensor";
2017
2018 if (failed(VerifyStridedSliceBase(op))) return failure();
2019
2020 // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
2021 // the sliced type from StridedSlice.
2022
2023 return success();
2024 }
2025
GetSlicedShapeAndBoundRanges(SmallVectorImpl<int64_t> * input_shape,SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)2026 bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
2027 SmallVectorImpl<int64_t> *input_shape,
2028 SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
2029 SmallVectorImpl<int64_t> *slice_stride) {
2030 DenseIntElementsAttr shape_attr;
2031 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
2032 if (!matchPattern(shape(), m_Constant(&shape_attr)) ||
2033 !matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
2034 !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
2035 !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
2036 return false;
2037
2038 int rank = std::distance(shape_attr.begin(), shape_attr.end());
2039
2040 input_shape->clear();
2041 input_shape->reserve(rank);
2042 for (const APInt &dim : shape_attr)
2043 input_shape->push_back(dim.getSExtValue());
2044
2045 SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
2046
2047 for (const APInt &index : sparse_begin_attr)
2048 sparse_begin.push_back(index.getSExtValue());
2049 for (const APInt &index : sparse_end_attr)
2050 sparse_end.push_back(index.getSExtValue());
2051 for (const APInt &stride : sparse_strides_attr)
2052 sparse_strides.push_back(stride.getSExtValue());
2053
2054 CalculateSlicedShapeFromSparseIndices(
2055 *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
2056 end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
2057 slice_begin, slice_end, slice_stride);
2058 return true;
2059 }
2060
2061 //===----------------------------------------------------------------------===//
2062 // SummaryWriterOp
2063 //===----------------------------------------------------------------------===//
2064
GetResourceHandleValueAndId(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2065 ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId(
2066 llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2067 int64_t &next_id) {
2068 llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2069 return GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2070 writer(), resource_handle_id_map,
2071 next_id);
2072 }
2073
2074 //===----------------------------------------------------------------------===//
2075 // TPUExecuteOp
2076 //===----------------------------------------------------------------------===//
2077
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2078 void TPUExecuteOp::getEffects(
2079 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2080 &effects) {
2081 effects.reserve(args().size() + 1);
2082
2083 // There may be some TPU Embedding ops in the computation, so this effect is
2084 // added conservatively.
2085 effects.emplace_back(MemoryEffects::Write::get(),
2086 ResourceEffects::TPUEmbedding::get());
2087
2088 for (Value value : args()) {
2089 if (value.getType()
2090 .cast<TensorType>()
2091 .getElementType()
2092 .isa<ResourceType>()) {
2093 // Conservatively mark resource handles as read and write, as without
2094 // analyzing TPUCompile, there is not sufficient information to determine
2095 // effects on resources. For the MLIR bridge, this op will never be
2096 // populated with resource handles and tf.TPUExecuteAndUpdateVariables is
2097 // used instead.
2098 effects.emplace_back(MemoryEffects::Read::get(), value,
2099 ResourceEffects::Variable::get());
2100 effects.emplace_back(MemoryEffects::Write::get(), value,
2101 ResourceEffects::Variable::get());
2102 }
2103 }
2104 }
2105
2106 //===----------------------------------------------------------------------===//
2107 // TPUExecuteAndUpdateVariablesOp
2108 //===----------------------------------------------------------------------===//
2109
Verify(TPUExecuteAndUpdateVariablesOp op)2110 static LogicalResult Verify(TPUExecuteAndUpdateVariablesOp op) {
2111 int num_resource_args = 0;
2112 for (Type arg_type : op.args().getTypes())
2113 if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
2114 ++num_resource_args;
2115
2116 auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
2117 int min) -> LogicalResult {
2118 if (indices.size() != num_resource_args)
2119 return op.emitOpError()
2120 << "requires '" << name
2121 << "' to be the same size as number of resource handles in 'args' "
2122 "("
2123 << num_resource_args << "), but got " << indices.size();
2124
2125 for (auto entry : llvm::enumerate(indices.getValue())) {
2126 auto int_attr = entry.value().cast<IntegerAttr>();
2127 if (int_attr.getInt() < min)
2128 return op.emitOpError()
2129 << "requires '" << name << "' to contain values of at least "
2130 << min << ", but got " << int_attr.getInt() << " at index "
2131 << entry.index();
2132 }
2133
2134 return success();
2135 };
2136
2137 return failure(
2138 failed(check_attr(op.device_var_reads_indices(),
2139 /*name=*/"device_var_reads_indices", /*min=*/0)) ||
2140 failed(check_attr(op.device_var_updates_indices(),
2141 /*name=*/"device_var_updates_indices", /*min=*/-1)));
2142 }
2143
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2144 void TPUExecuteAndUpdateVariablesOp::getEffects(
2145 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2146 &effects) {
2147 effects.reserve(device_var_reads_indices().size() + 1);
2148
2149 // There may be some TPU Embedding ops in the computation, so this effect is
2150 // added conservatively.
2151 effects.emplace_back(MemoryEffects::Write::get(),
2152 ResourceEffects::TPUEmbedding::get());
2153 auto resource_handles = llvm::make_filter_range(args(), [](Value value) {
2154 return value.getType()
2155 .cast<TensorType>()
2156 .getElementType()
2157 .isa<ResourceType>();
2158 });
2159
2160 for (auto &entry : llvm::enumerate(resource_handles)) {
2161 Value value = entry.value();
2162 effects.emplace_back(MemoryEffects::Read::get(), value,
2163 ResourceEffects::Variable::get());
2164 if (device_var_updates_indices()
2165 .getValue()[entry.index()]
2166 .cast<IntegerAttr>()
2167 .getInt() >= 0)
2168 effects.emplace_back(MemoryEffects::Write::get(), value,
2169 ResourceEffects::Variable::get());
2170 }
2171 }
2172
2173 //===----------------------------------------------------------------------===//
2174 // TensorListReserveOp
2175 //===----------------------------------------------------------------------===//
2176
Verify(TensorListReserveOp op)2177 static LogicalResult Verify(TensorListReserveOp op) {
2178 if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2179 !IsOfRankOrUnranked(op.element_shape(), 1)) {
2180 return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2181 }
2182
2183 if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
2184 return op.emitOpError("requires num_elements operand to be 0D tensor");
2185 }
2186 return success();
2187 }
2188
2189 //===----------------------------------------------------------------------===//
2190 // TensorListElementShapeOp
2191 //===----------------------------------------------------------------------===//
2192
fold(ArrayRef<Attribute> operands)2193 OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
2194 int width =
2195 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2196 auto variant_type =
2197 getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
2198 if (variant_type.getSubtypes().empty()) return {};
2199 return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
2200 }
2201
2202 //===----------------------------------------------------------------------===//
2203 // TensorListStackOp
2204 //===----------------------------------------------------------------------===//
2205
Verify(TensorListStackOp op)2206 static LogicalResult Verify(TensorListStackOp op) {
2207 if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2208 !IsOfRankOrUnranked(op.element_shape(), 1)) {
2209 return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2210 }
2211 return success();
2212 }
2213
2214 //===----------------------------------------------------------------------===//
2215 // TensorScatterUpdateOp
2216 //===----------------------------------------------------------------------===//
2217
Verify(TensorScatterUpdateOp op)2218 static LogicalResult Verify(TensorScatterUpdateOp op) {
2219 if (!HasRankAtLeast(op.tensor(), 1))
2220 return op.emitOpError(
2221 "requires tensor operand to have at least 1 dimension");
2222 if (!HasRankAtLeast(op.indices(), 1))
2223 return op.emitOpError(
2224 "requires indices operand to have at least 1 dimension");
2225 if (!HasRankAtLeast(op.updates(), 1))
2226 return op.emitOpError(
2227 "requires updates operand to have at least 1 dimension");
2228
2229 auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
2230 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2231 if (!tensor_ty || !indices_ty) return success();
2232
2233 int64_t num_index_dims = indices_ty.getShape().back();
2234 if (ShapedType::isDynamic(num_index_dims)) return success();
2235
2236 if (num_index_dims > tensor_ty.getRank())
2237 return op.emitOpError(
2238 "requires tensor operand with rank greater than or equal to the "
2239 "indices operand's last dimensions");
2240 return success();
2241 }
2242
2243 //===----------------------------------------------------------------------===//
2244 // TileOp
2245 //===----------------------------------------------------------------------===//
2246
2247 // Verifies that,
2248 //
2249 // - input has at least rank 1
2250 // - multiples is rank 1
2251 // - multiples.size() == input.rank()
2252 // - input.rank() == output.rank()
2253 // - Elements in multiples are non-negative
2254 // - input.shape[i] * multiples[i] == output.shape[i]
2255 // for i in [0, input.rank() - 1]
2256
Verify(TileOp op)2257 static LogicalResult Verify(TileOp op) {
2258 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
2259 auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
2260 auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
2261
2262 if (multiples_type && multiples_type.getRank() != 1) {
2263 return op.emitOpError() << "expected multiples to be rank 1, got rank = "
2264 << multiples_type.getRank();
2265 }
2266
2267 if (input_type && multiples_type && multiples_type.hasStaticShape() &&
2268 (input_type.getRank() != multiples_type.getNumElements() ||
2269 (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
2270 return op.emitOpError()
2271 << "expected size of multiples equal to rank of input"
2272 << ", got multiples of size " << multiples_type.getNumElements()
2273 << ", and input of rank " << input_type.getRank();
2274 }
2275
2276 if (input_type && output_type) {
2277 if (input_type.getRank() != output_type.getRank()) {
2278 return op.emitOpError()
2279 << "expected rank of input to equal to rank of output"
2280 << ", got input of rank " << input_type.getRank()
2281 << ", and output of rank " << output_type.getRank();
2282 }
2283
2284 DenseIntElementsAttr multiples_attr;
2285 if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
2286 for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
2287 const int64_t input_dim = input_type.getDimSize(i);
2288 const int64_t output_dim = output_type.getDimSize(i);
2289 const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
2290
2291 if (m < 0) {
2292 return op.emitOpError()
2293 << "expected multiples to be non-negative, got "
2294 << "multiples[" << i << "] = " << m;
2295 }
2296
2297 if (!ShapedType::isDynamic(input_dim) &&
2298 !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
2299 return op.emitOpError()
2300 << "requires input.shape[" << i << "] (" << input_dim << ")"
2301 << " * " << m << " to be equal to "
2302 << "output.shape[" << i << "] (" << output_dim << ")";
2303 }
2304 }
2305 }
2306 }
2307
2308 return success();
2309 }
2310
fold(ArrayRef<Attribute> operands)2311 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
2312 DenseIntElementsAttr multiples_attr;
2313 if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
2314 // Return input directly when multiples are all ones,
2315 // regardless what input is.
2316 if (multiples_attr.isSplat() &&
2317 multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
2318 return input();
2319 }
2320 }
2321 return {};
2322 }
2323
2324 //===----------------------------------------------------------------------===//
2325 // TopKV2Op
2326 //===----------------------------------------------------------------------===//
2327
Verify(TopKV2Op op)2328 static LogicalResult Verify(TopKV2Op op) {
2329 if (!HasRankAtLeast(op.input(), 1))
2330 return op.emitOpError(
2331 "requires input operand to have at least 1 dimension");
2332
2333 if (!IsOfRankOrUnranked(op.k(), 0))
2334 return op.emitOpError("requires k operand to be 0D tensor");
2335
2336 return success();
2337 }
2338
2339 //===----------------------------------------------------------------------===//
2340 // ToBoolOp
2341 //===----------------------------------------------------------------------===//
2342
2343 namespace {
2344 // If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
2345 // into an identity or an equality comparison.
2346 class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
2347 using OpRewritePattern<ToBoolOp>::OpRewritePattern;
matchAndRewrite(ToBoolOp op,PatternRewriter & rewriter) const2348 LogicalResult matchAndRewrite(ToBoolOp op,
2349 PatternRewriter &rewriter) const override {
2350 auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
2351 // If the input is an unranked tensor, cannpt rewrite.
2352 if (!type) return failure();
2353
2354 // Expected return type of the ToBool operation. The return type of ToBool
2355 // operation is always 0D tensor of bool type.
2356 auto result_type = op.getResult().getType().cast<RankedTensorType>();
2357
2358 // If input is already a tensor<i1>, it can be folded into an identity.
2359 if (type == result_type) {
2360 rewriter.replaceOp(op, op.getOperand());
2361 return success();
2362 }
2363
2364 if (type.getRank() == 0) {
2365 // If the input is a scalar tensor, the ToBool can be expanded to
2366 // element != 0 (for numerical values) or element == empty (for string).
2367 Type element_type = type.getElementType();
2368 Attribute zero_attr;
2369 if (element_type.isIntOrFloat())
2370 zero_attr = rewriter.getZeroAttr(type);
2371 else if (element_type.isa<TF::StringType>())
2372 zero_attr = DenseStringElementsAttr::get(type, {""});
2373
2374 if (!zero_attr) return failure();
2375
2376 auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
2377 rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
2378 op, result_type, op.getOperand(), zero_const, false);
2379 } else {
2380 // If the input is a non-scalar ranked tensor, ToBool can be expanded
2381 // to numElements != 0. numElements will be 0 iff one of the dimensions is
2382 // zero.
2383 bool any_zero =
2384 llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
2385 rewriter.replaceOpWithNewOp<TF::ConstOp>(
2386 op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
2387 }
2388 return success();
2389 }
2390 };
2391 } // namespace
2392
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2393 void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2394 MLIRContext *context) {
2395 results.insert<ToBoolOfRankedTensor>(context);
2396 }
2397
2398 //===----------------------------------------------------------------------===//
2399 // TransposeOp
2400 //===----------------------------------------------------------------------===//
2401
Verify(TransposeOp op)2402 static LogicalResult Verify(TransposeOp op) {
2403 auto perm_type = op.perm().getType().dyn_cast<RankedTensorType>();
2404 auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
2405 auto y_type = op.y().getType().dyn_cast<RankedTensorType>();
2406
2407 if (perm_type && perm_type.getRank() != 1) {
2408 return op.emitOpError()
2409 << "expected perm to be a 1-D Tensor, got perm of rank "
2410 << perm_type.getRank();
2411 }
2412
2413 if (x_type && y_type && x_type.getRank() != y_type.getRank()) {
2414 return op.emitOpError() << "x should be of the same rank with y, got "
2415 << "x of rank " << x_type.getRank()
2416 << ", and y of rank " << y_type.getRank();
2417 }
2418
2419 if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) {
2420 return success();
2421 }
2422
2423 if (x_type.getRank() != perm_type.getNumElements()) {
2424 return op.emitOpError() << "expected perm to be a 1-D Tensor of size "
2425 << "equal to the rank of x, got perm of size "
2426 << perm_type.getNumElements() << ", and x of rank "
2427 << x_type.getRank();
2428 }
2429
2430 DenseIntElementsAttr attr_perm;
2431 if (matchPattern(op.perm(), m_Constant(&attr_perm))) {
2432 // y.shape[i] should be equal to x.shape[perm[i]]
2433 // for i = [0, 1, ..., rank(x) - 1]
2434 for (auto e : llvm::enumerate(attr_perm)) {
2435 const int64_t y_idx = e.index();
2436 const int64_t y_dim = y_type.getDimSize(y_idx);
2437 const int64_t x_idx = e.value().getSExtValue();
2438 const int64_t x_dim = x_type.getDimSize(x_idx);
2439 if (y_dim != ShapedType::kDynamicSize &&
2440 x_dim != ShapedType::kDynamicSize && y_dim != x_dim) {
2441 return op.emitOpError()
2442 << "requires y.shape[" << y_idx << "] (" << y_dim << ") "
2443 << "to be equal to x.shape[perm[" << x_idx << "]] "
2444 << "(" << x_dim << ")";
2445 }
2446 }
2447 }
2448
2449 return success();
2450 }
2451
2452 // TODO(jpienaar): perm could be optional too.
build(OpBuilder & builder,OperationState & result,Value x,Value perm)2453 void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
2454 Value perm) {
2455 auto x_type = x.getType().cast<TensorType>();
2456 // If value is unranked, then so is results.
2457 if (!x_type.hasRank())
2458 return TransposeOp::build(builder, result,
2459 UnrankedTensorType::get(x_type.getElementType()),
2460 x, perm);
2461
2462 // TODO(jpienaar): Handle unknown perm case.
2463
2464 // TODO(jpienaar): Extract utility function.
2465 auto etype = x_type.cast<ShapedType>().getElementType();
2466 DenseIntElementsAttr attr_shape;
2467 if (matchPattern(perm, m_Constant(&attr_shape))) {
2468 llvm::SmallVector<int64_t, 4> const_shape;
2469 if (attr_shape.isSplat()) {
2470 const_shape.assign(
2471 attr_shape.getNumElements(),
2472 x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
2473 } else {
2474 const_shape.reserve(attr_shape.getNumElements());
2475 for (const auto &dim : attr_shape)
2476 const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
2477 }
2478 return TransposeOp::build(
2479 builder, result, RankedTensorType::get(const_shape, etype), x, perm);
2480 }
2481 return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x,
2482 perm);
2483 }
2484
2485 namespace {
2486
FoldIdentityTranspose(TransposeOp op)2487 OpFoldResult FoldIdentityTranspose(TransposeOp op) {
2488 DenseIntElementsAttr perm;
2489 if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
2490 const auto elements = perm.getValues<APInt>();
2491
2492 for (auto it : llvm::enumerate(elements)) {
2493 if (it.index() != it.value()) return {};
2494 }
2495
2496 // TODO(jpienaar): Remove if/when we handle this more generally.
2497 if (op.getType() != op.x().getType()) {
2498 // If the types don't match then only fold if all the operands are in the TF
2499 // dialect.
2500 for (auto user : op.getOperation()->getUsers())
2501 if (user->getDialect() != op->getDialect()) return {};
2502 }
2503
2504 return op.x();
2505 }
2506
FoldCancellableTranspose(TransposeOp op)2507 OpFoldResult FoldCancellableTranspose(TransposeOp op) {
2508 // Operand is a TransposeOp.
2509 auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
2510 if (!transpose) return {};
2511
2512 // Permutations defined by constant operations.
2513 DenseIntElementsAttr perm0;
2514 DenseIntElementsAttr perm1;
2515 if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
2516 !matchPattern(transpose.perm(), m_Constant(&perm1)))
2517 return {};
2518
2519 // With permutation indices that cancel each other
2520 if (!AreCancellablePermutations(perm0, perm1)) return {};
2521
2522 return transpose.x();
2523 }
2524
2525 } // namespace
2526
fold(ArrayRef<Attribute> operands)2527 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2528 if (auto folded = FoldIdentityTranspose(*this)) return folded;
2529 if (auto folded = FoldCancellableTranspose(*this)) return folded;
2530 return {};
2531 }
2532
2533 //===----------------------------------------------------------------------===//
2534 // TruncateDivOp
2535 //===----------------------------------------------------------------------===//
2536
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2537 void TruncateDivOp::getCanonicalizationPatterns(
2538 OwningRewritePatternList &results, MLIRContext *context) {
2539 results.insert<TruncateDivWithSqrtDivisor>(context);
2540 }
2541
2542 //===----------------------------------------------------------------------===//
2543 // NonMaxSuppressionV3Op
2544 //===----------------------------------------------------------------------===//
2545
2546 namespace {
2547
2548 // Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op.
2549 class NMSV3ToNMSV4Op : public OpRewritePattern<NonMaxSuppressionV3Op> {
2550 using OpRewritePattern<NonMaxSuppressionV3Op>::OpRewritePattern;
matchAndRewrite(NonMaxSuppressionV3Op nms_op,PatternRewriter & rewriter) const2551 LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op,
2552 PatternRewriter &rewriter) const override {
2553 if (nms_op.getNumOperands() != 5) {
2554 return failure();
2555 }
2556 SmallVector<Type, 2> new_result_types;
2557 new_result_types.push_back(nms_op.getType());
2558 auto input_ty = nms_op.getType().template cast<ShapedType>();
2559 // corresponds to the second result type of nmsv4
2560 RankedTensorType valid_output_type =
2561 RankedTensorType::get({}, input_ty.getElementType());
2562 new_result_types.push_back(valid_output_type);
2563
2564 auto nmsv4 = rewriter.create<TF::NonMaxSuppressionV4Op>(
2565 nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(),
2566 nms_op.max_output_size(), nms_op.iou_threshold(),
2567 nms_op.score_threshold());
2568 // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the
2569 // two are different (v4 expects two output values vs v3 requires only one.
2570 nms_op.replaceAllUsesWith(nmsv4.getResult(0));
2571 return success();
2572 }
2573 };
2574 } // namespace.
2575
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2576 void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
2577 OwningRewritePatternList &results, MLIRContext *context) {
2578 results.insert<NMSV3ToNMSV4Op>(context);
2579 }
2580
2581 //===----------------------------------------------------------------------===//
2582 // FusedBatchNormOp
2583 //===----------------------------------------------------------------------===//
2584
2585 namespace {
2586
2587 class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
2588 using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,PatternRewriter & rewriter) const2589 LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
2590 PatternRewriter &rewriter) const override {
2591 auto new_result_types =
2592 llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
2593 // reserve_space_3
2594 new_result_types.push_back(
2595 UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
2596
2597 OperationState new_state(tf_fused_batch_norm_op.getLoc(),
2598 TF::FusedBatchNormV3Op::getOperationName(),
2599 tf_fused_batch_norm_op.getOperands(),
2600 new_result_types,
2601 tf_fused_batch_norm_op.getAttrs());
2602 Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
2603
2604 rewriter.replaceOp(tf_fused_batch_norm_op,
2605 tf_fused_batch_norm_op_v3->getResults().drop_back());
2606 return success();
2607 }
2608 };
2609 } // namespace.
2610
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2611 void FusedBatchNormOp::getCanonicalizationPatterns(
2612 OwningRewritePatternList &results, MLIRContext *context) {
2613 results.insert<ConvertFusedBatchNorm>(context);
2614 }
2615
2616 //===----------------------------------------------------------------------===//
2617 // UnpackOp
2618 //===----------------------------------------------------------------------===//
2619
Verify(UnpackOp op)2620 static LogicalResult Verify(UnpackOp op) {
2621 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
2622 if (!value_type) return success();
2623
2624 int64_t value_rank = value_type.getRank();
2625 int64_t axis = op.axis();
2626 if (axis < -value_rank || axis >= value_rank)
2627 return op.emitOpError("axis attribute must be in the range of [-")
2628 << value_rank << ", " << value_rank << ')';
2629
2630 axis = GetDimForAxis(axis, value_rank);
2631 int64_t dim_size = value_type.getDimSize(axis);
2632 if (ShapedType::isDynamic(dim_size)) return success();
2633
2634 if (dim_size != op.getNumResults())
2635 return op.emitOpError("result count must be equal to ") << dim_size;
2636
2637 return success();
2638 }
2639
2640 namespace {
2641
2642 // Hoist coefficient-wise unary operation out of the Unpack op:
2643 //
2644 // %unpacked:N = "tf.Unpack"(%0)
2645 // %neg0 = "tf.Neg"(%unpacked#0)
2646 // %neg1 = "tf.Neg"(%unpacked#1)
2647 // ...
2648 // %negN-1 = "tf.Neg"(%unpacked:N-1)
2649 //
2650 // Rewrite it to:
2651 //
2652 // %neg = "tf.Neg"(%0)
2653 // %unpacked:N = "tf.Unpack"(%neg)
2654 class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
2655 public:
HoistCwiseUnaryOutOfUnpack(MLIRContext * context)2656 explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
2657 : OpRewritePattern<UnpackOp>(context) {}
2658 LogicalResult matchAndRewrite(UnpackOp op,
2659 PatternRewriter &rewriter) const override;
2660 };
2661
matchAndRewrite(UnpackOp op,PatternRewriter & rewriter) const2662 LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
2663 UnpackOp op, PatternRewriter &rewriter) const {
2664 auto loc = op.getLoc();
2665
2666 // First unpack user must be coeff-wise unary operation.
2667 Operation *first_user = *op->getUsers().begin();
2668 if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
2669
2670 // All unpack users must be defined by the op of same kind.
2671 bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
2672 return user->getName() == first_user->getName();
2673 });
2674 if (!users_same_op) return failure();
2675
2676 // Pass unpack operand to unary operation.
2677 OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
2678 op.getOperand(), op.getOperand().getType(),
2679 ArrayRef<NamedAttribute>());
2680 Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
2681
2682 // Unpack results after applying unary operation.
2683 auto unpack_unary_op = rewriter.create<UnpackOp>(
2684 loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
2685
2686 // Bypass all users of the original unpack operation and use `unpack_unary_op`
2687 // results instead.
2688 for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
2689 OpResult old_result = std::get<0>(pair); // result of original Unpack
2690 OpResult new_result = std::get<1>(pair); // result of transformed Unpack
2691 for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
2692 rewriter.replaceOp(user, ValueRange(new_result));
2693 }
2694
2695 // Erase original unpack operation.
2696 rewriter.eraseOp(op.getOperation());
2697
2698 return success();
2699 }
2700
2701 } // namespace
2702
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2703 void UnpackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2704 MLIRContext *context) {
2705 results.insert<HoistCwiseUnaryOutOfUnpack>(context);
2706 }
2707
2708 //===----------------------------------------------------------------------===//
2709 // Unsorted segment reduction ops
2710 //===----------------------------------------------------------------------===//
2711
2712 template <class Op>
VerifyUnsortedSegmentReduction(Op op)2713 static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
2714 if (!HasRankAtMost(op.num_segments(), 0))
2715 return op.emitOpError("number of segments should be a 0-D tensor");
2716
2717 auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
2718 auto segment_ids_type =
2719 op.segment_ids().getType().template dyn_cast<RankedTensorType>();
2720 if (data_type && segment_ids_type) {
2721 if (data_type.getRank() < segment_ids_type.getRank())
2722 return op.emitOpError(
2723 "requires segment ids rank to be less than or equal to data's rank");
2724
2725 int index = 0;
2726 for (auto shape_pair :
2727 llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
2728 int64_t segment_id_dim = std::get<0>(shape_pair);
2729 int64_t data_dim = std::get<1>(shape_pair);
2730 if (!ShapedType::isDynamic(segment_id_dim) &&
2731 !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
2732 return op.emitOpError(
2733 "requires segment ids shape to be a prefix of data shape, "
2734 "but dimension #")
2735 << index << " differs: " << segment_id_dim << " vs. "
2736 << data_dim;
2737 ++index;
2738 }
2739 }
2740
2741 DenseIntElementsAttr num_segments_attr;
2742 if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
2743 int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
2744 if (num_segments < 0)
2745 return op.emitOpError("num of segments cannot be negative");
2746 }
2747
2748 return success();
2749 }
2750
2751 //===----------------------------------------------------------------------===//
2752 // VarHandleOp
2753 //===----------------------------------------------------------------------===//
2754
GetResourceHandleValueAndId(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2755 ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId(
2756 llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2757 int64_t &next_id) {
2758 llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2759 return GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2760 resource(), resource_handle_id_map,
2761 next_id);
2762 }
2763
2764 //===----------------------------------------------------------------------===//
2765 // VarIsInitializedOp
2766 //===----------------------------------------------------------------------===//
2767
2768 namespace {
2769
2770 /// Erase VarIsInitializedOp operations with no uses. This op has side effect on
2771 /// resources (read-only), but can still be deleted if it has zero uses.
2772 struct EraseDeadVarIsInitializedOp
2773 : public OpRewritePattern<VarIsInitializedOp> {
2774 using OpRewritePattern<VarIsInitializedOp>::OpRewritePattern;
2775
matchAndRewritemlir::TF::__anon2bac52fa1311::EraseDeadVarIsInitializedOp2776 LogicalResult matchAndRewrite(VarIsInitializedOp op,
2777 PatternRewriter &rewriter) const override {
2778 if (!op.use_empty()) return failure();
2779 rewriter.eraseOp(op);
2780 return success();
2781 }
2782 };
2783 } // end anonymous namespace.
2784
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2785 void VarIsInitializedOp::getCanonicalizationPatterns(
2786 OwningRewritePatternList &patterns, MLIRContext *context) {
2787 patterns.insert<EraseDeadVarIsInitializedOp>(context);
2788 }
2789
2790 //===----------------------------------------------------------------------===//
2791 // VariableOp
2792 //===----------------------------------------------------------------------===//
2793
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2794 void VariableOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2795 MLIRContext *context) {
2796 results.insert<VariableToVariableV2>(context);
2797 }
2798
2799 //===----------------------------------------------------------------------===//
2800 // VariableShapeOp
2801 //===----------------------------------------------------------------------===//
2802
Verify(VariableShapeOp op)2803 static LogicalResult Verify(VariableShapeOp op) {
2804 auto input_type = op.input().getType().cast<TensorType>();
2805 if (input_type.hasStaticShape() && input_type.getNumElements() != 1)
2806 return op.emitOpError("requires input to have one resource");
2807
2808 auto resource_type = input_type.getElementType().cast<TF::ResourceType>();
2809 auto subtypes = resource_type.getSubtypes();
2810 switch (subtypes.size()) {
2811 case 1:
2812 return VerifyShapeOperandAndResult(
2813 op, resource_type.getSubtypes().front(), op.getType());
2814 case 0:
2815 return VerifyShapeOperandAndResult(op, Type(), op.getType());
2816 default:
2817 return op.emitOpError(
2818 "requires resource input type to have at most 1 subtype");
2819 }
2820 }
2821
fold(ArrayRef<Attribute> operands)2822 OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
2823 int width =
2824 getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2825 auto resource_type =
2826 getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
2827 if (resource_type.getSubtypes().empty()) return {};
2828 return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
2829 }
2830
2831 //===----------------------------------------------------------------------===//
2832 // WhileOp
2833 //===----------------------------------------------------------------------===//
2834
VerifyWhileTypes(Operation * op,TypeRange cond_input,TypeRange body_input,TypeRange body_result,bool shape_invariant)2835 static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
2836 TypeRange body_input,
2837 TypeRange body_result,
2838 bool shape_invariant) {
2839 const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
2840 const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
2841 constexpr int kNumRegionTypeLists = 3;
2842 const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
2843 {body_result, "body result"},
2844 {cond_input, "condition input"},
2845 {body_input, "body input"},
2846 }};
2847
2848 // A pair of type lists should be cast compatible with each other if one is
2849 // converted to the another for a function call or assignment or there is a
2850 // common source of inputs for both. Therefore, the While op requires the
2851 // following pairs of type lists to be cast compatible for the tensor_cast
2852 // operation:
2853 //
2854 // * Operands and cond inputs to call the cond function before the
2855 // first iteration.
2856 // * Operands and body inputs to call the body function for the first
2857 // iteration if the cond functions returns True or equivalent result.
2858 // * Operands and results to assign cond function arguments to op results if
2859 // the cond function returns False or equivalent result. If the op is shape
2860 // invariant, this does not hold as shapes can differ.
2861 // * All three pairs using cond inputs, body inputs and results as operand is
2862 // a common source for all three.
2863 // * Body result and cond inputs to call the cond function for the subsequent
2864 // iterations. Similarly, Body result should be compatible with body inputs
2865 // and op results.
2866 //
2867 // Note that the operands and body results need not be compatible as they are
2868 // never converted from one to the another nor there is a common source
2869 // tensors. Compatibility requirement is not transitive.
2870
2871 if (!shape_invariant &&
2872 failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
2873 return failure();
2874
2875 // Skip the first pair as the While op operands and body function results does
2876 // not need to be compatible with each other.
2877 for (int i = 1; i < kNumRegionTypeLists; ++i)
2878 if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
2879 return failure();
2880
2881 for (int i = 0; i < kNumRegionTypeLists; ++i)
2882 if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
2883 return failure();
2884
2885 for (int i = 0; i < kNumRegionTypeLists; ++i)
2886 for (int j = i + 1; j < kNumRegionTypeLists; ++j)
2887 if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
2888 region_types[j])))
2889 return failure();
2890
2891 return success();
2892 }
2893
Verify(WhileOp op)2894 static LogicalResult Verify(WhileOp op) {
2895 auto cond_fn = op.cond_function();
2896 auto body_fn = op.body_function();
2897 if (!cond_fn) {
2898 return op.emitOpError("cond refers to an undefined function : ")
2899 << op.cond();
2900 }
2901 if (!body_fn) {
2902 return op.emitOpError("body refers to an undefined function : ")
2903 << op.body();
2904 }
2905
2906 auto cond_fn_type = cond_fn.getType();
2907 auto body_fn_type = body_fn.getType();
2908
2909 // Verify that the cond function has exactly one result.
2910 if (cond_fn_type.getNumResults() != 1)
2911 return op.emitOpError("requires cond function to have exactly one result");
2912
2913 if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(),
2914 /*body_input=*/body_fn_type.getInputs(),
2915 /*body_result=*/body_fn_type.getResults(),
2916 op.shape_invariant())))
2917 return failure();
2918 return success();
2919 }
2920
2921 //===----------------------------------------------------------------------===//
2922 // WhileRegionOp
2923 //===----------------------------------------------------------------------===//
Verify(WhileRegionOp op)2924 static LogicalResult Verify(WhileRegionOp op) {
2925 // Verify that the condition generates a single tensor<i1> result.
2926 Operation *cond_yield = op.cond().front().getTerminator();
2927 if (cond_yield->getNumOperands() != 1)
2928 return op.emitOpError()
2929 << "condition should have a single tensor<i1> result";
2930
2931 auto cond_type =
2932 cond_yield->getOperand(0).getType().dyn_cast<RankedTensorType>();
2933 if (!cond_type || !cond_type.getShape().equals({}) ||
2934 !cond_type.getElementType().isInteger(/*width=*/1))
2935 return op.emitOpError()
2936 << "condition should have a single tensor<i1> result";
2937
2938 Operation *body_yield = op.body().front().getTerminator();
2939 if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
2940 /*body_input=*/op.body().getArgumentTypes(),
2941 /*body_result=*/body_yield->getOperandTypes(),
2942 op.shape_invariant())))
2943 return failure();
2944 return success();
2945 }
2946
2947 //===----------------------------------------------------------------------===//
2948 // WhileRegionOp LoopLikeOpInterface
2949 //===----------------------------------------------------------------------===//
2950
getLoopBody()2951 Region &WhileRegionOp::getLoopBody() { return body(); }
2952
isDefinedOutsideOfLoop(Value value)2953 bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) {
2954 // If the Op defining the value exists and the defining op is outside the
2955 // scope of this WhileRegion, then we can infer that its defined outside.
2956 // The defining Op is outside the scope of this WhileRegion if this
2957 // WhileRegionOp is not an ancestor of the defining op in the parent chain.
2958 Operation *def_op = value.getDefiningOp();
2959 return def_op && !getOperation()->isAncestor(def_op);
2960 }
2961
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)2962 LogicalResult WhileRegionOp::moveOutOfLoop(
2963 llvm::ArrayRef<mlir::Operation *> ops) {
2964 // Move the hoisted value to just before the while.
2965 Operation *while_op = this->getOperation();
2966 for (auto op : ops) op->moveBefore(while_op);
2967 return success();
2968 }
2969
2970 //===----------------------------------------------------------------------===//
2971 // WhileRegionOp canonicalization
2972 //===----------------------------------------------------------------------===//
2973 namespace {
2974 // Eliminate values that pass through the WhileRegionOp body.
2975 struct WhileRegionEliminatePassThrough
2976 : public OpRewritePattern<WhileRegionOp> {
2977 using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
2978
matchAndRewritemlir::TF::__anon2bac52fa1411::WhileRegionEliminatePassThrough2979 LogicalResult matchAndRewrite(WhileRegionOp while_op,
2980 PatternRewriter &rewriter) const override {
2981 // Replace values that simply passthrough the body with extern values. The
2982 // block arguments of body and while match and so the corresponding cond
2983 // argument can be easily found.
2984 int old_num_operands = while_op.getNumOperands();
2985 int new_num_operands = old_num_operands;
2986 auto &body_block = while_op.body().front();
2987 auto &cond_block = while_op.cond().front();
2988 auto &yield = *body_block.getTerminator();
2989
2990 // Bit mask indicating which operands will be removed.
2991 SmallVector<bool, 16> removed_operand(old_num_operands, false);
2992
2993 for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
2994 auto body_arg = body_block.getArgument(op_idx);
2995 if (body_arg == yield.getOperand(op_idx)) {
2996 // Replace the use of the passthrough value with the while operand
2997 // in the body and condition regions, as well as the while output (if
2998 // type match)
2999 // TODO(jurahul): Use PatternRewriter API for IR modification.
3000 auto value = while_op.getOperand(op_idx);
3001 if (body_arg.getType() == value.getType())
3002 body_arg.replaceAllUsesWith(value);
3003
3004 auto cond_arg = cond_block.getArgument(op_idx);
3005 if (cond_arg.getType() == value.getType())
3006 cond_arg.replaceAllUsesWith(value);
3007
3008 auto result = while_op.getResult(op_idx);
3009 if (result.getType() == value.getType())
3010 result.replaceAllUsesWith(value);
3011 }
3012
3013 // Now check if the operand is unused in both regions as well as the
3014 // result. If so, mark it for removal.
3015 if (body_block.getArgument(op_idx).use_empty() &&
3016 cond_block.getArgument(op_idx).use_empty() &&
3017 while_op.getResult(op_idx).use_empty()) {
3018 removed_operand[op_idx] = true;
3019 new_num_operands--;
3020 }
3021 }
3022
3023 if (new_num_operands == old_num_operands) return failure();
3024
3025 // Compress the operands, region arguments, and outputs.
3026 SmallVector<Value, 4> new_while_operands;
3027 SmallVector<Type, 4> new_result_types;
3028 new_while_operands.reserve(new_num_operands);
3029 new_result_types.reserve(new_num_operands);
3030
3031 // Build new operands and result type.
3032 int next_idx = 0;
3033 for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3034 if (removed_operand[op_idx]) continue;
3035 new_while_operands.push_back(while_op.getOperand(op_idx));
3036 new_result_types.push_back(while_op.getResult(op_idx).getType());
3037 next_idx++;
3038 }
3039
3040 // Create the new while operation.
3041 auto new_while_op =
3042 rewriter.create<WhileRegionOp>(while_op.getLoc(), new_result_types,
3043 new_while_operands, while_op.getAttrs());
3044
3045 // Move region bodies to the new while.
3046 rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
3047 new_while_op.cond().end());
3048 rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
3049 new_while_op.body().end());
3050
3051 auto &new_cond_block = new_while_op.cond().front();
3052 auto &new_body_block = new_while_op.body().front();
3053 auto &new_yield = *new_body_block.getTerminator();
3054
3055 // Build a vector of new results. Also patch up the region bodies and
3056 // yield.
3057 SmallVector<Value, 4> new_results;
3058 next_idx = 0;
3059 for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3060 if (removed_operand[op_idx]) {
3061 new_cond_block.eraseArgument(next_idx);
3062 new_body_block.eraseArgument(next_idx);
3063 new_yield.eraseOperand(next_idx);
3064 new_results.push_back(nullptr);
3065 } else {
3066 new_results.push_back(new_while_op.getResult(next_idx++));
3067 }
3068 }
3069
3070 rewriter.replaceOp(while_op, new_results);
3071 return success();
3072 }
3073 };
3074
3075 } // anonymous namespace
3076
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3077 void WhileRegionOp::getCanonicalizationPatterns(
3078 OwningRewritePatternList &results, MLIRContext *context) {
3079 results.insert<WhileRegionEliminatePassThrough>(context);
3080 }
3081
3082 //===----------------------------------------------------------------------===//
3083 // XdivyOp
3084 //===----------------------------------------------------------------------===//
3085
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3086 void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3087 MLIRContext *context) {
3088 results.insert<XdivyWithSqrtDivisor>(context);
3089 }
3090
3091 //===----------------------------------------------------------------------===//
3092 // XlaBroadcastHelperOp
3093 //===----------------------------------------------------------------------===//
3094
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3095 LogicalResult XlaBroadcastHelperOp::inferReturnTypes(
3096 MLIRContext *context, Optional<Location> location, ValueRange operands,
3097 DictionaryAttr attributes, RegionRange regions,
3098 SmallVectorImpl<Type> &inferredReturnTypes) {
3099 auto loc = location ? *location : mlir::UnknownLoc::get(context);
3100 XlaBroadcastHelperOpAdaptor op(operands, attributes);
3101 if (failed(op.verify(loc))) {
3102 return failure();
3103 }
3104
3105 Value lhs = op.lhs();
3106 Value rhs = op.rhs();
3107 auto set_unranked_results = [&]() {
3108 auto unranked_lhs = UnrankedTensorType::get(getElementTypeOrSelf(lhs));
3109 inferredReturnTypes.push_back(unranked_lhs);
3110 auto unranked_rhs = UnrankedTensorType::get(getElementTypeOrSelf(rhs));
3111 inferredReturnTypes.push_back(unranked_rhs);
3112 return success();
3113 };
3114
3115 RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
3116 RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
3117 if (!lhs_ty || !rhs_ty) return set_unranked_results();
3118
3119 int64_t lhs_rank = lhs_ty.getRank();
3120 int64_t rhs_rank = rhs_ty.getRank();
3121
3122 DenseIntElementsAttr dims;
3123 if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
3124 return set_unranked_results();
3125 }
3126
3127 if (dims.size() == 0) {
3128 if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
3129 return emitOptionalError(
3130 location,
3131 "if broadcast_dims is empty, both arguments must have equal rank or "
3132 "at least one argument must be a scalar");
3133 }
3134 inferredReturnTypes.push_back(lhs_ty);
3135 inferredReturnTypes.push_back(rhs_ty);
3136 return success();
3137 }
3138
3139 const bool broadcast_lhs = lhs_rank < rhs_rank;
3140 RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
3141 RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
3142
3143 if (dims.size() != min_rank_ty.getRank()) {
3144 return emitOptionalError(
3145 location,
3146 "broadcast_dims must have size equal to the smaller argument rank");
3147 }
3148
3149 int64_t output_rank = max_rank_ty.getRank();
3150 llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
3151 llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
3152 for (auto item : llvm::enumerate(dims)) {
3153 int64_t index = item.index();
3154 int64_t dim = item.value().getSExtValue();
3155 if (dim < 0 || dim > output_rank) {
3156 return emitOptionalError(location, "out of range broadcast dim");
3157 }
3158 if (is_broadcasted[dim]) {
3159 return emitOptionalError(location, "broadcast_dims has duplicates");
3160 }
3161 broadcast_shape[dim] = min_rank_ty.getDimSize(index);
3162 is_broadcasted[dim] = true;
3163 }
3164
3165 if (broadcast_lhs) {
3166 inferredReturnTypes.push_back(
3167 RankedTensorType::get(broadcast_shape, lhs_ty.getElementType()));
3168 inferredReturnTypes.push_back(rhs_ty);
3169 } else {
3170 inferredReturnTypes.push_back(lhs_ty);
3171 inferredReturnTypes.push_back(
3172 RankedTensorType::get(broadcast_shape, rhs_ty.getElementType()));
3173 }
3174 return success();
3175 }
3176
3177 //===----------------------------------------------------------------------===//
3178 // XlaSetDynamicDimensionSizeOp
3179 //===----------------------------------------------------------------------===//
3180
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3181 LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypes(
3182 MLIRContext *context, Optional<Location> location, ValueRange operands,
3183 DictionaryAttr attributes, RegionRange regions,
3184 SmallVectorImpl<Type> &inferredReturnTypes) {
3185 inferredReturnTypes.assign({operands.front().getType()});
3186 return success();
3187 }
3188
3189 } // namespace TF
3190 } // namespace mlir
3191
3192 //===----------------------------------------------------------------------===//
3193 // TableGen'd op method definitions
3194 //===----------------------------------------------------------------------===//
3195
3196 #define GET_OP_CLASSES
3197 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc"
3198