• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the operations used in the MHLO dialect.
17 
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 
20 #include <assert.h>
21 #include <stddef.h>
22 #include <stdint.h>
23 
24 #include <algorithm>
25 #include <functional>
26 
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/Twine.h"
35 #include "llvm/ADT/iterator_range.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/FormatVariadic.h"
38 #include "llvm/Support/MathExtras.h"
39 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
40 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
41 #include "mlir-hlo/utils/convert_op_folder.h"
42 #include "mlir-hlo/utils/hlo_utils.h"
43 #include "mlir/Dialect/Shape/IR/Shape.h"
44 #include "mlir/Dialect/StandardOps/IR/Ops.h"
45 #include "mlir/Dialect/Tensor/IR/Tensor.h"
46 #include "mlir/IR/Attributes.h"
47 #include "mlir/IR/Builders.h"
48 #include "mlir/IR/BuiltinAttributes.h"
49 #include "mlir/IR/BuiltinTypes.h"
50 #include "mlir/IR/Dialect.h"
51 #include "mlir/IR/Location.h"
52 #include "mlir/IR/MLIRContext.h"
53 #include "mlir/IR/Matchers.h"
54 #include "mlir/IR/OpDefinition.h"
55 #include "mlir/IR/OpImplementation.h"
56 #include "mlir/IR/Operation.h"
57 #include "mlir/IR/OperationSupport.h"
58 #include "mlir/IR/PatternMatch.h"
59 #include "mlir/IR/TypeUtilities.h"
60 #include "mlir/IR/Types.h"
61 #include "mlir/IR/Value.h"
62 #include "mlir/Support/LLVM.h"
63 #include "mlir/Support/LogicalResult.h"
64 #include "mlir/Transforms/InliningUtils.h"
65 
66 namespace mlir {
67 #include "hlo_patterns.cc.inc"
68 }  // namespace mlir
69 
70 namespace mlir {
71 namespace mhlo {
72 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)73 Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
74                                             Type type, Location loc) {
75   // HLO dialect constants only support ElementsAttr unlike standard dialect
76   // constant which supports all attributes.
77   if (value.isa<ElementsAttr>())
78     return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
79   return nullptr;
80 }
81 
82 template <typename T>
Verify(T op)83 static LogicalResult Verify(T op) {
84   return success();
85 }
86 
87 namespace {
88 
89 //===----------------------------------------------------------------------===//
90 // Utilities for the canonicalize patterns
91 //===----------------------------------------------------------------------===//
92 
93 // Verifies that dimension attribute for the op correctly indexes in operand or
94 // result shape.
95 template <typename OpT>
VerifyDimAttr(OpT op)96 static LogicalResult VerifyDimAttr(OpT op) {
97   int64_t rank = -1;
98   if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
99     rank = ty.getRank();
100   } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
101     rank = ty.getRank();
102   } else {
103     return success();
104   }
105 
106   int64_t dim = op.dimension();
107   if (dim < 0 || dim >= rank)
108     return op.emitOpError() << "requires dimension attribute in range [0, "
109                             << rank << "); found (" << dim << ")";
110   return success();
111 }
112 
113 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)114 DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
115                                         Builder* builder) {
116   RankedTensorType ty = RankedTensorType::get(
117       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
118   return DenseIntElementsAttr::get(ty, values);
119 }
120 
121 // Given the start indices and slice sizes for a dynamic-slice that can be
122 // converted to a static slice, returns the limits for the static slice.
BuildSliceLimits(DenseIntElementsAttr start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)123 DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
124                                       DenseIntElementsAttr slice_sizes,
125                                       Builder* builder) {
126   SmallVector<int64_t, 4> slice_limits;
127   for (int64_t i = 0; i < slice_sizes.getNumElements(); ++i) {
128     int64_t start_index = start_indices.getValue<IntegerAttr>(i).getInt();
129     int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
130     slice_limits.push_back(start_index + slice_size);
131   }
132   return GetI64ElementsAttr(slice_limits, builder);
133 }
134 
135 /// Replaces the given op with the contents of the given single-block region,
136 /// using the operands of the block terminator to replace operation results.
ReplaceOpWithRegion(PatternRewriter & rewriter,Operation * op,Region & region,ValueRange blockArgs={})137 static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
138                                 Region& region, ValueRange blockArgs = {}) {
139   assert(llvm::hasSingleElement(region) && "expected single-block region");
140   Block* block = &region.front();
141   Operation* terminator = block->getTerminator();
142   ValueRange results = terminator->getOperands();
143   rewriter.mergeBlockBefore(block, op, blockArgs);
144   rewriter.replaceOp(op, results);
145   rewriter.eraseOp(terminator);
146 }
147 
148 #include "mhlo_canonicalize.inc"
149 
150 // Common shape function helper for RngNormal and RngUniform.
rngInferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)151 static LogicalResult rngInferReturnTypeComponents(
152     MLIRContext* context, Optional<Location> location, ValueRange operands,
153     DictionaryAttr attributes, RegionRange regions,
154     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
155   if (operands.size() != 3)
156     return emitOptionalError(location, "expected 3 operands");
157 
158   SmallVector<int64_t> shapeVector;
159   Value shapeOperand = operands[2];
160   auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
161   Type elementType = getElementTypeOrSelf(operands[1]);
162 
163   // Match constant shape arguments.
164   DenseIntElementsAttr shape;
165   if (!matchPattern(shapeOperand, m_Constant(&shape))) {
166     if (!shapeOperandType.hasRank()) {
167       inferredReturnShapes.emplace_back(elementType);
168       return success();
169     }
170     if (shapeOperandType.getRank() != 1)
171       return emitOptionalError(location, "shape operand required to be 1D");
172     int size = shapeOperandType.getDimSize(0);
173     if (size == ShapedType::kDynamicSize) {
174       inferredReturnShapes.emplace_back(elementType);
175       return success();
176     }
177     shapeVector.resize(size, ShapedType::kDynamicSize);
178     inferredReturnShapes.emplace_back(shapeVector, elementType);
179     return success();
180   }
181 
182   shapeVector.reserve(shape.size());
183   for (const APInt& fp : shape.getIntValues())
184     shapeVector.push_back(fp.getSExtValue());
185   inferredReturnShapes.emplace_back(shapeVector, elementType);
186   return success();
187 }
188 
189 // Returns a new scalar integer value having type `type`. Here `type` must be
190 // an integer or index type.
MaybeCastTo(OpBuilder & b,Location loc,Value value,Type type)191 Value MaybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
192   if (type == value.getType()) return value;
193   assert(type.isIndex() || value.getType().isIndex());
194   return b.create<IndexCastOp>(loc, value, type);
195 }
196 
197 }  // namespace
198 
199 //===----------------------------------------------------------------------===//
200 // ReduceScatterOp
201 //===----------------------------------------------------------------------===//
202 
Verify(ReduceScatterOp op)203 static LogicalResult Verify(ReduceScatterOp op) {
204   return mlir::hlo::VerifyReduceScatter(
205       op,
206       /*operand_types=*/{op.operand().getType()},
207       /*result_types=*/{op.getType()},
208       /*scatter_dimension=*/op.scatter_dimension());
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // ConstOp
213 //===----------------------------------------------------------------------===//
214 
fold(ArrayRef<Attribute> operands)215 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
216   assert(operands.empty() && "constant has no operands");
217 
218   // Return the held attribute value.
219   return value();
220 }
221 
222 // Builds a constant op with the specified attribute `value`.
build(OpBuilder & builder,OperationState & result,Attribute value)223 void ConstOp::build(OpBuilder& builder, OperationState& result,
224                     Attribute value) {
225   Type type;
226   if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
227     type = elemAttr.getType();
228   } else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
229              value.isa<IntegerAttr>()) {
230     // All XLA types must be tensor types. In the build() method, we want to
231     // provide more flexibility by allowing attributes of scalar types. But we
232     // need to wrap it up with ElementsAttr to construct valid XLA constants.
233     type = RankedTensorType::get(/*shape=*/{}, value.getType());
234     value = DenseElementsAttr::get(type.cast<TensorType>(), value);
235   }
236 
237   // TODO: support other XLA specific types.
238   assert(type && "unsupported attribute type for building mhlo.constant");
239   result.types.push_back(type);
240   result.addAttribute("value", value);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // DotGeneralOp
245 //===----------------------------------------------------------------------===//
246 
Verify(DotGeneralOp op)247 static LogicalResult Verify(DotGeneralOp op) {
248   auto dot_dimension_numbers = op.dot_dimension_numbers();
249   int64_t lhs_batching_dimensions_size = llvm::size(
250       dot_dimension_numbers.lhs_batching_dimensions().getValues<int64_t>());
251   int64_t rhs_batching_dimensions_size = llvm::size(
252       dot_dimension_numbers.rhs_batching_dimensions().getValues<int64_t>());
253   if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) {
254     return op.emitError()
255            << "lhs and rhs should have the same number of batching dimensions";
256   }
257   int64_t lhs_contracting_dimensions_size = llvm::size(
258       dot_dimension_numbers.lhs_contracting_dimensions().getValues<int64_t>());
259   int64_t rhs_contracting_dimensions_size = llvm::size(
260       dot_dimension_numbers.rhs_contracting_dimensions().getValues<int64_t>());
261   if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) {
262     return op.emitError() << "lhs and rhs should have the same number of "
263                              "contracting dimensions";
264   }
265   return success();
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // GatherOp
270 //===----------------------------------------------------------------------===//
271 
272 // Converts gather ops to slice ops in case we have a single set of constant
273 // indices.
274 struct GatherSlice : public OpRewritePattern<GatherOp> {
275   using OpRewritePattern<GatherOp>::OpRewritePattern;
276 
matchAndRewritemlir::mhlo::GatherSlice277   LogicalResult matchAndRewrite(GatherOp gather,
278                                 PatternRewriter& rewriter) const override {
279     DenseIntElementsAttr index;
280     if (!matchPattern(gather.start_indices(), m_Constant(&index)))
281       return failure();
282 
283     const auto& dnums = gather.dimension_numbers();
284     if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
285       return failure();
286 
287     // TODO(tberghammer): Remove when the verifier catches this case what is
288     // invalid if all previous condition holds.
289     if (index.getNumElements() != dnums.start_index_map().getNumElements())
290       return failure();
291 
292     auto slice_end =
293         llvm::to_vector<8>(gather.slice_sizes().getValues<int64_t>());
294     llvm::SmallVector<int64_t, 8> slice_start(slice_end.size(), 0);
295     for (auto it : llvm::zip(dnums.start_index_map().getIntValues(),
296                              index.getIntValues())) {
297       int64_t map_index = std::get<0>(it).getSExtValue();
298       int64_t offset = std::get<1>(it).getSExtValue();
299       slice_start[map_index] += offset;
300       slice_end[map_index] += offset;
301     }
302 
303     llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
304     llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
305     for (size_t i = 0; i < slice_end.size(); ++i) {
306       slice_shape[i] = slice_end[i] - slice_start[i];
307     }
308     Type element_type = gather.getType().cast<TensorType>().getElementType();
309     auto slice_type = RankedTensorType::get(slice_shape, element_type);
310     Value result = rewriter.create<SliceOp>(
311         gather.getLoc(), slice_type, gather.getOperand(0),
312         GetI64ElementsAttr(slice_start, &rewriter),
313         GetI64ElementsAttr(slice_end, &rewriter),
314         GetI64ElementsAttr(slice_stride, &rewriter));
315 
316     if (dnums.collapsed_slice_dims().getNumElements() > 0) {
317       auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
318           dnums.collapsed_slice_dims().getIntValues(),
319           [](const llvm::APInt& i) { return i.getSExtValue(); }));
320       llvm::SmallVector<int64_t, 8> reshape_shape;
321       for (size_t i = 0; i < slice_shape.size(); ++i) {
322         if (llvm::count(collapsed_slice_dims, i) == 0) {
323           reshape_shape.push_back(slice_shape[i]);
324         }
325       }
326       auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
327       result =
328           rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
329     }
330 
331     result.setType(gather.getType());
332     rewriter.replaceOp(gather, result);
333     return success();
334   }
335 };
336 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)337 void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
338                                            MLIRContext* context) {
339   results.insert<GatherSlice>(context);
340 }
341 
342 namespace {
343 
344 // following https://www.tensorflow.org/xla/operation_semantics#gather
345 // The bounds for the output array along dimension i is computed as follows:
346 // (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
347 // then we pick
348 // the corresponding dimension bounds out of start_indices.shape, skipping
349 // index_vector_dim
350 // (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
351 // start_indices.shape.dims[k+1] otherwise).
352 // (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
353 // then we pick
354 // the corresponding bound out of slice_sizes after accounting for
355 // collapsed_slice_dims
356 // (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
357 // slice_sizes with the bounds at indices collapsed_slice_dims removed).
358 
GetSliceSizeValues(GatherOp * gather,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & slice_sizes)359 void GetSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
360                         ValueRange operands,
361                         SmallVectorImpl<Value>& slice_sizes) {
362   for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
363     slice_sizes.push_back(builder.create<ConstantIndexOp>(loc, val));
364   }
365 }
366 
GetSliceSizeValues(DynamicGatherOp * d_gather,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & slice_size_values)367 void GetSliceSizeValues(DynamicGatherOp* d_gather, OpBuilder& builder,
368                         Location loc, ValueRange operands,
369                         SmallVectorImpl<Value>& slice_size_values) {
370   DynamicGatherOp::Adaptor adaptor(operands);
371   Value slice_sizes = adaptor.slice_sizes();
372   auto slice_sizes_ty = slice_sizes.getType().cast<ShapedType>();
373   for (int64_t i = 0; i < slice_sizes_ty.getDimSize(0); ++i) {
374     Value idx = builder.create<ConstantIndexOp>(loc, i);
375     slice_size_values.push_back(
376         builder.create<tensor::ExtractOp>(loc, slice_sizes, idx));
377   }
378 }
379 
380 template <typename Op>
GatherShapeInferImpl(Op * op,OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)381 LogicalResult GatherShapeInferImpl(
382     Op* op, OpBuilder& builder, ValueRange operands,
383     SmallVectorImpl<Value>& reifiedReturnShapes) {
384   // Not support unranked pad a.t.m.
385   auto result_ty =
386       op->getResult().getType().template dyn_cast<RankedTensorType>();
387   if (!result_ty) return failure();
388 
389   typename Op::Adaptor adaptor(operands);
390   Value start_indices = adaptor.start_indices();
391 
392   Location loc = op->getLoc();
393   int result_rank = result_ty.getRank();
394   Type shape_scalar_type =
395       start_indices.getType().cast<ShapedType>().getElementType();
396   auto to_shape_scalar_type = [&](Value v) {
397     return MaybeCastTo(builder, loc, v, shape_scalar_type);
398   };
399 
400   auto dimension_numbers = op->dimension_numbers();
401   SmallVector<int64_t, 4> collapsed_slice_dims(
402       dimension_numbers.collapsed_slice_dims().template getValues<int64_t>());
403   SmallVector<int64_t, 4> offset_dims(
404       dimension_numbers.offset_dims().template getValues<int64_t>());
405   int64_t index_vector_dim =
406       dimension_numbers.index_vector_dim().getValue().getSExtValue();
407 
408   SmallVector<Value, 4> slice_sizes;
409   GetSliceSizeValues(op, builder, loc, operands, slice_sizes);
410   // Convert to `shape_scalar_type`
411   llvm::transform(slice_sizes, slice_sizes.begin(),
412                   [&](Value v) { return to_shape_scalar_type(v); });
413 
414   // we label dimensions in the output array not in offset_dims as batch_dims
415   SmallVector<int64_t, 4> batch_dims;
416   for (int64_t i = 0; i < result_rank; ++i) {
417     if (std::find(offset_dims.begin(), offset_dims.end(), i) ==
418         offset_dims.end()) {
419       batch_dims.push_back(i);
420     }
421   }
422   // adjusted_slice_sizes is slice_sizes with the bounds at indices
423   // collapsed_slice_dims removed
424   SmallVector<Value, 4> adjusted_slice_sizes;
425   for (int64_t i = 0; i < slice_sizes.size(); ++i) {
426     if (std::find(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
427                   i) == collapsed_slice_dims.end()) {
428       adjusted_slice_sizes.push_back(slice_sizes[i]);
429     }
430   }
431 
432   SmallVector<Value, 4> shape_values;
433   shape_values.reserve(result_rank);
434   for (int64_t i = 0; i < result_rank; ++i) {
435     auto iter = std::find(batch_dims.begin(), batch_dims.end(), i);
436     if (iter != batch_dims.end()) {
437       // i is present in batch_dims
438       int64_t k = std::distance(batch_dims.begin(), iter);
439       if (k < index_vector_dim) {
440         shape_values.push_back(to_shape_scalar_type(
441             builder.create<tensor::DimOp>(loc, start_indices, k)));
442       } else {
443         shape_values.push_back(to_shape_scalar_type(
444             builder.create<tensor::DimOp>(loc, start_indices, k + 1)));
445       }
446     } else {
447       // i is present in offset_dims
448       auto offset_dims_iter =
449           std::find(offset_dims.begin(), offset_dims.end(), i);
450       assert(offset_dims_iter != offset_dims.end());
451       int64_t k = std::distance(offset_dims.begin(), offset_dims_iter);
452       assert(k < adjusted_slice_sizes.size());
453       shape_values.push_back(adjusted_slice_sizes[k]);
454     }
455   }
456 
457   Value output_shape = builder.create<tensor::FromElementsOp>(
458       loc, shape_scalar_type, shape_values);
459   reifiedReturnShapes.push_back(output_shape);
460 
461   return success();
462 }
463 
464 }  // namespace
465 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)466 LogicalResult GatherOp::reifyReturnTypeShapes(
467     OpBuilder& builder, ValueRange operands,
468     SmallVectorImpl<Value>& reifiedReturnShapes) {
469   return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // DynamicGatherOp
474 //===----------------------------------------------------------------------===//
475 //
476 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)477 LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
478     OpBuilder& builder, ValueRange operands,
479     SmallVectorImpl<Value>& reifiedReturnShapes) {
480   return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
481 }
482 
483 //===----------------------------------------------------------------------===//
484 // GetDimensionSizeOp
485 //===----------------------------------------------------------------------===//
486 //
Verify(GetDimensionSizeOp op)487 static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
488 
489 /// Fold get_dimension_size when the said shape dimension is a constant.
fold(ArrayRef<Attribute> attrs)490 OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
491   RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
492   if (!type) return {};
493 
494   int32_t dim = dimension();
495   if (type.isDynamic(dim)) return {};
496   // The result type is always is a 0-d i32 tensor.
497   return DenseIntElementsAttr::get<int32_t>(
498       getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // IotaOp
503 //===----------------------------------------------------------------------===//
504 
Verify(IotaOp op)505 static LogicalResult Verify(IotaOp op) {
506   auto shape = op.getType().cast<ShapedType>();
507   if (!shape.hasRank()) return success();
508 
509   if (shape.getRank() == 0)
510     return op.emitOpError() << "does not support scalars.";
511 
512   auto iota_dimension = op.iota_dimension();
513   if (iota_dimension >= shape.getRank() || iota_dimension < 0)
514     return op.emitOpError() << "iota dimension cannot go beyond the output "
515                                "rank or be negative.";
516   return success();
517 }
518 
519 // Iota operations across multiple dimensions can be reduced to an iota and a
520 // ranked broadcast.
521 struct IotaBroadcast : public OpRewritePattern<IotaOp> {
522   using OpRewritePattern<IotaOp>::OpRewritePattern;
523 
matchAndRewritemlir::mhlo::IotaBroadcast524   LogicalResult matchAndRewrite(IotaOp iota,
525                                 PatternRewriter& rewriter) const override {
526     auto result_ty = iota.getType().cast<ShapedType>();
527     if (!result_ty.hasRank() || result_ty.getRank() < 2) {
528       return failure();
529     }
530 
531     auto iota_dimension = iota.iota_dimension();
532 
533     auto iota_type = RankedTensorType::get(
534         {result_ty.getDimSize(iota_dimension)}, result_ty.getElementType());
535 
536     auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
537                                             rewriter.getI64IntegerAttr(0));
538 
539     auto broadcast_attr = DenseIntElementsAttr::get(
540         RankedTensorType::get({1}, rewriter.getIntegerType(64)),
541         {iota_dimension});
542     rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
543                                                   broadcast_attr);
544     return success();
545   }
546 };
547 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)548 void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
549                                          MLIRContext* context) {
550   results.insert<IotaBroadcast>(context);
551 }
552 
fold(ArrayRef<Attribute> operands)553 OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
554   auto dimension = iota_dimension();
555   auto result_ty = getResult().getType().cast<ShapedType>();
556   if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
557     Builder builder(getContext());
558     return builder.getZeroAttr(result_ty);
559   }
560 
561   return {};
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // DynamicIotaOp
566 //===----------------------------------------------------------------------===//
567 
568 namespace {
569 
570 struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
571   using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
572 
matchAndRewritemlir::mhlo::__anon7800207c0611::DynamicIotaIsStatic573   LogicalResult matchAndRewrite(DynamicIotaOp iota,
574                                 PatternRewriter& rewriter) const override {
575     auto result_ty = iota.getType().cast<ShapedType>();
576     if (!result_ty.hasStaticShape()) {
577       return failure();
578     }
579 
580     rewriter.replaceOpWithNewOp<IotaOp>(iota, result_ty, iota.iota_dimension());
581     return success();
582   }
583 };
584 
585 // Dynamic Iota operations across multiple dimensions can be reduced to an iota
586 // and a ranked broadcast.
587 struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
588   using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
589 
matchAndRewritemlir::mhlo::__anon7800207c0611::DynamicIotaBroadcast590   LogicalResult matchAndRewrite(DynamicIotaOp iota,
591                                 PatternRewriter& rewriter) const override {
592     auto result_ty = iota.getType().cast<ShapedType>();
593     if (!result_ty.hasRank() || result_ty.getRank() < 2) {
594       return failure();
595     }
596 
597     auto iota_dimension = iota.iota_dimension();
598     auto iota_dimension_int = iota_dimension;
599 
600     auto converted_shape = rewriter.create<IndexCastOp>(
601         iota.getLoc(),
602         RankedTensorType::get(
603             iota.output_shape().getType().cast<ShapedType>().getShape(),
604             rewriter.getI64Type()),
605         iota.output_shape());
606 
607     auto sliced_shape = rewriter.create<SliceOp>(
608         iota.getLoc(), converted_shape,
609         GetI64ElementsAttr(iota_dimension_int, &rewriter),
610         GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
611         GetI64ElementsAttr(1, &rewriter));
612 
613     auto converted_sliced_shape = rewriter.create<IndexCastOp>(
614         iota.getLoc(),
615         RankedTensorType::get(
616             {1},
617             iota.output_shape().getType().cast<ShapedType>().getElementType()),
618         sliced_shape);
619 
620     auto iota_type = RankedTensorType::get(
621         {result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
622 
623     auto new_iota = rewriter.create<DynamicIotaOp>(
624         iota.getLoc(), iota_type, converted_sliced_shape,
625         rewriter.getI64IntegerAttr(0));
626 
627     auto broadcast_attr = DenseIntElementsAttr::get(
628         RankedTensorType::get({1}, rewriter.getIntegerType(64)),
629         {iota_dimension});
630     rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
631         iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
632     return success();
633   }
634 };
635 
636 }  // namespace
637 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)638 void DynamicIotaOp::getCanonicalizationPatterns(
639     OwningRewritePatternList& results, MLIRContext* context) {
640   results.insert<DynamicIotaIsStatic>(context);
641   results.insert<DynamicIotaBroadcast>(context);
642 }
643 
castToIndexTensor(OpBuilder & builder,Location loc,Value shape_op)644 static Value castToIndexTensor(OpBuilder& builder, Location loc,
645                                Value shape_op) {
646   ShapedType result_ty = shape::getExtentTensorType(
647       builder.getContext(),
648       shape_op.getType().cast<ShapedType>().getDimSize(0));
649   if (shape_op.getType() == result_ty) return shape_op;  // Nothing to do.
650   return builder.create<IndexCastOp>(loc, shape_op, result_ty);
651 }
652 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)653 LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
654     OpBuilder& builder, ValueRange operands,
655     SmallVectorImpl<Value>& reifiedReturnShapes) {
656   DynamicIotaOp::Adaptor adaptor(operands);
657   reifiedReturnShapes.push_back(
658       castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
659   return success();
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // DynamicUpdateSliceOp
664 //===----------------------------------------------------------------------===//
665 
Verify(DynamicUpdateSliceOp op)666 static LogicalResult Verify(DynamicUpdateSliceOp op) {
667   OperandRange indices = op.start_indices();
668   if (indices.size() <= 1) return success();
669 
670   // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
671   // is OK to cast indices to ShapedType here.
672   auto idx_tensor = indices.take_front().front().getType().cast<ShapedType>();
673   Type first_elem_ty = idx_tensor.getElementType();
674   Type elem_ty;
675 
676   for (auto idx : llvm::drop_begin(indices, 1)) {
677     idx_tensor = idx.getType().cast<ShapedType>();
678     elem_ty = idx_tensor.getElementType();
679 
680     if (first_elem_ty != elem_ty) {
681       return op.emitOpError() << "start indices must have same element type "
682                                  "(encountered mismatch: "
683                               << first_elem_ty << " vs " << elem_ty << ")";
684     }
685   }
686   return success();
687 }
688 
fold(ArrayRef<Attribute> operands)689 OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef<Attribute> operands) {
690   auto operand_shape = this->operand().getType().cast<RankedTensorType>();
691   auto update_shape = this->update().getType().cast<RankedTensorType>();
692 
693   if (operand_shape != update_shape || !operand_shape.hasStaticShape()) {
694     return {};
695   }
696 
697   // Ensure that indices are 0 constants. The 0 check mostly ensures
698   // correctness. For non-constants, the pattern does not fold to avoid hiding
699   // the behavior of incorrect user input.
700   for (Value index : this->start_indices()) {
701     DenseIntElementsAttr de_attr;
702     if (!matchPattern(index, m_Constant(&de_attr))) return {};
703     int start_val = de_attr.getSplatValue<IntegerAttr>().getInt();
704     if (start_val != 0) return {};
705   }
706   return this->update();
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // AbsOp
711 //===----------------------------------------------------------------------===//
712 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)713 LogicalResult AbsOp::inferReturnTypes(
714     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
715     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
716   auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
717   Type element_ty = operand_ty.getElementType();
718   if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
719     element_ty = complex_ty.getElementType();
720   }
721 
722   Type result_ty;
723   if (operand_ty.hasRank()) {
724     result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
725   } else {
726     result_ty = UnrankedTensorType::get(element_ty);
727   }
728   inferredReturnTypes.push_back(result_ty);
729   return success();
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // CollectivePermuteOp
734 //===----------------------------------------------------------------------===//
735 
Verify(CollectivePermuteOp op)736 static LogicalResult Verify(CollectivePermuteOp op) {
737   return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs(
738       op, op.source_target_pairs());
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // ConvertOp
743 //===----------------------------------------------------------------------===//
744 
build(OpBuilder & builder,OperationState & result,Value operand,Type result_element_ty)745 void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
746                       Type result_element_ty) {
747   Type result_ty;
748   Type operand_ty = operand.getType();
749   if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
750     result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
751   } else {
752     result_ty = UnrankedTensorType::get(result_element_ty);
753   }
754   build(builder, result, result_ty, operand);
755 }
756 
fold(ArrayRef<Attribute> operands)757 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
758   auto operand_ty = getOperand().getType().cast<TensorType>();
759   auto result_ty = getResult().getType().cast<TensorType>();
760   if (operand_ty == result_ty) return getOperand();
761 
762   // If the result has non-static shape, a convert op is necessary to go from
763   // static shape to non-static shape.
764   if (!result_ty.hasStaticShape()) return {};
765 
766   // TODO(hinsu): Handle unsigned types.
767   if (operand_ty.getElementType().isUnsignedInteger() ||
768       result_ty.getElementType().isUnsignedInteger()) {
769     return {};
770   }
771 
772   // If the operand is constant, we can do the conversion now.
773   if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
774     return hlo::ConvertElementsAttr(elementsAttr,
775                                     getElementTypeOrSelf(getResult()));
776   }
777 
778   return {};
779 }
780 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)781 void ConvertOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
782                                             MLIRContext* context) {
783   results.insert<EliminateIdentityConvert>(context);
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // DequantizeOp
788 //===----------------------------------------------------------------------===//
789 
Verify(DequantizeOp op)790 static LogicalResult Verify(DequantizeOp op) {
791   auto input_type = op.input().getType().dyn_cast<ShapedType>();
792   auto output_type = op.output().getType().dyn_cast<ShapedType>();
793   if (!input_type || !output_type) {
794     return op.emitError() << "ranked input and output.";
795   }
796   auto input_shape = input_type.getShape();
797   auto output_shape = output_type.getShape().vec();
798   if (op.transpose_output()) {
799     std::reverse(output_shape.begin(), output_shape.end());
800   }
801 
802   // Check the input rank and output rank are same, and also the lower
803   // dimensions are same.
804   if (input_shape.size() != output_shape.size() ||
805       !std::equal(input_shape.begin(),
806                   std::next(input_shape.begin(), input_shape.size() - 1),
807                   output_shape.begin())) {
808     return op.emitError() << "mismatched dimensions.";
809   }
810 
811   // Check that the last dimension of the output is 2x or 4x of that of the
812   // input depending on the unpacked input is 16 or 8 bits.
813   int input_last_dim = *input_shape.rbegin();
814   int output_last_dim = *output_shape.rbegin();
815   int scale_factor = op.is_16bits() ? 2 : 4;
816   if (output_last_dim != scale_factor * input_last_dim) {
817     return op.emitError() << "last dimension of output should be "
818                           << scale_factor << "x of the input.";
819   }
820 
821   return success();
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // GetTupleElementOp
826 //===----------------------------------------------------------------------===//
827 
Verify(GetTupleElementOp op)828 static LogicalResult Verify(GetTupleElementOp op) {
829   auto indexVal = op.index();
830   auto operandType = op.getOperand().getType().cast<TupleType>();
831   if (indexVal >= operandType.size()) {
832     return op.emitOpError(
833         llvm::formatv("index {0} is out of bounds of operand with size {1}",
834                       indexVal, operandType.size()));
835   }
836 
837   auto expectedType = operandType.getType(indexVal);
838   if (op.getType() != expectedType) {
839     return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
840                                         op.getType(), expectedType));
841   }
842   return success();
843 }
844 
fold(ArrayRef<Attribute> operands)845 OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
846   if (auto tuple_op = getOperand().getDefiningOp<mhlo::TupleOp>()) {
847     return tuple_op.getOperand(index());
848   }
849 
850   return {};
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // TupleOp
855 //===----------------------------------------------------------------------===//
856 
Verify(TupleOp op)857 static LogicalResult Verify(TupleOp op) {
858   auto opType = op.getType().dyn_cast<TupleType>();
859   if (!opType) return op.emitOpError("tuple op with non-tuple result");
860   if (op.getNumOperands() != opType.size())
861     return op.emitOpError(
862         "number of operands to tuple expected to match number of types in "
863         "resultant tuple type");
864   for (auto it : llvm::enumerate(
865            llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) {
866     if (std::get<0>(it.value()) != std::get<1>(it.value()))
867       return op.emitOpError("has return type mismatch at ")
868              << it.index() << "th value (" << std::get<0>(it.value())
869              << " != " << std::get<1>(it.value()) << ")";
870   }
871   return success();
872 }
873 
874 namespace {
875 
876 // Pattern for unpacking and repacking the same tuple.
877 struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
878   using OpRewritePattern<TupleOp>::OpRewritePattern;
879 
matchAndRewritemlir::mhlo::__anon7800207c0711::UnpackRepackSameTuple880   LogicalResult matchAndRewrite(TupleOp op,
881                                 PatternRewriter& rewriter) const override {
882     if (op.val().empty()) return failure();
883 
884     Value first_element = op.val().front();
885     auto first_element_op = first_element.getDefiningOp<GetTupleElementOp>();
886     if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
887       return failure();
888 
889     Value tuple_predecessor = first_element_op.getOperand();
890     if (tuple_predecessor.getType() != op.getType()) return failure();
891 
892     for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
893       auto element_op =
894           element_and_idx.value().getDefiningOp<GetTupleElementOp>();
895       if (!element_op ||
896           element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
897           element_op.getOperand() != tuple_predecessor)
898         return failure();
899     }
900 
901     rewriter.replaceOp(op, tuple_predecessor);
902     return success();
903   }
904 };
905 
906 }  // namespace
907 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)908 void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
909                                           MLIRContext* context) {
910   results.insert<UnpackRepackSameTuple>(context);
911 }
912 
913 //===----------------------------------------------------------------------===//
914 // AllToAllOp
915 //===----------------------------------------------------------------------===//
916 
Verify(AllToAllOp op)917 static LogicalResult Verify(AllToAllOp op) {
918   // If operand is ranked, size of split dimension should be a multiple of split
919   // count.
920   auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
921   if (!type) return success();
922   auto split_dim_size = type.getDimSize(op.split_dimension());
923   auto split_count = op.split_count();
924   if (split_dim_size % split_count != 0) {
925     return op.emitError() << "split dimension has size " << split_dim_size
926                           << ", expected to be a multiple of split_count "
927                           << split_count;
928   }
929   return success();
930 }
931 
932 //===----------------------------------------------------------------------===//
933 // AllGatherOp
934 //===----------------------------------------------------------------------===//
935 
Verify(AllGatherOp op)936 static LogicalResult Verify(AllGatherOp op) {
937   // If operand and result are both ranked, then the size of the gather
938   // dimension in the result should be a multiple of the size of the gather
939   // dimension in the operand.
940   auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
941   auto resultType = op.getType().dyn_cast<RankedTensorType>();
942   uint64_t allGatherDimIndex = op.all_gather_dim();
943   if (!operandType || !resultType ||
944       operandType.isDynamicDim(allGatherDimIndex) ||
945       resultType.isDynamicDim(allGatherDimIndex))
946     return success();
947   if (operandType.getDimSize(allGatherDimIndex) == 0)
948     return op.emitOpError() << "operand gather dimension cannot be zero.";
949   if ((resultType.getDimSize(allGatherDimIndex) %
950        operandType.getDimSize(allGatherDimIndex)) != 0)
951     return op.emitOpError()
952            << "result gather dimension has size "
953            << resultType.getDimSize(allGatherDimIndex)
954            << ", expected to be a multiple of operand gather dimension size "
955            << operandType.getDimSize(allGatherDimIndex);
956 
957   return success();
958 }
959 
960 //===----------------------------------------------------------------------===//
961 // BroadcastOp
962 //===----------------------------------------------------------------------===//
963 
964 // TODO(b/129012527) These should be expressed as type constraints.
Verify(BroadcastOp op)965 static LogicalResult Verify(BroadcastOp op) {
966   auto sizes = op.broadcast_sizes();
967   auto sizesType = sizes.getType();
968   auto sizesRank = sizesType.getRank();
969   if (sizesRank != 1) {
970     return op.emitOpError(llvm::formatv(
971         "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
972   }
973 
974   auto resultType = op.getResult().getType().cast<RankedTensorType>();
975   auto resultRank = resultType.getRank();
976   auto operandType = op.operand().getType().cast<RankedTensorType>();
977   auto operandRank = operandType.getRank();
978   auto sizesSize = sizesType.getNumElements();
979   auto expectedRank = operandRank + sizesSize;
980 
981   if (resultRank != expectedRank) {
982     return op.emitOpError(
983         llvm::formatv("result rank ({0}) does not match operand rank "
984                       "({1}) plus size of broadcast_sizes ({2})",
985                       resultRank, operandRank, sizesSize));
986   }
987 
988   llvm::SmallVector<int64_t, 10> expectedShape(sizes.getValues<int64_t>());
989 
990   auto operandShape = operandType.getShape();
991   expectedShape.insert(expectedShape.end(), operandShape.begin(),
992                        operandShape.end());
993 
994   auto resultShape = resultType.getShape();
995   if (resultShape != llvm::makeArrayRef(expectedShape)) {
996     return op.emitOpError(llvm::formatv(
997         "result has shape [{0}] instead of [{1}]",
998         llvm::make_range(resultShape.begin(), resultShape.end()),
999         llvm::make_range(expectedShape.begin(), expectedShape.end())));
1000   }
1001 
1002   return success();
1003 }
1004 
1005 //===----------------------------------------------------------------------===//
1006 // BroadcastInDimOp
1007 //===----------------------------------------------------------------------===//
1008 
Verify(BroadcastInDimOp op)1009 static LogicalResult Verify(BroadcastInDimOp op) {
1010   auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
1011   if (!operandType) {
1012     // The following verification checks all depend on knowing the rank of
1013     // the operand. Bail out now if we don't know the rank of the operand.
1014     return success();
1015   }
1016 
1017   auto operandRank = operandType.getRank();
1018   if (!op.broadcast_dimensions()) {
1019     if (operandRank == 0) {
1020       return success();
1021     }
1022     return op.emitOpError(
1023         llvm::formatv("broadcast_dimensions is absent, but required because "
1024                       "operand has non-zero rank ({0})",
1025                       operandRank));
1026   }
1027 
1028   auto dimensions = op.broadcast_dimensions();
1029   auto dimensionsType = op.broadcast_dimensions().getType();
1030   auto dimensionsRank = dimensionsType.getRank();
1031   if (dimensionsRank != 1) {
1032     return op.emitOpError(llvm::formatv(
1033         "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
1034   }
1035 
1036   auto dimensionsSize = dimensionsType.getNumElements();
1037   if (dimensionsSize != operandRank) {
1038     return op.emitOpError(llvm::formatv(
1039         "broadcast_dimensions size ({0}) does not match operand rank ({1})",
1040         dimensionsSize, operandRank));
1041   }
1042 
1043   auto resultType = op.getResult().getType().cast<RankedTensorType>();
1044   auto resultRank = resultType.getRank();
1045   if (resultRank < operandRank) {
1046     return op.emitOpError(
1047         llvm::formatv("result rank ({0}) is less than operand rank ({1})",
1048                       resultRank, operandRank));
1049   }
1050 
1051   for (int i = 0; i != dimensionsSize; ++i) {
1052     auto dimIndex = dimensions.getValue<int64_t>(i);
1053     if (dimIndex >= resultRank) {
1054       return op.emitOpError(
1055           llvm::formatv("broadcast_dimensions contains invalid value {0} for "
1056                         "result with rank {1}",
1057                         dimIndex, resultRank));
1058     }
1059 
1060     if (!operandType.isDynamicDim(i)) {
1061       auto dimSize = operandType.getDimSize(i);
1062       auto resultDimSize = resultType.getDimSize(dimIndex);
1063       if (dimSize != 1 && dimSize != resultDimSize) {
1064         return op.emitOpError(
1065             llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
1066                           "1 or size of result dimension {2} ({3})",
1067                           i, dimSize, dimIndex, resultDimSize));
1068       }
1069     }
1070   }
1071 
1072   return success();
1073 }
1074 
fold(ArrayRef<Attribute> attrs)1075 OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
1076   auto type = getType().cast<RankedTensorType>();
1077   if (type == getOperand().getType()) {
1078     auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
1079     if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
1080                     llvm::seq<int64_t>(0, type.getRank()).begin())) {
1081       return {};
1082     }
1083     return getOperand();
1084   }
1085 
1086   // Constant fold when an operand is a splat tensor attribute.
1087   if (!attrs[0] || !type.hasStaticShape()) return {};
1088   auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
1089   if (!splatOperandAttr) return {};
1090   // MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element
1091   // attribute iterator not implemented for complex element types.
1092   if (type.getElementType().isa<ComplexType>()) return {};
1093   return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue());
1094 }
1095 
1096 //===----------------------------------------------------------------------===//
1097 // DynamicBroadcastInDimOp
1098 //===----------------------------------------------------------------------===//
1099 
Verify(DynamicBroadcastInDimOp op)1100 static LogicalResult Verify(DynamicBroadcastInDimOp op) {
1101   auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
1102   auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
1103 
1104   // If either the operand or result are unranked, there is very little
1105   // to verify statically.
1106   if (!operandType || !resultType) {
1107     return success();
1108   }
1109 
1110   auto outputDimensionsType =
1111       op.output_dimensions().getType().cast<RankedTensorType>();
1112   auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
1113   auto operandRank = operandType.getRank();
1114   auto resultRank = resultType.getRank();
1115 
1116   // Verify broadcast_dimensions.
1117   auto bcastDimensions = op.broadcast_dimensions();
1118   auto bcastDimensionsType = op.broadcast_dimensions().getType();
1119   auto bcastDimensionsRank = bcastDimensionsType.getRank();
1120   // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
1121   if (bcastDimensionsRank != 1) {
1122     return op.emitOpError(
1123         llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
1124                       bcastDimensionsRank));
1125   }
1126 
1127   auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
1128   if (bcastDimensionsSize != operandRank) {
1129     return op.emitOpError(llvm::formatv(
1130         "broadcast_dimensions size ({0}) does not match operand rank ({1})",
1131         bcastDimensionsSize, operandRank));
1132   }
1133 
1134   if (resultRank < operandRank) {
1135     return op.emitOpError(
1136         llvm::formatv("result rank ({0}) is less than operand rank ({1})",
1137                       resultRank, operandRank));
1138   }
1139 
1140   for (int i = 0; i != bcastDimensionsSize; ++i) {
1141     auto dimIndex = bcastDimensions.getValue<int64_t>(i);
1142     if (dimIndex >= resultRank) {
1143       return op.emitOpError(
1144           llvm::formatv("broadcast_dimensions contains invalid value {0} for "
1145                         "result with rank {1}",
1146                         dimIndex, resultRank));
1147     }
1148 
1149     auto dimSize = operandType.getDimSize(i);
1150     auto resultDimSize = resultType.getDimSize(dimIndex);
1151     // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
1152     // add a manual check for this.
1153     if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
1154       return op.emitOpError(
1155           llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
1156                         "with size of result dimension {2} ({3})",
1157                         i, dimSize, dimIndex, resultDimSize));
1158     }
1159   }
1160 
1161   if (outputDimensionsSize != resultRank) {
1162     return op.emitOpError(
1163         llvm::formatv("result rank ({0}) is not equal to number of output "
1164                       "dimensions ({1})",
1165                       resultRank, outputDimensionsSize));
1166   }
1167 
1168   return success();
1169 }
1170 
1171 namespace {
1172 // If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
1173 // BroadcastInDimOp.
1174 class DynamicBroadcastInDimOpNotActuallyDynamic
1175     : public OpRewritePattern<DynamicBroadcastInDimOp> {
1176   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp op,PatternRewriter & rewriter) const1177   LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
1178                                 PatternRewriter& rewriter) const override {
1179     auto type = op.getType().dyn_cast<RankedTensorType>();
1180     if (!type || !type.hasStaticShape()) {
1181       return rewriter.notifyMatchFailure(op, "requires static shape");
1182     }
1183     rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
1184         op, op.getType(), op.operand(), op.broadcast_dimensions());
1185     return success();
1186   }
1187 };
1188 
1189 class ChainedDynamicBroadcastInDimCanonicalization
1190     : public OpRewritePattern<DynamicBroadcastInDimOp> {
1191   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp bcast,PatternRewriter & rewriter) const1192   LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast,
1193                                 PatternRewriter& rewriter) const override {
1194     auto preceding_bcast =
1195         bcast.operand().getDefiningOp<DynamicBroadcastInDimOp>();
1196     if (!preceding_bcast) return failure();
1197 
1198     // Compose broadcast dimensions.
1199     DenseIntElementsAttr preceding_bcast_dims =
1200         preceding_bcast.broadcast_dimensions();
1201     DenseIntElementsAttr bcast_dims = bcast.broadcast_dimensions();
1202     SmallVector<APInt, 4> composition;
1203     for (APInt preceding_dim : preceding_bcast_dims) {
1204       auto composed_dim = bcast_dims.getValue({preceding_dim.getZExtValue()})
1205                               .cast<IntegerAttr>();
1206       composition.push_back(composed_dim.getValue());
1207     }
1208     auto composed_bcast_dims =
1209         DenseIntElementsAttr::get(preceding_bcast_dims.getType(), composition);
1210 
1211     rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
1212         bcast, bcast.getType(), preceding_bcast.operand(),
1213         bcast.output_dimensions(), composed_bcast_dims);
1214     return success();
1215   }
1216 };
1217 }  // namespace
1218 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1219 void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
1220     OwningRewritePatternList& results, MLIRContext* context) {
1221   results.insert<ChainedDynamicBroadcastInDimCanonicalization,
1222                  DynamicBroadcastInDimOpNotActuallyDynamic,
1223                  DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
1224                  DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
1225       context);
1226 }
1227 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1228 LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
1229     OpBuilder& builder, ValueRange operands,
1230     SmallVectorImpl<Value>& reifiedReturnShapes) {
1231   DynamicBroadcastInDimOp::Adaptor adaptor(operands);
1232   reifiedReturnShapes.push_back(
1233       castToIndexTensor(builder, getLoc(), adaptor.output_dimensions()));
1234   return success();
1235 }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // ClampOp
1239 //===----------------------------------------------------------------------===//
1240 
Verify(ClampOp op)1241 static LogicalResult Verify(ClampOp op) {
1242   auto operandType = op.operand().getType().cast<RankedTensorType>();
1243   auto operandShape = operandType.getShape();
1244   auto minType = op.min().getType().cast<RankedTensorType>();
1245 
1246   auto minShape = minType.getShape();
1247   if (minShape != operandShape && minType.getRank() != 0) {
1248     return op.emitOpError(llvm::formatv(
1249         "min shape [{0}] is not scalar and does not match operand shape [{1}]",
1250         llvm::make_range(minShape.begin(), minShape.end()),
1251         llvm::make_range(operandShape.begin(), operandShape.end())));
1252   }
1253 
1254   auto maxType = op.max().getType().cast<RankedTensorType>();
1255   auto maxShape = maxType.getShape();
1256   if (maxShape != operandShape && maxType.getRank() != 0) {
1257     return op.emitOpError(llvm::formatv(
1258         "max shape [{0}] is not scalar and does not match operand shape [{1}]",
1259         llvm::make_range(maxShape.begin(), maxShape.end()),
1260         llvm::make_range(operandShape.begin(), operandShape.end())));
1261   }
1262 
1263   return success();
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // ComplexOp
1268 //===----------------------------------------------------------------------===//
1269 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1270 LogicalResult ComplexOp::inferReturnTypes(
1271     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1272     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1273   auto type = operands[0].getType();
1274   auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
1275   Type result_ty;
1276   if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
1277     result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
1278   } else if (type.isa<UnrankedTensorType>()) {
1279     result_ty = UnrankedTensorType::get(element_ty);
1280   } else {
1281     result_ty = element_ty;
1282   }
1283   inferredReturnTypes.push_back(result_ty);
1284   return success();
1285 }
1286 
fold(ArrayRef<Attribute> operands)1287 OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
1288   auto real_op = getOperand(0).getDefiningOp<mhlo::RealOp>();
1289   auto imag_op = getOperand(1).getDefiningOp<mhlo::ImagOp>();
1290   if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
1291     return real_op.getOperand();
1292   }
1293 
1294   return {};
1295 }
1296 
1297 //===----------------------------------------------------------------------===//
1298 // ImagOp
1299 //===----------------------------------------------------------------------===//
1300 
1301 namespace {
CreateRealType(Type type)1302 Type CreateRealType(Type type) {
1303   auto element_ty = getElementTypeOrSelf(type);
1304   if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
1305     element_ty = complex_ty.getElementType();
1306   }
1307 
1308   if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
1309     return RankedTensorType::get(ranked_type.getShape(), element_ty);
1310   } else if (type.dyn_cast<UnrankedTensorType>()) {
1311     return UnrankedTensorType::get(element_ty);
1312   }
1313 
1314   return element_ty;
1315 }
1316 }  // namespace
1317 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1318 LogicalResult ImagOp::inferReturnTypes(
1319     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1320     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1321   inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
1322   return success();
1323 }
1324 
fold(ArrayRef<Attribute> operands)1325 OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
1326   if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
1327     return complex_op.getOperand(1);
1328   }
1329 
1330   return {};
1331 }
1332 
1333 //===----------------------------------------------------------------------===//
1334 // IsFiniteOp
1335 //===----------------------------------------------------------------------===//
1336 
getSameShapeTensorType(TensorType tensor_type,Type element_type)1337 TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) {
1338   if (auto ranked_tensor_ty = tensor_type.dyn_cast<RankedTensorType>()) {
1339     return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type);
1340   }
1341   if (auto unranked_tensor_ty = tensor_type.dyn_cast<UnrankedTensorType>()) {
1342     return UnrankedTensorType::get(element_type);
1343   }
1344   llvm_unreachable("unhandled type");
1345 }
1346 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1347 LogicalResult IsFiniteOp::inferReturnTypes(
1348     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
1349     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1350   auto arg_ty = operands.front().getType().cast<TensorType>();
1351   Builder b(ctx);
1352   inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type()));
1353   return success();
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // RealOp
1358 //===----------------------------------------------------------------------===//
1359 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1360 LogicalResult RealOp::inferReturnTypes(
1361     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1362     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1363   inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
1364   return success();
1365 }
1366 
fold(ArrayRef<Attribute> operands)1367 OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
1368   if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
1369     return complex_op.getOperand(0);
1370   }
1371 
1372   return {};
1373 }
1374 
1375 //===----------------------------------------------------------------------===//
1376 // ConcatenateOp
1377 //===----------------------------------------------------------------------===//
1378 
1379 namespace {
1380 class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
1381  public:
1382   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const1383   LogicalResult matchAndRewrite(ConcatenateOp op,
1384                                 PatternRewriter& rewriter) const override {
1385     auto axis = op.dimension();
1386     llvm::SmallVector<Value, 6> new_operands;
1387     for (auto operand : op.getOperands()) {
1388       auto ty = operand.getType().cast<ShapedType>();
1389       if (ty.getDimSize(axis) != 0) {
1390         new_operands.push_back(operand);
1391       }
1392     }
1393 
1394     if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
1395       rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
1396                                                  new_operands, op.dimension());
1397       return success();
1398     }
1399 
1400     return failure();
1401   }
1402 };
1403 }  // namespace
1404 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1405 LogicalResult ConcatenateOp::inferReturnTypes(
1406     MLIRContext*, Optional<Location> location, ValueRange operands,
1407     DictionaryAttr attributes, RegionRange regions,
1408     SmallVectorImpl<Type>& inferredReturnTypes) {
1409   if (operands.empty()) {
1410     return failure();
1411   }
1412 
1413   auto dimension_attr = attributes.get("dimension").cast<IntegerAttr>();
1414   auto dimension = dimension_attr.getInt();
1415 
1416   auto first_type = (*operands.begin()).getType().cast<ShapedType>();
1417   auto out_element = first_type.getElementType();
1418 
1419   for (auto operand : operands.getTypes()) {
1420     auto element_type = getElementTypeOrSelf(operand);
1421     if (element_type != out_element) {
1422       return failure();
1423     }
1424   }
1425 
1426   // Find the first ranked input to determine the output rank.
1427   for (auto type : operands.getTypes()) {
1428     auto shaped_type = type.cast<ShapedType>();
1429     if (shaped_type.hasRank()) {
1430       first_type = shaped_type;
1431       break;
1432     }
1433   }
1434 
1435   // If all inputs are unranked, the result must be unranked.
1436   if (!first_type.hasRank()) {
1437     inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
1438     return success();
1439   }
1440 
1441   if (first_type.getRank() == 0)
1442     return emitOptionalError(location, "rank-0 values cannot be concatenated");
1443 
1444   auto out_shape = llvm::to_vector<6>(first_type.getShape());
1445 
1446   // Determine what the non-concatenate dimensions should be.
1447   for (auto type : operands.getTypes()) {
1448     auto shaped_ty = type.cast<ShapedType>();
1449     if (!shaped_ty.hasRank()) {
1450       continue;
1451     }
1452 
1453     for (auto it : llvm::enumerate(shaped_ty.getShape())) {
1454       // If a dimension is not dynamic, the output shape should match.
1455       if (ShapedType::isDynamic(out_shape[it.index()])) {
1456         out_shape[it.index()] = it.value();
1457       }
1458     }
1459   }
1460 
1461   out_shape[dimension] = 0;
1462 
1463   for (auto operand : operands.getTypes()) {
1464     auto type = operand.cast<ShapedType>();
1465     if (!type.hasRank()) {
1466       inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
1467       return success();
1468     }
1469 
1470     // If the dimension is dynamic we know the output dimension is dynamic.
1471     auto dim = type.getShape()[dimension];
1472     if (dim == -1) {
1473       out_shape[dimension] = -1;
1474       break;
1475     }
1476 
1477     out_shape[dimension] += dim;
1478   }
1479 
1480   inferredReturnTypes.push_back(RankedTensorType::get(out_shape, out_element));
1481 
1482   return success();
1483 }
1484 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1485 void ConcatenateOp::getCanonicalizationPatterns(
1486     OwningRewritePatternList& results, MLIRContext* context) {
1487   results.insert<ConcatenateOperandRemoval>(context);
1488 }
1489 
1490 template <typename T>
foldConcatenateHelper(ConcatenateOp * op,ArrayRef<Attribute> operands)1491 static Attribute foldConcatenateHelper(ConcatenateOp* op,
1492                                        ArrayRef<Attribute> operands) {
1493   auto axis = op->dimension();
1494   auto type = op->getType().cast<ShapedType>();
1495 
1496   SmallVector<T, 6> values;
1497   auto shape = type.getShape();
1498 
1499   size_t top_size = 1;
1500   for (int i = 0, e = axis; i < e; i++) {
1501     top_size = top_size * shape[i];
1502   }
1503 
1504   for (size_t i = 0; i < top_size; i++) {
1505     for (auto operand : operands) {
1506       DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
1507       size_t bottom_size = attr.getNumElements() / top_size;
1508       auto iter = attr.getValues<T>().begin() + i * bottom_size;
1509       values.append(iter, iter + bottom_size);
1510     }
1511   }
1512 
1513   return DenseElementsAttr::get(type, values);
1514 }
1515 
foldConcatenate(ConcatenateOp * op,ArrayRef<Attribute> operands)1516 static Attribute foldConcatenate(ConcatenateOp* op,
1517                                  ArrayRef<Attribute> operands) {
1518   for (auto operand : operands) {
1519     if (!operand) return {};
1520   }
1521 
1522   auto type = op->getResult().getType().cast<ShapedType>();
1523   auto etype = type.getElementType();
1524   if (etype.isa<IntegerType>()) {
1525     return foldConcatenateHelper<APInt>(op, operands);
1526   }
1527 
1528   if (etype.isa<FloatType>()) {
1529     return foldConcatenateHelper<APFloat>(op, operands);
1530   }
1531 
1532   return {};
1533 }
1534 
fold(ArrayRef<Attribute> operands)1535 OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
1536   if (getNumOperands() == 1) return getOperand(0);
1537 
1538   ShapedType type = getResult().getType().cast<ShapedType>();
1539   if (!type.hasStaticShape()) return {};
1540 
1541   auto axis = dimension();
1542   if (auto attr = foldConcatenate(this, operands)) {
1543     return attr;
1544   }
1545 
1546   llvm::SmallVector<Value, 6> new_operands;
1547   for (auto operand : getOperands()) {
1548     auto ty = operand.getType().cast<ShapedType>();
1549     if (ty.getDimSize(axis) != 0) {
1550       return {};
1551     }
1552   }
1553 
1554   return DenseElementsAttr::get(type, ArrayRef<Attribute>());
1555 }
1556 
Verify(ConcatenateOp op)1557 static LogicalResult Verify(ConcatenateOp op) {
1558   Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
1559   RankedTensorType first_ranked_type;
1560   int num_operands = op.getNumOperands();
1561   for (int i = 0; i < num_operands; i++) {
1562     auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
1563     if (second_type.getElementType() != element_type) {
1564       return op.emitOpError(
1565           llvm::formatv("operands (0) and ({0}) do not match element type", i));
1566     }
1567 
1568     if (!second_type.hasRank()) {
1569       continue;
1570     }
1571 
1572     if (!first_ranked_type) {
1573       first_ranked_type = second_type.cast<RankedTensorType>();
1574       continue;
1575     }
1576 
1577     if (first_ranked_type.getRank() != second_type.getRank()) {
1578       return op.emitOpError(
1579           llvm::formatv("operands (0) and ({0}) do not match rank", i));
1580     }
1581 
1582     auto first_shape = second_type.getShape();
1583     auto second_shape = second_type.getShape();
1584     for (int d = 0; d < first_ranked_type.getRank(); ++d) {
1585       if (first_shape[d] != second_shape[d] && d != op.dimension()) {
1586         return op.emitOpError(llvm::formatv(
1587             "operands (0) and ({0}) non-concat dimensions do not match "
1588             "({1}) != ({2})",
1589             i, llvm::make_range(first_shape.begin(), first_shape.end()),
1590             llvm::make_range(second_shape.begin(), second_shape.end())));
1591       }
1592     }
1593   }
1594   return success();
1595 }
1596 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1597 LogicalResult ConcatenateOp::reifyReturnTypeShapes(
1598     OpBuilder& builder, ValueRange operands,
1599     SmallVectorImpl<Value>& reifiedReturnShapes) {
1600   ConcatenateOp::Adaptor adaptor(operands);
1601   auto inputs = adaptor.val();
1602 
1603   auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
1604   // Not support unranked type a.t.m.
1605   if (!operand_type) return failure();
1606 
1607   Location loc = this->getLoc();
1608   Type shape_scalar_type = builder.getIndexType();
1609   auto to_shape_scalar_type = [&](Value v) {
1610     return MaybeCastTo(builder, loc, v, shape_scalar_type);
1611   };
1612 
1613   SmallVector<SmallVector<Value, 4>, 4> all_shape_values;
1614   for (size_t input_id = 0; input_id < inputs.size(); ++input_id) {
1615     Value operand = inputs[input_id];
1616     auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
1617     if (!operand_type) return failure();
1618 
1619     SmallVector<Value, 4> shape_vals;
1620     for (const auto& element : llvm::enumerate(operand_type.getShape())) {
1621       Value value_dim = to_shape_scalar_type(
1622           builder.create<tensor::DimOp>(loc, operand, element.index()));
1623       shape_vals.push_back(value_dim);
1624     }
1625     all_shape_values.emplace_back(std::move(shape_vals));
1626   }
1627 
1628   int axis = this->dimension();
1629   auto& shape_values = all_shape_values[0];
1630   for (size_t vec_id = 1; vec_id < all_shape_values.size(); ++vec_id) {
1631     auto& other_shape_values = all_shape_values[vec_id];
1632     if (other_shape_values.size() != shape_values.size()) {
1633       this->emitOpError()
1634           << "Concatenate expects all operands must be of the same rank";
1635       return failure();
1636     }
1637     shape_values[axis] = builder.create<AddIOp>(loc, shape_values[axis],
1638                                                 other_shape_values[axis]);
1639   }
1640 
1641   Value output_shape = builder.create<tensor::FromElementsOp>(
1642       loc, shape_scalar_type, shape_values);
1643   reifiedReturnShapes.push_back(output_shape);
1644 
1645   return success();
1646 }
1647 
1648 //===----------------------------------------------------------------------===//
1649 // DynamicReshapeOp
1650 //===----------------------------------------------------------------------===//
1651 
Verify(DynamicReshapeOp op)1652 static LogicalResult Verify(DynamicReshapeOp op) {
1653   auto result_type = op.result().getType().dyn_cast<RankedTensorType>();
1654   auto output_shape_type =
1655       op.output_shape().getType().dyn_cast<RankedTensorType>();
1656   if (result_type && output_shape_type && output_shape_type.hasStaticShape() &&
1657       output_shape_type.getDimSize(0) != result_type.getRank()) {
1658     return op.emitError() << "output should have a rank equal to the number of "
1659                              "elements in output_shape";
1660   }
1661   return success();
1662 }
1663 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1664 LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
1665     OpBuilder& builder, ValueRange operands,
1666     SmallVectorImpl<Value>& reifiedReturnShapes) {
1667   DynamicReshapeOp::Adaptor adaptor(operands);
1668   reifiedReturnShapes.push_back(
1669       castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
1670   return success();
1671 }
1672 
1673 namespace {
1674 class DynamicReshapeOpNotActuallyDynamic
1675     : public OpRewritePattern<DynamicReshapeOp> {
1676  public:
1677   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1678   LogicalResult matchAndRewrite(DynamicReshapeOp op,
1679                                 PatternRewriter& rewriter) const override {
1680     auto type = op.result().getType().dyn_cast<RankedTensorType>();
1681     if (!type || !type.hasStaticShape()) {
1682       return rewriter.notifyMatchFailure(op, "requires static shape tensor");
1683     }
1684     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
1685     return success();
1686   }
1687 };
1688 
1689 // Canonicalizes
1690 // %0 = some_op(%tensor)
1691 // %1 = "mhlo.dynamic_reshape"(%0, %shape)
1692 //      (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
1693 // ... uses of %1.
1694 //
1695 // into
1696 //
1697 // ... uses of %0.
1698 // This canonicalization is only correct if the input is correct!
1699 // TODO(b/178779691): Use a more sophisticated canonicalization that preserves
1700 // errors in input, and still allows us to get rid of redundant reshapes.
1701 class RemoveRedundantRank1DynamicReshape
1702     : public OpRewritePattern<DynamicReshapeOp> {
1703  public:
1704   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1705   LogicalResult matchAndRewrite(DynamicReshapeOp op,
1706                                 PatternRewriter& rewriter) const override {
1707     auto type = op.result().getType().dyn_cast<RankedTensorType>();
1708     if (!type || type.getRank() != 1 || type.hasStaticShape()) {
1709       return rewriter.notifyMatchFailure(
1710           op, "requires rank 1 shape tensor with dynamic dimension");
1711     }
1712     auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
1713     if (!operand_type || operand_type.getRank() != 1 ||
1714         operand_type.hasStaticShape()) {
1715       return rewriter.notifyMatchFailure(
1716           op, "requires rank 1 shape tensor with dynamic dimension");
1717     }
1718     rewriter.replaceOp(op, {op.operand()});
1719     return success();
1720   }
1721 };
1722 
1723 // Canonicalizes
1724 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
1725 // %1 = same_operands_and_result_shape_op(%tensor)
1726 // %2 = "mhlo.dynamic_reshape"(%1, %shape)
1727 // ... uses of %2.
1728 //
1729 // into
1730 //
1731 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
1732 // %1 = same_operands_and_result_shape_op(%tensor)
1733 // ... uses of %1.
1734 class DynamicReshapeOpSameShapeOpResult
1735     : public OpRewritePattern<DynamicReshapeOp> {
1736  public:
1737   using OpRewritePattern::OpRewritePattern;
1738 
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1739   LogicalResult matchAndRewrite(DynamicReshapeOp op,
1740                                 PatternRewriter& rewriter) const override {
1741     Operation* def_op = op.operand().getDefiningOp();
1742     if (!def_op ||
1743         !def_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
1744       return failure();
1745     }
1746     Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
1747     if (!input_def_op) {
1748       return failure();
1749     }
1750     auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
1751     if (reshape && reshape.output_shape() == op.output_shape()) {
1752       rewriter.replaceOp(op, {def_op->getResult(0)});
1753       return success();
1754     }
1755     return failure();
1756   }
1757 };
1758 }  // namespace
1759 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1760 void DynamicReshapeOp::getCanonicalizationPatterns(
1761     OwningRewritePatternList& results, MLIRContext* context) {
1762   // clang-format off
1763   results.insert<
1764       DynamicReshapeOpNotActuallyDynamic,
1765       DynamicReshapeOpSameShapeOpResult,
1766       RemoveRedundantDynamicBroadcast,
1767       RemoveRedundantDynamicReshape,
1768       RemoveRedundantRank1DynamicReshape,
1769       ShapeOfDynamicReshape
1770     >(context);
1771   // clang-format on
1772 }
1773 
1774 //===----------------------------------------------------------------------===//
1775 // DynamicSliceOp
1776 //===----------------------------------------------------------------------===//
1777 
1778 namespace {
1779 // Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
1780 // This canonicalization is applied the case when the `begin` input values are
1781 // compile time constants and thus can be made into a tensor.
1782 struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
1783   using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
1784 
matchAndRewritemlir::mhlo::__anon7800207c0d11::DynamicSliceToSlice1785   LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice,
1786                                 PatternRewriter& rewriter) const override {
1787     Value input = dynamic_slice.operand();
1788     auto input_tensor = input.getType().dyn_cast<RankedTensorType>();
1789     if (!input_tensor) return failure();
1790 
1791     SmallVector<int64_t, 4> temp_start_indices;
1792     for (Value start : dynamic_slice.start_indices()) {
1793       APInt val;
1794       if (!matchPattern(start, m_ConstantInt(&val))) {
1795         return failure();
1796       }
1797       temp_start_indices.push_back(*(val.getRawData()));
1798     }
1799 
1800     // At this point we've determined that the start indices are all constants;
1801     // pack them into a single tensor.
1802     auto loc = dynamic_slice.getLoc();
1803     int64_t input_rank = input_tensor.getRank();
1804     auto slice_start_indices =
1805         GetI64ElementsAttr(temp_start_indices, &rewriter);
1806     DenseIntElementsAttr slice_limits = BuildSliceLimits(
1807         slice_start_indices, dynamic_slice.slice_sizes(), &rewriter);
1808     DenseIntElementsAttr slice_strides =
1809         GetI64ElementsAttr(SmallVector<int64_t, 4>(input_rank, 1), &rewriter);
1810     auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
1811                                            slice_limits, slice_strides);
1812     rewriter.replaceOp(dynamic_slice, {result});
1813     return success();
1814   }
1815 };
1816 
1817 }  // namespace
1818 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1819 void DynamicSliceOp::getCanonicalizationPatterns(
1820     OwningRewritePatternList& results, MLIRContext* context) {
1821   results.insert<DynamicSliceToSlice>(context);
1822 }
1823 
1824 // Verifies that the number of slice sizes and the number of start indices match
Verify(DynamicSliceOp op)1825 static LogicalResult Verify(DynamicSliceOp op) {
1826   int num_slice_sizes = op.slice_sizes().getNumElements();
1827   int num_start_indices = op.start_indices().size();
1828   if (num_start_indices != num_slice_sizes) {
1829     return op.emitOpError()
1830            << "has mismatched number of slice sizes (" << num_slice_sizes
1831            << ") and number of start indices (" << num_start_indices << ")";
1832   }
1833   return success();
1834 }
1835 
1836 //===----------------------------------------------------------------------===//
1837 // RealDynamicSliceOp
1838 //===----------------------------------------------------------------------===//
1839 // Verifies that operand rank matches start_indices/limit_indices/strides size
Verify(RealDynamicSliceOp op)1840 static LogicalResult Verify(RealDynamicSliceOp op) {
1841   auto input_type = op.operand().getType().dyn_cast<RankedTensorType>();
1842   // If operand is unranked, there is very little to verify statically.
1843   if (!input_type) return success();
1844   int input_rank = input_type.getRank();
1845 
1846   auto start_type = op.start_indices().getType().cast<RankedTensorType>();
1847   auto limit_type = op.limit_indices().getType().cast<RankedTensorType>();
1848   auto strides_type = op.strides().getType().cast<RankedTensorType>();
1849 
1850   if (input_rank != start_type.getNumElements()) {
1851     return op.emitOpError() << "has mismatched number of operand rank ("
1852                             << input_rank << ") and start_indices size ("
1853                             << start_type.getNumElements() << ")";
1854   }
1855 
1856   if (input_rank != limit_type.getNumElements()) {
1857     return op.emitOpError() << "has mismatched number of operand rank ("
1858                             << input_rank << ") and limit_indices size ("
1859                             << limit_type.getNumElements() << ")";
1860   }
1861 
1862   if (input_rank != strides_type.getNumElements()) {
1863     return op.emitOpError()
1864            << "has mismatched number of operand rank (" << input_rank
1865            << ") and strides size (" << strides_type.getNumElements() << ")";
1866   }
1867 
1868   return success();
1869 }
1870 
1871 namespace {
1872 // Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice
1873 // ops. This canonicalization is applied the case when the `begin` input values
1874 // are compile time constants and thus can be made into a tensor.
1875 struct RealDynamicSliceIsStatic : public OpRewritePattern<RealDynamicSliceOp> {
1876   using OpRewritePattern<RealDynamicSliceOp>::OpRewritePattern;
1877 
matchAndRewritemlir::mhlo::__anon7800207c0e11::RealDynamicSliceIsStatic1878   LogicalResult matchAndRewrite(RealDynamicSliceOp real_dynamic_slice,
1879                                 PatternRewriter& rewriter) const override {
1880     Location loc = real_dynamic_slice.getLoc();
1881     Value input = real_dynamic_slice.operand();
1882     Value output = real_dynamic_slice.result();
1883     auto input_ty = input.getType().dyn_cast<RankedTensorType>();
1884     auto output_ty = output.getType().dyn_cast<RankedTensorType>();
1885 
1886     if (!input_ty || !output_ty || !input_ty.hasStaticShape() ||
1887         !output_ty.hasStaticShape()) {
1888       return failure();
1889     }
1890 
1891     int64_t input_rank = input_ty.getRank();
1892 
1893     auto start_val = real_dynamic_slice.start_indices();
1894     auto limit_val = real_dynamic_slice.limit_indices();
1895     auto stride_val = real_dynamic_slice.strides();
1896     auto start_op = start_val.getDefiningOp<mlir::ConstantOp>();
1897     auto limit_op = limit_val.getDefiningOp<mlir::ConstantOp>();
1898     auto stride_op = stride_val.getDefiningOp<mlir::ConstantOp>();
1899     if (!start_op || !limit_op || !stride_op) return failure();
1900 
1901     auto start_attr =
1902         start_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
1903     auto limit_attr =
1904         limit_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
1905     auto stride_attr =
1906         stride_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
1907     if (!start_attr || !limit_attr || !stride_attr) return failure();
1908 
1909     SmallVector<int64_t, 4> temp_start_indices;
1910     SmallVector<int64_t, 4> temp_limit_indices;
1911     SmallVector<int64_t, 4> temp_stride;
1912     for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) {
1913       int64_t start = start_attr.getValue<IntegerAttr>(dim_idx).getInt();
1914       temp_start_indices.push_back(start);
1915       int64_t limit = limit_attr.getValue<IntegerAttr>(dim_idx).getInt();
1916       temp_limit_indices.push_back(limit);
1917       int64_t end = stride_attr.getValue<IntegerAttr>(dim_idx).getInt();
1918       temp_stride.push_back(end);
1919     }
1920 
1921     DenseIntElementsAttr slice_start_indices =
1922         GetI64ElementsAttr(temp_start_indices, &rewriter);
1923     DenseIntElementsAttr slice_limit_indices =
1924         GetI64ElementsAttr(temp_limit_indices, &rewriter);
1925     DenseIntElementsAttr slice_strides =
1926         GetI64ElementsAttr(temp_stride, &rewriter);
1927     auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
1928                                            slice_limit_indices, slice_strides);
1929     rewriter.replaceOp(real_dynamic_slice, {result});
1930     return success();
1931   }
1932 };
1933 }  // namespace
1934 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1935 void RealDynamicSliceOp::getCanonicalizationPatterns(
1936     OwningRewritePatternList& results, MLIRContext* context) {
1937   results.insert<RealDynamicSliceIsStatic, RealDSliceToSlice>(context);
1938 }
1939 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1940 LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
1941     OpBuilder& builder, ValueRange operands,
1942     SmallVectorImpl<Value>& reifiedReturnShapes) {
1943   RealDynamicSliceOp::Adaptor adaptor(operands);
1944   Value operand = adaptor.operand();
1945   Value start_indices = adaptor.start_indices();
1946   Value limit_indices = adaptor.limit_indices();
1947   Value strides = adaptor.strides();
1948 
1949   auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
1950   // Not support unranked type a.t.m.
1951   if (!operand_type) return failure();
1952 
1953   Location loc = this->getLoc();
1954   SmallVector<Value, 4> shape_values;
1955   shape_values.reserve(operand_type.getRank());
1956   Type shape_scalar_type =
1957       start_indices.getType().cast<ShapedType>().getElementType();
1958   Value one = builder.create<ConstantIndexOp>(loc, 1);
1959   one = MaybeCastTo(builder, loc, one, shape_scalar_type);
1960   for (const auto& element : llvm::enumerate(operand_type.getShape())) {
1961     Value offset = builder.create<ConstantIndexOp>(loc, element.index());
1962     Value value_start =
1963         builder.create<tensor::ExtractOp>(loc, start_indices, offset);
1964     Value value_limit =
1965         builder.create<tensor::ExtractOp>(loc, limit_indices, offset);
1966     Value value_stride =
1967         builder.create<tensor::ExtractOp>(loc, strides, offset);
1968     // size = (limit - start + stride - 1) / stride
1969     shape_values.push_back(builder.create<SignedDivIOp>(
1970         loc,
1971         builder.create<SubIOp>(
1972             loc,
1973             builder.create<AddIOp>(
1974                 loc, value_stride,
1975                 builder.create<SubIOp>(loc, value_limit, value_start)),
1976             one),
1977         value_stride));
1978   }
1979 
1980   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
1981       loc, shape_scalar_type, shape_values));
1982 
1983   return success();
1984 }
1985 
1986 //===----------------------------------------------------------------------===//
1987 // InfeedOp
1988 //===----------------------------------------------------------------------===//
1989 
1990 // Checks that the result type is of the form `tuple< any_type, token >`.
Verify(InfeedOp op)1991 static LogicalResult Verify(InfeedOp op) {
1992   auto result_ty = op.getResult().getType().cast<TupleType>();
1993   auto subtypes = result_ty.getTypes();
1994   if (subtypes.size() != 2)
1995     return op.emitOpError()
1996            << "result is expected to be a tuple of size 2, but got "
1997            << subtypes.size();
1998   if (!subtypes[1].isa<TokenType>())
1999     return op.emitOpError() << "second element of result tuple is expected to "
2000                                "be of token type, but got "
2001                             << subtypes[1];
2002   return success();
2003 }
2004 
2005 //===----------------------------------------------------------------------===//
2006 // Logical Ops
2007 //===----------------------------------------------------------------------===//
2008 
fold(ArrayRef<Attribute> operands)2009 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
2010   if (lhs() == rhs()) return lhs();
2011 
2012   auto rType = getType().cast<ShapedType>();
2013   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2014   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2015 
2016   if (lhsVal && lhsVal.isSplat()) {
2017     if (lhsVal.getSplatValue()
2018             .cast<IntegerAttr>()
2019             .getValue()
2020             .isAllOnesValue()) {
2021       return rhs();
2022     }
2023 
2024     if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2025       return lhsVal;
2026     }
2027   }
2028 
2029   if (rhsVal && rhsVal.isSplat()) {
2030     if (rhsVal.getSplatValue()
2031             .cast<IntegerAttr>()
2032             .getValue()
2033             .isAllOnesValue()) {
2034       return lhs();
2035     }
2036 
2037     if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2038       return rhsVal;
2039     }
2040   }
2041 
2042   if (!rhsVal || !lhsVal) return {};
2043 
2044   llvm::SmallVector<APInt, 4> values;
2045   values.reserve(rhsVal.getNumElements());
2046   for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
2047     values.push_back(std::get<0>(it) & std::get<1>(it));
2048   }
2049 
2050   return DenseIntElementsAttr::get(rType, values);
2051 }
2052 
fold(ArrayRef<Attribute> operands)2053 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
2054   if (lhs() == rhs()) return lhs();
2055 
2056   auto rType = getType().cast<ShapedType>();
2057   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2058   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2059 
2060   if (lhsVal && lhsVal.isSplat()) {
2061     if (lhsVal.getSplatValue()
2062             .cast<IntegerAttr>()
2063             .getValue()
2064             .isAllOnesValue()) {
2065       return lhsVal;
2066     }
2067 
2068     if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2069       return rhs();
2070     }
2071   }
2072 
2073   if (rhsVal && rhsVal.isSplat()) {
2074     if (rhsVal.getSplatValue()
2075             .cast<IntegerAttr>()
2076             .getValue()
2077             .isAllOnesValue()) {
2078       return rhsVal;
2079     }
2080 
2081     if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2082       return lhs();
2083     }
2084   }
2085 
2086   if (!rhsVal || !lhsVal) return {};
2087 
2088   llvm::SmallVector<APInt, 4> values;
2089   values.reserve(rhsVal.getNumElements());
2090   for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
2091     values.push_back(std::get<0>(it) | std::get<1>(it));
2092   }
2093 
2094   return DenseIntElementsAttr::get(rType, values);
2095 }
2096 
fold(ArrayRef<Attribute> operands)2097 OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
2098   auto rType = getType().cast<ShapedType>();
2099   if (lhs() == rhs()) {
2100     Builder builder(getContext());
2101     return builder.getZeroAttr(rType);
2102   }
2103 
2104   auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2105   auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2106 
2107   if (lhsVal && lhsVal.isSplat()) {
2108     if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2109       return rhs();
2110     }
2111   }
2112 
2113   if (rhsVal && rhsVal.isSplat()) {
2114     if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2115       return lhs();
2116     }
2117   }
2118 
2119   if (!rhsVal || !lhsVal) return {};
2120 
2121   llvm::SmallVector<APInt, 4> values;
2122   values.reserve(rhsVal.getNumElements());
2123   for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
2124     values.push_back(std::get<0>(it) ^ std::get<1>(it));
2125   }
2126 
2127   return DenseIntElementsAttr::get(rType, values);
2128 }
2129 
2130 //===----------------------------------------------------------------------===//
2131 // MapOp
2132 //===----------------------------------------------------------------------===//
2133 
Verify(MapOp op)2134 static LogicalResult Verify(MapOp op) {
2135   // Checks if the number of `operands` match the arity of the map `computation`
2136   // region.
2137   auto& computation_block = op.computation().front();
2138   auto computation_args = computation_block.getArguments();
2139   if (op.operands().size() != computation_args.size())
2140     return op.emitOpError()
2141            << "expects number of operands to match the arity "
2142               "of map computation, but got: "
2143            << op.operands().size() << " and " << computation_args.size();
2144 
2145   // The parameters of computation should all be scalars and match the element
2146   // type of operands.
2147   auto operand_type = op.operands()[0].getType().cast<TensorType>();
2148   auto operand_elem_ty = operand_type.getElementType();
2149 
2150   for (auto indexed_arg : llvm::enumerate(computation_args)) {
2151     auto arg_type = indexed_arg.value().getType().dyn_cast<TensorType>();
2152     if (!arg_type || arg_type.getRank() != 0)
2153       return op.emitOpError()
2154              << "computation arguments must be 0-rank tensor, but got: arg #"
2155              << indexed_arg.index() << " of type "
2156              << indexed_arg.value().getType();
2157     if (arg_type.getElementType() != operand_elem_ty) {
2158       return op.emitOpError()
2159              << "element type of operands and computation arguments must "
2160                 "match, but got: "
2161              << operand_elem_ty << " and " << arg_type.getElementType();
2162     }
2163   }
2164 
2165   // Mapped computation must return single output
2166   auto computation_outputs = computation_block.getTerminator()->getOperands();
2167   if (computation_outputs.size() != 1)
2168     return op.emitOpError()
2169            << "computation must return single output, but got: "
2170            << computation_outputs.size();
2171 
2172   // The output of computation must be scalar and have the same element type
2173   // as op result.
2174   auto computation_output_type =
2175       computation_outputs[0].getType().dyn_cast<TensorType>();
2176   if (!computation_output_type || computation_output_type.getRank() != 0)
2177     return op.emitOpError()
2178            << "computation must return 0-rank tensor, but got: "
2179            << computation_outputs[0].getType();
2180 
2181   auto result_type = op.getType().cast<TensorType>();
2182   if (computation_output_type.getElementType() != result_type.getElementType())
2183     return op.emitOpError() << "element type of result and computation output "
2184                                "must match, but got: "
2185                             << result_type.getElementType() << " and "
2186                             << computation_output_type.getElementType();
2187 
2188   // Checks that the requested map dimension numbers are monotonically
2189   // increasing.
2190   auto values = op.dimensions().getValues<int64_t>();
2191   auto dimensions = std::vector<int64_t>{values.begin(), values.end()};
2192   for (int i = 0, e = dimensions.size(); i < e; ++i) {
2193     if (dimensions[i] != i)
2194       return op.emitOpError() << "requires monotonically increasing dimension "
2195                                  "numbers, but got: "
2196                               << op.dimensions();
2197   }
2198 
2199   // Checks that number of dimensions of operands matches the size of
2200   // `dimensions` since we currently only support mapping across all
2201   // dimensions: i.e., scalar map functions.
2202   if (operand_type.hasRank()) {
2203     if (dimensions.size() != operand_type.getShape().size())
2204       return op.emitOpError()
2205              << "applied to a subset of dimensions currently not supported: "
2206                 "operand dimensions = "
2207              << operand_type.getShape().size()
2208              << ", requested map dimensions size = " << dimensions.size();
2209   }
2210 
2211   return success();
2212 }
2213 
fold(ArrayRef<Attribute> operands)2214 OpFoldResult MapOp::fold(ArrayRef<Attribute> operands) {
2215   mlir::Block& bb = computation().front();
2216   mlir::Operation& front_op = bb.front();
2217 
2218   auto ret_op = mlir::dyn_cast<ReturnOp>(front_op);
2219   if (!ret_op) return nullptr;
2220   if (ret_op.results().size() != 1) return nullptr;
2221 
2222   for (mlir::BlockArgument barg : bb.getArguments()) {
2223     if (barg == ret_op.results()[0]) return getOperands()[barg.getArgNumber()];
2224   }
2225   return nullptr;
2226 }
2227 
2228 //===----------------------------------------------------------------------===//
2229 // RecvOp
2230 //===----------------------------------------------------------------------===//
2231 
2232 // Checks that the result type is of the form `tuple<any_type, mhlo::token>`
Verify(RecvOp op)2233 static LogicalResult Verify(RecvOp op) {
2234   auto result_ty = op.getResult().getType().cast<TupleType>();
2235   auto subtypes = result_ty.getTypes();
2236   if (subtypes.size() != 2)
2237     return op.emitOpError()
2238            << "result is expected to be a tuple of size 2, but got "
2239            << subtypes.size();
2240   if (!subtypes[1].isa<TokenType>())
2241     return op.emitOpError() << "second element of result tuple is expected to "
2242                                "be of token type, but got "
2243                             << subtypes[1];
2244   return success();
2245 }
2246 
2247 //===----------------------------------------------------------------------===//
2248 // CopyOp
2249 //===----------------------------------------------------------------------===//
2250 
fold(ArrayRef<Attribute> operands)2251 OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
2252 
2253 //===----------------------------------------------------------------------===//
2254 // ReduceWindowOp
2255 //===----------------------------------------------------------------------===//
2256 
2257 // For reduce-window, all `inputs` need to have compatible shapes.
Verify(ReduceWindowOp op)2258 static LogicalResult Verify(ReduceWindowOp op) {
2259   if (failed(verifyCompatibleShapes(op.inputs().getTypes())))
2260     return op.emitOpError() << "requires same shape for all inputs";
2261   return success();
2262 }
2263 
2264 // Get the operation used for reduction applied to `result_index`th result. Its
2265 // expected to be a binary operation that consumes `result_index`th and
2266 // `result_index + operands().size`th arguments of the body.
getReductionOp(int result_index)2267 Operation* ReduceWindowOp::getReductionOp(int result_index) {
2268   auto return_op = cast<ReturnOp>(body().front().getTerminator());
2269   Operation* compute_op = return_op.results()[result_index].getDefiningOp();
2270   if (compute_op->getNumOperands() != 2) return nullptr;
2271   auto arg0 = compute_op->getOperand(0).dyn_cast<BlockArgument>();
2272   auto arg1 = compute_op->getOperand(1).dyn_cast<BlockArgument>();
2273   if (!arg0 || !arg1) return nullptr;
2274   int arg0_num = arg0.getArgNumber();
2275   int arg1_num = arg1.getArgNumber();
2276   int other_arg_index = result_index + inputs().size();
2277   if (arg0_num == result_index && arg1_num == other_arg_index)
2278     return compute_op;
2279   if (arg0_num == other_arg_index && arg1_num == result_index &&
2280       compute_op->hasTrait<mlir::OpTrait::IsCommutative>())
2281     return compute_op;
2282   return nullptr;
2283 }
2284 
2285 //===----------------------------------------------------------------------===//
2286 // ReverseOp
2287 //===----------------------------------------------------------------------===//
2288 
fold(ArrayRef<Attribute> operands)2289 OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
2290   auto input = operand();
2291 
2292   // No dimensions to reverse.
2293   if (dimensions().getNumElements() == 0) return input;
2294 
2295   llvm::SmallVector<APInt, 5> new_dims;
2296   new_dims.reserve(dimensions().getNumElements());
2297 
2298   auto shaped_type = input.getType().cast<ShapedType>();
2299   for (auto dim : dimensions().getValues<APInt>()) {
2300     if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) {
2301       return nullptr;
2302     }
2303   }
2304 
2305   return input;
2306 }
2307 
2308 //===----------------------------------------------------------------------===//
2309 // ReduceOp
2310 //===----------------------------------------------------------------------===//
2311 
2312 // Returns the result type after reducing operand of the given type across the
2313 // specified dimensions.
GetReduceResultType(Type operand_ty,DenseIntElementsAttr dimensions,Builder * builder)2314 static TensorType GetReduceResultType(Type operand_ty,
2315                                       DenseIntElementsAttr dimensions,
2316                                       Builder* builder) {
2317   Type element_ty = getElementTypeOrSelf(operand_ty);
2318 
2319   auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>();
2320   if (!ranked_ty) return UnrankedTensorType::get(element_ty);
2321 
2322   int64_t rank = ranked_ty.getRank();
2323   llvm::SmallVector<bool, 4> dims_mask(rank, false);
2324   for (int64_t dim : dimensions.getValues<int64_t>()) dims_mask[dim] = true;
2325 
2326   SmallVector<int64_t, 4> shape;
2327   for (int64_t i = 0; i < rank; ++i) {
2328     if (!dims_mask[i]) shape.push_back(ranked_ty.getDimSize(i));
2329   }
2330 
2331   return RankedTensorType::get(shape, element_ty);
2332 }
2333 
build(OpBuilder & builder,OperationState & state,ValueRange inputs,ValueRange init_values,DenseIntElementsAttr dimensions)2334 void ReduceOp::build(OpBuilder& builder, OperationState& state,
2335                      ValueRange inputs, ValueRange init_values,
2336                      DenseIntElementsAttr dimensions) {
2337   SmallVector<Type, 1> result_ty;
2338   result_ty.reserve(inputs.size());
2339 
2340   for (Value input : inputs) {
2341     result_ty.push_back(
2342         GetReduceResultType(input.getType(), dimensions, &builder));
2343   }
2344   build(builder, state, result_ty, inputs, init_values, dimensions);
2345 }
2346 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2347 LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
2348                              SmallVectorImpl<OpFoldResult>& results) {
2349   // No dimensions to reduce.
2350   if (dimensions().getNumElements() == 0) {
2351     for (Value input : this->inputs()) {
2352       results.push_back(input);
2353     }
2354     return success();
2355   }
2356 
2357   // If all returned values in the ReduceOp region exists outside
2358   // the region replace the ReduceOp with those values.
2359   mlir::Block& bb = this->body().front();
2360   SmallVector<Value> replaced_results;
2361   if (auto ret_op = mlir::dyn_cast<ReturnOp>(bb.back())) {
2362     for (Value result : ret_op.results()) {
2363       if (result.getParentRegion() == ret_op->getParentRegion())
2364         return failure();
2365       replaced_results.push_back(result);
2366     }
2367 
2368     results.insert(results.end(), replaced_results.begin(),
2369                    replaced_results.end());
2370     return success();
2371   }
2372 
2373   return failure();
2374 }
2375 
2376 // Enable constant folding to occur within the region of the ReduceOp
2377 // by replacing block argument uses with constants if:
2378 //  1. All the ReduceOp operands are splat constants.
2379 //  2. The ReduceOp region consists of a single logical AND or logical OR.
2380 // The pattern leverages the idempotent property of the AND and OR operators
2381 // to determine the value of a reduction on splat constants. Other boolean
2382 // operators do not have this property, and need separate patterns to resolve
2383 // reductions of their splat constants.
2384 struct LowerBoolSplatConstantsIntoRegion : public OpRewritePattern<ReduceOp> {
2385   using OpRewritePattern<ReduceOp>::OpRewritePattern;
2386 
matchAndRewritemlir::mhlo::LowerBoolSplatConstantsIntoRegion2387   LogicalResult matchAndRewrite(ReduceOp op,
2388                                 PatternRewriter& rewriter) const override {
2389     mlir::Block& bb = op.body().front();
2390 
2391     // Ensure only a compute op and return op exist and the
2392     // compute op is an AND or OR op.
2393     if (bb.getOperations().size() != 2) return failure();
2394     if (!mlir::isa<mhlo::AndOp, mhlo::OrOp>(bb.front())) return failure();
2395 
2396     // Ensure all operands are splat constants.
2397     SmallVector<DenseElementsAttr, 4> barg_cst_attrs;
2398     for (auto inp_and_barg : llvm::zip(op.getOperands(), bb.getArguments())) {
2399       Value inp = std::get<0>(inp_and_barg);
2400       BlockArgument barg = std::get<1>(inp_and_barg);
2401       ConstOp cst = inp.getDefiningOp<ConstOp>();
2402       if (!cst) return failure();
2403 
2404       auto cst_attr = cst.value().dyn_cast_or_null<DenseElementsAttr>();
2405       if (!cst_attr.isSplat()) {
2406         return rewriter.notifyMatchFailure(op, "Must be splat constant.");
2407       }
2408 
2409       auto barg_shaped_type = barg.getType().dyn_cast<ShapedType>();
2410       if (!barg_shaped_type) return failure();
2411 
2412       auto barg_cst_attr =
2413           DenseElementsAttr::get(barg_shaped_type, cst_attr.getSplatValue());
2414       barg_cst_attrs.push_back(barg_cst_attr);
2415     }
2416 
2417     // Create new splat constants to replace block arguments.
2418     for (BlockArgument barg : bb.getArguments()) {
2419       int arg_idx = barg.getArgNumber();
2420       mhlo::ConstOp new_cst = rewriter.create<mhlo::ConstOp>(
2421           bb.front().getLoc(), barg.getType(), barg_cst_attrs[arg_idx]);
2422       barg.replaceAllUsesWith(new_cst);
2423     }
2424     return success();
2425   }
2426 };
2427 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2428 void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2429                                            MLIRContext* context) {
2430   results.insert<LowerBoolSplatConstantsIntoRegion>(context);
2431 }
2432 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2433 LogicalResult ReduceOp::reifyReturnTypeShapes(
2434     OpBuilder& builder, ValueRange operands,
2435     SmallVectorImpl<Value>& reifiedReturnShapes) {
2436   ReduceOp::Adaptor adaptor(operands);
2437   auto inputs = adaptor.inputs();
2438 
2439   auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
2440   // Not support unranked type a.t.m.
2441   if (!operand_type) return failure();
2442 
2443   Location loc = this->getLoc();
2444   SmallVector<Value, 4> shape_values;
2445   SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
2446   shape_values.reserve(operand_type.getRank());
2447   Type shape_scalar_type = builder.getIndexType();
2448   auto to_shape_scalar_type = [&](Value v) {
2449     return MaybeCastTo(builder, loc, v, shape_scalar_type);
2450   };
2451 
2452   for (const auto& element : llvm::enumerate(operand_type.getShape())) {
2453     int64_t idx = element.index();
2454     auto it = std::find(dimensions.begin(), dimensions.end(), idx);
2455     if (it != dimensions.end()) {
2456       continue;
2457     }
2458     Value value_dim = to_shape_scalar_type(
2459         builder.create<tensor::DimOp>(loc, inputs[0], element.index()));
2460     shape_values.push_back(value_dim);
2461   }
2462 
2463   Value output_shape = builder.create<tensor::FromElementsOp>(
2464       loc, shape_scalar_type, shape_values);
2465   for (size_t i = 0; i < inputs.size(); ++i) {
2466     reifiedReturnShapes.push_back(output_shape);
2467   }
2468 
2469   return success();
2470 }
2471 
2472 //===----------------------------------------------------------------------===//
2473 // RngNormalOp
2474 //===----------------------------------------------------------------------===//
2475 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2476 LogicalResult RngNormalOp::inferReturnTypeComponents(
2477     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
2478     DictionaryAttr attributes, RegionRange regions,
2479     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2480   return rngInferReturnTypeComponents(context, location, operands, attributes,
2481                                       regions, inferredReturnShapes);
2482 }
2483 
2484 //===----------------------------------------------------------------------===//
2485 // RngUniformOp
2486 //===----------------------------------------------------------------------===//
2487 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2488 LogicalResult RngUniformOp::inferReturnTypeComponents(
2489     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
2490     DictionaryAttr attributes, RegionRange regions,
2491     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2492   return rngInferReturnTypeComponents(context, location, operands, attributes,
2493                                       regions, inferredReturnShapes);
2494 }
2495 
2496 //===----------------------------------------------------------------------===//
2497 // SelectOp
2498 //===----------------------------------------------------------------------===//
2499 
Verify(SelectOp op)2500 static LogicalResult Verify(SelectOp op) {
2501   // TODO(jpienaar): Update to allow broadcastable and unranked inputs. This
2502   // corresponds to the client side HLO.
2503   return success();
2504 }
2505 
fold(ArrayRef<Attribute> operands)2506 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
2507   if (on_true() == on_false()) {
2508     return on_true();
2509   }
2510 
2511   auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2512   if (!predicate) {
2513     return {};
2514   }
2515 
2516   auto predicateTy = predicate.getType().cast<ShapedType>();
2517   if (!predicateTy.getElementType().isInteger(1)) {
2518     return {};
2519   }
2520 
2521   if (predicate.isSplat()) {
2522     return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
2523                                                            : on_false();
2524   }
2525 
2526   return {};
2527 }
2528 
2529 // Makes it such that a SelectOp that is a non-root operation in a DRR infers
2530 // the return type based on operand type.
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2531 LogicalResult SelectOp::inferReturnTypes(
2532     MLIRContext*, Optional<Location> location, ValueRange operands,
2533     DictionaryAttr attributes, RegionRange regions,
2534     SmallVectorImpl<Type>& inferredReturnTypes) {
2535   auto x_type = operands[1].getType();
2536   auto y_type = operands[2].getType();
2537   auto x_tensor = x_type.cast<TensorType>();
2538   auto y_tensor = y_type.cast<TensorType>();
2539 
2540   // Check for type compatibility in the select op. This requires that the two
2541   // non-predicate operands:
2542   //   (a) have the same element type
2543   //   (b) have compatible shapes (i.e. the same shape and/or at least one
2544   //       dynamic shape)
2545   if (x_tensor.getElementType() != y_tensor.getElementType() ||
2546       failed(mlir::verifyCompatibleShape(x_type, y_type))) {
2547     return emitOptionalError(location, "incompatible operand types: ", x_type,
2548                              " and ", y_type);
2549   }
2550 
2551   // TODO(lucyfox): Support output shape inference when operands have compatible
2552   // shapes. (The output shape should be the most general of the operand shapes
2553   // at each dimension.) For now, handle the straightforward cases and fail
2554   // otherwise. When this is fully implemented, this logic should move into
2555   // reusable functionality in MLIR Core.
2556   Type output_type;
2557   if (x_type == y_type || !x_tensor.hasRank()) {
2558     output_type = x_type;
2559   } else if (!y_tensor.hasRank()) {
2560     output_type = y_type;
2561   } else {
2562     return emitOptionalError(location,
2563                              "currently unsupported operand types: ", x_type,
2564                              " and ", y_type);
2565   }
2566   inferredReturnTypes.assign({output_type});
2567   return success();
2568 }
2569 
inferReturnTypeComponents(mlir::MLIRContext * ctx,llvm::Optional<mlir::Location> loc,ValueShapeRange operands,mlir::DictionaryAttr attributes,mlir::RegionRange regions,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferredShapedTypeComponents)2570 LogicalResult SelectOp::inferReturnTypeComponents(
2571     mlir::MLIRContext* ctx, llvm::Optional<mlir::Location> loc,
2572     ValueShapeRange operands, mlir::DictionaryAttr attributes,
2573     mlir::RegionRange regions,
2574     llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&
2575         inferredShapedTypeComponents) {
2576   llvm::SmallVector<Type, 4> inferredReturnTypes;
2577   const LogicalResult infer_types_status = inferReturnTypes(
2578       ctx, loc, operands, attributes, regions, inferredReturnTypes);
2579   if (infer_types_status.failed()) return infer_types_status;
2580 
2581   if (inferredReturnTypes.size() != 1) return failure();
2582 
2583   auto result_tensor_type =
2584       inferredReturnTypes[0].dyn_cast_or_null<TensorType>();
2585   if (!result_tensor_type) return failure();
2586 
2587   mlir::Type element_type =
2588       operands[1].getType().cast<TensorType>().getElementType();
2589   inferredShapedTypeComponents.push_back(
2590       {result_tensor_type.getShape(), element_type});
2591 
2592   return success();
2593 }
2594 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2595 LogicalResult SelectOp::reifyReturnTypeShapes(
2596     OpBuilder& builder, ValueRange operands,
2597     SmallVectorImpl<Value>& reifiedReturnShapes) {
2598   // For `hlo.select`, the first operand may be a scalar.
2599   return deriveShapeFromOperand(&builder, getOperation(), operands[1],
2600                                 &reifiedReturnShapes);
2601 }
2602 
2603 //===----------------------------------------------------------------------===//
2604 // SetDimensionSizeOp
2605 //===----------------------------------------------------------------------===//
2606 
Verify(SetDimensionSizeOp op)2607 static LogicalResult Verify(SetDimensionSizeOp op) {
2608   if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
2609     if (size.getRank() != 0)
2610       return op.emitOpError() << "size operand should be of rank-0";
2611   }
2612 
2613   return VerifyDimAttr(op);
2614 }
2615 
fold(ArrayRef<Attribute> operands)2616 OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
2617   DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2618   if (input) return input;
2619 
2620   DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2621   if (!size || !size.isSplat()) return {};
2622 
2623   auto ty = getType().dyn_cast<RankedTensorType>();
2624   if (!ty) return {};
2625 
2626   int64_t dim_size = ty.getDimSize(dimension());
2627   if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
2628     return operand();
2629   return {};
2630 }
2631 
2632 //===----------------------------------------------------------------------===//
2633 // PadOp
2634 //===----------------------------------------------------------------------===//
2635 
Verify(PadOp op)2636 static LogicalResult Verify(PadOp op) {
2637   auto input_type = op.operand().getType().cast<RankedTensorType>();
2638   auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
2639 
2640   if (pad_type.getRank() != 0) {
2641     return op.emitOpError(
2642         llvm::formatv("padding value type should be a rank-0 "
2643                       "tensor, is rank {0}",
2644                       pad_type.getRank()));
2645   }
2646 
2647   const auto& padding_low = op.edge_padding_low();
2648   if (padding_low.getType().getNumElements() != input_type.getRank()) {
2649     return op.emitOpError(llvm::formatv(
2650         "edge_padding_low length ({0}) must match operand rank ({1})",
2651         padding_low.getType().getNumElements(), input_type.getRank()));
2652   }
2653 
2654   const auto& padding_high = op.edge_padding_high();
2655   if (padding_high.getType().getNumElements() != input_type.getRank()) {
2656     return op.emitOpError(llvm::formatv(
2657         "edge_padding_high length ({0}) must match operand rank ({1})",
2658         padding_high.getType().getNumElements(), input_type.getRank()));
2659   }
2660 
2661   const auto& padding_interior = op.interior_padding();
2662   if (padding_interior.getType().getNumElements() != input_type.getRank()) {
2663     return op.emitOpError(llvm::formatv(
2664         "interior_padding length ({0}) must match operand rank ({1})",
2665         padding_interior.getType().getNumElements(), input_type.getRank()));
2666   }
2667 
2668   auto input_shape = input_type.getShape();
2669   auto output_shape =
2670       op.getResult().getType().cast<RankedTensorType>().getShape();
2671   if (input_shape.size() != output_shape.size()) {
2672     return op.emitOpError(
2673         llvm::formatv("operand rank ({0}) and result rank({0}) should match",
2674                       input_shape.size(), output_shape.size()));
2675   }
2676 
2677   for (int i = 0, e = input_shape.size(); i < e; i++) {
2678     int64_t padding_low_val = padding_low.getValue<IntegerAttr>(i).getInt();
2679     int64_t padding_high_val = padding_high.getValue<IntegerAttr>(i).getInt();
2680     int64_t padding_interior_val =
2681         padding_interior.getValue<IntegerAttr>(i).getInt();
2682     int64_t expected_output =
2683         input_shape[i] + padding_low_val + padding_high_val +
2684         std::max<int64_t>(input_shape[i] - 1, 0LL) * padding_interior_val;
2685     if (expected_output != output_shape[i]) {
2686       return op.emitOpError(llvm::formatv(
2687           "expected output shape's dimension #{0} to be {1} but found {2}", i,
2688           expected_output, output_shape[i]));
2689     }
2690   }
2691 
2692   return success();
2693 }
2694 
fold(ArrayRef<Attribute> operands)2695 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
2696   // If all padding is zero then it is an identity pad.
2697   auto is_zero = [](const APInt& i) { return i == 0; };
2698   if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) &&
2699       llvm::all_of(edge_padding_high().getIntValues(), is_zero) &&
2700       llvm::all_of(interior_padding().getIntValues(), is_zero))
2701     return operand();
2702 
2703   // If any padding is negative then it isn't supported by the folder (yet).
2704   auto is_negative = [](const APInt& i) { return i.slt(0); };
2705   if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) &&
2706       llvm::all_of(edge_padding_high().getIntValues(), is_negative) &&
2707       llvm::all_of(interior_padding().getIntValues(), is_negative))
2708     return {};
2709 
2710   DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2711   DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2712   RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
2713   if (!input || !input.getType().hasRank() || !padding || !return_type ||
2714       !return_type.hasStaticShape())
2715     return {};
2716 
2717   // Fill the full result tensor with the padding value.
2718   llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
2719                                          padding.getValue({}));
2720 
2721   auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
2722                        llvm::ArrayRef<int64_t> shape) {
2723     for (int64_t i = index.size() - 1; i >= 0; --i) {
2724       ++index[i];
2725       if (index[i] < shape[i]) return;
2726       index[i] = 0;
2727     }
2728   };
2729 
2730   // Iterate over all elements of the input tensor and copy it to the correct
2731   // location in the output tensor.
2732   llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
2733   uint64_t num_elements = input.getNumElements();
2734   for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
2735     uint64_t result_idx = 0;
2736     uint64_t idx_multiplyer = 1;
2737     for (int64_t i = index.size() - 1; i >= 0; --i) {
2738       result_idx +=
2739           (edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
2740            index[i] *
2741                (interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
2742           idx_multiplyer;
2743       idx_multiplyer *= return_type.getDimSize(i);
2744     }
2745     result[result_idx] = input.getValue(index);
2746     next_index(index, input.getType().getShape());
2747   }
2748   return DenseElementsAttr::get(return_type, result);
2749 }
2750 
2751 //===----------------------------------------------------------------------===//
2752 // DynamicPadOp
2753 //===----------------------------------------------------------------------===//
2754 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2755 void DynamicPadOp::getCanonicalizationPatterns(
2756     OwningRewritePatternList& results, MLIRContext* context) {
2757   results.insert<DPadToPad>(context);
2758 }
2759 
Verify(DynamicPadOp op)2760 static LogicalResult Verify(DynamicPadOp op) {
2761   auto input_type = op.operand().getType().dyn_cast<RankedTensorType>();
2762   // If operand is unranked, there is very little to verify statically.
2763   if (!input_type) return success();
2764   int input_rank = input_type.getRank();
2765 
2766   auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
2767   if (pad_type.getRank() != 0) {
2768     return op.emitOpError() << "padding value type should be a rank-0";
2769   }
2770 
2771   auto padding_low_type =
2772       op.edge_padding_low().getType().cast<RankedTensorType>();
2773   if (padding_low_type.getNumElements() != input_rank) {
2774     return op.emitOpError()
2775            << "edge_padding_low length(" << padding_low_type.getNumElements()
2776            << ") must match operand rank(" << input_rank << ").";
2777   }
2778 
2779   auto padding_high_type =
2780       op.edge_padding_high().getType().cast<RankedTensorType>();
2781   if (padding_high_type.getNumElements() != input_rank) {
2782     return op.emitOpError()
2783            << "edge_padding_high length(" << padding_high_type.getNumElements()
2784            << ") must match operand rank(" << input_rank << ").";
2785   }
2786 
2787   auto interior_padding_type =
2788       op.interior_padding().getType().cast<RankedTensorType>();
2789   if (interior_padding_type.getNumElements() != input_rank) {
2790     return op.emitOpError()
2791            << "edge_padding_interior length("
2792            << interior_padding_type.getNumElements()
2793            << ") must match operand rank(" << input_rank << ").";
2794   }
2795 
2796   auto output_type = op.getResult().getType().dyn_cast<RankedTensorType>();
2797   // If result is unranked, there is very little to verify statically.
2798   if (!output_type) return success();
2799   int output_rank = output_type.getRank();
2800   if (input_rank != output_rank) {
2801     return op.emitOpError() << "operand rank(" << input_rank
2802                             << ") must match result(" << output_rank << ").";
2803   }
2804 
2805   return success();
2806 }
2807 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2808 LogicalResult DynamicPadOp::reifyReturnTypeShapes(
2809     OpBuilder& builder, ValueRange operands,
2810     SmallVectorImpl<Value>& reifiedReturnShapes) {
2811   DynamicPadOp::Adaptor adaptor(operands);
2812   Value operand = adaptor.operand();
2813   Value edge_padding_low = adaptor.edge_padding_low();
2814   Value edge_padding_high = adaptor.edge_padding_high();
2815   Value interior_padding = adaptor.interior_padding();
2816 
2817   auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
2818   // Not support unranked pad a.t.m.
2819   if (!operand_type) return failure();
2820 
2821   auto loc = this->getLoc();
2822   SmallVector<Value, 4> shape_values;
2823   shape_values.reserve(operand_type.getRank());
2824   Type shape_scalar_type =
2825       edge_padding_low.getType().cast<ShapedType>().getElementType();
2826 
2827   auto to_shape_scalar_type = [&](Value v) {
2828     return MaybeCastTo(builder, loc, v, shape_scalar_type);
2829   };
2830 
2831   Value zero = to_shape_scalar_type(builder.create<ConstantIndexOp>(loc, 0));
2832   Value one = to_shape_scalar_type(builder.create<ConstantIndexOp>(loc, 1));
2833 
2834   for (int idx : llvm::seq<int>(0, operand_type.getShape().size())) {
2835     Value value_dim =
2836         to_shape_scalar_type(builder.create<tensor::DimOp>(loc, operand, idx));
2837     Value offset = builder.create<ConstantIndexOp>(loc, idx);
2838     Value value_low =
2839         builder.create<tensor::ExtractOp>(loc, edge_padding_low, offset);
2840     Value value_high =
2841         builder.create<tensor::ExtractOp>(loc, edge_padding_high, offset);
2842     Value value_interior =
2843         builder.create<tensor::ExtractOp>(loc, interior_padding, offset);
2844     // output_size = input_size + padding_low + padding_high + interior *
2845     // max(input_size - 1, 0)
2846     Value value_dim_less_than_one =
2847         builder.create<CmpIOp>(loc, CmpIPredicate::slt, value_dim, one);
2848     Value interior_size = builder.create<MulIOp>(
2849         loc, value_interior,
2850         builder.create<mlir::SelectOp>(
2851             loc, value_dim_less_than_one, zero,
2852             builder.create<SubIOp>(loc, value_dim, one)));
2853     shape_values.push_back(builder.create<AddIOp>(
2854         loc,
2855         builder.create<AddIOp>(
2856             loc, builder.create<AddIOp>(loc, interior_size, value_dim),
2857             value_low),
2858         value_high));
2859   }
2860 
2861   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
2862       loc, shape_scalar_type, shape_values));
2863 
2864   return success();
2865 }
2866 
2867 //===----------------------------------------------------------------------===//
2868 // ReshapeOp
2869 //===----------------------------------------------------------------------===//
2870 
Verify(ReshapeOp op)2871 static LogicalResult Verify(ReshapeOp op) {
2872   // If the operand type is dynamically shaped there is nothing to verify.
2873   auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
2874   if (!operand_ty || !operand_ty.hasStaticShape()) return success();
2875 
2876   // If the operand type is statically shaped (not required) the number of
2877   // elements must match that of the result type.
2878   auto result_ty = op.getType().cast<RankedTensorType>();
2879   assert(result_ty && result_ty.hasStaticShape() &&
2880          "result type must be statically shaped");
2881   int64_t num_result_elements = result_ty.getNumElements();
2882   int64_t num_operand_elements = operand_ty.getNumElements();
2883   if (num_result_elements != num_operand_elements)
2884     return op.emitOpError()
2885            << "number of output elements (" << num_result_elements
2886            << ") doesn't match expected number of elements ("
2887            << num_operand_elements << ")";
2888 
2889   return success();
2890 }
2891 
fold(ArrayRef<Attribute> operands)2892 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
2893   if (getOperand().getType() == getType()) {
2894     return getOperand();
2895   }
2896 
2897   if (auto prev_op = getOperand().getDefiningOp<ReshapeOp>()) {
2898     setOperand(prev_op.getOperand());
2899     return getResult();
2900   }
2901 
2902   if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
2903     return elements.reshape(getResult().getType().cast<ShapedType>());
2904   }
2905 
2906   return {};
2907 }
2908 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2909 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2910                                             MLIRContext* context) {
2911   results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape,
2912                  EliminateRedundantReshape, EliminateIdentityReshape>(context);
2913 }
2914 
2915 //===----------------------------------------------------------------------===//
2916 // ReplicaId Op
2917 //===----------------------------------------------------------------------===//
2918 
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2919 LogicalResult ReplicaIdOp::inferReturnTypes(
2920     MLIRContext* context, Optional<Location>, ValueRange operands,
2921     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2922   inferredReturnTypes.push_back(RankedTensorType::get(
2923       /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
2924   return success();
2925 }
2926 
2927 //===----------------------------------------------------------------------===//
2928 // If Op
2929 //===----------------------------------------------------------------------===//
2930 
VerifyConditionalBranch(Operation * op,Region & region,Value operand,llvm::Twine branchName,llvm::Twine operandName)2931 static LogicalResult VerifyConditionalBranch(Operation* op, Region& region,
2932                                              Value operand,
2933                                              llvm::Twine branchName,
2934                                              llvm::Twine operandName) {
2935   mlir::Block& entryBlock = region.front();
2936   if (entryBlock.getNumArguments() != 1)
2937     return op->emitOpError()
2938            << branchName << " block should have single argument, but found "
2939            << entryBlock.getNumArguments();
2940 
2941   Type operandType = operand.getType();
2942   Type branchArgType = entryBlock.getArgument(0).getType();
2943   if (branchArgType != operandType)
2944     return op->emitOpError()
2945            << operandName << " type (" << operandType << ") does not match "
2946            << branchName << " block arg type (" << branchArgType << ")";
2947   TypeRange branchReturnTypes = entryBlock.getTerminator()->getOperandTypes();
2948   if (branchReturnTypes != op->getResultTypes())
2949     return op->emitOpError()
2950            << branchName << " returned types (" << branchReturnTypes
2951            << ") do not match op result types (" << op->getResultTypes() << ")";
2952 
2953   return success();
2954 }
2955 
Verify(IfOp op)2956 static LogicalResult Verify(IfOp op) {
2957   if (failed(VerifyConditionalBranch(op, op.true_branch(), op.true_arg(),
2958                                      /*branchName=*/"true_branch",
2959                                      /*operandName=*/"true_arg"))) {
2960     return failure();
2961   }
2962 
2963   if (failed(VerifyConditionalBranch(op, op.false_branch(), op.false_arg(),
2964                                      /*branchName=*/"false_branch",
2965                                      /*operandName=*/"false_arg"))) {
2966     return failure();
2967   }
2968   return success();
2969 }
2970 
InlineIfConstantCondition(IfOp ifOp,PatternRewriter & rewriter)2971 static LogicalResult InlineIfConstantCondition(IfOp ifOp,
2972                                                PatternRewriter& rewriter) {
2973   DenseIntElementsAttr pred_attr;
2974   if (!matchPattern(ifOp.pred(), m_Constant(&pred_attr))) return failure();
2975 
2976   if (pred_attr.getSplatValue<BoolAttr>().getValue()) {
2977     ReplaceOpWithRegion(rewriter, ifOp, ifOp.true_branch(), ifOp.true_arg());
2978   } else {
2979     ReplaceOpWithRegion(rewriter, ifOp, ifOp.false_branch(), ifOp.false_arg());
2980   }
2981   return success();
2982 }
2983 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2984 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2985                                        MLIRContext* context) {
2986   results.add(&InlineIfConstantCondition);
2987 }
2988 
2989 //===----------------------------------------------------------------------===//
2990 // Case Op
2991 //===----------------------------------------------------------------------===//
2992 
Verify(CaseOp op)2993 static LogicalResult Verify(CaseOp op) {
2994   auto num_branches = op.branches().size();
2995   if (op.branch_operands().size() != num_branches)
2996     return op.emitOpError() << " number of branches (" << num_branches
2997                             << ") does not match number of branch operands ("
2998                             << op.branch_operands().size() << ")";
2999 
3000   for (unsigned i = 0; i < num_branches; ++i)
3001     if (failed(VerifyConditionalBranch(
3002             op, op.branches()[i], op.branch_operands()[i],
3003             /*branchName=*/"branch " + Twine(i),
3004             /*operandName=*/"branch_operand " + Twine(i))))
3005       return failure();
3006 
3007   return success();
3008 }
3009 
InlineCaseConstantCondition(CaseOp caseOp,PatternRewriter & rewriter)3010 static LogicalResult InlineCaseConstantCondition(CaseOp caseOp,
3011                                                  PatternRewriter& rewriter) {
3012   DenseIntElementsAttr index_attr;
3013   if (!matchPattern(caseOp.index(), m_Constant(&index_attr))) {
3014     return failure();
3015   }
3016   int64_t index =
3017       index_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
3018   // For an OOB index, the last branch is executed as the default branch:
3019   // https://www.tensorflow.org/xla/operation_semantics#conditional
3020   if (index < 0 || index >= caseOp.getNumRegions())
3021     index = caseOp.getNumRegions() - 1;
3022 
3023   Region& region = caseOp.getRegion(index);
3024   if (!llvm::hasSingleElement(region)) return failure();
3025   ReplaceOpWithRegion(rewriter, caseOp, region,
3026                       caseOp.branch_operands()[index]);
3027   return success();
3028 }
3029 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3030 void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
3031                                          MLIRContext* context) {
3032   results.add(&InlineCaseConstantCondition);
3033 }
3034 
3035 //===----------------------------------------------------------------------===//
3036 // SqrtOp
3037 //===----------------------------------------------------------------------===//
3038 
fold(ArrayRef<Attribute> operands)3039 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
3040   auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
3041   if (!val) return {};
3042 
3043   auto type = getElementTypeOrSelf(getType());
3044   if (!type.isF32() && !type.isF64()) return {};
3045 
3046   auto shaped_type = getType().cast<ShapedType>();
3047   if (!shaped_type.hasStaticShape()) return {};
3048 
3049   int bit_width = type.getIntOrFloatBitWidth();
3050   llvm::SmallVector<APFloat, 4> values;
3051   values.reserve(val.getNumElements());
3052   for (auto it : val.getFloatValues()) {
3053     double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble();
3054     if (value < 0) return {};
3055     value = std::sqrt(value);
3056     if (bit_width == 32)
3057       values.emplace_back(static_cast<float>(value));
3058     else
3059       values.emplace_back(value);
3060   }
3061   return DenseFPElementsAttr::get(shaped_type, values);
3062 }
3063 
3064 //===----------------------------------------------------------------------===//
3065 // UnaryOps
3066 //===----------------------------------------------------------------------===//
3067 
3068 template <typename Op, typename ElementType = Type, typename ValType,
3069           typename Convert>
UnaryFolder(Op * op,ArrayRef<Attribute> attrs)3070 static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
3071   if (!attrs[0]) return {};
3072 
3073   DenseElementsAttr val = attrs[0].dyn_cast<DenseElementsAttr>();
3074   if (!val) return {};
3075 
3076   ShapedType type = op->getType().template cast<ShapedType>();
3077   if (!type.hasStaticShape()) {
3078     return {};
3079   }
3080 
3081   Type etype = type.getElementType();
3082 
3083   // Evaluate for integer values.
3084   if (!etype.isa<ElementType>()) {
3085     return {};
3086   }
3087 
3088   SmallVector<ValType, 6> values;
3089   values.reserve(val.getNumElements());
3090   for (const auto v : val.getValues<ValType>()) {
3091     values.push_back(Convert()(v));
3092   }
3093 
3094   return DenseElementsAttr::get(type, values);
3095 }
3096 
3097 struct round {
operator ()mlir::mhlo::round3098   APFloat operator()(const APFloat& f) {
3099     APFloat r = f;
3100     r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
3101     return r;
3102   }
3103 };
3104 
3105 struct logical_not {
operator ()mlir::mhlo::logical_not3106   APInt operator()(const APInt& i) {
3107     return APInt(i.getBitWidth(), static_cast<uint64_t>(!i));
3108   }
3109 };
3110 
3111 template <typename FloatOrInt>
3112 struct sign {
computemlir::mhlo::sign3113   APFloat compute(const APFloat& f) {
3114     if (f.isZero() || f.isNaN()) return f;
3115     double value = f.isNegative() ? -1.0 : 1.0;
3116     APFloat val(value);
3117     bool unused;
3118     val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused);
3119     return val;
3120   }
3121 
computemlir::mhlo::sign3122   APInt compute(const APInt& i) {
3123     APInt r = i;
3124     if (r == 0) return r;
3125     if (r.isNegative()) {
3126       return APInt(r.getBitWidth(), -1, /*isSigned=*/true);
3127     }
3128     return APInt(r.getBitWidth(), 1, /*isSigned=*/true);
3129   }
3130 
operator ()mlir::mhlo::sign3131   FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); }
3132 };
3133 
3134 #define UNARY_FOLDER(Op, Func)                                                \
3135   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                          \
3136     if (getElementTypeOrSelf(getType()).isa<FloatType>())                     \
3137       return UnaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
3138     if (getElementTypeOrSelf(getType()).isa<IntegerType>())                   \
3139       return UnaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs);   \
3140     return {};                                                                \
3141   }
3142 
3143 #define UNARY_FOLDER_INT(Op, Func)                                   \
3144   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                 \
3145     if (getElementTypeOrSelf(getType()).isa<IntegerType>())          \
3146       return UnaryFolder<Op, IntegerType, APInt, Func>(this, attrs); \
3147     return {};                                                       \
3148   }
3149 
3150 #define UNARY_FOLDER_FLOAT(Op, Func)                                 \
3151   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                 \
3152     if (getElementTypeOrSelf(getType()).isa<FloatType>())            \
3153       return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
3154     return {};                                                       \
3155   }
3156 
3157 UNARY_FOLDER(NegOp, std::negate);
3158 UNARY_FOLDER(SignOp, sign);
3159 UNARY_FOLDER_INT(NotOp, logical_not);
3160 UNARY_FOLDER_FLOAT(RoundOp, round);
3161 
3162 #undef UNARY_FOLDER
3163 #undef UNARY_FOLDER_INT
3164 #undef UNARY_FOLDER_FLOAT
3165 
3166 //===----------------------------------------------------------------------===//
3167 // BinaryOps
3168 //===----------------------------------------------------------------------===//
3169 
3170 namespace {
3171 
3172 // Updates the element type of a (presumed) tensor type 'x', returning either
3173 // a permuted UnrankedTensorType or RankedTensorType.
UpdateResultElementType(Builder * builder,Type x,Type element_type)3174 static Type UpdateResultElementType(Builder* builder, Type x,
3175                                     Type element_type) {
3176   auto x_ranked = x.dyn_cast<RankedTensorType>();
3177   if (!x_ranked) {
3178     return UnrankedTensorType::get(element_type);
3179   }
3180 
3181   auto shape_x = x_ranked.getShape();
3182   return RankedTensorType::get(shape_x, element_type);
3183 }
3184 }  // namespace
3185 
3186 template <typename Op, typename ElementType = Type, typename ValType,
3187           typename Convert>
BinaryFolder(Op * op,ArrayRef<Attribute> attrs)3188 static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
3189   if (!attrs[0] || !attrs[1]) return {};
3190 
3191   DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
3192   DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
3193   if (!lhs || !rhs) return {};
3194 
3195   ShapedType type = op->getType().template cast<ShapedType>();
3196   if (!type.hasStaticShape()) {
3197     return {};
3198   }
3199 
3200   Type etype = type.getElementType();
3201 
3202   // Evaluate for integer values.
3203   if (!etype.isa<ElementType>()) {
3204     return {};
3205   }
3206 
3207   SmallVector<ValType, 6> values;
3208   values.reserve(lhs.getNumElements());
3209   for (const auto zip :
3210        llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) {
3211     values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
3212   }
3213 
3214   return DenseElementsAttr::get(type, values);
3215 }
3216 
3217 template <typename T>
3218 struct divide : std::divides<T> {};
3219 
3220 template <>
3221 struct divide<APInt> {
operator ()mlir::mhlo::divide3222   APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
3223 };
3224 
3225 template <typename T>
3226 struct remainder : std::modulus<T> {};
3227 
3228 template <>
3229 struct remainder<APInt> {
operator ()mlir::mhlo::remainder3230   APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
3231 };
3232 
3233 template <>
3234 struct remainder<APFloat> {
operator ()mlir::mhlo::remainder3235   APFloat operator()(const APFloat& a, const APFloat& b) const {
3236     APFloat result(a);
3237     result.remainder(b);
3238     return result;
3239   }
3240 };
3241 
3242 template <typename T>
3243 struct max {
operator ()mlir::mhlo::max3244   T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
3245 };
3246 
3247 template <>
3248 struct max<APInt> {
operator ()mlir::mhlo::max3249   APInt operator()(const APInt& a, const APInt& b) const {
3250     return llvm::APIntOps::smax(a, b);
3251   }
3252 };
3253 
3254 template <typename T>
3255 struct min {
operator ()mlir::mhlo::min3256   T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
3257 };
3258 
3259 template <>
3260 struct min<APInt> {
operator ()mlir::mhlo::min3261   APInt operator()(const APInt& a, const APInt& b) const {
3262     return llvm::APIntOps::smin(a, b);
3263   }
3264 };
3265 
3266 #define BINARY_FOLDER(Op, Func)                                                \
3267   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                           \
3268     if (getElementTypeOrSelf(getType()).isa<FloatType>())                      \
3269       return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
3270     if (getElementTypeOrSelf(getType()).isa<IntegerType>())                    \
3271       return BinaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs);   \
3272     return {};                                                                 \
3273   }
3274 
3275 // Addition, subtraction and multiplication use the std:: versions of the ops.
3276 // Due to the other ops behaving differently in signed vs unsigned integers,
3277 // APInts need a special implementation. Currently, it replicates signed int
3278 // op behavior.
3279 BINARY_FOLDER(AddOp, std::plus);
3280 BINARY_FOLDER(SubOp, std::minus);
3281 BINARY_FOLDER(MulOp, std::multiplies);
3282 BINARY_FOLDER(DivOp, divide);
3283 BINARY_FOLDER(RemOp, remainder);
3284 BINARY_FOLDER(MaxOp, max);
3285 BINARY_FOLDER(MinOp, min);
3286 
3287 #undef BINARY_FOLDER
3288 
3289 //===----------------------------------------------------------------------===//
3290 // SliceOp
3291 //===----------------------------------------------------------------------===//
3292 
3293 // Returns output dimension size for slice result for the given arguments.
3294 // Returns -1 if arguments are illegal.
InferSliceDim(int64_t input_dim,int64_t start,int64_t end,int64_t stride)3295 static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
3296                              int64_t stride) {
3297   if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
3298       stride == 0)
3299     return -1;
3300 
3301   return llvm::divideCeil(end - start, stride);
3302 }
3303 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3304 LogicalResult SliceOp::inferReturnTypes(
3305     MLIRContext* context, Optional<Location> location, ValueRange operands,
3306     DictionaryAttr attributes, RegionRange regions,
3307     SmallVectorImpl<Type>& inferredReturnTypes) {
3308   SliceOpAdaptor slice(operands, attributes);
3309   // TODO(jpienaar): Update this code after refactoring verify.
3310   if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) {
3311     return failure();
3312   }
3313 
3314   Type ty = slice.operand().getType();
3315   RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
3316   if (!ranked_ty) {
3317     // The operand type is unranked, so the best we can infer for the result
3318     // type is an unranked tensor with the same element type as the operand
3319     // type.
3320     inferredReturnTypes.assign({ty});
3321     return success();
3322   }
3323 
3324   ShapedType attr_ty = slice.start_indices().getType();
3325   if (attr_ty.getRank() != 1) {
3326     return emitOptionalError(location, "start_indices has rank ",
3327                              attr_ty.getRank(), " instead of required rank 1");
3328   }
3329 
3330   int64_t rank = ranked_ty.getRank();
3331   if (attr_ty.getNumElements() != rank) {
3332     return emitOptionalError(
3333         location, "the number of elements in start_indices (",
3334         attr_ty.getNumElements(), ") does not match the rank of the operand (",
3335         rank, ")");
3336   }
3337 
3338   if (!attr_ty.getElementType().isSignlessInteger(64) ||
3339       slice.limit_indices().getType() != attr_ty ||
3340       slice.strides().getType() != attr_ty) {
3341     // Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
3342     // having been verified at this point. Emit an error message that matches
3343     // the one that would be reported by AllTypesMatch for a more consistent
3344     // user experience.
3345     // TODO(b/171567182): Clean this up after AllTypesMatch has been refactored.
3346     return emitOptionalError(location,
3347                              "failed to verify that all of {start_indices, "
3348                              "limit_indices, strides} have same type");
3349   }
3350 
3351   SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
3352   SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
3353   SmallVector<int64_t, 4> stride_vals(slice.strides().getValues<int64_t>());
3354 
3355   SmallVector<int64_t, 4> shape;
3356   shape.reserve(rank);
3357   for (int64_t i = 0, e = rank; i != e; i++) {
3358     shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
3359                                   stride_vals[i]));
3360   }
3361   inferredReturnTypes.assign(
3362       {RankedTensorType::get(shape, ranked_ty.getElementType())});
3363   return success();
3364 }
3365 
3366 template <typename I, typename E>
SliceElements(I values,ArrayRef<int64_t> sizes,ArrayRef<int64_t> starts,ArrayRef<int64_t> limits,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<E> * out_values)3367 static void SliceElements(I values, ArrayRef<int64_t> sizes,
3368                           ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
3369                           ArrayRef<int64_t> strides,
3370                           llvm::SmallVectorImpl<E>* out_values) {
3371   assert(starts.size() == limits.size());
3372   assert(starts.size() == strides.size());
3373   if (starts.empty()) return;
3374 
3375   int64_t start = starts.front();
3376   int64_t limit = limits.front();
3377   int64_t stride = strides.front();
3378   if (starts.size() == 1) {
3379     for (int i = start; i < limit; i += stride) {
3380       out_values->push_back(*(values + i));
3381     }
3382     return;
3383   }
3384 
3385   for (; start < limit; start += stride) {
3386     auto begin = values + start * sizes.front();
3387     SliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
3388                         limits.drop_front(), strides.drop_front(), out_values);
3389   }
3390 }
3391 
3392 template <typename I, typename E>
FoldSlice(SliceOp * op,I values)3393 static Attribute FoldSlice(SliceOp* op, I values) {
3394   auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
3395   auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
3396   auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
3397 
3398   auto result_type = op->operand().getType().cast<ShapedType>();
3399   if (!result_type.hasStaticShape()) return {};
3400 
3401   auto shape = result_type.getShape();
3402   int64_t count = result_type.getNumElements();
3403   if (count == 0) {
3404     return DenseElementsAttr::get<E>(
3405         op->getResult().getType().cast<ShapedType>(),
3406         /*list=*/{});
3407   }
3408 
3409   // Compute the striding for each dimension.
3410   llvm::SmallVector<int64_t, 6> sizes;
3411   sizes.reserve(shape.size());
3412   for (auto v : shape) {
3413     count = count / v;
3414     sizes.push_back(count);
3415   }
3416 
3417   llvm::SmallVector<E, 6> out_values;
3418   out_values.reserve(result_type.getNumElements());
3419   SliceElements<I, E>(values, sizes, start, limit, stride, &out_values);
3420 
3421   return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
3422                                 out_values);
3423 }
3424 
fold(ArrayRef<Attribute> operands)3425 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
3426   // Check if the SliceOp is a NoOp operation.
3427   auto operand_type = getOperand().getType().cast<ShapedType>();
3428   auto result_type = getResult().getType().cast<ShapedType>();
3429 
3430   if (operand_type.hasStaticShape() && result_type.hasStaticShape() &&
3431       (operand_type.getShape() == result_type.getShape())) {
3432     return getOperand();
3433   }
3434 
3435   if (operands.empty() || !operands.front()) return {};
3436 
3437   // Evaluate for statically valued inputs.
3438   DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
3439   if (!elements) return {};
3440 
3441   auto etype = elements.getType().getElementType();
3442   if (etype.isa<IntegerType>()) {
3443     return FoldSlice<DenseElementsAttr::IntElementIterator, APInt>(
3444         this, elements.getIntValues().begin());
3445   } else if (etype.isa<FloatType>()) {
3446     return FoldSlice<
3447         llvm::mapped_iterator<DenseElementsAttr::IntElementIterator,
3448                               std::function<APFloat(const APInt&)>>,
3449         APFloat>(this, elements.getFloatValues().begin());
3450   }
3451 
3452   return {};
3453 }
3454 
3455 namespace {
3456 // In cases where a concat is fed into a slice, it is possible the concat
3457 // can be simplified or bypassed. This checks which inputs to the concat are
3458 // used by the slice, either reducing the number of concatenated values or
3459 // entirely removes the concat.
3460 struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
3461   using OpRewritePattern<SliceOp>::OpRewritePattern;
3462 
matchAndRewritemlir::mhlo::__anon7800207c1511::SimplifyConcatSlice3463   LogicalResult matchAndRewrite(SliceOp slice,
3464                                 PatternRewriter& rewriter) const override {
3465     auto result_ty = slice.getType().cast<ShapedType>();
3466     if (!result_ty.hasStaticShape()) {
3467       return failure();
3468     }
3469 
3470     auto slice_input = slice.operand();
3471     auto slice_input_ty = slice_input.getType().cast<ShapedType>();
3472     auto concat = slice_input.getDefiningOp<ConcatenateOp>();
3473     if (!concat) {
3474       return failure();
3475     }
3476 
3477     auto dimension = concat.dimension();
3478 
3479     auto start = slice.start_indices().getIntValues();
3480     auto limit = slice.limit_indices().getIntValues();
3481 
3482     auto slice_start = (*(start.begin() + dimension)).getSExtValue();
3483     auto slice_limit = (*(limit.begin() + dimension)).getSExtValue();
3484 
3485     // We need to determine what inputs from the concat affect the slice, and
3486     // how the bounds of the slice need to be updated for the minimally required
3487     // inputs.
3488     int64_t running_size = 0;
3489     int64_t front_offset = slice_input_ty.getShape()[dimension];
3490 
3491     auto subset_start = concat.operand_end();
3492     auto subset_end = concat.operand_end();
3493     for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
3494       auto input = *it;
3495       ShapedType input_ty = input.getType().cast<ShapedType>();
3496       if (input_ty.isDynamicDim(dimension)) {
3497         return failure();
3498       }
3499       auto dim_size = input_ty.getShape()[dimension];
3500 
3501       // If this position is in the slice its the start of the subset and we
3502       // need to update the start and limit values.
3503       if (running_size + dim_size > slice_start &&
3504           subset_start == concat.operand_end()) {
3505         subset_start = it;
3506         front_offset = running_size;
3507       }
3508 
3509       // Determine the last required offset.
3510       if (running_size < slice_limit) {
3511         subset_end = it + 1;
3512       }
3513 
3514       running_size += dim_size;
3515     }
3516 
3517     auto subset_size = subset_end - subset_start;
3518     // We need all inputs so no optimization.
3519     if (subset_size == concat.getNumOperands()) {
3520       return failure();
3521     }
3522 
3523     // If there's nothing to slice that means the output is an empty tensor and
3524     // there is dead code. We do nothing here and rely on other passes to clean
3525     // this up.
3526     if (subset_size == 0) {
3527       return failure();
3528     }
3529 
3530     if (subset_size > 1 && !concat.getResult().hasOneUse()) {
3531       return failure();
3532     }
3533 
3534     auto concat_range = OperandRange(subset_start, subset_end);
3535     auto new_concat = rewriter.create<ConcatenateOp>(
3536         concat.getLoc(), concat_range, concat.dimension());
3537 
3538     llvm::SmallVector<APInt, 6> new_start(start);
3539     llvm::SmallVector<APInt, 6> new_limit(limit);
3540     new_start[dimension] -= front_offset;
3541     new_limit[dimension] -= front_offset;
3542 
3543     auto attr_type = slice.start_indices().getType().cast<ShapedType>();
3544     auto create = rewriter.create<SliceOp>(
3545         slice.getLoc(), new_concat,
3546         DenseIntElementsAttr::get(attr_type, new_start),
3547         DenseIntElementsAttr::get(attr_type, new_limit), slice.strides());
3548     rewriter.replaceOp(slice, create.getResult());
3549     return success();
3550   }
3551 };
3552 }  // namespace
3553 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3554 void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
3555                                           MLIRContext* context) {
3556   results.insert<SimplifyConcatSlice>(context);
3557 }
3558 
3559 //===----------------------------------------------------------------------===//
3560 // SortOp
3561 //===----------------------------------------------------------------------===//
3562 
build(OpBuilder & builder,OperationState & state,ValueRange operands,int64_t dimension,bool is_stable)3563 void SortOp::build(OpBuilder& builder, OperationState& state,
3564                    ValueRange operands, int64_t dimension, bool is_stable) {
3565   state.addOperands(operands);
3566   state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
3567   state.addAttribute("is_stable", builder.getBoolAttr(dimension));
3568 
3569   for (Value operand : operands) state.addTypes(operand.getType());
3570 
3571   state.addRegion();
3572 }
3573 
Verify(SortOp op)3574 static LogicalResult Verify(SortOp op) {
3575   Operation::operand_range operands = op.operands();
3576   if (operands.empty()) return op.emitOpError("requires at least one input");
3577 
3578   // TODO(antiagainst): verify partionally dynamic shapes
3579   if (llvm::all_of(operands, [](Value operand) {
3580         return operand.getType().cast<ShapedType>().hasRank();
3581       })) {
3582     ArrayRef<int64_t> input_shape =
3583         (*operands.begin()).getType().cast<ShapedType>().getShape();
3584 
3585     if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
3586           return operand.getType().cast<ShapedType>().getShape() != input_shape;
3587         }))
3588       return op.emitOpError("requires all inputs to have the same dimensions");
3589 
3590     int64_t rank = input_shape.size();
3591     int64_t cmp_dim = op.dimension();
3592     if (cmp_dim < -rank || cmp_dim >= rank)
3593       return op.emitOpError("dimension attribute value must be in range [-")
3594              << rank << ", " << rank << "), but found " << cmp_dim;
3595   }
3596 
3597   Block& block = op.comparator().front();
3598   size_t num_operands = op.getOperation()->getNumOperands();
3599   if (block.getNumArguments() != 2 * num_operands)
3600     return op.emitOpError("comparator block should have ")
3601            << 2 * num_operands << " arguments";
3602 
3603   for (auto indexed_operand : llvm::enumerate(operands)) {
3604     int index = indexed_operand.index();
3605     Type element_type =
3606         indexed_operand.value().getType().cast<ShapedType>().getElementType();
3607     Type tensor_type = RankedTensorType::get({}, element_type);
3608     for (int i : {2 * index, 2 * index + 1}) {
3609       Type arg_type = block.getArgument(i).getType();
3610       if (arg_type != tensor_type)
3611         return op.emitOpError("comparator block argument #")
3612                << i << " should be of type " << tensor_type << " but got "
3613                << arg_type;
3614     }
3615   }
3616 
3617   return success();
3618 }
3619 
3620 /// Drops the operands if the results are not used and they are not used in
3621 /// op.comparator().
SortDropEmptyUseArgs(SortOp op,PatternRewriter & rewriter)3622 static LogicalResult SortDropEmptyUseArgs(SortOp op,
3623                                           PatternRewriter& rewriter) {
3624   DenseSet<unsigned> erased_args;
3625   unsigned num_operands = op.getNumOperands();
3626   for (unsigned i = 0; i < num_operands; ++i) {
3627     if (!op.getResult(i).use_empty()) continue;
3628     Block& block = op.comparator().front();
3629     if (!block.getArgument(i * 2).use_empty()) continue;
3630     if (!block.getArgument(i * 2 + 1).use_empty()) continue;
3631     erased_args.insert(i);
3632   }
3633   if (erased_args.empty()) return failure();
3634 
3635   SmallVector<Value> new_operands;
3636   SmallVector<unsigned> erased_block_args;
3637   for (auto en : llvm::enumerate(op.operands())) {
3638     if (erased_args.contains(en.index())) {
3639       erased_block_args.push_back(en.index() * 2);
3640       erased_block_args.push_back(en.index() * 2 + 1);
3641     } else {
3642       new_operands.push_back(en.value());
3643     }
3644   }
3645 
3646   auto new_op = rewriter.create<SortOp>(op.getLoc(), new_operands,
3647                                         op.dimension(), op.is_stable());
3648   Region& region = new_op.comparator();
3649   rewriter.inlineRegionBefore(op.comparator(), region, region.end());
3650   region.front().eraseArguments(erased_block_args);
3651 
3652   SmallVector<Value> results;
3653   for (unsigned i = 0, j = 0; i < num_operands; ++i) {
3654     if (erased_args.contains(i)) {
3655       results.push_back({});
3656     } else {
3657       results.push_back(new_op.getResult(j++));
3658     }
3659   }
3660   rewriter.replaceOp(op, results);
3661 
3662   return success();
3663 }
3664 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext *)3665 void SortOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
3666                                          MLIRContext* /*context*/) {
3667   results.insert(SortDropEmptyUseArgs);
3668 }
3669 
3670 //===----------------------------------------------------------------------===//
3671 // TransposeOp
3672 //===----------------------------------------------------------------------===//
3673 
fold(ArrayRef<Attribute> operands)3674 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
3675   for (auto it : llvm::enumerate(permutation().getValues<APInt>())) {
3676     if (it.index() != it.value()) {
3677       return {};
3678     }
3679   }
3680   return getOperand();
3681 }
3682 
Verify(TransposeOp op)3683 static LogicalResult Verify(TransposeOp op) {
3684   // permutation is an attribute of the op so it has static shape.
3685   auto permutationType = op.permutation().getType();
3686   auto permutationRank = permutationType.getRank();
3687   if (permutationRank != 1) {
3688     return op.emitOpError(llvm::formatv(
3689         "permutation has rank {0} instead of rank 1", permutationRank));
3690   }
3691   auto permutationSize = permutationType.getNumElements();
3692 
3693   auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
3694   if (operandType) {
3695     auto operandRank = operandType.getRank();
3696     if (operandRank != permutationSize) {
3697       return op.emitOpError(llvm::formatv(
3698           "operand rank ({0}) does not match permutation size ({1})",
3699           operandRank, permutationSize));
3700     }
3701   }
3702 
3703   auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
3704   if (resultType) {
3705     auto resultRank = resultType.getRank();
3706     if (resultRank != permutationSize) {
3707       return op.emitOpError(llvm::formatv(
3708           "result rank ({0}) does not match permutation size ({1})", resultRank,
3709           permutationSize));
3710     }
3711   }
3712 
3713   if (!resultType || !operandType) return success();
3714 
3715   auto operandRank = operandType.getRank();
3716   SmallVector<int64_t, 4> expectedShape(operandRank);
3717   for (int i = 0; i != operandRank; ++i) {
3718     auto permutedDim = op.permutation().getValue<IntegerAttr>(i).getInt();
3719     expectedShape[i] = operandType.getDimSize(permutedDim);
3720   }
3721 
3722   auto expectedType =
3723       RankedTensorType::get(expectedShape, resultType.getElementType());
3724   if (failed(verifyCompatibleShape(resultType, expectedType))) {
3725     return op.emitOpError(llvm::formatv(
3726         "result type {0} is incompatible with the expected type {1}",
3727         resultType, expectedType));
3728   }
3729 
3730   return success();
3731 }
3732 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3733 LogicalResult TransposeOp::reifyReturnTypeShapes(
3734     OpBuilder& builder, ValueRange operands,
3735     SmallVectorImpl<Value>& reifiedReturnShapes) {
3736   TransposeOp::Adaptor adaptor(operands);
3737   Value operand = adaptor.operand();
3738 
3739   auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
3740   // Not support unranked type a.t.m.
3741   if (!operand_type) return failure();
3742 
3743   Location loc = this->getLoc();
3744   SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
3745   SmallVector<Value, 4> shape_values(permutation.size());
3746 
3747   Type shape_scalar_type = builder.getIndexType();
3748   auto to_shape_scalar_type = [&](Value v) {
3749     return MaybeCastTo(builder, loc, v, shape_scalar_type);
3750   };
3751 
3752   for (const auto& element : llvm::enumerate(operand_type.getShape())) {
3753     int64_t idx = element.index();
3754     auto it = std::find(permutation.begin(), permutation.end(), idx);
3755     Value value_dim = to_shape_scalar_type(
3756         builder.create<tensor::DimOp>(loc, operand, element.index()));
3757     shape_values[std::distance(permutation.begin(), it)] = value_dim;
3758   }
3759 
3760   Value output_shape = builder.create<tensor::FromElementsOp>(
3761       loc, shape_scalar_type, shape_values);
3762   reifiedReturnShapes.push_back(output_shape);
3763 
3764   return success();
3765 }
3766 
3767 //===----------------------------------------------------------------------===//
3768 // TriangularSolveOp
3769 //===----------------------------------------------------------------------===//
3770 
Verify(TriangularSolveOp op)3771 static LogicalResult Verify(TriangularSolveOp op) {
3772   auto a_type = op.a().getType().dyn_cast<RankedTensorType>();
3773 
3774   // Skip verifier if a is unranked tensor.
3775   if (!a_type) return success();
3776 
3777   // Check that a should have rank >= 2
3778   auto a_rank = a_type.getRank();
3779   if (a_rank < 2)
3780     return op.emitOpError()
3781            << "operand 'a' must have rank >= 2, but got " << a_type;
3782 
3783   // The two minor dimensions of a must have same size.
3784   if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1))
3785     return op.emitOpError() << "two minor dimensions of operand 'a' must have "
3786                                "equal size, but got "
3787                             << a_type;
3788 
3789   auto b_type = op.b().getType().dyn_cast<RankedTensorType>();
3790   // If b is unranked skip remaining checks.
3791   if (!b_type) return success();
3792 
3793   // Check that a and b have same rank.
3794   auto b_rank = b_type.getRank();
3795   if (a_rank != b_rank)
3796     return op.emitOpError() << "operands must have equal rank, but got "
3797                             << a_type << " and " << b_type;
3798 
3799   // The shared dimension of a and b should match.
3800   if (a_type.getDimSize(a_rank - 1) !=
3801       b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1)))
3802     return op.emitOpError() << "shared dimension of operands 'a' and 'b' does "
3803                                "not match, but got "
3804                             << a_type << " and " << b_type;
3805 
3806   // The leading batch dimensions of a and b must be equal.
3807   auto a_batch_dims = a_type.getShape().drop_back(2);
3808   auto b_batch_dims = b_type.getShape().drop_back(2);
3809   if (a_batch_dims != b_batch_dims)
3810     return op.emitOpError()
3811            << "leading batch dimensions of the operands must be same, but got "
3812            << a_type << " and " << b_type;
3813 
3814   // Result and argument b must have same shape.
3815   auto result_type = op.getType().dyn_cast<RankedTensorType>();
3816   if (!result_type) return success();
3817   if (result_type != b_type)
3818     return op.emitOpError()
3819            << "result and operand 'b' must have same shape, but got "
3820            << result_type << " and " << b_type;
3821   return success();
3822 }
3823 
3824 //===----------------------------------------------------------------------===//
3825 // GetTupleElementOp
3826 //===----------------------------------------------------------------------===//
3827 
build(OpBuilder & builder,OperationState & result,Value tuple,int32_t index)3828 void GetTupleElementOp::build(OpBuilder& builder, OperationState& result,
3829                               Value tuple, int32_t index) {
3830   if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
3831     auto element_type = tuple_type.getType(index);
3832     build(builder, result, element_type, tuple,
3833           builder.getI32IntegerAttr(index));
3834     return;
3835   }
3836 
3837   build(builder, result, tuple.getType(), tuple,
3838         builder.getI32IntegerAttr(index));
3839 }
3840 
3841 //===----------------------------------------------------------------------===//
3842 // TupleOp
3843 //===----------------------------------------------------------------------===//
3844 
build(OpBuilder & builder,OperationState & result,ValueRange values)3845 void TupleOp::build(OpBuilder& builder, OperationState& result,
3846                     ValueRange values) {
3847   SmallVector<Type, 4> types;
3848   types.reserve(values.size());
3849   for (auto val : values) {
3850     types.push_back(val.getType());
3851   }
3852 
3853   build(builder, result, builder.getTupleType(types), values);
3854 }
3855 
3856 //===----------------------------------------------------------------------===//
3857 // UnaryEinsumOp
3858 //===----------------------------------------------------------------------===//
3859 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3860 void UnaryEinsumOp::getCanonicalizationPatterns(
3861     OwningRewritePatternList& results, MLIRContext* context) {
3862   results.insert<UnaryEinsumToEinsum>(context);
3863 }
3864 
3865 //===----------------------------------------------------------------------===//
3866 // CompareOp
3867 //===----------------------------------------------------------------------===//
3868 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,StringAttr comparison_direction,StringAttr compare_type)3869 void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
3870                       Value rhs, StringAttr comparison_direction,
3871                       StringAttr compare_type) {
3872   auto new_type =
3873       UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
3874   build(builder, result, new_type, lhs, rhs, comparison_direction,
3875         compare_type);
3876 }
3877 
inferReturnTypeComponents(mlir::MLIRContext * ctx,llvm::Optional<mlir::Location>,ValueShapeRange operands,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferredReturnTypes)3878 LogicalResult CompareOp::inferReturnTypeComponents(
3879     mlir::MLIRContext* ctx, llvm::Optional<mlir::Location>,
3880     ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange,
3881     llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferredReturnTypes) {
3882   OpBuilder builder(ctx);
3883   auto arg_ty = operands.front().getType().cast<TensorType>();
3884   inferredReturnTypes.push_back({arg_ty.getShape(), builder.getI1Type()});
3885   return success();
3886 }
3887 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3888 LogicalResult CompareOp::reifyReturnTypeShapes(
3889     OpBuilder& builder, ValueRange operands,
3890     SmallVectorImpl<Value>& reifiedReturnShapes) {
3891   return deriveShapeFromOperand(&builder, getOperation(), operands.front(),
3892                                 &reifiedReturnShapes);
3893 }
3894 
3895 template <typename T>
3896 struct less : std::less<T> {};
3897 
3898 template <>
3899 struct less<APInt> {
operator ()mlir::mhlo::less3900   bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
3901 };
3902 
3903 template <typename T>
3904 struct less_equal : std::less_equal<T> {};
3905 
3906 template <>
3907 struct less_equal<APInt> {
operator ()mlir::mhlo::less_equal3908   bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
3909 };
3910 
3911 template <typename T>
3912 struct greater : std::greater<T> {};
3913 
3914 template <>
3915 struct greater<APInt> {
operator ()mlir::mhlo::greater3916   bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
3917 };
3918 
3919 template <typename T>
3920 struct greater_equal : std::greater_equal<T> {};
3921 
3922 template <>
3923 struct greater_equal<APInt> {
operator ()mlir::mhlo::greater_equal3924   bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
3925 };
3926 
3927 template <typename Op, typename ElementType, typename SrcType, typename Convert>
CompareFolder(CompareOp op,ArrayRef<Attribute> attrs)3928 static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
3929   if (!attrs[0] || !attrs[1]) return {};
3930 
3931   DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
3932   DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
3933   if (!lhs || !rhs) return {};
3934 
3935   ShapedType operand_type =
3936       op.getOperand(0).getType().template cast<ShapedType>();
3937   if (!operand_type.hasStaticShape()) {
3938     return {};
3939   }
3940 
3941   if (!operand_type.getElementType().isa<ElementType>()) {
3942     return {};
3943   }
3944 
3945   SmallVector<bool, 6> values;
3946   values.reserve(lhs.getNumElements());
3947   for (const auto zip :
3948        llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
3949     values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
3950   }
3951 
3952   auto result_ty = op.getType().cast<ShapedType>();
3953   return DenseElementsAttr::get(result_ty, values);
3954 }
3955 
fold(ArrayRef<Attribute> operands)3956 OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
3957   auto result_ty = getType().cast<ShapedType>();
3958   if (!result_ty.hasStaticShape()) return {};
3959 
3960   auto direction = comparison_direction();
3961   if (lhs() == rhs() && !getElementTypeOrSelf(lhs()).isa<FloatType>()) {
3962     if (direction == "LE" || direction == "EQ" || direction == "GE") {
3963       return DenseIntElementsAttr::get(result_ty, {true});
3964     }
3965     return DenseIntElementsAttr::get(result_ty, {false});
3966   }
3967 
3968   auto op_el_type = lhs().getType().cast<ShapedType>().getElementType();
3969   // Fold tensor<*xi1> != false to just return tensor<*xi1>
3970   if (direction == "NE" && op_el_type.isInteger(1)) {
3971     DenseIntElementsAttr cst_attr;
3972     if (matchPattern(lhs(), m_Constant(&cst_attr))) {
3973       if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
3974         return rhs();
3975       }
3976     }
3977 
3978     if (matchPattern(rhs(), m_Constant(&cst_attr))) {
3979       if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
3980         return lhs();
3981       }
3982     }
3983   }
3984 
3985   // Fold tensor<*xi1> == True to just return tensor<*xi1>
3986   if (direction == "EQ" && op_el_type.isInteger(1)) {
3987     DenseIntElementsAttr cst_attr;
3988     if (matchPattern(lhs(), m_Constant(&cst_attr))) {
3989       if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
3990         return rhs();
3991       }
3992     }
3993 
3994     if (matchPattern(rhs(), m_Constant(&cst_attr))) {
3995       if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
3996         return lhs();
3997       }
3998     }
3999   }
4000 
4001   if (!operands[0] || !operands[1]) {
4002     return {};
4003   }
4004 
4005 #define COMPARE_FOLDER(Op, comparison, Func)                                \
4006   if (direction == comparison) {                                            \
4007     if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
4008             *this, operands))                                               \
4009       return folded;                                                        \
4010     if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>(   \
4011             *this, operands))                                               \
4012       return folded;                                                        \
4013   }
4014 
4015   COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
4016   COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
4017   COMPARE_FOLDER(CompareOp, "LT", less);
4018   COMPARE_FOLDER(CompareOp, "LE", less_equal);
4019   COMPARE_FOLDER(CompareOp, "GT", greater);
4020   COMPARE_FOLDER(CompareOp, "GE", greater_equal);
4021 #undef COMPARE_FOLDER
4022 
4023   return {};
4024 }
4025 
4026 //===----------------------------------------------------------------------===//
4027 // ScatterOp
4028 //===----------------------------------------------------------------------===//
4029 
evaluateMhloRegion(Region & region,ArrayRef<Attribute> inputs)4030 llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
4031                                                    ArrayRef<Attribute> inputs) {
4032   if (region.getNumArguments() != inputs.size()) return {};
4033 
4034   llvm::DenseMap<Value, Attribute> values;
4035   values.reserve(region.getNumArguments());
4036   for (auto it : llvm::zip(region.getArguments(), inputs)) {
4037     values.try_emplace(std::get<0>(it), std::get<1>(it));
4038   }
4039 
4040   for (auto& op : region.getOps()) {
4041     llvm::SmallVector<Attribute, 4> inputs;
4042     for (auto& operand : op.getOpOperands()) {
4043       inputs.push_back(values.lookup(operand.get()));
4044     }
4045     if (isa<ReturnOp>(op)) return inputs;
4046 
4047     llvm::SmallVector<OpFoldResult, 4> results;
4048     if (failed(op.fold(inputs, results))) return {};
4049     for (auto it : llvm::zip(op.getResults(), results)) {
4050       if (!std::get<1>(it).is<Attribute>()) return {};
4051       values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
4052     }
4053   }
4054   return {};
4055 }
4056 
fold(ArrayRef<Attribute> operands)4057 OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) {
4058   auto base = operands[0].dyn_cast_or_null<DenseElementsAttr>();
4059   auto index = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
4060   auto update = operands[2].dyn_cast_or_null<DenseElementsAttr>();
4061   if (!base || !index || !update) return {};
4062 
4063   auto base_type = base.getType().dyn_cast<RankedTensorType>();
4064   auto index_type = index.getType().dyn_cast<RankedTensorType>();
4065   auto update_type = update.getType().dyn_cast<RankedTensorType>();
4066   if (!base_type || !index_type || !update_type) return {};
4067 
4068   // Add the virtual trailing dimension of size 1 if index_vector_dim equals to
4069   // index_type.rank.
4070   const int64_t index_vector_dim =
4071       scatter_dimension_numbers().index_vector_dim().getInt();
4072   if (index_vector_dim == index_type.getRank()) {
4073     auto index_shape = index_type.getShape().vec();
4074     index_shape.push_back(1);
4075     index_type =
4076         RankedTensorType::get(index_shape, index_type.getElementType());
4077     index = index.reshape(index_type).cast<DenseIntElementsAttr>();
4078   }
4079 
4080   // Increment the multi-dimensional index vector based on the limits for each
4081   // dimension specified by shape and returns false if the index rolled around
4082   // with true otherwise.
4083   auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
4084                        llvm::ArrayRef<int64_t> shape) {
4085     for (int64_t i = index.size() - 1; i >= 0; --i) {
4086       ++index[i];
4087       if (index[i] < shape[i]) return true;
4088       index[i] = 0;
4089     }
4090     return false;
4091   };
4092 
4093   // Iterate over all elements of the update tensor, then find the corresponding
4094   // value in the indices tensor to determine which location we have to update
4095   // in the base/result tensor.
4096   llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
4097   llvm::SmallVector<uint64_t, 8> update_index(update_type.getRank(), 0);
4098   llvm::SmallVector<uint64_t, 8> index_index;
4099   index_index.reserve(index_type.getRank());
4100   llvm::SmallVector<uint64_t, 8> base_index;
4101   base_index.reserve(base_type.getRank());
4102   do {
4103     // Compute the index for the slice of the indices tensor for this update
4104     // value.
4105     index_index.clear();
4106     if (index_vector_dim == 0) index_index.push_back(0);
4107     for (int64_t i = 0; i < update_index.size(); ++i) {
4108       if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0)
4109         index_index.push_back(update_index[i]);
4110       if (index_index.size() == index_vector_dim) index_index.push_back(0);
4111     }
4112 
4113     // Compute the index for the given update value in the base tensor.
4114     base_index.assign(base_type.getRank(), 0);
4115     uint64_t index_count = index_type.getShape()[index_vector_dim];
4116     for (uint64_t i = 0; i < index_count; ++i) {
4117       uint64_t operand_dim = scatter_dimension_numbers()
4118                                  .scatter_dims_to_operand_dims()
4119                                  .getValue<APInt>({i})
4120                                  .getSExtValue();
4121       index_index[index_vector_dim] = i;
4122       base_index[operand_dim] +=
4123           index.getValue<APInt>(index_index).getSExtValue();
4124     }
4125     uint64_t update_window_dim_index = 0;
4126     for (uint64_t i = 0; i < base_index.size(); ++i) {
4127       if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i))
4128         continue;
4129       base_index[i] +=
4130           update_index[scatter_dimension_numbers()
4131                            .update_window_dims()
4132                            .getValue<APInt>({update_window_dim_index})
4133                            .getSExtValue()];
4134       update_window_dim_index++;
4135     }
4136 
4137     // Compute the linear index for the index into the base tensor.
4138     int64_t linear_base_index = 0;
4139     int64_t linear_base_index_multiplyer = 1;
4140     for (int64_t i = base_index.size() - 1; i >= 0; --i) {
4141       // Out of bound index have backend specific behaviour so avoid folding it.
4142       if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i])
4143         return {};
4144       linear_base_index += base_index[i] * linear_base_index_multiplyer;
4145       linear_base_index_multiplyer *= base_type.getShape()[i];
4146     }
4147 
4148     // Evaluate update computation and update the value with the newly computed
4149     // attribute in the base tensor.
4150     auto lhs = DenseElementsAttr::get(
4151         RankedTensorType::get({}, base_type.getElementType()),
4152         results[linear_base_index]);
4153     auto rhs = DenseElementsAttr::get(
4154         RankedTensorType::get({}, base_type.getElementType()),
4155         update.getValue<Attribute>(update_index));
4156     auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs});
4157     if (new_value.size() != 1 || !new_value[0]) return {};
4158     results[linear_base_index] =
4159         new_value[0].cast<DenseElementsAttr>().getValue<Attribute>({});
4160   } while (next_index(update_index, update_type.getShape()));
4161 
4162   return DenseElementsAttr::get(base_type, results);
4163 }
4164 
4165 using mlir::hlo::parseWindowAttributes;
4166 using mlir::hlo::printWindowAttributes;
4167 
4168 }  // namespace mhlo
4169 }  // namespace mlir
4170 
4171 #define GET_OP_CLASSES
4172 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
4173 
4174 namespace mlir {
4175 namespace mhlo {
4176 
4177 //===----------------------------------------------------------------------===//
4178 // mhlo Dialect Interfaces
4179 //===----------------------------------------------------------------------===//
4180 
4181 namespace {
4182 struct HLOInlinerInterface : public DialectInlinerInterface {
4183   using DialectInlinerInterface::DialectInlinerInterface;
4184 
4185   // Allow all call operations to be inlined.
isLegalToInlinemlir::mhlo::__anon7800207c1a11::HLOInlinerInterface4186   bool isLegalToInline(Operation* call, Operation* callable,
4187                        bool wouldBeCloned) const final {
4188     return true;
4189   }
4190   // We don't have any special restrictions on what can be inlined into
4191   // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInlinemlir::mhlo::__anon7800207c1a11::HLOInlinerInterface4192   bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
4193                        BlockAndValueMapping& valueMapping) const final {
4194     return true;
4195   }
4196   // Operations in mhlo dialect are always legal to inline since they are
4197   // pure.
isLegalToInlinemlir::mhlo::__anon7800207c1a11::HLOInlinerInterface4198   bool isLegalToInline(Operation*, Region*, bool,
4199                        BlockAndValueMapping&) const final {
4200     return true;
4201   }
4202 };
4203 }  // end anonymous namespace
4204 
4205 //===----------------------------------------------------------------------===//
4206 // mhlo Dialect Constructor
4207 //===----------------------------------------------------------------------===//
4208 
MhloDialect(MLIRContext * context)4209 MhloDialect::MhloDialect(MLIRContext* context)
4210     : Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) {
4211   addOperations<
4212 #define GET_OP_LIST
4213 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
4214       >();
4215   addInterfaces<HLOInlinerInterface>();
4216   addTypes<TokenType>();
4217   context->loadDialect<tensor::TensorDialect>();
4218 }
4219 
parseType(DialectAsmParser & parser) const4220 Type MhloDialect::parseType(DialectAsmParser& parser) const {
4221   StringRef data_type;
4222   if (parser.parseKeyword(&data_type)) return Type();
4223 
4224   if (data_type == "token") return TokenType::get(getContext());
4225   parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
4226   return nullptr;
4227 }
4228 
printType(Type type,DialectAsmPrinter & os) const4229 void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
4230   if (type.isa<TokenType>()) {
4231     os << "token";
4232     return;
4233   }
4234   os << "<unknown mhlo type>";
4235 }
4236 
4237 //===----------------------------------------------------------------------===//
4238 // Shape inference
4239 //===----------------------------------------------------------------------===//
4240 
deriveShapeFromOperand(OpBuilder * builder,Operation * op,Value operand,SmallVectorImpl<Value> * reifiedReturnShapes)4241 LogicalResult deriveShapeFromOperand(
4242     OpBuilder* builder, Operation* op, Value operand,
4243     SmallVectorImpl<Value>* reifiedReturnShapes) {
4244   auto shaped_ty = operand.getType().dyn_cast<ShapedType>();
4245   if (!shaped_ty) {
4246     op->emitOpError() << "operand is not a shaped type";
4247     return failure();
4248   }
4249   reifiedReturnShapes->assign(
4250       {builder->create<shape::ShapeOfOp>(op->getLoc(), operand)});
4251   return success();
4252 }
4253 
4254 }  // namespace mhlo
4255 }  // namespace mlir
4256