1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines the operations used in the MHLO dialect.
17
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19
20 #include <assert.h>
21 #include <stddef.h>
22 #include <stdint.h>
23
24 #include <algorithm>
25 #include <functional>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/Twine.h"
35 #include "llvm/ADT/iterator_range.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/FormatVariadic.h"
38 #include "llvm/Support/MathExtras.h"
39 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
40 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
41 #include "mlir-hlo/utils/convert_op_folder.h"
42 #include "mlir-hlo/utils/hlo_utils.h"
43 #include "mlir/Dialect/Shape/IR/Shape.h"
44 #include "mlir/Dialect/StandardOps/IR/Ops.h"
45 #include "mlir/Dialect/Tensor/IR/Tensor.h"
46 #include "mlir/IR/Attributes.h"
47 #include "mlir/IR/Builders.h"
48 #include "mlir/IR/BuiltinAttributes.h"
49 #include "mlir/IR/BuiltinTypes.h"
50 #include "mlir/IR/Dialect.h"
51 #include "mlir/IR/Location.h"
52 #include "mlir/IR/MLIRContext.h"
53 #include "mlir/IR/Matchers.h"
54 #include "mlir/IR/OpDefinition.h"
55 #include "mlir/IR/OpImplementation.h"
56 #include "mlir/IR/Operation.h"
57 #include "mlir/IR/OperationSupport.h"
58 #include "mlir/IR/PatternMatch.h"
59 #include "mlir/IR/TypeUtilities.h"
60 #include "mlir/IR/Types.h"
61 #include "mlir/IR/Value.h"
62 #include "mlir/Support/LLVM.h"
63 #include "mlir/Support/LogicalResult.h"
64 #include "mlir/Transforms/InliningUtils.h"
65
66 namespace mlir {
67 #include "hlo_patterns.cc.inc"
68 } // namespace mlir
69
70 namespace mlir {
71 namespace mhlo {
72
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)73 Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
74 Type type, Location loc) {
75 // HLO dialect constants only support ElementsAttr unlike standard dialect
76 // constant which supports all attributes.
77 if (value.isa<ElementsAttr>())
78 return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
79 return nullptr;
80 }
81
82 template <typename T>
Verify(T op)83 static LogicalResult Verify(T op) {
84 return success();
85 }
86
87 namespace {
88
89 //===----------------------------------------------------------------------===//
90 // Utilities for the canonicalize patterns
91 //===----------------------------------------------------------------------===//
92
93 // Verifies that dimension attribute for the op correctly indexes in operand or
94 // result shape.
95 template <typename OpT>
VerifyDimAttr(OpT op)96 static LogicalResult VerifyDimAttr(OpT op) {
97 int64_t rank = -1;
98 if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
99 rank = ty.getRank();
100 } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
101 rank = ty.getRank();
102 } else {
103 return success();
104 }
105
106 int64_t dim = op.dimension();
107 if (dim < 0 || dim >= rank)
108 return op.emitOpError() << "requires dimension attribute in range [0, "
109 << rank << "); found (" << dim << ")";
110 return success();
111 }
112
113 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)114 DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
115 Builder* builder) {
116 RankedTensorType ty = RankedTensorType::get(
117 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
118 return DenseIntElementsAttr::get(ty, values);
119 }
120
121 // Given the start indices and slice sizes for a dynamic-slice that can be
122 // converted to a static slice, returns the limits for the static slice.
BuildSliceLimits(DenseIntElementsAttr start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)123 DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
124 DenseIntElementsAttr slice_sizes,
125 Builder* builder) {
126 SmallVector<int64_t, 4> slice_limits;
127 for (int64_t i = 0; i < slice_sizes.getNumElements(); ++i) {
128 int64_t start_index = start_indices.getValue<IntegerAttr>(i).getInt();
129 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
130 slice_limits.push_back(start_index + slice_size);
131 }
132 return GetI64ElementsAttr(slice_limits, builder);
133 }
134
135 /// Replaces the given op with the contents of the given single-block region,
136 /// using the operands of the block terminator to replace operation results.
ReplaceOpWithRegion(PatternRewriter & rewriter,Operation * op,Region & region,ValueRange blockArgs={})137 static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
138 Region& region, ValueRange blockArgs = {}) {
139 assert(llvm::hasSingleElement(region) && "expected single-block region");
140 Block* block = ®ion.front();
141 Operation* terminator = block->getTerminator();
142 ValueRange results = terminator->getOperands();
143 rewriter.mergeBlockBefore(block, op, blockArgs);
144 rewriter.replaceOp(op, results);
145 rewriter.eraseOp(terminator);
146 }
147
148 #include "mhlo_canonicalize.inc"
149
150 // Common shape function helper for RngNormal and RngUniform.
rngInferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)151 static LogicalResult rngInferReturnTypeComponents(
152 MLIRContext* context, Optional<Location> location, ValueRange operands,
153 DictionaryAttr attributes, RegionRange regions,
154 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
155 if (operands.size() != 3)
156 return emitOptionalError(location, "expected 3 operands");
157
158 SmallVector<int64_t> shapeVector;
159 Value shapeOperand = operands[2];
160 auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
161 Type elementType = getElementTypeOrSelf(operands[1]);
162
163 // Match constant shape arguments.
164 DenseIntElementsAttr shape;
165 if (!matchPattern(shapeOperand, m_Constant(&shape))) {
166 if (!shapeOperandType.hasRank()) {
167 inferredReturnShapes.emplace_back(elementType);
168 return success();
169 }
170 if (shapeOperandType.getRank() != 1)
171 return emitOptionalError(location, "shape operand required to be 1D");
172 int size = shapeOperandType.getDimSize(0);
173 if (size == ShapedType::kDynamicSize) {
174 inferredReturnShapes.emplace_back(elementType);
175 return success();
176 }
177 shapeVector.resize(size, ShapedType::kDynamicSize);
178 inferredReturnShapes.emplace_back(shapeVector, elementType);
179 return success();
180 }
181
182 shapeVector.reserve(shape.size());
183 for (const APInt& fp : shape.getIntValues())
184 shapeVector.push_back(fp.getSExtValue());
185 inferredReturnShapes.emplace_back(shapeVector, elementType);
186 return success();
187 }
188
189 // Returns a new scalar integer value having type `type`. Here `type` must be
190 // an integer or index type.
MaybeCastTo(OpBuilder & b,Location loc,Value value,Type type)191 Value MaybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
192 if (type == value.getType()) return value;
193 assert(type.isIndex() || value.getType().isIndex());
194 return b.create<IndexCastOp>(loc, value, type);
195 }
196
197 } // namespace
198
199 //===----------------------------------------------------------------------===//
200 // ReduceScatterOp
201 //===----------------------------------------------------------------------===//
202
Verify(ReduceScatterOp op)203 static LogicalResult Verify(ReduceScatterOp op) {
204 return mlir::hlo::VerifyReduceScatter(
205 op,
206 /*operand_types=*/{op.operand().getType()},
207 /*result_types=*/{op.getType()},
208 /*scatter_dimension=*/op.scatter_dimension());
209 }
210
211 //===----------------------------------------------------------------------===//
212 // ConstOp
213 //===----------------------------------------------------------------------===//
214
fold(ArrayRef<Attribute> operands)215 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
216 assert(operands.empty() && "constant has no operands");
217
218 // Return the held attribute value.
219 return value();
220 }
221
222 // Builds a constant op with the specified attribute `value`.
build(OpBuilder & builder,OperationState & result,Attribute value)223 void ConstOp::build(OpBuilder& builder, OperationState& result,
224 Attribute value) {
225 Type type;
226 if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
227 type = elemAttr.getType();
228 } else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
229 value.isa<IntegerAttr>()) {
230 // All XLA types must be tensor types. In the build() method, we want to
231 // provide more flexibility by allowing attributes of scalar types. But we
232 // need to wrap it up with ElementsAttr to construct valid XLA constants.
233 type = RankedTensorType::get(/*shape=*/{}, value.getType());
234 value = DenseElementsAttr::get(type.cast<TensorType>(), value);
235 }
236
237 // TODO: support other XLA specific types.
238 assert(type && "unsupported attribute type for building mhlo.constant");
239 result.types.push_back(type);
240 result.addAttribute("value", value);
241 }
242
243 //===----------------------------------------------------------------------===//
244 // DotGeneralOp
245 //===----------------------------------------------------------------------===//
246
Verify(DotGeneralOp op)247 static LogicalResult Verify(DotGeneralOp op) {
248 auto dot_dimension_numbers = op.dot_dimension_numbers();
249 int64_t lhs_batching_dimensions_size = llvm::size(
250 dot_dimension_numbers.lhs_batching_dimensions().getValues<int64_t>());
251 int64_t rhs_batching_dimensions_size = llvm::size(
252 dot_dimension_numbers.rhs_batching_dimensions().getValues<int64_t>());
253 if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) {
254 return op.emitError()
255 << "lhs and rhs should have the same number of batching dimensions";
256 }
257 int64_t lhs_contracting_dimensions_size = llvm::size(
258 dot_dimension_numbers.lhs_contracting_dimensions().getValues<int64_t>());
259 int64_t rhs_contracting_dimensions_size = llvm::size(
260 dot_dimension_numbers.rhs_contracting_dimensions().getValues<int64_t>());
261 if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) {
262 return op.emitError() << "lhs and rhs should have the same number of "
263 "contracting dimensions";
264 }
265 return success();
266 }
267
268 //===----------------------------------------------------------------------===//
269 // GatherOp
270 //===----------------------------------------------------------------------===//
271
272 // Converts gather ops to slice ops in case we have a single set of constant
273 // indices.
274 struct GatherSlice : public OpRewritePattern<GatherOp> {
275 using OpRewritePattern<GatherOp>::OpRewritePattern;
276
matchAndRewritemlir::mhlo::GatherSlice277 LogicalResult matchAndRewrite(GatherOp gather,
278 PatternRewriter& rewriter) const override {
279 DenseIntElementsAttr index;
280 if (!matchPattern(gather.start_indices(), m_Constant(&index)))
281 return failure();
282
283 const auto& dnums = gather.dimension_numbers();
284 if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
285 return failure();
286
287 // TODO(tberghammer): Remove when the verifier catches this case what is
288 // invalid if all previous condition holds.
289 if (index.getNumElements() != dnums.start_index_map().getNumElements())
290 return failure();
291
292 auto slice_end =
293 llvm::to_vector<8>(gather.slice_sizes().getValues<int64_t>());
294 llvm::SmallVector<int64_t, 8> slice_start(slice_end.size(), 0);
295 for (auto it : llvm::zip(dnums.start_index_map().getIntValues(),
296 index.getIntValues())) {
297 int64_t map_index = std::get<0>(it).getSExtValue();
298 int64_t offset = std::get<1>(it).getSExtValue();
299 slice_start[map_index] += offset;
300 slice_end[map_index] += offset;
301 }
302
303 llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
304 llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
305 for (size_t i = 0; i < slice_end.size(); ++i) {
306 slice_shape[i] = slice_end[i] - slice_start[i];
307 }
308 Type element_type = gather.getType().cast<TensorType>().getElementType();
309 auto slice_type = RankedTensorType::get(slice_shape, element_type);
310 Value result = rewriter.create<SliceOp>(
311 gather.getLoc(), slice_type, gather.getOperand(0),
312 GetI64ElementsAttr(slice_start, &rewriter),
313 GetI64ElementsAttr(slice_end, &rewriter),
314 GetI64ElementsAttr(slice_stride, &rewriter));
315
316 if (dnums.collapsed_slice_dims().getNumElements() > 0) {
317 auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
318 dnums.collapsed_slice_dims().getIntValues(),
319 [](const llvm::APInt& i) { return i.getSExtValue(); }));
320 llvm::SmallVector<int64_t, 8> reshape_shape;
321 for (size_t i = 0; i < slice_shape.size(); ++i) {
322 if (llvm::count(collapsed_slice_dims, i) == 0) {
323 reshape_shape.push_back(slice_shape[i]);
324 }
325 }
326 auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
327 result =
328 rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
329 }
330
331 result.setType(gather.getType());
332 rewriter.replaceOp(gather, result);
333 return success();
334 }
335 };
336
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)337 void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
338 MLIRContext* context) {
339 results.insert<GatherSlice>(context);
340 }
341
342 namespace {
343
344 // following https://www.tensorflow.org/xla/operation_semantics#gather
345 // The bounds for the output array along dimension i is computed as follows:
346 // (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
347 // then we pick
348 // the corresponding dimension bounds out of start_indices.shape, skipping
349 // index_vector_dim
350 // (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
351 // start_indices.shape.dims[k+1] otherwise).
352 // (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
353 // then we pick
354 // the corresponding bound out of slice_sizes after accounting for
355 // collapsed_slice_dims
356 // (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
357 // slice_sizes with the bounds at indices collapsed_slice_dims removed).
358
GetSliceSizeValues(GatherOp * gather,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & slice_sizes)359 void GetSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
360 ValueRange operands,
361 SmallVectorImpl<Value>& slice_sizes) {
362 for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
363 slice_sizes.push_back(builder.create<ConstantIndexOp>(loc, val));
364 }
365 }
366
GetSliceSizeValues(DynamicGatherOp * d_gather,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & slice_size_values)367 void GetSliceSizeValues(DynamicGatherOp* d_gather, OpBuilder& builder,
368 Location loc, ValueRange operands,
369 SmallVectorImpl<Value>& slice_size_values) {
370 DynamicGatherOp::Adaptor adaptor(operands);
371 Value slice_sizes = adaptor.slice_sizes();
372 auto slice_sizes_ty = slice_sizes.getType().cast<ShapedType>();
373 for (int64_t i = 0; i < slice_sizes_ty.getDimSize(0); ++i) {
374 Value idx = builder.create<ConstantIndexOp>(loc, i);
375 slice_size_values.push_back(
376 builder.create<tensor::ExtractOp>(loc, slice_sizes, idx));
377 }
378 }
379
380 template <typename Op>
GatherShapeInferImpl(Op * op,OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)381 LogicalResult GatherShapeInferImpl(
382 Op* op, OpBuilder& builder, ValueRange operands,
383 SmallVectorImpl<Value>& reifiedReturnShapes) {
384 // Not support unranked pad a.t.m.
385 auto result_ty =
386 op->getResult().getType().template dyn_cast<RankedTensorType>();
387 if (!result_ty) return failure();
388
389 typename Op::Adaptor adaptor(operands);
390 Value start_indices = adaptor.start_indices();
391
392 Location loc = op->getLoc();
393 int result_rank = result_ty.getRank();
394 Type shape_scalar_type =
395 start_indices.getType().cast<ShapedType>().getElementType();
396 auto to_shape_scalar_type = [&](Value v) {
397 return MaybeCastTo(builder, loc, v, shape_scalar_type);
398 };
399
400 auto dimension_numbers = op->dimension_numbers();
401 SmallVector<int64_t, 4> collapsed_slice_dims(
402 dimension_numbers.collapsed_slice_dims().template getValues<int64_t>());
403 SmallVector<int64_t, 4> offset_dims(
404 dimension_numbers.offset_dims().template getValues<int64_t>());
405 int64_t index_vector_dim =
406 dimension_numbers.index_vector_dim().getValue().getSExtValue();
407
408 SmallVector<Value, 4> slice_sizes;
409 GetSliceSizeValues(op, builder, loc, operands, slice_sizes);
410 // Convert to `shape_scalar_type`
411 llvm::transform(slice_sizes, slice_sizes.begin(),
412 [&](Value v) { return to_shape_scalar_type(v); });
413
414 // we label dimensions in the output array not in offset_dims as batch_dims
415 SmallVector<int64_t, 4> batch_dims;
416 for (int64_t i = 0; i < result_rank; ++i) {
417 if (std::find(offset_dims.begin(), offset_dims.end(), i) ==
418 offset_dims.end()) {
419 batch_dims.push_back(i);
420 }
421 }
422 // adjusted_slice_sizes is slice_sizes with the bounds at indices
423 // collapsed_slice_dims removed
424 SmallVector<Value, 4> adjusted_slice_sizes;
425 for (int64_t i = 0; i < slice_sizes.size(); ++i) {
426 if (std::find(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
427 i) == collapsed_slice_dims.end()) {
428 adjusted_slice_sizes.push_back(slice_sizes[i]);
429 }
430 }
431
432 SmallVector<Value, 4> shape_values;
433 shape_values.reserve(result_rank);
434 for (int64_t i = 0; i < result_rank; ++i) {
435 auto iter = std::find(batch_dims.begin(), batch_dims.end(), i);
436 if (iter != batch_dims.end()) {
437 // i is present in batch_dims
438 int64_t k = std::distance(batch_dims.begin(), iter);
439 if (k < index_vector_dim) {
440 shape_values.push_back(to_shape_scalar_type(
441 builder.create<tensor::DimOp>(loc, start_indices, k)));
442 } else {
443 shape_values.push_back(to_shape_scalar_type(
444 builder.create<tensor::DimOp>(loc, start_indices, k + 1)));
445 }
446 } else {
447 // i is present in offset_dims
448 auto offset_dims_iter =
449 std::find(offset_dims.begin(), offset_dims.end(), i);
450 assert(offset_dims_iter != offset_dims.end());
451 int64_t k = std::distance(offset_dims.begin(), offset_dims_iter);
452 assert(k < adjusted_slice_sizes.size());
453 shape_values.push_back(adjusted_slice_sizes[k]);
454 }
455 }
456
457 Value output_shape = builder.create<tensor::FromElementsOp>(
458 loc, shape_scalar_type, shape_values);
459 reifiedReturnShapes.push_back(output_shape);
460
461 return success();
462 }
463
464 } // namespace
465
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)466 LogicalResult GatherOp::reifyReturnTypeShapes(
467 OpBuilder& builder, ValueRange operands,
468 SmallVectorImpl<Value>& reifiedReturnShapes) {
469 return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
470 }
471
472 //===----------------------------------------------------------------------===//
473 // DynamicGatherOp
474 //===----------------------------------------------------------------------===//
475 //
476
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)477 LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
478 OpBuilder& builder, ValueRange operands,
479 SmallVectorImpl<Value>& reifiedReturnShapes) {
480 return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
481 }
482
483 //===----------------------------------------------------------------------===//
484 // GetDimensionSizeOp
485 //===----------------------------------------------------------------------===//
486 //
Verify(GetDimensionSizeOp op)487 static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
488
489 /// Fold get_dimension_size when the said shape dimension is a constant.
fold(ArrayRef<Attribute> attrs)490 OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
491 RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
492 if (!type) return {};
493
494 int32_t dim = dimension();
495 if (type.isDynamic(dim)) return {};
496 // The result type is always is a 0-d i32 tensor.
497 return DenseIntElementsAttr::get<int32_t>(
498 getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
499 }
500
501 //===----------------------------------------------------------------------===//
502 // IotaOp
503 //===----------------------------------------------------------------------===//
504
Verify(IotaOp op)505 static LogicalResult Verify(IotaOp op) {
506 auto shape = op.getType().cast<ShapedType>();
507 if (!shape.hasRank()) return success();
508
509 if (shape.getRank() == 0)
510 return op.emitOpError() << "does not support scalars.";
511
512 auto iota_dimension = op.iota_dimension();
513 if (iota_dimension >= shape.getRank() || iota_dimension < 0)
514 return op.emitOpError() << "iota dimension cannot go beyond the output "
515 "rank or be negative.";
516 return success();
517 }
518
519 // Iota operations across multiple dimensions can be reduced to an iota and a
520 // ranked broadcast.
521 struct IotaBroadcast : public OpRewritePattern<IotaOp> {
522 using OpRewritePattern<IotaOp>::OpRewritePattern;
523
matchAndRewritemlir::mhlo::IotaBroadcast524 LogicalResult matchAndRewrite(IotaOp iota,
525 PatternRewriter& rewriter) const override {
526 auto result_ty = iota.getType().cast<ShapedType>();
527 if (!result_ty.hasRank() || result_ty.getRank() < 2) {
528 return failure();
529 }
530
531 auto iota_dimension = iota.iota_dimension();
532
533 auto iota_type = RankedTensorType::get(
534 {result_ty.getDimSize(iota_dimension)}, result_ty.getElementType());
535
536 auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
537 rewriter.getI64IntegerAttr(0));
538
539 auto broadcast_attr = DenseIntElementsAttr::get(
540 RankedTensorType::get({1}, rewriter.getIntegerType(64)),
541 {iota_dimension});
542 rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
543 broadcast_attr);
544 return success();
545 }
546 };
547
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)548 void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
549 MLIRContext* context) {
550 results.insert<IotaBroadcast>(context);
551 }
552
fold(ArrayRef<Attribute> operands)553 OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
554 auto dimension = iota_dimension();
555 auto result_ty = getResult().getType().cast<ShapedType>();
556 if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
557 Builder builder(getContext());
558 return builder.getZeroAttr(result_ty);
559 }
560
561 return {};
562 }
563
564 //===----------------------------------------------------------------------===//
565 // DynamicIotaOp
566 //===----------------------------------------------------------------------===//
567
568 namespace {
569
570 struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
571 using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
572
matchAndRewritemlir::mhlo::__anon7800207c0611::DynamicIotaIsStatic573 LogicalResult matchAndRewrite(DynamicIotaOp iota,
574 PatternRewriter& rewriter) const override {
575 auto result_ty = iota.getType().cast<ShapedType>();
576 if (!result_ty.hasStaticShape()) {
577 return failure();
578 }
579
580 rewriter.replaceOpWithNewOp<IotaOp>(iota, result_ty, iota.iota_dimension());
581 return success();
582 }
583 };
584
585 // Dynamic Iota operations across multiple dimensions can be reduced to an iota
586 // and a ranked broadcast.
587 struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
588 using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
589
matchAndRewritemlir::mhlo::__anon7800207c0611::DynamicIotaBroadcast590 LogicalResult matchAndRewrite(DynamicIotaOp iota,
591 PatternRewriter& rewriter) const override {
592 auto result_ty = iota.getType().cast<ShapedType>();
593 if (!result_ty.hasRank() || result_ty.getRank() < 2) {
594 return failure();
595 }
596
597 auto iota_dimension = iota.iota_dimension();
598 auto iota_dimension_int = iota_dimension;
599
600 auto converted_shape = rewriter.create<IndexCastOp>(
601 iota.getLoc(),
602 RankedTensorType::get(
603 iota.output_shape().getType().cast<ShapedType>().getShape(),
604 rewriter.getI64Type()),
605 iota.output_shape());
606
607 auto sliced_shape = rewriter.create<SliceOp>(
608 iota.getLoc(), converted_shape,
609 GetI64ElementsAttr(iota_dimension_int, &rewriter),
610 GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
611 GetI64ElementsAttr(1, &rewriter));
612
613 auto converted_sliced_shape = rewriter.create<IndexCastOp>(
614 iota.getLoc(),
615 RankedTensorType::get(
616 {1},
617 iota.output_shape().getType().cast<ShapedType>().getElementType()),
618 sliced_shape);
619
620 auto iota_type = RankedTensorType::get(
621 {result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
622
623 auto new_iota = rewriter.create<DynamicIotaOp>(
624 iota.getLoc(), iota_type, converted_sliced_shape,
625 rewriter.getI64IntegerAttr(0));
626
627 auto broadcast_attr = DenseIntElementsAttr::get(
628 RankedTensorType::get({1}, rewriter.getIntegerType(64)),
629 {iota_dimension});
630 rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
631 iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
632 return success();
633 }
634 };
635
636 } // namespace
637
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)638 void DynamicIotaOp::getCanonicalizationPatterns(
639 OwningRewritePatternList& results, MLIRContext* context) {
640 results.insert<DynamicIotaIsStatic>(context);
641 results.insert<DynamicIotaBroadcast>(context);
642 }
643
castToIndexTensor(OpBuilder & builder,Location loc,Value shape_op)644 static Value castToIndexTensor(OpBuilder& builder, Location loc,
645 Value shape_op) {
646 ShapedType result_ty = shape::getExtentTensorType(
647 builder.getContext(),
648 shape_op.getType().cast<ShapedType>().getDimSize(0));
649 if (shape_op.getType() == result_ty) return shape_op; // Nothing to do.
650 return builder.create<IndexCastOp>(loc, shape_op, result_ty);
651 }
652
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)653 LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
654 OpBuilder& builder, ValueRange operands,
655 SmallVectorImpl<Value>& reifiedReturnShapes) {
656 DynamicIotaOp::Adaptor adaptor(operands);
657 reifiedReturnShapes.push_back(
658 castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
659 return success();
660 }
661
662 //===----------------------------------------------------------------------===//
663 // DynamicUpdateSliceOp
664 //===----------------------------------------------------------------------===//
665
Verify(DynamicUpdateSliceOp op)666 static LogicalResult Verify(DynamicUpdateSliceOp op) {
667 OperandRange indices = op.start_indices();
668 if (indices.size() <= 1) return success();
669
670 // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
671 // is OK to cast indices to ShapedType here.
672 auto idx_tensor = indices.take_front().front().getType().cast<ShapedType>();
673 Type first_elem_ty = idx_tensor.getElementType();
674 Type elem_ty;
675
676 for (auto idx : llvm::drop_begin(indices, 1)) {
677 idx_tensor = idx.getType().cast<ShapedType>();
678 elem_ty = idx_tensor.getElementType();
679
680 if (first_elem_ty != elem_ty) {
681 return op.emitOpError() << "start indices must have same element type "
682 "(encountered mismatch: "
683 << first_elem_ty << " vs " << elem_ty << ")";
684 }
685 }
686 return success();
687 }
688
fold(ArrayRef<Attribute> operands)689 OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef<Attribute> operands) {
690 auto operand_shape = this->operand().getType().cast<RankedTensorType>();
691 auto update_shape = this->update().getType().cast<RankedTensorType>();
692
693 if (operand_shape != update_shape || !operand_shape.hasStaticShape()) {
694 return {};
695 }
696
697 // Ensure that indices are 0 constants. The 0 check mostly ensures
698 // correctness. For non-constants, the pattern does not fold to avoid hiding
699 // the behavior of incorrect user input.
700 for (Value index : this->start_indices()) {
701 DenseIntElementsAttr de_attr;
702 if (!matchPattern(index, m_Constant(&de_attr))) return {};
703 int start_val = de_attr.getSplatValue<IntegerAttr>().getInt();
704 if (start_val != 0) return {};
705 }
706 return this->update();
707 }
708
709 //===----------------------------------------------------------------------===//
710 // AbsOp
711 //===----------------------------------------------------------------------===//
712
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)713 LogicalResult AbsOp::inferReturnTypes(
714 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
715 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
716 auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
717 Type element_ty = operand_ty.getElementType();
718 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
719 element_ty = complex_ty.getElementType();
720 }
721
722 Type result_ty;
723 if (operand_ty.hasRank()) {
724 result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
725 } else {
726 result_ty = UnrankedTensorType::get(element_ty);
727 }
728 inferredReturnTypes.push_back(result_ty);
729 return success();
730 }
731
732 //===----------------------------------------------------------------------===//
733 // CollectivePermuteOp
734 //===----------------------------------------------------------------------===//
735
Verify(CollectivePermuteOp op)736 static LogicalResult Verify(CollectivePermuteOp op) {
737 return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs(
738 op, op.source_target_pairs());
739 }
740
741 //===----------------------------------------------------------------------===//
742 // ConvertOp
743 //===----------------------------------------------------------------------===//
744
build(OpBuilder & builder,OperationState & result,Value operand,Type result_element_ty)745 void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
746 Type result_element_ty) {
747 Type result_ty;
748 Type operand_ty = operand.getType();
749 if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
750 result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
751 } else {
752 result_ty = UnrankedTensorType::get(result_element_ty);
753 }
754 build(builder, result, result_ty, operand);
755 }
756
fold(ArrayRef<Attribute> operands)757 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
758 auto operand_ty = getOperand().getType().cast<TensorType>();
759 auto result_ty = getResult().getType().cast<TensorType>();
760 if (operand_ty == result_ty) return getOperand();
761
762 // If the result has non-static shape, a convert op is necessary to go from
763 // static shape to non-static shape.
764 if (!result_ty.hasStaticShape()) return {};
765
766 // TODO(hinsu): Handle unsigned types.
767 if (operand_ty.getElementType().isUnsignedInteger() ||
768 result_ty.getElementType().isUnsignedInteger()) {
769 return {};
770 }
771
772 // If the operand is constant, we can do the conversion now.
773 if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
774 return hlo::ConvertElementsAttr(elementsAttr,
775 getElementTypeOrSelf(getResult()));
776 }
777
778 return {};
779 }
780
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)781 void ConvertOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
782 MLIRContext* context) {
783 results.insert<EliminateIdentityConvert>(context);
784 }
785
786 //===----------------------------------------------------------------------===//
787 // DequantizeOp
788 //===----------------------------------------------------------------------===//
789
Verify(DequantizeOp op)790 static LogicalResult Verify(DequantizeOp op) {
791 auto input_type = op.input().getType().dyn_cast<ShapedType>();
792 auto output_type = op.output().getType().dyn_cast<ShapedType>();
793 if (!input_type || !output_type) {
794 return op.emitError() << "ranked input and output.";
795 }
796 auto input_shape = input_type.getShape();
797 auto output_shape = output_type.getShape().vec();
798 if (op.transpose_output()) {
799 std::reverse(output_shape.begin(), output_shape.end());
800 }
801
802 // Check the input rank and output rank are same, and also the lower
803 // dimensions are same.
804 if (input_shape.size() != output_shape.size() ||
805 !std::equal(input_shape.begin(),
806 std::next(input_shape.begin(), input_shape.size() - 1),
807 output_shape.begin())) {
808 return op.emitError() << "mismatched dimensions.";
809 }
810
811 // Check that the last dimension of the output is 2x or 4x of that of the
812 // input depending on the unpacked input is 16 or 8 bits.
813 int input_last_dim = *input_shape.rbegin();
814 int output_last_dim = *output_shape.rbegin();
815 int scale_factor = op.is_16bits() ? 2 : 4;
816 if (output_last_dim != scale_factor * input_last_dim) {
817 return op.emitError() << "last dimension of output should be "
818 << scale_factor << "x of the input.";
819 }
820
821 return success();
822 }
823
824 //===----------------------------------------------------------------------===//
825 // GetTupleElementOp
826 //===----------------------------------------------------------------------===//
827
Verify(GetTupleElementOp op)828 static LogicalResult Verify(GetTupleElementOp op) {
829 auto indexVal = op.index();
830 auto operandType = op.getOperand().getType().cast<TupleType>();
831 if (indexVal >= operandType.size()) {
832 return op.emitOpError(
833 llvm::formatv("index {0} is out of bounds of operand with size {1}",
834 indexVal, operandType.size()));
835 }
836
837 auto expectedType = operandType.getType(indexVal);
838 if (op.getType() != expectedType) {
839 return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
840 op.getType(), expectedType));
841 }
842 return success();
843 }
844
fold(ArrayRef<Attribute> operands)845 OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
846 if (auto tuple_op = getOperand().getDefiningOp<mhlo::TupleOp>()) {
847 return tuple_op.getOperand(index());
848 }
849
850 return {};
851 }
852
853 //===----------------------------------------------------------------------===//
854 // TupleOp
855 //===----------------------------------------------------------------------===//
856
Verify(TupleOp op)857 static LogicalResult Verify(TupleOp op) {
858 auto opType = op.getType().dyn_cast<TupleType>();
859 if (!opType) return op.emitOpError("tuple op with non-tuple result");
860 if (op.getNumOperands() != opType.size())
861 return op.emitOpError(
862 "number of operands to tuple expected to match number of types in "
863 "resultant tuple type");
864 for (auto it : llvm::enumerate(
865 llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) {
866 if (std::get<0>(it.value()) != std::get<1>(it.value()))
867 return op.emitOpError("has return type mismatch at ")
868 << it.index() << "th value (" << std::get<0>(it.value())
869 << " != " << std::get<1>(it.value()) << ")";
870 }
871 return success();
872 }
873
874 namespace {
875
876 // Pattern for unpacking and repacking the same tuple.
877 struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
878 using OpRewritePattern<TupleOp>::OpRewritePattern;
879
matchAndRewritemlir::mhlo::__anon7800207c0711::UnpackRepackSameTuple880 LogicalResult matchAndRewrite(TupleOp op,
881 PatternRewriter& rewriter) const override {
882 if (op.val().empty()) return failure();
883
884 Value first_element = op.val().front();
885 auto first_element_op = first_element.getDefiningOp<GetTupleElementOp>();
886 if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
887 return failure();
888
889 Value tuple_predecessor = first_element_op.getOperand();
890 if (tuple_predecessor.getType() != op.getType()) return failure();
891
892 for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
893 auto element_op =
894 element_and_idx.value().getDefiningOp<GetTupleElementOp>();
895 if (!element_op ||
896 element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
897 element_op.getOperand() != tuple_predecessor)
898 return failure();
899 }
900
901 rewriter.replaceOp(op, tuple_predecessor);
902 return success();
903 }
904 };
905
906 } // namespace
907
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)908 void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
909 MLIRContext* context) {
910 results.insert<UnpackRepackSameTuple>(context);
911 }
912
913 //===----------------------------------------------------------------------===//
914 // AllToAllOp
915 //===----------------------------------------------------------------------===//
916
Verify(AllToAllOp op)917 static LogicalResult Verify(AllToAllOp op) {
918 // If operand is ranked, size of split dimension should be a multiple of split
919 // count.
920 auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
921 if (!type) return success();
922 auto split_dim_size = type.getDimSize(op.split_dimension());
923 auto split_count = op.split_count();
924 if (split_dim_size % split_count != 0) {
925 return op.emitError() << "split dimension has size " << split_dim_size
926 << ", expected to be a multiple of split_count "
927 << split_count;
928 }
929 return success();
930 }
931
932 //===----------------------------------------------------------------------===//
933 // AllGatherOp
934 //===----------------------------------------------------------------------===//
935
Verify(AllGatherOp op)936 static LogicalResult Verify(AllGatherOp op) {
937 // If operand and result are both ranked, then the size of the gather
938 // dimension in the result should be a multiple of the size of the gather
939 // dimension in the operand.
940 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
941 auto resultType = op.getType().dyn_cast<RankedTensorType>();
942 uint64_t allGatherDimIndex = op.all_gather_dim();
943 if (!operandType || !resultType ||
944 operandType.isDynamicDim(allGatherDimIndex) ||
945 resultType.isDynamicDim(allGatherDimIndex))
946 return success();
947 if (operandType.getDimSize(allGatherDimIndex) == 0)
948 return op.emitOpError() << "operand gather dimension cannot be zero.";
949 if ((resultType.getDimSize(allGatherDimIndex) %
950 operandType.getDimSize(allGatherDimIndex)) != 0)
951 return op.emitOpError()
952 << "result gather dimension has size "
953 << resultType.getDimSize(allGatherDimIndex)
954 << ", expected to be a multiple of operand gather dimension size "
955 << operandType.getDimSize(allGatherDimIndex);
956
957 return success();
958 }
959
960 //===----------------------------------------------------------------------===//
961 // BroadcastOp
962 //===----------------------------------------------------------------------===//
963
964 // TODO(b/129012527) These should be expressed as type constraints.
Verify(BroadcastOp op)965 static LogicalResult Verify(BroadcastOp op) {
966 auto sizes = op.broadcast_sizes();
967 auto sizesType = sizes.getType();
968 auto sizesRank = sizesType.getRank();
969 if (sizesRank != 1) {
970 return op.emitOpError(llvm::formatv(
971 "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
972 }
973
974 auto resultType = op.getResult().getType().cast<RankedTensorType>();
975 auto resultRank = resultType.getRank();
976 auto operandType = op.operand().getType().cast<RankedTensorType>();
977 auto operandRank = operandType.getRank();
978 auto sizesSize = sizesType.getNumElements();
979 auto expectedRank = operandRank + sizesSize;
980
981 if (resultRank != expectedRank) {
982 return op.emitOpError(
983 llvm::formatv("result rank ({0}) does not match operand rank "
984 "({1}) plus size of broadcast_sizes ({2})",
985 resultRank, operandRank, sizesSize));
986 }
987
988 llvm::SmallVector<int64_t, 10> expectedShape(sizes.getValues<int64_t>());
989
990 auto operandShape = operandType.getShape();
991 expectedShape.insert(expectedShape.end(), operandShape.begin(),
992 operandShape.end());
993
994 auto resultShape = resultType.getShape();
995 if (resultShape != llvm::makeArrayRef(expectedShape)) {
996 return op.emitOpError(llvm::formatv(
997 "result has shape [{0}] instead of [{1}]",
998 llvm::make_range(resultShape.begin(), resultShape.end()),
999 llvm::make_range(expectedShape.begin(), expectedShape.end())));
1000 }
1001
1002 return success();
1003 }
1004
1005 //===----------------------------------------------------------------------===//
1006 // BroadcastInDimOp
1007 //===----------------------------------------------------------------------===//
1008
Verify(BroadcastInDimOp op)1009 static LogicalResult Verify(BroadcastInDimOp op) {
1010 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
1011 if (!operandType) {
1012 // The following verification checks all depend on knowing the rank of
1013 // the operand. Bail out now if we don't know the rank of the operand.
1014 return success();
1015 }
1016
1017 auto operandRank = operandType.getRank();
1018 if (!op.broadcast_dimensions()) {
1019 if (operandRank == 0) {
1020 return success();
1021 }
1022 return op.emitOpError(
1023 llvm::formatv("broadcast_dimensions is absent, but required because "
1024 "operand has non-zero rank ({0})",
1025 operandRank));
1026 }
1027
1028 auto dimensions = op.broadcast_dimensions();
1029 auto dimensionsType = op.broadcast_dimensions().getType();
1030 auto dimensionsRank = dimensionsType.getRank();
1031 if (dimensionsRank != 1) {
1032 return op.emitOpError(llvm::formatv(
1033 "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
1034 }
1035
1036 auto dimensionsSize = dimensionsType.getNumElements();
1037 if (dimensionsSize != operandRank) {
1038 return op.emitOpError(llvm::formatv(
1039 "broadcast_dimensions size ({0}) does not match operand rank ({1})",
1040 dimensionsSize, operandRank));
1041 }
1042
1043 auto resultType = op.getResult().getType().cast<RankedTensorType>();
1044 auto resultRank = resultType.getRank();
1045 if (resultRank < operandRank) {
1046 return op.emitOpError(
1047 llvm::formatv("result rank ({0}) is less than operand rank ({1})",
1048 resultRank, operandRank));
1049 }
1050
1051 for (int i = 0; i != dimensionsSize; ++i) {
1052 auto dimIndex = dimensions.getValue<int64_t>(i);
1053 if (dimIndex >= resultRank) {
1054 return op.emitOpError(
1055 llvm::formatv("broadcast_dimensions contains invalid value {0} for "
1056 "result with rank {1}",
1057 dimIndex, resultRank));
1058 }
1059
1060 if (!operandType.isDynamicDim(i)) {
1061 auto dimSize = operandType.getDimSize(i);
1062 auto resultDimSize = resultType.getDimSize(dimIndex);
1063 if (dimSize != 1 && dimSize != resultDimSize) {
1064 return op.emitOpError(
1065 llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
1066 "1 or size of result dimension {2} ({3})",
1067 i, dimSize, dimIndex, resultDimSize));
1068 }
1069 }
1070 }
1071
1072 return success();
1073 }
1074
fold(ArrayRef<Attribute> attrs)1075 OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
1076 auto type = getType().cast<RankedTensorType>();
1077 if (type == getOperand().getType()) {
1078 auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
1079 if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
1080 llvm::seq<int64_t>(0, type.getRank()).begin())) {
1081 return {};
1082 }
1083 return getOperand();
1084 }
1085
1086 // Constant fold when an operand is a splat tensor attribute.
1087 if (!attrs[0] || !type.hasStaticShape()) return {};
1088 auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
1089 if (!splatOperandAttr) return {};
1090 // MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element
1091 // attribute iterator not implemented for complex element types.
1092 if (type.getElementType().isa<ComplexType>()) return {};
1093 return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue());
1094 }
1095
1096 //===----------------------------------------------------------------------===//
1097 // DynamicBroadcastInDimOp
1098 //===----------------------------------------------------------------------===//
1099
Verify(DynamicBroadcastInDimOp op)1100 static LogicalResult Verify(DynamicBroadcastInDimOp op) {
1101 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
1102 auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
1103
1104 // If either the operand or result are unranked, there is very little
1105 // to verify statically.
1106 if (!operandType || !resultType) {
1107 return success();
1108 }
1109
1110 auto outputDimensionsType =
1111 op.output_dimensions().getType().cast<RankedTensorType>();
1112 auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
1113 auto operandRank = operandType.getRank();
1114 auto resultRank = resultType.getRank();
1115
1116 // Verify broadcast_dimensions.
1117 auto bcastDimensions = op.broadcast_dimensions();
1118 auto bcastDimensionsType = op.broadcast_dimensions().getType();
1119 auto bcastDimensionsRank = bcastDimensionsType.getRank();
1120 // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
1121 if (bcastDimensionsRank != 1) {
1122 return op.emitOpError(
1123 llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
1124 bcastDimensionsRank));
1125 }
1126
1127 auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
1128 if (bcastDimensionsSize != operandRank) {
1129 return op.emitOpError(llvm::formatv(
1130 "broadcast_dimensions size ({0}) does not match operand rank ({1})",
1131 bcastDimensionsSize, operandRank));
1132 }
1133
1134 if (resultRank < operandRank) {
1135 return op.emitOpError(
1136 llvm::formatv("result rank ({0}) is less than operand rank ({1})",
1137 resultRank, operandRank));
1138 }
1139
1140 for (int i = 0; i != bcastDimensionsSize; ++i) {
1141 auto dimIndex = bcastDimensions.getValue<int64_t>(i);
1142 if (dimIndex >= resultRank) {
1143 return op.emitOpError(
1144 llvm::formatv("broadcast_dimensions contains invalid value {0} for "
1145 "result with rank {1}",
1146 dimIndex, resultRank));
1147 }
1148
1149 auto dimSize = operandType.getDimSize(i);
1150 auto resultDimSize = resultType.getDimSize(dimIndex);
1151 // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
1152 // add a manual check for this.
1153 if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
1154 return op.emitOpError(
1155 llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
1156 "with size of result dimension {2} ({3})",
1157 i, dimSize, dimIndex, resultDimSize));
1158 }
1159 }
1160
1161 if (outputDimensionsSize != resultRank) {
1162 return op.emitOpError(
1163 llvm::formatv("result rank ({0}) is not equal to number of output "
1164 "dimensions ({1})",
1165 resultRank, outputDimensionsSize));
1166 }
1167
1168 return success();
1169 }
1170
1171 namespace {
1172 // If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
1173 // BroadcastInDimOp.
1174 class DynamicBroadcastInDimOpNotActuallyDynamic
1175 : public OpRewritePattern<DynamicBroadcastInDimOp> {
1176 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp op,PatternRewriter & rewriter) const1177 LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
1178 PatternRewriter& rewriter) const override {
1179 auto type = op.getType().dyn_cast<RankedTensorType>();
1180 if (!type || !type.hasStaticShape()) {
1181 return rewriter.notifyMatchFailure(op, "requires static shape");
1182 }
1183 rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
1184 op, op.getType(), op.operand(), op.broadcast_dimensions());
1185 return success();
1186 }
1187 };
1188
1189 class ChainedDynamicBroadcastInDimCanonicalization
1190 : public OpRewritePattern<DynamicBroadcastInDimOp> {
1191 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp bcast,PatternRewriter & rewriter) const1192 LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast,
1193 PatternRewriter& rewriter) const override {
1194 auto preceding_bcast =
1195 bcast.operand().getDefiningOp<DynamicBroadcastInDimOp>();
1196 if (!preceding_bcast) return failure();
1197
1198 // Compose broadcast dimensions.
1199 DenseIntElementsAttr preceding_bcast_dims =
1200 preceding_bcast.broadcast_dimensions();
1201 DenseIntElementsAttr bcast_dims = bcast.broadcast_dimensions();
1202 SmallVector<APInt, 4> composition;
1203 for (APInt preceding_dim : preceding_bcast_dims) {
1204 auto composed_dim = bcast_dims.getValue({preceding_dim.getZExtValue()})
1205 .cast<IntegerAttr>();
1206 composition.push_back(composed_dim.getValue());
1207 }
1208 auto composed_bcast_dims =
1209 DenseIntElementsAttr::get(preceding_bcast_dims.getType(), composition);
1210
1211 rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
1212 bcast, bcast.getType(), preceding_bcast.operand(),
1213 bcast.output_dimensions(), composed_bcast_dims);
1214 return success();
1215 }
1216 };
1217 } // namespace
1218
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1219 void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
1220 OwningRewritePatternList& results, MLIRContext* context) {
1221 results.insert<ChainedDynamicBroadcastInDimCanonicalization,
1222 DynamicBroadcastInDimOpNotActuallyDynamic,
1223 DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2,
1224 DynamicBroadcastToOwnShape_3, DynamicBroadcastToOwnShape_4>(
1225 context);
1226 }
1227
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1228 LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
1229 OpBuilder& builder, ValueRange operands,
1230 SmallVectorImpl<Value>& reifiedReturnShapes) {
1231 DynamicBroadcastInDimOp::Adaptor adaptor(operands);
1232 reifiedReturnShapes.push_back(
1233 castToIndexTensor(builder, getLoc(), adaptor.output_dimensions()));
1234 return success();
1235 }
1236
1237 //===----------------------------------------------------------------------===//
1238 // ClampOp
1239 //===----------------------------------------------------------------------===//
1240
Verify(ClampOp op)1241 static LogicalResult Verify(ClampOp op) {
1242 auto operandType = op.operand().getType().cast<RankedTensorType>();
1243 auto operandShape = operandType.getShape();
1244 auto minType = op.min().getType().cast<RankedTensorType>();
1245
1246 auto minShape = minType.getShape();
1247 if (minShape != operandShape && minType.getRank() != 0) {
1248 return op.emitOpError(llvm::formatv(
1249 "min shape [{0}] is not scalar and does not match operand shape [{1}]",
1250 llvm::make_range(minShape.begin(), minShape.end()),
1251 llvm::make_range(operandShape.begin(), operandShape.end())));
1252 }
1253
1254 auto maxType = op.max().getType().cast<RankedTensorType>();
1255 auto maxShape = maxType.getShape();
1256 if (maxShape != operandShape && maxType.getRank() != 0) {
1257 return op.emitOpError(llvm::formatv(
1258 "max shape [{0}] is not scalar and does not match operand shape [{1}]",
1259 llvm::make_range(maxShape.begin(), maxShape.end()),
1260 llvm::make_range(operandShape.begin(), operandShape.end())));
1261 }
1262
1263 return success();
1264 }
1265
1266 //===----------------------------------------------------------------------===//
1267 // ComplexOp
1268 //===----------------------------------------------------------------------===//
1269
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1270 LogicalResult ComplexOp::inferReturnTypes(
1271 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1272 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1273 auto type = operands[0].getType();
1274 auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
1275 Type result_ty;
1276 if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
1277 result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
1278 } else if (type.isa<UnrankedTensorType>()) {
1279 result_ty = UnrankedTensorType::get(element_ty);
1280 } else {
1281 result_ty = element_ty;
1282 }
1283 inferredReturnTypes.push_back(result_ty);
1284 return success();
1285 }
1286
fold(ArrayRef<Attribute> operands)1287 OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
1288 auto real_op = getOperand(0).getDefiningOp<mhlo::RealOp>();
1289 auto imag_op = getOperand(1).getDefiningOp<mhlo::ImagOp>();
1290 if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
1291 return real_op.getOperand();
1292 }
1293
1294 return {};
1295 }
1296
1297 //===----------------------------------------------------------------------===//
1298 // ImagOp
1299 //===----------------------------------------------------------------------===//
1300
1301 namespace {
CreateRealType(Type type)1302 Type CreateRealType(Type type) {
1303 auto element_ty = getElementTypeOrSelf(type);
1304 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
1305 element_ty = complex_ty.getElementType();
1306 }
1307
1308 if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
1309 return RankedTensorType::get(ranked_type.getShape(), element_ty);
1310 } else if (type.dyn_cast<UnrankedTensorType>()) {
1311 return UnrankedTensorType::get(element_ty);
1312 }
1313
1314 return element_ty;
1315 }
1316 } // namespace
1317
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1318 LogicalResult ImagOp::inferReturnTypes(
1319 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1320 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1321 inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
1322 return success();
1323 }
1324
fold(ArrayRef<Attribute> operands)1325 OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
1326 if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
1327 return complex_op.getOperand(1);
1328 }
1329
1330 return {};
1331 }
1332
1333 //===----------------------------------------------------------------------===//
1334 // IsFiniteOp
1335 //===----------------------------------------------------------------------===//
1336
getSameShapeTensorType(TensorType tensor_type,Type element_type)1337 TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) {
1338 if (auto ranked_tensor_ty = tensor_type.dyn_cast<RankedTensorType>()) {
1339 return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type);
1340 }
1341 if (auto unranked_tensor_ty = tensor_type.dyn_cast<UnrankedTensorType>()) {
1342 return UnrankedTensorType::get(element_type);
1343 }
1344 llvm_unreachable("unhandled type");
1345 }
1346
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1347 LogicalResult IsFiniteOp::inferReturnTypes(
1348 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
1349 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1350 auto arg_ty = operands.front().getType().cast<TensorType>();
1351 Builder b(ctx);
1352 inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type()));
1353 return success();
1354 }
1355
1356 //===----------------------------------------------------------------------===//
1357 // RealOp
1358 //===----------------------------------------------------------------------===//
1359
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1360 LogicalResult RealOp::inferReturnTypes(
1361 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1362 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1363 inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
1364 return success();
1365 }
1366
fold(ArrayRef<Attribute> operands)1367 OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
1368 if (auto complex_op = getOperand().getDefiningOp<mhlo::ComplexOp>()) {
1369 return complex_op.getOperand(0);
1370 }
1371
1372 return {};
1373 }
1374
1375 //===----------------------------------------------------------------------===//
1376 // ConcatenateOp
1377 //===----------------------------------------------------------------------===//
1378
1379 namespace {
1380 class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
1381 public:
1382 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const1383 LogicalResult matchAndRewrite(ConcatenateOp op,
1384 PatternRewriter& rewriter) const override {
1385 auto axis = op.dimension();
1386 llvm::SmallVector<Value, 6> new_operands;
1387 for (auto operand : op.getOperands()) {
1388 auto ty = operand.getType().cast<ShapedType>();
1389 if (ty.getDimSize(axis) != 0) {
1390 new_operands.push_back(operand);
1391 }
1392 }
1393
1394 if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
1395 rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
1396 new_operands, op.dimension());
1397 return success();
1398 }
1399
1400 return failure();
1401 }
1402 };
1403 } // namespace
1404
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1405 LogicalResult ConcatenateOp::inferReturnTypes(
1406 MLIRContext*, Optional<Location> location, ValueRange operands,
1407 DictionaryAttr attributes, RegionRange regions,
1408 SmallVectorImpl<Type>& inferredReturnTypes) {
1409 if (operands.empty()) {
1410 return failure();
1411 }
1412
1413 auto dimension_attr = attributes.get("dimension").cast<IntegerAttr>();
1414 auto dimension = dimension_attr.getInt();
1415
1416 auto first_type = (*operands.begin()).getType().cast<ShapedType>();
1417 auto out_element = first_type.getElementType();
1418
1419 for (auto operand : operands.getTypes()) {
1420 auto element_type = getElementTypeOrSelf(operand);
1421 if (element_type != out_element) {
1422 return failure();
1423 }
1424 }
1425
1426 // Find the first ranked input to determine the output rank.
1427 for (auto type : operands.getTypes()) {
1428 auto shaped_type = type.cast<ShapedType>();
1429 if (shaped_type.hasRank()) {
1430 first_type = shaped_type;
1431 break;
1432 }
1433 }
1434
1435 // If all inputs are unranked, the result must be unranked.
1436 if (!first_type.hasRank()) {
1437 inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
1438 return success();
1439 }
1440
1441 if (first_type.getRank() == 0)
1442 return emitOptionalError(location, "rank-0 values cannot be concatenated");
1443
1444 auto out_shape = llvm::to_vector<6>(first_type.getShape());
1445
1446 // Determine what the non-concatenate dimensions should be.
1447 for (auto type : operands.getTypes()) {
1448 auto shaped_ty = type.cast<ShapedType>();
1449 if (!shaped_ty.hasRank()) {
1450 continue;
1451 }
1452
1453 for (auto it : llvm::enumerate(shaped_ty.getShape())) {
1454 // If a dimension is not dynamic, the output shape should match.
1455 if (ShapedType::isDynamic(out_shape[it.index()])) {
1456 out_shape[it.index()] = it.value();
1457 }
1458 }
1459 }
1460
1461 out_shape[dimension] = 0;
1462
1463 for (auto operand : operands.getTypes()) {
1464 auto type = operand.cast<ShapedType>();
1465 if (!type.hasRank()) {
1466 inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
1467 return success();
1468 }
1469
1470 // If the dimension is dynamic we know the output dimension is dynamic.
1471 auto dim = type.getShape()[dimension];
1472 if (dim == -1) {
1473 out_shape[dimension] = -1;
1474 break;
1475 }
1476
1477 out_shape[dimension] += dim;
1478 }
1479
1480 inferredReturnTypes.push_back(RankedTensorType::get(out_shape, out_element));
1481
1482 return success();
1483 }
1484
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1485 void ConcatenateOp::getCanonicalizationPatterns(
1486 OwningRewritePatternList& results, MLIRContext* context) {
1487 results.insert<ConcatenateOperandRemoval>(context);
1488 }
1489
1490 template <typename T>
foldConcatenateHelper(ConcatenateOp * op,ArrayRef<Attribute> operands)1491 static Attribute foldConcatenateHelper(ConcatenateOp* op,
1492 ArrayRef<Attribute> operands) {
1493 auto axis = op->dimension();
1494 auto type = op->getType().cast<ShapedType>();
1495
1496 SmallVector<T, 6> values;
1497 auto shape = type.getShape();
1498
1499 size_t top_size = 1;
1500 for (int i = 0, e = axis; i < e; i++) {
1501 top_size = top_size * shape[i];
1502 }
1503
1504 for (size_t i = 0; i < top_size; i++) {
1505 for (auto operand : operands) {
1506 DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
1507 size_t bottom_size = attr.getNumElements() / top_size;
1508 auto iter = attr.getValues<T>().begin() + i * bottom_size;
1509 values.append(iter, iter + bottom_size);
1510 }
1511 }
1512
1513 return DenseElementsAttr::get(type, values);
1514 }
1515
foldConcatenate(ConcatenateOp * op,ArrayRef<Attribute> operands)1516 static Attribute foldConcatenate(ConcatenateOp* op,
1517 ArrayRef<Attribute> operands) {
1518 for (auto operand : operands) {
1519 if (!operand) return {};
1520 }
1521
1522 auto type = op->getResult().getType().cast<ShapedType>();
1523 auto etype = type.getElementType();
1524 if (etype.isa<IntegerType>()) {
1525 return foldConcatenateHelper<APInt>(op, operands);
1526 }
1527
1528 if (etype.isa<FloatType>()) {
1529 return foldConcatenateHelper<APFloat>(op, operands);
1530 }
1531
1532 return {};
1533 }
1534
fold(ArrayRef<Attribute> operands)1535 OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
1536 if (getNumOperands() == 1) return getOperand(0);
1537
1538 ShapedType type = getResult().getType().cast<ShapedType>();
1539 if (!type.hasStaticShape()) return {};
1540
1541 auto axis = dimension();
1542 if (auto attr = foldConcatenate(this, operands)) {
1543 return attr;
1544 }
1545
1546 llvm::SmallVector<Value, 6> new_operands;
1547 for (auto operand : getOperands()) {
1548 auto ty = operand.getType().cast<ShapedType>();
1549 if (ty.getDimSize(axis) != 0) {
1550 return {};
1551 }
1552 }
1553
1554 return DenseElementsAttr::get(type, ArrayRef<Attribute>());
1555 }
1556
Verify(ConcatenateOp op)1557 static LogicalResult Verify(ConcatenateOp op) {
1558 Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
1559 RankedTensorType first_ranked_type;
1560 int num_operands = op.getNumOperands();
1561 for (int i = 0; i < num_operands; i++) {
1562 auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
1563 if (second_type.getElementType() != element_type) {
1564 return op.emitOpError(
1565 llvm::formatv("operands (0) and ({0}) do not match element type", i));
1566 }
1567
1568 if (!second_type.hasRank()) {
1569 continue;
1570 }
1571
1572 if (!first_ranked_type) {
1573 first_ranked_type = second_type.cast<RankedTensorType>();
1574 continue;
1575 }
1576
1577 if (first_ranked_type.getRank() != second_type.getRank()) {
1578 return op.emitOpError(
1579 llvm::formatv("operands (0) and ({0}) do not match rank", i));
1580 }
1581
1582 auto first_shape = second_type.getShape();
1583 auto second_shape = second_type.getShape();
1584 for (int d = 0; d < first_ranked_type.getRank(); ++d) {
1585 if (first_shape[d] != second_shape[d] && d != op.dimension()) {
1586 return op.emitOpError(llvm::formatv(
1587 "operands (0) and ({0}) non-concat dimensions do not match "
1588 "({1}) != ({2})",
1589 i, llvm::make_range(first_shape.begin(), first_shape.end()),
1590 llvm::make_range(second_shape.begin(), second_shape.end())));
1591 }
1592 }
1593 }
1594 return success();
1595 }
1596
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1597 LogicalResult ConcatenateOp::reifyReturnTypeShapes(
1598 OpBuilder& builder, ValueRange operands,
1599 SmallVectorImpl<Value>& reifiedReturnShapes) {
1600 ConcatenateOp::Adaptor adaptor(operands);
1601 auto inputs = adaptor.val();
1602
1603 auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
1604 // Not support unranked type a.t.m.
1605 if (!operand_type) return failure();
1606
1607 Location loc = this->getLoc();
1608 Type shape_scalar_type = builder.getIndexType();
1609 auto to_shape_scalar_type = [&](Value v) {
1610 return MaybeCastTo(builder, loc, v, shape_scalar_type);
1611 };
1612
1613 SmallVector<SmallVector<Value, 4>, 4> all_shape_values;
1614 for (size_t input_id = 0; input_id < inputs.size(); ++input_id) {
1615 Value operand = inputs[input_id];
1616 auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
1617 if (!operand_type) return failure();
1618
1619 SmallVector<Value, 4> shape_vals;
1620 for (const auto& element : llvm::enumerate(operand_type.getShape())) {
1621 Value value_dim = to_shape_scalar_type(
1622 builder.create<tensor::DimOp>(loc, operand, element.index()));
1623 shape_vals.push_back(value_dim);
1624 }
1625 all_shape_values.emplace_back(std::move(shape_vals));
1626 }
1627
1628 int axis = this->dimension();
1629 auto& shape_values = all_shape_values[0];
1630 for (size_t vec_id = 1; vec_id < all_shape_values.size(); ++vec_id) {
1631 auto& other_shape_values = all_shape_values[vec_id];
1632 if (other_shape_values.size() != shape_values.size()) {
1633 this->emitOpError()
1634 << "Concatenate expects all operands must be of the same rank";
1635 return failure();
1636 }
1637 shape_values[axis] = builder.create<AddIOp>(loc, shape_values[axis],
1638 other_shape_values[axis]);
1639 }
1640
1641 Value output_shape = builder.create<tensor::FromElementsOp>(
1642 loc, shape_scalar_type, shape_values);
1643 reifiedReturnShapes.push_back(output_shape);
1644
1645 return success();
1646 }
1647
1648 //===----------------------------------------------------------------------===//
1649 // DynamicReshapeOp
1650 //===----------------------------------------------------------------------===//
1651
Verify(DynamicReshapeOp op)1652 static LogicalResult Verify(DynamicReshapeOp op) {
1653 auto result_type = op.result().getType().dyn_cast<RankedTensorType>();
1654 auto output_shape_type =
1655 op.output_shape().getType().dyn_cast<RankedTensorType>();
1656 if (result_type && output_shape_type && output_shape_type.hasStaticShape() &&
1657 output_shape_type.getDimSize(0) != result_type.getRank()) {
1658 return op.emitError() << "output should have a rank equal to the number of "
1659 "elements in output_shape";
1660 }
1661 return success();
1662 }
1663
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1664 LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
1665 OpBuilder& builder, ValueRange operands,
1666 SmallVectorImpl<Value>& reifiedReturnShapes) {
1667 DynamicReshapeOp::Adaptor adaptor(operands);
1668 reifiedReturnShapes.push_back(
1669 castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
1670 return success();
1671 }
1672
1673 namespace {
1674 class DynamicReshapeOpNotActuallyDynamic
1675 : public OpRewritePattern<DynamicReshapeOp> {
1676 public:
1677 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1678 LogicalResult matchAndRewrite(DynamicReshapeOp op,
1679 PatternRewriter& rewriter) const override {
1680 auto type = op.result().getType().dyn_cast<RankedTensorType>();
1681 if (!type || !type.hasStaticShape()) {
1682 return rewriter.notifyMatchFailure(op, "requires static shape tensor");
1683 }
1684 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
1685 return success();
1686 }
1687 };
1688
1689 // Canonicalizes
1690 // %0 = some_op(%tensor)
1691 // %1 = "mhlo.dynamic_reshape"(%0, %shape)
1692 // (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
1693 // ... uses of %1.
1694 //
1695 // into
1696 //
1697 // ... uses of %0.
1698 // This canonicalization is only correct if the input is correct!
1699 // TODO(b/178779691): Use a more sophisticated canonicalization that preserves
1700 // errors in input, and still allows us to get rid of redundant reshapes.
1701 class RemoveRedundantRank1DynamicReshape
1702 : public OpRewritePattern<DynamicReshapeOp> {
1703 public:
1704 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1705 LogicalResult matchAndRewrite(DynamicReshapeOp op,
1706 PatternRewriter& rewriter) const override {
1707 auto type = op.result().getType().dyn_cast<RankedTensorType>();
1708 if (!type || type.getRank() != 1 || type.hasStaticShape()) {
1709 return rewriter.notifyMatchFailure(
1710 op, "requires rank 1 shape tensor with dynamic dimension");
1711 }
1712 auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
1713 if (!operand_type || operand_type.getRank() != 1 ||
1714 operand_type.hasStaticShape()) {
1715 return rewriter.notifyMatchFailure(
1716 op, "requires rank 1 shape tensor with dynamic dimension");
1717 }
1718 rewriter.replaceOp(op, {op.operand()});
1719 return success();
1720 }
1721 };
1722
1723 // Canonicalizes
1724 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
1725 // %1 = same_operands_and_result_shape_op(%tensor)
1726 // %2 = "mhlo.dynamic_reshape"(%1, %shape)
1727 // ... uses of %2.
1728 //
1729 // into
1730 //
1731 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
1732 // %1 = same_operands_and_result_shape_op(%tensor)
1733 // ... uses of %1.
1734 class DynamicReshapeOpSameShapeOpResult
1735 : public OpRewritePattern<DynamicReshapeOp> {
1736 public:
1737 using OpRewritePattern::OpRewritePattern;
1738
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1739 LogicalResult matchAndRewrite(DynamicReshapeOp op,
1740 PatternRewriter& rewriter) const override {
1741 Operation* def_op = op.operand().getDefiningOp();
1742 if (!def_op ||
1743 !def_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
1744 return failure();
1745 }
1746 Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
1747 if (!input_def_op) {
1748 return failure();
1749 }
1750 auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
1751 if (reshape && reshape.output_shape() == op.output_shape()) {
1752 rewriter.replaceOp(op, {def_op->getResult(0)});
1753 return success();
1754 }
1755 return failure();
1756 }
1757 };
1758 } // namespace
1759
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1760 void DynamicReshapeOp::getCanonicalizationPatterns(
1761 OwningRewritePatternList& results, MLIRContext* context) {
1762 // clang-format off
1763 results.insert<
1764 DynamicReshapeOpNotActuallyDynamic,
1765 DynamicReshapeOpSameShapeOpResult,
1766 RemoveRedundantDynamicBroadcast,
1767 RemoveRedundantDynamicReshape,
1768 RemoveRedundantRank1DynamicReshape,
1769 ShapeOfDynamicReshape
1770 >(context);
1771 // clang-format on
1772 }
1773
1774 //===----------------------------------------------------------------------===//
1775 // DynamicSliceOp
1776 //===----------------------------------------------------------------------===//
1777
1778 namespace {
1779 // Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
1780 // This canonicalization is applied the case when the `begin` input values are
1781 // compile time constants and thus can be made into a tensor.
1782 struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
1783 using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
1784
matchAndRewritemlir::mhlo::__anon7800207c0d11::DynamicSliceToSlice1785 LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice,
1786 PatternRewriter& rewriter) const override {
1787 Value input = dynamic_slice.operand();
1788 auto input_tensor = input.getType().dyn_cast<RankedTensorType>();
1789 if (!input_tensor) return failure();
1790
1791 SmallVector<int64_t, 4> temp_start_indices;
1792 for (Value start : dynamic_slice.start_indices()) {
1793 APInt val;
1794 if (!matchPattern(start, m_ConstantInt(&val))) {
1795 return failure();
1796 }
1797 temp_start_indices.push_back(*(val.getRawData()));
1798 }
1799
1800 // At this point we've determined that the start indices are all constants;
1801 // pack them into a single tensor.
1802 auto loc = dynamic_slice.getLoc();
1803 int64_t input_rank = input_tensor.getRank();
1804 auto slice_start_indices =
1805 GetI64ElementsAttr(temp_start_indices, &rewriter);
1806 DenseIntElementsAttr slice_limits = BuildSliceLimits(
1807 slice_start_indices, dynamic_slice.slice_sizes(), &rewriter);
1808 DenseIntElementsAttr slice_strides =
1809 GetI64ElementsAttr(SmallVector<int64_t, 4>(input_rank, 1), &rewriter);
1810 auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
1811 slice_limits, slice_strides);
1812 rewriter.replaceOp(dynamic_slice, {result});
1813 return success();
1814 }
1815 };
1816
1817 } // namespace
1818
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1819 void DynamicSliceOp::getCanonicalizationPatterns(
1820 OwningRewritePatternList& results, MLIRContext* context) {
1821 results.insert<DynamicSliceToSlice>(context);
1822 }
1823
1824 // Verifies that the number of slice sizes and the number of start indices match
Verify(DynamicSliceOp op)1825 static LogicalResult Verify(DynamicSliceOp op) {
1826 int num_slice_sizes = op.slice_sizes().getNumElements();
1827 int num_start_indices = op.start_indices().size();
1828 if (num_start_indices != num_slice_sizes) {
1829 return op.emitOpError()
1830 << "has mismatched number of slice sizes (" << num_slice_sizes
1831 << ") and number of start indices (" << num_start_indices << ")";
1832 }
1833 return success();
1834 }
1835
1836 //===----------------------------------------------------------------------===//
1837 // RealDynamicSliceOp
1838 //===----------------------------------------------------------------------===//
1839 // Verifies that operand rank matches start_indices/limit_indices/strides size
Verify(RealDynamicSliceOp op)1840 static LogicalResult Verify(RealDynamicSliceOp op) {
1841 auto input_type = op.operand().getType().dyn_cast<RankedTensorType>();
1842 // If operand is unranked, there is very little to verify statically.
1843 if (!input_type) return success();
1844 int input_rank = input_type.getRank();
1845
1846 auto start_type = op.start_indices().getType().cast<RankedTensorType>();
1847 auto limit_type = op.limit_indices().getType().cast<RankedTensorType>();
1848 auto strides_type = op.strides().getType().cast<RankedTensorType>();
1849
1850 if (input_rank != start_type.getNumElements()) {
1851 return op.emitOpError() << "has mismatched number of operand rank ("
1852 << input_rank << ") and start_indices size ("
1853 << start_type.getNumElements() << ")";
1854 }
1855
1856 if (input_rank != limit_type.getNumElements()) {
1857 return op.emitOpError() << "has mismatched number of operand rank ("
1858 << input_rank << ") and limit_indices size ("
1859 << limit_type.getNumElements() << ")";
1860 }
1861
1862 if (input_rank != strides_type.getNumElements()) {
1863 return op.emitOpError()
1864 << "has mismatched number of operand rank (" << input_rank
1865 << ") and strides size (" << strides_type.getNumElements() << ")";
1866 }
1867
1868 return success();
1869 }
1870
1871 namespace {
1872 // Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice
1873 // ops. This canonicalization is applied the case when the `begin` input values
1874 // are compile time constants and thus can be made into a tensor.
1875 struct RealDynamicSliceIsStatic : public OpRewritePattern<RealDynamicSliceOp> {
1876 using OpRewritePattern<RealDynamicSliceOp>::OpRewritePattern;
1877
matchAndRewritemlir::mhlo::__anon7800207c0e11::RealDynamicSliceIsStatic1878 LogicalResult matchAndRewrite(RealDynamicSliceOp real_dynamic_slice,
1879 PatternRewriter& rewriter) const override {
1880 Location loc = real_dynamic_slice.getLoc();
1881 Value input = real_dynamic_slice.operand();
1882 Value output = real_dynamic_slice.result();
1883 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
1884 auto output_ty = output.getType().dyn_cast<RankedTensorType>();
1885
1886 if (!input_ty || !output_ty || !input_ty.hasStaticShape() ||
1887 !output_ty.hasStaticShape()) {
1888 return failure();
1889 }
1890
1891 int64_t input_rank = input_ty.getRank();
1892
1893 auto start_val = real_dynamic_slice.start_indices();
1894 auto limit_val = real_dynamic_slice.limit_indices();
1895 auto stride_val = real_dynamic_slice.strides();
1896 auto start_op = start_val.getDefiningOp<mlir::ConstantOp>();
1897 auto limit_op = limit_val.getDefiningOp<mlir::ConstantOp>();
1898 auto stride_op = stride_val.getDefiningOp<mlir::ConstantOp>();
1899 if (!start_op || !limit_op || !stride_op) return failure();
1900
1901 auto start_attr =
1902 start_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
1903 auto limit_attr =
1904 limit_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
1905 auto stride_attr =
1906 stride_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
1907 if (!start_attr || !limit_attr || !stride_attr) return failure();
1908
1909 SmallVector<int64_t, 4> temp_start_indices;
1910 SmallVector<int64_t, 4> temp_limit_indices;
1911 SmallVector<int64_t, 4> temp_stride;
1912 for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) {
1913 int64_t start = start_attr.getValue<IntegerAttr>(dim_idx).getInt();
1914 temp_start_indices.push_back(start);
1915 int64_t limit = limit_attr.getValue<IntegerAttr>(dim_idx).getInt();
1916 temp_limit_indices.push_back(limit);
1917 int64_t end = stride_attr.getValue<IntegerAttr>(dim_idx).getInt();
1918 temp_stride.push_back(end);
1919 }
1920
1921 DenseIntElementsAttr slice_start_indices =
1922 GetI64ElementsAttr(temp_start_indices, &rewriter);
1923 DenseIntElementsAttr slice_limit_indices =
1924 GetI64ElementsAttr(temp_limit_indices, &rewriter);
1925 DenseIntElementsAttr slice_strides =
1926 GetI64ElementsAttr(temp_stride, &rewriter);
1927 auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
1928 slice_limit_indices, slice_strides);
1929 rewriter.replaceOp(real_dynamic_slice, {result});
1930 return success();
1931 }
1932 };
1933 } // namespace
1934
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1935 void RealDynamicSliceOp::getCanonicalizationPatterns(
1936 OwningRewritePatternList& results, MLIRContext* context) {
1937 results.insert<RealDynamicSliceIsStatic, RealDSliceToSlice>(context);
1938 }
1939
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1940 LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
1941 OpBuilder& builder, ValueRange operands,
1942 SmallVectorImpl<Value>& reifiedReturnShapes) {
1943 RealDynamicSliceOp::Adaptor adaptor(operands);
1944 Value operand = adaptor.operand();
1945 Value start_indices = adaptor.start_indices();
1946 Value limit_indices = adaptor.limit_indices();
1947 Value strides = adaptor.strides();
1948
1949 auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
1950 // Not support unranked type a.t.m.
1951 if (!operand_type) return failure();
1952
1953 Location loc = this->getLoc();
1954 SmallVector<Value, 4> shape_values;
1955 shape_values.reserve(operand_type.getRank());
1956 Type shape_scalar_type =
1957 start_indices.getType().cast<ShapedType>().getElementType();
1958 Value one = builder.create<ConstantIndexOp>(loc, 1);
1959 one = MaybeCastTo(builder, loc, one, shape_scalar_type);
1960 for (const auto& element : llvm::enumerate(operand_type.getShape())) {
1961 Value offset = builder.create<ConstantIndexOp>(loc, element.index());
1962 Value value_start =
1963 builder.create<tensor::ExtractOp>(loc, start_indices, offset);
1964 Value value_limit =
1965 builder.create<tensor::ExtractOp>(loc, limit_indices, offset);
1966 Value value_stride =
1967 builder.create<tensor::ExtractOp>(loc, strides, offset);
1968 // size = (limit - start + stride - 1) / stride
1969 shape_values.push_back(builder.create<SignedDivIOp>(
1970 loc,
1971 builder.create<SubIOp>(
1972 loc,
1973 builder.create<AddIOp>(
1974 loc, value_stride,
1975 builder.create<SubIOp>(loc, value_limit, value_start)),
1976 one),
1977 value_stride));
1978 }
1979
1980 reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
1981 loc, shape_scalar_type, shape_values));
1982
1983 return success();
1984 }
1985
1986 //===----------------------------------------------------------------------===//
1987 // InfeedOp
1988 //===----------------------------------------------------------------------===//
1989
1990 // Checks that the result type is of the form `tuple< any_type, token >`.
Verify(InfeedOp op)1991 static LogicalResult Verify(InfeedOp op) {
1992 auto result_ty = op.getResult().getType().cast<TupleType>();
1993 auto subtypes = result_ty.getTypes();
1994 if (subtypes.size() != 2)
1995 return op.emitOpError()
1996 << "result is expected to be a tuple of size 2, but got "
1997 << subtypes.size();
1998 if (!subtypes[1].isa<TokenType>())
1999 return op.emitOpError() << "second element of result tuple is expected to "
2000 "be of token type, but got "
2001 << subtypes[1];
2002 return success();
2003 }
2004
2005 //===----------------------------------------------------------------------===//
2006 // Logical Ops
2007 //===----------------------------------------------------------------------===//
2008
fold(ArrayRef<Attribute> operands)2009 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
2010 if (lhs() == rhs()) return lhs();
2011
2012 auto rType = getType().cast<ShapedType>();
2013 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2014 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2015
2016 if (lhsVal && lhsVal.isSplat()) {
2017 if (lhsVal.getSplatValue()
2018 .cast<IntegerAttr>()
2019 .getValue()
2020 .isAllOnesValue()) {
2021 return rhs();
2022 }
2023
2024 if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2025 return lhsVal;
2026 }
2027 }
2028
2029 if (rhsVal && rhsVal.isSplat()) {
2030 if (rhsVal.getSplatValue()
2031 .cast<IntegerAttr>()
2032 .getValue()
2033 .isAllOnesValue()) {
2034 return lhs();
2035 }
2036
2037 if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2038 return rhsVal;
2039 }
2040 }
2041
2042 if (!rhsVal || !lhsVal) return {};
2043
2044 llvm::SmallVector<APInt, 4> values;
2045 values.reserve(rhsVal.getNumElements());
2046 for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
2047 values.push_back(std::get<0>(it) & std::get<1>(it));
2048 }
2049
2050 return DenseIntElementsAttr::get(rType, values);
2051 }
2052
fold(ArrayRef<Attribute> operands)2053 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
2054 if (lhs() == rhs()) return lhs();
2055
2056 auto rType = getType().cast<ShapedType>();
2057 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2058 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2059
2060 if (lhsVal && lhsVal.isSplat()) {
2061 if (lhsVal.getSplatValue()
2062 .cast<IntegerAttr>()
2063 .getValue()
2064 .isAllOnesValue()) {
2065 return lhsVal;
2066 }
2067
2068 if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2069 return rhs();
2070 }
2071 }
2072
2073 if (rhsVal && rhsVal.isSplat()) {
2074 if (rhsVal.getSplatValue()
2075 .cast<IntegerAttr>()
2076 .getValue()
2077 .isAllOnesValue()) {
2078 return rhsVal;
2079 }
2080
2081 if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2082 return lhs();
2083 }
2084 }
2085
2086 if (!rhsVal || !lhsVal) return {};
2087
2088 llvm::SmallVector<APInt, 4> values;
2089 values.reserve(rhsVal.getNumElements());
2090 for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
2091 values.push_back(std::get<0>(it) | std::get<1>(it));
2092 }
2093
2094 return DenseIntElementsAttr::get(rType, values);
2095 }
2096
fold(ArrayRef<Attribute> operands)2097 OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
2098 auto rType = getType().cast<ShapedType>();
2099 if (lhs() == rhs()) {
2100 Builder builder(getContext());
2101 return builder.getZeroAttr(rType);
2102 }
2103
2104 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2105 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2106
2107 if (lhsVal && lhsVal.isSplat()) {
2108 if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2109 return rhs();
2110 }
2111 }
2112
2113 if (rhsVal && rhsVal.isSplat()) {
2114 if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
2115 return lhs();
2116 }
2117 }
2118
2119 if (!rhsVal || !lhsVal) return {};
2120
2121 llvm::SmallVector<APInt, 4> values;
2122 values.reserve(rhsVal.getNumElements());
2123 for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
2124 values.push_back(std::get<0>(it) ^ std::get<1>(it));
2125 }
2126
2127 return DenseIntElementsAttr::get(rType, values);
2128 }
2129
2130 //===----------------------------------------------------------------------===//
2131 // MapOp
2132 //===----------------------------------------------------------------------===//
2133
Verify(MapOp op)2134 static LogicalResult Verify(MapOp op) {
2135 // Checks if the number of `operands` match the arity of the map `computation`
2136 // region.
2137 auto& computation_block = op.computation().front();
2138 auto computation_args = computation_block.getArguments();
2139 if (op.operands().size() != computation_args.size())
2140 return op.emitOpError()
2141 << "expects number of operands to match the arity "
2142 "of map computation, but got: "
2143 << op.operands().size() << " and " << computation_args.size();
2144
2145 // The parameters of computation should all be scalars and match the element
2146 // type of operands.
2147 auto operand_type = op.operands()[0].getType().cast<TensorType>();
2148 auto operand_elem_ty = operand_type.getElementType();
2149
2150 for (auto indexed_arg : llvm::enumerate(computation_args)) {
2151 auto arg_type = indexed_arg.value().getType().dyn_cast<TensorType>();
2152 if (!arg_type || arg_type.getRank() != 0)
2153 return op.emitOpError()
2154 << "computation arguments must be 0-rank tensor, but got: arg #"
2155 << indexed_arg.index() << " of type "
2156 << indexed_arg.value().getType();
2157 if (arg_type.getElementType() != operand_elem_ty) {
2158 return op.emitOpError()
2159 << "element type of operands and computation arguments must "
2160 "match, but got: "
2161 << operand_elem_ty << " and " << arg_type.getElementType();
2162 }
2163 }
2164
2165 // Mapped computation must return single output
2166 auto computation_outputs = computation_block.getTerminator()->getOperands();
2167 if (computation_outputs.size() != 1)
2168 return op.emitOpError()
2169 << "computation must return single output, but got: "
2170 << computation_outputs.size();
2171
2172 // The output of computation must be scalar and have the same element type
2173 // as op result.
2174 auto computation_output_type =
2175 computation_outputs[0].getType().dyn_cast<TensorType>();
2176 if (!computation_output_type || computation_output_type.getRank() != 0)
2177 return op.emitOpError()
2178 << "computation must return 0-rank tensor, but got: "
2179 << computation_outputs[0].getType();
2180
2181 auto result_type = op.getType().cast<TensorType>();
2182 if (computation_output_type.getElementType() != result_type.getElementType())
2183 return op.emitOpError() << "element type of result and computation output "
2184 "must match, but got: "
2185 << result_type.getElementType() << " and "
2186 << computation_output_type.getElementType();
2187
2188 // Checks that the requested map dimension numbers are monotonically
2189 // increasing.
2190 auto values = op.dimensions().getValues<int64_t>();
2191 auto dimensions = std::vector<int64_t>{values.begin(), values.end()};
2192 for (int i = 0, e = dimensions.size(); i < e; ++i) {
2193 if (dimensions[i] != i)
2194 return op.emitOpError() << "requires monotonically increasing dimension "
2195 "numbers, but got: "
2196 << op.dimensions();
2197 }
2198
2199 // Checks that number of dimensions of operands matches the size of
2200 // `dimensions` since we currently only support mapping across all
2201 // dimensions: i.e., scalar map functions.
2202 if (operand_type.hasRank()) {
2203 if (dimensions.size() != operand_type.getShape().size())
2204 return op.emitOpError()
2205 << "applied to a subset of dimensions currently not supported: "
2206 "operand dimensions = "
2207 << operand_type.getShape().size()
2208 << ", requested map dimensions size = " << dimensions.size();
2209 }
2210
2211 return success();
2212 }
2213
fold(ArrayRef<Attribute> operands)2214 OpFoldResult MapOp::fold(ArrayRef<Attribute> operands) {
2215 mlir::Block& bb = computation().front();
2216 mlir::Operation& front_op = bb.front();
2217
2218 auto ret_op = mlir::dyn_cast<ReturnOp>(front_op);
2219 if (!ret_op) return nullptr;
2220 if (ret_op.results().size() != 1) return nullptr;
2221
2222 for (mlir::BlockArgument barg : bb.getArguments()) {
2223 if (barg == ret_op.results()[0]) return getOperands()[barg.getArgNumber()];
2224 }
2225 return nullptr;
2226 }
2227
2228 //===----------------------------------------------------------------------===//
2229 // RecvOp
2230 //===----------------------------------------------------------------------===//
2231
2232 // Checks that the result type is of the form `tuple<any_type, mhlo::token>`
Verify(RecvOp op)2233 static LogicalResult Verify(RecvOp op) {
2234 auto result_ty = op.getResult().getType().cast<TupleType>();
2235 auto subtypes = result_ty.getTypes();
2236 if (subtypes.size() != 2)
2237 return op.emitOpError()
2238 << "result is expected to be a tuple of size 2, but got "
2239 << subtypes.size();
2240 if (!subtypes[1].isa<TokenType>())
2241 return op.emitOpError() << "second element of result tuple is expected to "
2242 "be of token type, but got "
2243 << subtypes[1];
2244 return success();
2245 }
2246
2247 //===----------------------------------------------------------------------===//
2248 // CopyOp
2249 //===----------------------------------------------------------------------===//
2250
fold(ArrayRef<Attribute> operands)2251 OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
2252
2253 //===----------------------------------------------------------------------===//
2254 // ReduceWindowOp
2255 //===----------------------------------------------------------------------===//
2256
2257 // For reduce-window, all `inputs` need to have compatible shapes.
Verify(ReduceWindowOp op)2258 static LogicalResult Verify(ReduceWindowOp op) {
2259 if (failed(verifyCompatibleShapes(op.inputs().getTypes())))
2260 return op.emitOpError() << "requires same shape for all inputs";
2261 return success();
2262 }
2263
2264 // Get the operation used for reduction applied to `result_index`th result. Its
2265 // expected to be a binary operation that consumes `result_index`th and
2266 // `result_index + operands().size`th arguments of the body.
getReductionOp(int result_index)2267 Operation* ReduceWindowOp::getReductionOp(int result_index) {
2268 auto return_op = cast<ReturnOp>(body().front().getTerminator());
2269 Operation* compute_op = return_op.results()[result_index].getDefiningOp();
2270 if (compute_op->getNumOperands() != 2) return nullptr;
2271 auto arg0 = compute_op->getOperand(0).dyn_cast<BlockArgument>();
2272 auto arg1 = compute_op->getOperand(1).dyn_cast<BlockArgument>();
2273 if (!arg0 || !arg1) return nullptr;
2274 int arg0_num = arg0.getArgNumber();
2275 int arg1_num = arg1.getArgNumber();
2276 int other_arg_index = result_index + inputs().size();
2277 if (arg0_num == result_index && arg1_num == other_arg_index)
2278 return compute_op;
2279 if (arg0_num == other_arg_index && arg1_num == result_index &&
2280 compute_op->hasTrait<mlir::OpTrait::IsCommutative>())
2281 return compute_op;
2282 return nullptr;
2283 }
2284
2285 //===----------------------------------------------------------------------===//
2286 // ReverseOp
2287 //===----------------------------------------------------------------------===//
2288
fold(ArrayRef<Attribute> operands)2289 OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
2290 auto input = operand();
2291
2292 // No dimensions to reverse.
2293 if (dimensions().getNumElements() == 0) return input;
2294
2295 llvm::SmallVector<APInt, 5> new_dims;
2296 new_dims.reserve(dimensions().getNumElements());
2297
2298 auto shaped_type = input.getType().cast<ShapedType>();
2299 for (auto dim : dimensions().getValues<APInt>()) {
2300 if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) {
2301 return nullptr;
2302 }
2303 }
2304
2305 return input;
2306 }
2307
2308 //===----------------------------------------------------------------------===//
2309 // ReduceOp
2310 //===----------------------------------------------------------------------===//
2311
2312 // Returns the result type after reducing operand of the given type across the
2313 // specified dimensions.
GetReduceResultType(Type operand_ty,DenseIntElementsAttr dimensions,Builder * builder)2314 static TensorType GetReduceResultType(Type operand_ty,
2315 DenseIntElementsAttr dimensions,
2316 Builder* builder) {
2317 Type element_ty = getElementTypeOrSelf(operand_ty);
2318
2319 auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>();
2320 if (!ranked_ty) return UnrankedTensorType::get(element_ty);
2321
2322 int64_t rank = ranked_ty.getRank();
2323 llvm::SmallVector<bool, 4> dims_mask(rank, false);
2324 for (int64_t dim : dimensions.getValues<int64_t>()) dims_mask[dim] = true;
2325
2326 SmallVector<int64_t, 4> shape;
2327 for (int64_t i = 0; i < rank; ++i) {
2328 if (!dims_mask[i]) shape.push_back(ranked_ty.getDimSize(i));
2329 }
2330
2331 return RankedTensorType::get(shape, element_ty);
2332 }
2333
build(OpBuilder & builder,OperationState & state,ValueRange inputs,ValueRange init_values,DenseIntElementsAttr dimensions)2334 void ReduceOp::build(OpBuilder& builder, OperationState& state,
2335 ValueRange inputs, ValueRange init_values,
2336 DenseIntElementsAttr dimensions) {
2337 SmallVector<Type, 1> result_ty;
2338 result_ty.reserve(inputs.size());
2339
2340 for (Value input : inputs) {
2341 result_ty.push_back(
2342 GetReduceResultType(input.getType(), dimensions, &builder));
2343 }
2344 build(builder, state, result_ty, inputs, init_values, dimensions);
2345 }
2346
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2347 LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
2348 SmallVectorImpl<OpFoldResult>& results) {
2349 // No dimensions to reduce.
2350 if (dimensions().getNumElements() == 0) {
2351 for (Value input : this->inputs()) {
2352 results.push_back(input);
2353 }
2354 return success();
2355 }
2356
2357 // If all returned values in the ReduceOp region exists outside
2358 // the region replace the ReduceOp with those values.
2359 mlir::Block& bb = this->body().front();
2360 SmallVector<Value> replaced_results;
2361 if (auto ret_op = mlir::dyn_cast<ReturnOp>(bb.back())) {
2362 for (Value result : ret_op.results()) {
2363 if (result.getParentRegion() == ret_op->getParentRegion())
2364 return failure();
2365 replaced_results.push_back(result);
2366 }
2367
2368 results.insert(results.end(), replaced_results.begin(),
2369 replaced_results.end());
2370 return success();
2371 }
2372
2373 return failure();
2374 }
2375
2376 // Enable constant folding to occur within the region of the ReduceOp
2377 // by replacing block argument uses with constants if:
2378 // 1. All the ReduceOp operands are splat constants.
2379 // 2. The ReduceOp region consists of a single logical AND or logical OR.
2380 // The pattern leverages the idempotent property of the AND and OR operators
2381 // to determine the value of a reduction on splat constants. Other boolean
2382 // operators do not have this property, and need separate patterns to resolve
2383 // reductions of their splat constants.
2384 struct LowerBoolSplatConstantsIntoRegion : public OpRewritePattern<ReduceOp> {
2385 using OpRewritePattern<ReduceOp>::OpRewritePattern;
2386
matchAndRewritemlir::mhlo::LowerBoolSplatConstantsIntoRegion2387 LogicalResult matchAndRewrite(ReduceOp op,
2388 PatternRewriter& rewriter) const override {
2389 mlir::Block& bb = op.body().front();
2390
2391 // Ensure only a compute op and return op exist and the
2392 // compute op is an AND or OR op.
2393 if (bb.getOperations().size() != 2) return failure();
2394 if (!mlir::isa<mhlo::AndOp, mhlo::OrOp>(bb.front())) return failure();
2395
2396 // Ensure all operands are splat constants.
2397 SmallVector<DenseElementsAttr, 4> barg_cst_attrs;
2398 for (auto inp_and_barg : llvm::zip(op.getOperands(), bb.getArguments())) {
2399 Value inp = std::get<0>(inp_and_barg);
2400 BlockArgument barg = std::get<1>(inp_and_barg);
2401 ConstOp cst = inp.getDefiningOp<ConstOp>();
2402 if (!cst) return failure();
2403
2404 auto cst_attr = cst.value().dyn_cast_or_null<DenseElementsAttr>();
2405 if (!cst_attr.isSplat()) {
2406 return rewriter.notifyMatchFailure(op, "Must be splat constant.");
2407 }
2408
2409 auto barg_shaped_type = barg.getType().dyn_cast<ShapedType>();
2410 if (!barg_shaped_type) return failure();
2411
2412 auto barg_cst_attr =
2413 DenseElementsAttr::get(barg_shaped_type, cst_attr.getSplatValue());
2414 barg_cst_attrs.push_back(barg_cst_attr);
2415 }
2416
2417 // Create new splat constants to replace block arguments.
2418 for (BlockArgument barg : bb.getArguments()) {
2419 int arg_idx = barg.getArgNumber();
2420 mhlo::ConstOp new_cst = rewriter.create<mhlo::ConstOp>(
2421 bb.front().getLoc(), barg.getType(), barg_cst_attrs[arg_idx]);
2422 barg.replaceAllUsesWith(new_cst);
2423 }
2424 return success();
2425 }
2426 };
2427
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2428 void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2429 MLIRContext* context) {
2430 results.insert<LowerBoolSplatConstantsIntoRegion>(context);
2431 }
2432
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2433 LogicalResult ReduceOp::reifyReturnTypeShapes(
2434 OpBuilder& builder, ValueRange operands,
2435 SmallVectorImpl<Value>& reifiedReturnShapes) {
2436 ReduceOp::Adaptor adaptor(operands);
2437 auto inputs = adaptor.inputs();
2438
2439 auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>();
2440 // Not support unranked type a.t.m.
2441 if (!operand_type) return failure();
2442
2443 Location loc = this->getLoc();
2444 SmallVector<Value, 4> shape_values;
2445 SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
2446 shape_values.reserve(operand_type.getRank());
2447 Type shape_scalar_type = builder.getIndexType();
2448 auto to_shape_scalar_type = [&](Value v) {
2449 return MaybeCastTo(builder, loc, v, shape_scalar_type);
2450 };
2451
2452 for (const auto& element : llvm::enumerate(operand_type.getShape())) {
2453 int64_t idx = element.index();
2454 auto it = std::find(dimensions.begin(), dimensions.end(), idx);
2455 if (it != dimensions.end()) {
2456 continue;
2457 }
2458 Value value_dim = to_shape_scalar_type(
2459 builder.create<tensor::DimOp>(loc, inputs[0], element.index()));
2460 shape_values.push_back(value_dim);
2461 }
2462
2463 Value output_shape = builder.create<tensor::FromElementsOp>(
2464 loc, shape_scalar_type, shape_values);
2465 for (size_t i = 0; i < inputs.size(); ++i) {
2466 reifiedReturnShapes.push_back(output_shape);
2467 }
2468
2469 return success();
2470 }
2471
2472 //===----------------------------------------------------------------------===//
2473 // RngNormalOp
2474 //===----------------------------------------------------------------------===//
2475
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2476 LogicalResult RngNormalOp::inferReturnTypeComponents(
2477 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
2478 DictionaryAttr attributes, RegionRange regions,
2479 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2480 return rngInferReturnTypeComponents(context, location, operands, attributes,
2481 regions, inferredReturnShapes);
2482 }
2483
2484 //===----------------------------------------------------------------------===//
2485 // RngUniformOp
2486 //===----------------------------------------------------------------------===//
2487
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2488 LogicalResult RngUniformOp::inferReturnTypeComponents(
2489 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
2490 DictionaryAttr attributes, RegionRange regions,
2491 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2492 return rngInferReturnTypeComponents(context, location, operands, attributes,
2493 regions, inferredReturnShapes);
2494 }
2495
2496 //===----------------------------------------------------------------------===//
2497 // SelectOp
2498 //===----------------------------------------------------------------------===//
2499
Verify(SelectOp op)2500 static LogicalResult Verify(SelectOp op) {
2501 // TODO(jpienaar): Update to allow broadcastable and unranked inputs. This
2502 // corresponds to the client side HLO.
2503 return success();
2504 }
2505
fold(ArrayRef<Attribute> operands)2506 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
2507 if (on_true() == on_false()) {
2508 return on_true();
2509 }
2510
2511 auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2512 if (!predicate) {
2513 return {};
2514 }
2515
2516 auto predicateTy = predicate.getType().cast<ShapedType>();
2517 if (!predicateTy.getElementType().isInteger(1)) {
2518 return {};
2519 }
2520
2521 if (predicate.isSplat()) {
2522 return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
2523 : on_false();
2524 }
2525
2526 return {};
2527 }
2528
2529 // Makes it such that a SelectOp that is a non-root operation in a DRR infers
2530 // the return type based on operand type.
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2531 LogicalResult SelectOp::inferReturnTypes(
2532 MLIRContext*, Optional<Location> location, ValueRange operands,
2533 DictionaryAttr attributes, RegionRange regions,
2534 SmallVectorImpl<Type>& inferredReturnTypes) {
2535 auto x_type = operands[1].getType();
2536 auto y_type = operands[2].getType();
2537 auto x_tensor = x_type.cast<TensorType>();
2538 auto y_tensor = y_type.cast<TensorType>();
2539
2540 // Check for type compatibility in the select op. This requires that the two
2541 // non-predicate operands:
2542 // (a) have the same element type
2543 // (b) have compatible shapes (i.e. the same shape and/or at least one
2544 // dynamic shape)
2545 if (x_tensor.getElementType() != y_tensor.getElementType() ||
2546 failed(mlir::verifyCompatibleShape(x_type, y_type))) {
2547 return emitOptionalError(location, "incompatible operand types: ", x_type,
2548 " and ", y_type);
2549 }
2550
2551 // TODO(lucyfox): Support output shape inference when operands have compatible
2552 // shapes. (The output shape should be the most general of the operand shapes
2553 // at each dimension.) For now, handle the straightforward cases and fail
2554 // otherwise. When this is fully implemented, this logic should move into
2555 // reusable functionality in MLIR Core.
2556 Type output_type;
2557 if (x_type == y_type || !x_tensor.hasRank()) {
2558 output_type = x_type;
2559 } else if (!y_tensor.hasRank()) {
2560 output_type = y_type;
2561 } else {
2562 return emitOptionalError(location,
2563 "currently unsupported operand types: ", x_type,
2564 " and ", y_type);
2565 }
2566 inferredReturnTypes.assign({output_type});
2567 return success();
2568 }
2569
inferReturnTypeComponents(mlir::MLIRContext * ctx,llvm::Optional<mlir::Location> loc,ValueShapeRange operands,mlir::DictionaryAttr attributes,mlir::RegionRange regions,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferredShapedTypeComponents)2570 LogicalResult SelectOp::inferReturnTypeComponents(
2571 mlir::MLIRContext* ctx, llvm::Optional<mlir::Location> loc,
2572 ValueShapeRange operands, mlir::DictionaryAttr attributes,
2573 mlir::RegionRange regions,
2574 llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&
2575 inferredShapedTypeComponents) {
2576 llvm::SmallVector<Type, 4> inferredReturnTypes;
2577 const LogicalResult infer_types_status = inferReturnTypes(
2578 ctx, loc, operands, attributes, regions, inferredReturnTypes);
2579 if (infer_types_status.failed()) return infer_types_status;
2580
2581 if (inferredReturnTypes.size() != 1) return failure();
2582
2583 auto result_tensor_type =
2584 inferredReturnTypes[0].dyn_cast_or_null<TensorType>();
2585 if (!result_tensor_type) return failure();
2586
2587 mlir::Type element_type =
2588 operands[1].getType().cast<TensorType>().getElementType();
2589 inferredShapedTypeComponents.push_back(
2590 {result_tensor_type.getShape(), element_type});
2591
2592 return success();
2593 }
2594
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2595 LogicalResult SelectOp::reifyReturnTypeShapes(
2596 OpBuilder& builder, ValueRange operands,
2597 SmallVectorImpl<Value>& reifiedReturnShapes) {
2598 // For `hlo.select`, the first operand may be a scalar.
2599 return deriveShapeFromOperand(&builder, getOperation(), operands[1],
2600 &reifiedReturnShapes);
2601 }
2602
2603 //===----------------------------------------------------------------------===//
2604 // SetDimensionSizeOp
2605 //===----------------------------------------------------------------------===//
2606
Verify(SetDimensionSizeOp op)2607 static LogicalResult Verify(SetDimensionSizeOp op) {
2608 if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
2609 if (size.getRank() != 0)
2610 return op.emitOpError() << "size operand should be of rank-0";
2611 }
2612
2613 return VerifyDimAttr(op);
2614 }
2615
fold(ArrayRef<Attribute> operands)2616 OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
2617 DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2618 if (input) return input;
2619
2620 DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2621 if (!size || !size.isSplat()) return {};
2622
2623 auto ty = getType().dyn_cast<RankedTensorType>();
2624 if (!ty) return {};
2625
2626 int64_t dim_size = ty.getDimSize(dimension());
2627 if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
2628 return operand();
2629 return {};
2630 }
2631
2632 //===----------------------------------------------------------------------===//
2633 // PadOp
2634 //===----------------------------------------------------------------------===//
2635
Verify(PadOp op)2636 static LogicalResult Verify(PadOp op) {
2637 auto input_type = op.operand().getType().cast<RankedTensorType>();
2638 auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
2639
2640 if (pad_type.getRank() != 0) {
2641 return op.emitOpError(
2642 llvm::formatv("padding value type should be a rank-0 "
2643 "tensor, is rank {0}",
2644 pad_type.getRank()));
2645 }
2646
2647 const auto& padding_low = op.edge_padding_low();
2648 if (padding_low.getType().getNumElements() != input_type.getRank()) {
2649 return op.emitOpError(llvm::formatv(
2650 "edge_padding_low length ({0}) must match operand rank ({1})",
2651 padding_low.getType().getNumElements(), input_type.getRank()));
2652 }
2653
2654 const auto& padding_high = op.edge_padding_high();
2655 if (padding_high.getType().getNumElements() != input_type.getRank()) {
2656 return op.emitOpError(llvm::formatv(
2657 "edge_padding_high length ({0}) must match operand rank ({1})",
2658 padding_high.getType().getNumElements(), input_type.getRank()));
2659 }
2660
2661 const auto& padding_interior = op.interior_padding();
2662 if (padding_interior.getType().getNumElements() != input_type.getRank()) {
2663 return op.emitOpError(llvm::formatv(
2664 "interior_padding length ({0}) must match operand rank ({1})",
2665 padding_interior.getType().getNumElements(), input_type.getRank()));
2666 }
2667
2668 auto input_shape = input_type.getShape();
2669 auto output_shape =
2670 op.getResult().getType().cast<RankedTensorType>().getShape();
2671 if (input_shape.size() != output_shape.size()) {
2672 return op.emitOpError(
2673 llvm::formatv("operand rank ({0}) and result rank({0}) should match",
2674 input_shape.size(), output_shape.size()));
2675 }
2676
2677 for (int i = 0, e = input_shape.size(); i < e; i++) {
2678 int64_t padding_low_val = padding_low.getValue<IntegerAttr>(i).getInt();
2679 int64_t padding_high_val = padding_high.getValue<IntegerAttr>(i).getInt();
2680 int64_t padding_interior_val =
2681 padding_interior.getValue<IntegerAttr>(i).getInt();
2682 int64_t expected_output =
2683 input_shape[i] + padding_low_val + padding_high_val +
2684 std::max<int64_t>(input_shape[i] - 1, 0LL) * padding_interior_val;
2685 if (expected_output != output_shape[i]) {
2686 return op.emitOpError(llvm::formatv(
2687 "expected output shape's dimension #{0} to be {1} but found {2}", i,
2688 expected_output, output_shape[i]));
2689 }
2690 }
2691
2692 return success();
2693 }
2694
fold(ArrayRef<Attribute> operands)2695 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
2696 // If all padding is zero then it is an identity pad.
2697 auto is_zero = [](const APInt& i) { return i == 0; };
2698 if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) &&
2699 llvm::all_of(edge_padding_high().getIntValues(), is_zero) &&
2700 llvm::all_of(interior_padding().getIntValues(), is_zero))
2701 return operand();
2702
2703 // If any padding is negative then it isn't supported by the folder (yet).
2704 auto is_negative = [](const APInt& i) { return i.slt(0); };
2705 if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) &&
2706 llvm::all_of(edge_padding_high().getIntValues(), is_negative) &&
2707 llvm::all_of(interior_padding().getIntValues(), is_negative))
2708 return {};
2709
2710 DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2711 DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
2712 RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
2713 if (!input || !input.getType().hasRank() || !padding || !return_type ||
2714 !return_type.hasStaticShape())
2715 return {};
2716
2717 // Fill the full result tensor with the padding value.
2718 llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
2719 padding.getValue({}));
2720
2721 auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
2722 llvm::ArrayRef<int64_t> shape) {
2723 for (int64_t i = index.size() - 1; i >= 0; --i) {
2724 ++index[i];
2725 if (index[i] < shape[i]) return;
2726 index[i] = 0;
2727 }
2728 };
2729
2730 // Iterate over all elements of the input tensor and copy it to the correct
2731 // location in the output tensor.
2732 llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
2733 uint64_t num_elements = input.getNumElements();
2734 for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
2735 uint64_t result_idx = 0;
2736 uint64_t idx_multiplyer = 1;
2737 for (int64_t i = index.size() - 1; i >= 0; --i) {
2738 result_idx +=
2739 (edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
2740 index[i] *
2741 (interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
2742 idx_multiplyer;
2743 idx_multiplyer *= return_type.getDimSize(i);
2744 }
2745 result[result_idx] = input.getValue(index);
2746 next_index(index, input.getType().getShape());
2747 }
2748 return DenseElementsAttr::get(return_type, result);
2749 }
2750
2751 //===----------------------------------------------------------------------===//
2752 // DynamicPadOp
2753 //===----------------------------------------------------------------------===//
2754
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2755 void DynamicPadOp::getCanonicalizationPatterns(
2756 OwningRewritePatternList& results, MLIRContext* context) {
2757 results.insert<DPadToPad>(context);
2758 }
2759
Verify(DynamicPadOp op)2760 static LogicalResult Verify(DynamicPadOp op) {
2761 auto input_type = op.operand().getType().dyn_cast<RankedTensorType>();
2762 // If operand is unranked, there is very little to verify statically.
2763 if (!input_type) return success();
2764 int input_rank = input_type.getRank();
2765
2766 auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
2767 if (pad_type.getRank() != 0) {
2768 return op.emitOpError() << "padding value type should be a rank-0";
2769 }
2770
2771 auto padding_low_type =
2772 op.edge_padding_low().getType().cast<RankedTensorType>();
2773 if (padding_low_type.getNumElements() != input_rank) {
2774 return op.emitOpError()
2775 << "edge_padding_low length(" << padding_low_type.getNumElements()
2776 << ") must match operand rank(" << input_rank << ").";
2777 }
2778
2779 auto padding_high_type =
2780 op.edge_padding_high().getType().cast<RankedTensorType>();
2781 if (padding_high_type.getNumElements() != input_rank) {
2782 return op.emitOpError()
2783 << "edge_padding_high length(" << padding_high_type.getNumElements()
2784 << ") must match operand rank(" << input_rank << ").";
2785 }
2786
2787 auto interior_padding_type =
2788 op.interior_padding().getType().cast<RankedTensorType>();
2789 if (interior_padding_type.getNumElements() != input_rank) {
2790 return op.emitOpError()
2791 << "edge_padding_interior length("
2792 << interior_padding_type.getNumElements()
2793 << ") must match operand rank(" << input_rank << ").";
2794 }
2795
2796 auto output_type = op.getResult().getType().dyn_cast<RankedTensorType>();
2797 // If result is unranked, there is very little to verify statically.
2798 if (!output_type) return success();
2799 int output_rank = output_type.getRank();
2800 if (input_rank != output_rank) {
2801 return op.emitOpError() << "operand rank(" << input_rank
2802 << ") must match result(" << output_rank << ").";
2803 }
2804
2805 return success();
2806 }
2807
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2808 LogicalResult DynamicPadOp::reifyReturnTypeShapes(
2809 OpBuilder& builder, ValueRange operands,
2810 SmallVectorImpl<Value>& reifiedReturnShapes) {
2811 DynamicPadOp::Adaptor adaptor(operands);
2812 Value operand = adaptor.operand();
2813 Value edge_padding_low = adaptor.edge_padding_low();
2814 Value edge_padding_high = adaptor.edge_padding_high();
2815 Value interior_padding = adaptor.interior_padding();
2816
2817 auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
2818 // Not support unranked pad a.t.m.
2819 if (!operand_type) return failure();
2820
2821 auto loc = this->getLoc();
2822 SmallVector<Value, 4> shape_values;
2823 shape_values.reserve(operand_type.getRank());
2824 Type shape_scalar_type =
2825 edge_padding_low.getType().cast<ShapedType>().getElementType();
2826
2827 auto to_shape_scalar_type = [&](Value v) {
2828 return MaybeCastTo(builder, loc, v, shape_scalar_type);
2829 };
2830
2831 Value zero = to_shape_scalar_type(builder.create<ConstantIndexOp>(loc, 0));
2832 Value one = to_shape_scalar_type(builder.create<ConstantIndexOp>(loc, 1));
2833
2834 for (int idx : llvm::seq<int>(0, operand_type.getShape().size())) {
2835 Value value_dim =
2836 to_shape_scalar_type(builder.create<tensor::DimOp>(loc, operand, idx));
2837 Value offset = builder.create<ConstantIndexOp>(loc, idx);
2838 Value value_low =
2839 builder.create<tensor::ExtractOp>(loc, edge_padding_low, offset);
2840 Value value_high =
2841 builder.create<tensor::ExtractOp>(loc, edge_padding_high, offset);
2842 Value value_interior =
2843 builder.create<tensor::ExtractOp>(loc, interior_padding, offset);
2844 // output_size = input_size + padding_low + padding_high + interior *
2845 // max(input_size - 1, 0)
2846 Value value_dim_less_than_one =
2847 builder.create<CmpIOp>(loc, CmpIPredicate::slt, value_dim, one);
2848 Value interior_size = builder.create<MulIOp>(
2849 loc, value_interior,
2850 builder.create<mlir::SelectOp>(
2851 loc, value_dim_less_than_one, zero,
2852 builder.create<SubIOp>(loc, value_dim, one)));
2853 shape_values.push_back(builder.create<AddIOp>(
2854 loc,
2855 builder.create<AddIOp>(
2856 loc, builder.create<AddIOp>(loc, interior_size, value_dim),
2857 value_low),
2858 value_high));
2859 }
2860
2861 reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
2862 loc, shape_scalar_type, shape_values));
2863
2864 return success();
2865 }
2866
2867 //===----------------------------------------------------------------------===//
2868 // ReshapeOp
2869 //===----------------------------------------------------------------------===//
2870
Verify(ReshapeOp op)2871 static LogicalResult Verify(ReshapeOp op) {
2872 // If the operand type is dynamically shaped there is nothing to verify.
2873 auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
2874 if (!operand_ty || !operand_ty.hasStaticShape()) return success();
2875
2876 // If the operand type is statically shaped (not required) the number of
2877 // elements must match that of the result type.
2878 auto result_ty = op.getType().cast<RankedTensorType>();
2879 assert(result_ty && result_ty.hasStaticShape() &&
2880 "result type must be statically shaped");
2881 int64_t num_result_elements = result_ty.getNumElements();
2882 int64_t num_operand_elements = operand_ty.getNumElements();
2883 if (num_result_elements != num_operand_elements)
2884 return op.emitOpError()
2885 << "number of output elements (" << num_result_elements
2886 << ") doesn't match expected number of elements ("
2887 << num_operand_elements << ")";
2888
2889 return success();
2890 }
2891
fold(ArrayRef<Attribute> operands)2892 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
2893 if (getOperand().getType() == getType()) {
2894 return getOperand();
2895 }
2896
2897 if (auto prev_op = getOperand().getDefiningOp<ReshapeOp>()) {
2898 setOperand(prev_op.getOperand());
2899 return getResult();
2900 }
2901
2902 if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
2903 return elements.reshape(getResult().getType().cast<ShapedType>());
2904 }
2905
2906 return {};
2907 }
2908
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2909 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2910 MLIRContext* context) {
2911 results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape,
2912 EliminateRedundantReshape, EliminateIdentityReshape>(context);
2913 }
2914
2915 //===----------------------------------------------------------------------===//
2916 // ReplicaId Op
2917 //===----------------------------------------------------------------------===//
2918
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2919 LogicalResult ReplicaIdOp::inferReturnTypes(
2920 MLIRContext* context, Optional<Location>, ValueRange operands,
2921 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2922 inferredReturnTypes.push_back(RankedTensorType::get(
2923 /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
2924 return success();
2925 }
2926
2927 //===----------------------------------------------------------------------===//
2928 // If Op
2929 //===----------------------------------------------------------------------===//
2930
VerifyConditionalBranch(Operation * op,Region & region,Value operand,llvm::Twine branchName,llvm::Twine operandName)2931 static LogicalResult VerifyConditionalBranch(Operation* op, Region& region,
2932 Value operand,
2933 llvm::Twine branchName,
2934 llvm::Twine operandName) {
2935 mlir::Block& entryBlock = region.front();
2936 if (entryBlock.getNumArguments() != 1)
2937 return op->emitOpError()
2938 << branchName << " block should have single argument, but found "
2939 << entryBlock.getNumArguments();
2940
2941 Type operandType = operand.getType();
2942 Type branchArgType = entryBlock.getArgument(0).getType();
2943 if (branchArgType != operandType)
2944 return op->emitOpError()
2945 << operandName << " type (" << operandType << ") does not match "
2946 << branchName << " block arg type (" << branchArgType << ")";
2947 TypeRange branchReturnTypes = entryBlock.getTerminator()->getOperandTypes();
2948 if (branchReturnTypes != op->getResultTypes())
2949 return op->emitOpError()
2950 << branchName << " returned types (" << branchReturnTypes
2951 << ") do not match op result types (" << op->getResultTypes() << ")";
2952
2953 return success();
2954 }
2955
Verify(IfOp op)2956 static LogicalResult Verify(IfOp op) {
2957 if (failed(VerifyConditionalBranch(op, op.true_branch(), op.true_arg(),
2958 /*branchName=*/"true_branch",
2959 /*operandName=*/"true_arg"))) {
2960 return failure();
2961 }
2962
2963 if (failed(VerifyConditionalBranch(op, op.false_branch(), op.false_arg(),
2964 /*branchName=*/"false_branch",
2965 /*operandName=*/"false_arg"))) {
2966 return failure();
2967 }
2968 return success();
2969 }
2970
InlineIfConstantCondition(IfOp ifOp,PatternRewriter & rewriter)2971 static LogicalResult InlineIfConstantCondition(IfOp ifOp,
2972 PatternRewriter& rewriter) {
2973 DenseIntElementsAttr pred_attr;
2974 if (!matchPattern(ifOp.pred(), m_Constant(&pred_attr))) return failure();
2975
2976 if (pred_attr.getSplatValue<BoolAttr>().getValue()) {
2977 ReplaceOpWithRegion(rewriter, ifOp, ifOp.true_branch(), ifOp.true_arg());
2978 } else {
2979 ReplaceOpWithRegion(rewriter, ifOp, ifOp.false_branch(), ifOp.false_arg());
2980 }
2981 return success();
2982 }
2983
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2984 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2985 MLIRContext* context) {
2986 results.add(&InlineIfConstantCondition);
2987 }
2988
2989 //===----------------------------------------------------------------------===//
2990 // Case Op
2991 //===----------------------------------------------------------------------===//
2992
Verify(CaseOp op)2993 static LogicalResult Verify(CaseOp op) {
2994 auto num_branches = op.branches().size();
2995 if (op.branch_operands().size() != num_branches)
2996 return op.emitOpError() << " number of branches (" << num_branches
2997 << ") does not match number of branch operands ("
2998 << op.branch_operands().size() << ")";
2999
3000 for (unsigned i = 0; i < num_branches; ++i)
3001 if (failed(VerifyConditionalBranch(
3002 op, op.branches()[i], op.branch_operands()[i],
3003 /*branchName=*/"branch " + Twine(i),
3004 /*operandName=*/"branch_operand " + Twine(i))))
3005 return failure();
3006
3007 return success();
3008 }
3009
InlineCaseConstantCondition(CaseOp caseOp,PatternRewriter & rewriter)3010 static LogicalResult InlineCaseConstantCondition(CaseOp caseOp,
3011 PatternRewriter& rewriter) {
3012 DenseIntElementsAttr index_attr;
3013 if (!matchPattern(caseOp.index(), m_Constant(&index_attr))) {
3014 return failure();
3015 }
3016 int64_t index =
3017 index_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
3018 // For an OOB index, the last branch is executed as the default branch:
3019 // https://www.tensorflow.org/xla/operation_semantics#conditional
3020 if (index < 0 || index >= caseOp.getNumRegions())
3021 index = caseOp.getNumRegions() - 1;
3022
3023 Region& region = caseOp.getRegion(index);
3024 if (!llvm::hasSingleElement(region)) return failure();
3025 ReplaceOpWithRegion(rewriter, caseOp, region,
3026 caseOp.branch_operands()[index]);
3027 return success();
3028 }
3029
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3030 void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
3031 MLIRContext* context) {
3032 results.add(&InlineCaseConstantCondition);
3033 }
3034
3035 //===----------------------------------------------------------------------===//
3036 // SqrtOp
3037 //===----------------------------------------------------------------------===//
3038
fold(ArrayRef<Attribute> operands)3039 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
3040 auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
3041 if (!val) return {};
3042
3043 auto type = getElementTypeOrSelf(getType());
3044 if (!type.isF32() && !type.isF64()) return {};
3045
3046 auto shaped_type = getType().cast<ShapedType>();
3047 if (!shaped_type.hasStaticShape()) return {};
3048
3049 int bit_width = type.getIntOrFloatBitWidth();
3050 llvm::SmallVector<APFloat, 4> values;
3051 values.reserve(val.getNumElements());
3052 for (auto it : val.getFloatValues()) {
3053 double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble();
3054 if (value < 0) return {};
3055 value = std::sqrt(value);
3056 if (bit_width == 32)
3057 values.emplace_back(static_cast<float>(value));
3058 else
3059 values.emplace_back(value);
3060 }
3061 return DenseFPElementsAttr::get(shaped_type, values);
3062 }
3063
3064 //===----------------------------------------------------------------------===//
3065 // UnaryOps
3066 //===----------------------------------------------------------------------===//
3067
3068 template <typename Op, typename ElementType = Type, typename ValType,
3069 typename Convert>
UnaryFolder(Op * op,ArrayRef<Attribute> attrs)3070 static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
3071 if (!attrs[0]) return {};
3072
3073 DenseElementsAttr val = attrs[0].dyn_cast<DenseElementsAttr>();
3074 if (!val) return {};
3075
3076 ShapedType type = op->getType().template cast<ShapedType>();
3077 if (!type.hasStaticShape()) {
3078 return {};
3079 }
3080
3081 Type etype = type.getElementType();
3082
3083 // Evaluate for integer values.
3084 if (!etype.isa<ElementType>()) {
3085 return {};
3086 }
3087
3088 SmallVector<ValType, 6> values;
3089 values.reserve(val.getNumElements());
3090 for (const auto v : val.getValues<ValType>()) {
3091 values.push_back(Convert()(v));
3092 }
3093
3094 return DenseElementsAttr::get(type, values);
3095 }
3096
3097 struct round {
operator ()mlir::mhlo::round3098 APFloat operator()(const APFloat& f) {
3099 APFloat r = f;
3100 r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
3101 return r;
3102 }
3103 };
3104
3105 struct logical_not {
operator ()mlir::mhlo::logical_not3106 APInt operator()(const APInt& i) {
3107 return APInt(i.getBitWidth(), static_cast<uint64_t>(!i));
3108 }
3109 };
3110
3111 template <typename FloatOrInt>
3112 struct sign {
computemlir::mhlo::sign3113 APFloat compute(const APFloat& f) {
3114 if (f.isZero() || f.isNaN()) return f;
3115 double value = f.isNegative() ? -1.0 : 1.0;
3116 APFloat val(value);
3117 bool unused;
3118 val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused);
3119 return val;
3120 }
3121
computemlir::mhlo::sign3122 APInt compute(const APInt& i) {
3123 APInt r = i;
3124 if (r == 0) return r;
3125 if (r.isNegative()) {
3126 return APInt(r.getBitWidth(), -1, /*isSigned=*/true);
3127 }
3128 return APInt(r.getBitWidth(), 1, /*isSigned=*/true);
3129 }
3130
operator ()mlir::mhlo::sign3131 FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); }
3132 };
3133
3134 #define UNARY_FOLDER(Op, Func) \
3135 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
3136 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
3137 return UnaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
3138 if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
3139 return UnaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
3140 return {}; \
3141 }
3142
3143 #define UNARY_FOLDER_INT(Op, Func) \
3144 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
3145 if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
3146 return UnaryFolder<Op, IntegerType, APInt, Func>(this, attrs); \
3147 return {}; \
3148 }
3149
3150 #define UNARY_FOLDER_FLOAT(Op, Func) \
3151 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
3152 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
3153 return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
3154 return {}; \
3155 }
3156
3157 UNARY_FOLDER(NegOp, std::negate);
3158 UNARY_FOLDER(SignOp, sign);
3159 UNARY_FOLDER_INT(NotOp, logical_not);
3160 UNARY_FOLDER_FLOAT(RoundOp, round);
3161
3162 #undef UNARY_FOLDER
3163 #undef UNARY_FOLDER_INT
3164 #undef UNARY_FOLDER_FLOAT
3165
3166 //===----------------------------------------------------------------------===//
3167 // BinaryOps
3168 //===----------------------------------------------------------------------===//
3169
3170 namespace {
3171
3172 // Updates the element type of a (presumed) tensor type 'x', returning either
3173 // a permuted UnrankedTensorType or RankedTensorType.
UpdateResultElementType(Builder * builder,Type x,Type element_type)3174 static Type UpdateResultElementType(Builder* builder, Type x,
3175 Type element_type) {
3176 auto x_ranked = x.dyn_cast<RankedTensorType>();
3177 if (!x_ranked) {
3178 return UnrankedTensorType::get(element_type);
3179 }
3180
3181 auto shape_x = x_ranked.getShape();
3182 return RankedTensorType::get(shape_x, element_type);
3183 }
3184 } // namespace
3185
3186 template <typename Op, typename ElementType = Type, typename ValType,
3187 typename Convert>
BinaryFolder(Op * op,ArrayRef<Attribute> attrs)3188 static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
3189 if (!attrs[0] || !attrs[1]) return {};
3190
3191 DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
3192 DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
3193 if (!lhs || !rhs) return {};
3194
3195 ShapedType type = op->getType().template cast<ShapedType>();
3196 if (!type.hasStaticShape()) {
3197 return {};
3198 }
3199
3200 Type etype = type.getElementType();
3201
3202 // Evaluate for integer values.
3203 if (!etype.isa<ElementType>()) {
3204 return {};
3205 }
3206
3207 SmallVector<ValType, 6> values;
3208 values.reserve(lhs.getNumElements());
3209 for (const auto zip :
3210 llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) {
3211 values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
3212 }
3213
3214 return DenseElementsAttr::get(type, values);
3215 }
3216
3217 template <typename T>
3218 struct divide : std::divides<T> {};
3219
3220 template <>
3221 struct divide<APInt> {
operator ()mlir::mhlo::divide3222 APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
3223 };
3224
3225 template <typename T>
3226 struct remainder : std::modulus<T> {};
3227
3228 template <>
3229 struct remainder<APInt> {
operator ()mlir::mhlo::remainder3230 APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
3231 };
3232
3233 template <>
3234 struct remainder<APFloat> {
operator ()mlir::mhlo::remainder3235 APFloat operator()(const APFloat& a, const APFloat& b) const {
3236 APFloat result(a);
3237 result.remainder(b);
3238 return result;
3239 }
3240 };
3241
3242 template <typename T>
3243 struct max {
operator ()mlir::mhlo::max3244 T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
3245 };
3246
3247 template <>
3248 struct max<APInt> {
operator ()mlir::mhlo::max3249 APInt operator()(const APInt& a, const APInt& b) const {
3250 return llvm::APIntOps::smax(a, b);
3251 }
3252 };
3253
3254 template <typename T>
3255 struct min {
operator ()mlir::mhlo::min3256 T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
3257 };
3258
3259 template <>
3260 struct min<APInt> {
operator ()mlir::mhlo::min3261 APInt operator()(const APInt& a, const APInt& b) const {
3262 return llvm::APIntOps::smin(a, b);
3263 }
3264 };
3265
3266 #define BINARY_FOLDER(Op, Func) \
3267 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
3268 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
3269 return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
3270 if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
3271 return BinaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
3272 return {}; \
3273 }
3274
3275 // Addition, subtraction and multiplication use the std:: versions of the ops.
3276 // Due to the other ops behaving differently in signed vs unsigned integers,
3277 // APInts need a special implementation. Currently, it replicates signed int
3278 // op behavior.
3279 BINARY_FOLDER(AddOp, std::plus);
3280 BINARY_FOLDER(SubOp, std::minus);
3281 BINARY_FOLDER(MulOp, std::multiplies);
3282 BINARY_FOLDER(DivOp, divide);
3283 BINARY_FOLDER(RemOp, remainder);
3284 BINARY_FOLDER(MaxOp, max);
3285 BINARY_FOLDER(MinOp, min);
3286
3287 #undef BINARY_FOLDER
3288
3289 //===----------------------------------------------------------------------===//
3290 // SliceOp
3291 //===----------------------------------------------------------------------===//
3292
3293 // Returns output dimension size for slice result for the given arguments.
3294 // Returns -1 if arguments are illegal.
InferSliceDim(int64_t input_dim,int64_t start,int64_t end,int64_t stride)3295 static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
3296 int64_t stride) {
3297 if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
3298 stride == 0)
3299 return -1;
3300
3301 return llvm::divideCeil(end - start, stride);
3302 }
3303
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)3304 LogicalResult SliceOp::inferReturnTypes(
3305 MLIRContext* context, Optional<Location> location, ValueRange operands,
3306 DictionaryAttr attributes, RegionRange regions,
3307 SmallVectorImpl<Type>& inferredReturnTypes) {
3308 SliceOpAdaptor slice(operands, attributes);
3309 // TODO(jpienaar): Update this code after refactoring verify.
3310 if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) {
3311 return failure();
3312 }
3313
3314 Type ty = slice.operand().getType();
3315 RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
3316 if (!ranked_ty) {
3317 // The operand type is unranked, so the best we can infer for the result
3318 // type is an unranked tensor with the same element type as the operand
3319 // type.
3320 inferredReturnTypes.assign({ty});
3321 return success();
3322 }
3323
3324 ShapedType attr_ty = slice.start_indices().getType();
3325 if (attr_ty.getRank() != 1) {
3326 return emitOptionalError(location, "start_indices has rank ",
3327 attr_ty.getRank(), " instead of required rank 1");
3328 }
3329
3330 int64_t rank = ranked_ty.getRank();
3331 if (attr_ty.getNumElements() != rank) {
3332 return emitOptionalError(
3333 location, "the number of elements in start_indices (",
3334 attr_ty.getNumElements(), ") does not match the rank of the operand (",
3335 rank, ")");
3336 }
3337
3338 if (!attr_ty.getElementType().isSignlessInteger(64) ||
3339 slice.limit_indices().getType() != attr_ty ||
3340 slice.strides().getType() != attr_ty) {
3341 // Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
3342 // having been verified at this point. Emit an error message that matches
3343 // the one that would be reported by AllTypesMatch for a more consistent
3344 // user experience.
3345 // TODO(b/171567182): Clean this up after AllTypesMatch has been refactored.
3346 return emitOptionalError(location,
3347 "failed to verify that all of {start_indices, "
3348 "limit_indices, strides} have same type");
3349 }
3350
3351 SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
3352 SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
3353 SmallVector<int64_t, 4> stride_vals(slice.strides().getValues<int64_t>());
3354
3355 SmallVector<int64_t, 4> shape;
3356 shape.reserve(rank);
3357 for (int64_t i = 0, e = rank; i != e; i++) {
3358 shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
3359 stride_vals[i]));
3360 }
3361 inferredReturnTypes.assign(
3362 {RankedTensorType::get(shape, ranked_ty.getElementType())});
3363 return success();
3364 }
3365
3366 template <typename I, typename E>
SliceElements(I values,ArrayRef<int64_t> sizes,ArrayRef<int64_t> starts,ArrayRef<int64_t> limits,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<E> * out_values)3367 static void SliceElements(I values, ArrayRef<int64_t> sizes,
3368 ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
3369 ArrayRef<int64_t> strides,
3370 llvm::SmallVectorImpl<E>* out_values) {
3371 assert(starts.size() == limits.size());
3372 assert(starts.size() == strides.size());
3373 if (starts.empty()) return;
3374
3375 int64_t start = starts.front();
3376 int64_t limit = limits.front();
3377 int64_t stride = strides.front();
3378 if (starts.size() == 1) {
3379 for (int i = start; i < limit; i += stride) {
3380 out_values->push_back(*(values + i));
3381 }
3382 return;
3383 }
3384
3385 for (; start < limit; start += stride) {
3386 auto begin = values + start * sizes.front();
3387 SliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
3388 limits.drop_front(), strides.drop_front(), out_values);
3389 }
3390 }
3391
3392 template <typename I, typename E>
FoldSlice(SliceOp * op,I values)3393 static Attribute FoldSlice(SliceOp* op, I values) {
3394 auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
3395 auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
3396 auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
3397
3398 auto result_type = op->operand().getType().cast<ShapedType>();
3399 if (!result_type.hasStaticShape()) return {};
3400
3401 auto shape = result_type.getShape();
3402 int64_t count = result_type.getNumElements();
3403 if (count == 0) {
3404 return DenseElementsAttr::get<E>(
3405 op->getResult().getType().cast<ShapedType>(),
3406 /*list=*/{});
3407 }
3408
3409 // Compute the striding for each dimension.
3410 llvm::SmallVector<int64_t, 6> sizes;
3411 sizes.reserve(shape.size());
3412 for (auto v : shape) {
3413 count = count / v;
3414 sizes.push_back(count);
3415 }
3416
3417 llvm::SmallVector<E, 6> out_values;
3418 out_values.reserve(result_type.getNumElements());
3419 SliceElements<I, E>(values, sizes, start, limit, stride, &out_values);
3420
3421 return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
3422 out_values);
3423 }
3424
fold(ArrayRef<Attribute> operands)3425 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
3426 // Check if the SliceOp is a NoOp operation.
3427 auto operand_type = getOperand().getType().cast<ShapedType>();
3428 auto result_type = getResult().getType().cast<ShapedType>();
3429
3430 if (operand_type.hasStaticShape() && result_type.hasStaticShape() &&
3431 (operand_type.getShape() == result_type.getShape())) {
3432 return getOperand();
3433 }
3434
3435 if (operands.empty() || !operands.front()) return {};
3436
3437 // Evaluate for statically valued inputs.
3438 DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
3439 if (!elements) return {};
3440
3441 auto etype = elements.getType().getElementType();
3442 if (etype.isa<IntegerType>()) {
3443 return FoldSlice<DenseElementsAttr::IntElementIterator, APInt>(
3444 this, elements.getIntValues().begin());
3445 } else if (etype.isa<FloatType>()) {
3446 return FoldSlice<
3447 llvm::mapped_iterator<DenseElementsAttr::IntElementIterator,
3448 std::function<APFloat(const APInt&)>>,
3449 APFloat>(this, elements.getFloatValues().begin());
3450 }
3451
3452 return {};
3453 }
3454
3455 namespace {
3456 // In cases where a concat is fed into a slice, it is possible the concat
3457 // can be simplified or bypassed. This checks which inputs to the concat are
3458 // used by the slice, either reducing the number of concatenated values or
3459 // entirely removes the concat.
3460 struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
3461 using OpRewritePattern<SliceOp>::OpRewritePattern;
3462
matchAndRewritemlir::mhlo::__anon7800207c1511::SimplifyConcatSlice3463 LogicalResult matchAndRewrite(SliceOp slice,
3464 PatternRewriter& rewriter) const override {
3465 auto result_ty = slice.getType().cast<ShapedType>();
3466 if (!result_ty.hasStaticShape()) {
3467 return failure();
3468 }
3469
3470 auto slice_input = slice.operand();
3471 auto slice_input_ty = slice_input.getType().cast<ShapedType>();
3472 auto concat = slice_input.getDefiningOp<ConcatenateOp>();
3473 if (!concat) {
3474 return failure();
3475 }
3476
3477 auto dimension = concat.dimension();
3478
3479 auto start = slice.start_indices().getIntValues();
3480 auto limit = slice.limit_indices().getIntValues();
3481
3482 auto slice_start = (*(start.begin() + dimension)).getSExtValue();
3483 auto slice_limit = (*(limit.begin() + dimension)).getSExtValue();
3484
3485 // We need to determine what inputs from the concat affect the slice, and
3486 // how the bounds of the slice need to be updated for the minimally required
3487 // inputs.
3488 int64_t running_size = 0;
3489 int64_t front_offset = slice_input_ty.getShape()[dimension];
3490
3491 auto subset_start = concat.operand_end();
3492 auto subset_end = concat.operand_end();
3493 for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
3494 auto input = *it;
3495 ShapedType input_ty = input.getType().cast<ShapedType>();
3496 if (input_ty.isDynamicDim(dimension)) {
3497 return failure();
3498 }
3499 auto dim_size = input_ty.getShape()[dimension];
3500
3501 // If this position is in the slice its the start of the subset and we
3502 // need to update the start and limit values.
3503 if (running_size + dim_size > slice_start &&
3504 subset_start == concat.operand_end()) {
3505 subset_start = it;
3506 front_offset = running_size;
3507 }
3508
3509 // Determine the last required offset.
3510 if (running_size < slice_limit) {
3511 subset_end = it + 1;
3512 }
3513
3514 running_size += dim_size;
3515 }
3516
3517 auto subset_size = subset_end - subset_start;
3518 // We need all inputs so no optimization.
3519 if (subset_size == concat.getNumOperands()) {
3520 return failure();
3521 }
3522
3523 // If there's nothing to slice that means the output is an empty tensor and
3524 // there is dead code. We do nothing here and rely on other passes to clean
3525 // this up.
3526 if (subset_size == 0) {
3527 return failure();
3528 }
3529
3530 if (subset_size > 1 && !concat.getResult().hasOneUse()) {
3531 return failure();
3532 }
3533
3534 auto concat_range = OperandRange(subset_start, subset_end);
3535 auto new_concat = rewriter.create<ConcatenateOp>(
3536 concat.getLoc(), concat_range, concat.dimension());
3537
3538 llvm::SmallVector<APInt, 6> new_start(start);
3539 llvm::SmallVector<APInt, 6> new_limit(limit);
3540 new_start[dimension] -= front_offset;
3541 new_limit[dimension] -= front_offset;
3542
3543 auto attr_type = slice.start_indices().getType().cast<ShapedType>();
3544 auto create = rewriter.create<SliceOp>(
3545 slice.getLoc(), new_concat,
3546 DenseIntElementsAttr::get(attr_type, new_start),
3547 DenseIntElementsAttr::get(attr_type, new_limit), slice.strides());
3548 rewriter.replaceOp(slice, create.getResult());
3549 return success();
3550 }
3551 };
3552 } // namespace
3553
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3554 void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
3555 MLIRContext* context) {
3556 results.insert<SimplifyConcatSlice>(context);
3557 }
3558
3559 //===----------------------------------------------------------------------===//
3560 // SortOp
3561 //===----------------------------------------------------------------------===//
3562
build(OpBuilder & builder,OperationState & state,ValueRange operands,int64_t dimension,bool is_stable)3563 void SortOp::build(OpBuilder& builder, OperationState& state,
3564 ValueRange operands, int64_t dimension, bool is_stable) {
3565 state.addOperands(operands);
3566 state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
3567 state.addAttribute("is_stable", builder.getBoolAttr(dimension));
3568
3569 for (Value operand : operands) state.addTypes(operand.getType());
3570
3571 state.addRegion();
3572 }
3573
Verify(SortOp op)3574 static LogicalResult Verify(SortOp op) {
3575 Operation::operand_range operands = op.operands();
3576 if (operands.empty()) return op.emitOpError("requires at least one input");
3577
3578 // TODO(antiagainst): verify partionally dynamic shapes
3579 if (llvm::all_of(operands, [](Value operand) {
3580 return operand.getType().cast<ShapedType>().hasRank();
3581 })) {
3582 ArrayRef<int64_t> input_shape =
3583 (*operands.begin()).getType().cast<ShapedType>().getShape();
3584
3585 if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
3586 return operand.getType().cast<ShapedType>().getShape() != input_shape;
3587 }))
3588 return op.emitOpError("requires all inputs to have the same dimensions");
3589
3590 int64_t rank = input_shape.size();
3591 int64_t cmp_dim = op.dimension();
3592 if (cmp_dim < -rank || cmp_dim >= rank)
3593 return op.emitOpError("dimension attribute value must be in range [-")
3594 << rank << ", " << rank << "), but found " << cmp_dim;
3595 }
3596
3597 Block& block = op.comparator().front();
3598 size_t num_operands = op.getOperation()->getNumOperands();
3599 if (block.getNumArguments() != 2 * num_operands)
3600 return op.emitOpError("comparator block should have ")
3601 << 2 * num_operands << " arguments";
3602
3603 for (auto indexed_operand : llvm::enumerate(operands)) {
3604 int index = indexed_operand.index();
3605 Type element_type =
3606 indexed_operand.value().getType().cast<ShapedType>().getElementType();
3607 Type tensor_type = RankedTensorType::get({}, element_type);
3608 for (int i : {2 * index, 2 * index + 1}) {
3609 Type arg_type = block.getArgument(i).getType();
3610 if (arg_type != tensor_type)
3611 return op.emitOpError("comparator block argument #")
3612 << i << " should be of type " << tensor_type << " but got "
3613 << arg_type;
3614 }
3615 }
3616
3617 return success();
3618 }
3619
3620 /// Drops the operands if the results are not used and they are not used in
3621 /// op.comparator().
SortDropEmptyUseArgs(SortOp op,PatternRewriter & rewriter)3622 static LogicalResult SortDropEmptyUseArgs(SortOp op,
3623 PatternRewriter& rewriter) {
3624 DenseSet<unsigned> erased_args;
3625 unsigned num_operands = op.getNumOperands();
3626 for (unsigned i = 0; i < num_operands; ++i) {
3627 if (!op.getResult(i).use_empty()) continue;
3628 Block& block = op.comparator().front();
3629 if (!block.getArgument(i * 2).use_empty()) continue;
3630 if (!block.getArgument(i * 2 + 1).use_empty()) continue;
3631 erased_args.insert(i);
3632 }
3633 if (erased_args.empty()) return failure();
3634
3635 SmallVector<Value> new_operands;
3636 SmallVector<unsigned> erased_block_args;
3637 for (auto en : llvm::enumerate(op.operands())) {
3638 if (erased_args.contains(en.index())) {
3639 erased_block_args.push_back(en.index() * 2);
3640 erased_block_args.push_back(en.index() * 2 + 1);
3641 } else {
3642 new_operands.push_back(en.value());
3643 }
3644 }
3645
3646 auto new_op = rewriter.create<SortOp>(op.getLoc(), new_operands,
3647 op.dimension(), op.is_stable());
3648 Region& region = new_op.comparator();
3649 rewriter.inlineRegionBefore(op.comparator(), region, region.end());
3650 region.front().eraseArguments(erased_block_args);
3651
3652 SmallVector<Value> results;
3653 for (unsigned i = 0, j = 0; i < num_operands; ++i) {
3654 if (erased_args.contains(i)) {
3655 results.push_back({});
3656 } else {
3657 results.push_back(new_op.getResult(j++));
3658 }
3659 }
3660 rewriter.replaceOp(op, results);
3661
3662 return success();
3663 }
3664
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext *)3665 void SortOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
3666 MLIRContext* /*context*/) {
3667 results.insert(SortDropEmptyUseArgs);
3668 }
3669
3670 //===----------------------------------------------------------------------===//
3671 // TransposeOp
3672 //===----------------------------------------------------------------------===//
3673
fold(ArrayRef<Attribute> operands)3674 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
3675 for (auto it : llvm::enumerate(permutation().getValues<APInt>())) {
3676 if (it.index() != it.value()) {
3677 return {};
3678 }
3679 }
3680 return getOperand();
3681 }
3682
Verify(TransposeOp op)3683 static LogicalResult Verify(TransposeOp op) {
3684 // permutation is an attribute of the op so it has static shape.
3685 auto permutationType = op.permutation().getType();
3686 auto permutationRank = permutationType.getRank();
3687 if (permutationRank != 1) {
3688 return op.emitOpError(llvm::formatv(
3689 "permutation has rank {0} instead of rank 1", permutationRank));
3690 }
3691 auto permutationSize = permutationType.getNumElements();
3692
3693 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
3694 if (operandType) {
3695 auto operandRank = operandType.getRank();
3696 if (operandRank != permutationSize) {
3697 return op.emitOpError(llvm::formatv(
3698 "operand rank ({0}) does not match permutation size ({1})",
3699 operandRank, permutationSize));
3700 }
3701 }
3702
3703 auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
3704 if (resultType) {
3705 auto resultRank = resultType.getRank();
3706 if (resultRank != permutationSize) {
3707 return op.emitOpError(llvm::formatv(
3708 "result rank ({0}) does not match permutation size ({1})", resultRank,
3709 permutationSize));
3710 }
3711 }
3712
3713 if (!resultType || !operandType) return success();
3714
3715 auto operandRank = operandType.getRank();
3716 SmallVector<int64_t, 4> expectedShape(operandRank);
3717 for (int i = 0; i != operandRank; ++i) {
3718 auto permutedDim = op.permutation().getValue<IntegerAttr>(i).getInt();
3719 expectedShape[i] = operandType.getDimSize(permutedDim);
3720 }
3721
3722 auto expectedType =
3723 RankedTensorType::get(expectedShape, resultType.getElementType());
3724 if (failed(verifyCompatibleShape(resultType, expectedType))) {
3725 return op.emitOpError(llvm::formatv(
3726 "result type {0} is incompatible with the expected type {1}",
3727 resultType, expectedType));
3728 }
3729
3730 return success();
3731 }
3732
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3733 LogicalResult TransposeOp::reifyReturnTypeShapes(
3734 OpBuilder& builder, ValueRange operands,
3735 SmallVectorImpl<Value>& reifiedReturnShapes) {
3736 TransposeOp::Adaptor adaptor(operands);
3737 Value operand = adaptor.operand();
3738
3739 auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
3740 // Not support unranked type a.t.m.
3741 if (!operand_type) return failure();
3742
3743 Location loc = this->getLoc();
3744 SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
3745 SmallVector<Value, 4> shape_values(permutation.size());
3746
3747 Type shape_scalar_type = builder.getIndexType();
3748 auto to_shape_scalar_type = [&](Value v) {
3749 return MaybeCastTo(builder, loc, v, shape_scalar_type);
3750 };
3751
3752 for (const auto& element : llvm::enumerate(operand_type.getShape())) {
3753 int64_t idx = element.index();
3754 auto it = std::find(permutation.begin(), permutation.end(), idx);
3755 Value value_dim = to_shape_scalar_type(
3756 builder.create<tensor::DimOp>(loc, operand, element.index()));
3757 shape_values[std::distance(permutation.begin(), it)] = value_dim;
3758 }
3759
3760 Value output_shape = builder.create<tensor::FromElementsOp>(
3761 loc, shape_scalar_type, shape_values);
3762 reifiedReturnShapes.push_back(output_shape);
3763
3764 return success();
3765 }
3766
3767 //===----------------------------------------------------------------------===//
3768 // TriangularSolveOp
3769 //===----------------------------------------------------------------------===//
3770
Verify(TriangularSolveOp op)3771 static LogicalResult Verify(TriangularSolveOp op) {
3772 auto a_type = op.a().getType().dyn_cast<RankedTensorType>();
3773
3774 // Skip verifier if a is unranked tensor.
3775 if (!a_type) return success();
3776
3777 // Check that a should have rank >= 2
3778 auto a_rank = a_type.getRank();
3779 if (a_rank < 2)
3780 return op.emitOpError()
3781 << "operand 'a' must have rank >= 2, but got " << a_type;
3782
3783 // The two minor dimensions of a must have same size.
3784 if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1))
3785 return op.emitOpError() << "two minor dimensions of operand 'a' must have "
3786 "equal size, but got "
3787 << a_type;
3788
3789 auto b_type = op.b().getType().dyn_cast<RankedTensorType>();
3790 // If b is unranked skip remaining checks.
3791 if (!b_type) return success();
3792
3793 // Check that a and b have same rank.
3794 auto b_rank = b_type.getRank();
3795 if (a_rank != b_rank)
3796 return op.emitOpError() << "operands must have equal rank, but got "
3797 << a_type << " and " << b_type;
3798
3799 // The shared dimension of a and b should match.
3800 if (a_type.getDimSize(a_rank - 1) !=
3801 b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1)))
3802 return op.emitOpError() << "shared dimension of operands 'a' and 'b' does "
3803 "not match, but got "
3804 << a_type << " and " << b_type;
3805
3806 // The leading batch dimensions of a and b must be equal.
3807 auto a_batch_dims = a_type.getShape().drop_back(2);
3808 auto b_batch_dims = b_type.getShape().drop_back(2);
3809 if (a_batch_dims != b_batch_dims)
3810 return op.emitOpError()
3811 << "leading batch dimensions of the operands must be same, but got "
3812 << a_type << " and " << b_type;
3813
3814 // Result and argument b must have same shape.
3815 auto result_type = op.getType().dyn_cast<RankedTensorType>();
3816 if (!result_type) return success();
3817 if (result_type != b_type)
3818 return op.emitOpError()
3819 << "result and operand 'b' must have same shape, but got "
3820 << result_type << " and " << b_type;
3821 return success();
3822 }
3823
3824 //===----------------------------------------------------------------------===//
3825 // GetTupleElementOp
3826 //===----------------------------------------------------------------------===//
3827
build(OpBuilder & builder,OperationState & result,Value tuple,int32_t index)3828 void GetTupleElementOp::build(OpBuilder& builder, OperationState& result,
3829 Value tuple, int32_t index) {
3830 if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
3831 auto element_type = tuple_type.getType(index);
3832 build(builder, result, element_type, tuple,
3833 builder.getI32IntegerAttr(index));
3834 return;
3835 }
3836
3837 build(builder, result, tuple.getType(), tuple,
3838 builder.getI32IntegerAttr(index));
3839 }
3840
3841 //===----------------------------------------------------------------------===//
3842 // TupleOp
3843 //===----------------------------------------------------------------------===//
3844
build(OpBuilder & builder,OperationState & result,ValueRange values)3845 void TupleOp::build(OpBuilder& builder, OperationState& result,
3846 ValueRange values) {
3847 SmallVector<Type, 4> types;
3848 types.reserve(values.size());
3849 for (auto val : values) {
3850 types.push_back(val.getType());
3851 }
3852
3853 build(builder, result, builder.getTupleType(types), values);
3854 }
3855
3856 //===----------------------------------------------------------------------===//
3857 // UnaryEinsumOp
3858 //===----------------------------------------------------------------------===//
3859
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3860 void UnaryEinsumOp::getCanonicalizationPatterns(
3861 OwningRewritePatternList& results, MLIRContext* context) {
3862 results.insert<UnaryEinsumToEinsum>(context);
3863 }
3864
3865 //===----------------------------------------------------------------------===//
3866 // CompareOp
3867 //===----------------------------------------------------------------------===//
3868
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,StringAttr comparison_direction,StringAttr compare_type)3869 void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
3870 Value rhs, StringAttr comparison_direction,
3871 StringAttr compare_type) {
3872 auto new_type =
3873 UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
3874 build(builder, result, new_type, lhs, rhs, comparison_direction,
3875 compare_type);
3876 }
3877
inferReturnTypeComponents(mlir::MLIRContext * ctx,llvm::Optional<mlir::Location>,ValueShapeRange operands,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferredReturnTypes)3878 LogicalResult CompareOp::inferReturnTypeComponents(
3879 mlir::MLIRContext* ctx, llvm::Optional<mlir::Location>,
3880 ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange,
3881 llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferredReturnTypes) {
3882 OpBuilder builder(ctx);
3883 auto arg_ty = operands.front().getType().cast<TensorType>();
3884 inferredReturnTypes.push_back({arg_ty.getShape(), builder.getI1Type()});
3885 return success();
3886 }
3887
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3888 LogicalResult CompareOp::reifyReturnTypeShapes(
3889 OpBuilder& builder, ValueRange operands,
3890 SmallVectorImpl<Value>& reifiedReturnShapes) {
3891 return deriveShapeFromOperand(&builder, getOperation(), operands.front(),
3892 &reifiedReturnShapes);
3893 }
3894
3895 template <typename T>
3896 struct less : std::less<T> {};
3897
3898 template <>
3899 struct less<APInt> {
operator ()mlir::mhlo::less3900 bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
3901 };
3902
3903 template <typename T>
3904 struct less_equal : std::less_equal<T> {};
3905
3906 template <>
3907 struct less_equal<APInt> {
operator ()mlir::mhlo::less_equal3908 bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
3909 };
3910
3911 template <typename T>
3912 struct greater : std::greater<T> {};
3913
3914 template <>
3915 struct greater<APInt> {
operator ()mlir::mhlo::greater3916 bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
3917 };
3918
3919 template <typename T>
3920 struct greater_equal : std::greater_equal<T> {};
3921
3922 template <>
3923 struct greater_equal<APInt> {
operator ()mlir::mhlo::greater_equal3924 bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
3925 };
3926
3927 template <typename Op, typename ElementType, typename SrcType, typename Convert>
CompareFolder(CompareOp op,ArrayRef<Attribute> attrs)3928 static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
3929 if (!attrs[0] || !attrs[1]) return {};
3930
3931 DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
3932 DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
3933 if (!lhs || !rhs) return {};
3934
3935 ShapedType operand_type =
3936 op.getOperand(0).getType().template cast<ShapedType>();
3937 if (!operand_type.hasStaticShape()) {
3938 return {};
3939 }
3940
3941 if (!operand_type.getElementType().isa<ElementType>()) {
3942 return {};
3943 }
3944
3945 SmallVector<bool, 6> values;
3946 values.reserve(lhs.getNumElements());
3947 for (const auto zip :
3948 llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
3949 values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
3950 }
3951
3952 auto result_ty = op.getType().cast<ShapedType>();
3953 return DenseElementsAttr::get(result_ty, values);
3954 }
3955
fold(ArrayRef<Attribute> operands)3956 OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
3957 auto result_ty = getType().cast<ShapedType>();
3958 if (!result_ty.hasStaticShape()) return {};
3959
3960 auto direction = comparison_direction();
3961 if (lhs() == rhs() && !getElementTypeOrSelf(lhs()).isa<FloatType>()) {
3962 if (direction == "LE" || direction == "EQ" || direction == "GE") {
3963 return DenseIntElementsAttr::get(result_ty, {true});
3964 }
3965 return DenseIntElementsAttr::get(result_ty, {false});
3966 }
3967
3968 auto op_el_type = lhs().getType().cast<ShapedType>().getElementType();
3969 // Fold tensor<*xi1> != false to just return tensor<*xi1>
3970 if (direction == "NE" && op_el_type.isInteger(1)) {
3971 DenseIntElementsAttr cst_attr;
3972 if (matchPattern(lhs(), m_Constant(&cst_attr))) {
3973 if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
3974 return rhs();
3975 }
3976 }
3977
3978 if (matchPattern(rhs(), m_Constant(&cst_attr))) {
3979 if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
3980 return lhs();
3981 }
3982 }
3983 }
3984
3985 // Fold tensor<*xi1> == True to just return tensor<*xi1>
3986 if (direction == "EQ" && op_el_type.isInteger(1)) {
3987 DenseIntElementsAttr cst_attr;
3988 if (matchPattern(lhs(), m_Constant(&cst_attr))) {
3989 if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
3990 return rhs();
3991 }
3992 }
3993
3994 if (matchPattern(rhs(), m_Constant(&cst_attr))) {
3995 if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
3996 return lhs();
3997 }
3998 }
3999 }
4000
4001 if (!operands[0] || !operands[1]) {
4002 return {};
4003 }
4004
4005 #define COMPARE_FOLDER(Op, comparison, Func) \
4006 if (direction == comparison) { \
4007 if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
4008 *this, operands)) \
4009 return folded; \
4010 if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>( \
4011 *this, operands)) \
4012 return folded; \
4013 }
4014
4015 COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
4016 COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
4017 COMPARE_FOLDER(CompareOp, "LT", less);
4018 COMPARE_FOLDER(CompareOp, "LE", less_equal);
4019 COMPARE_FOLDER(CompareOp, "GT", greater);
4020 COMPARE_FOLDER(CompareOp, "GE", greater_equal);
4021 #undef COMPARE_FOLDER
4022
4023 return {};
4024 }
4025
4026 //===----------------------------------------------------------------------===//
4027 // ScatterOp
4028 //===----------------------------------------------------------------------===//
4029
evaluateMhloRegion(Region & region,ArrayRef<Attribute> inputs)4030 llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
4031 ArrayRef<Attribute> inputs) {
4032 if (region.getNumArguments() != inputs.size()) return {};
4033
4034 llvm::DenseMap<Value, Attribute> values;
4035 values.reserve(region.getNumArguments());
4036 for (auto it : llvm::zip(region.getArguments(), inputs)) {
4037 values.try_emplace(std::get<0>(it), std::get<1>(it));
4038 }
4039
4040 for (auto& op : region.getOps()) {
4041 llvm::SmallVector<Attribute, 4> inputs;
4042 for (auto& operand : op.getOpOperands()) {
4043 inputs.push_back(values.lookup(operand.get()));
4044 }
4045 if (isa<ReturnOp>(op)) return inputs;
4046
4047 llvm::SmallVector<OpFoldResult, 4> results;
4048 if (failed(op.fold(inputs, results))) return {};
4049 for (auto it : llvm::zip(op.getResults(), results)) {
4050 if (!std::get<1>(it).is<Attribute>()) return {};
4051 values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
4052 }
4053 }
4054 return {};
4055 }
4056
fold(ArrayRef<Attribute> operands)4057 OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) {
4058 auto base = operands[0].dyn_cast_or_null<DenseElementsAttr>();
4059 auto index = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
4060 auto update = operands[2].dyn_cast_or_null<DenseElementsAttr>();
4061 if (!base || !index || !update) return {};
4062
4063 auto base_type = base.getType().dyn_cast<RankedTensorType>();
4064 auto index_type = index.getType().dyn_cast<RankedTensorType>();
4065 auto update_type = update.getType().dyn_cast<RankedTensorType>();
4066 if (!base_type || !index_type || !update_type) return {};
4067
4068 // Add the virtual trailing dimension of size 1 if index_vector_dim equals to
4069 // index_type.rank.
4070 const int64_t index_vector_dim =
4071 scatter_dimension_numbers().index_vector_dim().getInt();
4072 if (index_vector_dim == index_type.getRank()) {
4073 auto index_shape = index_type.getShape().vec();
4074 index_shape.push_back(1);
4075 index_type =
4076 RankedTensorType::get(index_shape, index_type.getElementType());
4077 index = index.reshape(index_type).cast<DenseIntElementsAttr>();
4078 }
4079
4080 // Increment the multi-dimensional index vector based on the limits for each
4081 // dimension specified by shape and returns false if the index rolled around
4082 // with true otherwise.
4083 auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
4084 llvm::ArrayRef<int64_t> shape) {
4085 for (int64_t i = index.size() - 1; i >= 0; --i) {
4086 ++index[i];
4087 if (index[i] < shape[i]) return true;
4088 index[i] = 0;
4089 }
4090 return false;
4091 };
4092
4093 // Iterate over all elements of the update tensor, then find the corresponding
4094 // value in the indices tensor to determine which location we have to update
4095 // in the base/result tensor.
4096 llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
4097 llvm::SmallVector<uint64_t, 8> update_index(update_type.getRank(), 0);
4098 llvm::SmallVector<uint64_t, 8> index_index;
4099 index_index.reserve(index_type.getRank());
4100 llvm::SmallVector<uint64_t, 8> base_index;
4101 base_index.reserve(base_type.getRank());
4102 do {
4103 // Compute the index for the slice of the indices tensor for this update
4104 // value.
4105 index_index.clear();
4106 if (index_vector_dim == 0) index_index.push_back(0);
4107 for (int64_t i = 0; i < update_index.size(); ++i) {
4108 if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0)
4109 index_index.push_back(update_index[i]);
4110 if (index_index.size() == index_vector_dim) index_index.push_back(0);
4111 }
4112
4113 // Compute the index for the given update value in the base tensor.
4114 base_index.assign(base_type.getRank(), 0);
4115 uint64_t index_count = index_type.getShape()[index_vector_dim];
4116 for (uint64_t i = 0; i < index_count; ++i) {
4117 uint64_t operand_dim = scatter_dimension_numbers()
4118 .scatter_dims_to_operand_dims()
4119 .getValue<APInt>({i})
4120 .getSExtValue();
4121 index_index[index_vector_dim] = i;
4122 base_index[operand_dim] +=
4123 index.getValue<APInt>(index_index).getSExtValue();
4124 }
4125 uint64_t update_window_dim_index = 0;
4126 for (uint64_t i = 0; i < base_index.size(); ++i) {
4127 if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i))
4128 continue;
4129 base_index[i] +=
4130 update_index[scatter_dimension_numbers()
4131 .update_window_dims()
4132 .getValue<APInt>({update_window_dim_index})
4133 .getSExtValue()];
4134 update_window_dim_index++;
4135 }
4136
4137 // Compute the linear index for the index into the base tensor.
4138 int64_t linear_base_index = 0;
4139 int64_t linear_base_index_multiplyer = 1;
4140 for (int64_t i = base_index.size() - 1; i >= 0; --i) {
4141 // Out of bound index have backend specific behaviour so avoid folding it.
4142 if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i])
4143 return {};
4144 linear_base_index += base_index[i] * linear_base_index_multiplyer;
4145 linear_base_index_multiplyer *= base_type.getShape()[i];
4146 }
4147
4148 // Evaluate update computation and update the value with the newly computed
4149 // attribute in the base tensor.
4150 auto lhs = DenseElementsAttr::get(
4151 RankedTensorType::get({}, base_type.getElementType()),
4152 results[linear_base_index]);
4153 auto rhs = DenseElementsAttr::get(
4154 RankedTensorType::get({}, base_type.getElementType()),
4155 update.getValue<Attribute>(update_index));
4156 auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs});
4157 if (new_value.size() != 1 || !new_value[0]) return {};
4158 results[linear_base_index] =
4159 new_value[0].cast<DenseElementsAttr>().getValue<Attribute>({});
4160 } while (next_index(update_index, update_type.getShape()));
4161
4162 return DenseElementsAttr::get(base_type, results);
4163 }
4164
4165 using mlir::hlo::parseWindowAttributes;
4166 using mlir::hlo::printWindowAttributes;
4167
4168 } // namespace mhlo
4169 } // namespace mlir
4170
4171 #define GET_OP_CLASSES
4172 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
4173
4174 namespace mlir {
4175 namespace mhlo {
4176
4177 //===----------------------------------------------------------------------===//
4178 // mhlo Dialect Interfaces
4179 //===----------------------------------------------------------------------===//
4180
4181 namespace {
4182 struct HLOInlinerInterface : public DialectInlinerInterface {
4183 using DialectInlinerInterface::DialectInlinerInterface;
4184
4185 // Allow all call operations to be inlined.
isLegalToInlinemlir::mhlo::__anon7800207c1a11::HLOInlinerInterface4186 bool isLegalToInline(Operation* call, Operation* callable,
4187 bool wouldBeCloned) const final {
4188 return true;
4189 }
4190 // We don't have any special restrictions on what can be inlined into
4191 // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInlinemlir::mhlo::__anon7800207c1a11::HLOInlinerInterface4192 bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
4193 BlockAndValueMapping& valueMapping) const final {
4194 return true;
4195 }
4196 // Operations in mhlo dialect are always legal to inline since they are
4197 // pure.
isLegalToInlinemlir::mhlo::__anon7800207c1a11::HLOInlinerInterface4198 bool isLegalToInline(Operation*, Region*, bool,
4199 BlockAndValueMapping&) const final {
4200 return true;
4201 }
4202 };
4203 } // end anonymous namespace
4204
4205 //===----------------------------------------------------------------------===//
4206 // mhlo Dialect Constructor
4207 //===----------------------------------------------------------------------===//
4208
MhloDialect(MLIRContext * context)4209 MhloDialect::MhloDialect(MLIRContext* context)
4210 : Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) {
4211 addOperations<
4212 #define GET_OP_LIST
4213 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
4214 >();
4215 addInterfaces<HLOInlinerInterface>();
4216 addTypes<TokenType>();
4217 context->loadDialect<tensor::TensorDialect>();
4218 }
4219
parseType(DialectAsmParser & parser) const4220 Type MhloDialect::parseType(DialectAsmParser& parser) const {
4221 StringRef data_type;
4222 if (parser.parseKeyword(&data_type)) return Type();
4223
4224 if (data_type == "token") return TokenType::get(getContext());
4225 parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
4226 return nullptr;
4227 }
4228
printType(Type type,DialectAsmPrinter & os) const4229 void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
4230 if (type.isa<TokenType>()) {
4231 os << "token";
4232 return;
4233 }
4234 os << "<unknown mhlo type>";
4235 }
4236
4237 //===----------------------------------------------------------------------===//
4238 // Shape inference
4239 //===----------------------------------------------------------------------===//
4240
deriveShapeFromOperand(OpBuilder * builder,Operation * op,Value operand,SmallVectorImpl<Value> * reifiedReturnShapes)4241 LogicalResult deriveShapeFromOperand(
4242 OpBuilder* builder, Operation* op, Value operand,
4243 SmallVectorImpl<Value>* reifiedReturnShapes) {
4244 auto shaped_ty = operand.getType().dyn_cast<ShapedType>();
4245 if (!shaped_ty) {
4246 op->emitOpError() << "operand is not a shaped type";
4247 return failure();
4248 }
4249 reifiedReturnShapes->assign(
4250 {builder->create<shape::ShapeOfOp>(op->getLoc(), operand)});
4251 return success();
4252 }
4253
4254 } // namespace mhlo
4255 } // namespace mlir
4256