• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/BitVector.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/None.h"
33 #include "llvm/ADT/Optional.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/Sequence.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/ADT/StringSwitch.h"
40 #include "llvm/ADT/Twine.h"
41 #include "llvm/ADT/iterator_range.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
45 #include "mlir/Dialect/Traits.h"  // from @llvm-project
46 #include "mlir/IR/Attributes.h"  // from @llvm-project
47 #include "mlir/IR/Builders.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
49 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
50 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
51 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
52 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
53 #include "mlir/IR/Location.h"  // from @llvm-project
54 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
55 #include "mlir/IR/Matchers.h"  // from @llvm-project
56 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
57 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
58 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
59 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
60 #include "mlir/IR/Types.h"  // from @llvm-project
61 #include "mlir/IR/Value.h"  // from @llvm-project
62 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
63 #include "mlir/Parser/Parser.h"  // from @llvm-project
64 #include "mlir/Support/LLVM.h"  // from @llvm-project
65 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
66 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h"
72 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
74 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
77 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
78 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
79 #include "tensorflow/core/platform/logging.h"
80 #include "tensorflow/core/util/tensor_format.h"
81 
82 namespace mlir {
83 namespace TF {
84 
85 namespace {
86 // Returns the equivalent Value skipping through identity nodes.
LookThroughIdentity(Value result)87 Value LookThroughIdentity(Value result) {
88   while (isa_and_nonnull<IdentityOp, IdentityNOp>(result.getDefiningOp())) {
89     auto op_result = result.cast<OpResult>();
90     result = op_result.getOwner()->getOperand(op_result.getResultNumber());
91   }
92   return result;
93 }
94 
95 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
96 }  // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // NotEqualOp
100 //===----------------------------------------------------------------------===//
101 
verify()102 LogicalResult NotEqualOp::verify() {
103   NotEqualOp op = *this;
104   // If we allow inputs to have incompatible type, then nothing to do.
105   if (!op.incompatible_shape_error()) return success();
106 
107   // Otherwise, check inputs are broadcastable.
108   return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
109       op.getOperation());
110 }
111 
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)112 void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
113                        Value y, BoolAttr incompatible_shape_error) {
114   auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
115                                           incompatible_shape_error);
116   return build(builder, result, result_type, x, y, incompatible_shape_error);
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // OneHotOp
121 //===----------------------------------------------------------------------===//
122 
verify()123 LogicalResult OneHotOp::verify() {
124   OneHotOp op = *this;
125   int64_t axis = op.axis();
126 
127   auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
128   if (indices_ty &&
129       !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
130     return op.emitOpError()
131            << "expected axis (" << axis << ") to be -1 or between [0, "
132            << indices_ty.getShape().size() << "]";
133   }
134 
135   if (axis < -1) {
136     return op.emitOpError() << "expected axis (" << axis
137                             << ") to be -1 or between [0, rank(indices()))";
138   }
139 
140   if (!IsOfRankOrUnranked(op.depth(), 0)) {
141     return op.emitOpError() << "requires depth to be a scalar";
142   }
143   if (!IsOfRankOrUnranked(op.on_value(), 0)) {
144     return op.emitOpError() << "requires on_value to be a scalar";
145   }
146   if (!IsOfRankOrUnranked(op.off_value(), 0)) {
147     return op.emitOpError() << "requires off_value to be a scalar";
148   }
149 
150   DenseIntElementsAttr depth_attr;
151   if (matchPattern(op.depth(), m_Constant(&depth_attr))) {
152     if (depth_attr.getType().getRank() != 0)
153       return op.emitOpError() << "requires depth to be a scalar";
154     int64_t depth = depth_attr.getValues<APInt>()[0].getSExtValue();
155     if (depth < 0) {
156       return op.emitOpError() << "depth must be non-negative, got: " << depth;
157     }
158   }
159 
160   return success();
161 }
162 
InferOneHotOpType(Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)163 static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
164                                     Value off_value, IntegerAttr axis) {
165   int64_t axis_val = axis.getInt();
166   Type element_ty = on_value.getType().cast<TensorType>().getElementType();
167   auto unranked_ty = UnrankedTensorType::get(element_ty);
168   if (axis_val < -1) return unranked_ty;
169 
170   auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
171   if (!indices_ty) return unranked_ty;
172 
173   auto shape = llvm::to_vector<2>(indices_ty.getShape());
174   if (axis_val == -1) axis_val = shape.size();
175 
176   int64_t depth_val = ShapedType::kDynamicSize;
177   DenseIntElementsAttr depth_attr;
178   if (matchPattern(depth, m_Constant(&depth_attr)) &&
179       depth_attr.getNumElements() == 1)
180     depth_val = (*depth_attr.begin()).getSExtValue();
181   shape.insert(shape.begin() + axis_val, depth_val);
182   return RankedTensorType::get(shape, element_ty);
183 }
184 
build(OpBuilder & builder,OperationState & result,Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)185 void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices,
186                      Value depth, Value on_value, Value off_value,
187                      IntegerAttr axis) {
188   build(builder, result,
189         InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
190         depth, on_value, off_value, axis);
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // PackOp
195 //===----------------------------------------------------------------------===//
196 
verify()197 LogicalResult PackOp::verify() {
198   PackOp op = *this;
199   // TODO(hinsu): Convert variadic length attributes to derived attributes.
200   Operation::operand_range values = op.values();
201 
202   if (failed(VerifyTypesCompatibility(values,
203                                       /*mask_one_dim=*/false,
204                                       op.getOperation()))) {
205     return failure();
206   }
207 
208   int64_t inputs_rank = -1;
209   for (Value value : values) {
210     if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
211       // Exit early as input types are verified to be compatible so all ranked
212       // tensors have the same rank.
213       inputs_rank = ty.getRank();
214       break;
215     }
216   }
217   if (inputs_rank == -1) return success();
218 
219   // The values can be packed along any of the dimensions between 0 and
220   // inputs rank, inclusive. Also, as the negative axis values wrap around so
221   // the axis value range is [-(R+1), R+1).
222   int64_t range_begin = -inputs_rank - 1;  // Inclusive
223   int64_t range_end = inputs_rank + 1;     // Exclusive
224   int64_t axis = op.axis();
225   if (axis < range_begin || axis >= range_end) {
226     return op.emitError() << "attribute 'axis' should be within range ["
227                           << range_begin << ", " << range_end
228                           << "); actual value: " << axis;
229   }
230 
231   return success();
232 }
233 
fold(ArrayRef<Attribute> operands)234 OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
235   // Fold pack operation if it computes the input tensor shape:
236   //
237   //   %shape  = tf.Shape(%arg)                    // [? x ...]
238   //   %dim0   = tf.StridedSlice(%shape, 0, 1, 1)  // get unknown dim0 value
239   //   %pack   = tf.Pack(dim0, ...) { axis = 0 }   // [? x ...]
240   //
241   // Where `...` are some statically known dimensions. In this case %pack can be
242   // replaced with a %shape. This is a common pattern in models with a dynamic
243   // batch size.
244 
245   // Pack operation should pack at least two values.
246   if (values().size() < 2) return {};
247 
248   // Dimensions packed along axis = 0 (pack scalars into vector).
249   if (axis() != 0) return {};
250 
251   // First packed value is defined by a strided slice operation.
252   auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
253   if (!slice_op) return {};
254 
255   // Input to the slice op is defined by shape operation.
256   auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
257   if (!shape_op) return {};
258 
259   // Input tensor, which shape is reconstructed by the pack operation.
260   Value tensor = shape_op.input();
261 
262   // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
263   // scalar value from input vector).
264   if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 ||
265       slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 ||
266       slice_op.shrink_axis_mask() != 1)
267     return {};
268 
269   // Returns a value if the `value` is defined by a ConstOp with a single
270   // integer element in it and has an expected rank.
271   auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
272     auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
273     if (!const_op) return None;
274 
275     auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
276     if (!value_attr || value_attr.getNumElements() != 1) return None;
277 
278     auto value_ty = value_attr.getType();
279     if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
280 
281     auto splat = value_attr.getSplatValue<IntegerAttr>();
282     return splat.getValue().getSExtValue();
283   };
284 
285   // All other packed values are scalar constants.
286   SmallVector<int64_t, 4> packed_dims;
287   packed_dims.reserve(values().size() - 1);
288   for (Value operand : llvm::drop_begin(values(), 1)) {
289     if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
290       packed_dims.push_back(*dim);
291     } else {
292       return {};
293     }
294   }
295 
296   // Slice exactly the first shape dimension:
297   //   begin = [0] end = [1], strides = [1]
298   auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
299   auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
300   auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
301   if (!begin.has_value() || !end.has_value() || !strides.has_value() ||
302       *begin != 0 || *end != 1 || *strides != 1)
303     return {};
304 
305   // First tensor dimension is dynamic.
306   auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
307   if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
308       !arg_ty.isDynamicDim(0))
309     return {};
310 
311   // Argument tensor rank is equal to the number of packed dimensions.
312   if (arg_ty.getRank() != values().size()) return {};
313 
314   // All other dimensions are statically known and equal to packed dims.
315   auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
316   if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
317     return {};
318 
319   // Replace %pack with %shape.
320   return slice_op.input();
321 }
322 
323 // Convert Pack to Reshape when there is only one operand to be packed.
324 // For example,
325 //
326 //   %0 = tf.Pack(%input) {axis = 0} // %input : tensor<2x3xf32>
327 //
328 // can be canonicalized to
329 //
330 //   %shape = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi64>}
331 //   %0 = tf.Reshape(%input, %shape)
332 struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
333   using OpRewritePattern<PackOp>::OpRewritePattern;
334 
matchAndRewritemlir::TF::ConvertPackToReshape335   LogicalResult matchAndRewrite(PackOp pack_op,
336                                 PatternRewriter &rewriter) const override {
337     // Check if there is only one operand to be packed.
338     if (pack_op.N() != 1) {
339       return failure();
340     }
341 
342     // Check if input and output are static.
343     auto input_ty = pack_op.getOperand(0).getType().cast<ShapedType>();
344     auto output_ty = pack_op.output().getType().cast<ShapedType>();
345     if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) {
346       return failure();
347     }
348 
349     // Create constant shape for reshape.
350     auto type =
351         RankedTensorType::get(output_ty.getRank(), rewriter.getIntegerType(64));
352     auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
353     auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
354 
355     // TODO(b/173622615): Remove after fixed.
356     ReplaceTfOpWithNewOp<ReshapeOp>(rewriter, pack_op, output_ty,
357                                     pack_op.getOperand(0), shape);
358     return success();
359   }
360 };
361 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)362 void PackOp::getCanonicalizationPatterns(RewritePatternSet &results,
363                                          MLIRContext *context) {
364   results.add<ConvertPackToReshape>(context);
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // PadOp
369 //===----------------------------------------------------------------------===//
370 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)371 LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
372   // Paddings must be defined by a constant operation.
373   auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
374   if (!paddings_op) return failure();
375 
376   auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
377   if (!paddings_value ||
378       paddings_value.getNumElements() != permutation.size() * 2)
379     return failure();
380 
381   SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
382   for (auto index_pair : llvm::enumerate(paddings_value.getValues<APInt>())) {
383     size_t outer_idx = index_pair.index() / 2;
384     size_t inner_idx = index_pair.index() % 2;
385 
386     shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
387         index_pair.value().getSExtValue();
388   }
389 
390   // Add constant operation with a new paddings.
391   OpBuilder builder(getOperation());
392   auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
393                                           builder.getIntegerType(32));
394   auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
395   auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
396 
397   // Use new paddings.
398   setOperand(1, shuffled_paddings_op);
399 
400   // Change the result type.
401   getResult().setType(ShuffleRankedTensorType(getResult().getType(),
402                                               ReversePermutation(permutation)));
403 
404   return success();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // ParseExampleV2Op
409 //===----------------------------------------------------------------------===//
410 
verify()411 LogicalResult ParseExampleV2Op::verify() {
412   ParseExampleV2Op op = *this;
413   // NOTE(mrry): This validates properties of an op that would previously be
414   // validated by the TensorFlow OpDef type checker. In addition to these
415   // checks, the shape inference function for ParseExampleV2 validates the
416   // consistency of the argument and result types.
417 
418   // Validate dense variadic input and output lengths.
419   // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
420   // do not need to validate dense_defaults.
421   auto dense_types_count =
422       std::distance(op.Tdense().begin(), op.Tdense().end());
423   auto dense_values_count =
424       std::distance(op.dense_values().begin(), op.dense_values().end());
425   if (dense_values_count != dense_types_count) {
426     return op.emitError() << "output 'dense_values' should have same length "
427                           << "as attribute 'Tdense'";
428   }
429 
430   // Validate sparse variadic output lengths.
431   // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
432   // do not need to validate sparse_values.
433   auto sparse_types_count =
434       std::distance(op.sparse_types().begin(), op.sparse_types().end());
435   if (op.num_sparse() != sparse_types_count) {
436     return op.emitError() << "attribute 'num_sparse' should be the same as "
437                           << "the length of attribute 'sparse_types'";
438   }
439   if (op.sparse_indices().size() != sparse_types_count) {
440     return op.emitError() << "output 'sparse_indices' should have same length "
441                           << "as attribute 'sparse_types'";
442   }
443   if (op.sparse_shapes().size() != sparse_types_count) {
444     return op.emitError() << "output 'sparse_shapes' should have same length "
445                           << "as attribute 'sparse_types'";
446   }
447 
448   // Validate ragged variadic output lengths.
449   auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
450                                                 op.ragged_value_types().end());
451   auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
452                                                 op.ragged_split_types().end());
453   if (ragged_value_types_count != ragged_split_types_count) {
454     return op.emitError() << "attribute 'ragged_value_types' should have same "
455                           << "length as attribute 'ragged_split_types'";
456   }
457 
458   return success();
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // PartitionedCallOp
463 //===----------------------------------------------------------------------===//
464 
465 template <typename CallOpClass>
VerifyPartitionedCall(CallOpClass op,SymbolTableCollection & symbolTable)466 static LogicalResult VerifyPartitionedCall(CallOpClass op,
467                                            SymbolTableCollection &symbolTable) {
468   SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
469   auto function = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(op, func);
470   if (!function) {
471     return op.emitError("'f' attribute refers to an undefined function: ")
472            << func;
473   }
474 
475   FunctionType function_ty = function.getFunctionType();
476   int func_arg_count = function_ty.getNumInputs();
477   int arg_count = op.args().size();
478 
479   if (arg_count != func_arg_count) {
480     return op.emitError() << "argument count mismatch: 'args' has " << arg_count
481                           << " arguments, but '" << func << "' expects "
482                           << func_arg_count;
483   }
484 
485   return success();
486 }
487 
verifySymbolUses(SymbolTableCollection & symbolTable)488 LogicalResult PartitionedCallOp::verifySymbolUses(
489     SymbolTableCollection &symbolTable) {
490   return VerifyPartitionedCall(*this, symbolTable);
491 }
verifySymbolUses(SymbolTableCollection & symbolTable)492 LogicalResult StatefulPartitionedCallOp::verifySymbolUses(
493     SymbolTableCollection &symbolTable) {
494   return VerifyPartitionedCall(*this, symbolTable);
495 }
verifySymbolUses(SymbolTableCollection & symbolTable)496 LogicalResult TPUPartitionedCallOp::verifySymbolUses(
497     SymbolTableCollection &symbolTable) {
498   return VerifyPartitionedCall(*this, symbolTable);
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // PowOp
503 //===----------------------------------------------------------------------===//
504 
fold(ArrayRef<Attribute> operands)505 OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
506   auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
507   if (constant_y && constant_y.isSplat()) {
508     APFloat y_value = constant_y.getSplatValue<APFloat>();
509     auto output_type = getType().cast<ShapedType>();
510     if (y_value.isZero() && output_type.hasStaticShape()) {
511       return DenseElementsAttr::get(
512           output_type,
513           FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
514     }
515     if (y_value.isExactlyValue(1.0)) {
516       return x();
517     }
518   }
519   return {};
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // QuantizeAndDequantizeV2Op
524 //===----------------------------------------------------------------------===//
525 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)526 void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns(
527     RewritePatternSet &results, MLIRContext *context) {
528   results.add<QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4>(context);
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // QrOp
533 //===----------------------------------------------------------------------===//
534 
535 // Verifies that,
536 //
537 // * Input type, if ranked, must have at least 2 dimensions and at most
538 //   INT32_MAX dimensions.
539 //
verify()540 LogicalResult QrOp::verify() {
541   QrOp op = *this;
542   auto ttype = op.input().getType().cast<TensorType>();
543   if (!ttype.hasRank()) return success();
544   if (!HasRankAtLeast(op.input(), 2))
545     return op.emitOpError(
546         "requires ranked input tensor to be of rank 2 or more");
547   if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
548     return op.emitOpError(
549         "requires ranked input tensor to be of rank INT32_MAX or less");
550 
551   return success();
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // ReadVariableOp
556 //===----------------------------------------------------------------------===//
557 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)558 void ReadVariableOp::getCanonicalizationPatterns(RewritePatternSet &results,
559                                                  MLIRContext *context) {
560   results.add<ReadVariableOfCast>(context);
561 }
562 
563 //===----------------------------------------------------------------------===//
564 // RandomUniformOp
565 //===----------------------------------------------------------------------===//
566 
verify()567 LogicalResult RandomUniformOp::verify() {
568   RandomUniformOp op = *this;
569   if (!IsOfRankOrUnranked(op.shape(), 1))
570     return op.emitOpError("shape must be 1D tensor");
571   return success();
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // RangeOp
576 //===----------------------------------------------------------------------===//
577 
578 namespace {
579 
580 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
581 // Template parameter `FloatOrInt` must be standard C integer or floating-point
582 // types.
583 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)584 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
585   // Refer to the implementation in
586   // tensorflow/lite/kernels/range.cc.
587   FloatOrInt diff = limit - start;
588   if (std::is_integral<FloatOrInt>::value) {
589     return ((std::abs(diff) + std::abs(delta) - 1) / std::abs(delta));
590   }
591   return std::ceil(std::abs(diff / delta));
592 }
593 
594 // Builds a constant range tensor of `result_elem_type` elements.
595 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
596 // mlir::FloatAttr.
597 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)598 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
599                                         FloatOrIntAtrr start_attr,
600                                         FloatOrIntAtrr delta_attr) {
601   using ValueType = typename FloatOrIntAtrr::ValueType;  // APInt or APFloat
602   ValueType start = start_attr.getValue();
603   ValueType delta = delta_attr.getValue();
604 
605   SmallVector<ValueType, 16> new_values;
606   new_values.reserve(num_elements);
607   ValueType new_value = start;
608   for (int i = 0; i < num_elements; ++i) {
609     new_values.push_back(new_value);
610     new_value = new_value + delta;
611   }
612   // Result is always a 1-D tensor.
613   auto new_result_type =
614       RankedTensorType::get({num_elements}, result_elem_type);
615   return DenseElementsAttr::get(new_result_type, new_values);
616 }
617 }  // namespace
618 
build(OpBuilder & builder,OperationState & result,Value start,Value limit,Value delta)619 void RangeOp::build(OpBuilder &builder, OperationState &result, Value start,
620                     Value limit, Value delta) {
621   assert(start.getType() == limit.getType());
622   assert(start.getType() == delta.getType());
623   DenseIntElementsAttr start_val;
624   DenseIntElementsAttr limit_val;
625   DenseIntElementsAttr delta_val;
626   if (matchPattern(start, m_Constant(&start_val)) &&
627       matchPattern(limit, m_Constant(&limit_val)) &&
628       matchPattern(delta, m_Constant(&delta_val))) {
629     auto size = llvm::APIntOps::RoundingSDiv(
630         *limit_val.begin() - *start_val.begin(), *delta_val.begin(),
631         llvm::APInt::Rounding::DOWN);
632     return RangeOp::build(
633         builder, result,
634         RankedTensorType::get(
635             size.getSExtValue(),
636             start.getType().cast<TensorType>().getElementType()),
637         start, limit, delta);
638   }
639   return RangeOp::build(
640       builder, result,
641       RankedTensorType::get(
642           {-1}, start.getType().cast<TensorType>().getElementType()),
643       start, limit, delta);
644 }
645 
fold(ArrayRef<Attribute> operands)646 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
647   assert(operands.size() == 3);
648   auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
649   auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
650   auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
651   if (!(start_tensor && limit_tensor && delta_tensor)) return nullptr;
652 
653   // Operands should all be scalars
654   assert(start_tensor.getType().getRank() == 0 &&
655          limit_tensor.getType().getRank() == 0 &&
656          delta_tensor.getType().getRank() == 0);
657   Type elem_type = getType().cast<ShapedType>().getElementType();
658   if (elem_type.isSignlessInteger() || elem_type.isUnsignedInteger()) {
659     auto start_attr = start_tensor.getValues<IntegerAttr>()[0];
660     auto limit_attr = limit_tensor.getValues<IntegerAttr>()[0];
661     auto delta_attr = delta_tensor.getValues<IntegerAttr>()[0];
662     int num_elements;
663     if (elem_type.isUnsignedInteger()) {
664       uint64_t start = start_attr.getUInt();
665       uint64_t limit = limit_attr.getUInt();
666       uint64_t delta = delta_attr.getUInt();
667       assert(start <= (uint64_t)INT_MAX);
668       assert(limit <= (uint64_t)INT_MAX);
669       assert(delta <= (uint64_t)INT_MAX);
670       num_elements =
671           GetLengthOfRange(static_cast<int>(start), static_cast<int>(limit),
672                            static_cast<int>(delta));
673     } else {
674       num_elements = GetLengthOfRange(start_attr.getInt(), limit_attr.getInt(),
675                                       delta_attr.getInt());
676     }
677     return BuildConstRangeTensor(elem_type, num_elements, start_attr,
678                                  delta_attr);
679   } else if (elem_type.isa<FloatType>()) {
680     auto start_attr = start_tensor.getValues<FloatAttr>()[0];
681     auto limit_attr = limit_tensor.getValues<FloatAttr>()[0];
682     auto delta_attr = delta_tensor.getValues<FloatAttr>()[0];
683     const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
684                                               limit_attr.getValueAsDouble(),
685                                               delta_attr.getValueAsDouble());
686     return BuildConstRangeTensor(elem_type, num_elements, start_attr,
687                                  delta_attr);
688   }
689   return nullptr;
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // RankOp
694 //===----------------------------------------------------------------------===//
695 
build(OpBuilder & builder,OperationState & result,Value input)696 void RankOp::build(OpBuilder &builder, OperationState &result, Value input) {
697   return RankOp::build(builder, result,
698                        RankedTensorType::get({}, builder.getIntegerType(32)),
699                        input);
700 }
701 
702 // This will create a constant value for RankOp of a ranked tensor.
fold(ArrayRef<Attribute> operands)703 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
704   auto type = input().getType();
705   auto ranked_type = type.dyn_cast<RankedTensorType>();
706   if (!ranked_type) return {};
707 
708   // DenseIntElementsAttr::get requires the output type be ranked with static
709   // shape.
710   auto output_type = getType().dyn_cast<RankedTensorType>();
711   if (!output_type || !output_type.hasStaticShape()) return {};
712 
713   int32_t rank = ranked_type.getRank();
714   return DenseIntElementsAttr::get(output_type, rank);
715 }
716 
717 //===----------------------------------------------------------------------===//
718 // RealDivOp
719 //===----------------------------------------------------------------------===//
720 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)721 void RealDivOp::getCanonicalizationPatterns(RewritePatternSet &results,
722                                             MLIRContext *context) {
723   results.add<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
724 }
725 
fold(ArrayRef<Attribute> operands)726 OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
727   return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // ReluOp
732 //===----------------------------------------------------------------------===//
733 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)734 void ReluOp::getCanonicalizationPatterns(RewritePatternSet &results,
735                                          MLIRContext *context) {
736   results.add<ReluOfMinimum6ToRelu6>(context);
737 }
738 
739 //===----------------------------------------------------------------------===//
740 // ReshapeOp
741 //===----------------------------------------------------------------------===//
742 
743 namespace {
744 using ReshapeErrorHandler =
745     llvm::function_ref<LogicalResult(const llvm::Twine &)>;
746 
GetReshapeOutputType(Value tensor,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)747 LogicalResult GetReshapeOutputType(Value tensor, Value shape,
748                                    ReshapeErrorHandler error_handler,
749                                    TensorType &output_ty) {
750   auto tensor_ty = tensor.getType().cast<TensorType>();
751   auto element_ty = tensor_ty.getElementType();
752   output_ty = UnrankedTensorType::get(element_ty);
753 
754   auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
755   if (!shape_ty) return success();
756   if (shape_ty.getRank() != 1)
757     return error_handler(llvm::formatv(
758         "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
759 
760   DenseIntElementsAttr shape_attr;
761   if (!matchPattern(shape, m_Constant(&shape_attr))) {
762     // If only shape of `shape` is known, return ranked but dynamic output
763     // shape.
764     if (shape_ty.hasStaticShape()) {
765       llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
766                                                   ShapedType::kDynamicSize);
767       output_ty = RankedTensorType::get(dynamic_shape, element_ty);
768     }
769     return success();
770   }
771 
772   // Detect if reshape output shape is folded.
773   bool shape_ty_zero_dim = false;
774   int unknown_index = -1;
775   // The product of constant shape argument excluding unknown dimension.
776   int64_t shape_ty_size = 1;
777   llvm::SmallVector<int64_t, 8> output_ty_shape;
778   output_ty_shape.reserve(shape_attr.getNumElements());
779   for (const auto &dim : llvm::enumerate(shape_attr.getValues<APInt>())) {
780     const int64_t size = dim.value().getSExtValue();
781     if (ShapedType::isDynamic(size)) {
782       if (unknown_index != -1)
783         return error_handler(llvm::formatv(
784             "requires 'shape' to have at most one dynamic dimension, but got "
785             "multiple dynamic dimensions at indices {0} and {1}",
786             unknown_index, dim.index()));
787 
788       unknown_index = dim.index();
789     } else if (size == 0) {
790       shape_ty_zero_dim = true;
791     } else if (size > 0) {
792       shape_ty_size *= size;
793     } else {
794       return error_handler(
795           llvm::formatv("requires 'shape' to have dimensions greater than -1, "
796                         "but got {0} at index {1}",
797                         size, dim.index()));
798     }
799     output_ty_shape.push_back(size);
800   }
801 
802   if (!tensor_ty.hasStaticShape()) {
803     output_ty = RankedTensorType::get(output_ty_shape, element_ty);
804     return success();
805   }
806 
807   // Compute the value of the unknown dimension.
808   if (unknown_index != -1) {
809     // Compute number of elements in tensor shape.
810     int64_t tensor_ty_size = 1;
811     bool tensor_ty_zero_dim = false;
812     for (const auto &dim : tensor_ty.getShape()) {
813       if (dim > 0 || !shape_ty_zero_dim) {
814         tensor_ty_size *= dim;
815       } else {
816         tensor_ty_zero_dim = true;
817       }
818     }
819 
820     const int64_t missing_dim = tensor_ty_size / shape_ty_size;
821     if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size)
822       return error_handler(
823           llvm::formatv("requires 'tensor' number of elements be a multiple of "
824                         "{0}, but got {1}",
825                         shape_ty_size, tensor_ty_size));
826 
827     // Set the unknown dimension such that total number of elements remain
828     // constant.
829     output_ty_shape[unknown_index] = missing_dim;
830   }
831 
832   output_ty = RankedTensorType::get(output_ty_shape, element_ty);
833 
834   return success();
835 }
836 }  // namespace
837 
verify()838 LogicalResult ReshapeOp::verify() {
839   ReshapeOp op = *this;
840   auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
841     return op.emitOpError() << message;
842   };
843   TensorType expected_ty;
844   if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler,
845                                   expected_ty)))
846     return failure();
847 
848   auto output_ty = op.getType().dyn_cast<RankedTensorType>();
849   if (!output_ty) return success();
850   auto tensor_ty = op.tensor().getType().cast<TensorType>();
851   if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) {
852     const int64_t output_ty_size = output_ty.getNumElements();
853     const int64_t tensor_ty_size = tensor_ty.getNumElements();
854     if (tensor_ty_size != output_ty_size)
855       return op.emitOpError() << "requires 'output' number of elements to "
856                                  "match 'tensor' number of elements, but got "
857                               << output_ty_size << " and " << tensor_ty_size;
858   }
859 
860   if (!AreCastCompatible({output_ty, expected_ty}))
861     return op.emitOpError()
862            << "requires 'output' type " << output_ty
863            << " to be cast compatible with expected type " << expected_ty;
864 
865   return success();
866 }
867 
868 // Currently there are use cases that rely on partial evaluation of the `shape`
869 // operand, so InferTypeOpInterface is not used (along with generated builder of
870 // the same signature).
build(OpBuilder & builder,OperationState & result,Value tensor,Value shape)871 void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
872                       Value shape) {
873   auto error_handler = [&result](const llvm::Twine &message) {
874     return mlir::emitError(result.location) << message;
875   };
876   TensorType output_ty;
877   if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty)))
878     return;
879 
880   return ReshapeOp::build(builder, result, output_ty, tensor, shape);
881 }
882 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)883 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
884                                             MLIRContext *context) {
885   results.add<RedundantReshape, ReshapeToSelfShape>(context);
886 }
887 
fold(ArrayRef<Attribute> operands)888 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
889   Value tensor = this->tensor();
890 
891   // Fold reshape if operand and result types are the same and all dimensions
892   // are statically known (no-op reshape).
893   auto result_ty = getType().dyn_cast<ShapedType>();
894   if (result_ty && result_ty.hasStaticShape() &&
895       result_ty == tensor.getType()) {
896     return tensor;
897   }
898 
899   return {};
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // SelectOp
904 //===----------------------------------------------------------------------===//
905 
906 // Verifies a few extra requirements on SelectOp:
907 // (1) `then` and `else` must have same shape
908 // (2) At least one of the following must be true:
909 //     (a) `cond` has the same rank as `then` and `else`
910 //     (b) `cond` is a scalar
911 //     (c) `cond` is a vector AND `then` and `else` are non-scalar with their
912 //         first dimension equal to `cond`.
verify()913 LogicalResult SelectOp::verify() {
914   SelectOp op = *this;
915   auto then_tensor = op.t().getType().cast<TensorType>();
916   auto else_tensor = op.e().getType().cast<TensorType>();
917   // Check (1).
918   if (!AreCastCompatible({then_tensor, else_tensor}))
919     return op.emitOpError() << "requires t and e have compatible shapes";
920 
921   // Get data rank (if exists).
922   int data_rank;
923   // If data is unranked or data_rank is 0, this will remain -2. Otherwise
924   // refers to first dimension of then and/or else.
925   int data_first_dim = -2;
926   bool then_has_rank = then_tensor.hasRank();
927   bool else_has_rank = else_tensor.hasRank();
928   if (then_has_rank && else_has_rank) {
929     data_rank = then_tensor.getRank();
930     if (then_tensor.getRank() > 0)
931       data_first_dim = then_tensor.getShape().front();
932     if (else_tensor.getRank() > 0)
933       data_first_dim = std::max(
934           static_cast<int>(else_tensor.getShape().front()), data_first_dim);
935   } else if (then_has_rank) {
936     data_rank = then_tensor.getRank();
937     if (then_tensor.getRank() > 0)
938       data_first_dim = then_tensor.getShape().front();
939   } else if (else_has_rank) {
940     data_rank = else_tensor.getRank();
941     if (else_tensor.getRank() > 0)
942       data_first_dim = else_tensor.getShape().front();
943   } else {
944     // Neither has a rank.
945     return success();
946   }
947 
948   auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
949   if (!cond_tensor) return success();
950   auto cond_rank = cond_tensor.getRank();
951   // Check (2a) and (2b).
952   if (cond_rank == 0 || cond_rank == data_rank) return success();
953   // Check (2c).
954   if (cond_rank == 1) {
955     auto cond_shape = cond_tensor.getShape().front();
956     if (data_rank == 0) {
957       return op.emitOpError()
958              << "requires that t and e are nonscalar when pred is a vector";
959     }
960     // We know `data` tensor has a rank of at least 1.
961     if (data_first_dim != -1 && cond_shape != -1 &&
962         data_first_dim != cond_shape) {
963       return op.emitOpError() << "requires that, when pred is a vector, the "
964                                  "shape matches the first dimension of t and e";
965     }
966     return success();
967   }
968   // None of (2a,b,c) were true; fail.
969   return op.emitOpError() << "requires that pred is a scalar OR has the same "
970                              "rank as t and e OR is a vector";
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // SelectV2Op
975 //===----------------------------------------------------------------------===//
976 
InferSelectV2OpType(Value condition,Value e,Value t)977 static Type InferSelectV2OpType(Value condition, Value e, Value t) {
978   Type element_ty = e.getType().cast<TensorType>().getElementType();
979   auto unranked_ty = UnrankedTensorType::get(element_ty);
980 
981   Type broadcasted_ty =
982       OpTrait::util::getBroadcastedType(e.getType(), t.getType());
983   if (!broadcasted_ty) return unranked_ty;
984 
985   auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
986   auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
987   if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
988 
989   // Explicitly get broadcasted output type as element types of condition may
990   // not be same as the broadcated type's element type.
991   SmallVector<int64_t, 4> result_shape;
992   if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(),
993                                           broadcasted_ranked_ty.getShape(),
994                                           result_shape))
995     return unranked_ty;
996   return RankedTensorType::get(result_shape, element_ty);
997 }
998 
build(OpBuilder & builder,OperationState & result,Value condition,Value e,Value t)999 void SelectV2Op::build(OpBuilder &builder, OperationState &result,
1000                        Value condition, Value e, Value t) {
1001   build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // ShapeOp
1006 //===----------------------------------------------------------------------===//
1007 
1008 namespace {
1009 // Validates Shape/ShapeN/VariableShape operand and associated result types.
VerifyShapeOperandAndResult(Operation * op,Type operand_type,Type result_type,int variadic_idx=-1)1010 LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
1011                                           Type result_type,
1012                                           int variadic_idx = -1) {
1013   std::string variadic_idx_str =
1014       variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str();
1015 
1016   auto result_ranked_type = result_type.dyn_cast<RankedTensorType>();
1017   if (!result_ranked_type) return success();
1018   if (result_ranked_type.getShape().size() != 1)
1019     return op->emitOpError("requires 1D type for result") << variadic_idx_str;
1020 
1021   auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
1022   if (operand_ranked_type) {
1023     // The operand is a ranked tensor.
1024     if (result_ranked_type.hasStaticShape() &&
1025         !operand_ranked_type.getShape().empty() &&
1026         result_ranked_type.getDimSize(0) !=
1027             operand_ranked_type.getShape().size())
1028       return op->emitOpError("requires dimension size of result")
1029              << variadic_idx_str << " to match rank of operand"
1030              << variadic_idx_str;
1031   } else if (result_ranked_type.hasStaticShape()) {
1032     // The operand is an unranked tensor, print a warning if the result
1033     // is static.
1034     // Note: We do not handle this situation as an error, this would be too
1035     // restrictive due to incompleteness of shape inference at this point.
1036     mlir::InFlightDiagnostic diag =
1037         mlir::emitWarning(op->getLoc(), "has static shape result");
1038     if (op->getContext()->shouldPrintOpOnDiagnostic()) {
1039       diag.attachNote(op->getLoc())
1040           .append("see current operation: ")
1041           .appendOp(*op, OpPrintingFlags().printGenericOpForm());
1042     }
1043     diag << variadic_idx_str << " for unranked operand" << variadic_idx_str;
1044   }
1045 
1046   Type element_type = result_ranked_type.getElementType();
1047   if (!element_type.isSignlessInteger(32) &&
1048       !element_type.isSignlessInteger(64))
1049     return op->emitOpError("requires int32 or int64 return type for result")
1050            << variadic_idx_str;
1051 
1052   return success();
1053 }
1054 }  // anonymous namespace
1055 
verify()1056 LogicalResult ShapeOp::verify() {
1057   ShapeOp op = *this;
1058   return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
1059 }
1060 
1061 // Converts shape of the given type to attribute if it is of ranked tensor type.
1062 // Returned attribute has integer elements of the given width.
ConvertShapeToAttr(Type input_ty,int out_width)1063 static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
1064   auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
1065   if (!ranked_ty || !ranked_ty.hasStaticShape()) return {};
1066 
1067   auto shape = ranked_ty.getShape();
1068   int rank = shape.size();
1069 
1070   SmallVector<APInt, 4> dimensions;
1071   dimensions.reserve(rank);
1072   for (int i = 0; i < rank; ++i)
1073     dimensions.push_back(APInt(out_width, shape[i]));
1074 
1075   auto result_type = RankedTensorType::get(
1076       {rank}, IntegerType::get(input_ty.getContext(), out_width));
1077   return DenseElementsAttr::get(result_type, dimensions);
1078 }
1079 
fold(ArrayRef<Attribute> operands)1080 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
1081   int width =
1082       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
1083   return ConvertShapeToAttr(getOperand().getType(), width);
1084 }
1085 
build(OpBuilder & builder,OperationState & result,Value input,BoolAttr use32Bit)1086 void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input,
1087                     BoolAttr use32Bit) {
1088   auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
1089   int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
1090   auto out_type = use32Bit.getValue() ? builder.getIntegerType(32)
1091                                       : builder.getIntegerType(64);
1092   return ShapeOp::build(builder, result,
1093                         RankedTensorType::get({rank}, out_type), input);
1094 }
1095 
1096 //===----------------------------------------------------------------------===//
1097 // ShapeNOp
1098 //===----------------------------------------------------------------------===//
1099 
verify()1100 LogicalResult ShapeNOp::verify() {
1101   ShapeNOp op = *this;
1102   const size_t num_tensors = op.N();
1103 
1104   if (op.getNumOperands() != num_tensors)
1105     return op.emitOpError() << "requires " << num_tensors << " operand(s), got "
1106                             << op.getNumOperands() << " operand(s)";
1107 
1108   if (op.getNumResults() != num_tensors)
1109     return op.emitOpError() << "requires " << num_tensors << " result(s), got "
1110                             << op.getNumResults() << " result(s)";
1111 
1112   for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
1113     auto verification = VerifyShapeOperandAndResult(
1114         op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
1115     if (failed(verification)) return verification;
1116   }
1117 
1118   return success();
1119 }
1120 
1121 namespace {
1122 // Canonicalization pattern for ShapeNOp that don't have all
1123 // static input shapes. Replacing output values corresponding to static input
1124 // types may enable optimizations in users of the values.
1125 class ShapeNPartialStaticInputShape : public OpRewritePattern<ShapeNOp> {
1126   using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const1127   LogicalResult matchAndRewrite(ShapeNOp op,
1128                                 PatternRewriter &rewriter) const override {
1129     if (op.getNumOperands() == 0) {
1130       rewriter.eraseOp(op);
1131       return success();
1132     }
1133 
1134     int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
1135 
1136     SmallVector<Value, 4> results(op.getNumOperands());
1137     SmallVector<int64_t, 4> dynamic_indices;
1138     SmallVector<Value, 4> dynamic_inputs;
1139     SmallVector<Type, 4> result_types;
1140     for (auto e : llvm::enumerate(op.getOperands())) {
1141       if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) {
1142         results[e.index()] = rewriter.create<TF::ConstOp>(op.getLoc(), result);
1143       } else {
1144         dynamic_indices.push_back(e.index());
1145         dynamic_inputs.push_back(e.value());
1146         result_types.push_back(op.getType(e.index()));
1147       }
1148     }
1149 
1150     if (dynamic_inputs.size() == op.getNumOperands()) {
1151       // Cannot canonicalize ShapeN if all inputs are dynamic.
1152       return failure();
1153     }
1154 
1155     // Create a ShapeNOp for all dynamic inputs.
1156     if (!dynamic_inputs.empty()) {
1157       auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
1158           op.getLoc(), result_types, dynamic_inputs);
1159       for (auto index_result :
1160            llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) {
1161         results[std::get<0>(index_result)] = std::get<1>(index_result);
1162       }
1163     }
1164 
1165     rewriter.replaceOp(op, results);
1166     return success();
1167   }
1168 };
1169 
1170 // Canonicalize ShapeNOp to ShapeOp if there is only one operand.
1171 class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
1172   using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const1173   LogicalResult matchAndRewrite(ShapeNOp op,
1174                                 PatternRewriter &rewriter) const override {
1175     if (op.getNumOperands() != 1) {
1176       return failure();
1177     }
1178     auto shape = rewriter.create<TF::ShapeOp>(op.getLoc(), op.getType(0),
1179                                               op.getOperand(0));
1180     rewriter.replaceOp(op, {shape});
1181     return success();
1182   }
1183 };
1184 }  // namespace
1185 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1186 void ShapeNOp::getCanonicalizationPatterns(RewritePatternSet &results,
1187                                            MLIRContext *context) {
1188   results.add<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // SizeOp
1193 //===----------------------------------------------------------------------===//
1194 
1195 // Verifies that,
1196 //
1197 // * Input type, if is a ranked tensor, has at most INT32_MAX dimensions.
1198 //
verify()1199 LogicalResult SizeOp::verify() {
1200   SizeOp op = *this;
1201   if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
1202     return op.emitOpError(
1203         "requires ranked input tensor to be of rank INT32_MAX or less");
1204 
1205   // Output type needs to be scalar.
1206   if (!IsOfRankOrUnranked(op.output(), /*rank=*/0))
1207     return op.emitOpError("requires scalar output");
1208 
1209   return success();
1210 }
1211 
fold(ArrayRef<Attribute> operands)1212 OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
1213   ShapedType output_type = getType().cast<ShapedType>();
1214   if (!output_type.hasRank()) return {};
1215   ShapedType input_type = getOperand().getType().cast<ShapedType>();
1216   if (!input_type.hasStaticShape()) return {};
1217   int size = input_type.getNumElements();
1218   return DenseElementsAttr::get(
1219       output_type,
1220       IntegerAttr::get(output_type.getElementType(), /*value=*/size));
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // SliceOp
1225 //===----------------------------------------------------------------------===//
1226 
1227 // Verifies that:
1228 //
1229 // - operands begin and size are 1D with the same number of elements.
1230 // - if the input is a ranked tensor, the rank of the input equals the number
1231 //   of elements in operands begin and size.
1232 // - if begin are constants, that
1233 //   0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
1234 //   and
1235 //   size[i] == output_ty.getShape()[i]
1236 // - if begins aren't constant but the input is a ranked tensor, that
1237 //   size[i] <= input_ty.getShape()[i]
1238 // - output rank is the same as input rank
1239 //
verify()1240 LogicalResult SliceOp::verify() {
1241   SliceOp op = *this;
1242   RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
1243   if (begin_ty && begin_ty.getRank() != 1) {
1244     return op.emitOpError() << "requires begin operand to be 1D tensor";
1245   }
1246 
1247   RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
1248   if (size_ty && size_ty.getRank() != 1) {
1249     return op.emitOpError() << "requires size operand to be 1D tensor";
1250   }
1251 
1252   if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
1253       !size_ty.hasStaticShape())
1254     return success();
1255 
1256   if (begin_ty.getNumElements() != size_ty.getNumElements()) {
1257     return op.emitOpError() << "requires begin and size operands to have the"
1258                                " same number of elements";
1259   }
1260 
1261   auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1262   if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
1263     return op.emitOpError() << "requires number of elements in begin and size "
1264                                "are equal to input rank";
1265   }
1266 
1267   auto output_ty = op.output().getType().dyn_cast<RankedTensorType>();
1268   if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) {
1269     return op.emitOpError()
1270            << "requires output to have the same rank as input, but got input "
1271               "rank "
1272            << input_ty.getRank() << " and output rank " << output_ty.getRank();
1273   }
1274 
1275   DenseIntElementsAttr begin_indices;
1276   if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
1277     DenseIntElementsAttr slice_sizes;
1278     bool constant_slice_sizes =
1279         matchPattern(op.size(), m_Constant(&slice_sizes));
1280     int dim = 0;
1281     // TODO(jpienaar): Reformulate the shape verification below to not use magic
1282     // constants.
1283     for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
1284       int64_t begin_index = raw_begin_index.getSExtValue();
1285       int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
1286       int64_t slice_size =
1287           constant_slice_sizes
1288               ? slice_sizes.getValues<APInt>()[dim].getSExtValue()
1289               : 0;
1290       int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1;
1291 
1292       if (slice_size == -1 && input_size != -1) {
1293         slice_size = input_size - begin_index;
1294       }
1295       if (output_size != -1 && constant_slice_sizes &&
1296           output_size != slice_size) {
1297         return op.emitOpError()
1298                << "requires output size to have the same size of slice, got "
1299                   "slice size "
1300                << slice_size << " and output size " << output_size;
1301       }
1302       if (begin_index < 0 ||
1303           (input_size != -1 && begin_index + slice_size > input_size)) {
1304         return op.emitOpError()
1305                << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
1306       }
1307       ++dim;
1308     }
1309   } else if (input_ty) {
1310     // If the inputs are ranked, we can do a few more sanity checks.
1311     DenseIntElementsAttr slice_sizes;
1312     if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
1313       auto input_shape = input_ty.getShape();
1314       for (int64_t i = 0; i < input_ty.getRank(); ++i) {
1315         int64_t slice_size = slice_sizes.getValues<APInt>()[i].getSExtValue();
1316         int64_t input_size = input_shape[i];
1317         if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
1318           return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
1319                                      "is unknown at compile time";
1320         }
1321       }
1322     }
1323   }
1324 
1325   return success();
1326 }
1327 
1328 //===----------------------------------------------------------------------===//
1329 // SoftmaxOp
1330 //===----------------------------------------------------------------------===//
1331 
verify()1332 LogicalResult SoftmaxOp::verify() {
1333   SoftmaxOp op = *this;
1334   if (!HasRankAtLeast(op.logits(), 1)) {
1335     return op.emitOpError("requires operand to have rank at least 1");
1336   }
1337   return success();
1338 }
1339 
1340 //===----------------------------------------------------------------------===//
1341 // SoftmaxCrossEntropyWithLogitsOp
1342 //===----------------------------------------------------------------------===//
1343 
1344 // Verifies that,
1345 //
1346 // * Input types are broadcast compatible and the broadcasted type has rank two.
1347 //
verify()1348 LogicalResult SoftmaxCrossEntropyWithLogitsOp::verify() {
1349   SoftmaxCrossEntropyWithLogitsOp op = *this;
1350   auto broadcasted_ty = OpTrait::util::getBroadcastedType(
1351                             op.features().getType(), op.labels().getType())
1352                             .dyn_cast_or_null<ShapedType>();
1353   if (!broadcasted_ty ||
1354       (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
1355     return op.emitOpError(
1356         "requires features and labels to be broadcast compatible to rank two");
1357 
1358   return success();
1359 }
1360 
1361 //===----------------------------------------------------------------------===//
1362 // SpaceToBatchNDOp
1363 //===----------------------------------------------------------------------===//
1364 
SpaceToBatchNDBlockRank(const TensorType block_shape_type,const TensorType paddings_type)1365 int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type,
1366                                 const TensorType paddings_type) {
1367   if (block_shape_type.hasStaticShape()) {
1368     return block_shape_type.getShape()[0];
1369   } else if (paddings_type.hasStaticShape()) {
1370     return paddings_type.getShape()[0];
1371   } else {
1372     return -1;
1373   }
1374 }
1375 
verify()1376 LogicalResult SpaceToBatchNDOp::verify() {
1377   SpaceToBatchNDOp op = *this;
1378   const auto input_type = op.input().getType().cast<TensorType>();
1379   const auto block_shape_type = op.block_shape().getType().cast<TensorType>();
1380   const auto paddings_type = op.paddings().getType().cast<TensorType>();
1381 
1382   // Check that block_shape has rank 1.
1383   if (!IsOfRankOrUnranked(op.block_shape(), 1)) {
1384     return op.emitOpError() << "requires rank of block_shape = 1; got "
1385                             << block_shape_type.getRank();
1386   }
1387 
1388   // Check that paddings has rank 2.
1389   if (!IsOfRankOrUnranked(op.paddings(), 2)) {
1390     return op.emitOpError()
1391            << "requires rank of paddings = 2; got " << paddings_type.getRank();
1392   }
1393 
1394   // Check that paddings.shape[1]=2.
1395   if (paddings_type.hasStaticShape() && paddings_type.getShape()[1] != 2) {
1396     return op.emitOpError() << "requires paddings.shape[1] to be 2; got "
1397                             << paddings_type.getShape()[1];
1398   }
1399 
1400   // Check that block_shape and paddings have consistent ranks.
1401   if (block_shape_type.hasStaticShape() && paddings_type.hasStaticShape() &&
1402       block_shape_type.getShape()[0] != paddings_type.getShape()[0]) {
1403     return op.emitOpError()
1404            << "requires block_shape.shape[0] must equal paddings.shape[0]";
1405   }
1406 
1407   const int64_t block_rank =
1408       SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
1409 
1410   // Further checks require block_rank to be known.
1411   if (block_rank == -1) {
1412     return success();
1413   }
1414 
1415   // check that rank of input_type >= block_rank + 1
1416   if (input_type.hasRank() && input_type.getRank() < 1 + block_rank) {
1417     return op.emitOpError() << "requires rank of input >= 1 + rank of block";
1418   }
1419 
1420   ElementsAttr block_shape_attr = nullptr;
1421   ElementsAttr paddings_attr = nullptr;
1422 
1423   // Check that block_shape[*] >= 1.
1424   if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) {
1425     uint64_t i = 0;
1426     for (auto block_len : block_shape_attr.getValues<APInt>()) {
1427       if (block_len.getSExtValue() < 1) {
1428         return op.emitOpError()
1429                << "requires all values of block_shape to be >= 1; "
1430                   "failed for dimension "
1431                << i;
1432       }
1433       ++i;
1434     }
1435   }
1436 
1437   // Check that paddings[*] >= 0.
1438   if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) {
1439     for (uint64_t i = 0; i < block_rank; ++i) {
1440       const int64_t pad_start =
1441           paddings_attr.getValues<APInt>()[{i, 0}].getSExtValue();
1442       const int64_t pad_end =
1443           paddings_attr.getValues<APInt>()[{i, 1}].getSExtValue();
1444       if (pad_start < 0 || pad_end < 0) {
1445         return op.emitOpError()
1446                << "requires all values of paddings to be >= 0; "
1447                   "failed for dimension "
1448                << i;
1449       }
1450     }
1451   }
1452 
1453   // Check that block_shape divides the padded input.
1454   if (input_type.hasStaticShape() && block_shape_attr && paddings_attr) {
1455     for (uint64_t i = 0; i < block_rank; ++i) {
1456       const int64_t input_len = input_type.getShape()[1 + i];
1457       const int64_t pad_start =
1458           paddings_attr.getValues<APInt>()[{i, 0}].getSExtValue();
1459       const int64_t pad_end =
1460           paddings_attr.getValues<APInt>()[{i, 1}].getSExtValue();
1461       const int64_t block_len =
1462           block_shape_attr.getValues<APInt>()[i].getSExtValue();
1463       if ((input_len + pad_start + pad_end) % block_len != 0) {
1464         return op.emitOpError()
1465                << "requires block_shape[i] divides "
1466                   "input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; "
1467                   "failed for i="
1468                << i;
1469       }
1470     }
1471   }
1472 
1473   return success();
1474 }
1475 
1476 //===----------------------------------------------------------------------===//
1477 // SparseSoftmaxCrossEntropyWithLogitsOp
1478 //===----------------------------------------------------------------------===//
1479 
verify()1480 LogicalResult SparseSoftmaxCrossEntropyWithLogitsOp::verify() {
1481   SparseSoftmaxCrossEntropyWithLogitsOp op = *this;
1482   if (!IsOfRankOrUnranked(op.features(), 2)) {
1483     return op.emitOpError("requires features operand of rank two");
1484   }
1485   if (!IsOfRankOrUnranked(op.labels(), 1)) {
1486     return op.emitOpError("requires labels operand of rank one");
1487   }
1488   auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
1489   auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
1490   if (features_ty && labels_ty) {
1491     int64_t features_batches = features_ty.getDimSize(0);
1492     int64_t labels_batches = labels_ty.getDimSize(0);
1493     if (!ShapedType::isDynamic(features_batches) &&
1494         !ShapedType::isDynamic(labels_batches) &&
1495         features_batches != labels_batches)
1496       return op.emitOpError(
1497           "requires features and labels with matching first dimension");
1498   }
1499   return success();
1500 }
1501 
1502 //===----------------------------------------------------------------------===//
1503 // SplitOp
1504 //===----------------------------------------------------------------------===//
1505 
1506 // Verifies the input and split dimension operands for tf.Split/tf.SplitV.
1507 // Writes the split dimension's index (adjusted with input rank) via `dim_index`
1508 // if it's a constant.
1509 template <class Op>
VerifySplitInputAndSplitDim(Op op,Optional<int64_t> * dim_index)1510 LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
1511   *dim_index = llvm::None;
1512 
1513   Value split_dim = op.split_dim();
1514   if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
1515     if (split_dim_type.getRank() != 0)
1516       return op.emitOpError(
1517           "split dimension should be an integer scalar tensor");
1518 
1519   // We can perform further verification if the input tensor to be split has
1520   // known rank and the split dimension tensor is a constant.
1521 
1522   auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
1523   if (!input_type) return success();
1524 
1525   int64_t input_rank = input_type.getRank();
1526   if (input_rank == 0)
1527     return op.emitOpError("cannot split scalar input tensor");
1528 
1529   DenseIntElementsAttr split_dim_attr;
1530   if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
1531 
1532   int64_t index = (*split_dim_attr.begin()).getSExtValue();
1533 
1534   if (index + input_rank < 0 || index >= input_rank) {
1535     return op.emitOpError("split dimension must be in range [-")
1536            << input_rank << ", " << input_rank << ")";
1537   }
1538 
1539   if (index < 0) index += input_rank;
1540   *dim_index = index;
1541 
1542   return success();
1543 }
1544 
verify()1545 LogicalResult SplitOp::verify() {
1546   SplitOp op = *this;
1547   Optional<int64_t> dim_index;
1548   if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1549   if (!dim_index) return success();
1550 
1551   int64_t input_dim_size =
1552       op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1553   if (ShapedType::isDynamic(input_dim_size)) return success();
1554 
1555   if (op.getNumResults() == 0) return failure();
1556 
1557   if (input_dim_size % op.getNumResults() != 0)
1558     return op.emitOpError("dimension #")
1559            << *dim_index << " not divisible by the number of result tensors";
1560 
1561   return success();
1562 }
1563 
1564 //===----------------------------------------------------------------------===//
1565 // SplitVOp
1566 //===----------------------------------------------------------------------===//
1567 
verify()1568 LogicalResult SplitVOp::verify() {
1569   SplitVOp op = *this;
1570   auto split_sizes_type =
1571       op.size_splits().getType().dyn_cast<RankedTensorType>();
1572   if (!split_sizes_type) return success();
1573 
1574   if (split_sizes_type.getRank() != 1 ||
1575       (!ShapedType::isDynamic(split_sizes_type.getDimSize(0)) &&
1576        split_sizes_type.getDimSize(0) != op.getNumResults()))
1577     return op.emitOpError("split sizes should be a 1D tensor of ")
1578            << op.getNumResults() << " elements";
1579 
1580   Optional<int64_t> dim_index = 0;
1581   if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1582   if (!dim_index) return success();
1583 
1584   int64_t input_dim_size =
1585       op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1586   if (ShapedType::isDynamic(input_dim_size)) return success();
1587 
1588   // If split sizes come from a constant, they must sum to the dimension size
1589   // along split_dim, and we can have no more than one dynamic dimension.
1590   DenseIntElementsAttr split_sizes_attr;
1591   if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
1592     return success();
1593 
1594   int64_t total_dim_size = 0;  // Total dimension size assigned to splits
1595   llvm::Optional<int> dynamic_dim_index;
1596 
1597   SmallVector<int64_t, 4> split_sizes;
1598   split_sizes.reserve(
1599       split_sizes_attr.getType().cast<ShapedType>().getNumElements());
1600 
1601   for (auto dim : llvm::enumerate(split_sizes_attr)) {
1602     int64_t dim_val = dim.value().getSExtValue();
1603     split_sizes.push_back(dim_val);
1604     if (ShapedType::isDynamic(dim_val)) {
1605       // We cannot have more than one dynamic dimension.
1606       if (dynamic_dim_index)
1607         return op.emitOpError(
1608             "cannot have more than one dynamic dimension in split sizes");
1609       dynamic_dim_index = dim.index();
1610     } else {
1611       total_dim_size += dim_val;
1612     }
1613   }
1614 
1615   if (!dynamic_dim_index && total_dim_size != input_dim_size)
1616     return op.emitOpError(
1617                "split sizes must sum up to the dimension size along split "
1618                "dimension, found ")
1619            << total_dim_size << " vs " << input_dim_size;
1620 
1621   if (dynamic_dim_index && total_dim_size > input_dim_size)
1622     return op.emitOpError(
1623                "split sizes must sum up to be less than or equal to the "
1624                "dimension size along split dimension, found ")
1625            << total_dim_size << " vs " << input_dim_size;
1626 
1627   return success();
1628 }
1629 
1630 //===----------------------------------------------------------------------===//
1631 // SquareOp
1632 //===----------------------------------------------------------------------===//
1633 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1634 void SquareOp::getCanonicalizationPatterns(RewritePatternSet &results,
1635                                            MLIRContext *context) {
1636   results.add<SquareOfSub>(context);
1637 }
1638 
1639 //===----------------------------------------------------------------------===//
1640 // SqueezeOp
1641 //===----------------------------------------------------------------------===//
1642 
verify()1643 LogicalResult SqueezeOp::verify() {
1644   SqueezeOp op = *this;
1645   auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1646 
1647   if (!input_type) return success();  // Can't verify squeeze dims.
1648 
1649   int64_t input_rank = input_type.getRank();
1650   for (const auto &squeeze_dim_apint :
1651        op.squeeze_dims().getAsValueRange<IntegerAttr>()) {
1652     int64_t squeeze_dim = squeeze_dim_apint.getSExtValue();
1653     if (squeeze_dim < -input_rank || squeeze_dim >= input_rank) {
1654       return op.emitOpError()
1655              << "squeeze dimension " << squeeze_dim << " not in ["
1656              << -input_rank << ", " << input_rank << ")";
1657     }
1658   }
1659 
1660   return success();
1661 }
1662 
1663 //===----------------------------------------------------------------------===//
1664 // SubOp
1665 //===----------------------------------------------------------------------===//
1666 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1667 void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
1668                                         MLIRContext *context) {
1669   results.add<SubOfNeg>(context);
1670 }
1671 
fold(ArrayRef<Attribute> operands)1672 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1673   return IdentityArithmeticOpFolder<SubOp>(*this, operands);
1674 }
1675 
1676 //===----------------------------------------------------------------------===//
1677 // SumOp
1678 //===----------------------------------------------------------------------===//
1679 
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)1680 void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
1681                   Value reduction_indices, BoolAttr keep_dims) {
1682   Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims);
1683   build(builder, result, out_ty, input, reduction_indices, keep_dims);
1684 }
1685 
1686 // TODO: Templatize this fold for all reduction ops.
fold(ArrayRef<Attribute> operands)1687 OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
1688   auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
1689   if (!input_ty) return {};
1690   auto result_ty = getType().template dyn_cast<RankedTensorType>();
1691   if (!result_ty) return {};
1692 
1693   // Bypass this op if the result has the same shape and type. This can happen
1694   // if the input tensor has size 0 or size 1.
1695   if (!keep_dims() && input_ty == result_ty) {
1696     return input();
1697   }
1698   return {};
1699 }
1700 
1701 //===----------------------------------------------------------------------===//
1702 // StridedSliceOp
1703 //===----------------------------------------------------------------------===//
1704 
1705 // TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
1706 // tf.SliceOp if both of the following are true:
1707 // - All strides have a known value equal to 1
1708 // - No masks are set (or masks can be applied by transforming the inputs to
1709 //   Slice)
1710 
1711 // Verifies that,
1712 //
1713 // - begin, end and strides operands are 1D and they have the same number of
1714 //   elements. Here, the number of elements should be less than 32 to support
1715 //   32-bit mask attributes.
1716 // - None of the strides values are zero.
1717 // - Ellipsis mask can have at most one bit set.
1718 
1719 template <class OpTy>
VerifyStridedSliceBase(OpTy op)1720 static LogicalResult VerifyStridedSliceBase(OpTy op) {
1721   // Expected size for operands begin, end and strides vector operands.
1722   int64_t expected_size = -1;
1723 
1724   for (Value val : {op.begin(), op.end(), op.strides()}) {
1725     auto operand_ty = val.getType().dyn_cast<ShapedType>();
1726     if (!operand_ty || !operand_ty.hasStaticShape()) {
1727       // TensorFlow constant ops may have non-static shape because the shape is
1728       // not propagated during constant folding. If the defining op for this
1729       // operand is a constant op, use the constant op's attribute to get the
1730       // actual shape.
1731       DenseIntElementsAttr attr;
1732       if (!matchPattern(val, m_Constant(&attr))) continue;
1733       operand_ty = attr.getType();
1734     }
1735 
1736     if (operand_ty.getRank() != 1)
1737       return op.emitOpError()
1738              << "requires begin, end and strides to be 1D tensors";
1739 
1740     int64_t length = operand_ty.getDimSize(0);
1741     if (length == -1) continue;
1742 
1743     if (expected_size == -1) {
1744       // This op uses 32-bit masks.
1745       if (length >= 32)
1746         return op.emitOpError(
1747             "requires begin, end and strides operands with less than 32 "
1748             "elements");
1749 
1750       expected_size = length;
1751     } else if (length != expected_size) {
1752       return op.emitOpError() << "requires begin, end and strides to have the "
1753                                  "same number of elements";
1754     }
1755   }
1756 
1757   // If strides are constants, verify that none of the element is zero.
1758   DenseIntElementsAttr strides;
1759   if (matchPattern(op.strides(), m_Constant(&strides))) {
1760     if (llvm::is_contained(strides.getValues<APInt>(), 0))
1761       return op.emitOpError("requires non-zero strides");
1762   }
1763 
1764   // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
1765   // exists only no more than one ellipsis.
1766   uint32_t ellipsis_mask = op.ellipsis_mask();
1767   if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
1768     return op.emitOpError("cannot have multiple ellipses");
1769 
1770   return success();
1771 }
1772 
verify()1773 LogicalResult StridedSliceOp::verify() { return VerifyStridedSliceBase(*this); }
1774 
1775 // Clamps the given `val`: returns `low` if `val` is less than `low`; returns
1776 // `high` if `high` is less than `val`; otherwise returns `val`.
1777 template <class T>
Clamp(const T & val,const T & low,const T & high)1778 constexpr const T &Clamp(const T &val, const T &low, const T &high) {
1779   assert(!(high < low));
1780   return (val < low) ? low : (high < val) ? high : val;
1781 }
1782 
1783 // Checks if the `index` bit of `val` is set.
1784 template <class T>
IsSet(const T & val,unsigned index)1785 constexpr bool IsSet(const T &val, unsigned index) {
1786   return (val & (1 << index)) != 0;
1787 }
1788 
1789 // Sets the `index` bit of `val`.
1790 template <class T>
Set(T & val,unsigned index)1791 constexpr void Set(T &val, unsigned index) {
1792   val |= (1 << index);
1793 }
1794 
1795 // Unset the `index` bit of `val`.
1796 template <class T>
Unset(T & val,unsigned index)1797 constexpr void Unset(T &val, unsigned index) {
1798   val &= ~(1 << index);
1799 }
1800 
1801 // Copy the `src_index` bit of `src` to `dst_index` bit of `dst`.
1802 template <class T>
CopyBit(const T & src,unsigned src_index,T & dst,unsigned dst_index)1803 constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
1804                        unsigned dst_index) {
1805   if (IsSet(src, src_index))
1806     Set(dst, dst_index);
1807   else
1808     Unset(dst, dst_index);
1809 }
1810 
1811 // The sparse spec of strided slice does not correspond to the number of
1812 // dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
1813 // 4, 8) would have dims = 2.
1814 struct SparseSliceSpec {
1815   int64_t dims;
1816   int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
1817   const ArrayRef<int64_t> &begin;
1818   const ArrayRef<int64_t> &end;
1819   const ArrayRef<int64_t> &strides;
1820 };
1821 
1822 // The dense spec of strided slice is the canonicalized version of sparse spec.
1823 // The number of dimensions of dense spec correspond to the number of dimensions
1824 // in operand tensor.
1825 struct DenseSliceSpec {
1826   int64_t dims;
1827   int32_t begin_mask, end_mask, shrink_axis_mask;
1828   SmallVectorImpl<int64_t> &begin;
1829   SmallVectorImpl<int64_t> &end;
1830   SmallVectorImpl<int64_t> &strides;
1831 };
1832 
1833 // Make a sparse spec into a dense index spec.
1834 // The sparse spec does not correspond to the number of dimensions
1835 // Make a dense spec that corresponds to the number of dimensions
1836 //
1837 // For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then
1838 // we need to produce the missing begin_mask, end_mask for the first two
1839 // dimensions i.e. foo[:, :, 3:, 2].
BuildDenseSliceSpec(const SparseSliceSpec & sparse,DenseSliceSpec * dense)1840 static void BuildDenseSliceSpec(const SparseSliceSpec &sparse,
1841                                 DenseSliceSpec *dense) {
1842   // Build expanded dense begin, end, strides, begin_mask, end_mask, and
1843   // shrink_axis_mask.
1844   dense->begin.resize(dense->dims);
1845   dense->end.resize(dense->dims);
1846   dense->strides.resize(dense->dims);
1847   dense->begin_mask = 0;
1848   dense->end_mask = 0;
1849   dense->shrink_axis_mask = 0;
1850 
1851   // Count number of new_axis after ellipsis. This helps in calculating the
1852   // number of dimensions ellipsis represents in the sparse spec.
1853   bool ellipsis_seen = false;
1854   int num_new_axis_after_ellipsis = 0;
1855   for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1856     if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
1857       num_new_axis_after_ellipsis++;
1858     if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
1859   }
1860 
1861   int dense_index = 0;
1862   for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1863     if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
1864     if (IsSet(sparse.ellipsis_mask, sparse_index)) {
1865       auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1866                                      1 + num_new_axis_after_ellipsis,
1867                                  dense->dims);
1868       // Expand ellipsis into the appropriate dense indices. From current index
1869       // until next_index, all dimensions would have begin and end masks set and
1870       // stride 1, i.e., get all elements in those dimensions.
1871       for (; dense_index < next_index; ++dense_index) {
1872         dense->begin[dense_index] = dense->end[dense_index] = 0;
1873         dense->strides[dense_index] = 1;
1874         Set(dense->begin_mask, dense_index);
1875         Set(dense->end_mask, dense_index);
1876       }
1877       continue;
1878     }
1879     assert(dense_index < dense->dims);
1880     // Copy over the sparse indices to dense indices if ellipsis_mask and
1881     // new_axis_mask are not set.
1882     dense->begin[dense_index] = sparse.begin[sparse_index];
1883     dense->end[dense_index] = sparse.end[sparse_index];
1884     dense->strides[dense_index] = sparse.strides[sparse_index];
1885     CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index);
1886     CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index);
1887     CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask,
1888             dense_index);
1889     dense_index++;
1890   }
1891 }
1892 
1893 // For the given `input_shape`, calculates the sliced shape using the given
1894 // `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and
1895 // `shrink_axis_mask` masks. Updates the result back to `input_shape`. If
1896 // `shrink_axis_mask` is not zero, this function will not drop the corresponding
1897 // dimensions in `input_shape`; it will turn them into 1s. At the same time,
1898 // canonicalizes `begin`, `end`, and `strides. The calculation follows
1899 // tf.StridedSlice op semantics.
CalculateSlicedShapeFromDenseIndices(MutableArrayRef<int64_t> input_shape,int32_t begin_mask,int32_t end_mask,int32_t shrink_axis_mask,MutableArrayRef<int64_t> begin,MutableArrayRef<int64_t> end,MutableArrayRef<int64_t> stride)1900 static void CalculateSlicedShapeFromDenseIndices(
1901     MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
1902     int32_t shrink_axis_mask, MutableArrayRef<int64_t> begin,
1903     MutableArrayRef<int64_t> end, MutableArrayRef<int64_t> stride) {
1904   assert(input_shape.size() <= 32);  // Only 32-bit masks are supported.
1905 
1906   // Make sure ranges' ranks are consistent with the input.
1907   assert(input_shape.size() == begin.size());
1908   assert(input_shape.size() == end.size());
1909   assert(input_shape.size() == stride.size());
1910 
1911   for (int i = 0, e = input_shape.size(); i < e; ++i) {
1912     if (ShapedType::isDynamic(input_shape[i])) continue;
1913 
1914     int64_t dim_i = input_shape[i];
1915     int64_t begin_i = begin[i];
1916     int64_t end_i = end[i];
1917     int64_t stride_i = stride[i];
1918 
1919     // [0]: mask for begin, [1]: mask for end
1920     int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
1921     // [0]: bound for begin, [1]: bound for end
1922     int64_t bounds[] = {stride_i > 0 ? 0 : -1,
1923                         stride_i > 0 ? dim_i : dim_i - 1};
1924 
1925     // Canonicalizes the given range `point` (begin/end) according to the
1926     // current dimension. `c` means case: 0 for begin, 1 for end.
1927     auto canonicalize = [&](int64_t point, int c) {
1928       if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
1929 
1930       // Add dim as offset to negative range point.
1931       point = point < 0 ? dim_i + point : point;
1932       return Clamp(point, bounds[0], bounds[1]);
1933     };
1934 
1935     begin_i = canonicalize(begin_i, 0);
1936     end_i = canonicalize(end_i, 1);
1937 
1938     int64_t interval_len = end_i - begin_i;
1939     int64_t size_i = 0;
1940     // If internal length is zero or has different sign from stride, it's a
1941     // degenerated case: we are slicing nothing. Otherwise, calculate the sliced
1942     // size.
1943     if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
1944       size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
1945 
1946     begin[i] = begin_i;
1947     if (IsSet(shrink_axis_mask, i)) {
1948       // Shrink this dimension. It means we only take the element at begin_i.
1949       input_shape[i] = 1;
1950       end[i] = begin_i + 1;
1951       stride[i] = 1;
1952     } else {
1953       input_shape[i] = size_i;
1954       end[i] = end_i;
1955       stride[i] = stride_i;
1956     }
1957   }
1958 }
1959 
1960 // For the given `input_shape`, calculates the sliced shape using the given
1961 // `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`,
1962 // `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks.
1963 // Updates the result back to `input_shape`.
CalculateSlicedShapeFromSparseIndices(MutableArrayRef<int64_t> input_shape,ArrayRef<int64_t> sparse_begin,ArrayRef<int64_t> sparse_end,ArrayRef<int64_t> sparse_strides,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,SmallVectorImpl<int64_t> * begin,SmallVectorImpl<int64_t> * end,SmallVectorImpl<int64_t> * stride)1964 static void CalculateSlicedShapeFromSparseIndices(
1965     MutableArrayRef<int64_t> input_shape, ArrayRef<int64_t> sparse_begin,
1966     ArrayRef<int64_t> sparse_end, ArrayRef<int64_t> sparse_strides,
1967     int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
1968     int32_t new_axis_mask, int32_t shrink_axis_mask,
1969     SmallVectorImpl<int64_t> *begin, SmallVectorImpl<int64_t> *end,
1970     SmallVectorImpl<int64_t> *stride) {
1971   int64_t num_sparse_indices = sparse_begin.size();
1972   SparseSliceSpec sparse = {num_sparse_indices, begin_mask,    end_mask,
1973                             ellipsis_mask,      new_axis_mask, shrink_axis_mask,
1974                             sparse_begin,       sparse_end,    sparse_strides};
1975 
1976   // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
1977   // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
1978   // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
1979   if (sparse.ellipsis_mask == 0) {
1980     Set(sparse.ellipsis_mask, sparse.dims);
1981     sparse.dims++;
1982   }
1983 
1984   int64_t dims = input_shape.size();
1985   DenseSliceSpec dense = {dims,
1986                           /*begin_mask = */ 0,
1987                           /*end_mask = */ 0,
1988                           /*shrink_axis_mask = */ 0,
1989                           *begin,
1990                           *end,
1991                           *stride};
1992 
1993   BuildDenseSliceSpec(sparse, &dense);
1994   CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask,
1995                                        dense.end_mask, dense.shrink_axis_mask,
1996                                        *begin, *end, *stride);
1997 }
1998 
GetSlicedBoundRanges(SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)1999 bool StridedSliceOp::GetSlicedBoundRanges(
2000     SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
2001     SmallVectorImpl<int64_t> *slice_stride) {
2002   // TODO(hinsu): Support lowering for ops with dynamic begin and end values
2003   // when it is possible to derive indices based on mask attributes.
2004   DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
2005   if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
2006       !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
2007       !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
2008     return false;
2009 
2010   auto input_ty = this->input().getType().dyn_cast<RankedTensorType>();
2011   if (!input_ty || !input_ty.hasStaticShape()) return false;
2012   auto input_shape = llvm::to_vector<4>(input_ty.getShape());
2013 
2014   SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
2015 
2016   for (const APInt &index : sparse_begin_attr)
2017     sparse_begin.push_back(index.getSExtValue());
2018   for (const APInt &index : sparse_end_attr)
2019     sparse_end.push_back(index.getSExtValue());
2020   for (const APInt &stride : sparse_strides_attr)
2021     sparse_strides.push_back(stride.getSExtValue());
2022 
2023   CalculateSlicedShapeFromSparseIndices(
2024       input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
2025       end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
2026       slice_begin, slice_end, slice_stride);
2027   return true;
2028 }
2029 
fold(ArrayRef<Attribute> operands)2030 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
2031   // Fold StridedSlice operation if it extracts statically known dimensions.
2032   //
2033   // For example,
2034   //
2035   //   %shape  = tf.Shape(%arg)                   // %arg: tensor<?x2x3x1xf32>
2036   //   %height = tf.StridedSlice(%shape, 1, 2, 1)
2037   //
2038   // In this case %height can be replaced with a constant 2.
2039   //
2040   // Or,
2041   //
2042   //   %shape  = tf.Shape(%arg)                   // %arg: tensor<?x2x3x1xf32>
2043   //   %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
2044   //
2045   // In this case %spatial_shape can be replaced with a constant [2, 3].
2046 
2047   // Input to strided slice op is defined by shape operation.
2048   auto shape_op = input().getDefiningOp<ShapeOp>();
2049   if (!shape_op) {
2050     return {};
2051   }
2052 
2053   // `begin`, `end` and `strides` should be constant in order to infer static
2054   // dimension.
2055   DenseIntElementsAttr begin_attr, end_attr, strides_attr;
2056   if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
2057       !matchPattern(end(), m_Constant(&end_attr)) ||
2058       !matchPattern(strides(), m_Constant(&strides_attr)) ||
2059       begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
2060       strides_attr.getNumElements() != 1) {
2061     return {};
2062   }
2063 
2064   // Do not fold when `new_axis_mask` is set. It's likely to break the shape
2065   // of output. Typically, `new_axis_mask` is not set in this canonicalization
2066   // pattern.
2067   if (new_axis_mask() != 0) return {};
2068 
2069   auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
2070   // Only ranked tensor can be folded.
2071   if (!tensor_ty) return {};
2072 
2073   int64_t rank = tensor_ty.getRank();
2074   int64_t begin_int = begin_attr.getValues<APInt>()[0].getSExtValue();
2075   int64_t end_int = end_attr.getValues<APInt>()[0].getSExtValue();
2076   int64_t strides_int = strides_attr.getValues<APInt>()[0].getSExtValue();
2077 
2078   // Canonicalize `begin` and `end` in case of negative index.
2079   if (begin_int < 0) begin_int += rank;
2080   if (end_int < 0) end_int += rank;
2081 
2082   // Create `begin` and `end` from `*_mask`. Note that we don't care about
2083   // `new_axis_mask` as it can be inferred from `output_ty`.
2084   if (shrink_axis_mask() == 1) {
2085     // When `shrink_axis_mask` is set, output is always a scalar so only
2086     // one element is sliced.
2087     end_int = begin_int + 1;
2088   }
2089   if (begin_mask() == 1) {
2090     begin_int = (strides_int > 0) ? 0 : rank - 1;
2091   }
2092   if (end_mask() == 1) {
2093     end_int = (strides_int > 0) ? rank : -1;
2094   }
2095   if (ellipsis_mask() == 1) {
2096     begin_int = 0;
2097     end_int = rank;
2098   }
2099 
2100   // It's possible that `begin` and `end` are out of bound. See
2101   // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
2102   if (strides_int > 0) {
2103     begin_int = std::min(begin_int, rank);
2104     end_int = std::min(end_int, rank);
2105   } else {
2106     begin_int = std::min(begin_int, rank - 1);
2107     end_int = std::min(end_int, rank - 1);
2108   }
2109 
2110   SmallVector<int64_t, 2> sub_shape;
2111   // Only handle cases that have something to slice to avoid infinite for-loop.
2112   if ((end_int > begin_int && strides_int > 0) ||
2113       (end_int < begin_int && strides_int < 0)) {
2114     // Extract sub-shape only if all of those dimensions are static.
2115     for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
2116          i += strides_int) {
2117       if (tensor_ty.isDynamicDim(i)) {
2118         return {};
2119       }
2120       sub_shape.push_back(tensor_ty.getDimSize(i));
2121     }
2122   }
2123 
2124   // For unranked or dynamic output, we infer the output type to either a
2125   // scalar or a vector based on `shrink_axis_mask` because we have rejected
2126   // the case of `new_axis_mask` != 0.
2127   auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
2128   auto output_ty = output().getType().dyn_cast<RankedTensorType>();
2129   if (!output_ty || !output_ty.hasStaticShape()) {
2130     if (shrink_axis_mask() == 1) {
2131       output_ty = RankedTensorType::get({}, output_elt_ty);
2132     } else {
2133       output_ty = RankedTensorType::get(
2134           {static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
2135     }
2136   }
2137 
2138   // Down-cast to 32 bit int if needed.
2139   if (output_elt_ty.isInteger(32)) {
2140     SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
2141     std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
2142                    [](int64_t d) { return static_cast<int32_t>(d); });
2143     return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
2144   }
2145   return DenseIntElementsAttr::get(output_ty, sub_shape);
2146 }
2147 
2148 //===----------------------------------------------------------------------===//
2149 // StridedSliceGradOp
2150 //===----------------------------------------------------------------------===//
2151 
verify()2152 LogicalResult StridedSliceGradOp::verify() {
2153   StridedSliceGradOp op = *this;
2154   auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
2155   if (shape_type && shape_type.getRank() != 1)
2156     return op.emitOpError("'shape' operand must be 1D tensor, but got ")
2157            << shape_type.getRank() << "D tensor";
2158 
2159   if (failed(VerifyStridedSliceBase(op))) return failure();
2160 
2161   // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
2162   // the sliced type from StridedSlice.
2163 
2164   return success();
2165 }
2166 
GetSlicedShapeAndBoundRanges(SmallVectorImpl<int64_t> * input_shape,SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)2167 bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
2168     SmallVectorImpl<int64_t> *input_shape,
2169     SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
2170     SmallVectorImpl<int64_t> *slice_stride) {
2171   DenseIntElementsAttr shape_attr;
2172   DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
2173   if (!matchPattern(shape(), m_Constant(&shape_attr)) ||
2174       !matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
2175       !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
2176       !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
2177     return false;
2178 
2179   int rank = std::distance(shape_attr.begin(), shape_attr.end());
2180 
2181   input_shape->clear();
2182   input_shape->reserve(rank);
2183   for (const APInt &dim : shape_attr)
2184     input_shape->push_back(dim.getSExtValue());
2185 
2186   SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
2187 
2188   for (const APInt &index : sparse_begin_attr)
2189     sparse_begin.push_back(index.getSExtValue());
2190   for (const APInt &index : sparse_end_attr)
2191     sparse_end.push_back(index.getSExtValue());
2192   for (const APInt &stride : sparse_strides_attr)
2193     sparse_strides.push_back(stride.getSExtValue());
2194 
2195   CalculateSlicedShapeFromSparseIndices(
2196       *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
2197       end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
2198       slice_begin, slice_end, slice_stride);
2199   return true;
2200 }
2201 
2202 //===----------------------------------------------------------------------===//
2203 // SummaryWriterOp
2204 //===----------------------------------------------------------------------===//
2205 
2206 llvm::SmallVector<ResourceHandleValueAndId, 4>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2207 SummaryWriterOp::GetResourceHandleValueAndIdList(
2208     llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2209     int64_t &next_id) {
2210   llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2211   return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2212                                           writer(), resource_handle_id_map,
2213                                           next_id)};
2214 }
2215 
2216 //===----------------------------------------------------------------------===//
2217 // TPUExecuteOp
2218 //===----------------------------------------------------------------------===//
2219 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2220 void TPUExecuteOp::getEffects(
2221     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2222         &effects) {
2223   effects.reserve(args().size() + 1);
2224   effects.emplace_back(MemoryEffects::Write::get(),
2225                        ResourceEffects::TPUExecute::get());
2226 
2227   for (Value value : args()) {
2228     if (value.getType()
2229             .cast<TensorType>()
2230             .getElementType()
2231             .isa<ResourceType>()) {
2232       // Conservatively mark resource handles as read and write, as without
2233       // analyzing TPUCompile, there is not sufficient information to determine
2234       // effects on resources. For the MLIR bridge, this op will never be
2235       // populated with resource handles and tf.TPUExecuteAndUpdateVariables is
2236       // used instead.
2237       effects.emplace_back(MemoryEffects::Read::get(), value,
2238                            ResourceEffects::Variable::get());
2239       effects.emplace_back(MemoryEffects::Write::get(), value,
2240                            ResourceEffects::Variable::get());
2241     }
2242   }
2243 }
2244 
2245 //===----------------------------------------------------------------------===//
2246 // TPUExecuteAndUpdateVariablesOp
2247 //===----------------------------------------------------------------------===//
2248 
verify()2249 LogicalResult TPUExecuteAndUpdateVariablesOp::verify() {
2250   TPUExecuteAndUpdateVariablesOp op = *this;
2251   int num_resource_args = 0;
2252   for (Type arg_type : op.args().getTypes())
2253     if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
2254       ++num_resource_args;
2255 
2256   auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
2257                         int min) -> LogicalResult {
2258     if (indices.size() != num_resource_args)
2259       return op.emitOpError()
2260              << "requires '" << name
2261              << "' to be the same size as number of resource handles in 'args' "
2262                 "("
2263              << num_resource_args << "), but got " << indices.size();
2264 
2265     for (auto entry : llvm::enumerate(indices.getValue())) {
2266       auto int_attr = entry.value().cast<IntegerAttr>();
2267       if (int_attr.getInt() < min)
2268         return op.emitOpError()
2269                << "requires '" << name << "' to contain values of at least "
2270                << min << ", but got " << int_attr.getInt() << " at index "
2271                << entry.index();
2272     }
2273 
2274     return success();
2275   };
2276 
2277   return failure(
2278       failed(check_attr(op.device_var_reads_indices(),
2279                         /*name=*/"device_var_reads_indices", /*min=*/0)) ||
2280       failed(check_attr(op.device_var_updates_indices(),
2281                         /*name=*/"device_var_updates_indices", /*min=*/-1)));
2282 }
2283 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2284 void TPUExecuteAndUpdateVariablesOp::getEffects(
2285     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2286         &effects) {
2287   effects.reserve(device_var_reads_indices().size() + 1);
2288   effects.emplace_back(MemoryEffects::Write::get(),
2289                        ResourceEffects::TPUExecute::get());
2290   auto resource_handles = llvm::make_filter_range(args(), [](Value value) {
2291     return value.getType()
2292         .cast<TensorType>()
2293         .getElementType()
2294         .isa<ResourceType>();
2295   });
2296 
2297   for (auto &entry : llvm::enumerate(resource_handles)) {
2298     Value value = entry.value();
2299     effects.emplace_back(MemoryEffects::Read::get(), value,
2300                          ResourceEffects::Variable::get());
2301     if (device_var_updates_indices()
2302             .getValue()[entry.index()]
2303             .cast<IntegerAttr>()
2304             .getInt() >= 0)
2305       effects.emplace_back(MemoryEffects::Write::get(), value,
2306                            ResourceEffects::Variable::get());
2307   }
2308 }
2309 
2310 //===----------------------------------------------------------------------===//
2311 // TensorListGetItemOp
2312 //===----------------------------------------------------------------------===//
2313 
2314 namespace {
2315 // If the input of TensorListGetItemOp is TensorListFromTensorOp and the
2316 // TensorListFromTensorOp is only used by TensorListGetItemOp (not modified by
2317 // other TensorList ops), we can convert it to a GatherOp.
2318 class ConvertTensorListGetItemOpOfTensorListFromTensorOpToGather
2319     : public OpRewritePattern<TensorListGetItemOp> {
2320   using OpRewritePattern<TensorListGetItemOp>::OpRewritePattern;
matchAndRewrite(TensorListGetItemOp op,PatternRewriter & rewriter) const2321   LogicalResult matchAndRewrite(TensorListGetItemOp op,
2322                                 PatternRewriter &rewriter) const override {
2323     // Checks that the input is created by TensorListFromTensorOp and the input
2324     // is only used by TensorListGetItemOp.
2325     auto tensor_list_from_tensor_op = dyn_cast_or_null<TensorListFromTensorOp>(
2326         op.input_handle().getDefiningOp());
2327     if (!tensor_list_from_tensor_op ||
2328         llvm::any_of(
2329             tensor_list_from_tensor_op->getUsers(),
2330             [](Operation *user) { return !isa<TensorListGetItemOp>(user); })) {
2331       return failure();
2332     }
2333 
2334     rewriter.replaceOpWithNewOp<GatherOp>(
2335         op, op.getType(), tensor_list_from_tensor_op.tensor(), op.index());
2336     return success();
2337   }
2338 };
2339 }  // namespace
2340 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2341 void TensorListGetItemOp::getCanonicalizationPatterns(
2342     RewritePatternSet &results, MLIRContext *context) {
2343   results.add<ConvertTensorListGetItemOpOfTensorListFromTensorOpToGather>(
2344       context);
2345 }
2346 
2347 //===----------------------------------------------------------------------===//
2348 // TensorListReserveOp
2349 //===----------------------------------------------------------------------===//
2350 
verify()2351 LogicalResult TensorListReserveOp::verify() {
2352   TensorListReserveOp op = *this;
2353   // This is required to populate derived attributes during export in a
2354   // meaningful way. Else during export to GraphDef element_type() query
2355   // will result in out of bounds access/assert.
2356   if (handle_dtype().getSubtypes().size() != 1) {
2357     return emitOpError(
2358         "must have exactly one subtype in the result variant type");
2359   }
2360   if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2361       !IsOfRankOrUnranked(op.element_shape(), 1)) {
2362     return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2363   }
2364 
2365   if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
2366     return op.emitOpError("requires num_elements operand to be 0D tensor");
2367   }
2368   return success();
2369 }
2370 
2371 //===----------------------------------------------------------------------===//
2372 // TensorListElementShapeOp
2373 //===----------------------------------------------------------------------===//
2374 
fold(ArrayRef<Attribute> operands)2375 OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
2376   int width =
2377       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2378   auto variant_type =
2379       getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
2380   if (variant_type.getSubtypes().empty()) return {};
2381   return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
2382 }
2383 
2384 //===----------------------------------------------------------------------===//
2385 // TensorListStackOp
2386 //===----------------------------------------------------------------------===//
2387 
verify()2388 LogicalResult TensorListStackOp::verify() {
2389   TensorListStackOp op = *this;
2390   if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2391       !IsOfRankOrUnranked(op.element_shape(), 1)) {
2392     return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2393   }
2394   return success();
2395 }
2396 
2397 //===----------------------------------------------------------------------===//
2398 // TensorScatterUpdateOp
2399 //===----------------------------------------------------------------------===//
2400 
verify()2401 LogicalResult TensorScatterUpdateOp::verify() {
2402   TensorScatterUpdateOp op = *this;
2403   if (!HasRankAtLeast(op.tensor(), 1))
2404     return op.emitOpError(
2405         "requires tensor operand to have at least 1 dimension");
2406   if (!HasRankAtLeast(op.indices(), 1))
2407     return op.emitOpError(
2408         "requires indices operand to have at least 1 dimension");
2409 
2410   auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
2411   auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2412   if (!tensor_ty || !indices_ty) return success();
2413 
2414   int64_t num_index_dims = indices_ty.getShape().back();
2415   if (ShapedType::isDynamic(num_index_dims)) return success();
2416 
2417   if (num_index_dims > tensor_ty.getRank())
2418     return op.emitOpError(
2419         "requires tensor operand with rank greater than or equal to the "
2420         "indices operand's last dimensions");
2421   return success();
2422 }
2423 
2424 //===----------------------------------------------------------------------===//
2425 // TileOp
2426 //===----------------------------------------------------------------------===//
2427 
2428 // Verifies that,
2429 //
2430 // - input has at least rank 1
2431 // - multiples is rank 1
2432 // - multiples.size() == input.rank()
2433 // - input.rank() == output.rank()
2434 // - Elements in multiples are non-negative
2435 // - input.shape[i] * multiples[i] == output.shape[i]
2436 //   for i in [0, input.rank() - 1]
2437 
verify()2438 LogicalResult TileOp::verify() {
2439   TileOp op = *this;
2440   auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
2441   auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
2442   auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
2443 
2444   if (multiples_type && multiples_type.getRank() != 1) {
2445     return op.emitOpError() << "expected multiples to be rank 1, got rank = "
2446                             << multiples_type.getRank();
2447   }
2448 
2449   if (input_type && multiples_type && multiples_type.hasStaticShape() &&
2450       (input_type.getRank() != multiples_type.getNumElements() ||
2451        (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
2452     return op.emitOpError()
2453            << "expected size of multiples equal to rank of input"
2454            << ", got multiples of size " << multiples_type.getNumElements()
2455            << ", and input of rank " << input_type.getRank();
2456   }
2457 
2458   if (input_type && output_type) {
2459     if (input_type.getRank() != output_type.getRank()) {
2460       return op.emitOpError()
2461              << "expected rank of input to equal to rank of output"
2462              << ", got input of rank " << input_type.getRank()
2463              << ", and output of rank " << output_type.getRank();
2464     }
2465 
2466     DenseIntElementsAttr multiples_attr;
2467     if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
2468       for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
2469         const int64_t input_dim = input_type.getDimSize(i);
2470         const int64_t output_dim = output_type.getDimSize(i);
2471         const int64_t m = multiples_attr.getValues<APInt>()[i].getSExtValue();
2472 
2473         if (m < 0) {
2474           return op.emitOpError()
2475                  << "expected multiples to be non-negative, got "
2476                  << "multiples[" << i << "] = " << m;
2477         }
2478 
2479         if (!ShapedType::isDynamic(input_dim) &&
2480             !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
2481           return op.emitOpError()
2482                  << "requires input.shape[" << i << "] (" << input_dim << ")"
2483                  << " * " << m << " to be equal to "
2484                  << "output.shape[" << i << "] (" << output_dim << ")";
2485         }
2486       }
2487     }
2488   }
2489 
2490   return success();
2491 }
2492 
fold(ArrayRef<Attribute> operands)2493 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
2494   DenseIntElementsAttr multiples_attr;
2495   if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
2496     // Return input directly when multiples are all ones,
2497     // regardless what input is.
2498     if (multiples_attr.isSplat() &&
2499         multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
2500       return input();
2501     }
2502   }
2503   return {};
2504 }
2505 
2506 //===----------------------------------------------------------------------===//
2507 // TopKV2Op
2508 //===----------------------------------------------------------------------===//
2509 
verify()2510 LogicalResult TopKV2Op::verify() {
2511   TopKV2Op op = *this;
2512   if (!HasRankAtLeast(op.input(), 1))
2513     return op.emitOpError(
2514         "requires input operand to have at least 1 dimension");
2515 
2516   if (!IsOfRankOrUnranked(op.k(), 0))
2517     return op.emitOpError("requires k operand to be 0D tensor");
2518 
2519   return success();
2520 }
2521 
2522 //===----------------------------------------------------------------------===//
2523 // ToBoolOp
2524 //===----------------------------------------------------------------------===//
2525 
2526 namespace {
2527 // If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
2528 // into an identity or an equality comparison.
2529 class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
2530   using OpRewritePattern<ToBoolOp>::OpRewritePattern;
matchAndRewrite(ToBoolOp op,PatternRewriter & rewriter) const2531   LogicalResult matchAndRewrite(ToBoolOp op,
2532                                 PatternRewriter &rewriter) const override {
2533     auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
2534     // If the input is an unranked tensor, cannpt rewrite.
2535     if (!type) return failure();
2536 
2537     // Expected return type of the ToBool operation. The return type of ToBool
2538     // operation is always 0D tensor of bool type.
2539     auto result_type = op.getResult().getType().cast<RankedTensorType>();
2540 
2541     // If input is already a tensor<i1>, it can be folded into an identity.
2542     if (type == result_type) {
2543       rewriter.replaceOp(op, op.getOperand());
2544       return success();
2545     }
2546 
2547     if (type.getRank() == 0) {
2548       // If the input is a scalar tensor, the ToBool can be expanded to
2549       // element != 0 (for numerical values) or element == empty (for string).
2550       Type element_type = type.getElementType();
2551       Attribute zero_attr;
2552       if (element_type.isIntOrFloat())
2553         zero_attr = rewriter.getZeroAttr(type);
2554       else if (element_type.isa<TF::StringType>())
2555         zero_attr = DenseStringElementsAttr::get(type, {""});
2556 
2557       if (!zero_attr) return failure();
2558 
2559       auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
2560       rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
2561           op, result_type, op.getOperand(), zero_const, false);
2562     } else {
2563       // If the input is a non-scalar ranked tensor, ToBool can be expanded
2564       // to numElements != 0. numElements will be 0 iff one of the dimensions is
2565       // zero.
2566       bool any_zero =
2567           llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
2568       rewriter.replaceOpWithNewOp<TF::ConstOp>(
2569           op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
2570     }
2571     return success();
2572   }
2573 };
2574 }  // namespace
2575 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2576 void ToBoolOp::getCanonicalizationPatterns(RewritePatternSet &results,
2577                                            MLIRContext *context) {
2578   results.add<ToBoolOfRankedTensor>(context);
2579 }
2580 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2581 LogicalResult ToBoolOp::inferReturnTypes(
2582     MLIRContext *context, Optional<Location> location, ValueRange operands,
2583     DictionaryAttr attributes, RegionRange regions,
2584     SmallVectorImpl<Type> &inferredReturnTypes) {
2585   inferredReturnTypes.push_back(
2586       RankedTensorType::get({}, IntegerType::get(context, 1)));
2587   return success();
2588 }
2589 
2590 //===----------------------------------------------------------------------===//
2591 // TransposeOp
2592 //===----------------------------------------------------------------------===//
2593 
verify()2594 LogicalResult TransposeOp::verify() {
2595   TransposeOp op = *this;
2596   auto perm_type = op.perm().getType().dyn_cast<RankedTensorType>();
2597   auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
2598   auto y_type = op.y().getType().dyn_cast<RankedTensorType>();
2599 
2600   if (perm_type && perm_type.getRank() != 1) {
2601     return op.emitOpError()
2602            << "expected perm to be a 1-D Tensor, got perm of rank "
2603            << perm_type.getRank();
2604   }
2605 
2606   if (x_type && y_type && x_type.getRank() != y_type.getRank()) {
2607     return op.emitOpError() << "x should be of the same rank with y, got "
2608                             << "x of rank " << x_type.getRank()
2609                             << ", and y of rank " << y_type.getRank();
2610   }
2611 
2612   if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) {
2613     return success();
2614   }
2615 
2616   if (x_type.getRank() != perm_type.getNumElements()) {
2617     return op.emitOpError() << "expected perm to be a 1-D Tensor of size "
2618                             << "equal to the rank of x, got perm of size "
2619                             << perm_type.getNumElements() << ", and x of rank "
2620                             << x_type.getRank();
2621   }
2622 
2623   DenseIntElementsAttr attr_perm;
2624   if (matchPattern(op.perm(), m_Constant(&attr_perm))) {
2625     // y.shape[i] should be equal to x.shape[perm[i]]
2626     // for i = [0, 1, ..., rank(x) - 1]
2627     for (auto e : llvm::enumerate(attr_perm)) {
2628       const int64_t y_idx = e.index();
2629       const int64_t y_dim = y_type.getDimSize(y_idx);
2630       const int64_t x_idx = e.value().getSExtValue();
2631       const int64_t x_dim = x_type.getDimSize(x_idx);
2632       if (!ShapedType::isDynamic(y_dim) && !ShapedType::isDynamic(x_dim) &&
2633           y_dim != x_dim) {
2634         return op.emitOpError()
2635                << "requires y.shape[" << y_idx << "] (" << y_dim << ") "
2636                << "to be equal to x.shape[perm[" << x_idx << "]] "
2637                << "(" << x_dim << ")";
2638       }
2639     }
2640   }
2641 
2642   return success();
2643 }
2644 
2645 // TODO(jpienaar): perm could be optional too.
build(OpBuilder & builder,OperationState & result,Value x,Value perm)2646 void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
2647                         Value perm) {
2648   auto x_type = x.getType().cast<TensorType>();
2649   // If value is unranked, then so is results.
2650   if (!x_type.hasRank())
2651     return TransposeOp::build(builder, result,
2652                               UnrankedTensorType::get(x_type.getElementType()),
2653                               x, perm);
2654 
2655   // TODO(jpienaar): Handle unknown perm case.
2656 
2657   // TODO(jpienaar): Extract utility function.
2658   auto etype = x_type.cast<ShapedType>().getElementType();
2659   DenseIntElementsAttr attr_shape;
2660   if (matchPattern(perm, m_Constant(&attr_shape))) {
2661     llvm::SmallVector<int64_t, 4> const_shape;
2662     if (attr_shape.isSplat()) {
2663       const_shape.assign(
2664           attr_shape.getNumElements(),
2665           x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
2666     } else {
2667       const_shape.reserve(attr_shape.getNumElements());
2668       for (const auto &dim : attr_shape)
2669         const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
2670     }
2671     return TransposeOp::build(
2672         builder, result, RankedTensorType::get(const_shape, etype), x, perm);
2673   }
2674   return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x,
2675                             perm);
2676 }
2677 
2678 namespace {
2679 
FoldIdentityTranspose(TransposeOp op)2680 OpFoldResult FoldIdentityTranspose(TransposeOp op) {
2681   DenseIntElementsAttr perm;
2682   if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
2683   const auto elements = perm.getValues<APInt>();
2684 
2685   for (auto it : llvm::enumerate(elements)) {
2686     if (it.index() != it.value()) return {};
2687   }
2688 
2689   // TODO(jpienaar): Remove if/when we handle this more generally.
2690   if (op.getType() != op.x().getType()) {
2691     // If the types don't match then only fold if all the operands are in the TF
2692     // dialect.
2693     for (auto user : op.getOperation()->getUsers())
2694       if (user->getDialect() != op->getDialect()) return {};
2695   }
2696 
2697   return op.x();
2698 }
2699 
FoldCancellableTranspose(TransposeOp op)2700 OpFoldResult FoldCancellableTranspose(TransposeOp op) {
2701   // Operand is a TransposeOp.
2702   auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
2703   if (!transpose) return {};
2704 
2705   // Permutations defined by constant operations.
2706   DenseIntElementsAttr perm0;
2707   DenseIntElementsAttr perm1;
2708   if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
2709       !matchPattern(transpose.perm(), m_Constant(&perm1)))
2710     return {};
2711 
2712   // With permutation indices that cancel each other
2713   if (!AreCancellablePermutations(perm0, perm1)) return {};
2714 
2715   return transpose.x();
2716 }
2717 
2718 }  // namespace
2719 
fold(ArrayRef<Attribute> operands)2720 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2721   if (auto folded = FoldIdentityTranspose(*this)) return folded;
2722   if (auto folded = FoldCancellableTranspose(*this)) return folded;
2723   return {};
2724 }
2725 
2726 //===----------------------------------------------------------------------===//
2727 // TruncateDivOp
2728 //===----------------------------------------------------------------------===//
2729 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2730 void TruncateDivOp::getCanonicalizationPatterns(RewritePatternSet &results,
2731                                                 MLIRContext *context) {
2732   results.add<TruncateDivWithSqrtDivisor>(context);
2733 }
2734 
2735 //===----------------------------------------------------------------------===//
2736 // NonMaxSuppressionV3Op
2737 //===----------------------------------------------------------------------===//
2738 
2739 namespace {
2740 
2741 // Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op.
2742 class NMSV3ToNMSV4Op : public OpRewritePattern<NonMaxSuppressionV3Op> {
2743   using OpRewritePattern<NonMaxSuppressionV3Op>::OpRewritePattern;
matchAndRewrite(NonMaxSuppressionV3Op nms_op,PatternRewriter & rewriter) const2744   LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op,
2745                                 PatternRewriter &rewriter) const override {
2746     if (nms_op.getNumOperands() != 5) {
2747       return failure();
2748     }
2749     SmallVector<Type, 2> new_result_types;
2750     new_result_types.push_back(nms_op.getType());
2751     auto input_ty = nms_op.getType().template cast<ShapedType>();
2752     // corresponds to the second result type of nmsv4
2753     RankedTensorType valid_output_type =
2754         RankedTensorType::get({}, input_ty.getElementType());
2755     new_result_types.push_back(valid_output_type);
2756 
2757     auto nmsv4 = rewriter.create<TF::NonMaxSuppressionV4Op>(
2758         nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(),
2759         nms_op.max_output_size(), nms_op.iou_threshold(),
2760         nms_op.score_threshold());
2761     // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the
2762     // two are different (v4 expects two output values vs v3 requires only one.
2763     nms_op.replaceAllUsesWith(nmsv4.getResult(0));
2764     return success();
2765   }
2766 };
2767 }  // namespace.
2768 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2769 void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
2770     RewritePatternSet &results, MLIRContext *context) {
2771   results.add<NMSV3ToNMSV4Op>(context);
2772 }
2773 
2774 //===----------------------------------------------------------------------===//
2775 // FusedBatchNormOp
2776 //===----------------------------------------------------------------------===//
2777 
2778 namespace {
2779 
2780 class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
2781   using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,PatternRewriter & rewriter) const2782   LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
2783                                 PatternRewriter &rewriter) const override {
2784     auto new_result_types =
2785         llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
2786     // reserve_space_3
2787     new_result_types.push_back(
2788         UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
2789 
2790     OperationState new_state(tf_fused_batch_norm_op.getLoc(),
2791                              TF::FusedBatchNormV3Op::getOperationName(),
2792                              tf_fused_batch_norm_op.getOperands(),
2793                              new_result_types,
2794                              tf_fused_batch_norm_op->getAttrs());
2795     Operation *tf_fused_batch_norm_op_v3 = rewriter.create(new_state);
2796 
2797     rewriter.replaceOp(tf_fused_batch_norm_op,
2798                        tf_fused_batch_norm_op_v3->getResults().drop_back());
2799     return success();
2800   }
2801 };
2802 }  // namespace.
2803 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2804 void FusedBatchNormOp::getCanonicalizationPatterns(RewritePatternSet &results,
2805                                                    MLIRContext *context) {
2806   results.add<ConvertFusedBatchNorm>(context);
2807 }
2808 
2809 //===----------------------------------------------------------------------===//
2810 // UnpackOp
2811 //===----------------------------------------------------------------------===//
2812 
verify()2813 LogicalResult UnpackOp::verify() {
2814   UnpackOp op = *this;
2815   auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
2816   if (!value_type) return success();
2817 
2818   int64_t value_rank = value_type.getRank();
2819   int64_t axis = op.axis();
2820   if (axis < -value_rank || axis >= value_rank)
2821     return op.emitOpError("axis attribute must be in the range of [-")
2822            << value_rank << ", " << value_rank << ')';
2823 
2824   axis = GetDimForAxis(axis, value_rank);
2825   int64_t dim_size = value_type.getDimSize(axis);
2826   if (ShapedType::isDynamic(dim_size)) return success();
2827 
2828   if (dim_size != op.getNumResults())
2829     return op.emitOpError("result count must be equal to ") << dim_size;
2830 
2831   return success();
2832 }
2833 
2834 namespace {
2835 
2836 // Hoist coefficient-wise unary operation out of the Unpack op:
2837 //
2838 //   %unpacked:N = "tf.Unpack"(%0)
2839 //   %neg0 = "tf.Neg"(%unpacked#0)
2840 //   %neg1 = "tf.Neg"(%unpacked#1)
2841 //   ...
2842 //   %negN-1 = "tf.Neg"(%unpacked:N-1)
2843 //
2844 // Rewrite it to:
2845 //
2846 //   %neg = "tf.Neg"(%0)
2847 //   %unpacked:N = "tf.Unpack"(%neg)
2848 class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
2849  public:
HoistCwiseUnaryOutOfUnpack(MLIRContext * context)2850   explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
2851       : OpRewritePattern<UnpackOp>(context) {}
2852   LogicalResult matchAndRewrite(UnpackOp op,
2853                                 PatternRewriter &rewriter) const override;
2854 };
2855 
matchAndRewrite(UnpackOp op,PatternRewriter & rewriter) const2856 LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
2857     UnpackOp op, PatternRewriter &rewriter) const {
2858   auto loc = op.getLoc();
2859 
2860   // First unpack user must be coeff-wise unary operation.
2861   Operation *first_user = *op->getUsers().begin();
2862   if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
2863 
2864   // All unpack users must be defined by the op of same kind.
2865   bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
2866     return user->getName() == first_user->getName();
2867   });
2868   if (!users_same_op) return failure();
2869 
2870   // Pass unpack operand to unary operation.
2871   OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
2872                                     op.getOperand(), op.getOperand().getType(),
2873                                     ArrayRef<NamedAttribute>());
2874   Operation *new_unary_op = rewriter.create(new_unary_op_state);
2875 
2876   // Unpack results after applying unary operation.
2877   auto unpack_unary_op = rewriter.create<UnpackOp>(
2878       loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
2879 
2880   // Bypass all users of the original unpack operation and use `unpack_unary_op`
2881   // results instead.
2882   for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
2883     OpResult old_result = std::get<0>(pair);  // result of original Unpack
2884     OpResult new_result = std::get<1>(pair);  // result of transformed Unpack
2885     for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
2886       rewriter.replaceOp(user, ValueRange(new_result));
2887   }
2888 
2889   // Erase original unpack operation.
2890   rewriter.eraseOp(op.getOperation());
2891 
2892   return success();
2893 }
2894 
2895 }  // namespace
2896 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2897 void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
2898                                            MLIRContext *context) {
2899   results.add<HoistCwiseUnaryOutOfUnpack>(context);
2900 }
2901 
2902 //===----------------------------------------------------------------------===//
2903 // Unsorted segment reduction ops
2904 //===----------------------------------------------------------------------===//
2905 
2906 template <class Op>
VerifyUnsortedSegmentReduction(Op op)2907 static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
2908   if (!HasRankAtMost(op.num_segments(), 0))
2909     return op.emitOpError("number of segments should be a 0-D tensor");
2910 
2911   auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
2912   auto segment_ids_type =
2913       op.segment_ids().getType().template dyn_cast<RankedTensorType>();
2914   if (data_type && segment_ids_type) {
2915     if (data_type.getRank() < segment_ids_type.getRank())
2916       return op.emitOpError(
2917           "requires segment ids rank to be less than or equal to data's rank");
2918 
2919     int index = 0;
2920     for (auto shape_pair :
2921          llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
2922       int64_t segment_id_dim = std::get<0>(shape_pair);
2923       int64_t data_dim = std::get<1>(shape_pair);
2924       if (!ShapedType::isDynamic(segment_id_dim) &&
2925           !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
2926         return op.emitOpError(
2927                    "requires segment ids shape to be a prefix of data shape, "
2928                    "but dimension #")
2929                << index << " differs: " << segment_id_dim << " vs. "
2930                << data_dim;
2931       ++index;
2932     }
2933   }
2934 
2935   DenseIntElementsAttr num_segments_attr;
2936   if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
2937     int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
2938     if (num_segments < 0)
2939       return op.emitOpError("num of segments cannot be negative");
2940   }
2941 
2942   return success();
2943 }
2944 
verify()2945 LogicalResult UnsortedSegmentMaxOp::verify() {
2946   return VerifyUnsortedSegmentReduction(*this);
2947 }
verify()2948 LogicalResult UnsortedSegmentMinOp::verify() {
2949   return VerifyUnsortedSegmentReduction(*this);
2950 }
verify()2951 LogicalResult UnsortedSegmentProdOp::verify() {
2952   return VerifyUnsortedSegmentReduction(*this);
2953 }
verify()2954 LogicalResult UnsortedSegmentSumOp::verify() {
2955   return VerifyUnsortedSegmentReduction(*this);
2956 }
2957 
2958 //===----------------------------------------------------------------------===//
2959 // VarHandleOp
2960 //===----------------------------------------------------------------------===//
2961 
verify()2962 LogicalResult VarHandleOp::verify() {
2963   // VarHandleOp requires the resource handle supply a single subtype from
2964   // which to derive the dtype and shape attributes.
2965   if (resource_type().getSubtypes().size() != 1) {
2966     return emitOpError(
2967         "must have exactly one subtype in the result resource type");
2968   }
2969 
2970   return success();
2971 }
2972 
2973 llvm::SmallVector<ResourceHandleValueAndId, 4>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2974 VarHandleOp::GetResourceHandleValueAndIdList(
2975     llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2976     int64_t &next_id) {
2977   llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2978   return {GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2979                                           resource(), resource_handle_id_map,
2980                                           next_id)};
2981 }
2982 
2983 //===----------------------------------------------------------------------===//
2984 // VarIsInitializedOp
2985 //===----------------------------------------------------------------------===//
2986 
2987 namespace {
2988 
2989 /// Erase VarIsInitializedOp operations with no uses. This op has side effect on
2990 /// resources (read-only), but can still be deleted if it has zero uses.
2991 struct EraseDeadVarIsInitializedOp
2992     : public OpRewritePattern<VarIsInitializedOp> {
2993   using OpRewritePattern<VarIsInitializedOp>::OpRewritePattern;
2994 
matchAndRewritemlir::TF::__anonf66a1e981611::EraseDeadVarIsInitializedOp2995   LogicalResult matchAndRewrite(VarIsInitializedOp op,
2996                                 PatternRewriter &rewriter) const override {
2997     if (!op.use_empty()) return failure();
2998     rewriter.eraseOp(op);
2999     return success();
3000   }
3001 };
3002 }  // end anonymous namespace.
3003 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)3004 void VarIsInitializedOp::getCanonicalizationPatterns(
3005     RewritePatternSet &patterns, MLIRContext *context) {
3006   patterns.add<EraseDeadVarIsInitializedOp>(context);
3007 }
3008 
3009 //===----------------------------------------------------------------------===//
3010 // VariableOp
3011 //===----------------------------------------------------------------------===//
3012 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3013 void VariableOp::getCanonicalizationPatterns(RewritePatternSet &results,
3014                                              MLIRContext *context) {
3015   results.add<VariableToVariableV2>(context);
3016 }
3017 
3018 //===----------------------------------------------------------------------===//
3019 // VariableShapeOp
3020 //===----------------------------------------------------------------------===//
3021 
verify()3022 LogicalResult VariableShapeOp::verify() {
3023   VariableShapeOp op = *this;
3024   auto input_type = op.input().getType().cast<TensorType>();
3025   if (input_type.hasStaticShape() && input_type.getNumElements() != 1)
3026     return op.emitOpError("requires input to have one resource");
3027 
3028   auto resource_type = input_type.getElementType().cast<TF::ResourceType>();
3029   auto subtypes = resource_type.getSubtypes();
3030   switch (subtypes.size()) {
3031     case 1:
3032       return VerifyShapeOperandAndResult(
3033           op, resource_type.getSubtypes().front(), op.getType());
3034     case 0:
3035       return VerifyShapeOperandAndResult(op, Type(), op.getType());
3036     default:
3037       return op.emitOpError(
3038           "requires resource input type to have at most 1 subtype");
3039   }
3040 }
3041 
fold(ArrayRef<Attribute> operands)3042 OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
3043   int width =
3044       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
3045   auto resource_type =
3046       getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
3047   if (resource_type.getSubtypes().empty()) return {};
3048   return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
3049 }
3050 
3051 //===----------------------------------------------------------------------===//
3052 // WhileOp
3053 //===----------------------------------------------------------------------===//
3054 
VerifyWhileTypes(Operation * op,TypeRange cond_input,TypeRange body_input,TypeRange body_result,bool shape_invariant)3055 static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
3056                                       TypeRange body_input,
3057                                       TypeRange body_result,
3058                                       bool shape_invariant) {
3059   const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
3060   const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
3061   constexpr int kNumRegionTypeLists = 3;
3062   const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
3063       {body_result, "body result"},
3064       {cond_input, "condition input"},
3065       {body_input, "body input"},
3066   }};
3067 
3068   // A pair of type lists should be cast compatible with each other if one is
3069   // converted to the another for a function call or assignment or there is a
3070   // common source of inputs for both. Therefore, the While op requires the
3071   // following pairs of type lists to be cast compatible for the tensor_cast
3072   // operation:
3073   //
3074   // * Operands and cond inputs to call the cond function before the
3075   //   first iteration.
3076   // * Operands and body inputs to call the body function for the first
3077   //   iteration if the cond functions returns True or equivalent result.
3078   // * Operands and results to assign cond function arguments to op results if
3079   //   the cond function returns False or equivalent result. If the op is shape
3080   //   invariant, this does not hold as shapes can differ.
3081   // * All three pairs using cond inputs, body inputs and results as operand is
3082   //   a common source for all three.
3083   // * Body result and cond inputs to call the cond function for the subsequent
3084   //   iterations. Similarly, Body result should be compatible with body inputs
3085   //   and op results.
3086   //
3087   // Note that the operands and body results need not be compatible as they are
3088   // never converted from one to the another nor there is a common source
3089   // tensors. Compatibility requirement is not transitive.
3090 
3091   if (!shape_invariant &&
3092       failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
3093     return failure();
3094 
3095   // Skip the first pair as the While op operands and body function results does
3096   // not need to be compatible with each other.
3097   for (int i = 1; i < kNumRegionTypeLists; ++i)
3098     if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
3099       return failure();
3100 
3101   for (int i = 0; i < kNumRegionTypeLists; ++i)
3102     if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
3103       return failure();
3104 
3105   for (int i = 0; i < kNumRegionTypeLists; ++i)
3106     for (int j = i + 1; j < kNumRegionTypeLists; ++j)
3107       if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
3108                                                region_types[j])))
3109         return failure();
3110 
3111   return success();
3112 }
3113 
verifySymbolUses(SymbolTableCollection & symbol_table)3114 LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
3115   auto cond_fn =
3116       symbol_table.lookupNearestSymbolFrom<func::FuncOp>(*this, condAttr());
3117   auto body_fn =
3118       symbol_table.lookupNearestSymbolFrom<func::FuncOp>(*this, bodyAttr());
3119   if (!cond_fn) {
3120     return emitOpError("cond refers to an undefined function : ") << cond();
3121   }
3122   if (!body_fn) {
3123     return emitOpError("body refers to an undefined function : ") << body();
3124   }
3125 
3126   auto cond_fn_type = cond_fn.getFunctionType();
3127   auto body_fn_type = body_fn.getFunctionType();
3128 
3129   // Verify that the cond function has exactly one result.
3130   if (cond_fn_type.getNumResults() != 1)
3131     return emitOpError("requires cond function to have exactly one result");
3132 
3133   return VerifyWhileTypes(*this, /*cond_input=*/cond_fn_type.getInputs(),
3134                           /*body_input=*/body_fn_type.getInputs(),
3135                           /*body_result=*/body_fn_type.getResults(),
3136                           shape_invariant());
3137 }
3138 
3139 //===----------------------------------------------------------------------===//
3140 // WhileRegionOp
3141 //===----------------------------------------------------------------------===//
verify()3142 LogicalResult WhileRegionOp::verify() {
3143   WhileRegionOp op = *this;
3144   // Verify that the condition generates a single tensor<i1> result.
3145   Operation *cond_yield = op.cond().front().getTerminator();
3146   if (cond_yield->getNumOperands() != 1)
3147     return op.emitOpError()
3148            << "condition should have a single tensor<i1> result";
3149 
3150   auto cond_type =
3151       cond_yield->getOperand(0).getType().dyn_cast<RankedTensorType>();
3152   if (!cond_type || !cond_type.getShape().equals({}) ||
3153       !cond_type.getElementType().isInteger(/*width=*/1))
3154     return op.emitOpError()
3155            << "condition should have a single tensor<i1> result";
3156 
3157   Operation *body_yield = op.body().front().getTerminator();
3158   if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
3159                               /*body_input=*/op.body().getArgumentTypes(),
3160                               /*body_result=*/body_yield->getOperandTypes(),
3161                               op.shape_invariant())))
3162     return failure();
3163   return success();
3164 }
3165 
3166 //===----------------------------------------------------------------------===//
3167 // WhileRegionOp LoopLikeOpInterface
3168 //===----------------------------------------------------------------------===//
3169 
getLoopBody()3170 Region &WhileRegionOp::getLoopBody() { return body(); }
3171 
3172 //===----------------------------------------------------------------------===//
3173 // WhileRegionOp canonicalization
3174 //===----------------------------------------------------------------------===//
3175 namespace {
3176 // Eliminate values that pass through the WhileRegionOp body.
3177 struct WhileRegionEliminatePassThrough
3178     : public OpRewritePattern<WhileRegionOp> {
3179   using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
3180 
matchAndRewritemlir::TF::__anonf66a1e981711::WhileRegionEliminatePassThrough3181   LogicalResult matchAndRewrite(WhileRegionOp while_op,
3182                                 PatternRewriter &rewriter) const override {
3183     // Remove any extern values that are explicitly captured and returned. Also
3184     // replace values that simply passthrough the body with extern values. The
3185     // block arguments of body and while match and so the corresponding cond
3186     // argument can be easily found.
3187     int old_num_operands = while_op.getNumOperands();
3188     int new_num_operands = old_num_operands;
3189     auto &body_block = while_op.body().front();
3190     auto &cond_block = while_op.cond().front();
3191     auto &yield = *body_block.getTerminator();
3192 
3193     // Bit mask indicating which operands will be removed.
3194     llvm::BitVector removed_operand(old_num_operands);
3195 
3196     for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3197       auto body_arg = body_block.getArgument(op_idx);
3198       auto yield_operand = LookThroughIdentity(yield.getOperand(op_idx));
3199       auto while_operand = while_op.getOperand(op_idx);
3200       if (body_arg == yield_operand || while_operand == yield_operand) {
3201         // Replace the use of the passthrough value with the while operand
3202         // in the body and condition regions, as well as the while output (if
3203         // type match)
3204         // TODO(jurahul): Use PatternRewriter API for IR modification.
3205         if (body_arg.getType() == while_operand.getType())
3206           body_arg.replaceAllUsesWith(while_operand);
3207 
3208         auto cond_arg = cond_block.getArgument(op_idx);
3209         if (cond_arg.getType() == while_operand.getType())
3210           cond_arg.replaceAllUsesWith(while_operand);
3211 
3212         auto result = while_op.getResult(op_idx);
3213         if (result.getType() == while_operand.getType())
3214           result.replaceAllUsesWith(while_operand);
3215       }
3216 
3217       // Now check if the operand is unused in both regions as well as the
3218       // result. If so, mark it for removal.
3219       if (body_block.getArgument(op_idx).use_empty() &&
3220           cond_block.getArgument(op_idx).use_empty() &&
3221           while_op.getResult(op_idx).use_empty()) {
3222         removed_operand.set(op_idx);
3223         new_num_operands--;
3224       }
3225     }
3226 
3227     if (new_num_operands == old_num_operands) return failure();
3228 
3229     // Compress the operands, region arguments, and outputs.
3230     SmallVector<Value, 4> new_while_operands;
3231     SmallVector<Type, 4> new_result_types;
3232     new_while_operands.reserve(new_num_operands);
3233     new_result_types.reserve(new_num_operands);
3234 
3235     // Build new operands and result type.
3236     for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3237       if (removed_operand.test(op_idx)) continue;
3238       new_while_operands.push_back(while_op.getOperand(op_idx));
3239       new_result_types.push_back(while_op.getResult(op_idx).getType());
3240     }
3241 
3242     // Create the new while operation.
3243     auto new_while_op = rewriter.create<WhileRegionOp>(
3244         while_op.getLoc(), new_result_types, new_while_operands,
3245         while_op->getAttrs());
3246 
3247     // Move region bodies to the new while.
3248     rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
3249                                 new_while_op.cond().end());
3250     rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
3251                                 new_while_op.body().end());
3252 
3253     auto &new_cond_block = new_while_op.cond().front();
3254     auto &new_body_block = new_while_op.body().front();
3255     auto &new_yield = *new_body_block.getTerminator();
3256 
3257     // Patch up the region bodies and yield.
3258     new_cond_block.eraseArguments(removed_operand);
3259     new_body_block.eraseArguments(removed_operand);
3260     new_yield.eraseOperands(removed_operand);
3261 
3262     // Build a vector of new results. Also patch up the region bodies and
3263     // yield.
3264     SmallVector<Value, 4> new_results(old_num_operands);
3265     int next_idx = 0;
3266     for (int op_idx : llvm::seq<int>(0, old_num_operands))
3267       if (!removed_operand.test(op_idx))
3268         new_results[op_idx] = new_while_op.getResult(next_idx++);
3269 
3270     rewriter.replaceOp(while_op, new_results);
3271     return success();
3272   }
3273 };
3274 
3275 }  // anonymous namespace
3276 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3277 void WhileRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
3278                                                 MLIRContext *context) {
3279   results.add<WhileRegionEliminatePassThrough>(context);
3280 }
3281 
3282 //===----------------------------------------------------------------------===//
3283 // XdivyOp
3284 //===----------------------------------------------------------------------===//
3285 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3286 void XdivyOp::getCanonicalizationPatterns(RewritePatternSet &results,
3287                                           MLIRContext *context) {
3288   results.add<XdivyWithSqrtDivisor>(context);
3289 }
3290 
3291 //===----------------------------------------------------------------------===//
3292 // XlaBroadcastHelperOp
3293 //===----------------------------------------------------------------------===//
3294 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3295 LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents(
3296     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
3297     DictionaryAttr attributes, RegionRange regions,
3298     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3299   XlaBroadcastHelperOpAdaptor op(operands.getValues(), attributes);
3300   Value lhs = op.lhs();
3301   Value rhs = op.rhs();
3302   auto set_unranked_results = [&]() {
3303     inferredReturnShapes.emplace_back(getElementTypeOrSelf(lhs));
3304     inferredReturnShapes.emplace_back(getElementTypeOrSelf(rhs));
3305     return success();
3306   };
3307 
3308   RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
3309   RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
3310   if (!lhs_ty || !rhs_ty) return set_unranked_results();
3311 
3312   int64_t lhs_rank = lhs_ty.getRank();
3313   int64_t rhs_rank = rhs_ty.getRank();
3314 
3315   DenseIntElementsAttr dims;
3316   if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
3317     return set_unranked_results();
3318   }
3319 
3320   if (dims.size() == 0) {
3321     if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
3322       return emitOptionalError(
3323           location,
3324           "if broadcast_dims is empty, both arguments must have equal rank or "
3325           "at least one argument must be a scalar");
3326     }
3327     inferredReturnShapes.emplace_back(lhs_ty.cast<ShapedType>());
3328     inferredReturnShapes.emplace_back(rhs_ty.cast<ShapedType>());
3329     return success();
3330   }
3331 
3332   const bool broadcast_lhs = lhs_rank < rhs_rank;
3333   RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
3334   RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
3335 
3336   if (dims.size() != min_rank_ty.getRank()) {
3337     return emitOptionalError(
3338         location,
3339         "broadcast_dims must have size equal to the smaller argument rank");
3340   }
3341 
3342   int64_t output_rank = max_rank_ty.getRank();
3343   llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
3344   llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
3345   for (auto item : llvm::enumerate(dims)) {
3346     int64_t index = item.index();
3347     int64_t dim = item.value().getSExtValue();
3348     if (dim < 0 || dim > output_rank) {
3349       return emitOptionalError(location, "out of range broadcast dim");
3350     }
3351     if (is_broadcasted[dim]) {
3352       return emitOptionalError(location, "broadcast_dims has duplicates");
3353     }
3354     broadcast_shape[dim] = min_rank_ty.getDimSize(index);
3355     is_broadcasted[dim] = true;
3356   }
3357 
3358   if (broadcast_lhs) {
3359     inferredReturnShapes.emplace_back(broadcast_shape, lhs_ty.getElementType());
3360     inferredReturnShapes.emplace_back(rhs_ty.cast<ShapedType>());
3361   } else {
3362     inferredReturnShapes.emplace_back(lhs_ty.cast<ShapedType>());
3363     inferredReturnShapes.emplace_back(broadcast_shape, rhs_ty.getElementType());
3364   }
3365   return success();
3366 }
3367 
3368 //===----------------------------------------------------------------------===//
3369 // XlaConvOp
3370 //===----------------------------------------------------------------------===//
3371 
3372 class XlaConvToV2 : public OpRewritePattern<TF::XlaConvOp> {
3373  public:
3374   using OpRewritePattern::OpRewritePattern;
3375 
matchAndRewrite(TF::XlaConvOp op,PatternRewriter & rewriter) const3376   LogicalResult matchAndRewrite(TF::XlaConvOp op,
3377                                 PatternRewriter &rewriter) const override {
3378     SmallVector<Type> result_types{op.getResult().getType()};
3379     rewriter.replaceOpWithNewOp<TF::XlaConvV2Op>(
3380         op, op.getResult().getType(), op.lhs(), op.rhs(), op.window_strides(),
3381         op.padding(), op.lhs_dilation(), op.rhs_dilation(),
3382         op.feature_group_count(), op.dimension_numbers(), op.precision_config(),
3383         1);
3384     return ::mlir::success();
3385   };
3386 };
3387 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3388 void XlaConvOp::getCanonicalizationPatterns(RewritePatternSet &results,
3389                                             MLIRContext *context) {
3390   results.insert<XlaConvToV2>(context);
3391 }
3392 
3393 //===----------------------------------------------------------------------===//
3394 // XlaConvV2Op
3395 //===----------------------------------------------------------------------===//
3396 
verify()3397 LogicalResult XlaConvV2Op::verify() {
3398   XlaConvV2Op op = *this;
3399   DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr,
3400       rhs_dilation_attr, feature_group_count_attr;
3401   if (!(matchPattern(op.window_strides(), m_Constant(&window_strides_attr)) &&
3402         matchPattern(op.padding(), m_Constant(&padding_attr)) &&
3403         matchPattern(op.lhs_dilation(), m_Constant(&lhs_dilation_attr)) &&
3404         matchPattern(op.rhs_dilation(), m_Constant(&rhs_dilation_attr)) &&
3405         matchPattern(op.feature_group_count(),
3406                      m_Constant(&feature_group_count_attr))))
3407     return success();
3408 
3409   if (window_strides_attr.getType().getRank() != 1)
3410     return op.emitOpError() << "expects window_stride to be a vector";
3411 
3412   const ShapedType &padding_ty = padding_attr.getType();
3413   if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2)
3414     return op.emitOpError()
3415            << "expects padding to be a matrix with minor dimension 2";
3416 
3417   if (lhs_dilation_attr.getType().getRank() != 1)
3418     return op.emitOpError() << "expects lhs_dilation to be a vecotr";
3419 
3420   if (rhs_dilation_attr.getType().getRank() != 1)
3421     return op.emitOpError() << "expects rhs_dilation to be a vecotr";
3422 
3423   if (feature_group_count_attr.getType().getRank())
3424     return op.emitOpError() << "expects feature_group_count to be a scalar";
3425 
3426   return success();
3427 }
3428 
3429 //===----------------------------------------------------------------------===//
3430 // XlaSetDynamicDimensionSizeOp
3431 //===----------------------------------------------------------------------===//
3432 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3433 LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents(
3434     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
3435     DictionaryAttr attributes, RegionRange regions,
3436     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3437   XlaSetDynamicDimensionSizeOpAdaptor op(operands.getValues(), attributes);
3438 
3439   TensorType operand_ty = op.input().getType().cast<TensorType>();
3440   Type element_ty = operand_ty.getElementType();
3441 
3442   TensorType result_ty;
3443   if (operand_ty.hasRank()) {
3444     auto shape = llvm::to_vector<4>(operand_ty.getShape());
3445 
3446     DenseIntElementsAttr dim_index_attr;
3447     if (matchPattern(op.dim_index(), m_Constant(&dim_index_attr))) {
3448       int64_t dim_index = dim_index_attr.getValues<APInt>()[0].getSExtValue();
3449 
3450       int64_t rank = operand_ty.getRank();
3451       if (dim_index < 0 || dim_index >= rank) {
3452         return emitOptionalError(location, "dim_index (", dim_index,
3453                                  ") is out of range [0, ", rank, ")");
3454       }
3455       shape[dim_index] = ShapedType::kDynamicSize;
3456     } else {
3457       shape.assign(shape.size(), ShapedType::kDynamicSize);
3458     }
3459     result_ty = RankedTensorType::get(shape, element_ty);
3460   } else {
3461     result_ty = UnrankedTensorType::get(element_ty);
3462   }
3463 
3464   inferredReturnShapes.emplace_back(result_ty.cast<ShapedType>());
3465   return success();
3466 }
3467 
3468 //===----------------------------------------------------------------------===//
3469 // XlaReduceOp
3470 //===----------------------------------------------------------------------===//
3471 
3472 class XlaReduceToXlaVariadicReduceV2
3473     : public OpRewritePattern<TF::XlaReduceOp> {
3474  public:
3475   using OpRewritePattern::OpRewritePattern;
3476 
matchAndRewrite(TF::XlaReduceOp op,PatternRewriter & rewriter) const3477   LogicalResult matchAndRewrite(TF::XlaReduceOp op,
3478                                 PatternRewriter &rewriter) const override {
3479     SmallVector<Value> inputs{op.input()};
3480     SmallVector<Value> init_values{op.init_value()};
3481     SmallVector<Type> result_types{op.getResult().getType()};
3482     rewriter.replaceOpWithNewOp<TF::XlaVariadicReduceV2Op>(
3483         op, result_types, inputs, init_values, op.dimensions_to_reduce(),
3484         op.reducer());
3485     return ::mlir::success();
3486   };
3487 };
3488 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3489 void XlaReduceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3490                                               MLIRContext *context) {
3491   results.add<XlaReduceToXlaVariadicReduceV2>(context);
3492 }
3493 
3494 //===----------------------------------------------------------------------===//
3495 // XlaReduceWindowOp
3496 //===----------------------------------------------------------------------===//
3497 
verify()3498 LogicalResult XlaReduceWindowOp::verify() {
3499   XlaReduceWindowOp op = *this;
3500   const auto &input_ty = op.input().getType().cast<ShapedType>();
3501 
3502   auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult {
3503     ElementsAttr attr;
3504     if (matchPattern(val, m_Constant(&attr))) {
3505       if (attr.getType().getRank() != 1) {
3506         return op.emitOpError() << "expects the rank of " << attr_name
3507                                 << "to be 1, got " << attr.getType().getRank();
3508       }
3509       if (input_ty.hasRank()) {
3510         int64_t input_rank = input_ty.getRank();
3511         int64_t size = attr.size();
3512         if (input_rank != size) {
3513           return op.emitOpError() << "expects the size of " << attr_name
3514                                   << " to be equal to the input "
3515                                      "rank ("
3516                                   << size << " vs. " << input_rank << ")";
3517         }
3518       }
3519     }
3520     return success();
3521   };
3522 
3523   if (check(op.window_dimensions(), "window_dimensions").failed())
3524     return failure();
3525 
3526   if (check(op.window_strides(), "window_strides").failed()) return failure();
3527 
3528   if (check(op.base_dilations(), "base_dilations").failed()) return failure();
3529 
3530   if (check(op.window_dilations(), "window_dilations").failed())
3531     return failure();
3532 
3533   ElementsAttr padding;
3534   if (matchPattern(op.padding(), m_Constant(&padding))) {
3535     const ShapedType &padding_ty = padding.getType();
3536     if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) {
3537       return op.emitOpError()
3538              << "expects padding to be a matrix with minor dimension 2, got "
3539              << padding.getType().getShape();
3540     }
3541   }
3542 
3543   auto module = op->getParentOfType<mlir::ModuleOp>();
3544   auto func = dyn_cast_or_null<mlir::func::FuncOp>(
3545       SymbolTable::lookupSymbolIn(module, op.computation()));
3546   if (!func) {
3547     return op.emitOpError() << "has no reduction function specified";
3548   }
3549 
3550   auto func_type = func.getFunctionType();
3551 
3552   if (func_type.getNumInputs() != 2) {
3553     return op.emitOpError()
3554            << "expects reduction function to take 2 parameters, but "
3555               "has "
3556            << func_type.getNumInputs() << " parameter(s)";
3557   }
3558 
3559   return success();
3560 }
3561 
3562 //===----------------------------------------------------------------------===//
3563 // XlaSelectAndScatterOp
3564 //===----------------------------------------------------------------------===//
3565 
verify()3566 LogicalResult XlaSelectAndScatterOp::verify() {
3567   XlaSelectAndScatterOp op = *this;
3568   auto input_ty = op.operand().getType().cast<ShapedType>();
3569 
3570   auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult {
3571     ElementsAttr attr;
3572     if (input_ty.hasRank() && matchPattern(val, m_Constant(&attr))) {
3573       int64_t input_rank = input_ty.getRank();
3574       int64_t size = attr.size();
3575       if (input_rank != size) {
3576         return op.emitOpError() << "expects the size of " << attr_name
3577                                 << "to be equal to the input "
3578                                    "rank ("
3579                                 << size << " vs. " << input_rank << ")";
3580       }
3581     }
3582     return success();
3583   };
3584 
3585   if (check(op.window_dimensions(), "window_dimensions").failed())
3586     return failure();
3587 
3588   if (check(op.window_strides(), "window_strides").failed()) return failure();
3589 
3590   ElementsAttr padding;
3591   if (matchPattern(op.padding(), m_Constant(&padding))) {
3592     const ShapedType &padding_ty = padding.getType();
3593     if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) {
3594       return op.emitOpError()
3595              << "expects padding to be a matrix with minor dimension 2, got "
3596              << padding.getType().getShape();
3597     }
3598   }
3599 
3600   auto module = op->getParentOfType<mlir::ModuleOp>();
3601   auto select_func = dyn_cast_or_null<mlir::func::FuncOp>(
3602       SymbolTable::lookupSymbolIn(module, op.select()));
3603   if (!select_func) {
3604     return op.emitOpError() << "has no select function specified";
3605   }
3606   auto select_func_type = select_func.getFunctionType();
3607   if (select_func_type.getNumInputs() != 2) {
3608     return op.emitOpError()
3609            << "expects select function to take 2 parameters, but has "
3610            << select_func_type.getNumInputs() << " parameter(s)";
3611   }
3612   if (select_func_type.getNumResults() != 1 ||
3613       !getElementTypeOrSelf(select_func_type.getResult(0)).isInteger(1)) {
3614     return op.emitOpError() << "expects select function to return a single "
3615                                "boolean result but got "
3616                             << select_func_type.getResult(0);
3617   }
3618   auto scatter_func = dyn_cast_or_null<mlir::func::FuncOp>(
3619       SymbolTable::lookupSymbolIn(module, op.scatter()));
3620   if (!scatter_func) {
3621     return op.emitOpError() << "has no scatter function specified";
3622   }
3623   auto scatter_func_type = scatter_func.getFunctionType();
3624   if (scatter_func_type.getNumInputs() != 2) {
3625     return op.emitOpError()
3626            << "expects scatter function to take 2 parameters, but has "
3627            << scatter_func_type.getNumInputs() << " parameter(s)";
3628   }
3629 
3630   return success();
3631 }
3632 
3633 //===----------------------------------------------------------------------===//
3634 // XlaVariadicReduceOp
3635 //===----------------------------------------------------------------------===//
3636 
verify()3637 LogicalResult XlaVariadicReduceOp::verify() {
3638   XlaVariadicReduceOp op = *this;
3639   // We rely on V2 for the majority of the checks.
3640   const auto &input_ty = op.input().getType();
3641   if (input_ty.empty()) return op.emitOpError() << "No input";
3642   const auto &dtype = input_ty[0].cast<TensorType>().getElementType();
3643   for (const auto &ty : input_ty) {
3644     if (ty.cast<TensorType>().getElementType() != dtype)
3645       return op.emitOpError()
3646              << "This version is limited to operands of the same dtype";
3647   }
3648   return success();
3649 }
3650 
3651 class XlaVariadicReduceToV2 : public OpRewritePattern<TF::XlaVariadicReduceOp> {
3652  public:
3653   using OpRewritePattern::OpRewritePattern;
3654 
matchAndRewrite(TF::XlaVariadicReduceOp op,PatternRewriter & rewriter) const3655   LogicalResult matchAndRewrite(TF::XlaVariadicReduceOp op,
3656                                 PatternRewriter &rewriter) const override {
3657     mlir::TF::XlaVariadicReduceV2Op xla_variadic_reduce_v2_op =
3658         rewriter.create<::mlir::TF::XlaVariadicReduceV2Op>(
3659             op.getLoc(), op.getResults().getTypes(), op.input(),
3660             op.init_value(), op.dimensions_to_reduce(), op.reducer());
3661 
3662     rewriter.replaceOp(op, xla_variadic_reduce_v2_op.getResults());
3663     return ::mlir::success();
3664   };
3665 };
3666 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3667 void XlaVariadicReduceOp::getCanonicalizationPatterns(
3668     RewritePatternSet &results, MLIRContext *context) {
3669   results.add<XlaVariadicReduceToV2>(context);
3670 }
3671 
3672 //===----------------------------------------------------------------------===//
3673 // XlaVariadicReduceV2Op
3674 //===----------------------------------------------------------------------===//
3675 
verify()3676 LogicalResult XlaVariadicReduceV2Op::verify() {
3677   XlaVariadicReduceV2Op op = *this;
3678   const auto &inputs_ty = op.inputs().getType();
3679   int n_inputs = inputs_ty.size();
3680   if (n_inputs < 1) return op.emitOpError() << "No inputs";
3681 
3682   const auto &init_values_ty = op.init_values().getType();
3683   int n_init_values = init_values_ty.size();
3684   if (n_init_values != n_inputs) {
3685     return op.emitOpError() << "Number of inputs (" << n_inputs
3686                             << ") is different than number of init_values ("
3687                             << n_init_values << ")";
3688   }
3689 
3690   auto input_ty_0 = inputs_ty[0].cast<ShapedType>();
3691   if (input_ty_0.hasStaticShape()) {
3692     for (int i = 1; i < n_inputs; ++i) {
3693       auto input_ty_i = inputs_ty[i].cast<ShapedType>();
3694       if (input_ty_i.hasStaticShape() &&
3695           input_ty_i.getShape() != input_ty_0.getShape()) {
3696         return op.emitOpError()
3697                << "inputs[" << i << "] has shape [" << input_ty_i.getShape()
3698                << "] different than the shape of inputs[0]: "
3699                << input_ty_0.getShape();
3700       }
3701     }
3702 
3703     if (op.dimensions_to_reduce().size() > input_ty_0.getRank()) {
3704       return op.emitOpError()
3705              << "Invalid dimensions_to_reduce argument to XlaVariadicReduceV2";
3706     }
3707   }
3708 
3709   for (int i = 0; i < n_inputs; ++i) {
3710     auto init_value_ty_i = init_values_ty[i].cast<ShapedType>();
3711     if (init_value_ty_i.hasRank() && init_value_ty_i.getRank() != 0) {
3712       return op.emitOpError()
3713              << "init_values[" << i << "] must be a scalar but got ["
3714              << init_value_ty_i.getShape() << "]";
3715     }
3716   }
3717 
3718   auto module = op->getParentOfType<mlir::ModuleOp>();
3719   auto function = dyn_cast_or_null<mlir::func::FuncOp>(
3720       SymbolTable::lookupSymbolIn(module, op.reducer()));
3721   if (!function) return op.emitOpError() << "No reducer";
3722   if (!function.getBody().hasOneBlock())
3723     return op.emitOpError() << "reducer has more than one block";
3724 
3725   return success();
3726 }
3727 
3728 //===----------------------------------------------------------------------===//
3729 // XlaVariadicSortOp
3730 //===----------------------------------------------------------------------===//
3731 
verify()3732 LogicalResult XlaVariadicSortOp::verify() {
3733   XlaVariadicSortOp op = *this;
3734   const auto &inputs_ty = op.inputs().getType();
3735   int n_inputs = inputs_ty.size();
3736   auto input_ty_0 = inputs_ty[0].cast<ShapedType>();
3737   if (input_ty_0.hasStaticShape()) {
3738     for (int i = 1; i < n_inputs; ++i) {
3739       auto input_ty_i = inputs_ty[i].cast<ShapedType>();
3740       if (input_ty_i.hasStaticShape() &&
3741           input_ty_i.getShape() != input_ty_0.getShape()) {
3742         return op.emitOpError()
3743                << "input[" << i << "] has shape [" << input_ty_i.getShape()
3744                << "] different than the shape of input[0]: "
3745                << input_ty_0.getShape();
3746       }
3747     }
3748   }
3749 
3750   ElementsAttr dimension;
3751   if (matchPattern(op.dimension(), m_Constant(&dimension))) {
3752     if (dimension.getType().getRank() != 0 ||
3753         dimension.getType().getNumElements() != 1)
3754       return op.emitOpError() << "dimension must be a scalar";
3755   }
3756 
3757   auto module = op->getParentOfType<mlir::ModuleOp>();
3758   auto function = dyn_cast_or_null<mlir::func::FuncOp>(
3759       SymbolTable::lookupSymbolIn(module, op.comparator()));
3760   if (!function) return op.emitOpError() << "No comparator";
3761   if (!function.getBody().hasOneBlock())
3762     return op.emitOpError() << "comparator has more than one block";
3763 
3764   return success();
3765 }
3766 
3767 //===----------------------------------------------------------------------===//
3768 // SetStaticDimensionBoundsOp
3769 //===----------------------------------------------------------------------===//
3770 //
3771 
verify()3772 LogicalResult SetStaticDimensionBoundsOp::verify() {
3773   SetStaticDimensionBoundsOp op = *this;
3774   mlir::ShapedType input_type = op.input().getType().cast<mlir::ShapedType>();
3775   mlir::ShapedType static_shape_type =
3776       op.static_shape().getType().cast<mlir::ShapedType>();
3777   int input_type_rank = input_type.hasRank() ? input_type.getRank() : -1;
3778   if (input_type_rank > 2) {
3779     return op.emitOpError() << "was used with an input tensor with rank > 2, "
3780                                "only tensors of rank 1,2 are supported";
3781   }
3782 
3783   if (static_shape_type.hasRank() && static_shape_type.getRank() != 1) {
3784     return op.emitOpError("static shape must be of rank 1 (vector)");
3785   }
3786   if (input_type_rank != -1 && static_shape_type.hasStaticShape()) {
3787     if (static_shape_type.getShape()[0] != input_type_rank) {
3788       return op.emitOpError(
3789           "static shape must have num_elements == rank of input "
3790           "tensor");
3791     }
3792   }
3793 
3794   return success();
3795 }
3796 
3797 }  // namespace TF
3798 }  // namespace mlir
3799 
3800 //===----------------------------------------------------------------------===//
3801 // TableGen'd op method definitions
3802 //===----------------------------------------------------------------------===//
3803 
3804 #define GET_OP_CLASSES
3805 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc"
3806