• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/Sequence.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringRef.h"
38 #include "llvm/ADT/StringSwitch.h"
39 #include "llvm/ADT/Twine.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
44 #include "mlir/Dialect/Traits.h"  // from @llvm-project
45 #include "mlir/IR/Attributes.h"  // from @llvm-project
46 #include "mlir/IR/Builders.h"  // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
49 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
50 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
51 #include "mlir/IR/Identifier.h"  // from @llvm-project
52 #include "mlir/IR/Location.h"  // from @llvm-project
53 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
54 #include "mlir/IR/Matchers.h"  // from @llvm-project
55 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
56 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
57 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
58 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
59 #include "mlir/IR/Types.h"  // from @llvm-project
60 #include "mlir/IR/Value.h"  // from @llvm-project
61 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
62 #include "mlir/Parser.h"  // from @llvm-project
63 #include "mlir/Support/LLVM.h"  // from @llvm-project
64 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
65 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
71 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
72 #include "tensorflow/core/platform/logging.h"
73 #include "tensorflow/core/util/tensor_format.h"
74 
75 namespace mlir {
76 namespace TF {
77 
78 namespace {
79 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
80 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
81 }  // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // NotEqualOp
85 //===----------------------------------------------------------------------===//
86 
Verify(NotEqualOp op)87 static LogicalResult Verify(NotEqualOp op) {
88   // If we allow inputs to have incompatible type, then nothing to do.
89   if (!op.incompatible_shape_error()) return success();
90 
91   // Otherwise, check inputs are broadcastable.
92   return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
93       op.getOperation());
94 }
95 
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)96 void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x,
97                        Value y, BoolAttr incompatible_shape_error) {
98   auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
99                                           incompatible_shape_error);
100   return build(builder, result, result_type, x, y, incompatible_shape_error);
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // OneHotOp
105 //===----------------------------------------------------------------------===//
106 
Verify(OneHotOp op)107 static LogicalResult Verify(OneHotOp op) {
108   int64_t axis = op.axis();
109 
110   auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
111   if (indices_ty &&
112       !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
113     return op.emitOpError()
114            << "expected axis (" << axis << ") to be -1 or between [0, "
115            << indices_ty.getShape().size() << "]";
116   }
117 
118   if (axis < -1) {
119     return op.emitOpError() << "expected axis (" << axis
120                             << ") to be -1 or between [0, rank(indices()))";
121   }
122 
123   if (!IsOfRankOrUnranked(op.depth(), 0)) {
124     return op.emitOpError() << "requires depth to be a scalar";
125   }
126   if (!IsOfRankOrUnranked(op.on_value(), 0)) {
127     return op.emitOpError() << "requires on_value to be a scalar";
128   }
129   if (!IsOfRankOrUnranked(op.off_value(), 0)) {
130     return op.emitOpError() << "requires off_value to be a scalar";
131   }
132 
133   DenseIntElementsAttr depth_attr;
134   if (matchPattern(op.depth(), m_Constant(&depth_attr))) {
135     if (depth_attr.getType().getRank() != 0)
136       return op.emitOpError() << "requires depth to be a scalar";
137     int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
138     if (depth < 0) {
139       return op.emitOpError() << "depth must be non-negative, got: " << depth;
140     }
141   }
142 
143   return success();
144 }
145 
InferOneHotOpType(Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)146 static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
147                                     Value off_value, IntegerAttr axis) {
148   int64_t axis_val = axis.getInt();
149   Type element_ty = on_value.getType().cast<TensorType>().getElementType();
150   auto unranked_ty = UnrankedTensorType::get(element_ty);
151   if (axis_val < -1) return unranked_ty;
152 
153   auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
154   if (!indices_ty) return unranked_ty;
155 
156   auto shape = llvm::to_vector<2>(indices_ty.getShape());
157   if (axis_val == -1) axis_val = shape.size();
158 
159   int64_t depth_val = ShapedType::kDynamicSize;
160   DenseIntElementsAttr depth_attr;
161   if (matchPattern(depth, m_Constant(&depth_attr)) &&
162       depth_attr.getNumElements() == 1)
163     depth_val = (*depth_attr.begin()).getSExtValue();
164   shape.insert(shape.begin() + axis_val, depth_val);
165   return RankedTensorType::get(shape, element_ty);
166 }
167 
build(OpBuilder & builder,OperationState & result,Value indices,Value depth,Value on_value,Value off_value,IntegerAttr axis)168 void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices,
169                      Value depth, Value on_value, Value off_value,
170                      IntegerAttr axis) {
171   build(builder, result,
172         InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
173         depth, on_value, off_value, axis);
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // PackOp
178 //===----------------------------------------------------------------------===//
179 
Verify(PackOp op)180 static LogicalResult Verify(PackOp op) {
181   // TODO(hinsu): Convert variadic length attributes to derived attributes.
182   Operation::operand_range values = op.values();
183 
184   if (failed(VerifyTypesCompatibility(values,
185                                       /*mask_one_dim=*/false,
186                                       op.getOperation()))) {
187     return failure();
188   }
189 
190   int64_t inputs_rank = -1;
191   for (Value value : values) {
192     if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
193       // Exit early as input types are verified to be compatible so all ranked
194       // tensors have the same rank.
195       inputs_rank = ty.getRank();
196       break;
197     }
198   }
199   if (inputs_rank == -1) return success();
200 
201   // The values can be packed along any of the dimensions between 0 and
202   // inputs rank, inclusive. Also, as the negative axis values wrap around so
203   // the axis value range is [-(R+1), R+1).
204   int64_t range_begin = -inputs_rank - 1;  // Inclusive
205   int64_t range_end = inputs_rank + 1;     // Exclusive
206   int64_t axis = op.axis();
207   if (axis < range_begin || axis >= range_end) {
208     return op.emitError() << "attribute 'axis' should be within range ["
209                           << range_begin << ", " << range_end
210                           << "); actual value: " << axis;
211   }
212 
213   return success();
214 }
215 
fold(ArrayRef<Attribute> operands)216 OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
217   // Fold pack operation if it computes the input tensor shape:
218   //
219   //   %shape  = tf.Shape(%arg)                    // [? x ...]
220   //   %dim0   = tf.StridedSlice(%shape, 0, 1, 1)  // get unknown dim0 value
221   //   %pack   = tf.Pack(dim0, ...) { axis = 0 }   // [? x ...]
222   //
223   // Where `...` are some statically known dimensions. In this case %pack can be
224   // replaced with a %shape. This is a common pattern in models with a dynamic
225   // batch size.
226 
227   // Pack operation should pack at least two values.
228   if (values().size() < 2) return {};
229 
230   // Dimensions packed along axis = 0 (pack scalars into vector).
231   if (axis() != 0) return {};
232 
233   // First packed value is defined by a strided slice operation.
234   auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
235   if (!slice_op) return {};
236 
237   // Input to the slice op is defined by shape operation.
238   auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
239   if (!shape_op) return {};
240 
241   // Input tensor, which shape is reconstructed by the pack operation.
242   Value tensor = shape_op.input();
243 
244   // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
245   // scalar value from input vector).
246   if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 ||
247       slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 ||
248       slice_op.shrink_axis_mask() != 1)
249     return {};
250 
251   // Returns a value if the `value` is defined by a ConstOp with a single
252   // integer element in it and has an expected rank.
253   auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
254     auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
255     if (!const_op) return None;
256 
257     auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
258     if (!value_attr || value_attr.getNumElements() != 1) return None;
259 
260     auto value_ty = value_attr.getType();
261     if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
262 
263     auto splat = value_attr.getSplatValue<IntegerAttr>();
264     return splat.getValue().getSExtValue();
265   };
266 
267   // All other packed values are scalar constants.
268   SmallVector<int64_t, 4> packed_dims;
269   packed_dims.reserve(values().size() - 1);
270   for (Value operand : llvm::drop_begin(values(), 1)) {
271     if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
272       packed_dims.push_back(*dim);
273     } else {
274       return {};
275     }
276   }
277 
278   // Slice exactly the first shape dimension:
279   //   begin = [0] end = [1], strides = [1]
280   auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
281   auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
282   auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
283   if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
284       *begin != 0 || *end != 1 || *strides != 1)
285     return {};
286 
287   // First tensor dimension is dynamic.
288   auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
289   if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
290       !arg_ty.isDynamicDim(0))
291     return {};
292 
293   // Argument tensor rank is equal to the number of packed dimensions.
294   if (arg_ty.getRank() != values().size()) return {};
295 
296   // All other dimensions are statically known and equal to packed dims.
297   auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
298   if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
299     return {};
300 
301   // Replace %pack with %shape.
302   return slice_op.input();
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // PadOp
307 //===----------------------------------------------------------------------===//
308 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)309 LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
310   // Paddings must be defined by a constant operation.
311   auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
312   if (!paddings_op) return failure();
313 
314   auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
315   if (!paddings_value ||
316       paddings_value.getNumElements() != permutation.size() * 2)
317     return failure();
318 
319   SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
320   for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) {
321     size_t outer_idx = index_pair.index() / 2;
322     size_t inner_idx = index_pair.index() % 2;
323 
324     shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
325         index_pair.value().getSExtValue();
326   }
327 
328   // Add constant operation with a new paddings.
329   OpBuilder builder(getOperation());
330   auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
331                                           builder.getIntegerType(32));
332   auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
333   auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
334 
335   // Use new paddings.
336   setOperand(1, shuffled_paddings_op);
337 
338   // Change the result type.
339   getResult().setType(ShuffleRankedTensorType(getResult().getType(),
340                                               ReversePermutation(permutation)));
341 
342   return success();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // ParseExampleV2Op
347 //===----------------------------------------------------------------------===//
348 
Verify(ParseExampleV2Op op)349 static LogicalResult Verify(ParseExampleV2Op op) {
350   // NOTE(mrry): This validates properties of an op that would previously be
351   // validated by the TensorFlow OpDef type checker. In addition to these
352   // checks, the shape inference function for ParseExampleV2 validates the
353   // consistency of the argument and result types.
354 
355   // Validate dense variadic input and output lengths.
356   // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
357   // do not need to validate dense_defaults.
358   auto dense_types_count =
359       std::distance(op.Tdense().begin(), op.Tdense().end());
360   auto dense_values_count =
361       std::distance(op.dense_values().begin(), op.dense_values().end());
362   if (dense_values_count != dense_types_count) {
363     return op.emitError() << "output 'dense_values' should have same length "
364                           << "as attribute 'Tdense'";
365   }
366 
367   // Validate sparse variadic output lengths.
368   // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
369   // do not need to validate sparse_values.
370   auto sparse_types_count =
371       std::distance(op.sparse_types().begin(), op.sparse_types().end());
372   if (op.num_sparse() != sparse_types_count) {
373     return op.emitError() << "attribute 'num_sparse' should be the same as "
374                           << "the length of attribute 'sparse_types'";
375   }
376   if (op.sparse_indices().size() != sparse_types_count) {
377     return op.emitError() << "output 'sparse_indices' should have same length "
378                           << "as attribute 'sparse_types'";
379   }
380   if (op.sparse_shapes().size() != sparse_types_count) {
381     return op.emitError() << "output 'sparse_shapes' should have same length "
382                           << "as attribute 'sparse_types'";
383   }
384 
385   // Validate ragged variadic output lengths.
386   auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
387                                                 op.ragged_value_types().end());
388   auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
389                                                 op.ragged_split_types().end());
390   if (ragged_value_types_count != ragged_split_types_count) {
391     return op.emitError() << "attribute 'ragged_value_types' should have same "
392                           << "length as attribute 'ragged_split_types'";
393   }
394 
395   return success();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // PartitionedCallOp
400 //===----------------------------------------------------------------------===//
401 
402 template <class OpClass>
VerifyPartitionedCall(OpClass op)403 static LogicalResult VerifyPartitionedCall(OpClass op) {
404   auto module = op->template getParentOfType<ModuleOp>();
405   SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
406 
407   auto function =
408       dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
409 
410   if (!function) {
411     return op.emitError("'f' attribute refers to an undefined function: ")
412            << func;
413   }
414 
415   FunctionType function_ty = function.getType();
416   int func_arg_count = function_ty.getNumInputs();
417   int arg_count = op.args().size();
418 
419   if (arg_count != func_arg_count) {
420     return op.emitError() << "argument count mismatch: 'args' has " << arg_count
421                           << " arguments, but '" << func << "' expects "
422                           << func_arg_count;
423   }
424 
425   return success();
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // PowOp
430 //===----------------------------------------------------------------------===//
431 
fold(ArrayRef<Attribute> operands)432 OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
433   auto constant_y = operands[1].dyn_cast_or_null<DenseFPElementsAttr>();
434   if (constant_y && constant_y.isSplat()) {
435     APFloat y_value = constant_y.getSplatValue<APFloat>();
436     auto output_type = getType().cast<ShapedType>();
437     if (y_value.isZero() && output_type.hasStaticShape()) {
438       return DenseElementsAttr::get(
439           output_type,
440           FloatAttr::get(output_type.getElementType(), /*value=*/1.0));
441     }
442     if (y_value.isExactlyValue(1.0)) {
443       return x();
444     }
445   }
446   return {};
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // QrOp
451 //===----------------------------------------------------------------------===//
452 
453 // Verifies that,
454 //
455 // * Input type, if ranked, must have at least 2 dimensions and at most
456 //   INT32_MAX dimensions.
457 //
Verify(QrOp op)458 static LogicalResult Verify(QrOp op) {
459   auto ttype = op.input().getType().cast<TensorType>();
460   if (!ttype.hasRank()) return success();
461   if (!HasRankAtLeast(op.input(), 2))
462     return op.emitOpError(
463         "requires ranked input tensor to be of rank 2 or more");
464   if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
465     return op.emitOpError(
466         "requires ranked input tensor to be of rank INT32_MAX or less");
467 
468   return success();
469 }
470 
471 //===----------------------------------------------------------------------===//
472 // ReadVariableOp
473 //===----------------------------------------------------------------------===//
474 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)475 void ReadVariableOp::getCanonicalizationPatterns(
476     OwningRewritePatternList &results, MLIRContext *context) {
477   results.insert<ReadVariableOfCast>(context);
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // RandomUniformOp
482 //===----------------------------------------------------------------------===//
483 
Verify(RandomUniformOp op)484 static LogicalResult Verify(RandomUniformOp op) {
485   if (!IsOfRankOrUnranked(op.shape(), 1))
486     return op.emitOpError("shape must be 1D tensor");
487   return success();
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // RangeOp
492 //===----------------------------------------------------------------------===//
493 
build(OpBuilder & builder,OperationState & result,Value start,Value limit,Value delta)494 void RangeOp::build(OpBuilder &builder, OperationState &result, Value start,
495                     Value limit, Value delta) {
496   assert(start.getType() == limit.getType());
497   assert(start.getType() == delta.getType());
498   DenseIntElementsAttr start_val;
499   DenseIntElementsAttr limit_val;
500   DenseIntElementsAttr delta_val;
501   if (matchPattern(start, m_Constant(&start_val)) &&
502       matchPattern(limit, m_Constant(&limit_val)) &&
503       matchPattern(delta, m_Constant(&delta_val))) {
504     auto size = llvm::APIntOps::RoundingSDiv(
505         *limit_val.begin() - *start_val.begin(), *delta_val.begin(),
506         llvm::APInt::Rounding::DOWN);
507     return RangeOp::build(
508         builder, result,
509         RankedTensorType::get(
510             size.getSExtValue(),
511             start.getType().cast<TensorType>().getElementType()),
512         start, limit, delta);
513   }
514   return RangeOp::build(
515       builder, result,
516       RankedTensorType::get(
517           {-1}, start.getType().cast<TensorType>().getElementType()),
518       start, limit, delta);
519 }
520 //===----------------------------------------------------------------------===//
521 // RankOp
522 //===----------------------------------------------------------------------===//
523 
build(OpBuilder & builder,OperationState & result,Value input)524 void RankOp::build(OpBuilder &builder, OperationState &result, Value input) {
525   return RankOp::build(builder, result,
526                        RankedTensorType::get({}, builder.getIntegerType(32)),
527                        input);
528 }
529 
530 // This will create a constant value for RankOp of a ranked tensor.
fold(ArrayRef<Attribute> operands)531 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
532   auto type = input().getType();
533   auto ranked_type = type.dyn_cast<RankedTensorType>();
534   if (!ranked_type) return {};
535 
536   // DenseIntElementsAttr::get requires the output type be ranked with static
537   // shape.
538   auto output_type = getType().dyn_cast<RankedTensorType>();
539   if (!output_type || !output_type.hasStaticShape()) return {};
540 
541   int32_t rank = ranked_type.getRank();
542   return DenseIntElementsAttr::get(output_type, rank);
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // RealDivOp
547 //===----------------------------------------------------------------------===//
548 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)549 void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
550                                             MLIRContext *context) {
551   results.insert<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
552 }
553 
fold(ArrayRef<Attribute> operands)554 OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
555   return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // ReluOp
560 //===----------------------------------------------------------------------===//
561 
GetContractionFusion()562 Optional<ContractionFusion> ReluOp::GetContractionFusion() {
563   // Only f32 is supported for fusion.
564   if (!T().isF32()) return None;
565 
566   return ContractionFusion("Relu", /*additional_arguments=*/{});
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // ReshapeOp
571 //===----------------------------------------------------------------------===//
572 
573 namespace {
574 using ReshapeErrorHandler =
575     llvm::function_ref<LogicalResult(const llvm::Twine &)>;
576 
GetReshapeOutputType(Value tensor,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)577 LogicalResult GetReshapeOutputType(Value tensor, Value shape,
578                                    ReshapeErrorHandler error_handler,
579                                    TensorType &output_ty) {
580   auto tensor_ty = tensor.getType().cast<TensorType>();
581   auto element_ty = tensor_ty.getElementType();
582   output_ty = UnrankedTensorType::get(element_ty);
583 
584   auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
585   if (!shape_ty) return success();
586   if (shape_ty.getRank() != 1)
587     return error_handler(llvm::formatv(
588         "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
589 
590   DenseIntElementsAttr shape_attr;
591   if (!matchPattern(shape, m_Constant(&shape_attr))) {
592     // If only shape of `shape` is known, return ranked but dynamic output
593     // shape.
594     if (shape_ty.hasStaticShape()) {
595       llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
596                                                   ShapedType::kDynamicSize);
597       output_ty = RankedTensorType::get(dynamic_shape, element_ty);
598     }
599     return success();
600   }
601 
602   // Detect if reshape output shape is folded.
603   bool shape_ty_zero_dim = false;
604   int unknown_index = -1;
605   // The product of constant shape argument excluding unknown dimension.
606   int64_t shape_ty_size = 1;
607   llvm::SmallVector<int64_t, 8> output_ty_shape;
608   output_ty_shape.reserve(shape_attr.getNumElements());
609   for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) {
610     const int64_t size = dim.value().getSExtValue();
611     if (size == ShapedType::kDynamicSize) {
612       if (unknown_index != -1)
613         return error_handler(llvm::formatv(
614             "requires 'shape' to have at most one dynamic dimension, but got "
615             "multiple dynamic dimensions at indices {0} and {1}",
616             unknown_index, dim.index()));
617 
618       unknown_index = dim.index();
619     } else if (size == 0) {
620       shape_ty_zero_dim = true;
621     } else if (size > 0) {
622       shape_ty_size *= size;
623     } else {
624       return error_handler(
625           llvm::formatv("requires 'shape' to have dimensions greater than -1, "
626                         "but got {0} at index {1}",
627                         size, dim.index()));
628     }
629     output_ty_shape.push_back(size);
630   }
631 
632   if (!tensor_ty.hasStaticShape()) {
633     output_ty = RankedTensorType::get(output_ty_shape, element_ty);
634     return success();
635   }
636 
637   // Compute the value of the unknown dimension.
638   if (unknown_index != -1) {
639     // Compute number of elements in tensor shape.
640     int64_t tensor_ty_size = 1;
641     bool tensor_ty_zero_dim = false;
642     for (const auto &dim : tensor_ty.getShape()) {
643       if (dim > 0 || !shape_ty_zero_dim) {
644         tensor_ty_size *= dim;
645       } else {
646         tensor_ty_zero_dim = true;
647       }
648     }
649 
650     const int64_t missing_dim = tensor_ty_size / shape_ty_size;
651     if (!tensor_ty_zero_dim && shape_ty_size * missing_dim != tensor_ty_size)
652       return error_handler(
653           llvm::formatv("requires 'tensor' number of elements be a multiple of "
654                         "{0}, but got {1}",
655                         shape_ty_size, tensor_ty_size));
656 
657     // Set the unknown dimension such that total number of elements remain
658     // constant.
659     output_ty_shape[unknown_index] = missing_dim;
660   }
661 
662   output_ty = RankedTensorType::get(output_ty_shape, element_ty);
663 
664   return success();
665 }
666 }  // namespace
667 
Verify(ReshapeOp op)668 static LogicalResult Verify(ReshapeOp op) {
669   auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
670     return op.emitOpError() << message;
671   };
672   TensorType expected_ty;
673   if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler,
674                                   expected_ty)))
675     return failure();
676 
677   auto output_ty = op.getType().dyn_cast<RankedTensorType>();
678   if (!output_ty) return success();
679   auto tensor_ty = op.tensor().getType().cast<TensorType>();
680   if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) {
681     const int64_t output_ty_size = output_ty.getNumElements();
682     const int64_t tensor_ty_size = tensor_ty.getNumElements();
683     if (tensor_ty_size != output_ty_size)
684       return op.emitOpError() << "requires 'output' number of elements to "
685                                  "match 'tensor' number of elements, but got "
686                               << output_ty_size << " and " << tensor_ty_size;
687   }
688 
689   if (!AreCastCompatible({output_ty, expected_ty}))
690     return op.emitOpError()
691            << "requires 'output' type " << output_ty
692            << " to be cast compatible with expected type " << expected_ty;
693 
694   return success();
695 }
696 
697 // Currently there are use cases that rely on partial evaluation of the `shape`
698 // operand, so InferTypeOpInterface is not used (along with generated builder of
699 // the same signature).
build(OpBuilder & builder,OperationState & result,Value tensor,Value shape)700 void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
701                       Value shape) {
702   auto error_handler = [&result](const llvm::Twine &message) {
703     return mlir::emitError(result.location) << message;
704   };
705   TensorType output_ty;
706   if (failed(GetReshapeOutputType(tensor, shape, error_handler, output_ty)))
707     return;
708 
709   return ReshapeOp::build(builder, result, output_ty, tensor, shape);
710 }
711 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)712 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
713                                             MLIRContext *context) {
714   results.insert<RedundantReshape, ReshapeToSelfShape>(context);
715 }
716 
fold(ArrayRef<Attribute> operands)717 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
718   Value tensor = this->tensor();
719 
720   // Fold reshape if operand and result types are the same and all dimensions
721   // are statically known (no-op reshape).
722   auto result_ty = getType().dyn_cast<ShapedType>();
723   if (result_ty && result_ty.hasStaticShape() &&
724       result_ty == tensor.getType()) {
725     return tensor;
726   }
727 
728   return {};
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // SelectOp
733 //===----------------------------------------------------------------------===//
734 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)735 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
736                                            MLIRContext *context) {
737   results.insert<SelectToSelectV2>(context);
738 }
739 
740 // Verifies a few extra requirements on SelectOp:
741 // (1) `then` and `else` must have same shape
742 // (2) At least one of the following must be true:
743 //     (a) `cond` has the same rank as `then` and `else`
744 //     (b) `cond` is a scalar
745 //     (c) `cond` is a vector AND `then` and `else` are non-scalar with their
746 //         first dimension equal to `cond`.
Verify(SelectOp op)747 static LogicalResult Verify(SelectOp op) {
748   auto then_tensor = op.t().getType().cast<TensorType>();
749   auto else_tensor = op.e().getType().cast<TensorType>();
750   // Check (1).
751   if (!AreCastCompatible({then_tensor, else_tensor}))
752     return op.emitOpError() << "requires t and e have compatible shapes";
753 
754   // Get data rank (if exists).
755   int data_rank;
756   // If data is unranked or data_rank is 0, this will remain -2. Otherwise
757   // refers to first dimension of then and/or else.
758   int data_first_dim = -2;
759   bool then_has_rank = then_tensor.hasRank();
760   bool else_has_rank = else_tensor.hasRank();
761   if (then_has_rank && else_has_rank) {
762     data_rank = then_tensor.getRank();
763     if (then_tensor.getRank() > 0)
764       data_first_dim = then_tensor.getShape().front();
765     if (else_tensor.getRank() > 0)
766       data_first_dim = std::max(
767           static_cast<int>(else_tensor.getShape().front()), data_first_dim);
768   } else if (then_has_rank) {
769     data_rank = then_tensor.getRank();
770     if (then_tensor.getRank() > 0)
771       data_first_dim = then_tensor.getShape().front();
772   } else if (else_has_rank) {
773     data_rank = else_tensor.getRank();
774     if (else_tensor.getRank() > 0)
775       data_first_dim = else_tensor.getShape().front();
776   } else {
777     // Neither has a rank.
778     return success();
779   }
780 
781   auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
782   if (!cond_tensor) return success();
783   auto cond_rank = cond_tensor.getRank();
784   // Check (2a) and (2b).
785   if (cond_rank == 0 || cond_rank == data_rank) return success();
786   // Check (2c).
787   if (cond_rank == 1) {
788     auto cond_shape = cond_tensor.getShape().front();
789     if (data_rank == 0) {
790       return op.emitOpError()
791              << "requires that t and e are nonscalar when pred is a vector";
792     }
793     // We know `data` tensor has a rank of at least 1.
794     if (data_first_dim != -1 && cond_shape != -1 &&
795         data_first_dim != cond_shape) {
796       return op.emitOpError() << "requires that, when pred is a vector, the "
797                                  "shape matches the first dimension of t and e";
798     }
799     return success();
800   }
801   // None of (2a,b,c) were true; fail.
802   return op.emitOpError() << "requires that pred is a scalar OR has the same "
803                              "rank as t and e OR is a vector";
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // SelectV2Op
808 //===----------------------------------------------------------------------===//
809 
InferSelectV2OpType(Value condition,Value e,Value t)810 static Type InferSelectV2OpType(Value condition, Value e, Value t) {
811   Type element_ty = e.getType().cast<TensorType>().getElementType();
812   auto unranked_ty = UnrankedTensorType::get(element_ty);
813 
814   Type broadcasted_ty =
815       OpTrait::util::getBroadcastedType(e.getType(), t.getType());
816   if (!broadcasted_ty) return unranked_ty;
817 
818   auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
819   auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
820   if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
821 
822   // Explicitly get broadcasted output type as element types of condition may
823   // not be same as the broadcated type's element type.
824   SmallVector<int64_t, 4> result_shape;
825   if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(),
826                                           broadcasted_ranked_ty.getShape(),
827                                           result_shape))
828     return unranked_ty;
829   return RankedTensorType::get(result_shape, element_ty);
830 }
831 
build(OpBuilder & builder,OperationState & result,Value condition,Value e,Value t)832 void SelectV2Op::build(OpBuilder &builder, OperationState &result,
833                        Value condition, Value e, Value t) {
834   build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
835 }
836 
837 //===----------------------------------------------------------------------===//
838 // ShapeOp
839 //===----------------------------------------------------------------------===//
840 
841 namespace {
842 // Validates Shape/ShapeN/VariableShape operand and associated result types.
VerifyShapeOperandAndResult(Operation * op,Type operand_type,Type result_type,int variadic_idx=-1)843 LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
844                                           Type result_type,
845                                           int variadic_idx = -1) {
846   std::string variadic_idx_str =
847       variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str();
848 
849   auto result_ranked_type = result_type.dyn_cast<RankedTensorType>();
850   if (!result_ranked_type) return success();
851   if (result_ranked_type.getShape().size() != 1)
852     return op->emitOpError("requires 1D type for result") << variadic_idx_str;
853 
854   auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
855   if (operand_ranked_type) {
856     // The operand is a ranked tensor.
857     if (result_ranked_type.hasStaticShape() &&
858         !operand_ranked_type.getShape().empty() &&
859         result_ranked_type.getDimSize(0) !=
860             operand_ranked_type.getShape().size())
861       return op->emitOpError("requires dimension size of result")
862              << variadic_idx_str << " to match rank of operand"
863              << variadic_idx_str;
864   } else if (result_ranked_type.hasStaticShape()) {
865     // The operand is an unranked tensor, print a warning if the result
866     // is static.
867     // Note: We do not handle this situation as an error, this would be too
868     // restrictive due to incompleteness of shape inference at this point.
869     op->emitWarning("has static shape result")
870         << variadic_idx_str << " for unranked operand" << variadic_idx_str;
871   }
872 
873   Type element_type = result_ranked_type.getElementType();
874   if (!element_type.isSignlessInteger(32) &&
875       !element_type.isSignlessInteger(64))
876     return op->emitOpError("requires int32 or int64 return type for result")
877            << variadic_idx_str;
878 
879   return success();
880 }
881 }  // anonymous namespace
882 
Verify(ShapeOp op)883 static LogicalResult Verify(ShapeOp op) {
884   return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
885 }
886 
887 // Converts shape of the given type to attribute if it is of ranked tensor type.
888 // Returned attribute has integer elements of the given width.
ConvertShapeToAttr(Type input_ty,int out_width)889 static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
890   auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
891   if (!ranked_ty || !ranked_ty.hasStaticShape()) return {};
892 
893   auto shape = ranked_ty.getShape();
894   int rank = shape.size();
895 
896   SmallVector<APInt, 4> dimensions;
897   dimensions.reserve(rank);
898   for (int i = 0; i < rank; ++i)
899     dimensions.push_back(APInt(out_width, shape[i]));
900 
901   auto result_type = RankedTensorType::get(
902       {rank}, IntegerType::get(input_ty.getContext(), out_width));
903   return DenseElementsAttr::get(result_type, dimensions);
904 }
905 
fold(ArrayRef<Attribute> operands)906 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
907   int width =
908       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
909   return ConvertShapeToAttr(getOperand().getType(), width);
910 }
911 
build(OpBuilder & builder,OperationState & result,Value input,BoolAttr use32Bit)912 void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input,
913                     BoolAttr use32Bit) {
914   auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
915   int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
916   auto out_type = use32Bit.getValue() ? builder.getIntegerType(32)
917                                       : builder.getIntegerType(64);
918   return ShapeOp::build(builder, result,
919                         RankedTensorType::get({rank}, out_type), input);
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // ShapeNOp
924 //===----------------------------------------------------------------------===//
925 
Verify(ShapeNOp op)926 static LogicalResult Verify(ShapeNOp op) {
927   const size_t num_tensors = op.N();
928 
929   if (op.getNumOperands() != num_tensors)
930     return op.emitOpError() << "requires " << num_tensors << " operand(s), got "
931                             << op.getNumOperands() << " operand(s)";
932 
933   if (op.getNumResults() != num_tensors)
934     return op.emitOpError() << "requires " << num_tensors << " result(s), got "
935                             << op.getNumResults() << " result(s)";
936 
937   for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
938     auto verification = VerifyShapeOperandAndResult(
939         op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
940     if (failed(verification)) return verification;
941   }
942 
943   return success();
944 }
945 
946 namespace {
947 // Canonicalization pattern for ShapeNOp that don't have all
948 // static input shapes. Replacing output values corresponding to static input
949 // types may enable optimizations in users of the values.
950 class ShapeNPartialStaticInputShape : public OpRewritePattern<ShapeNOp> {
951   using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const952   LogicalResult matchAndRewrite(ShapeNOp op,
953                                 PatternRewriter &rewriter) const override {
954     if (op.getNumOperands() == 0) {
955       rewriter.eraseOp(op);
956       return success();
957     }
958 
959     int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
960 
961     SmallVector<Value, 4> results(op.getNumOperands());
962     SmallVector<int64_t, 4> dynamic_indices;
963     SmallVector<Value, 4> dynamic_inputs;
964     SmallVector<Type, 4> result_types;
965     for (auto e : llvm::enumerate(op.getOperands())) {
966       if (Attribute result = ConvertShapeToAttr(e.value().getType(), width)) {
967         results[e.index()] = rewriter.create<TF::ConstOp>(op.getLoc(), result);
968       } else {
969         dynamic_indices.push_back(e.index());
970         dynamic_inputs.push_back(e.value());
971         result_types.push_back(op.getType(e.index()));
972       }
973     }
974 
975     if (dynamic_inputs.size() == op.getNumOperands()) {
976       // Cannot canonicalize ShapeN if all inputs are dynamic.
977       return failure();
978     }
979 
980     // Create a ShapeNOp for all dynamic inputs.
981     if (!dynamic_inputs.empty()) {
982       auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
983           op.getLoc(), result_types, dynamic_inputs);
984       for (auto index_result :
985            llvm::zip(dynamic_indices, dynamic_shape_n.getResults())) {
986         results[std::get<0>(index_result)] = std::get<1>(index_result);
987       }
988     }
989 
990     rewriter.replaceOp(op, results);
991     return success();
992   }
993 };
994 
995 // Canonicalize ShapeNOp to ShapeOp if there is only one operand.
996 class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
997   using OpRewritePattern<ShapeNOp>::OpRewritePattern;
matchAndRewrite(ShapeNOp op,PatternRewriter & rewriter) const998   LogicalResult matchAndRewrite(ShapeNOp op,
999                                 PatternRewriter &rewriter) const override {
1000     if (op.getNumOperands() != 1) {
1001       return failure();
1002     }
1003     auto shape = rewriter.create<TF::ShapeOp>(op.getLoc(), op.getType(0),
1004                                               op.getOperand(0));
1005     rewriter.replaceOp(op, {shape});
1006     return success();
1007   }
1008 };
1009 }  // namespace
1010 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1011 void ShapeNOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1012                                            MLIRContext *context) {
1013   results.insert<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // SizeOp
1018 //===----------------------------------------------------------------------===//
1019 
1020 // Verifies that,
1021 //
1022 // * Input type, if is a ranked tensor, has at most INT32_MAX dimensions.
1023 //
Verify(SizeOp op)1024 static LogicalResult Verify(SizeOp op) {
1025   if (!HasRankAtMost(op.input(), std::numeric_limits<int32_t>::max()))
1026     return op.emitOpError(
1027         "requires ranked input tensor to be of rank INT32_MAX or less");
1028 
1029   // Output type needs to be scalar.
1030   if (!IsOfRankOrUnranked(op.output(), /*rank=*/0))
1031     return op.emitOpError("requires scalar output");
1032 
1033   return success();
1034 }
1035 
fold(ArrayRef<Attribute> operands)1036 OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
1037   ShapedType output_type = getType().cast<ShapedType>();
1038   if (!output_type.hasRank()) return {};
1039   ShapedType input_type = getOperand().getType().cast<ShapedType>();
1040   if (!input_type.hasStaticShape()) return {};
1041   int size = input_type.getNumElements();
1042   return DenseElementsAttr::get(
1043       output_type,
1044       IntegerAttr::get(output_type.getElementType(), /*value=*/size));
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // SliceOp
1049 //===----------------------------------------------------------------------===//
1050 
1051 // Verifies that:
1052 //
1053 // - operands begin and size are 1D with the same number of elements.
1054 // - if the input is a ranked tensor, the rank of the input equals the number
1055 //   of elements in operands begin and size.
1056 // - if begin are constants, that
1057 //   0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
1058 //   and
1059 //   size[i] == output_ty.getShape()[i]
1060 // - if begins aren't constant but the input is a ranked tensor, that
1061 //   size[i] <= input_ty.getShape()[i]
1062 // - output rank is the same as input rank
1063 //
Verify(SliceOp op)1064 static LogicalResult Verify(SliceOp op) {
1065   RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
1066   if (begin_ty && begin_ty.getRank() != 1) {
1067     return op.emitOpError() << "requires begin operand to be 1D tensor";
1068   }
1069 
1070   RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
1071   if (size_ty && size_ty.getRank() != 1) {
1072     return op.emitOpError() << "requires size operand to be 1D tensor";
1073   }
1074 
1075   if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
1076       !size_ty.hasStaticShape())
1077     return success();
1078 
1079   if (begin_ty.getNumElements() != size_ty.getNumElements()) {
1080     return op.emitOpError() << "requires begin and size operands to have the"
1081                                " same number of elements";
1082   }
1083 
1084   auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1085   if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
1086     return op.emitOpError() << "requires number of elements in begin and size"
1087                                "are equal to input rank";
1088   }
1089 
1090   auto output_ty = op.output().getType().dyn_cast<RankedTensorType>();
1091   if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) {
1092     return op.emitOpError()
1093            << "requires output to have the same rank as input, but got input "
1094               "rank "
1095            << input_ty.getRank() << " and output rank " << output_ty.getRank();
1096   }
1097 
1098   DenseIntElementsAttr begin_indices;
1099   if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
1100     DenseIntElementsAttr slice_sizes;
1101     bool constant_slice_sizes =
1102         matchPattern(op.size(), m_Constant(&slice_sizes));
1103     int dim = 0;
1104     // TODO(jpienaar): Reformulate the shape verification below to not use magic
1105     // constants.
1106     for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
1107       int64_t begin_index = raw_begin_index.getSExtValue();
1108       int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
1109       int64_t slice_size = constant_slice_sizes
1110                                ? slice_sizes.getValue<APInt>(dim).getSExtValue()
1111                                : 0;
1112       int64_t output_size = output_ty ? output_ty.getShape()[dim] : -1;
1113 
1114       if (slice_size == -1 && input_size != -1) {
1115         slice_size = input_size - begin_index;
1116       }
1117       if (output_size != -1 && constant_slice_sizes &&
1118           output_size != slice_size) {
1119         return op.emitOpError()
1120                << "requires output size to have the same size of slice, got "
1121                   "slice size "
1122                << slice_size << " and output size " << output_size;
1123       }
1124       if (begin_index < 0 ||
1125           (input_size != -1 && begin_index + slice_size > input_size)) {
1126         return op.emitOpError()
1127                << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
1128       }
1129       ++dim;
1130     }
1131   } else if (input_ty) {
1132     // If the inputs are ranked, we can do a few more sanity checks.
1133     DenseIntElementsAttr slice_sizes;
1134     if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
1135       auto input_shape = input_ty.getShape();
1136       for (int64_t i = 0; i < input_ty.getRank(); ++i) {
1137         int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
1138         int64_t input_size = input_shape[i];
1139         if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
1140           return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
1141                                      "is unknown at compile time";
1142         }
1143       }
1144     }
1145   }
1146 
1147   return success();
1148 }
1149 
1150 //===----------------------------------------------------------------------===//
1151 // SoftmaxOp
1152 //===----------------------------------------------------------------------===//
1153 
Verify(SoftmaxOp op)1154 static LogicalResult Verify(SoftmaxOp op) {
1155   if (!HasRankAtLeast(op.logits(), 1)) {
1156     return op.emitOpError("requires operand to have rank at least 1");
1157   }
1158   return success();
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // SoftmaxCrossEntropyWithLogitsOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 // Verifies that,
1166 //
1167 // * Input types are broadcast compatible and the broadcasted type has rank two.
1168 //
Verify(SoftmaxCrossEntropyWithLogitsOp op)1169 static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
1170   auto broadcasted_ty = OpTrait::util::getBroadcastedType(
1171                             op.features().getType(), op.labels().getType())
1172                             .dyn_cast_or_null<ShapedType>();
1173   if (!broadcasted_ty ||
1174       (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
1175     return op.emitOpError(
1176         "requires features and labels to be broadcast compatible to rank two");
1177 
1178   return success();
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // SpaceToBatchNDOp
1183 //===----------------------------------------------------------------------===//
1184 
SpaceToBatchNDBlockRank(const TensorType block_shape_type,const TensorType paddings_type)1185 int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type,
1186                                 const TensorType paddings_type) {
1187   if (block_shape_type.hasStaticShape()) {
1188     return block_shape_type.getShape()[0];
1189   } else if (paddings_type.hasStaticShape()) {
1190     return paddings_type.getShape()[0];
1191   } else {
1192     return -1;
1193   }
1194 }
1195 
Verify(SpaceToBatchNDOp op)1196 static LogicalResult Verify(SpaceToBatchNDOp op) {
1197   const auto input_type = op.input().getType().cast<TensorType>();
1198   const auto block_shape_type = op.block_shape().getType().cast<TensorType>();
1199   const auto paddings_type = op.paddings().getType().cast<TensorType>();
1200 
1201   // Check that block_shape has rank 1.
1202   if (!IsOfRankOrUnranked(op.block_shape(), 1)) {
1203     return op.emitOpError() << "requires rank of block_shape = 1; got "
1204                             << block_shape_type.getRank();
1205   }
1206 
1207   // Check that paddings has rank 2.
1208   if (!IsOfRankOrUnranked(op.paddings(), 2)) {
1209     return op.emitOpError()
1210            << "requires rank of paddings = 2; got " << paddings_type.getRank();
1211   }
1212 
1213   // Check that paddings.shape[1]=2.
1214   if (paddings_type.hasStaticShape() && paddings_type.getShape()[1] != 2) {
1215     return op.emitOpError() << "requires paddings.shape[1] to be 2; got "
1216                             << paddings_type.getShape()[1];
1217   }
1218 
1219   // Check that block_shape and paddings have consistent ranks.
1220   if (block_shape_type.hasStaticShape() && paddings_type.hasStaticShape() &&
1221       block_shape_type.getShape()[0] != paddings_type.getShape()[0]) {
1222     return op.emitOpError()
1223            << "requires block_shape.shape[0] must equal paddings.shape[0]";
1224   }
1225 
1226   const int64_t block_rank =
1227       SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
1228 
1229   // Further checks require block_rank to be known.
1230   if (block_rank == -1) {
1231     return success();
1232   }
1233 
1234   // check that rank of input_type >= block_rank + 1
1235   if (input_type.hasRank() && input_type.getRank() < 1 + block_rank) {
1236     return op.emitOpError() << "requires rank of input >= 1 + rank of block";
1237   }
1238 
1239   ElementsAttr block_shape_attr = nullptr;
1240   ElementsAttr paddings_attr = nullptr;
1241 
1242   // Check that block_shape[*] >= 1.
1243   if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) {
1244     uint64_t i = 0;
1245     for (auto block_len : block_shape_attr.getValues<APInt>()) {
1246       if (block_len.getSExtValue() < 1) {
1247         return op.emitOpError()
1248                << "requires all values of block_shape to be >= 1; "
1249                   "failed for dimension "
1250                << i;
1251       }
1252       ++i;
1253     }
1254   }
1255 
1256   // Check that paddings[*] >= 0.
1257   if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) {
1258     for (uint64_t i = 0; i < block_rank; ++i) {
1259       const int64_t pad_start =
1260           paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt();
1261       const int64_t pad_end =
1262           paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
1263       if (pad_start < 0 || pad_end < 0) {
1264         return op.emitOpError()
1265                << "requires all values of paddings to be >= 0; "
1266                   "failed for dimension "
1267                << i;
1268       }
1269     }
1270   }
1271 
1272   // Check that block_shape divides the padded input.
1273   if (input_type.hasStaticShape() && block_shape_attr && paddings_attr) {
1274     for (uint64_t i = 0; i < block_rank; ++i) {
1275       const int64_t input_len = input_type.getShape()[1 + i];
1276       const int64_t pad_start =
1277           paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt();
1278       const int64_t pad_end =
1279           paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
1280       const int64_t block_len =
1281           block_shape_attr.getValue({i}).cast<IntegerAttr>().getInt();
1282       if ((input_len + pad_start + pad_end) % block_len != 0) {
1283         return op.emitOpError()
1284                << "requires block_shape[i] divides "
1285                   "input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; "
1286                   "failed for i="
1287                << i;
1288       }
1289     }
1290   }
1291 
1292   return success();
1293 }
1294 
1295 // Infers returned rank if possible. Further, infers returned dimension sizes
1296 // when possible. For all dimensions sizes to be inferred, the arguments
1297 // block_shape and paddings must be constant.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1298 LogicalResult SpaceToBatchNDOp::inferReturnTypes(
1299     MLIRContext *context, Optional<Location> location, ValueRange operands,
1300     DictionaryAttr attributes, RegionRange regions,
1301     SmallVectorImpl<Type> &inferredReturnTypes) {
1302   const Value input = operands[0];
1303   const Value block_shape_val = operands[1];
1304   const Value paddings_val = operands[2];
1305   const auto input_type = input.getType().cast<TensorType>();
1306   const auto block_shape_type = block_shape_val.getType().cast<TensorType>();
1307   const auto paddings_type = paddings_val.getType().cast<TensorType>();
1308 
1309   // The return is unranked when the input is unranked.
1310   if (!input_type.hasRank()) {
1311     inferredReturnTypes.assign(
1312         {UnrankedTensorType::get(input_type.getElementType())});
1313     return success();
1314   }
1315 
1316   const int64_t input_rank = input_type.getRank();
1317   const ArrayRef<int64_t> input_shape = input_type.getShape();
1318   const int64_t block_rank =
1319       SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
1320   SmallVector<int64_t, 4> return_shape(input_rank, ShapedType::kDynamicSize);
1321 
1322   // The return has all dimension sizes unknown when block_rank is unknown.
1323   if (block_rank == ShapedType::kDynamicSize) {
1324     inferredReturnTypes.assign(
1325         {RankedTensorType::get(return_shape, input_type.getElementType())});
1326     return success();
1327   }
1328 
1329   // The return preserves the remaining dimensions after blocked dimensions.
1330   for (uint64_t i = 1 + block_rank; i < input_rank; ++i) {
1331     return_shape[i] = input_shape[i];
1332   }
1333 
1334   // The rest of the dimension sizes can be calculated when block_shape and
1335   // paddings arguments are constant.
1336   DenseIntElementsAttr block_shape_attr;
1337   DenseIntElementsAttr paddings_attr;
1338   if (GetValueAsConstant(block_shape_val, block_shape_attr) &&
1339       GetValueAsConstant(paddings_val, paddings_attr)) {
1340     int64_t return_batch = input_shape[0];
1341     for (uint64_t i = 0; i < block_rank; ++i) {
1342       // Propagate dynamic dimension.
1343       if (input_shape[i + 1] == ShapedType::kDynamicSize) {
1344         return_batch = ShapedType::kDynamicSize;
1345       }
1346       if (return_batch == ShapedType::kDynamicSize) {
1347         return_shape[1 + i] = ShapedType::kDynamicSize;
1348         continue;
1349       }
1350       int64_t paddings_sum =
1351           paddings_attr.getValue<APInt>({i, 0}).getSExtValue() +
1352           paddings_attr.getValue<APInt>({i, 1}).getSExtValue();
1353       int64_t block_shape_i =
1354           block_shape_attr.getValue<APInt>({i}).getSExtValue();
1355       return_batch *= block_shape_i;
1356       return_shape[1 + i] = (paddings_sum + input_shape[i + 1]) / block_shape_i;
1357     }
1358     return_shape[0] = return_batch;
1359   }
1360 
1361   inferredReturnTypes.assign(
1362       {RankedTensorType::get(return_shape, input_type.getElementType())});
1363   return success();
1364 }
1365 
1366 //===----------------------------------------------------------------------===//
1367 // SparseSoftmaxCrossEntropyWithLogitsOp
1368 //===----------------------------------------------------------------------===//
1369 
Verify(SparseSoftmaxCrossEntropyWithLogitsOp op)1370 static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) {
1371   if (!IsOfRankOrUnranked(op.features(), 2)) {
1372     return op.emitOpError("requires features operand of rank two");
1373   }
1374   if (!IsOfRankOrUnranked(op.labels(), 1)) {
1375     return op.emitOpError("requires labels operand of rank one");
1376   }
1377   auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
1378   auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
1379   if (features_ty && labels_ty) {
1380     int64_t features_batches = features_ty.getDimSize(0);
1381     int64_t labels_batches = labels_ty.getDimSize(0);
1382     if (!ShapedType::isDynamic(features_batches) &&
1383         !ShapedType::isDynamic(labels_batches) &&
1384         features_batches != labels_batches)
1385       return op.emitOpError(
1386           "requires features and labels with matching first dimension");
1387   }
1388   return success();
1389 }
1390 
1391 //===----------------------------------------------------------------------===//
1392 // SplitOp
1393 //===----------------------------------------------------------------------===//
1394 
1395 // Verifies the input and split dimension operands for tf.Split/tf.SplitV.
1396 // Writes the split dimension's index (adjusted with input rank) via `dim_index`
1397 // if it's a constant.
1398 template <class Op>
VerifySplitInputAndSplitDim(Op op,Optional<int64_t> * dim_index)1399 LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
1400   *dim_index = llvm::None;
1401 
1402   Value split_dim = op.split_dim();
1403   if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
1404     if (split_dim_type.getRank() != 0)
1405       return op.emitOpError(
1406           "split dimension should be an integer scalar tensor");
1407 
1408   // We can perform further verification if the input tensor to be split has
1409   // known rank and the split dimension tensor is a constant.
1410 
1411   auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
1412   if (!input_type) return success();
1413 
1414   int64_t input_rank = input_type.getRank();
1415   if (input_rank == 0)
1416     return op.emitOpError("cannot split scalar input tensor");
1417 
1418   DenseIntElementsAttr split_dim_attr;
1419   if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
1420 
1421   int64_t index = (*split_dim_attr.begin()).getSExtValue();
1422 
1423   if (index + input_rank < 0 || index >= input_rank) {
1424     return op.emitOpError("split dimension must be in range [-")
1425            << input_rank << ", " << input_rank << ")";
1426   }
1427 
1428   if (index < 0) index += input_rank;
1429   *dim_index = index;
1430 
1431   return success();
1432 }
1433 
Verify(SplitOp op)1434 static LogicalResult Verify(SplitOp op) {
1435   Optional<int64_t> dim_index;
1436   if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1437   if (!dim_index) return success();
1438 
1439   int64_t input_dim_size =
1440       op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1441   if (input_dim_size == ShapedType::kDynamicSize) return success();
1442 
1443   if (input_dim_size % op.getNumResults() != 0)
1444     return op.emitOpError("dimension #")
1445            << *dim_index << " not divisible by the number of result tensors";
1446 
1447   return success();
1448 }
1449 
1450 //===----------------------------------------------------------------------===//
1451 // SplitVOp
1452 //===----------------------------------------------------------------------===//
1453 
Verify(SplitVOp op)1454 static LogicalResult Verify(SplitVOp op) {
1455   auto split_sizes_type =
1456       op.size_splits().getType().dyn_cast<RankedTensorType>();
1457   if (!split_sizes_type) return success();
1458 
1459   if (split_sizes_type.getRank() != 1 ||
1460       (split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize &&
1461        split_sizes_type.getDimSize(0) != op.getNumResults()))
1462     return op.emitOpError("split sizes should be a 1D tensor of ")
1463            << op.getNumResults() << " elements";
1464 
1465   Optional<int64_t> dim_index = 0;
1466   if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
1467   if (!dim_index) return success();
1468 
1469   int64_t input_dim_size =
1470       op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
1471   if (input_dim_size == ShapedType::kDynamicSize) return success();
1472 
1473   // If split sizes come from a constant, they must sum to the dimension size
1474   // along split_dim, and we can have no more than one dynamic dimension.
1475   DenseIntElementsAttr split_sizes_attr;
1476   if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
1477     return success();
1478 
1479   int64_t total_dim_size = 0;  // Total dimension size assigned to splits
1480   llvm::Optional<int> dynamic_dim_index;
1481 
1482   SmallVector<int64_t, 4> split_sizes;
1483   split_sizes.reserve(
1484       split_sizes_attr.getType().cast<ShapedType>().getNumElements());
1485 
1486   for (auto dim : llvm::enumerate(split_sizes_attr)) {
1487     int64_t dim_val = dim.value().getSExtValue();
1488     split_sizes.push_back(dim_val);
1489     if (dim_val == ShapedType::kDynamicSize) {
1490       // We cannot have more than one dynamic dimension.
1491       if (dynamic_dim_index)
1492         return op.emitOpError(
1493             "cannot have more than one dynamic dimension in split sizes");
1494       dynamic_dim_index = dim.index();
1495     } else {
1496       total_dim_size += dim_val;
1497     }
1498   }
1499 
1500   if (!dynamic_dim_index && total_dim_size != input_dim_size)
1501     return op.emitOpError(
1502                "split sizes must sum up to the dimension size along split "
1503                "dimension, found ")
1504            << total_dim_size << " vs " << input_dim_size;
1505 
1506   if (dynamic_dim_index && total_dim_size > input_dim_size)
1507     return op.emitOpError(
1508                "split sizes must sum up to be less than or equal to the "
1509                "dimension size along split dimension, found ")
1510            << total_dim_size << " vs " << input_dim_size;
1511 
1512   return success();
1513 }
1514 
1515 //===----------------------------------------------------------------------===//
1516 // SquareOp
1517 //===----------------------------------------------------------------------===//
1518 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1519 void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1520                                            MLIRContext *context) {
1521   results.insert<SquareOfSub>(context);
1522 }
1523 
1524 //===----------------------------------------------------------------------===//
1525 // SubOp
1526 //===----------------------------------------------------------------------===//
1527 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1528 void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1529                                         MLIRContext *context) {
1530   results.insert<SubOfNeg>(context);
1531 }
1532 
fold(ArrayRef<Attribute> operands)1533 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1534   return IdentityArithmeticOpFolder<SubOp>(*this, operands);
1535 }
1536 
1537 //===----------------------------------------------------------------------===//
1538 // SumOp
1539 //===----------------------------------------------------------------------===//
1540 
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)1541 void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
1542                   Value reduction_indices, BoolAttr keep_dims) {
1543   Type out_ty =
1544       InferReductionOpType(input, reduction_indices, keep_dims, &builder);
1545   build(builder, result, out_ty, input, reduction_indices, keep_dims);
1546 }
1547 
1548 // TODO: Templatize this fold for all reduction ops.
fold(ArrayRef<Attribute> operands)1549 OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
1550   auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
1551   if (!input_ty) return {};
1552   auto result_ty = getType().template dyn_cast<RankedTensorType>();
1553   if (!result_ty) return {};
1554 
1555   // Bypass this op if the result has the same shape and type. This can happen
1556   // if the input tensor has size 0 or size 1.
1557   if (!keep_dims() && input_ty == result_ty) {
1558     return input();
1559   }
1560   return {};
1561 }
1562 
1563 //===----------------------------------------------------------------------===//
1564 // StridedSliceOp
1565 //===----------------------------------------------------------------------===//
1566 
1567 // TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to
1568 // tf.SliceOp if both of the following are true:
1569 // - All strides have a known value equal to 1
1570 // - No masks are set (or masks can be applied by transforming the inputs to
1571 //   Slice)
1572 
1573 // Verifies that,
1574 //
1575 // - begin, end and strides operands are 1D and they have the same number of
1576 //   elements. Here, the number of elements should be less than 32 to support
1577 //   32-bit mask attributes.
1578 // - None of the strides values are zero.
1579 // - Ellipsis mask can have at most one bit set.
1580 
1581 template <class OpTy>
VerifyStridedSliceBase(OpTy op)1582 static LogicalResult VerifyStridedSliceBase(OpTy op) {
1583   // Expected size for operands begin, end and strides vector operands.
1584   int64_t expected_size = -1;
1585 
1586   for (Value val : {op.begin(), op.end(), op.strides()}) {
1587     auto operand_ty = val.getType().dyn_cast<ShapedType>();
1588     if (!operand_ty || !operand_ty.hasStaticShape()) {
1589       // TensorFlow constant ops may have non-static shape because the shape is
1590       // not propagated during constant folding. If the defining op for this
1591       // operand is a constant op, use the constant op's attribute to get the
1592       // actual shape.
1593       DenseIntElementsAttr attr;
1594       if (!matchPattern(val, m_Constant(&attr))) continue;
1595       operand_ty = attr.getType();
1596     }
1597 
1598     if (operand_ty.getRank() != 1)
1599       return op.emitOpError()
1600              << "requires begin, end and strides to be 1D tensors";
1601 
1602     int64_t length = operand_ty.getDimSize(0);
1603     if (length == -1) continue;
1604 
1605     if (expected_size == -1) {
1606       // This op uses 32-bit masks.
1607       if (length >= 32)
1608         return op.emitOpError(
1609             "requires begin, end and strides operands with less than 32 "
1610             "elements");
1611 
1612       expected_size = length;
1613     } else if (length != expected_size) {
1614       return op.emitOpError() << "requires begin, end and strides to have the "
1615                                  "same number of elements";
1616     }
1617   }
1618 
1619   // If strides are constants, verify that none of the element is zero.
1620   DenseIntElementsAttr strides;
1621   if (matchPattern(op.strides(), m_Constant(&strides))) {
1622     if (llvm::is_contained(strides.getValues<APInt>(), 0))
1623       return op.emitOpError("requires non-zero strides");
1624   }
1625 
1626   // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
1627   // exists only no more than one ellipsis.
1628   uint32_t ellipsis_mask = op.ellipsis_mask();
1629   if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
1630     return op.emitOpError("cannot have multiple ellipses");
1631 
1632   return success();
1633 }
1634 
1635 // Clamps the given `val`: returns `low` if `val` is less than `low`; returns
1636 // `high` if `high` is less than `val`; otherwise returns `val`.
1637 template <class T>
Clamp(const T & val,const T & low,const T & high)1638 constexpr const T &Clamp(const T &val, const T &low, const T &high) {
1639   assert(!(high < low));
1640   return (val < low) ? low : (high < val) ? high : val;
1641 }
1642 
1643 // Checks if the `index` bit of `val` is set.
1644 template <class T>
IsSet(const T & val,unsigned index)1645 constexpr bool IsSet(const T &val, unsigned index) {
1646   return (val & (1 << index)) != 0;
1647 }
1648 
1649 // Sets the `index` bit of `val`.
1650 template <class T>
Set(T & val,unsigned index)1651 constexpr void Set(T &val, unsigned index) {
1652   val |= (1 << index);
1653 }
1654 
1655 // Unset the `index` bit of `val`.
1656 template <class T>
Unset(T & val,unsigned index)1657 constexpr void Unset(T &val, unsigned index) {
1658   val &= ~(1 << index);
1659 }
1660 
1661 // Copy the `src_index` bit of `src` to `dst_index` bit of `dst`.
1662 template <class T>
CopyBit(const T & src,unsigned src_index,T & dst,unsigned dst_index)1663 constexpr void CopyBit(const T &src, unsigned src_index, T &dst,
1664                        unsigned dst_index) {
1665   if (IsSet(src, src_index))
1666     Set(dst, dst_index);
1667   else
1668     Unset(dst, dst_index);
1669 }
1670 
1671 // The sparse spec of strided slice does not correspond to the number of
1672 // dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2,
1673 // 4, 8) would have dims = 2.
1674 struct SparseSliceSpec {
1675   int64_t dims;
1676   int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
1677   const ArrayRef<int64_t> &begin;
1678   const ArrayRef<int64_t> &end;
1679   const ArrayRef<int64_t> &strides;
1680 };
1681 
1682 // The dense spec of strided slice is the canonicalized version of sparse spec.
1683 // The number of dimensions of dense spec correspond to the number of dimensions
1684 // in operand tensor.
1685 struct DenseSliceSpec {
1686   int64_t dims;
1687   int32_t begin_mask, end_mask, shrink_axis_mask;
1688   SmallVectorImpl<int64_t> &begin;
1689   SmallVectorImpl<int64_t> &end;
1690   SmallVectorImpl<int64_t> &strides;
1691 };
1692 
1693 // Make a sparse spec into a dense index spec.
1694 // The sparse spec does not correspond to the number of dimensions
1695 // Make a dense spec that corresponds to the number of dimensions
1696 //
1697 // For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then
1698 // we need to produce the missing begin_mask, end_mask for the first two
1699 // dimensions i.e. foo[:, :, 3:, 2].
BuildDenseSliceSpec(const SparseSliceSpec & sparse,DenseSliceSpec * dense)1700 static void BuildDenseSliceSpec(const SparseSliceSpec &sparse,
1701                                 DenseSliceSpec *dense) {
1702   // Build expanded dense begin, end, strides, begin_mask, end_mask, and
1703   // shrink_axis_mask.
1704   dense->begin.resize(dense->dims);
1705   dense->end.resize(dense->dims);
1706   dense->strides.resize(dense->dims);
1707   dense->begin_mask = 0;
1708   dense->end_mask = 0;
1709   dense->shrink_axis_mask = 0;
1710 
1711   // Count number of new_axis after ellipsis. This helps in calculating the
1712   // number of dimensions ellipsis represents in the sparse spec.
1713   bool ellipsis_seen = false;
1714   int num_new_axis_after_ellipsis = 0;
1715   for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1716     if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index))
1717       num_new_axis_after_ellipsis++;
1718     if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true;
1719   }
1720 
1721   int dense_index = 0;
1722   for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) {
1723     if (IsSet(sparse.new_axis_mask, sparse_index)) continue;
1724     if (IsSet(sparse.ellipsis_mask, sparse_index)) {
1725       auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) +
1726                                      1 + num_new_axis_after_ellipsis,
1727                                  dense->dims);
1728       // Expand ellipsis into the appropriate dense indices. From current index
1729       // until next_index, all dimensions would have begin and end masks set and
1730       // stride 1, i.e., get all elements in those dimensions.
1731       for (; dense_index < next_index; ++dense_index) {
1732         dense->begin[dense_index] = dense->end[dense_index] = 0;
1733         dense->strides[dense_index] = 1;
1734         Set(dense->begin_mask, dense_index);
1735         Set(dense->end_mask, dense_index);
1736       }
1737       continue;
1738     }
1739     assert(dense_index < dense->dims);
1740     // Copy over the sparse indices to dense indices if ellipsis_mask and
1741     // new_axis_mask are not set.
1742     dense->begin[dense_index] = sparse.begin[sparse_index];
1743     dense->end[dense_index] = sparse.end[sparse_index];
1744     dense->strides[dense_index] = sparse.strides[sparse_index];
1745     CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index);
1746     CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index);
1747     CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask,
1748             dense_index);
1749     dense_index++;
1750   }
1751 }
1752 
1753 // For the given `input_shape`, calculates the sliced shape using the given
1754 // `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and
1755 // `shrink_axis_mask` masks. Updates the result back to `input_shape`. If
1756 // `shrink_axis_mask` is not zero, this function will not drop the corresponding
1757 // dimensions in `input_shape`; it will turn them into 1s. At the same time,
1758 // canonicalizes `begin`, `end`, and `strides. The calculation follows
1759 // tf.StridedSlice op semantics.
CalculateSlicedShapeFromDenseIndices(MutableArrayRef<int64_t> input_shape,int32_t begin_mask,int32_t end_mask,int32_t shrink_axis_mask,MutableArrayRef<int64_t> begin,MutableArrayRef<int64_t> end,MutableArrayRef<int64_t> stride)1760 static void CalculateSlicedShapeFromDenseIndices(
1761     MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
1762     int32_t shrink_axis_mask, MutableArrayRef<int64_t> begin,
1763     MutableArrayRef<int64_t> end, MutableArrayRef<int64_t> stride) {
1764   assert(input_shape.size() <= 32);  // Only 32-bit masks are supported.
1765 
1766   // Make sure ranges' ranks are consistent with the input.
1767   assert(input_shape.size() == begin.size());
1768   assert(input_shape.size() == end.size());
1769   assert(input_shape.size() == stride.size());
1770 
1771   for (int i = 0, e = input_shape.size(); i < e; ++i) {
1772     if (ShapedType::isDynamic(input_shape[i])) continue;
1773 
1774     int64_t dim_i = input_shape[i];
1775     int64_t begin_i = begin[i];
1776     int64_t end_i = end[i];
1777     int64_t stride_i = stride[i];
1778 
1779     // [0]: mask for begin, [1]: mask for end
1780     int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
1781     // [0]: bound for begin, [1]: bound for end
1782     int64_t bounds[] = {stride_i > 0 ? 0 : -1,
1783                         stride_i > 0 ? dim_i : dim_i - 1};
1784 
1785     // Canonicalizes the given range `point` (begin/end) according to the
1786     // current dimension. `c` means case: 0 for begin, 1 for end.
1787     auto canonicalize = [&](int64_t point, int c) {
1788       if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
1789 
1790       // Add dim as offset to negative range point.
1791       point = point < 0 ? dim_i + point : point;
1792       return Clamp(point, bounds[0], bounds[1]);
1793     };
1794 
1795     begin_i = canonicalize(begin_i, 0);
1796     end_i = canonicalize(end_i, 1);
1797 
1798     int64_t interval_len = end_i - begin_i;
1799     int64_t size_i = 0;
1800     // If internal length is zero or has different sign from stride, it's a
1801     // degenerated case: we are slicing nothing. Otherwise, calculate the sliced
1802     // size.
1803     if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
1804       size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
1805 
1806     begin[i] = begin_i;
1807     if (IsSet(shrink_axis_mask, i)) {
1808       // Shrink this dimension. It means we only take the element at begin_i.
1809       input_shape[i] = 1;
1810       end[i] = begin_i + 1;
1811       stride[i] = 1;
1812     } else {
1813       input_shape[i] = size_i;
1814       end[i] = end_i;
1815       stride[i] = stride_i;
1816     }
1817   }
1818 }
1819 
1820 // For the given `input_shape`, calculates the sliced shape using the given
1821 // `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`,
1822 // `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks.
1823 // Updates the result back to `input_shape`.
CalculateSlicedShapeFromSparseIndices(MutableArrayRef<int64_t> input_shape,ArrayRef<int64_t> sparse_begin,ArrayRef<int64_t> sparse_end,ArrayRef<int64_t> sparse_strides,int32_t begin_mask,int32_t end_mask,int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,SmallVectorImpl<int64_t> * begin,SmallVectorImpl<int64_t> * end,SmallVectorImpl<int64_t> * stride)1824 static void CalculateSlicedShapeFromSparseIndices(
1825     MutableArrayRef<int64_t> input_shape, ArrayRef<int64_t> sparse_begin,
1826     ArrayRef<int64_t> sparse_end, ArrayRef<int64_t> sparse_strides,
1827     int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask,
1828     int32_t new_axis_mask, int32_t shrink_axis_mask,
1829     SmallVectorImpl<int64_t> *begin, SmallVectorImpl<int64_t> *end,
1830     SmallVectorImpl<int64_t> *stride) {
1831   int64_t num_sparse_indices = sparse_begin.size();
1832   SparseSliceSpec sparse = {num_sparse_indices, begin_mask,    end_mask,
1833                             ellipsis_mask,      new_axis_mask, shrink_axis_mask,
1834                             sparse_begin,       sparse_end,    sparse_strides};
1835 
1836   // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is
1837   // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields
1838   // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...].
1839   if (sparse.ellipsis_mask == 0) {
1840     Set(sparse.ellipsis_mask, sparse.dims);
1841     sparse.dims++;
1842   }
1843 
1844   int64_t dims = input_shape.size();
1845   DenseSliceSpec dense = {dims,
1846                           /*begin_mask = */ 0,
1847                           /*end_mask = */ 0,
1848                           /*shrink_axis_mask = */ 0,
1849                           *begin,
1850                           *end,
1851                           *stride};
1852 
1853   BuildDenseSliceSpec(sparse, &dense);
1854   CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask,
1855                                        dense.end_mask, dense.shrink_axis_mask,
1856                                        *begin, *end, *stride);
1857 }
1858 
GetSlicedBoundRanges(SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)1859 bool StridedSliceOp::GetSlicedBoundRanges(
1860     SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
1861     SmallVectorImpl<int64_t> *slice_stride) {
1862   // TODO(hinsu): Support lowering for ops with dynamic begin and end values
1863   // when it is possible to derive indices based on mask attributes.
1864   DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
1865   if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
1866       !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
1867       !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
1868     return false;
1869 
1870   auto input_ty = this->input().getType().dyn_cast<RankedTensorType>();
1871   if (!input_ty || !input_ty.hasStaticShape()) return false;
1872   auto input_shape = llvm::to_vector<4>(input_ty.getShape());
1873 
1874   SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
1875 
1876   for (const APInt &index : sparse_begin_attr)
1877     sparse_begin.push_back(index.getSExtValue());
1878   for (const APInt &index : sparse_end_attr)
1879     sparse_end.push_back(index.getSExtValue());
1880   for (const APInt &stride : sparse_strides_attr)
1881     sparse_strides.push_back(stride.getSExtValue());
1882 
1883   CalculateSlicedShapeFromSparseIndices(
1884       input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
1885       end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
1886       slice_begin, slice_end, slice_stride);
1887   return true;
1888 }
1889 
fold(ArrayRef<Attribute> operands)1890 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
1891   // Fold StridedSlice operation if it extracts statically known dimensions.
1892   //
1893   // For example,
1894   //
1895   //   %shape  = tf.Shape(%arg)                   // %arg: tensor<?x2x3x1xf32>
1896   //   %height = tf.StridedSlice(%shape, 1, 2, 1)
1897   //
1898   // In this case %height can be replaced with a constant 2.
1899   //
1900   // Or,
1901   //
1902   //   %shape  = tf.Shape(%arg)                   // %arg: tensor<?x2x3x1xf32>
1903   //   %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
1904   //
1905   // In this case %spatial_shape can be replaced with a constant [2, 3].
1906 
1907   // Input to strided slice op is defined by shape operation.
1908   auto shape_op = input().getDefiningOp<ShapeOp>();
1909   if (!shape_op) {
1910     return {};
1911   }
1912 
1913   // `begin`, `end` and `strides` should be constant in order to infer static
1914   // dimension.
1915   DenseIntElementsAttr begin_attr, end_attr, strides_attr;
1916   if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
1917       !matchPattern(end(), m_Constant(&end_attr)) ||
1918       !matchPattern(strides(), m_Constant(&strides_attr)) ||
1919       begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
1920       strides_attr.getNumElements() != 1) {
1921     return {};
1922   }
1923 
1924   // Do not fold when `new_axis_mask` is set. It's likely to break the shape
1925   // of output. Typically, `new_axis_mask` is not set in this canonicalization
1926   // pattern.
1927   if (new_axis_mask() != 0) return {};
1928 
1929   auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
1930   // Only ranked tensor can be folded.
1931   if (!tensor_ty) return {};
1932 
1933   int64_t rank = tensor_ty.getRank();
1934   int64_t begin_int = begin_attr.getValue<APInt>(0).getSExtValue();
1935   int64_t end_int = end_attr.getValue<APInt>(0).getSExtValue();
1936   int64_t strides_int = strides_attr.getValue<APInt>(0).getSExtValue();
1937 
1938   // Canonicalize `begin` and `end` in case of negative index.
1939   if (begin_int < 0) begin_int += rank;
1940   if (end_int < 0) end_int += rank;
1941 
1942   // Create `begin` and `end` from `*_mask`. Note that we don't care about
1943   // `new_axis_mask` as it can be inferred from `output_ty`.
1944   if (shrink_axis_mask() == 1) {
1945     // When `shrink_axis_mask` is set, output is always a scalar so only
1946     // one element is sliced.
1947     end_int = begin_int + 1;
1948   }
1949   if (begin_mask() == 1) {
1950     begin_int = (strides_int > 0) ? 0 : rank - 1;
1951   }
1952   if (end_mask() == 1) {
1953     end_int = (strides_int > 0) ? rank : -1;
1954   }
1955   if (ellipsis_mask() == 1) {
1956     begin_int = 0;
1957     end_int = rank;
1958   }
1959 
1960   // It's possible that `begin` and `end` are out of bound. See
1961   // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
1962   if (strides_int > 0) {
1963     begin_int = std::min(begin_int, rank);
1964     end_int = std::min(end_int, rank);
1965   } else {
1966     begin_int = std::min(begin_int, rank - 1);
1967     end_int = std::min(end_int, rank - 1);
1968   }
1969 
1970   SmallVector<int64_t, 2> sub_shape;
1971   // Only handle cases that have something to slice to avoid infinite for-loop.
1972   if ((end_int > begin_int && strides_int > 0) ||
1973       (end_int < begin_int && strides_int < 0)) {
1974     // Extract sub-shape only if all of those dimensions are static.
1975     for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
1976          i += strides_int) {
1977       if (tensor_ty.isDynamicDim(i)) {
1978         return {};
1979       }
1980       sub_shape.push_back(tensor_ty.getDimSize(i));
1981     }
1982   }
1983 
1984   // For unranked or dynamic output, we infer the output type to either a
1985   // scalar or a vector based on `shrink_axis_mask` because we have rejected
1986   // the case of `new_axis_mask` != 0.
1987   auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
1988   auto output_ty = output().getType().dyn_cast<RankedTensorType>();
1989   if (!output_ty || !output_ty.hasStaticShape()) {
1990     if (shrink_axis_mask() == 1) {
1991       output_ty = RankedTensorType::get({}, output_elt_ty);
1992     } else {
1993       output_ty = RankedTensorType::get(
1994           {static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
1995     }
1996   }
1997 
1998   // Down-cast to 32 bit int if needed.
1999   if (output_elt_ty.isInteger(32)) {
2000     SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
2001     std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
2002                    [](int64_t d) { return static_cast<int32_t>(d); });
2003     return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
2004   }
2005   return DenseIntElementsAttr::get(output_ty, sub_shape);
2006 }
2007 
2008 //===----------------------------------------------------------------------===//
2009 // StridedSliceGradOp
2010 //===----------------------------------------------------------------------===//
2011 
Verify(StridedSliceGradOp op)2012 static LogicalResult Verify(StridedSliceGradOp op) {
2013   auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
2014   if (shape_type && shape_type.getRank() != 1)
2015     return op.emitOpError("'shape' operand must be 1D tensor, but got ")
2016            << shape_type.getRank() << "D tensor";
2017 
2018   if (failed(VerifyStridedSliceBase(op))) return failure();
2019 
2020   // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
2021   // the sliced type from StridedSlice.
2022 
2023   return success();
2024 }
2025 
GetSlicedShapeAndBoundRanges(SmallVectorImpl<int64_t> * input_shape,SmallVectorImpl<int64_t> * slice_begin,SmallVectorImpl<int64_t> * slice_end,SmallVectorImpl<int64_t> * slice_stride)2026 bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
2027     SmallVectorImpl<int64_t> *input_shape,
2028     SmallVectorImpl<int64_t> *slice_begin, SmallVectorImpl<int64_t> *slice_end,
2029     SmallVectorImpl<int64_t> *slice_stride) {
2030   DenseIntElementsAttr shape_attr;
2031   DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr;
2032   if (!matchPattern(shape(), m_Constant(&shape_attr)) ||
2033       !matchPattern(begin(), m_Constant(&sparse_begin_attr)) ||
2034       !matchPattern(end(), m_Constant(&sparse_end_attr)) ||
2035       !matchPattern(strides(), m_Constant(&sparse_strides_attr)))
2036     return false;
2037 
2038   int rank = std::distance(shape_attr.begin(), shape_attr.end());
2039 
2040   input_shape->clear();
2041   input_shape->reserve(rank);
2042   for (const APInt &dim : shape_attr)
2043     input_shape->push_back(dim.getSExtValue());
2044 
2045   SmallVector<int64_t, 4> sparse_begin, sparse_end, sparse_strides;
2046 
2047   for (const APInt &index : sparse_begin_attr)
2048     sparse_begin.push_back(index.getSExtValue());
2049   for (const APInt &index : sparse_end_attr)
2050     sparse_end.push_back(index.getSExtValue());
2051   for (const APInt &stride : sparse_strides_attr)
2052     sparse_strides.push_back(stride.getSExtValue());
2053 
2054   CalculateSlicedShapeFromSparseIndices(
2055       *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(),
2056       end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(),
2057       slice_begin, slice_end, slice_stride);
2058   return true;
2059 }
2060 
2061 //===----------------------------------------------------------------------===//
2062 // SummaryWriterOp
2063 //===----------------------------------------------------------------------===//
2064 
GetResourceHandleValueAndId(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2065 ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId(
2066     llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2067     int64_t &next_id) {
2068   llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2069   return GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2070                                          writer(), resource_handle_id_map,
2071                                          next_id);
2072 }
2073 
2074 //===----------------------------------------------------------------------===//
2075 // TPUExecuteOp
2076 //===----------------------------------------------------------------------===//
2077 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2078 void TPUExecuteOp::getEffects(
2079     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2080         &effects) {
2081   effects.reserve(args().size() + 1);
2082 
2083   // There may be some TPU Embedding ops in the computation, so this effect is
2084   // added conservatively.
2085   effects.emplace_back(MemoryEffects::Write::get(),
2086                        ResourceEffects::TPUEmbedding::get());
2087 
2088   for (Value value : args()) {
2089     if (value.getType()
2090             .cast<TensorType>()
2091             .getElementType()
2092             .isa<ResourceType>()) {
2093       // Conservatively mark resource handles as read and write, as without
2094       // analyzing TPUCompile, there is not sufficient information to determine
2095       // effects on resources. For the MLIR bridge, this op will never be
2096       // populated with resource handles and tf.TPUExecuteAndUpdateVariables is
2097       // used instead.
2098       effects.emplace_back(MemoryEffects::Read::get(), value,
2099                            ResourceEffects::Variable::get());
2100       effects.emplace_back(MemoryEffects::Write::get(), value,
2101                            ResourceEffects::Variable::get());
2102     }
2103   }
2104 }
2105 
2106 //===----------------------------------------------------------------------===//
2107 // TPUExecuteAndUpdateVariablesOp
2108 //===----------------------------------------------------------------------===//
2109 
Verify(TPUExecuteAndUpdateVariablesOp op)2110 static LogicalResult Verify(TPUExecuteAndUpdateVariablesOp op) {
2111   int num_resource_args = 0;
2112   for (Type arg_type : op.args().getTypes())
2113     if (arg_type.cast<TensorType>().getElementType().isa<ResourceType>())
2114       ++num_resource_args;
2115 
2116   auto check_attr = [&](ArrayAttr indices, llvm::StringRef name,
2117                         int min) -> LogicalResult {
2118     if (indices.size() != num_resource_args)
2119       return op.emitOpError()
2120              << "requires '" << name
2121              << "' to be the same size as number of resource handles in 'args' "
2122                 "("
2123              << num_resource_args << "), but got " << indices.size();
2124 
2125     for (auto entry : llvm::enumerate(indices.getValue())) {
2126       auto int_attr = entry.value().cast<IntegerAttr>();
2127       if (int_attr.getInt() < min)
2128         return op.emitOpError()
2129                << "requires '" << name << "' to contain values of at least "
2130                << min << ", but got " << int_attr.getInt() << " at index "
2131                << entry.index();
2132     }
2133 
2134     return success();
2135   };
2136 
2137   return failure(
2138       failed(check_attr(op.device_var_reads_indices(),
2139                         /*name=*/"device_var_reads_indices", /*min=*/0)) ||
2140       failed(check_attr(op.device_var_updates_indices(),
2141                         /*name=*/"device_var_updates_indices", /*min=*/-1)));
2142 }
2143 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)2144 void TPUExecuteAndUpdateVariablesOp::getEffects(
2145     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2146         &effects) {
2147   effects.reserve(device_var_reads_indices().size() + 1);
2148 
2149   // There may be some TPU Embedding ops in the computation, so this effect is
2150   // added conservatively.
2151   effects.emplace_back(MemoryEffects::Write::get(),
2152                        ResourceEffects::TPUEmbedding::get());
2153   auto resource_handles = llvm::make_filter_range(args(), [](Value value) {
2154     return value.getType()
2155         .cast<TensorType>()
2156         .getElementType()
2157         .isa<ResourceType>();
2158   });
2159 
2160   for (auto &entry : llvm::enumerate(resource_handles)) {
2161     Value value = entry.value();
2162     effects.emplace_back(MemoryEffects::Read::get(), value,
2163                          ResourceEffects::Variable::get());
2164     if (device_var_updates_indices()
2165             .getValue()[entry.index()]
2166             .cast<IntegerAttr>()
2167             .getInt() >= 0)
2168       effects.emplace_back(MemoryEffects::Write::get(), value,
2169                            ResourceEffects::Variable::get());
2170   }
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // TensorListReserveOp
2175 //===----------------------------------------------------------------------===//
2176 
Verify(TensorListReserveOp op)2177 static LogicalResult Verify(TensorListReserveOp op) {
2178   if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2179       !IsOfRankOrUnranked(op.element_shape(), 1)) {
2180     return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2181   }
2182 
2183   if (!IsOfRankOrUnranked(op.num_elements(), 0)) {
2184     return op.emitOpError("requires num_elements operand to be 0D tensor");
2185   }
2186   return success();
2187 }
2188 
2189 //===----------------------------------------------------------------------===//
2190 // TensorListElementShapeOp
2191 //===----------------------------------------------------------------------===//
2192 
fold(ArrayRef<Attribute> operands)2193 OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
2194   int width =
2195       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2196   auto variant_type =
2197       getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
2198   if (variant_type.getSubtypes().empty()) return {};
2199   return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
2200 }
2201 
2202 //===----------------------------------------------------------------------===//
2203 // TensorListStackOp
2204 //===----------------------------------------------------------------------===//
2205 
Verify(TensorListStackOp op)2206 static LogicalResult Verify(TensorListStackOp op) {
2207   if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2208       !IsOfRankOrUnranked(op.element_shape(), 1)) {
2209     return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2210   }
2211   return success();
2212 }
2213 
2214 //===----------------------------------------------------------------------===//
2215 // TensorScatterUpdateOp
2216 //===----------------------------------------------------------------------===//
2217 
Verify(TensorScatterUpdateOp op)2218 static LogicalResult Verify(TensorScatterUpdateOp op) {
2219   if (!HasRankAtLeast(op.tensor(), 1))
2220     return op.emitOpError(
2221         "requires tensor operand to have at least 1 dimension");
2222   if (!HasRankAtLeast(op.indices(), 1))
2223     return op.emitOpError(
2224         "requires indices operand to have at least 1 dimension");
2225   if (!HasRankAtLeast(op.updates(), 1))
2226     return op.emitOpError(
2227         "requires updates operand to have at least 1 dimension");
2228 
2229   auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
2230   auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2231   if (!tensor_ty || !indices_ty) return success();
2232 
2233   int64_t num_index_dims = indices_ty.getShape().back();
2234   if (ShapedType::isDynamic(num_index_dims)) return success();
2235 
2236   if (num_index_dims > tensor_ty.getRank())
2237     return op.emitOpError(
2238         "requires tensor operand with rank greater than or equal to the "
2239         "indices operand's last dimensions");
2240   return success();
2241 }
2242 
2243 //===----------------------------------------------------------------------===//
2244 // TileOp
2245 //===----------------------------------------------------------------------===//
2246 
2247 // Verifies that,
2248 //
2249 // - input has at least rank 1
2250 // - multiples is rank 1
2251 // - multiples.size() == input.rank()
2252 // - input.rank() == output.rank()
2253 // - Elements in multiples are non-negative
2254 // - input.shape[i] * multiples[i] == output.shape[i]
2255 //   for i in [0, input.rank() - 1]
2256 
Verify(TileOp op)2257 static LogicalResult Verify(TileOp op) {
2258   auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
2259   auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
2260   auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
2261 
2262   if (multiples_type && multiples_type.getRank() != 1) {
2263     return op.emitOpError() << "expected multiples to be rank 1, got rank = "
2264                             << multiples_type.getRank();
2265   }
2266 
2267   if (input_type && multiples_type && multiples_type.hasStaticShape() &&
2268       (input_type.getRank() != multiples_type.getNumElements() ||
2269        (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
2270     return op.emitOpError()
2271            << "expected size of multiples equal to rank of input"
2272            << ", got multiples of size " << multiples_type.getNumElements()
2273            << ", and input of rank " << input_type.getRank();
2274   }
2275 
2276   if (input_type && output_type) {
2277     if (input_type.getRank() != output_type.getRank()) {
2278       return op.emitOpError()
2279              << "expected rank of input to equal to rank of output"
2280              << ", got input of rank " << input_type.getRank()
2281              << ", and output of rank " << output_type.getRank();
2282     }
2283 
2284     DenseIntElementsAttr multiples_attr;
2285     if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
2286       for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
2287         const int64_t input_dim = input_type.getDimSize(i);
2288         const int64_t output_dim = output_type.getDimSize(i);
2289         const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
2290 
2291         if (m < 0) {
2292           return op.emitOpError()
2293                  << "expected multiples to be non-negative, got "
2294                  << "multiples[" << i << "] = " << m;
2295         }
2296 
2297         if (!ShapedType::isDynamic(input_dim) &&
2298             !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
2299           return op.emitOpError()
2300                  << "requires input.shape[" << i << "] (" << input_dim << ")"
2301                  << " * " << m << " to be equal to "
2302                  << "output.shape[" << i << "] (" << output_dim << ")";
2303         }
2304       }
2305     }
2306   }
2307 
2308   return success();
2309 }
2310 
fold(ArrayRef<Attribute> operands)2311 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
2312   DenseIntElementsAttr multiples_attr;
2313   if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
2314     // Return input directly when multiples are all ones,
2315     // regardless what input is.
2316     if (multiples_attr.isSplat() &&
2317         multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
2318       return input();
2319     }
2320   }
2321   return {};
2322 }
2323 
2324 //===----------------------------------------------------------------------===//
2325 // TopKV2Op
2326 //===----------------------------------------------------------------------===//
2327 
Verify(TopKV2Op op)2328 static LogicalResult Verify(TopKV2Op op) {
2329   if (!HasRankAtLeast(op.input(), 1))
2330     return op.emitOpError(
2331         "requires input operand to have at least 1 dimension");
2332 
2333   if (!IsOfRankOrUnranked(op.k(), 0))
2334     return op.emitOpError("requires k operand to be 0D tensor");
2335 
2336   return success();
2337 }
2338 
2339 //===----------------------------------------------------------------------===//
2340 // ToBoolOp
2341 //===----------------------------------------------------------------------===//
2342 
2343 namespace {
2344 // If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
2345 // into an identity or an equality comparison.
2346 class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
2347   using OpRewritePattern<ToBoolOp>::OpRewritePattern;
matchAndRewrite(ToBoolOp op,PatternRewriter & rewriter) const2348   LogicalResult matchAndRewrite(ToBoolOp op,
2349                                 PatternRewriter &rewriter) const override {
2350     auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
2351     // If the input is an unranked tensor, cannpt rewrite.
2352     if (!type) return failure();
2353 
2354     // Expected return type of the ToBool operation. The return type of ToBool
2355     // operation is always 0D tensor of bool type.
2356     auto result_type = op.getResult().getType().cast<RankedTensorType>();
2357 
2358     // If input is already a tensor<i1>, it can be folded into an identity.
2359     if (type == result_type) {
2360       rewriter.replaceOp(op, op.getOperand());
2361       return success();
2362     }
2363 
2364     if (type.getRank() == 0) {
2365       // If the input is a scalar tensor, the ToBool can be expanded to
2366       // element != 0 (for numerical values) or element == empty (for string).
2367       Type element_type = type.getElementType();
2368       Attribute zero_attr;
2369       if (element_type.isIntOrFloat())
2370         zero_attr = rewriter.getZeroAttr(type);
2371       else if (element_type.isa<TF::StringType>())
2372         zero_attr = DenseStringElementsAttr::get(type, {""});
2373 
2374       if (!zero_attr) return failure();
2375 
2376       auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
2377       rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
2378           op, result_type, op.getOperand(), zero_const, false);
2379     } else {
2380       // If the input is a non-scalar ranked tensor, ToBool can be expanded
2381       // to numElements != 0. numElements will be 0 iff one of the dimensions is
2382       // zero.
2383       bool any_zero =
2384           llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
2385       rewriter.replaceOpWithNewOp<TF::ConstOp>(
2386           op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
2387     }
2388     return success();
2389   }
2390 };
2391 }  // namespace
2392 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2393 void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2394                                            MLIRContext *context) {
2395   results.insert<ToBoolOfRankedTensor>(context);
2396 }
2397 
2398 //===----------------------------------------------------------------------===//
2399 // TransposeOp
2400 //===----------------------------------------------------------------------===//
2401 
Verify(TransposeOp op)2402 static LogicalResult Verify(TransposeOp op) {
2403   auto perm_type = op.perm().getType().dyn_cast<RankedTensorType>();
2404   auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
2405   auto y_type = op.y().getType().dyn_cast<RankedTensorType>();
2406 
2407   if (perm_type && perm_type.getRank() != 1) {
2408     return op.emitOpError()
2409            << "expected perm to be a 1-D Tensor, got perm of rank "
2410            << perm_type.getRank();
2411   }
2412 
2413   if (x_type && y_type && x_type.getRank() != y_type.getRank()) {
2414     return op.emitOpError() << "x should be of the same rank with y, got "
2415                             << "x of rank " << x_type.getRank()
2416                             << ", and y of rank " << y_type.getRank();
2417   }
2418 
2419   if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) {
2420     return success();
2421   }
2422 
2423   if (x_type.getRank() != perm_type.getNumElements()) {
2424     return op.emitOpError() << "expected perm to be a 1-D Tensor of size "
2425                             << "equal to the rank of x, got perm of size "
2426                             << perm_type.getNumElements() << ", and x of rank "
2427                             << x_type.getRank();
2428   }
2429 
2430   DenseIntElementsAttr attr_perm;
2431   if (matchPattern(op.perm(), m_Constant(&attr_perm))) {
2432     // y.shape[i] should be equal to x.shape[perm[i]]
2433     // for i = [0, 1, ..., rank(x) - 1]
2434     for (auto e : llvm::enumerate(attr_perm)) {
2435       const int64_t y_idx = e.index();
2436       const int64_t y_dim = y_type.getDimSize(y_idx);
2437       const int64_t x_idx = e.value().getSExtValue();
2438       const int64_t x_dim = x_type.getDimSize(x_idx);
2439       if (y_dim != ShapedType::kDynamicSize &&
2440           x_dim != ShapedType::kDynamicSize && y_dim != x_dim) {
2441         return op.emitOpError()
2442                << "requires y.shape[" << y_idx << "] (" << y_dim << ") "
2443                << "to be equal to x.shape[perm[" << x_idx << "]] "
2444                << "(" << x_dim << ")";
2445       }
2446     }
2447   }
2448 
2449   return success();
2450 }
2451 
2452 // TODO(jpienaar): perm could be optional too.
build(OpBuilder & builder,OperationState & result,Value x,Value perm)2453 void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
2454                         Value perm) {
2455   auto x_type = x.getType().cast<TensorType>();
2456   // If value is unranked, then so is results.
2457   if (!x_type.hasRank())
2458     return TransposeOp::build(builder, result,
2459                               UnrankedTensorType::get(x_type.getElementType()),
2460                               x, perm);
2461 
2462   // TODO(jpienaar): Handle unknown perm case.
2463 
2464   // TODO(jpienaar): Extract utility function.
2465   auto etype = x_type.cast<ShapedType>().getElementType();
2466   DenseIntElementsAttr attr_shape;
2467   if (matchPattern(perm, m_Constant(&attr_shape))) {
2468     llvm::SmallVector<int64_t, 4> const_shape;
2469     if (attr_shape.isSplat()) {
2470       const_shape.assign(
2471           attr_shape.getNumElements(),
2472           x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
2473     } else {
2474       const_shape.reserve(attr_shape.getNumElements());
2475       for (const auto &dim : attr_shape)
2476         const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
2477     }
2478     return TransposeOp::build(
2479         builder, result, RankedTensorType::get(const_shape, etype), x, perm);
2480   }
2481   return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x,
2482                             perm);
2483 }
2484 
2485 namespace {
2486 
FoldIdentityTranspose(TransposeOp op)2487 OpFoldResult FoldIdentityTranspose(TransposeOp op) {
2488   DenseIntElementsAttr perm;
2489   if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
2490   const auto elements = perm.getValues<APInt>();
2491 
2492   for (auto it : llvm::enumerate(elements)) {
2493     if (it.index() != it.value()) return {};
2494   }
2495 
2496   // TODO(jpienaar): Remove if/when we handle this more generally.
2497   if (op.getType() != op.x().getType()) {
2498     // If the types don't match then only fold if all the operands are in the TF
2499     // dialect.
2500     for (auto user : op.getOperation()->getUsers())
2501       if (user->getDialect() != op->getDialect()) return {};
2502   }
2503 
2504   return op.x();
2505 }
2506 
FoldCancellableTranspose(TransposeOp op)2507 OpFoldResult FoldCancellableTranspose(TransposeOp op) {
2508   // Operand is a TransposeOp.
2509   auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
2510   if (!transpose) return {};
2511 
2512   // Permutations defined by constant operations.
2513   DenseIntElementsAttr perm0;
2514   DenseIntElementsAttr perm1;
2515   if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
2516       !matchPattern(transpose.perm(), m_Constant(&perm1)))
2517     return {};
2518 
2519   // With permutation indices that cancel each other
2520   if (!AreCancellablePermutations(perm0, perm1)) return {};
2521 
2522   return transpose.x();
2523 }
2524 
2525 }  // namespace
2526 
fold(ArrayRef<Attribute> operands)2527 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2528   if (auto folded = FoldIdentityTranspose(*this)) return folded;
2529   if (auto folded = FoldCancellableTranspose(*this)) return folded;
2530   return {};
2531 }
2532 
2533 //===----------------------------------------------------------------------===//
2534 // TruncateDivOp
2535 //===----------------------------------------------------------------------===//
2536 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2537 void TruncateDivOp::getCanonicalizationPatterns(
2538     OwningRewritePatternList &results, MLIRContext *context) {
2539   results.insert<TruncateDivWithSqrtDivisor>(context);
2540 }
2541 
2542 //===----------------------------------------------------------------------===//
2543 // NonMaxSuppressionV3Op
2544 //===----------------------------------------------------------------------===//
2545 
2546 namespace {
2547 
2548 // Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op.
2549 class NMSV3ToNMSV4Op : public OpRewritePattern<NonMaxSuppressionV3Op> {
2550   using OpRewritePattern<NonMaxSuppressionV3Op>::OpRewritePattern;
matchAndRewrite(NonMaxSuppressionV3Op nms_op,PatternRewriter & rewriter) const2551   LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op,
2552                                 PatternRewriter &rewriter) const override {
2553     if (nms_op.getNumOperands() != 5) {
2554       return failure();
2555     }
2556     SmallVector<Type, 2> new_result_types;
2557     new_result_types.push_back(nms_op.getType());
2558     auto input_ty = nms_op.getType().template cast<ShapedType>();
2559     // corresponds to the second result type of nmsv4
2560     RankedTensorType valid_output_type =
2561         RankedTensorType::get({}, input_ty.getElementType());
2562     new_result_types.push_back(valid_output_type);
2563 
2564     auto nmsv4 = rewriter.create<TF::NonMaxSuppressionV4Op>(
2565         nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(),
2566         nms_op.max_output_size(), nms_op.iou_threshold(),
2567         nms_op.score_threshold());
2568     // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the
2569     // two are different (v4 expects two output values vs v3 requires only one.
2570     nms_op.replaceAllUsesWith(nmsv4.getResult(0));
2571     return success();
2572   }
2573 };
2574 }  // namespace.
2575 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2576 void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
2577     OwningRewritePatternList &results, MLIRContext *context) {
2578   results.insert<NMSV3ToNMSV4Op>(context);
2579 }
2580 
2581 //===----------------------------------------------------------------------===//
2582 // FusedBatchNormOp
2583 //===----------------------------------------------------------------------===//
2584 
2585 namespace {
2586 
2587 class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
2588   using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,PatternRewriter & rewriter) const2589   LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
2590                                 PatternRewriter &rewriter) const override {
2591     auto new_result_types =
2592         llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
2593     // reserve_space_3
2594     new_result_types.push_back(
2595         UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
2596 
2597     OperationState new_state(tf_fused_batch_norm_op.getLoc(),
2598                              TF::FusedBatchNormV3Op::getOperationName(),
2599                              tf_fused_batch_norm_op.getOperands(),
2600                              new_result_types,
2601                              tf_fused_batch_norm_op.getAttrs());
2602     Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
2603 
2604     rewriter.replaceOp(tf_fused_batch_norm_op,
2605                        tf_fused_batch_norm_op_v3->getResults().drop_back());
2606     return success();
2607   }
2608 };
2609 }  // namespace.
2610 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2611 void FusedBatchNormOp::getCanonicalizationPatterns(
2612     OwningRewritePatternList &results, MLIRContext *context) {
2613   results.insert<ConvertFusedBatchNorm>(context);
2614 }
2615 
2616 //===----------------------------------------------------------------------===//
2617 // UnpackOp
2618 //===----------------------------------------------------------------------===//
2619 
Verify(UnpackOp op)2620 static LogicalResult Verify(UnpackOp op) {
2621   auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
2622   if (!value_type) return success();
2623 
2624   int64_t value_rank = value_type.getRank();
2625   int64_t axis = op.axis();
2626   if (axis < -value_rank || axis >= value_rank)
2627     return op.emitOpError("axis attribute must be in the range of [-")
2628            << value_rank << ", " << value_rank << ')';
2629 
2630   axis = GetDimForAxis(axis, value_rank);
2631   int64_t dim_size = value_type.getDimSize(axis);
2632   if (ShapedType::isDynamic(dim_size)) return success();
2633 
2634   if (dim_size != op.getNumResults())
2635     return op.emitOpError("result count must be equal to ") << dim_size;
2636 
2637   return success();
2638 }
2639 
2640 namespace {
2641 
2642 // Hoist coefficient-wise unary operation out of the Unpack op:
2643 //
2644 //   %unpacked:N = "tf.Unpack"(%0)
2645 //   %neg0 = "tf.Neg"(%unpacked#0)
2646 //   %neg1 = "tf.Neg"(%unpacked#1)
2647 //   ...
2648 //   %negN-1 = "tf.Neg"(%unpacked:N-1)
2649 //
2650 // Rewrite it to:
2651 //
2652 //   %neg = "tf.Neg"(%0)
2653 //   %unpacked:N = "tf.Unpack"(%neg)
2654 class HoistCwiseUnaryOutOfUnpack : public OpRewritePattern<UnpackOp> {
2655  public:
HoistCwiseUnaryOutOfUnpack(MLIRContext * context)2656   explicit HoistCwiseUnaryOutOfUnpack(MLIRContext *context)
2657       : OpRewritePattern<UnpackOp>(context) {}
2658   LogicalResult matchAndRewrite(UnpackOp op,
2659                                 PatternRewriter &rewriter) const override;
2660 };
2661 
matchAndRewrite(UnpackOp op,PatternRewriter & rewriter) const2662 LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite(
2663     UnpackOp op, PatternRewriter &rewriter) const {
2664   auto loc = op.getLoc();
2665 
2666   // First unpack user must be coeff-wise unary operation.
2667   Operation *first_user = *op->getUsers().begin();
2668   if (!first_user->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
2669 
2670   // All unpack users must be defined by the op of same kind.
2671   bool users_same_op = llvm::all_of(op->getUsers(), [&](Operation *user) {
2672     return user->getName() == first_user->getName();
2673   });
2674   if (!users_same_op) return failure();
2675 
2676   // Pass unpack operand to unary operation.
2677   OperationState new_unary_op_state(loc, first_user->getName().getStringRef(),
2678                                     op.getOperand(), op.getOperand().getType(),
2679                                     ArrayRef<NamedAttribute>());
2680   Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
2681 
2682   // Unpack results after applying unary operation.
2683   auto unpack_unary_op = rewriter.create<UnpackOp>(
2684       loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis());
2685 
2686   // Bypass all users of the original unpack operation and use `unpack_unary_op`
2687   // results instead.
2688   for (auto pair : llvm::zip(op.getResults(), unpack_unary_op.getResults())) {
2689     OpResult old_result = std::get<0>(pair);  // result of original Unpack
2690     OpResult new_result = std::get<1>(pair);  // result of transformed Unpack
2691     for (Operation *user : llvm::make_early_inc_range(old_result.getUsers()))
2692       rewriter.replaceOp(user, ValueRange(new_result));
2693   }
2694 
2695   // Erase original unpack operation.
2696   rewriter.eraseOp(op.getOperation());
2697 
2698   return success();
2699 }
2700 
2701 }  // namespace
2702 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2703 void UnpackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2704                                            MLIRContext *context) {
2705   results.insert<HoistCwiseUnaryOutOfUnpack>(context);
2706 }
2707 
2708 //===----------------------------------------------------------------------===//
2709 // Unsorted segment reduction ops
2710 //===----------------------------------------------------------------------===//
2711 
2712 template <class Op>
VerifyUnsortedSegmentReduction(Op op)2713 static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
2714   if (!HasRankAtMost(op.num_segments(), 0))
2715     return op.emitOpError("number of segments should be a 0-D tensor");
2716 
2717   auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
2718   auto segment_ids_type =
2719       op.segment_ids().getType().template dyn_cast<RankedTensorType>();
2720   if (data_type && segment_ids_type) {
2721     if (data_type.getRank() < segment_ids_type.getRank())
2722       return op.emitOpError(
2723           "requires segment ids rank to be less than or equal to data's rank");
2724 
2725     int index = 0;
2726     for (auto shape_pair :
2727          llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
2728       int64_t segment_id_dim = std::get<0>(shape_pair);
2729       int64_t data_dim = std::get<1>(shape_pair);
2730       if (!ShapedType::isDynamic(segment_id_dim) &&
2731           !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
2732         return op.emitOpError(
2733                    "requires segment ids shape to be a prefix of data shape, "
2734                    "but dimension #")
2735                << index << " differs: " << segment_id_dim << " vs. "
2736                << data_dim;
2737       ++index;
2738     }
2739   }
2740 
2741   DenseIntElementsAttr num_segments_attr;
2742   if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
2743     int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
2744     if (num_segments < 0)
2745       return op.emitOpError("num of segments cannot be negative");
2746   }
2747 
2748   return success();
2749 }
2750 
2751 //===----------------------------------------------------------------------===//
2752 // VarHandleOp
2753 //===----------------------------------------------------------------------===//
2754 
GetResourceHandleValueAndId(llvm::SmallDenseMap<ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)2755 ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId(
2756     llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
2757     int64_t &next_id) {
2758   llvm::StringRef device = GetDeviceOrEmpty(getOperation());
2759   return GetResourceHandleValueAndIdBase(container(), shared_name(), device,
2760                                          resource(), resource_handle_id_map,
2761                                          next_id);
2762 }
2763 
2764 //===----------------------------------------------------------------------===//
2765 // VarIsInitializedOp
2766 //===----------------------------------------------------------------------===//
2767 
2768 namespace {
2769 
2770 /// Erase VarIsInitializedOp operations with no uses. This op has side effect on
2771 /// resources (read-only), but can still be deleted if it has zero uses.
2772 struct EraseDeadVarIsInitializedOp
2773     : public OpRewritePattern<VarIsInitializedOp> {
2774   using OpRewritePattern<VarIsInitializedOp>::OpRewritePattern;
2775 
matchAndRewritemlir::TF::__anon2bac52fa1311::EraseDeadVarIsInitializedOp2776   LogicalResult matchAndRewrite(VarIsInitializedOp op,
2777                                 PatternRewriter &rewriter) const override {
2778     if (!op.use_empty()) return failure();
2779     rewriter.eraseOp(op);
2780     return success();
2781   }
2782 };
2783 }  // end anonymous namespace.
2784 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2785 void VarIsInitializedOp::getCanonicalizationPatterns(
2786     OwningRewritePatternList &patterns, MLIRContext *context) {
2787   patterns.insert<EraseDeadVarIsInitializedOp>(context);
2788 }
2789 
2790 //===----------------------------------------------------------------------===//
2791 // VariableOp
2792 //===----------------------------------------------------------------------===//
2793 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2794 void VariableOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2795                                              MLIRContext *context) {
2796   results.insert<VariableToVariableV2>(context);
2797 }
2798 
2799 //===----------------------------------------------------------------------===//
2800 // VariableShapeOp
2801 //===----------------------------------------------------------------------===//
2802 
Verify(VariableShapeOp op)2803 static LogicalResult Verify(VariableShapeOp op) {
2804   auto input_type = op.input().getType().cast<TensorType>();
2805   if (input_type.hasStaticShape() && input_type.getNumElements() != 1)
2806     return op.emitOpError("requires input to have one resource");
2807 
2808   auto resource_type = input_type.getElementType().cast<TF::ResourceType>();
2809   auto subtypes = resource_type.getSubtypes();
2810   switch (subtypes.size()) {
2811     case 1:
2812       return VerifyShapeOperandAndResult(
2813           op, resource_type.getSubtypes().front(), op.getType());
2814     case 0:
2815       return VerifyShapeOperandAndResult(op, Type(), op.getType());
2816     default:
2817       return op.emitOpError(
2818           "requires resource input type to have at most 1 subtype");
2819   }
2820 }
2821 
fold(ArrayRef<Attribute> operands)2822 OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
2823   int width =
2824       getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
2825   auto resource_type =
2826       getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
2827   if (resource_type.getSubtypes().empty()) return {};
2828   return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
2829 }
2830 
2831 //===----------------------------------------------------------------------===//
2832 // WhileOp
2833 //===----------------------------------------------------------------------===//
2834 
VerifyWhileTypes(Operation * op,TypeRange cond_input,TypeRange body_input,TypeRange body_result,bool shape_invariant)2835 static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
2836                                       TypeRange body_input,
2837                                       TypeRange body_result,
2838                                       bool shape_invariant) {
2839   const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
2840   const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
2841   constexpr int kNumRegionTypeLists = 3;
2842   const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
2843       {body_result, "body result"},
2844       {cond_input, "condition input"},
2845       {body_input, "body input"},
2846   }};
2847 
2848   // A pair of type lists should be cast compatible with each other if one is
2849   // converted to the another for a function call or assignment or there is a
2850   // common source of inputs for both. Therefore, the While op requires the
2851   // following pairs of type lists to be cast compatible for the tensor_cast
2852   // operation:
2853   //
2854   // * Operands and cond inputs to call the cond function before the
2855   //   first iteration.
2856   // * Operands and body inputs to call the body function for the first
2857   //   iteration if the cond functions returns True or equivalent result.
2858   // * Operands and results to assign cond function arguments to op results if
2859   //   the cond function returns False or equivalent result. If the op is shape
2860   //   invariant, this does not hold as shapes can differ.
2861   // * All three pairs using cond inputs, body inputs and results as operand is
2862   //   a common source for all three.
2863   // * Body result and cond inputs to call the cond function for the subsequent
2864   //   iterations. Similarly, Body result should be compatible with body inputs
2865   //   and op results.
2866   //
2867   // Note that the operands and body results need not be compatible as they are
2868   // never converted from one to the another nor there is a common source
2869   // tensors. Compatibility requirement is not transitive.
2870 
2871   if (!shape_invariant &&
2872       failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
2873     return failure();
2874 
2875   // Skip the first pair as the While op operands and body function results does
2876   // not need to be compatible with each other.
2877   for (int i = 1; i < kNumRegionTypeLists; ++i)
2878     if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
2879       return failure();
2880 
2881   for (int i = 0; i < kNumRegionTypeLists; ++i)
2882     if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
2883       return failure();
2884 
2885   for (int i = 0; i < kNumRegionTypeLists; ++i)
2886     for (int j = i + 1; j < kNumRegionTypeLists; ++j)
2887       if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
2888                                                region_types[j])))
2889         return failure();
2890 
2891   return success();
2892 }
2893 
Verify(WhileOp op)2894 static LogicalResult Verify(WhileOp op) {
2895   auto cond_fn = op.cond_function();
2896   auto body_fn = op.body_function();
2897   if (!cond_fn) {
2898     return op.emitOpError("cond refers to an undefined function : ")
2899            << op.cond();
2900   }
2901   if (!body_fn) {
2902     return op.emitOpError("body refers to an undefined function : ")
2903            << op.body();
2904   }
2905 
2906   auto cond_fn_type = cond_fn.getType();
2907   auto body_fn_type = body_fn.getType();
2908 
2909   // Verify that the cond function has exactly one result.
2910   if (cond_fn_type.getNumResults() != 1)
2911     return op.emitOpError("requires cond function to have exactly one result");
2912 
2913   if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(),
2914                               /*body_input=*/body_fn_type.getInputs(),
2915                               /*body_result=*/body_fn_type.getResults(),
2916                               op.shape_invariant())))
2917     return failure();
2918   return success();
2919 }
2920 
2921 //===----------------------------------------------------------------------===//
2922 // WhileRegionOp
2923 //===----------------------------------------------------------------------===//
Verify(WhileRegionOp op)2924 static LogicalResult Verify(WhileRegionOp op) {
2925   // Verify that the condition generates a single tensor<i1> result.
2926   Operation *cond_yield = op.cond().front().getTerminator();
2927   if (cond_yield->getNumOperands() != 1)
2928     return op.emitOpError()
2929            << "condition should have a single tensor<i1> result";
2930 
2931   auto cond_type =
2932       cond_yield->getOperand(0).getType().dyn_cast<RankedTensorType>();
2933   if (!cond_type || !cond_type.getShape().equals({}) ||
2934       !cond_type.getElementType().isInteger(/*width=*/1))
2935     return op.emitOpError()
2936            << "condition should have a single tensor<i1> result";
2937 
2938   Operation *body_yield = op.body().front().getTerminator();
2939   if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
2940                               /*body_input=*/op.body().getArgumentTypes(),
2941                               /*body_result=*/body_yield->getOperandTypes(),
2942                               op.shape_invariant())))
2943     return failure();
2944   return success();
2945 }
2946 
2947 //===----------------------------------------------------------------------===//
2948 // WhileRegionOp LoopLikeOpInterface
2949 //===----------------------------------------------------------------------===//
2950 
getLoopBody()2951 Region &WhileRegionOp::getLoopBody() { return body(); }
2952 
isDefinedOutsideOfLoop(Value value)2953 bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) {
2954   // If the Op defining the value exists and the defining op is outside the
2955   // scope of this WhileRegion, then we can infer that its defined outside.
2956   // The defining Op is outside the scope of this WhileRegion if this
2957   // WhileRegionOp is not an ancestor of the defining op in the parent chain.
2958   Operation *def_op = value.getDefiningOp();
2959   return def_op && !getOperation()->isAncestor(def_op);
2960 }
2961 
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)2962 LogicalResult WhileRegionOp::moveOutOfLoop(
2963     llvm::ArrayRef<mlir::Operation *> ops) {
2964   // Move the hoisted value to just before the while.
2965   Operation *while_op = this->getOperation();
2966   for (auto op : ops) op->moveBefore(while_op);
2967   return success();
2968 }
2969 
2970 //===----------------------------------------------------------------------===//
2971 // WhileRegionOp canonicalization
2972 //===----------------------------------------------------------------------===//
2973 namespace {
2974 // Eliminate values that pass through the WhileRegionOp body.
2975 struct WhileRegionEliminatePassThrough
2976     : public OpRewritePattern<WhileRegionOp> {
2977   using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
2978 
matchAndRewritemlir::TF::__anon2bac52fa1411::WhileRegionEliminatePassThrough2979   LogicalResult matchAndRewrite(WhileRegionOp while_op,
2980                                 PatternRewriter &rewriter) const override {
2981     // Replace values that simply passthrough the body with extern values. The
2982     // block arguments of body and while match and so the corresponding cond
2983     // argument can be easily found.
2984     int old_num_operands = while_op.getNumOperands();
2985     int new_num_operands = old_num_operands;
2986     auto &body_block = while_op.body().front();
2987     auto &cond_block = while_op.cond().front();
2988     auto &yield = *body_block.getTerminator();
2989 
2990     // Bit mask indicating which operands will be removed.
2991     SmallVector<bool, 16> removed_operand(old_num_operands, false);
2992 
2993     for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
2994       auto body_arg = body_block.getArgument(op_idx);
2995       if (body_arg == yield.getOperand(op_idx)) {
2996         // Replace the use of the passthrough value with the while operand
2997         // in the body and condition regions, as well as the while output (if
2998         // type match)
2999         // TODO(jurahul): Use PatternRewriter API for IR modification.
3000         auto value = while_op.getOperand(op_idx);
3001         if (body_arg.getType() == value.getType())
3002           body_arg.replaceAllUsesWith(value);
3003 
3004         auto cond_arg = cond_block.getArgument(op_idx);
3005         if (cond_arg.getType() == value.getType())
3006           cond_arg.replaceAllUsesWith(value);
3007 
3008         auto result = while_op.getResult(op_idx);
3009         if (result.getType() == value.getType())
3010           result.replaceAllUsesWith(value);
3011       }
3012 
3013       // Now check if the operand is unused in both regions as well as the
3014       // result. If so, mark it for removal.
3015       if (body_block.getArgument(op_idx).use_empty() &&
3016           cond_block.getArgument(op_idx).use_empty() &&
3017           while_op.getResult(op_idx).use_empty()) {
3018         removed_operand[op_idx] = true;
3019         new_num_operands--;
3020       }
3021     }
3022 
3023     if (new_num_operands == old_num_operands) return failure();
3024 
3025     // Compress the operands, region arguments, and outputs.
3026     SmallVector<Value, 4> new_while_operands;
3027     SmallVector<Type, 4> new_result_types;
3028     new_while_operands.reserve(new_num_operands);
3029     new_result_types.reserve(new_num_operands);
3030 
3031     // Build new operands and result type.
3032     int next_idx = 0;
3033     for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3034       if (removed_operand[op_idx]) continue;
3035       new_while_operands.push_back(while_op.getOperand(op_idx));
3036       new_result_types.push_back(while_op.getResult(op_idx).getType());
3037       next_idx++;
3038     }
3039 
3040     // Create the new while operation.
3041     auto new_while_op =
3042         rewriter.create<WhileRegionOp>(while_op.getLoc(), new_result_types,
3043                                        new_while_operands, while_op.getAttrs());
3044 
3045     // Move region bodies to the new while.
3046     rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
3047                                 new_while_op.cond().end());
3048     rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
3049                                 new_while_op.body().end());
3050 
3051     auto &new_cond_block = new_while_op.cond().front();
3052     auto &new_body_block = new_while_op.body().front();
3053     auto &new_yield = *new_body_block.getTerminator();
3054 
3055     // Build a vector of new results. Also patch up the region bodies and
3056     // yield.
3057     SmallVector<Value, 4> new_results;
3058     next_idx = 0;
3059     for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
3060       if (removed_operand[op_idx]) {
3061         new_cond_block.eraseArgument(next_idx);
3062         new_body_block.eraseArgument(next_idx);
3063         new_yield.eraseOperand(next_idx);
3064         new_results.push_back(nullptr);
3065       } else {
3066         new_results.push_back(new_while_op.getResult(next_idx++));
3067       }
3068     }
3069 
3070     rewriter.replaceOp(while_op, new_results);
3071     return success();
3072   }
3073 };
3074 
3075 }  // anonymous namespace
3076 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3077 void WhileRegionOp::getCanonicalizationPatterns(
3078     OwningRewritePatternList &results, MLIRContext *context) {
3079   results.insert<WhileRegionEliminatePassThrough>(context);
3080 }
3081 
3082 //===----------------------------------------------------------------------===//
3083 // XdivyOp
3084 //===----------------------------------------------------------------------===//
3085 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3086 void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3087                                           MLIRContext *context) {
3088   results.insert<XdivyWithSqrtDivisor>(context);
3089 }
3090 
3091 //===----------------------------------------------------------------------===//
3092 // XlaBroadcastHelperOp
3093 //===----------------------------------------------------------------------===//
3094 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3095 LogicalResult XlaBroadcastHelperOp::inferReturnTypes(
3096     MLIRContext *context, Optional<Location> location, ValueRange operands,
3097     DictionaryAttr attributes, RegionRange regions,
3098     SmallVectorImpl<Type> &inferredReturnTypes) {
3099   auto loc = location ? *location : mlir::UnknownLoc::get(context);
3100   XlaBroadcastHelperOpAdaptor op(operands, attributes);
3101   if (failed(op.verify(loc))) {
3102     return failure();
3103   }
3104 
3105   Value lhs = op.lhs();
3106   Value rhs = op.rhs();
3107   auto set_unranked_results = [&]() {
3108     auto unranked_lhs = UnrankedTensorType::get(getElementTypeOrSelf(lhs));
3109     inferredReturnTypes.push_back(unranked_lhs);
3110     auto unranked_rhs = UnrankedTensorType::get(getElementTypeOrSelf(rhs));
3111     inferredReturnTypes.push_back(unranked_rhs);
3112     return success();
3113   };
3114 
3115   RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
3116   RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
3117   if (!lhs_ty || !rhs_ty) return set_unranked_results();
3118 
3119   int64_t lhs_rank = lhs_ty.getRank();
3120   int64_t rhs_rank = rhs_ty.getRank();
3121 
3122   DenseIntElementsAttr dims;
3123   if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
3124     return set_unranked_results();
3125   }
3126 
3127   if (dims.size() == 0) {
3128     if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
3129       return emitOptionalError(
3130           location,
3131           "if broadcast_dims is empty, both arguments must have equal rank or "
3132           "at least one argument must be a scalar");
3133     }
3134     inferredReturnTypes.push_back(lhs_ty);
3135     inferredReturnTypes.push_back(rhs_ty);
3136     return success();
3137   }
3138 
3139   const bool broadcast_lhs = lhs_rank < rhs_rank;
3140   RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
3141   RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
3142 
3143   if (dims.size() != min_rank_ty.getRank()) {
3144     return emitOptionalError(
3145         location,
3146         "broadcast_dims must have size equal to the smaller argument rank");
3147   }
3148 
3149   int64_t output_rank = max_rank_ty.getRank();
3150   llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
3151   llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
3152   for (auto item : llvm::enumerate(dims)) {
3153     int64_t index = item.index();
3154     int64_t dim = item.value().getSExtValue();
3155     if (dim < 0 || dim > output_rank) {
3156       return emitOptionalError(location, "out of range broadcast dim");
3157     }
3158     if (is_broadcasted[dim]) {
3159       return emitOptionalError(location, "broadcast_dims has duplicates");
3160     }
3161     broadcast_shape[dim] = min_rank_ty.getDimSize(index);
3162     is_broadcasted[dim] = true;
3163   }
3164 
3165   if (broadcast_lhs) {
3166     inferredReturnTypes.push_back(
3167         RankedTensorType::get(broadcast_shape, lhs_ty.getElementType()));
3168     inferredReturnTypes.push_back(rhs_ty);
3169   } else {
3170     inferredReturnTypes.push_back(lhs_ty);
3171     inferredReturnTypes.push_back(
3172         RankedTensorType::get(broadcast_shape, rhs_ty.getElementType()));
3173   }
3174   return success();
3175 }
3176 
3177 //===----------------------------------------------------------------------===//
3178 // XlaSetDynamicDimensionSizeOp
3179 //===----------------------------------------------------------------------===//
3180 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3181 LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypes(
3182     MLIRContext *context, Optional<Location> location, ValueRange operands,
3183     DictionaryAttr attributes, RegionRange regions,
3184     SmallVectorImpl<Type> &inferredReturnTypes) {
3185   inferredReturnTypes.assign({operands.front().getType()});
3186   return success();
3187 }
3188 
3189 }  // namespace TF
3190 }  // namespace mlir
3191 
3192 //===----------------------------------------------------------------------===//
3193 // TableGen'd op method definitions
3194 //===----------------------------------------------------------------------===//
3195 
3196 #define GET_OP_CLASSES
3197 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc"
3198