• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/BitVector.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/None.h"
33 #include "llvm/ADT/Optional.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/Sequence.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/ADT/StringSwitch.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/ADT/iterator_range.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
45 #include "mlir/Dialect/Traits.h"  // from @llvm-project
46 #include "mlir/IR/Attributes.h"  // from @llvm-project
47 #include "mlir/IR/Builders.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
49 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
50 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
51 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
52 #include "mlir/IR/Identifier.h"  // from @llvm-project
53 #include "mlir/IR/Location.h"  // from @llvm-project
54 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
55 #include "mlir/IR/Matchers.h"  // from @llvm-project
56 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
57 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
58 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
59 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
60 #include "mlir/IR/Types.h"  // from @llvm-project
61 #include "mlir/IR/Value.h"  // from @llvm-project
62 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
63 #include "mlir/Parser.h"  // from @llvm-project
64 #include "mlir/Support/LLVM.h"  // from @llvm-project
65 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
66 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
72 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
73 #include "tensorflow/core/platform/logging.h"
74 #include "tensorflow/core/util/tensor_format.h"
75 
76 namespace mlir {
77 namespace TF {
78 
79 namespace {
80 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
81 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
82 }  // namespace
83 
84 //===----------------------------------------------------------------------===//
85 // NotEqualOp
86 //===----------------------------------------------------------------------===//
87 
Verify(NotEqualOp op)88 static LogicalResult Verify(NotEqualOp op) {
89   // If we allow inputs to have incompatible type, then nothing to do.
90   if (!op.incompatible_shape_error()) return success();
91 
92   // Otherwise, check inputs are broadcastable.
93   return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
94       op.getOperation());
95 }
96 
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)97 void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
98                        Value y, BoolAttr incompatible_shape_error) {
99   auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
100                                           incompatible_shape_error);
101   return build(builder, result, result_type, x, y, incompatible_shape_error);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // OneHotOp
106 //===----------------------------------------------------------------------===//
107 
Verify(OneHotOp op)108 static LogicalResult Verify(OneHotOp op) {
109   int64_t axis = op.axis();
110 
111   auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
112   if (indices_ty &&
113       !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
114     return op.emitOpError()
115            << "expected axis (" << axis << ") to be -1 or between [0, "
116            << indices_ty.getShape().size() << "]";
117   }
118 
119   if (axis < -1) {
120     return op.emitOpError() << "expected axis (" << axis
121                             << ") to be -1 or between [0, rank(indices()))";
122   }
123 
124   if (!IsOfRankOrUnranked(op.depth(), 0)) {
125     return op.emitOpError() << "requires depth to be a scalar";
126   }
127   if (!IsOfRankOrUnranked(op.on_value(), 0)) {
128     return op.emitOpError() << "requires on_value to be a scalar";
129   }
130   if (!IsOfRankOrUnranked(op.off_value(), 0)) {
131     return op.emitOpError() << "requires off_value to be a scalar";
132   }
133 
134   DenseIntElementsAttr depth_attr;
135   if (matchPattern(op.depth(), m_Constant(&depth_attr))) {
136     if (depth_attr.getType().getRank() != 0)
137       return op.emitOpError() << "requires depth to be a scalar";
138     int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
139     if (depth < 0) {
140       return op.emitOpError() << "depth must be non-negative, got: " << depth;
141     }
142   }
143 
144   return success();
145 }
146 
InferOneHotOpType(Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)147 static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
148                                     Value off_value, IntegerAttr axis) {
149   int64_t axis_val = axis.getInt();
150   Type element_ty = on_value.getType().cast<TensorType>().getElementType();
151   auto unranked_ty = UnrankedTensorType::get(element_ty);
152   if (axis_val < -1) return unranked_ty;
153 
154   auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
155   if (!indices_ty) return unranked_ty;
156 
157   auto shape = llvm::to_vector<2>(indices_ty.getShape());
158   if (axis_val == -1) axis_val = shape.size();
159 
160   int64_t depth_val = ShapedType::kDynamicSize;
161   DenseIntElementsAttr depth_attr;
162   if (matchPattern(depth, m_Constant(&depth_attr)) &&
163       depth_attr.getNumElements() == 1)
164     depth_val = (*depth_attr.begin()).getSExtValue();
165   shape.insert(shape.begin() + axis_val, depth_val);
166   return RankedTensorType::get(shape, element_ty);
167 }
168 
build(OpBuilder & builder,OperationState & result,Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)169 void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices,
170                      Value depth, Value on_value, Value off_value,
171                      IntegerAttr axis) {
172   build(builder, result,
173         InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
174         depth, on_value, off_value, axis);
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // PackOp
179 //===----------------------------------------------------------------------===//
180 
Verify(PackOp op)181 static LogicalResult Verify(PackOp op) {
182   // TODO(hinsu): Convert variadic length attributes to derived attributes.
183   Operation::operand_range values = op.values();
184 
185   if (failed(VerifyTypesCompatibility(values,
186                                       /*mask_one_dim=*/false,
187                                       op.getOperation()))) {
188     return failure();
189   }
190 
191   int64_t inputs_rank = -1;
192   for (Value value : values) {
193     if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
194       // Exit early as input types are verified to be compatible so all ranked
195       // tensors have the same rank.
196       inputs_rank = ty.getRank();
197       break;
198     }
199   }
200   if (inputs_rank == -1) return success();
201 
202   // The values can be packed along any of the dimensions between 0 and
203   // inputs rank, inclusive. Also, as the negative axis values wrap around so
204   // the axis value range is [-(R+1), R+1).
205   int64_t range_begin = -inputs_rank - 1;  // Inclusive
206   int64_t range_end = inputs_rank + 1;     // Exclusive
207   int64_t axis = op.axis();
208   if (axis < range_begin || axis >= range_end) {
209     return op.emitError() << "attribute 'axis' should be within range ["
210                           << range_begin << ", " << range_end
211                           << "); actual value: " << axis;
212   }
213 
214   return success();
215 }
216 
fold(ArrayRef<Attribute> operands)217 OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
218   // Fold pack operation if it computes the input tensor shape:
219   //
220   //   %shape  = tf.Shape(%arg)                    // [? x ...]
221   //   %dim0   = tf.StridedSlice(%shape, 0, 1, 1)  // get unknown dim0 value
222   //   %pack   = tf.Pack(dim0, ...) { axis = 0 }   // [? x ...]
223   //
224   // Where `...` are some statically known dimensions. In this case %pack can be
225   // replaced with a %shape. This is a common pattern in models with a dynamic
226   // batch size.
227 
228   // Pack operation should pack at least two values.
229   if (values().size() < 2) return {};
230 
231   // Dimensions packed along axis = 0 (pack scalars into vector).
232   if (axis() != 0) return {};
233 
234   // First packed value is defined by a strided slice operation.
235   auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
236   if (!slice_op) return {};
237 
238   // Input to the slice op is defined by shape operation.
239   auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
240   if (!shape_op) return {};
241 
242   // Input tensor, which shape is reconstructed by the pack operation.
243   Value tensor = shape_op.input();
244 
245   // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
246   // scalar value from input vector).
247   if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 ||
248       slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 ||
249       slice_op.shrink_axis_mask() != 1)
250     return {};
251 
252   // Returns a value if the `value` is defined by a ConstOp with a single
253   // integer element in it and has an expected rank.
254   auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
255     auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
256     if (!const_op) return None;
257 
258     auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
259     if (!value_attr || value_attr.getNumElements() != 1) return None;
260 
261     auto value_ty = value_attr.getType();
262     if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
263 
264     auto splat = value_attr.getSplatValue<IntegerAttr>();
265     return splat.getValue().getSExtValue();
266   };
267 
268   // All other packed values are scalar constants.
269   SmallVector<int64_t, 4> packed_dims;
270   packed_dims.reserve(values().size() - 1);
271   for (Value operand : llvm::drop_begin(values(), 1)) {
272     if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
273       packed_dims.push_back(*dim);
274     } else {
275       return {};
276     }
277   }
278 
279   // Slice exactly the first shape dimension:
280   //   begin = [0] end = [1], strides = [1]
281   auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
282   auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
283   auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
284   if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
285       *begin != 0 || *end != 1 || *strides != 1)
286     return {};
287 
288   // First tensor dimension is dynamic.
289   auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
290   if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
291       !arg_ty.isDynamicDim(0))
292     return {};
293 
294   // Argument tensor rank is equal to the number of packed dimensions.
295   if (arg_ty.getRank() != values().size()) return {};
296 
297   // All other dimensions are statically known and equal to packed dims.
298   auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
299   if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
300     return {};
301 
302   // Replace %pack with %shape.
303   return slice_op.input();
304 }
305 
306 // Convert Pack to Reshape when there is only one operand to be packed.
307 // For example,
308 //
309 //   %0 = tf.Pack(%input) {axis = 0} // %input : tensor<2x3xf32>
310 //
311 // can be canonicalized to
312 //
313 //   %shape = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>}
314 //   %0 = tf.Reshape(%input, %shape)
315 struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
316   using OpRewritePattern<PackOp>::OpRewritePattern;
317 
matchAndRewritemlir::TF::ConvertPackToReshape318   LogicalResult matchAndRewrite(PackOp pack_op,
319                                 PatternRewriter &rewriter) const override {
320     // Check if there is only one operand to be packed.
321     if (pack_op.N() != 1) {
322       return failure();
323     }
324 
325     // Check if input and output are static.
326     auto input_ty = pack_op.getOperand(0).getType().cast<ShapedType>();
327     auto output_ty = pack_op.output().getType().cast<ShapedType>();
328     if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) {
329       return failure();
330     }
331 
332     // Create constant shape for reshape.
333     auto type =
334         RankedTensorType::get(output_ty.getRank(), rewriter.getIntegerType(64));
335     auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
336     auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
337 
338     // TODO(b/173622615): Remove after fixed.
339     ReplaceTfOpWithNewOp<ReshapeOp>(rewriter, pack_op, output_ty,
340                                     pack_op.getOperand(0), shape);
341     return success();
342   }
343 };
344 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)345 void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
346                                          MLIRContext *context) {
347   results.insert<ConvertPackToReshape>(context);
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // PadOp
352 //===----------------------------------------------------------------------===//
353 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)354 LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
355   // Paddings must be defined by a constant operation.
356   auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
357   if (!paddings_op) return failure();
358 
359   auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
360   if (!paddings_value ||
361       paddings_value.getNumElements() != permutation.size() * 2)
362     return failure();
363 
364   SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
365   for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) {
366     size_t outer_idx = index_pair.index() / 2;
367     size_t inner_idx = index_pair.index() % 2;
368 
369     shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
370         index_pair.value().getSExtValue();
371   }
372 
373   // Add constant operation with a new paddings.
374   OpBuilder builder(getOperation());
375   auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
376                                           builder.getIntegerType(32));
377   auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
378   auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
379 
380   // Use new paddings.
381   setOperand(1, shuffled_paddings_op);
382 
383   // Change the result type.
384   getResult().setType(ShuffleRankedTensorType(getResult().getType(),
385                                               ReversePermutation(permutation)));
386 
387   return success();
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // ParseExampleV2Op
392 //===----------------------------------------------------------------------===//
393 
Verify(ParseExampleV2Op op)394 static LogicalResult Verify(ParseExampleV2Op op) {
395   // NOTE(mrry): This validates properties of an op that would previously be
396   // validated by the TensorFlow OpDef type checker. In addition to these
397   // checks, the shape inference function for ParseExampleV2 validates the
398   // consistency of the argument and result types.
399 
400   // Validate dense variadic input and output lengths.
401   // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
402   // do not need to validate dense_defaults.
403   auto dense_types_count =
404       std::distance(op.Tdense().begin(), op.Tdense().end());
405   auto dense_values_count =
406       std::distance(op.dense_values().begin(), op.dense_values().end());
407   if (dense_values_count != dense_types_count) {
408     return op.emitError() << "output 'dense_values' should have same length "
409                           << "as attribute 'Tdense'";
410   }
411 
412   // Validate sparse variadic output lengths.
413   // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
414   // do not need to validate sparse_values.
415   auto sparse_types_count =
416       std::distance(op.sparse_types().begin(), op.sparse_types().end());
417   if (op.num_sparse() != sparse_types_count) {
418     return op.emitError() << "attribute 'num_sparse' should be the same as "
419                           << "the length of attribute 'sparse_types'";
420   }
421   if (op.sparse_indices().size() != sparse_types_count) {
422     return op.emitError() << "output 'sparse_indices' should have same length "
423                           << "as attribute 'sparse_types'";
424   }
425   if (op.sparse_shapes().size() != sparse_types_count) {
426     return op.emitError() << "output 'sparse_shapes' should have same length "
427                           << "as attribute 'sparse_types'";
428   }
429 
430   // Validate ragged variadic output lengths.
431   auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
432                                                 op.ragged_value_types().end());
433   auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
434                                                 op.ragged_split_types().end());
435   if (ragged_value_types_count != ragged_split_types_count) {
436     return op.emitError() << "attribute 'ragged_value_types' should have same "
437                           << "length as attribute 'ragged_split_types'";
438   }
439 
440   return success();
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // PartitionedCallOp
445 //===----------------------------------------------------------------------===//
446 
447 template <class OpClass>
VerifyPartitionedCall(OpClass op)448 static LogicalResult VerifyPartitionedCall(OpClass op) {
449   auto module = op->template getParentOfType<ModuleOp>();
450   SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
451 
452   auto function =
453       dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
454 
455   if (!function) {
456     return op.emitError("'f' attribute refers to an undefined function: ")
457            << func;
458   }
459 
460   FunctionType function_ty = function.getType();
461   int func_arg_count = function_ty.getNumInputs();
462   int arg_count = op.args().size();
463 
464   if (arg_count != func_arg_count) {
465     return op.emitError() << "argument count mismatch: 'args' has " << arg_count
466                           << " arguments, but '" << func << "' expects "
467                           << func_arg_count;
468   }
469 
470   return success();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // PowOp
475 //===----------------------------------------------------------------------===//
476 
fold(ArrayRef<Attribute> operands)477 OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
478   auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
479   if (constant_y && constant_y.isSplat()) {
480     APFloat y_value = constant_y.getSplatValue<APFloat>();
481     auto output_type = getType().cast<ShapedType>();
482     if (y_value.isZero() && output_type.hasStaticShape()) {
483       return DenseElementsAttr::get(
484           output_type,
485           FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
486     }
487     if (y_value.isExactlyValue(1.0)) {
488       return x();
489     }
490   }
491   return {};
492 }
493 
494 //===----------------------------------------------------------------------===//
495 // QuantizeAndDequantizeV2Op
496 //===----------------------------------------------------------------------===//
497 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)498 void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns(
499     OwningRewritePatternList &results, MLIRContext *context) {
500   results.insert<QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4>(context);
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // QrOp
505 //===----------------------------------------------------------------------===//
506 
507 // Verifies that,
508 //
509 // * Input type, if ranked, must have at least 2 dimensions and at most
510 //   INT32_MAX dimensions.
511 //
Verify(QrOp op)512 static LogicalResult Verify(QrOp op) {
513   auto ttype = op.input().getType().cast<TensorType>();
514   if (!ttype.hasRank()) return success();
515   if (!HasRankAtLeast(op.input(), 2))
516     return op.emitOpError(
517         "requires ranked input tensor to be of rank 2 or more");
518   if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
519     return op.emitOpError(
520         "requires ranked input tensor to be of rank INT32_MAX or less");
521 
522   return success();
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // ReadVariableOp
527 //===----------------------------------------------------------------------===//
528 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)529 void ReadVariableOp::getCanonicalizationPatterns(
530     OwningRewritePatternList &results, MLIRContext *context) {
531   results.insert<ReadVariableOfCast>(context);
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // RandomUniformOp
536 //===----------------------------------------------------------------------===//
537 
Verify(RandomUniformOp op)538 static LogicalResult Verify(RandomUniformOp op) {
539   if (!IsOfRankOrUnranked(op.shape(), 1))
540     return op.emitOpError("shape must be 1D tensor");
541   return success();
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // RangeOp
546 //===----------------------------------------------------------------------===//
547 
548 namespace {
549 
550 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
551 // Template parameter `FloatOrInt` must be standard C integer or floating-point
552 // types.
553 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)554 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
555   // Refer to the implementation in
556   // tensorflow/lite/kernels/range.cc.
557   FloatOrInt diff = limit - start;
558   if (std::is_integral<FloatOrInt>::value) {
559     return ((std::abs(diff) + std::abs(delta) - 1) / std::abs(delta));
560   }
561   return std::ceil(std::abs(diff / delta));
562 }
563 
564 // Builds a constant range tensor of `result_elem_type` elements.
565 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
566 // mlir::FloatAttr.
567 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)568 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
569                                         FloatOrIntAtrr start_attr,
570                                         FloatOrIntAtrr delta_attr) {
571   using ValueType = typename FloatOrIntAtrr::ValueType;  // APInt or APFloat
572   ValueType start = start_attr.getValue();
573   ValueType delta = delta_attr.getValue();
574 
575   SmallVector<ValueType, 16> new_values;
576   new_values.reserve(num_elements);
577   ValueType new_value = start;
578   for (int i = 0; i < num_elements; ++i) {
579     new_values.push_back(new_value);
580     new_value = new_value + delta;
581   }
582   // Result is always a 1-D tensor.
583   auto new_result_type =
584       RankedTensorType::get({num_elements}, result_elem_type);
585   return DenseElementsAttr::get(new_result_type, new_values);
586 }
587 }  // namespace
588 
build(OpBuilder & builder,OperationState & result,Value start,Value limit,Value delta)589 void RangeOp::build(OpBuilder &builder, OperationState &result, Value start,
590                     Value limit, Value delta) {
591   assert(start.getType() == limit.getType());
592   assert(start.getType() == delta.getType());
593   DenseIntElementsAttr start_val;
594   DenseIntElementsAttr limit_val;
595   DenseIntElementsAttr delta_val;
596   if (matchPattern(start, m_Constant(&start_val)) &&
597       matchPattern(limit, m_Constant(&limit_val)) &&
598       matchPattern(delta, m_Constant(&delta_val))) {
599     auto size = llvm::APIntOps::RoundingSDiv(
600         *limit_val.begin() - *start_val.begin(), *delta_val.begin(),
601         llvm::APInt::Rounding::DOWN);
602     return RangeOp::build(
603         builder, result,
604         RankedTensorType::get(
605             size.getSExtValue(),
606             start.getType().cast<TensorType>().getElementType()),
607         start, limit, delta);
608   }
609   return RangeOp::build(
610       builder, result,
611       RankedTensorType::get(
612           {-1}, start.getType().cast<TensorType>().getElementType()),
613       start, limit, delta);
614 }
615 
fold(ArrayRef<Attribute> operands)616 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
617   assert(operands.size() == 3);
618   auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
619   auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
620   auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
621   if (!(start_tensor && limit_tensor && delta_tensor)) return nullptr;
622 
623   // Operands should all be scalars
624   assert(start_tensor.getType().getRank() == 0 &&
625          limit_tensor.getType().getRank() == 0 &&
626          delta_tensor.getType().getRank() == 0);
627   Type elem_type = getType().cast<ShapedType>().getElementType();
628   if (elem_type.isSignlessInteger() || elem_type.isUnsignedInteger()) {
629     auto start_attr = start_tensor.getValue<IntegerAttr>({});
630     auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
631     auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
632     int num_elements;
633     if (elem_type.isUnsignedInteger()) {
634       uint64_t start = start_attr.getUInt();
635       uint64_t limit = limit_attr.getUInt();
636       uint64_t delta = delta_attr.getUInt();
637       assert(start <= (uint64_t)INT_MAX);
638       assert(limit <= (uint64_t)INT_MAX);
639       assert(delta <= (uint64_t)INT_MAX);
640       num_elements =
641           GetLengthOfRange(static_cast<int>(start), static_cast<int>(limit),
642                            static_cast<int>(delta));
643     } else {
644       num_elements = GetLengthOfRange(start_attr.getInt(), limit_attr.getInt(),
645                                       delta_attr.getInt());
646     }
647     return BuildConstRangeTensor(elem_type, num_elements, start_attr,
648                                  delta_attr);
649   } else if (elem_type.isa<FloatType>()) {
650     auto start_attr = start_tensor.getValue<FloatAttr>({});
651     auto limit_attr = limit_tensor.getValue<FloatAttr>({});
652     auto delta_attr = delta_tensor.getValue<FloatAttr>({});
653     const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
654                                               limit_attr.getValueAsDouble(),
655                                               delta_attr.getValueAsDouble());
656     return BuildConstRangeTensor(elem_type, num_elements, start_attr,
657                                  delta_attr);
658   }
659   return nullptr;
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // RankOp
664 //===----------------------------------------------------------------------===//
665 
build(OpBuilder & builder,OperationState & result,Value input)666 void RankOp::build(OpBuilder &builder, OperationState &result, Value input) {
667   return RankOp::build(builder, result,
668                        RankedTensorType::get({}, builder.getIntegerType(32)),
669                        input);
670 }
671 
672 // This will create a constant value for RankOp of a ranked tensor.
fold(ArrayRef<Attribute> operands)673 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
674   auto type = input().getType();
675   auto ranked_type = type.dyn_cast<RankedTensorType>();
676   if (!ranked_type) return {};
677 
678   // DenseIntElementsAttr::get requires the output type be ranked with static
679   // shape.
680   auto output_type = getType().dyn_cast<RankedTensorType>();
681   if (!output_type || !output_type.hasStaticShape()) return {};
682 
683   int32_t rank = ranked_type.getRank();
684   return DenseIntElementsAttr::get(output_type, rank);
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // RealDivOp
689 //===----------------------------------------------------------------------===//
690 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)691 void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
692                                             MLIRContext *context) {
693   results.insert<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
694 }
695 
fold(ArrayRef<Attribute> operands)696 OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
697   return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // ReshapeOp
702 //===----------------------------------------------------------------------===//
703 
704 namespace {
705 using ReshapeErrorHandler =
706     llvm::function_ref<LogicalResult(const llvm::Twine &)>;
707 
GetReshapeOutputType(Value tensor,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)708 LogicalResult GetReshapeOutputType(Value tensor, Value shape,
709                                    ReshapeErrorHandler error_handler,
710                                    TensorType &output_ty) {
711   auto tensor_ty = tensor.getType().cast<TensorType>();
712   auto element_ty = tensor_ty.getElementType();
713   output_ty = UnrankedTensorType::get(element_ty);
714 
715   auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
716   if (!shape_ty) return success();
717   if (shape_ty.getRank() != 1)
718     return error_handler(llvm::formatv(
719         "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
720 
721   DenseIntElementsAttr shape_attr;
722   if (!matchPattern(shape, m_Constant(&shape_attr))) {
723     // If only shape of `shape` is known, return ranked but dynamic output
724     // shape.
725     if (shape_ty.hasStaticShape()) {
726       llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
727                                                   ShapedType::kDynamicSize);
728       output_ty = RankedTensorType::get(dynamic_shape, element_ty);
729     }
730     return success();
731   }
732 
733   // Detect if reshape output shape is folded.
734   bool shape_ty_zero_dim = false;
735   int unknown_index = -1;
736   // The product of constant shape argument excluding unknown dimension.
737   int64_t shape_ty_size = 1;
738   llvm::SmallVector<int64_t, 8> output_ty_shape;
739   output_ty_shape.reserve(shape_attr.getNumElements());
740   for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) {
741     const int64_t size = dim.value().getSExtValue();
742     if (size == ShapedType::kDynamicSize) {
743       if (unknown_index != -1)
744         return error_handler(llvm::formatv(
745             "requires 'shape' to have at most one dynamic dimension, but got "
746             "multiple dynamic dimensions at indices {0} and {1}",
747             unknown_index, dim.index()));
748 
749       unknown_index = dim.index();
750     } else if (size == 0) {
751       shape_ty_zero_dim = true;
752     } else if (size > 0) {
753       shape_ty_size *= size;
754     } else {
755       return error_handler(
756           llvm::formatv("requires 'shape' to have dimensions greater than -1, "
757                         "but got {0} at index {1}",
758                         size, dim.index()));
759     }
760     output_ty_shape.push_back(size);
761   }
762 
763   if (!tensor_ty.hasStaticShape()) {
764     output_ty = RankedTensorType::get(output_ty_shape, element_ty);
765     return success();
766   }
767 
768   // Compute the value of the unknown dimension.
769   if (unknown_index != -1) {
770     // Compute number of elements in tensor shape.
771     int64_t tensor_ty_size = 1;
772     bool tensor_ty_zero_dim = false;
773     for (const auto &dim : tensor_ty.getShape()) {
774       if (dim > 0 || !shape_ty_zero_dim) {
775         tensor_ty_size *= dim;
776       } else {
777         tensor_ty_zero_dim = true;
778       }
779     }
780 
781     const int64_t missing_dim = tensor_ty_size / shape_ty_size;
782     if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size)
783       return error_handler(
784           llvm::formatv("requires 'tensor' number of elements be a multiple of "
785                         "{0}, but got {1}",
786                         shape_ty_size, tensor_ty_size));
787 
788     // Set the unknown dimension such that total number of elements remain
789     // constant.
790     output_ty_shape[unknown_index] = missing_dim;
791   }
792 
793   output_ty = RankedTensorType::get(output_ty_shape, element_ty);
794 
795   return success();
796 }
797 }  // namespace
798 
Verify(ReshapeOp op)799 static LogicalResult Verify(ReshapeOp op) {
800   auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
801     return op.emitOpError() << message;
802   };
803   TensorType expected_ty;
804   if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler,
805                                   expected_ty)))
806     return failure();
807 
808   auto output_ty = op.getType().dyn_cast<RankedTensorType>();
809   if (!output_ty) return success();
810   auto tensor_ty = op.tensor().getType().cast<TensorType>();
811   if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) {
812     const int64_t output_ty_size = output_ty.getNumElements();
813     const int64_t tensor_ty_size = tensor_ty.getNumElements();
814     if (tensor_ty_size != output_ty_size)
815       return op.emitOpError() << "requires 'output' number of elements to "
816                                  "match 'tensor' number of elements, but got "
817                               << output_ty_size << " and " << tensor_ty_size;
818   }
819 
820   if (!AreCastCompatible({output_ty, expected_ty}))
821     return op.emitOpError()
822            << "requires 'output' type " << output_ty
823            << " to be cast compatible with expected type " << expected_ty;
824 
825   return success();
826 }
827 
828 // Currently there are use cases that rely on partial evaluation of the `shape`
829 // operand, so InferTypeOpInterface is not used (along with generated builder of
830 // the same signature).
build(OpBuilder & builder,OperationState & result,Value tensor,Value shape)831 void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
832                       Value shape) {
833   auto error_handler = [&result](const llvm::Twine &message) {
834     return mlir::emitError(result.location) << message;
835   };
836   TensorType output_ty;
837   if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty)))
838     return;
839 
840   return ReshapeOp::build(builder, result, output_ty, tensor, shape);
841 }
842 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)843 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
844                                             MLIRContext *context) {
845   results.insert<RedundantReshape, ReshapeToSelfShape>(context);
846 }
847 
fold(ArrayRef<Attribute> operands)848 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
849   Value tensor = this->tensor();
850 
851   // Fold reshape if operand and result types are the same and all dimensions
852   // are statically known (no-op reshape).
853   auto result_ty = getType().dyn_cast<ShapedType>();
854   if (result_ty && result_ty.hasStaticShape() &&
855       result_ty == tensor.getType()) {
856     return tensor;
857   }
858 
859   return {};
860 }
861 
862 //===----------------------------------------------------------------------===//
863 // SelectOp
864 //===----------------------------------------------------------------------===//
865 
866 // Verifies a few extra requirements on SelectOp:
867 // (1) `then` and `else` must have same shape
868 // (2) At least one of the following must be true:
869 //     (a) `cond` has the same rank as `then` and `else`
870 //     (b) `cond` is a scalar
871 //     (c) `cond` is a vector AND `then` and `else` are non-scalar with their
872 //         first dimension equal to `cond`.
Verify(SelectOp op)873 static LogicalResult Verify(SelectOp op) {
874   auto then_tensor = op.t().getType().cast<TensorType>();
875   auto else_tensor = op.e().getType().cast<TensorType>();
876   // Check (1).
877   if (!AreCastCompatible({then_tensor, else_tensor}))
878     return op.emitOpError() << "requires t and e have compatible shapes";
879 
880   // Get data rank (if exists).
881   int data_rank;
882   // If data is unranked or data_rank is 0, this will remain -2. Otherwise
883   // refers to first dimension of then and/or else.
884   int data_first_dim = -2;
885   bool then_has_rank = then_tensor.hasRank();
886   bool else_has_rank = else_tensor.hasRank();
887   if (then_has_rank && else_has_rank) {
888     data_rank = then_tensor.getRank();
889     if (then_tensor.getRank() > 0)
890       data_first_dim = then_tensor.getShape().front();
891     if (else_tensor.getRank() > 0)
892       data_first_dim = std::max(
893           static_cast<int>(else_tensor.getShape().front()), data_first_dim);
894   } else if (then_has_rank) {
895     data_rank = then_tensor.getRank();
896     if (then_tensor.getRank() > 0)
897       data_first_dim = then_tensor.getShape().front();
898   } else if (else_has_rank) {
899     data_rank = else_tensor.getRank();
900     if (else_tensor.getRank() > 0)
901       data_first_dim = else_tensor.getShape().front();
902   } else {
903     // Neither has a rank.
904     return success();
905   }
906 
907   auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
908   if (!cond_tensor) return success();
909   auto cond_rank = cond_tensor.getRank();
910   // Check (2a) and (2b).
911   if (cond_rank == 0 || cond_rank == data_rank) return success();
912   // Check (2c).
913   if (cond_rank == 1) {
914     auto cond_shape = cond_tensor.getShape().front();
915     if (data_rank == 0) {
916       return op.emitOpError()
917              << "requires that t and e are nonscalar when pred is a vector";
918     }
919     // We know `data` tensor has a rank of at least 1.
920     if (data_first_dim != -1 && cond_shape != -1 &&
921         data_first_dim != cond_shape) {
922       return op.emitOpError() << "requires that, when pred is a vector, the "
923                                  "shape matches the first dimension of t and e";
924     }
925     return success();
926   }
927   // None of (2a,b,c) were true; fail.
928   return op.emitOpError() << "requires that pred is a scalar OR has the same "
929                              "rank as t and e OR is a vector";
930 }
931 
932 //===----------------------------------------------------------------------===//
933 // SelectV2Op
934 //===----------------------------------------------------------------------===//
935 
InferSelectV2OpType(Value condition,Value e,Value t)936 static Type InferSelectV2OpType(Value condition, Value e, Value t) {
937   Type element_ty = e.getType().cast<TensorType>().getElementType();
938   auto unranked_ty = UnrankedTensorType::get(element_ty);
939 
940   Type broadcasted_ty =
941       OpTrait::util::getBroadcastedType(e.getType(), t.getType());
942   if (!broadcasted_ty) return unranked_ty;
943 
944   auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
945   auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
946   if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
947 
948   // Explicitly get broadcasted output type as element types of condition may
949   // not be same as the broadcated type's element type.
950   SmallVector<int64_t, 4> result_shape;
951   if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(),
952                                           broadcasted_ranked_ty.getShape(),
953                                           result_shape))
954     return unranked_ty;
955   return RankedTensorType::get(result_shape, element_ty);
956 }
957 
build(OpBuilder & builder,OperationState & result,Value condition,Value e,Value t)958 void SelectV2Op::build(OpBuilder &builder, OperationState &result,
959                        Value condition, Value e, Value t) {
960   build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
961 }
962 
963 //===----------------------------------------------------------------------===//
964 // ShapeOp
965 //===----------------------------------------------------------------------===//
966 
967 namespace {
968 // Validates Shape/ShapeN/VariableShape operand and associated result types.
VerifyShapeOperandAndResult(Operation * op,Type operand_type,Type result_type,int variadic_idx=-1)969 LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
970                                           Type result_type,
971                                           int variadic_idx = -1) {
972   std::string variadic_idx_str =
973       variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str();
974 
975   auto result_ranked_type = result_type.dyn_cast<RankedTensorType>();
976   if (!result_ranked_type) return success();
977   if (result_ranked_type.getShape().size() != 1)
978     return op->emitOpError("requires 1D type for result") << variadic_idx_str;
979 
980   auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
981   if (operand_ranked_type) {
982     // The operand is a ranked tensor.
983     if (result_ranked_type.hasStaticShape() &&
984         !operand_ranked_type.getShape().empty() &&
985         result_ranked_type.getDimSize(0) !=
986             operand_ranked_type.getShape().size())
987       return op->emitOpError("requires dimension size of result")
988              << variadic_idx_str << " to match rank of operand"
989              << variadic_idx_str;
990   } else if (result_ranked_type.hasStaticShape()) {
991     // The operand is an unranked tensor, print a warning if the result
992     // is static.
993     // Note: We do not handle this situation as an error, this would be too
994     // restrictive due to incompleteness of shape inference at this point.
995     op->emitWarning("has static shape result")
996         << variadic_idx_str << " for unranked operand" << variadic_idx_str;
997   }
998 
999   Type element_type = result_ranked_type.getElementType();
1000   if (!element_type.isSignlessInteger(32) &&
1001       !element_type.isSignlessInteger(64))
1002     return op->emitOpError("requires int32 or int64 return type for result")
1003            << variadic_idx_str;
1004 
1005   return success();
1006 }
1007 }  // anonymous namespace
1008 
Verify(ShapeOp op)1009 static LogicalResult Verify(ShapeOp op) {
1010   return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
1011 }
1012 
1013 // Converts shape of the given type to attribute if it is of ranked tensor type.
1014 // Returned attribute has integer elements of the given width.
ConvertShapeToAttr(Type input_ty,int out_width)1015 static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
1016   auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
1017   if (!ranked_ty || !ranked_ty.hasStaticShape()) return {};
1018 
1019   auto shape = ranked_ty.getShape();
1020   int rank = shape.size();
1021 
1022   SmallVector<APInt, 4> dimensions;
1023   dimensions.reserve(rank);
1024   for (int i = 0; i < rank; ++i)
1025     dimensions.push_back(APInt(out_width, shape[i]));
1026 
1027   auto result_type = RankedTensorType::get(
1028       {rank}, IntegerType::get(input_ty.getContext(), out_width));
1029   return DenseElementsAttr::get(result_type, dimensions);
1030 }
1031 
fold(ArrayRef<Attribute> operands)1032 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
1033   int width =
1034       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
1035   return ConvertShapeToAttr(getOperand().getType(), width);
1036 }
1037 
build(OpBuilder & builder,OperationState & result,Value input,BoolAttr use32Bit)1038 void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input,
1039                     BoolAttr use32Bit) {
1040   auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
1041   int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
1042   auto out_type = use32Bit.getValue() ? builder.getIntegerType(32)
1043                                       : builder.getIntegerType(64);
1044   return ShapeOp::build(builder, result,
1045                         RankedTensorType::get({rank}, out_type), input);
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // ShapeNOp
1050 //===----------------------------------------------------------------------===//
1051 
Verify(ShapeNOp op)1052 static LogicalResult Verify(ShapeNOp op) {
1053   const size_t num_tensors = op.N();
1054 
1055   if (op.getNumOperands() != num_tensors)
1056     return op.emitOpError() << "requires " << num_tensors << " operand(s), got "
1057                             << op.getNumOperands() << " operand(s)";
1058 
1059   if (op.getNumResults() != num_tensors)
1060     return op.emitOpError() << "requires " << num_tensors << " result(s), got "
1061                             << op.getNumResults() << " result(s)";
1062 
1063   for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
1064     auto verification = VerifyShapeOperandAndResult(
1065         op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
1066     if (failed(verification)) return verification;
1067   }
1068 
1069   return success();
1070 }
1071 
1072 namespace {
1073 // Canonicalization pattern for ShapeNOp that don't have all
1074 // static input shapes. Replacing output values corresponding to static input
1075 // types may enable optimizations in users of the values.
1076 class ShapeNPartialStaticInputShape : public OpRewritePattern<ShapeNOp> {
1077   using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const1078   LogicalResult matchAndRewrite(ShapeNOp op,
1079                                 PatternRewriter &rewriter) const override {
1080     if (op.getNumOperands() == 0) {
1081       rewriter.eraseOp(op);
1082       return success();
1083     }
1084 
1085     int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
1086 
1087     SmallVector<Value, 4> results(op.getNumOperands());
1088     SmallVector<int64_t, 4> dynamic_indices;
1089     SmallVector<Value, 4> dynamic_inputs;
1090     SmallVector<Type, 4> result_types;
1091     for (auto e : llvm::enumerate(op.getOperands())) {
1092       if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) {
1093         results[e.index()] = rewriter.create<TF::ConstOp>(op.getLoc(), result);
1094       } else {
1095         dynamic_indices.push_back(e.index());
1096         dynamic_inputs.push_back(e.value());
1097         result_types.push_back(op.getType(e.index()));
1098       }
1099     }
1100 
1101     if (dynamic_inputs.size() == op.getNumOperands()) {
1102       // Cannot canonicalize ShapeN if all inputs are dynamic.
1103       return failure();
1104     }
1105 
1106     // Create a ShapeNOp for all dynamic inputs.
1107     if (!dynamic_inputs.empty()) {
1108       auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
1109           op.getLoc(), result_types, dynamic_inputs);
1110       for (auto index_result :
1111            llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) {
1112         results[std::get<0>(index_result)] = std::get<1>(index_result);
1113       }
1114     }
1115 
1116     rewriter.replaceOp(op, results);
1117     return success();
1118   }
1119 };
1120 
1121 // Canonicalize ShapeNOp to ShapeOp if there is only one operand.
1122 class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
1123   using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const1124   LogicalResult matchAndRewrite(ShapeNOp op,
1125                                 PatternRewriter &rewriter) const override {
1126     if (op.getNumOperands() != 1) {
1127       return failure();
1128     }
1129     auto shape = rewriter.create<TF::ShapeOp>(op.getLoc(), op.getType(0),
1130                                               op.getOperand(0));
1131     rewriter.replaceOp(op, {shape});
1132     return success();
1133   }
1134 };
1135 }  // namespace
1136 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1137 void ShapeNOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1138                                            MLIRContext *context) {
1139   results.insert<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
1140 }
1141 
1142 //===----------------------------------------------------------------------===//
1143 // SizeOp
1144 //===----------------------------------------------------------------------===//
1145 
1146 // Verifies that,
1147 //
1148 // * Input type, if is a ranked tensor, has at most INT32_MAX dimensions.
1149 //
Verify(SizeOp op)1150 static LogicalResult Verify(SizeOp op) {
1151   if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
1152     return op.emitOpError(
1153         "requires ranked input tensor to be of rank INT32_MAX or less");
1154 
1155   // Output type needs to be scalar.
1156   if (!IsOfRankOrUnranked(op.output(), /*rank=*/0))
1157     return op.emitOpError("requires scalar output");
1158 
1159   return success();
1160 }
1161 
fold(ArrayRef<Attribute> operands)1162 OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
1163   ShapedType output_type = getType().cast<ShapedType>();
1164   if (!output_type.hasRank()) return {};
1165   ShapedType input_type = getOperand().getType().cast<ShapedType>();
1166   if (!input_type.hasStaticShape()) return {};
1167   int size = input_type.getNumElements();
1168   return DenseElementsAttr::get(
1169       output_type,
1170       IntegerAttr::get(output_type.getElementType(), /*value=*/size));
1171 }
1172 
1173 //===----------------------------------------------------------------------===//
1174 // SliceOp
1175 //===----------------------------------------------------------------------===//
1176 
1177 // Verifies that:
1178 //
1179 // - operands begin and size are 1D with the same number of elements.
1180 // - if the input is a ranked tensor, the rank of the input equals the number
1181 //   of elements in operands begin and size.
1182 // - if begin are constants, that
1183 //   0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
1184 //   and
1185 //   size[i] == output_ty.getShape()[i]
1186 // - if begins aren't constant but the input is a ranked tensor, that
1187 //   size[i] <= input_ty.getShape()[i]
1188 // - output rank is the same as input rank
1189 //
Verify(SliceOp op)1190 static LogicalResult Verify(SliceOp op) {
1191   RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
1192   if (begin_ty && begin_ty.getRank() != 1) {
1193     return op.emitOpError() << "requires begin operand to be 1D tensor";
1194   }
1195 
1196   RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
1197   if (size_ty && size_ty.getRank() != 1) {
1198     return op.emitOpError() << "requires size operand to be 1D tensor";
1199   }
1200 
1201   if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
1202       !size_ty.hasStaticShape())
1203     return success();
1204 
1205   if (begin_ty.getNumElements() != size_ty.getNumElements()) {
1206     return op.emitOpError() << "requires begin and size operands to have the"
1207                                " same number of elements";
1208   }
1209 
1210   auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1211   if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
1212     return op.emitOpError() << "requires number of elements in begin and size "
1213                                "are equal to input rank";
1214   }
1215 
1216   auto output_ty = op.output().getType().dyn_cast<RankedTensorType>();
1217   if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) {
1218     return op.emitOpError()
1219            << "requires output to have the same rank as input, but got input "
1220               "rank "
1221            << input_ty.getRank() << " and output rank " << output_ty.getRank();
1222   }
1223 
1224   DenseIntElementsAttr begin_indices;
1225   if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
1226     DenseIntElementsAttr slice_sizes;
1227     bool constant_slice_sizes =
1228         matchPattern(op.size(), m_Constant(&slice_sizes));
1229     int dim = 0;
1230     // TODO(jpienaar): Reformulate the shape verification below to not use magic
1231     // constants.
1232     for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
1233       int64_t begin_index = raw_begin_index.getSExtValue();
1234       int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
1235       int64_t slice_size = constant_slice_sizes
1236                                ? slice_sizes.getValue<APInt>(dim).getSExtValue()
1237                                : 0;
1238       int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1;
1239 
1240       if (slice_size == -1 && input_size != -1) {
1241         slice_size = input_size - begin_index;
1242       }
1243       if (output_size != -1 && constant_slice_sizes &&
1244           output_size != slice_size) {
1245         return op.emitOpError()
1246                << "requires output size to have the same size of slice, got "
1247                   "slice size "
1248                << slice_size << " and output size " << output_size;
1249       }
1250       if (begin_index < 0 ||
1251           (input_size != -1 && begin_index + slice_size > input_size)) {
1252         return op.emitOpError()
1253                << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
1254       }
1255       ++dim;
1256     }
1257   } else if (input_ty) {
1258     // If the inputs are ranked, we can do a few more sanity checks.
1259     DenseIntElementsAttr slice_sizes;
1260     if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
1261       auto input_shape = input_ty.getShape();
1262       for (int64_t i = 0; i < input_ty.getRank(); ++i) {
1263         int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
1264         int64_t input_size = input_shape[i];
1265         if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
1266           return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
1267                                      "is unknown at compile time";
1268         }
1269       }
1270     }
1271   }
1272 
1273   return success();
1274 }
1275 
1276 //===----------------------------------------------------------------------===//
1277 // SoftmaxOp
1278 //===----------------------------------------------------------------------===//
1279 
Verify(SoftmaxOp op)1280 static LogicalResult Verify(SoftmaxOp op) {
1281   if (!HasRankAtLeast(op.logits(), 1)) {
1282     return op.emitOpError("requires operand to have rank at least 1");
1283   }
1284   return success();
1285 }
1286 
1287 //===----------------------------------------------------------------------===//
1288 // SoftmaxCrossEntropyWithLogitsOp
1289 //===----------------------------------------------------------------------===//
1290 
1291 // Verifies that,
1292 //
1293 // * Input types are broadcast compatible and the broadcasted type has rank two.
1294 //
Verify(SoftmaxCrossEntropyWithLogitsOp op)1295 static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
1296   auto broadcasted_ty = OpTrait::util::getBroadcastedType(
1297                             op.features().getType(), op.labels().getType())
1298                             .dyn_cast_or_null<ShapedType>();
1299   if (!broadcasted_ty ||
1300       (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
1301     return op.emitOpError(
1302         "requires features and labels to be broadcast compatible to rank two");
1303 
1304   return success();
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // SpaceToBatchNDOp
1309 //===----------------------------------------------------------------------===//
1310 
SpaceToBatchNDBlockRank(const TensorType block_shape_type,const TensorType paddings_type)1311 int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type,
1312                                 const TensorType paddings_type) {
1313   if (block_shape_type.hasStaticShape()) {
1314     return block_shape_type.getShape()[0];
1315   } else if (paddings_type.hasStaticShape()) {
1316     return paddings_type.getShape()[0];
1317   } else {
1318     return -1;
1319   }
1320 }
1321 
Verify(SpaceToBatchNDOp op)1322 static LogicalResult Verify(SpaceToBatchNDOp op) {
1323   const auto input_type = op.input().getType().cast<TensorType>();
1324   const auto block_shape_type = op.block_shape().getType().cast<TensorType>();
1325   const auto paddings_type = op.paddings().getType().cast<TensorType>();
1326 
1327   // Check that block_shape has rank 1.
1328   if (!IsOfRankOrUnranked(op.block_shape(), 1)) {
1329     return op.emitOpError() << "requires rank of block_shape = 1; got "
1330                             << block_shape_type.getRank();
1331   }
1332 
1333   // Check that paddings has rank 2.
1334   if (!IsOfRankOrUnranked(op.paddings(), 2)) {
1335     return op.emitOpError()
1336            << "requires rank of paddings = 2; got " << paddings_type.getRank();
1337   }
1338 
1339   // Check that paddings.shape[1]=2.
1340   if (paddings_type.hasStaticShape() && paddings_type.getShape()[1] != 2) {
1341     return op.emitOpError() << "requires paddings.shape[1] to be 2; got "
1342                             << paddings_type.getShape()[1];
1343   }
1344 
1345   // Check that block_shape and paddings have consistent ranks.
1346   if (block_shape_type.hasStaticShape() && paddings_type.hasStaticShape() &&
1347       block_shape_type.getShape()[0] != paddings_type.getShape()[0]) {
1348     return op.emitOpError()
1349            << "requires block_shape.shape[0] must equal paddings.shape[0]";
1350   }
1351 
1352   const int64_t block_rank =
1353       SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
1354 
1355   // Further checks require block_rank to be known.
1356   if (block_rank == -1) {
1357     return success();
1358   }
1359 
1360   // check that rank of input_type >= block_rank + 1
1361   if (input_type.hasRank() && input_type.getRank() < 1 + block_rank) {
1362     return op.emitOpError() << "requires rank of input >= 1 + rank of block";
1363   }
1364 
1365   ElementsAttr block_shape_attr = nullptr;
1366   ElementsAttr paddings_attr = nullptr;
1367 
1368   // Check that block_shape[*] >= 1.
1369   if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) {
1370     uint64_t i = 0;
1371     for (auto block_len : block_shape_attr.getValues<APInt>()) {
1372       if (block_len.getSExtValue() < 1) {
1373         return op.emitOpError()
1374                << "requires all values of block_shape to be >= 1; "
1375                   "failed for dimension "
1376                << i;
1377       }
1378       ++i;
1379     }
1380   }
1381 
1382   // Check that paddings[*] >= 0.
1383   if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) {
1384     for (uint64_t i = 0; i < block_rank; ++i) {
1385       const int64_t pad_start =
1386           paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt();
1387       const int64_t pad_end =
1388           paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
1389       if (pad_start < 0 || pad_end < 0) {
1390         return op.emitOpError()
1391                << "requires all values of paddings to be >= 0; "
1392                   "failed for dimension "
1393                << i;
1394       }
1395     }
1396   }
1397 
1398   // Check that block_shape divides the padded input.
1399   if (input_type.hasStaticShape() && block_shape_attr && paddings_attr) {
1400     for (uint64_t i = 0; i < block_rank; ++i) {
1401       const int64_t input_len = input_type.getShape()[1 + i];
1402       const int64_t pad_start =
1403           paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt();
1404       const int64_t pad_end =
1405           paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
1406       const int64_t block_len =
1407           block_shape_attr.getValue({i}).cast<IntegerAttr>().getInt();
1408       if ((input_len + pad_start + pad_end) % block_len != 0) {
1409         return op.emitOpError()
1410                << "requires block_shape[i] divides "
1411                   "input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; "
1412                   "failed for i="
1413                << i;
1414       }
1415     }
1416   }
1417 
1418   return success();
1419 }
1420 
1421 //===----------------------------------------------------------------------===//
1422 // SparseSoftmaxCrossEntropyWithLogitsOp
1423 //===----------------------------------------------------------------------===//
1424 
Verify(SparseSoftmaxCrossEntropyWithLogitsOp op)1425 static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) {
1426   if (!IsOfRankOrUnranked(op.features(), 2)) {
1427     return op.emitOpError("requires features operand of rank two");
1428   }
1429   if (!IsOfRankOrUnranked(op.labels(), 1)) {
1430     return op.emitOpError("requires labels operand of rank one");
1431   }
1432   auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
1433   auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
1434   if (features_ty && labels_ty) {
1435     int64_t features_batches = features_ty.getDimSize(0);
1436     int64_t labels_batches = labels_ty.getDimSize(0);
1437     if (!ShapedType::isDynamic(features_batches) &&
1438         !ShapedType::isDynamic(labels_batches) &&
1439         features_batches != labels_batches)
1440       return op.emitOpError(
1441           "requires features and labels with matching first dimension");
1442   }
1443   return success();
1444 }
1445 
1446 //===----------------------------------------------------------------------===//
1447 // SplitOp
1448 //===----------------------------------------------------------------------===//
1449 
1450 // Verifies the input and split dimension operands for tf.Split/tf.SplitV.
1451 // Writes the split dimension's index (adjusted with input rank) via `dim_index`
1452 // if it's a constant.
1453 template <class Op>
VerifySplitInputAndSplitDim(Op op,Optional<int64_t> * dim_index)1454 LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
1455   *dim_index = llvm::None;
1456 
1457   Value split_dim = op.split_dim();
1458   if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
1459     if (split_dim_type.getRank() != 0)
1460       return op.emitOpError(
1461           "split dimension should be an integer scalar tensor");
1462 
1463   // We can perform further verification if the input tensor to be split has
1464   // known rank and the split dimension tensor is a constant.
1465 
1466   auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
1467   if (!input_type) return success();
1468 
1469   int64_t input_rank = input_type.getRank();
1470   if (input_rank == 0)
1471     return op.emitOpError("cannot split scalar input tensor");
1472 
1473   DenseIntElementsAttr split_dim_attr;
1474   if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
1475 
1476   int64_t index = (*split_dim_attr.begin()).getSExtValue();
1477 
1478   if (index + input_rank < 0 || index >= input_rank) {
1479     return op.emitOpError("split dimension must be in range [-")
1480            << input_rank << ", " << input_rank << ")";
1481   }
1482 
1483   if (index < 0) index += input_rank;
1484   *dim_index = index;
1485 
1486   return success();
1487 }
1488 
Verify(SplitOp op)1489 static LogicalResult Verify(SplitOp op) {
1490   Optional<int64_t> dim_index;
1491   if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1492   if (!dim_index) return success();
1493 
1494   int64_t input_dim_size =
1495       op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1496   if (input_dim_size == ShapedType::kDynamicSize) return success();
1497 
1498   if (input_dim_size % op.getNumResults() != 0)
1499     return op.emitOpError("dimension #")
1500            << *dim_index << " not divisible by the number of result tensors";
1501 
1502   return success();
1503 }
1504 
1505 //===----------------------------------------------------------------------===//
1506 // SplitVOp
1507 //===----------------------------------------------------------------------===//
1508 
Verify(SplitVOp op)1509 static LogicalResult Verify(SplitVOp op) {
1510   auto split_sizes_type =
1511       op.size_splits().getType().dyn_cast<RankedTensorType>();
1512   if (!split_sizes_type) return success();
1513 
1514   if (split_sizes_type.getRank() != 1 ||
1515       (split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize &&
1516        split_sizes_type.getDimSize(0) != op.getNumResults()))
1517     return op.emitOpError("split sizes should be a 1D tensor of ")
1518            << op.getNumResults() << " elements";
1519 
1520   Optional<int64_t> dim_index = 0;
1521   if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1522   if (!dim_index) return success();
1523 
1524   int64_t input_dim_size =
1525       op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1526   if (input_dim_size == ShapedType::kDynamicSize) return success();
1527 
1528   // If split sizes come from a constant, they must sum to the dimension size
1529   // along split_dim, and we can have no more than one dynamic dimension.
1530   DenseIntElementsAttr split_sizes_attr;
1531   if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
1532     return success();
1533 
1534   int64_t total_dim_size = 0;  // Total dimension size assigned to splits
1535   llvm::Optional<int> dynamic_dim_index;
1536 
1537   SmallVector<int64_t, 4> split_sizes;
1538   split_sizes.reserve(
1539       split_sizes_attr.getType().cast<ShapedType>().getNumElements());
1540 
1541   for (auto dim : llvm::enumerate(split_sizes_attr)) {
1542     int64_t dim_val = dim.value().getSExtValue();
1543     split_sizes.push_back(dim_val);
1544     if (dim_val == ShapedType::kDynamicSize) {
1545       // We cannot have more than one dynamic dimension.
1546       if (dynamic_dim_index)
1547         return op.emitOpError(
1548             "cannot have more than one dynamic dimension in split sizes");
1549       dynamic_dim_index = dim.index();
1550     } else {
1551       total_dim_size += dim_val;
1552     }
1553   }
1554 
1555   if (!dynamic_dim_index && total_dim_size != input_dim_size)
1556     return op.emitOpError(
1557                "split sizes must sum up to the dimension size along split "
1558                "dimension, found ")
1559            << total_dim_size << " vs " << input_dim_size;
1560 
1561   if (dynamic_dim_index && total_dim_size > input_dim_size)
1562     return op.emitOpError(
1563                "split sizes must sum up to be less than or equal to the "
1564                "dimension size along split dimension, found ")
1565            << total_dim_size << " vs " << input_dim_size;
1566 
1567   return success();
1568 }
1569 
1570 //===----------------------------------------------------------------------===//
1571 // SquareOp
1572 //===----------------------------------------------------------------------===//
1573 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1574 void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1575                                            MLIRContext *context) {
1576   results.insert<SquareOfSub>(context);
1577 }
1578 
1579 //===----------------------------------------------------------------------===//
1580 // SubOp
1581 //===----------------------------------------------------------------------===//
1582 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1583 void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1584                                         MLIRContext *context) {
1585   results.insert<SubOfNeg>(context);
1586 }
1587 
fold(ArrayRef<Attribute> operands)1588 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1589   return IdentityArithmeticOpFolder<SubOp>(*this, operands);
1590 }
1591 
1592 //===----------------------------------------------------------------------===//
1593 // SumOp
1594 //===----------------------------------------------------------------------===//
1595 
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)1596 void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
1597                   Value reduction_indices, BoolAttr keep_dims) {
1598   Type out_ty =
1599       InferReductionOpType(input, reduction_indices, keep_dims, &builder);
1600   build(builder, result, out_ty, input, reduction_indices, keep_dims);
1601 }
1602 
1603 // TODO: Templatize this fold for all reduction ops.
fold(ArrayRef<Attribute> operands)1604 OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
1605   auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
1606   if (!input_ty) return {};
1607   auto result_ty = getType().template dyn_cast<RankedTensorType>();
1608   if (!result_ty) return {};
1609 
1610   // Bypass this op if the result has the same shape and type. This can happen
1611   // if the input tensor has size 0 or size 1.
1612   if (!keep_dims() && input_ty == result_ty) {
1613     return input();
1614   }
1615   return {};
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // StridedSliceOp
1620 //===----------------------------------------------------------------------===//
1621 
1622 // TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
1623 // tf.SliceOp if both of the following are true:
1624 // - All strides have a known value equal to 1
1625 // - No masks are set (or masks can be applied by transforming the inputs to
1626 //   Slice)
1627 
1628 // Verifies that,
1629 //
1630 // - begin, end and strides operands are 1D and they have the same number of
1631 //   elements. Here, the number of elements should be less than 32 to support
1632 //   32-bit mask attributes.
1633 // - None of the strides values are zero.
1634 // - Ellipsis mask can have at most one bit set.
1635 
1636 template <class OpTy>
VerifyStridedSliceBase(OpTy op)1637 static LogicalResult VerifyStridedSliceBase(OpTy op) {
1638   // Expected size for operands begin, end and strides vector operands.
1639   int64_t expected_size = -1;
1640 
1641   for (Value val : {op.begin(), op.end(), op.strides()}) {
1642     auto operand_ty = val.getType().dyn_cast<ShapedType>();
1643     if (!operand_ty || !operand_ty.hasStaticShape()) {
1644       // TensorFlow constant ops may have non-static shape because the shape is
1645       // not propagated during constant folding. If the defining op for this
1646       // operand is a constant op, use the constant op's attribute to get the
1647       // actual shape.
1648       DenseIntElementsAttr attr;
1649       if (!matchPattern(val, m_Constant(&attr))) continue;
1650       operand_ty = attr.getType();
1651     }
1652 
1653     if (operand_ty.getRank() != 1)
1654       return op.emitOpError()
1655              << "requires begin, end and strides to be 1D tensors";
1656 
1657     int64_t length = operand_ty.getDimSize(0);
1658     if (length == -1) continue;
1659 
1660     if (expected_size == -1) {
1661       // This op uses 32-bit masks.
1662       if (length >= 32)
1663         return op.emitOpError(
1664             "requires begin, end and strides operands with less than 32 "
1665             "elements");
1666 
1667       expected_size = length;
1668     } else if (length != expected_size) {
1669       return op.emitOpError() << "requires begin, end and strides to have the "
1670                                  "same number of elements";
1671     }
1672   }
1673 
1674   // If strides are constants, verify that none of the element is zero.
1675   DenseIntElementsAttr strides;
1676   if (matchPattern(op.strides(), m_Constant(&strides))) {
1677     if (llvm::is_contained(strides.getValues<APInt>(), 0))
1678       return op.emitOpError("requires non-zero strides");
1679   }
1680 
1681   // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
1682   // exists only no more than one ellipsis.
1683   uint32_t ellipsis_mask = op.ellipsis_mask();
1684   if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
1685     return op.emitOpError("cannot have multiple ellipses");
1686 
1687   return success();
1688 }
1689 
1690 // Clamps the given `val`: returns `low` if `val` is less than `low`; returns
1691 // `high` if `high` is less than `val`; otherwise returns `val`.
1692 template <class T>
Clamp(const T & val,const T & low,const T & high)1693 constexpr const T &Clamp(const T &val, const T &low, const T &high) {
1694   assert(!(high < low));
1695   return (val < low) ? low : (high < val) ? high : val;
1696 }
1697 
1698 // Checks if the `index` bit of `val` is set.
1699 template <class T>
IsSet(const T & val,unsigned index)1700 constexpr bool IsSet(const T &val, unsigned index) {
1701   return (val & (1 << index)) != 0;
1702 }
1703 
1704 // Sets the `index` bit of `val`.
1705 template <class T>
Set(T & val,unsigned index)1706 constexpr void Set(T &val, unsigned index) {
1707   val |= (1 << index);
1708 }
1709 
1710 // Unset the `index` bit of `val`.
1711 template <class T>
Unset(T & val,unsigned index)1712 constexpr void Unset(T &val, unsigned index) {
1713   val &= ~(1 << index);
1714 }
1715 
1716 // Copy the `src_index` bit of `src` to `dst_index` bit of `dst`.
1717 template <class T>
CopyBit(const T & src,unsigned src_index,T & dst,unsigned dst_index)1718 constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
1719                        unsigned dst_index) {
1720   if (IsSet(src, src_index))
1721     Set(dst, dst_index);
1722   else
1723     Unset(dst, dst_index);
1724 }
1725 
1726 // The sparse spec of strided slice does not correspond to the number of
1727 // dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
1728 // 4, 8) would have dims = 2.
1729 struct SparseSliceSpec {
1730   int64_t dims;
1731   int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
1732   const ArrayRef<int64_t> &begin;
1733   const ArrayRef<int64_t> &end;
1734   const ArrayRef<int64_t> &strides;
1735 };
1736 
1737 // The dense spec of strided slice is the canonicalized version of sparse spec.
1738 // The number of dimensions of dense spec correspond to the number of dimensions
1739 // in operand tensor.
1740 struct DenseSliceSpec {
1741   int64_t dims;
1742   int32_t begin_mask, end_mask, shrink_axis_mask;
1743   SmallVectorImpl<int64_t> &begin;
1744   SmallVectorImpl<int64_t> &end;
1745   SmallVectorImpl<int64_t> &strides;
1746 };
1747 
1748 // Make a sparse spec into a dense index spec.
1749 // The sparse spec does not correspond to the number of dimensions
1750 // Make a dense spec that corresponds to the number of dimensions
1751 //
1752 // For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then
1753 // we need to produce the missing begin_mask, end_mask for the first two
1754 // dimensions i.e. foo[:, :, 3:, 2].
BuildDenseSliceSpec(const SparseSliceSpec & sparse,DenseSliceSpec * dense)1755 static void BuildDenseSliceSpec(const SparseSliceSpec &sparse,
1756                                 DenseSliceSpec *dense) {
1757   // Build expanded dense begin, end, strides, begin_mask, end_mask, and
1758   // shrink_axis_mask.
1759   dense->begin.resize(dense->dims);
1760   dense->end.resize(dense->dims);
1761   dense->strides.resize(dense->dims);
1762   dense->begin_mask = 0;
1763   dense->end_mask = 0;
1764   dense->shrink_axis_mask = 0;
1765 
1766   // Count number of new_axis after ellipsis. This helps in calculating the
1767   // number of dimensions ellipsis represents in the sparse spec.
1768   bool ellipsis_seen = false;
1769   int num_new_axis_after_ellipsis = 0;
1770   for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1771     if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
1772       num_new_axis_after_ellipsis++;
1773     if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
1774   }
1775 
1776   int dense_index = 0;
1777   for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1778     if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
1779     if (IsSet(sparse.ellipsis_mask, sparse_index)) {
1780       auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1781                                      1 + num_new_axis_after_ellipsis,
1782                                  dense->dims);
1783       // Expand ellipsis into the appropriate dense indices. From current index
1784       // until next_index, all dimensions would have begin and end masks set and
1785       // stride 1, i.e., get all elements in those dimensions.
1786       for (; dense_index < next_index; ++dense_index) {
1787         dense->begin[dense_index] = dense->end[dense_index] = 0;
1788         dense->strides[dense_index] = 1;
1789         Set(dense->begin_mask, dense_index);
1790         Set(dense->end_mask, dense_index);
1791       }
1792       continue;
1793     }
1794     assert(dense_index < dense->dims);
1795     // Copy over the sparse indices to dense indices if ellipsis_mask and
1796     // new_axis_mask are not set.
1797     dense->begin[dense_index] = sparse.begin[sparse_index];
1798     dense->end[dense_index] = sparse.end[sparse_index];
1799     dense->strides[dense_index] = sparse.strides[sparse_index];
1800     CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index);
1801     CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index);
1802     CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask,
1803             dense_index);
1804     dense_index++;
1805   }
1806 }
1807 
1808 // For the given `input_shape`, calculates the sliced shape using the given
1809 // `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and
1810 // `shrink_axis_mask` masks. Updates the result back to `input_shape`. If
1811 // `shrink_axis_mask` is not zero, this function will not drop the corresponding
1812 // dimensions in `input_shape`; it will turn them into 1s. At the same time,
1813 // canonicalizes `begin`, `end`, and `strides. The calculation follows
1814 // tf.StridedSlice op semantics.
CalculateSlicedShapeFromDenseIndices(MutableArrayRef<int64_t> input_shape,int32_t begin_mask,int32_t end_mask,int32_t shrink_axis_mask,MutableArrayRef<int64_t> begin,MutableArrayRef<int64_t> end,MutableArrayRef<int64_t> stride)1815 static void CalculateSlicedShapeFromDenseIndices(
1816     MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
1817     int32_t shrink_axis_mask, MutableArrayRef<int64_t> begin,
1818     MutableArrayRef<int64_t> end, MutableArrayRef<int64_t> stride) {
1819   assert(input_shape.size() <= 32);  // Only 32-bit masks are supported.
1820 
1821   // Make sure ranges' ranks are consistent with the input.
1822   assert(input_shape.size() == begin.size());
1823   assert(input_shape.size() == end.size());
1824   assert(input_shape.size() == stride.size());
1825 
1826   for (int i = 0, e = input_shape.size(); i < e; ++i) {
1827     if (ShapedType::isDynamic(input_shape[i])) continue;
1828 
1829     int64_t dim_i = input_shape[i];
1830     int64_t begin_i = begin[i];
1831     int64_t end_i = end[i];
1832     int64_t stride_i = stride[i];
1833 
1834     // [0]: mask for begin, [1]: mask for end
1835     int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
1836     // [0]: bound for begin, [1]: bound for end
1837     int64_t bounds[] = {stride_i > 0 ? 0 : -1,
1838                         stride_i > 0 ? dim_i : dim_i - 1};
1839 
1840     // Canonicalizes the given range `point` (begin/end) according to the
1841     // current dimension. `c` means case: 0 for begin, 1 for end.
1842     auto canonicalize = [&](int64_t point, int c) {
1843       if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
1844 
1845       // Add dim as offset to negative range point.
1846       point = point < 0 ? dim_i + point : point;
1847       return Clamp(point, bounds[0], bounds[1]);
1848     };
1849 
1850     begin_i = canonicalize(begin_i, 0);
1851     end_i = canonicalize(end_i, 1);
1852 
1853     int64_t interval_len = end_i - begin_i;
1854     int64_t size_i = 0;
1855     // If internal length is zero or has different sign from stride, it's a
1856     // degenerated case: we are slicing nothing. Otherwise, calculate the sliced
1857     // size.
1858     if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
1859       size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
1860 
1861     begin[i] = begin_i;
1862     if (IsSet(shrink_axis_mask, i)) {
1863       // Shrink this dimension. It means we only take the element at begin_i.
1864       input_shape[i] = 1;
1865       end[i] = begin_i + 1;
1866       stride[i] = 1;
1867     } else {
1868       input_shape[i] = size_i;
1869       end[i] = end_i;
1870       stride[i] = stride_i;
1871     }
1872   }
1873 }
1874 
1875 // For the given `input_shape`, calculates the sliced shape using the given
1876 // `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`,
1877 // `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks.
1878 // Updates the result back to `input_shape`.
CalculateSlicedShapeFromSparseIndices(MutableArrayRef<int64_t> input_shape,ArrayRef<int64_t> sparse_begin,ArrayRef<int64_t> sparse_end,ArrayRef<int64_t> sparse_strides,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,SmallVectorImpl<int64_t> * begin,SmallVectorImpl<int64_t> * end,SmallVectorImpl<int64_t> * stride)1879 static void CalculateSlicedShapeFromSparseIndices(
1880     MutableArrayRef<int64_t> input_shape, ArrayRef<int64_t> sparse_begin,
1881     ArrayRef<int64_t> sparse_end, ArrayRef<int64_t> sparse_strides,
1882     int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
1883     int32_t new_axis_mask, int32_t shrink_axis_mask,
1884     SmallVectorImpl<int64_t> *begin, SmallVectorImpl<int64_t> *end,
1885     SmallVectorImpl<int64_t> *stride) {
1886   int64_t num_sparse_indices = sparse_begin.size();
1887   SparseSliceSpec sparse = {num_sparse_indices, begin_mask,    end_mask,
1888                             ellipsis_mask,      new_axis_mask, shrink_axis_mask,
1889                             sparse_begin,       sparse_end,    sparse_strides};
1890 
1891   // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
1892   // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
1893   // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
1894   if (sparse.ellipsis_mask == 0) {
1895     Set(sparse.ellipsis_mask, sparse.dims);
1896     sparse.dims++;
1897   }
1898 
1899   int64_t dims = input_shape.size();
1900   DenseSliceSpec dense = {dims,
1901                           /*begin_mask = */ 0,
1902                           /*end_mask = */ 0,
1903                           /*shrink_axis_mask = */ 0,
1904                           *begin,
1905                           *end,
1906                           *stride};
1907 
1908   BuildDenseSliceSpec(sparse, &dense);
1909   CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask,
1910                                        dense.end_mask, dense.shrink_axis_mask,
1911                                        *begin, *end, *stride);
1912 }
1913 
GetSlicedBoundRanges(SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)1914 bool StridedSliceOp::GetSlicedBoundRanges(
1915     SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
1916     SmallVectorImpl<int64_t> *slice_stride) {
1917   // TODO(hinsu): Support lowering for ops with dynamic begin and end values
1918   // when it is possible to derive indices based on mask attributes.
1919   DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
1920   if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
1921       !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
1922       !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
1923     return false;
1924 
1925   auto input_ty = this->input().getType().dyn_cast<RankedTensorType>();
1926   if (!input_ty || !input_ty.hasStaticShape()) return false;
1927   auto input_shape = llvm::to_vector<4>(input_ty.getShape());
1928 
1929   SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
1930 
1931   for (const APInt &index : sparse_begin_attr)
1932     sparse_begin.push_back(index.getSExtValue());
1933   for (const APInt &index : sparse_end_attr)
1934     sparse_end.push_back(index.getSExtValue());
1935   for (const APInt &stride : sparse_strides_attr)
1936     sparse_strides.push_back(stride.getSExtValue());
1937 
1938   CalculateSlicedShapeFromSparseIndices(
1939       input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
1940       end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
1941       slice_begin, slice_end, slice_stride);
1942   return true;
1943 }
1944 
fold(ArrayRef<Attribute> operands)1945 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
1946   // Fold StridedSlice operation if it extracts statically known dimensions.
1947   //
1948   // For example,
1949   //
1950   //   %shape  = tf.Shape(%arg)                   // %arg: tensor<?x2x3x1xf32>
1951   //   %height = tf.StridedSlice(%shape, 1, 2, 1)
1952   //
1953   // In this case %height can be replaced with a constant 2.
1954   //
1955   // Or,
1956   //
1957   //   %shape  = tf.Shape(%arg)                   // %arg: tensor<?x2x3x1xf32>
1958   //   %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
1959   //
1960   // In this case %spatial_shape can be replaced with a constant [2, 3].
1961 
1962   // Input to strided slice op is defined by shape operation.
1963   auto shape_op = input().getDefiningOp<ShapeOp>();
1964   if (!shape_op) {
1965     return {};
1966   }
1967 
1968   // `begin`, `end` and `strides` should be constant in order to infer static
1969   // dimension.
1970   DenseIntElementsAttr begin_attr, end_attr, strides_attr;
1971   if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
1972       !matchPattern(end(), m_Constant(&end_attr)) ||
1973       !matchPattern(strides(), m_Constant(&strides_attr)) ||
1974       begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
1975       strides_attr.getNumElements() != 1) {
1976     return {};
1977   }
1978 
1979   // Do not fold when `new_axis_mask` is set. It's likely to break the shape
1980   // of output. Typically, `new_axis_mask` is not set in this canonicalization
1981   // pattern.
1982   if (new_axis_mask() != 0) return {};
1983 
1984   auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
1985   // Only ranked tensor can be folded.
1986   if (!tensor_ty) return {};
1987 
1988   int64_t rank = tensor_ty.getRank();
1989   int64_t begin_int = begin_attr.getValue<APInt>(0).getSExtValue();
1990   int64_t end_int = end_attr.getValue<APInt>(0).getSExtValue();
1991   int64_t strides_int = strides_attr.getValue<APInt>(0).getSExtValue();
1992 
1993   // Canonicalize `begin` and `end` in case of negative index.
1994   if (begin_int < 0) begin_int += rank;
1995   if (end_int < 0) end_int += rank;
1996 
1997   // Create `begin` and `end` from `*_mask`. Note that we don't care about
1998   // `new_axis_mask` as it can be inferred from `output_ty`.
1999   if (shrink_axis_mask() == 1) {
2000     // When `shrink_axis_mask` is set, output is always a scalar so only
2001     // one element is sliced.
2002     end_int = begin_int + 1;
2003   }
2004   if (begin_mask() == 1) {
2005     begin_int = (strides_int > 0) ? 0 : rank - 1;
2006   }
2007   if (end_mask() == 1) {
2008     end_int = (strides_int > 0) ? rank : -1;
2009   }
2010   if (ellipsis_mask() == 1) {
2011     begin_int = 0;
2012     end_int = rank;
2013   }
2014 
2015   // It's possible that `begin` and `end` are out of bound. See
2016   // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
2017   if (strides_int > 0) {
2018     begin_int = std::min(begin_int, rank);
2019     end_int = std::min(end_int, rank);
2020   } else {
2021     begin_int = std::min(begin_int, rank - 1);
2022     end_int = std::min(end_int, rank - 1);
2023   }
2024 
2025   SmallVector<int64_t, 2> sub_shape;
2026   // Only handle cases that have something to slice to avoid infinite for-loop.
2027   if ((end_int > begin_int && strides_int > 0) ||
2028       (end_int < begin_int && strides_int < 0)) {
2029     // Extract sub-shape only if all of those dimensions are static.
2030     for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
2031          i += strides_int) {
2032       if (tensor_ty.isDynamicDim(i)) {
2033         return {};
2034       }
2035       sub_shape.push_back(tensor_ty.getDimSize(i));
2036     }
2037   }
2038 
2039   // For unranked or dynamic output, we infer the output type to either a
2040   // scalar or a vector based on `shrink_axis_mask` because we have rejected
2041   // the case of `new_axis_mask` != 0.
2042   auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
2043   auto output_ty = output().getType().dyn_cast<RankedTensorType>();
2044   if (!output_ty || !output_ty.hasStaticShape()) {
2045     if (shrink_axis_mask() == 1) {
2046       output_ty = RankedTensorType::get({}, output_elt_ty);
2047     } else {
2048       output_ty = RankedTensorType::get(
2049           {static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
2050     }
2051   }
2052 
2053   // Down-cast to 32 bit int if needed.
2054   if (output_elt_ty.isInteger(32)) {
2055     SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
2056     std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
2057                    [](int64_t d) { return static_cast<int32_t>(d); });
2058     return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
2059   }
2060   return DenseIntElementsAttr::get(output_ty, sub_shape);
2061 }
2062 
2063 //===----------------------------------------------------------------------===//
2064 // StridedSliceGradOp
2065 //===----------------------------------------------------------------------===//
2066 
Verify(StridedSliceGradOp op)2067 static LogicalResult Verify(StridedSliceGradOp op) {
2068   auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
2069   if (shape_type && shape_type.getRank() != 1)
2070     return op.emitOpError("'shape' operand must be 1D tensor, but got ")
2071            << shape_type.getRank() << "D tensor";
2072 
2073   if (failed(VerifyStridedSliceBase(op))) return failure();
2074 
2075   // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
2076   // the sliced type from StridedSlice.
2077 
2078   return success();
2079 }
2080 
GetSlicedShapeAndBoundRanges(SmallVectorImpl<int64_t> * input_shape,SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)2081 bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
2082     SmallVectorImpl<int64_t> *input_shape,
2083     SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
2084     SmallVectorImpl<int64_t> *slice_stride) {
2085   DenseIntElementsAttr shape_attr;
2086   DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
2087   if (!matchPattern(shape(), m_Constant(&shape_attr)) ||
2088       !matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
2089       !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
2090       !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
2091     return false;
2092 
2093   int rank = std::distance(shape_attr.begin(), shape_attr.end());
2094 
2095   input_shape->clear();
2096   input_shape->reserve(rank);
2097   for (const APInt &dim : shape_attr)
2098     input_shape->push_back(dim.getSExtValue());
2099 
2100   SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
2101 
2102   for (const APInt &index : sparse_begin_attr)
2103     sparse_begin.push_back(index.getSExtValue());
2104   for (const APInt &index : sparse_end_attr)
2105     sparse_end.push_back(index.getSExtValue());
2106   for (const APInt &stride : sparse_strides_attr)
2107     sparse_strides.push_back(stride.getSExtValue());
2108 
2109   CalculateSlicedShapeFromSparseIndices(
2110       *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
2111       end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
2112       slice_begin, slice_end, slice_stride);
2113   return true;
2114 }
2115 
2116 //===----------------------------------------------------------------------===//
2117 // SummaryWriterOp
2118 //===----------------------------------------------------------------------===//
2119 
2120 llvm::SmallVector<ResourceHandleValueAndId, 4>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2121 SummaryWriterOp::GetResourceHandleValueAndIdList(
2122     llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2123     int64_t &next_id) {
2124   llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2125   return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2126                                           writer(), resource_handle_id_map,
2127                                           next_id)};
2128 }
2129 
2130 //===----------------------------------------------------------------------===//
2131 // TPUExecuteOp
2132 //===----------------------------------------------------------------------===//
2133 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2134 void TPUExecuteOp::getEffects(
2135     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2136         &effects) {
2137   effects.reserve(args().size() + 1);
2138 
2139   // There may be some TPU Embedding ops in the computation, so this effect is
2140   // added conservatively.
2141   effects.emplace_back(MemoryEffects::Write::get(),
2142                        ResourceEffects::TPUEmbedding::get());
2143 
2144   for (Value value : args()) {
2145     if (value.getType()
2146             .cast<TensorType>()
2147             .getElementType()
2148             .isa<ResourceType>()) {
2149       // Conservatively mark resource handles as read and write, as without
2150       // analyzing TPUCompile, there is not sufficient information to determine
2151       // effects on resources. For the MLIR bridge, this op will never be
2152       // populated with resource handles and tf.TPUExecuteAndUpdateVariables is
2153       // used instead.
2154       effects.emplace_back(MemoryEffects::Read::get(), value,
2155                            ResourceEffects::Variable::get());
2156       effects.emplace_back(MemoryEffects::Write::get(), value,
2157                            ResourceEffects::Variable::get());
2158     }
2159   }
2160 }
2161 
2162 //===----------------------------------------------------------------------===//
2163 // TPUExecuteAndUpdateVariablesOp
2164 //===----------------------------------------------------------------------===//
2165 
Verify(TPUExecuteAndUpdateVariablesOp op)2166 static LogicalResult Verify(TPUExecuteAndUpdateVariablesOp op) {
2167   int num_resource_args = 0;
2168   for (Type arg_type : op.args().getTypes())
2169     if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
2170       ++num_resource_args;
2171 
2172   auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
2173                         int min) -> LogicalResult {
2174     if (indices.size() != num_resource_args)
2175       return op.emitOpError()
2176              << "requires '" << name
2177              << "' to be the same size as number of resource handles in 'args' "
2178                 "("
2179              << num_resource_args << "), but got " << indices.size();
2180 
2181     for (auto entry : llvm::enumerate(indices.getValue())) {
2182       auto int_attr = entry.value().cast<IntegerAttr>();
2183       if (int_attr.getInt() < min)
2184         return op.emitOpError()
2185                << "requires '" << name << "' to contain values of at least "
2186                << min << ", but got " << int_attr.getInt() << " at index "
2187                << entry.index();
2188     }
2189 
2190     return success();
2191   };
2192 
2193   return failure(
2194       failed(check_attr(op.device_var_reads_indices(),
2195                         /*name=*/"device_var_reads_indices", /*min=*/0)) ||
2196       failed(check_attr(op.device_var_updates_indices(),
2197                         /*name=*/"device_var_updates_indices", /*min=*/-1)));
2198 }
2199 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2200 void TPUExecuteAndUpdateVariablesOp::getEffects(
2201     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2202         &effects) {
2203   effects.reserve(device_var_reads_indices().size() + 1);
2204 
2205   // There may be some TPU Embedding ops in the computation, so this effect is
2206   // added conservatively.
2207   effects.emplace_back(MemoryEffects::Write::get(),
2208                        ResourceEffects::TPUEmbedding::get());
2209   auto resource_handles = llvm::make_filter_range(args(), [](Value value) {
2210     return value.getType()
2211         .cast<TensorType>()
2212         .getElementType()
2213         .isa<ResourceType>();
2214   });
2215 
2216   for (auto &entry : llvm::enumerate(resource_handles)) {
2217     Value value = entry.value();
2218     effects.emplace_back(MemoryEffects::Read::get(), value,
2219                          ResourceEffects::Variable::get());
2220     if (device_var_updates_indices()
2221             .getValue()[entry.index()]
2222             .cast<IntegerAttr>()
2223             .getInt() >= 0)
2224       effects.emplace_back(MemoryEffects::Write::get(), value,
2225                            ResourceEffects::Variable::get());
2226   }
2227 }
2228 
2229 //===----------------------------------------------------------------------===//
2230 // TensorListReserveOp
2231 //===----------------------------------------------------------------------===//
2232 
Verify(TensorListReserveOp op)2233 static LogicalResult Verify(TensorListReserveOp op) {
2234   if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2235       !IsOfRankOrUnranked(op.element_shape(), 1)) {
2236     return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2237   }
2238 
2239   if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
2240     return op.emitOpError("requires num_elements operand to be 0D tensor");
2241   }
2242   return success();
2243 }
2244 
2245 //===----------------------------------------------------------------------===//
2246 // TensorListElementShapeOp
2247 //===----------------------------------------------------------------------===//
2248 
fold(ArrayRef<Attribute> operands)2249 OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
2250   int width =
2251       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2252   auto variant_type =
2253       getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
2254   if (variant_type.getSubtypes().empty()) return {};
2255   return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
2256 }
2257 
2258 //===----------------------------------------------------------------------===//
2259 // TensorListStackOp
2260 //===----------------------------------------------------------------------===//
2261 
Verify(TensorListStackOp op)2262 static LogicalResult Verify(TensorListStackOp op) {
2263   if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2264       !IsOfRankOrUnranked(op.element_shape(), 1)) {
2265     return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2266   }
2267   return success();
2268 }
2269 
2270 //===----------------------------------------------------------------------===//
2271 // TensorScatterUpdateOp
2272 //===----------------------------------------------------------------------===//
2273 
Verify(TensorScatterUpdateOp op)2274 static LogicalResult Verify(TensorScatterUpdateOp op) {
2275   if (!HasRankAtLeast(op.tensor(), 1))
2276     return op.emitOpError(
2277         "requires tensor operand to have at least 1 dimension");
2278   if (!HasRankAtLeast(op.indices(), 1))
2279     return op.emitOpError(
2280         "requires indices operand to have at least 1 dimension");
2281   if (!HasRankAtLeast(op.updates(), 1))
2282     return op.emitOpError(
2283         "requires updates operand to have at least 1 dimension");
2284 
2285   auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
2286   auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2287   if (!tensor_ty || !indices_ty) return success();
2288 
2289   int64_t num_index_dims = indices_ty.getShape().back();
2290   if (ShapedType::isDynamic(num_index_dims)) return success();
2291 
2292   if (num_index_dims > tensor_ty.getRank())
2293     return op.emitOpError(
2294         "requires tensor operand with rank greater than or equal to the "
2295         "indices operand's last dimensions");
2296   return success();
2297 }
2298 
2299 //===----------------------------------------------------------------------===//
2300 // TileOp
2301 //===----------------------------------------------------------------------===//
2302 
2303 // Verifies that,
2304 //
2305 // - input has at least rank 1
2306 // - multiples is rank 1
2307 // - multiples.size() == input.rank()
2308 // - input.rank() == output.rank()
2309 // - Elements in multiples are non-negative
2310 // - input.shape[i] * multiples[i] == output.shape[i]
2311 //   for i in [0, input.rank() - 1]
2312 
Verify(TileOp op)2313 static LogicalResult Verify(TileOp op) {
2314   auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
2315   auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
2316   auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
2317 
2318   if (multiples_type && multiples_type.getRank() != 1) {
2319     return op.emitOpError() << "expected multiples to be rank 1, got rank = "
2320                             << multiples_type.getRank();
2321   }
2322 
2323   if (input_type && multiples_type && multiples_type.hasStaticShape() &&
2324       (input_type.getRank() != multiples_type.getNumElements() ||
2325        (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
2326     return op.emitOpError()
2327            << "expected size of multiples equal to rank of input"
2328            << ", got multiples of size " << multiples_type.getNumElements()
2329            << ", and input of rank " << input_type.getRank();
2330   }
2331 
2332   if (input_type && output_type) {
2333     if (input_type.getRank() != output_type.getRank()) {
2334       return op.emitOpError()
2335              << "expected rank of input to equal to rank of output"
2336              << ", got input of rank " << input_type.getRank()
2337              << ", and output of rank " << output_type.getRank();
2338     }
2339 
2340     DenseIntElementsAttr multiples_attr;
2341     if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
2342       for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
2343         const int64_t input_dim = input_type.getDimSize(i);
2344         const int64_t output_dim = output_type.getDimSize(i);
2345         const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
2346 
2347         if (m < 0) {
2348           return op.emitOpError()
2349                  << "expected multiples to be non-negative, got "
2350                  << "multiples[" << i << "] = " << m;
2351         }
2352 
2353         if (!ShapedType::isDynamic(input_dim) &&
2354             !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
2355           return op.emitOpError()
2356                  << "requires input.shape[" << i << "] (" << input_dim << ")"
2357                  << " * " << m << " to be equal to "
2358                  << "output.shape[" << i << "] (" << output_dim << ")";
2359         }
2360       }
2361     }
2362   }
2363 
2364   return success();
2365 }
2366 
fold(ArrayRef<Attribute> operands)2367 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
2368   DenseIntElementsAttr multiples_attr;
2369   if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
2370     // Return input directly when multiples are all ones,
2371     // regardless what input is.
2372     if (multiples_attr.isSplat() &&
2373         multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
2374       return input();
2375     }
2376   }
2377   return {};
2378 }
2379 
2380 //===----------------------------------------------------------------------===//
2381 // TopKV2Op
2382 //===----------------------------------------------------------------------===//
2383 
Verify(TopKV2Op op)2384 static LogicalResult Verify(TopKV2Op op) {
2385   if (!HasRankAtLeast(op.input(), 1))
2386     return op.emitOpError(
2387         "requires input operand to have at least 1 dimension");
2388 
2389   if (!IsOfRankOrUnranked(op.k(), 0))
2390     return op.emitOpError("requires k operand to be 0D tensor");
2391 
2392   return success();
2393 }
2394 
2395 //===----------------------------------------------------------------------===//
2396 // ToBoolOp
2397 //===----------------------------------------------------------------------===//
2398 
2399 namespace {
2400 // If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
2401 // into an identity or an equality comparison.
2402 class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
2403   using OpRewritePattern<ToBoolOp>::OpRewritePattern;
matchAndRewrite(ToBoolOp op,PatternRewriter & rewriter) const2404   LogicalResult matchAndRewrite(ToBoolOp op,
2405                                 PatternRewriter &rewriter) const override {
2406     auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
2407     // If the input is an unranked tensor, cannpt rewrite.
2408     if (!type) return failure();
2409 
2410     // Expected return type of the ToBool operation. The return type of ToBool
2411     // operation is always 0D tensor of bool type.
2412     auto result_type = op.getResult().getType().cast<RankedTensorType>();
2413 
2414     // If input is already a tensor<i1>, it can be folded into an identity.
2415     if (type == result_type) {
2416       rewriter.replaceOp(op, op.getOperand());
2417       return success();
2418     }
2419 
2420     if (type.getRank() == 0) {
2421       // If the input is a scalar tensor, the ToBool can be expanded to
2422       // element != 0 (for numerical values) or element == empty (for string).
2423       Type element_type = type.getElementType();
2424       Attribute zero_attr;
2425       if (element_type.isIntOrFloat())
2426         zero_attr = rewriter.getZeroAttr(type);
2427       else if (element_type.isa<TF::StringType>())
2428         zero_attr = DenseStringElementsAttr::get(type, {""});
2429 
2430       if (!zero_attr) return failure();
2431 
2432       auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
2433       rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
2434           op, result_type, op.getOperand(), zero_const, false);
2435     } else {
2436       // If the input is a non-scalar ranked tensor, ToBool can be expanded
2437       // to numElements != 0. numElements will be 0 iff one of the dimensions is
2438       // zero.
2439       bool any_zero =
2440           llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
2441       rewriter.replaceOpWithNewOp<TF::ConstOp>(
2442           op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
2443     }
2444     return success();
2445   }
2446 };
2447 }  // namespace
2448 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2449 void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2450                                            MLIRContext *context) {
2451   results.insert<ToBoolOfRankedTensor>(context);
2452 }
2453 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2454 LogicalResult ToBoolOp::inferReturnTypes(
2455     MLIRContext *context, Optional<Location> location, ValueRange operands,
2456     DictionaryAttr attributes, RegionRange regions,
2457     SmallVectorImpl<Type> &inferredReturnTypes) {
2458   inferredReturnTypes.push_back(
2459       RankedTensorType::get({}, IntegerType::get(context, 1)));
2460   return success();
2461 }
2462 
2463 //===----------------------------------------------------------------------===//
2464 // TransposeOp
2465 //===----------------------------------------------------------------------===//
2466 
Verify(TransposeOp op)2467 static LogicalResult Verify(TransposeOp op) {
2468   auto perm_type = op.perm().getType().dyn_cast<RankedTensorType>();
2469   auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
2470   auto y_type = op.y().getType().dyn_cast<RankedTensorType>();
2471 
2472   if (perm_type && perm_type.getRank() != 1) {
2473     return op.emitOpError()
2474            << "expected perm to be a 1-D Tensor, got perm of rank "
2475            << perm_type.getRank();
2476   }
2477 
2478   if (x_type && y_type && x_type.getRank() != y_type.getRank()) {
2479     return op.emitOpError() << "x should be of the same rank with y, got "
2480                             << "x of rank " << x_type.getRank()
2481                             << ", and y of rank " << y_type.getRank();
2482   }
2483 
2484   if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) {
2485     return success();
2486   }
2487 
2488   if (x_type.getRank() != perm_type.getNumElements()) {
2489     return op.emitOpError() << "expected perm to be a 1-D Tensor of size "
2490                             << "equal to the rank of x, got perm of size "
2491                             << perm_type.getNumElements() << ", and x of rank "
2492                             << x_type.getRank();
2493   }
2494 
2495   DenseIntElementsAttr attr_perm;
2496   if (matchPattern(op.perm(), m_Constant(&attr_perm))) {
2497     // y.shape[i] should be equal to x.shape[perm[i]]
2498     // for i = [0, 1, ..., rank(x) - 1]
2499     for (auto e : llvm::enumerate(attr_perm)) {
2500       const int64_t y_idx = e.index();
2501       const int64_t y_dim = y_type.getDimSize(y_idx);
2502       const int64_t x_idx = e.value().getSExtValue();
2503       const int64_t x_dim = x_type.getDimSize(x_idx);
2504       if (y_dim != ShapedType::kDynamicSize &&
2505           x_dim != ShapedType::kDynamicSize && y_dim != x_dim) {
2506         return op.emitOpError()
2507                << "requires y.shape[" << y_idx << "] (" << y_dim << ") "
2508                << "to be equal to x.shape[perm[" << x_idx << "]] "
2509                << "(" << x_dim << ")";
2510       }
2511     }
2512   }
2513 
2514   return success();
2515 }
2516 
2517 // TODO(jpienaar): perm could be optional too.
build(OpBuilder & builder,OperationState & result,Value x,Value perm)2518 void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
2519                         Value perm) {
2520   auto x_type = x.getType().cast<TensorType>();
2521   // If value is unranked, then so is results.
2522   if (!x_type.hasRank())
2523     return TransposeOp::build(builder, result,
2524                               UnrankedTensorType::get(x_type.getElementType()),
2525                               x, perm);
2526 
2527   // TODO(jpienaar): Handle unknown perm case.
2528 
2529   // TODO(jpienaar): Extract utility function.
2530   auto etype = x_type.cast<ShapedType>().getElementType();
2531   DenseIntElementsAttr attr_shape;
2532   if (matchPattern(perm, m_Constant(&attr_shape))) {
2533     llvm::SmallVector<int64_t, 4> const_shape;
2534     if (attr_shape.isSplat()) {
2535       const_shape.assign(
2536           attr_shape.getNumElements(),
2537           x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
2538     } else {
2539       const_shape.reserve(attr_shape.getNumElements());
2540       for (const auto &dim : attr_shape)
2541         const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
2542     }
2543     return TransposeOp::build(
2544         builder, result, RankedTensorType::get(const_shape, etype), x, perm);
2545   }
2546   return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x,
2547                             perm);
2548 }
2549 
2550 namespace {
2551 
FoldIdentityTranspose(TransposeOp op)2552 OpFoldResult FoldIdentityTranspose(TransposeOp op) {
2553   DenseIntElementsAttr perm;
2554   if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
2555   const auto elements = perm.getValues<APInt>();
2556 
2557   for (auto it : llvm::enumerate(elements)) {
2558     if (it.index() != it.value()) return {};
2559   }
2560 
2561   // TODO(jpienaar): Remove if/when we handle this more generally.
2562   if (op.getType() != op.x().getType()) {
2563     // If the types don't match then only fold if all the operands are in the TF
2564     // dialect.
2565     for (auto user : op.getOperation()->getUsers())
2566       if (user->getDialect() != op->getDialect()) return {};
2567   }
2568 
2569   return op.x();
2570 }
2571 
FoldCancellableTranspose(TransposeOp op)2572 OpFoldResult FoldCancellableTranspose(TransposeOp op) {
2573   // Operand is a TransposeOp.
2574   auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
2575   if (!transpose) return {};
2576 
2577   // Permutations defined by constant operations.
2578   DenseIntElementsAttr perm0;
2579   DenseIntElementsAttr perm1;
2580   if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
2581       !matchPattern(transpose.perm(), m_Constant(&perm1)))
2582     return {};
2583 
2584   // With permutation indices that cancel each other
2585   if (!AreCancellablePermutations(perm0, perm1)) return {};
2586 
2587   return transpose.x();
2588 }
2589 
2590 }  // namespace
2591 
fold(ArrayRef<Attribute> operands)2592 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2593   if (auto folded = FoldIdentityTranspose(*this)) return folded;
2594   if (auto folded = FoldCancellableTranspose(*this)) return folded;
2595   return {};
2596 }
2597 
2598 //===----------------------------------------------------------------------===//
2599 // TruncateDivOp
2600 //===----------------------------------------------------------------------===//
2601 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2602 void TruncateDivOp::getCanonicalizationPatterns(
2603     OwningRewritePatternList &results, MLIRContext *context) {
2604   results.insert<TruncateDivWithSqrtDivisor>(context);
2605 }
2606 
2607 //===----------------------------------------------------------------------===//
2608 // NonMaxSuppressionV3Op
2609 //===----------------------------------------------------------------------===//
2610 
2611 namespace {
2612 
2613 // Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op.
2614 class NMSV3ToNMSV4Op : public OpRewritePattern<NonMaxSuppressionV3Op> {
2615   using OpRewritePattern<NonMaxSuppressionV3Op>::OpRewritePattern;
matchAndRewrite(NonMaxSuppressionV3Op nms_op,PatternRewriter & rewriter) const2616   LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op,
2617                                 PatternRewriter &rewriter) const override {
2618     if (nms_op.getNumOperands() != 5) {
2619       return failure();
2620     }
2621     SmallVector<Type, 2> new_result_types;
2622     new_result_types.push_back(nms_op.getType());
2623     auto input_ty = nms_op.getType().template cast<ShapedType>();
2624     // corresponds to the second result type of nmsv4
2625     RankedTensorType valid_output_type =
2626         RankedTensorType::get({}, input_ty.getElementType());
2627     new_result_types.push_back(valid_output_type);
2628 
2629     auto nmsv4 = rewriter.create<TF::NonMaxSuppressionV4Op>(
2630         nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(),
2631         nms_op.max_output_size(), nms_op.iou_threshold(),
2632         nms_op.score_threshold());
2633     // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the
2634     // two are different (v4 expects two output values vs v3 requires only one.
2635     nms_op.replaceAllUsesWith(nmsv4.getResult(0));
2636     return success();
2637   }
2638 };
2639 }  // namespace.
2640 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2641 void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
2642     OwningRewritePatternList &results, MLIRContext *context) {
2643   results.insert<NMSV3ToNMSV4Op>(context);
2644 }
2645 
2646 //===----------------------------------------------------------------------===//
2647 // FusedBatchNormOp
2648 //===----------------------------------------------------------------------===//
2649 
2650 namespace {
2651 
2652 class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
2653   using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,PatternRewriter & rewriter) const2654   LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
2655                                 PatternRewriter &rewriter) const override {
2656     auto new_result_types =
2657         llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
2658     // reserve_space_3
2659     new_result_types.push_back(
2660         UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
2661 
2662     OperationState new_state(tf_fused_batch_norm_op.getLoc(),
2663                              TF::FusedBatchNormV3Op::getOperationName(),
2664                              tf_fused_batch_norm_op.getOperands(),
2665                              new_result_types,
2666                              tf_fused_batch_norm_op->getAttrs());
2667     Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
2668 
2669     rewriter.replaceOp(tf_fused_batch_norm_op,
2670                        tf_fused_batch_norm_op_v3->getResults().drop_back());
2671     return success();
2672   }
2673 };
2674 }  // namespace.
2675 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2676 void FusedBatchNormOp::getCanonicalizationPatterns(
2677     OwningRewritePatternList &results, MLIRContext *context) {
2678   results.insert<ConvertFusedBatchNorm>(context);
2679 }
2680 
2681 //===----------------------------------------------------------------------===//
2682 // UnpackOp
2683 //===----------------------------------------------------------------------===//
2684 
Verify(UnpackOp op)2685 static LogicalResult Verify(UnpackOp op) {
2686   auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
2687   if (!value_type) return success();
2688 
2689   int64_t value_rank = value_type.getRank();
2690   int64_t axis = op.axis();
2691   if (axis < -value_rank || axis >= value_rank)
2692     return op.emitOpError("axis attribute must be in the range of [-")
2693            << value_rank << ", " << value_rank << ')';
2694 
2695   axis = GetDimForAxis(axis, value_rank);
2696   int64_t dim_size = value_type.getDimSize(axis);
2697   if (ShapedType::isDynamic(dim_size)) return success();
2698 
2699   if (dim_size != op.getNumResults())
2700     return op.emitOpError("result count must be equal to ") << dim_size;
2701 
2702   return success();
2703 }
2704 
2705 namespace {
2706 
2707 // Hoist coefficient-wise unary operation out of the Unpack op:
2708 //
2709 //   %unpacked:N = "tf.Unpack"(%0)
2710 //   %neg0 = "tf.Neg"(%unpacked#0)
2711 //   %neg1 = "tf.Neg"(%unpacked#1)
2712 //   ...
2713 //   %negN-1 = "tf.Neg"(%unpacked:N-1)
2714 //
2715 // Rewrite it to:
2716 //
2717 //   %neg = "tf.Neg"(%0)
2718 //   %unpacked:N = "tf.Unpack"(%neg)
2719 class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
2720  public:
HoistCwiseUnaryOutOfUnpack(MLIRContext * context)2721   explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
2722       : OpRewritePattern<UnpackOp>(context) {}
2723   LogicalResult matchAndRewrite(UnpackOp op,
2724                                 PatternRewriter &rewriter) const override;
2725 };
2726 
matchAndRewrite(UnpackOp op,PatternRewriter & rewriter) const2727 LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
2728     UnpackOp op, PatternRewriter &rewriter) const {
2729   auto loc = op.getLoc();
2730 
2731   // First unpack user must be coeff-wise unary operation.
2732   Operation *first_user = *op->getUsers().begin();
2733   if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
2734 
2735   // All unpack users must be defined by the op of same kind.
2736   bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
2737     return user->getName() == first_user->getName();
2738   });
2739   if (!users_same_op) return failure();
2740 
2741   // Pass unpack operand to unary operation.
2742   OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
2743                                     op.getOperand(), op.getOperand().getType(),
2744                                     ArrayRef<NamedAttribute>());
2745   Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
2746 
2747   // Unpack results after applying unary operation.
2748   auto unpack_unary_op = rewriter.create<UnpackOp>(
2749       loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
2750 
2751   // Bypass all users of the original unpack operation and use `unpack_unary_op`
2752   // results instead.
2753   for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
2754     OpResult old_result = std::get<0>(pair);  // result of original Unpack
2755     OpResult new_result = std::get<1>(pair);  // result of transformed Unpack
2756     for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
2757       rewriter.replaceOp(user, ValueRange(new_result));
2758   }
2759 
2760   // Erase original unpack operation.
2761   rewriter.eraseOp(op.getOperation());
2762 
2763   return success();
2764 }
2765 
2766 }  // namespace
2767 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2768 void UnpackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2769                                            MLIRContext *context) {
2770   results.insert<HoistCwiseUnaryOutOfUnpack>(context);
2771 }
2772 
2773 //===----------------------------------------------------------------------===//
2774 // Unsorted segment reduction ops
2775 //===----------------------------------------------------------------------===//
2776 
2777 template <class Op>
VerifyUnsortedSegmentReduction(Op op)2778 static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
2779   if (!HasRankAtMost(op.num_segments(), 0))
2780     return op.emitOpError("number of segments should be a 0-D tensor");
2781 
2782   auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
2783   auto segment_ids_type =
2784       op.segment_ids().getType().template dyn_cast<RankedTensorType>();
2785   if (data_type && segment_ids_type) {
2786     if (data_type.getRank() < segment_ids_type.getRank())
2787       return op.emitOpError(
2788           "requires segment ids rank to be less than or equal to data's rank");
2789 
2790     int index = 0;
2791     for (auto shape_pair :
2792          llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
2793       int64_t segment_id_dim = std::get<0>(shape_pair);
2794       int64_t data_dim = std::get<1>(shape_pair);
2795       if (!ShapedType::isDynamic(segment_id_dim) &&
2796           !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
2797         return op.emitOpError(
2798                    "requires segment ids shape to be a prefix of data shape, "
2799                    "but dimension #")
2800                << index << " differs: " << segment_id_dim << " vs. "
2801                << data_dim;
2802       ++index;
2803     }
2804   }
2805 
2806   DenseIntElementsAttr num_segments_attr;
2807   if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
2808     int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
2809     if (num_segments < 0)
2810       return op.emitOpError("num of segments cannot be negative");
2811   }
2812 
2813   return success();
2814 }
2815 
2816 //===----------------------------------------------------------------------===//
2817 // VarHandleOp
2818 //===----------------------------------------------------------------------===//
2819 
2820 llvm::SmallVector<ResourceHandleValueAndId, 4>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2821 VarHandleOp::GetResourceHandleValueAndIdList(
2822     llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2823     int64_t &next_id) {
2824   llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2825   return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2826                                           resource(), resource_handle_id_map,
2827                                           next_id)};
2828 }
2829 
2830 //===----------------------------------------------------------------------===//
2831 // VarIsInitializedOp
2832 //===----------------------------------------------------------------------===//
2833 
2834 namespace {
2835 
2836 /// Erase VarIsInitializedOp operations with no uses. This op has side effect on
2837 /// resources (read-only), but can still be deleted if it has zero uses.
2838 struct EraseDeadVarIsInitializedOp
2839     : public OpRewritePattern<VarIsInitializedOp> {
2840   using OpRewritePattern<VarIsInitializedOp>::OpRewritePattern;
2841 
matchAndRewritemlir::TF::__anon4e3f163f1411::EraseDeadVarIsInitializedOp2842   LogicalResult matchAndRewrite(VarIsInitializedOp op,
2843                                 PatternRewriter &rewriter) const override {
2844     if (!op.use_empty()) return failure();
2845     rewriter.eraseOp(op);
2846     return success();
2847   }
2848 };
2849 }  // end anonymous namespace.
2850 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2851 void VarIsInitializedOp::getCanonicalizationPatterns(
2852     OwningRewritePatternList &patterns, MLIRContext *context) {
2853   patterns.insert<EraseDeadVarIsInitializedOp>(context);
2854 }
2855 
2856 //===----------------------------------------------------------------------===//
2857 // VariableOp
2858 //===----------------------------------------------------------------------===//
2859 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2860 void VariableOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2861                                              MLIRContext *context) {
2862   results.insert<VariableToVariableV2>(context);
2863 }
2864 
2865 //===----------------------------------------------------------------------===//
2866 // VariableShapeOp
2867 //===----------------------------------------------------------------------===//
2868 
Verify(VariableShapeOp op)2869 static LogicalResult Verify(VariableShapeOp op) {
2870   auto input_type = op.input().getType().cast<TensorType>();
2871   if (input_type.hasStaticShape() && input_type.getNumElements() != 1)
2872     return op.emitOpError("requires input to have one resource");
2873 
2874   auto resource_type = input_type.getElementType().cast<TF::ResourceType>();
2875   auto subtypes = resource_type.getSubtypes();
2876   switch (subtypes.size()) {
2877     case 1:
2878       return VerifyShapeOperandAndResult(
2879           op, resource_type.getSubtypes().front(), op.getType());
2880     case 0:
2881       return VerifyShapeOperandAndResult(op, Type(), op.getType());
2882     default:
2883       return op.emitOpError(
2884           "requires resource input type to have at most 1 subtype");
2885   }
2886 }
2887 
fold(ArrayRef<Attribute> operands)2888 OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
2889   int width =
2890       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2891   auto resource_type =
2892       getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
2893   if (resource_type.getSubtypes().empty()) return {};
2894   return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
2895 }
2896 
2897 //===----------------------------------------------------------------------===//
2898 // WhileOp
2899 //===----------------------------------------------------------------------===//
2900 
VerifyWhileTypes(Operation * op,TypeRange cond_input,TypeRange body_input,TypeRange body_result,bool shape_invariant)2901 static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
2902                                       TypeRange body_input,
2903                                       TypeRange body_result,
2904                                       bool shape_invariant) {
2905   const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
2906   const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
2907   constexpr int kNumRegionTypeLists = 3;
2908   const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
2909       {body_result, "body result"},
2910       {cond_input, "condition input"},
2911       {body_input, "body input"},
2912   }};
2913 
2914   // A pair of type lists should be cast compatible with each other if one is
2915   // converted to the another for a function call or assignment or there is a
2916   // common source of inputs for both. Therefore, the While op requires the
2917   // following pairs of type lists to be cast compatible for the tensor_cast
2918   // operation:
2919   //
2920   // * Operands and cond inputs to call the cond function before the
2921   //   first iteration.
2922   // * Operands and body inputs to call the body function for the first
2923   //   iteration if the cond functions returns True or equivalent result.
2924   // * Operands and results to assign cond function arguments to op results if
2925   //   the cond function returns False or equivalent result. If the op is shape
2926   //   invariant, this does not hold as shapes can differ.
2927   // * All three pairs using cond inputs, body inputs and results as operand is
2928   //   a common source for all three.
2929   // * Body result and cond inputs to call the cond function for the subsequent
2930   //   iterations. Similarly, Body result should be compatible with body inputs
2931   //   and op results.
2932   //
2933   // Note that the operands and body results need not be compatible as they are
2934   // never converted from one to the another nor there is a common source
2935   // tensors. Compatibility requirement is not transitive.
2936 
2937   if (!shape_invariant &&
2938       failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
2939     return failure();
2940 
2941   // Skip the first pair as the While op operands and body function results does
2942   // not need to be compatible with each other.
2943   for (int i = 1; i < kNumRegionTypeLists; ++i)
2944     if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
2945       return failure();
2946 
2947   for (int i = 0; i < kNumRegionTypeLists; ++i)
2948     if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
2949       return failure();
2950 
2951   for (int i = 0; i < kNumRegionTypeLists; ++i)
2952     for (int j = i + 1; j < kNumRegionTypeLists; ++j)
2953       if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
2954                                                region_types[j])))
2955         return failure();
2956 
2957   return success();
2958 }
2959 
verifySymbolUses(SymbolTableCollection & symbol_table)2960 LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
2961   // TODO(jpienaar): Remove.
2962   if (failed(WhileOpAdaptor(*this).verify(getLoc()))) return failure();
2963 
2964   auto cond_fn = symbol_table.lookupNearestSymbolFrom<FuncOp>(*this, cond());
2965   auto body_fn = symbol_table.lookupNearestSymbolFrom<FuncOp>(*this, body());
2966   if (!cond_fn) {
2967     return emitOpError("cond refers to an undefined function : ") << cond();
2968   }
2969   if (!body_fn) {
2970     return emitOpError("body refers to an undefined function : ") << body();
2971   }
2972 
2973   auto cond_fn_type = cond_fn.getType();
2974   auto body_fn_type = body_fn.getType();
2975 
2976   // Verify that the cond function has exactly one result.
2977   if (cond_fn_type.getNumResults() != 1)
2978     return emitOpError("requires cond function to have exactly one result");
2979 
2980   return VerifyWhileTypes(*this, /*cond_input=*/cond_fn_type.getInputs(),
2981                           /*body_input=*/body_fn_type.getInputs(),
2982                           /*body_result=*/body_fn_type.getResults(),
2983                           shape_invariant());
2984 }
2985 
2986 //===----------------------------------------------------------------------===//
2987 // WhileRegionOp
2988 //===----------------------------------------------------------------------===//
Verify(WhileRegionOp op)2989 static LogicalResult Verify(WhileRegionOp op) {
2990   // Verify that the condition generates a single tensor<i1> result.
2991   Operation *cond_yield = op.cond().front().getTerminator();
2992   if (cond_yield->getNumOperands() != 1)
2993     return op.emitOpError()
2994            << "condition should have a single tensor<i1> result";
2995 
2996   auto cond_type =
2997       cond_yield->getOperand(0).getType().dyn_cast<RankedTensorType>();
2998   if (!cond_type || !cond_type.getShape().equals({}) ||
2999       !cond_type.getElementType().isInteger(/*width=*/1))
3000     return op.emitOpError()
3001            << "condition should have a single tensor<i1> result";
3002 
3003   Operation *body_yield = op.body().front().getTerminator();
3004   if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
3005                               /*body_input=*/op.body().getArgumentTypes(),
3006                               /*body_result=*/body_yield->getOperandTypes(),
3007                               op.shape_invariant())))
3008     return failure();
3009   return success();
3010 }
3011 
3012 //===----------------------------------------------------------------------===//
3013 // WhileRegionOp LoopLikeOpInterface
3014 //===----------------------------------------------------------------------===//
3015 
getLoopBody()3016 Region &WhileRegionOp::getLoopBody() { return body(); }
3017 
isDefinedOutsideOfLoop(Value value)3018 bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) {
3019   // If the Op defining the value exists and the defining op is outside the
3020   // scope of this WhileRegion, then we can infer that its defined outside.
3021   // The defining Op is outside the scope of this WhileRegion if this
3022   // WhileRegionOp is not an ancestor of the defining op in the parent chain.
3023   Operation *def_op = value.getDefiningOp();
3024   return def_op && !getOperation()->isAncestor(def_op);
3025 }
3026 
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)3027 LogicalResult WhileRegionOp::moveOutOfLoop(
3028     llvm::ArrayRef<mlir::Operation *> ops) {
3029   // Move the hoisted value to just before the while.
3030   Operation *while_op = this->getOperation();
3031   for (auto op : ops) op->moveBefore(while_op);
3032   return success();
3033 }
3034 
3035 //===----------------------------------------------------------------------===//
3036 // WhileRegionOp canonicalization
3037 //===----------------------------------------------------------------------===//
3038 namespace {
3039 // Eliminate values that pass through the WhileRegionOp body.
3040 struct WhileRegionEliminatePassThrough
3041     : public OpRewritePattern<WhileRegionOp> {
3042   using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
3043 
matchAndRewritemlir::TF::__anon4e3f163f1511::WhileRegionEliminatePassThrough3044   LogicalResult matchAndRewrite(WhileRegionOp while_op,
3045                                 PatternRewriter &rewriter) const override {
3046     // Remove any extern values that are explicitly captured and returned. Also
3047     // replace values that simply passthrough the body with extern values. The
3048     // block arguments of body and while match and so the corresponding cond
3049     // argument can be easily found.
3050     int old_num_operands = while_op.getNumOperands();
3051     int new_num_operands = old_num_operands;
3052     auto &body_block = while_op.body().front();
3053     auto &cond_block = while_op.cond().front();
3054     auto &yield = *body_block.getTerminator();
3055 
3056     // Bit mask indicating which operands will be removed.
3057     llvm::BitVector removed_operand(old_num_operands);
3058 
3059     for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3060       auto body_arg = body_block.getArgument(op_idx);
3061       auto yield_operand = LookThroughIdentity(yield.getOperand(op_idx));
3062       auto while_operand = while_op.getOperand(op_idx);
3063       if (body_arg == yield_operand || while_operand == yield_operand) {
3064         // Replace the use of the passthrough value with the while operand
3065         // in the body and condition regions, as well as the while output (if
3066         // type match)
3067         // TODO(jurahul): Use PatternRewriter API for IR modification.
3068         if (body_arg.getType() == while_operand.getType())
3069           body_arg.replaceAllUsesWith(while_operand);
3070 
3071         auto cond_arg = cond_block.getArgument(op_idx);
3072         if (cond_arg.getType() == while_operand.getType())
3073           cond_arg.replaceAllUsesWith(while_operand);
3074 
3075         auto result = while_op.getResult(op_idx);
3076         if (result.getType() == while_operand.getType())
3077           result.replaceAllUsesWith(while_operand);
3078       }
3079 
3080       // Now check if the operand is unused in both regions as well as the
3081       // result. If so, mark it for removal.
3082       if (body_block.getArgument(op_idx).use_empty() &&
3083           cond_block.getArgument(op_idx).use_empty() &&
3084           while_op.getResult(op_idx).use_empty()) {
3085         removed_operand.set(op_idx);
3086         new_num_operands--;
3087       }
3088     }
3089 
3090     if (new_num_operands == old_num_operands) return failure();
3091 
3092     // Compress the operands, region arguments, and outputs.
3093     SmallVector<Value, 4> new_while_operands;
3094     SmallVector<Type, 4> new_result_types;
3095     new_while_operands.reserve(new_num_operands);
3096     new_result_types.reserve(new_num_operands);
3097 
3098     // Build new operands and result type.
3099     for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3100       if (removed_operand.test(op_idx)) continue;
3101       new_while_operands.push_back(while_op.getOperand(op_idx));
3102       new_result_types.push_back(while_op.getResult(op_idx).getType());
3103     }
3104 
3105     // Create the new while operation.
3106     auto new_while_op = rewriter.create<WhileRegionOp>(
3107         while_op.getLoc(), new_result_types, new_while_operands,
3108         while_op->getAttrs());
3109 
3110     // Move region bodies to the new while.
3111     rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
3112                                 new_while_op.cond().end());
3113     rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
3114                                 new_while_op.body().end());
3115 
3116     auto &new_cond_block = new_while_op.cond().front();
3117     auto &new_body_block = new_while_op.body().front();
3118     auto &new_yield = *new_body_block.getTerminator();
3119 
3120     // Patch up the region bodies and yield.
3121     new_cond_block.eraseArguments(removed_operand);
3122     new_body_block.eraseArguments(removed_operand);
3123     new_yield.eraseOperands(removed_operand);
3124 
3125     // Build a vector of new results. Also patch up the region bodies and
3126     // yield.
3127     SmallVector<Value, 4> new_results(old_num_operands);
3128     int next_idx = 0;
3129     for (int op_idx : llvm::seq<int>(0, old_num_operands))
3130       if (!removed_operand.test(op_idx))
3131         new_results[op_idx] = new_while_op.getResult(next_idx++);
3132 
3133     rewriter.replaceOp(while_op, new_results);
3134     return success();
3135   }
3136 };
3137 
3138 }  // anonymous namespace
3139 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3140 void WhileRegionOp::getCanonicalizationPatterns(
3141     OwningRewritePatternList &results, MLIRContext *context) {
3142   results.insert<WhileRegionEliminatePassThrough>(context);
3143 }
3144 
3145 //===----------------------------------------------------------------------===//
3146 // XdivyOp
3147 //===----------------------------------------------------------------------===//
3148 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3149 void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3150                                           MLIRContext *context) {
3151   results.insert<XdivyWithSqrtDivisor>(context);
3152 }
3153 
3154 //===----------------------------------------------------------------------===//
3155 // XlaBroadcastHelperOp
3156 //===----------------------------------------------------------------------===//
3157 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3158 LogicalResult XlaBroadcastHelperOp::inferReturnTypes(
3159     MLIRContext *context, Optional<Location> location, ValueRange operands,
3160     DictionaryAttr attributes, RegionRange regions,
3161     SmallVectorImpl<Type> &inferredReturnTypes) {
3162   auto loc = location ? *location : mlir::UnknownLoc::get(context);
3163   XlaBroadcastHelperOpAdaptor op(operands, attributes);
3164   if (failed(op.verify(loc))) {
3165     return failure();
3166   }
3167 
3168   Value lhs = op.lhs();
3169   Value rhs = op.rhs();
3170   auto set_unranked_results = [&]() {
3171     auto unranked_lhs = UnrankedTensorType::get(getElementTypeOrSelf(lhs));
3172     inferredReturnTypes.push_back(unranked_lhs);
3173     auto unranked_rhs = UnrankedTensorType::get(getElementTypeOrSelf(rhs));
3174     inferredReturnTypes.push_back(unranked_rhs);
3175     return success();
3176   };
3177 
3178   RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
3179   RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
3180   if (!lhs_ty || !rhs_ty) return set_unranked_results();
3181 
3182   int64_t lhs_rank = lhs_ty.getRank();
3183   int64_t rhs_rank = rhs_ty.getRank();
3184 
3185   DenseIntElementsAttr dims;
3186   if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
3187     return set_unranked_results();
3188   }
3189 
3190   if (dims.size() == 0) {
3191     if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
3192       return emitOptionalError(
3193           location,
3194           "if broadcast_dims is empty, both arguments must have equal rank or "
3195           "at least one argument must be a scalar");
3196     }
3197     inferredReturnTypes.push_back(lhs_ty);
3198     inferredReturnTypes.push_back(rhs_ty);
3199     return success();
3200   }
3201 
3202   const bool broadcast_lhs = lhs_rank < rhs_rank;
3203   RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
3204   RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
3205 
3206   if (dims.size() != min_rank_ty.getRank()) {
3207     return emitOptionalError(
3208         location,
3209         "broadcast_dims must have size equal to the smaller argument rank");
3210   }
3211 
3212   int64_t output_rank = max_rank_ty.getRank();
3213   llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
3214   llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
3215   for (auto item : llvm::enumerate(dims)) {
3216     int64_t index = item.index();
3217     int64_t dim = item.value().getSExtValue();
3218     if (dim < 0 || dim > output_rank) {
3219       return emitOptionalError(location, "out of range broadcast dim");
3220     }
3221     if (is_broadcasted[dim]) {
3222       return emitOptionalError(location, "broadcast_dims has duplicates");
3223     }
3224     broadcast_shape[dim] = min_rank_ty.getDimSize(index);
3225     is_broadcasted[dim] = true;
3226   }
3227 
3228   if (broadcast_lhs) {
3229     inferredReturnTypes.push_back(
3230         RankedTensorType::get(broadcast_shape, lhs_ty.getElementType()));
3231     inferredReturnTypes.push_back(rhs_ty);
3232   } else {
3233     inferredReturnTypes.push_back(lhs_ty);
3234     inferredReturnTypes.push_back(
3235         RankedTensorType::get(broadcast_shape, rhs_ty.getElementType()));
3236   }
3237   return success();
3238 }
3239 
3240 //===----------------------------------------------------------------------===//
3241 // XlaSetDynamicDimensionSizeOp
3242 //===----------------------------------------------------------------------===//
3243 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3244 LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypes(
3245     MLIRContext *context, Optional<Location> location, ValueRange operands,
3246     DictionaryAttr attributes, RegionRange regions,
3247     SmallVectorImpl<Type> &inferredReturnTypes) {
3248   auto loc = location ? *location : mlir::UnknownLoc::get(context);
3249   XlaSetDynamicDimensionSizeOpAdaptor op(operands, attributes);
3250   if (failed(op.verify(loc))) return failure();
3251 
3252   TensorType operand_ty = op.input().getType().cast<TensorType>();
3253   Type element_ty = operand_ty.getElementType();
3254 
3255   TensorType result_ty;
3256   if (operand_ty.hasRank()) {
3257     auto shape = llvm::to_vector<4>(operand_ty.getShape());
3258 
3259     DenseIntElementsAttr dim_index_attr;
3260     if (matchPattern(op.dim_index(), m_Constant(&dim_index_attr))) {
3261       int64_t dim_index = dim_index_attr.getValue<APInt>({}).getSExtValue();
3262 
3263       int64_t rank = operand_ty.getRank();
3264       if (dim_index < 0 || dim_index >= rank) {
3265         return emitOptionalError(location, "dim_index (", dim_index,
3266                                  ") is out of range [0, ", rank, ")");
3267       }
3268       shape[dim_index] = RankedTensorType::kDynamicSize;
3269     } else {
3270       shape.assign(shape.size(), RankedTensorType::kDynamicSize);
3271     }
3272     result_ty = RankedTensorType::get(shape, element_ty);
3273   } else {
3274     result_ty = UnrankedTensorType::get(element_ty);
3275   }
3276 
3277   inferredReturnTypes.push_back(result_ty);
3278   return success();
3279 }
3280 
3281 }  // namespace TF
3282 }  // namespace mlir
3283 
3284 //===----------------------------------------------------------------------===//
3285 // TableGen'd op method definitions
3286 //===----------------------------------------------------------------------===//
3287 
3288 #define GET_OP_CLASSES
3289 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc"
3290