1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines the operations used in the MHLO dialect.
17
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19
20 #include <assert.h>
21 #include <stddef.h>
22 #include <stdint.h>
23
24 #include <algorithm>
25 #include <functional>
26
27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/APInt.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/iterator_range.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/FormatVariadic.h"
37 #include "llvm/Support/MathExtras.h"
38 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
39 #include "mlir-hlo/utils/convert_op_folder.h"
40 #include "mlir-hlo/utils/hlo_utils.h"
41 #include "mlir/Dialect/Shape/IR/Shape.h"
42 #include "mlir/Dialect/StandardOps/IR/Ops.h"
43 #include "mlir/Dialect/Tensor/IR/Tensor.h"
44 #include "mlir/IR/Attributes.h"
45 #include "mlir/IR/Builders.h"
46 #include "mlir/IR/BuiltinTypes.h"
47 #include "mlir/IR/Dialect.h"
48 #include "mlir/IR/Location.h"
49 #include "mlir/IR/MLIRContext.h"
50 #include "mlir/IR/Matchers.h"
51 #include "mlir/IR/OpDefinition.h"
52 #include "mlir/IR/OpImplementation.h"
53 #include "mlir/IR/Operation.h"
54 #include "mlir/IR/OperationSupport.h"
55 #include "mlir/IR/PatternMatch.h"
56 #include "mlir/IR/TypeUtilities.h"
57 #include "mlir/IR/Types.h"
58 #include "mlir/IR/Value.h"
59 #include "mlir/Support/LLVM.h"
60 #include "mlir/Support/LogicalResult.h"
61 #include "mlir/Transforms/InliningUtils.h"
62
63 namespace mlir {
64 #include "hlo_patterns.cc.inc"
65 } // namespace mlir
66
67 namespace mlir {
68 namespace mhlo {
69
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)70 Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
71 Type type, Location loc) {
72 // HLO dialect constants only support ElementsAttr unlike standard dialect
73 // constant which supports all attributes.
74 if (value.isa<ElementsAttr>())
75 return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
76 return nullptr;
77 }
78
79 template <typename T>
Verify(T op)80 static LogicalResult Verify(T op) {
81 return success();
82 }
83
84 namespace {
85
86 //===----------------------------------------------------------------------===//
87 // Utilities for the canonicalize patterns
88 //===----------------------------------------------------------------------===//
89
90 // Verifies that dimension attribute for the op correctly indexes in operand or
91 // result shape.
92 template <typename OpT>
VerifyDimAttr(OpT op)93 static LogicalResult VerifyDimAttr(OpT op) {
94 int64_t rank = -1;
95 if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
96 rank = ty.getRank();
97 } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
98 rank = ty.getRank();
99 } else {
100 return success();
101 }
102
103 int64_t dim = op.dimension();
104 if (dim < 0 || dim >= rank)
105 return op.emitOpError() << "requires dimension attribute in range [0, "
106 << rank << "); found (" << dim << ")";
107 return success();
108 }
109
110 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)111 DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
112 Builder* builder) {
113 RankedTensorType ty = RankedTensorType::get(
114 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
115 return DenseIntElementsAttr::get(ty, values);
116 }
117
118 // Given the start indices and slice sizes for a dynamic-slice that can be
119 // converted to a static slice, returns the limits for the static slice.
BuildSliceLimits(DenseIntElementsAttr start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)120 DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
121 DenseIntElementsAttr slice_sizes,
122 Builder* builder) {
123 SmallVector<int64_t, 4> slice_limits;
124 for (int64_t i = 0; i < slice_sizes.getNumElements(); ++i) {
125 int64_t start_index = start_indices.getValue<IntegerAttr>(i).getInt();
126 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
127 slice_limits.push_back(start_index + slice_size);
128 }
129 return GetI64ElementsAttr(slice_limits, builder);
130 }
131
132 #include "mhlo_canonicalize.inc"
133 } // namespace
134
135 //===----------------------------------------------------------------------===//
136 // ConstOp
137 //===----------------------------------------------------------------------===//
138
fold(ArrayRef<Attribute> operands)139 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
140 assert(operands.empty() && "constant has no operands");
141
142 // Return the held attribute value.
143 return value();
144 }
145
146 // Builds a constant op with the specified attribute `value`.
build(OpBuilder & builder,OperationState & result,Attribute value)147 void ConstOp::build(OpBuilder& builder, OperationState& result,
148 Attribute value) {
149 Type type;
150 if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
151 type = elemAttr.getType();
152 } else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
153 value.isa<IntegerAttr>()) {
154 // All XLA types must be tensor types. In the build() method, we want to
155 // provide more flexibility by allowing attributes of scalar types. But we
156 // need to wrap it up with ElementsAttr to construct valid XLA constants.
157 type = RankedTensorType::get(/*shape=*/{}, value.getType());
158 value = DenseElementsAttr::get(type.cast<TensorType>(), value);
159 }
160
161 // TODO: support other XLA specific types.
162 assert(type && "unsupported attribute type for building mhlo.constant");
163 result.types.push_back(type);
164 result.addAttribute("value", value);
165 }
166
167 //===----------------------------------------------------------------------===//
168 // DotGeneralOp
169 //===----------------------------------------------------------------------===//
170
Verify(DotGeneralOp op)171 static LogicalResult Verify(DotGeneralOp op) {
172 auto dot_dimension_numbers = op.dot_dimension_numbers();
173 int64_t lhs_batching_dimensions_size = llvm::size(
174 dot_dimension_numbers.lhs_batching_dimensions().getValues<int64_t>());
175 int64_t rhs_batching_dimensions_size = llvm::size(
176 dot_dimension_numbers.rhs_batching_dimensions().getValues<int64_t>());
177 if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) {
178 return op.emitError()
179 << "lhs and rhs should have the same number of batching dimensions";
180 }
181 int64_t lhs_contracting_dimensions_size = llvm::size(
182 dot_dimension_numbers.lhs_contracting_dimensions().getValues<int64_t>());
183 int64_t rhs_contracting_dimensions_size = llvm::size(
184 dot_dimension_numbers.rhs_contracting_dimensions().getValues<int64_t>());
185 if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) {
186 return op.emitError() << "lhs and rhs should have the same number of "
187 "contracting dimensions";
188 }
189 return success();
190 }
191
192 //===----------------------------------------------------------------------===//
193 // GatherOp
194 //===----------------------------------------------------------------------===//
195
196 // Converts gather ops to slice ops in case we have a single set of constant
197 // indices.
198 struct GatherSlice : public OpRewritePattern<GatherOp> {
199 using OpRewritePattern<GatherOp>::OpRewritePattern;
200
matchAndRewritemlir::mhlo::GatherSlice201 LogicalResult matchAndRewrite(GatherOp gather,
202 PatternRewriter& rewriter) const override {
203 DenseIntElementsAttr index;
204 if (!matchPattern(gather.start_indices(), m_Constant(&index)))
205 return failure();
206
207 const auto& dnums = gather.dimension_numbers();
208 if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
209 return failure();
210
211 // TODO(tberghammer): Remove when the verifier catches this case what is
212 // invalid if all previous condition holds.
213 if (index.getNumElements() != dnums.start_index_map().getNumElements())
214 return failure();
215
216 auto slice_end =
217 llvm::to_vector<8>(gather.slice_sizes().getValues<int64_t>());
218 llvm::SmallVector<int64_t, 8> slice_start(slice_end.size(), 0);
219 for (auto it : llvm::zip(dnums.start_index_map().getIntValues(),
220 index.getIntValues())) {
221 int64_t map_index = std::get<0>(it).getSExtValue();
222 int64_t offset = std::get<1>(it).getSExtValue();
223 slice_start[map_index] += offset;
224 slice_end[map_index] += offset;
225 }
226
227 llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
228 llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
229 for (int64_t i = 0; i < slice_end.size(); ++i) {
230 slice_shape[i] = slice_end[i] - slice_start[i];
231 }
232 Type element_type = gather.getType().cast<TensorType>().getElementType();
233 auto slice_type = RankedTensorType::get(slice_shape, element_type);
234 Value result = rewriter.create<SliceOp>(
235 gather.getLoc(), slice_type, gather.getOperand(0),
236 GetI64ElementsAttr(slice_start, &rewriter),
237 GetI64ElementsAttr(slice_end, &rewriter),
238 GetI64ElementsAttr(slice_stride, &rewriter));
239
240 if (dnums.collapsed_slice_dims().getNumElements() > 0) {
241 auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
242 dnums.collapsed_slice_dims().getIntValues(),
243 [](const llvm::APInt& i) { return i.getSExtValue(); }));
244 llvm::SmallVector<int64_t, 8> reshape_shape;
245 for (int64_t i = 0; i < slice_shape.size(); ++i) {
246 if (llvm::count(collapsed_slice_dims, i) == 0) {
247 reshape_shape.push_back(slice_shape[i]);
248 }
249 }
250 auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
251 result =
252 rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
253 }
254
255 result.setType(gather.getType());
256 rewriter.replaceOp(gather, result);
257 return success();
258 }
259 };
260
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)261 void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
262 MLIRContext* context) {
263 results.insert<GatherSlice>(context);
264 }
265
266 //===----------------------------------------------------------------------===//
267 // GetDimensionSizeOp
268 //===----------------------------------------------------------------------===//
269 //
Verify(GetDimensionSizeOp op)270 static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
271
272 /// Fold get_dimension_size when the said shape dimension is a constant.
fold(ArrayRef<Attribute> attrs)273 OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
274 RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
275 if (!type) return {};
276
277 int32_t dim = dimension();
278 if (type.isDynamic(dim)) return {};
279 // The result type is always is a 0-d i32 tensor.
280 return DenseIntElementsAttr::get<int32_t>(
281 getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
282 }
283
284 //===----------------------------------------------------------------------===//
285 // IotaOp
286 //===----------------------------------------------------------------------===//
287
Verify(IotaOp op)288 static LogicalResult Verify(IotaOp op) {
289 auto shape = op.getType().cast<ShapedType>();
290 if (!shape.hasRank()) return success();
291
292 if (shape.getRank() == 0)
293 return op.emitOpError() << "does not support scalars.";
294
295 auto iota_dimension = op.iota_dimension();
296 if (iota_dimension >= shape.getRank() || iota_dimension < 0)
297 return op.emitOpError() << "iota dimension cannot go beyond the output "
298 "rank or be negative.";
299 return success();
300 }
301
302 // Iota operations across multiple dimensions can be reduced to an iota and a
303 // ranked broadcast.
304 struct IotaBroadcast : public OpRewritePattern<IotaOp> {
305 using OpRewritePattern<IotaOp>::OpRewritePattern;
306
matchAndRewritemlir::mhlo::IotaBroadcast307 LogicalResult matchAndRewrite(IotaOp iota,
308 PatternRewriter& rewriter) const override {
309 auto result_ty = iota.getType().cast<ShapedType>();
310 if (!result_ty.hasRank() || result_ty.getRank() < 2) {
311 return failure();
312 }
313
314 auto iota_dimension = iota.iota_dimension();
315
316 auto iota_type = RankedTensorType::get(
317 {result_ty.getDimSize(iota_dimension)}, result_ty.getElementType());
318
319 auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
320 rewriter.getI64IntegerAttr(0));
321
322 auto broadcast_attr = DenseIntElementsAttr::get(
323 RankedTensorType::get({1}, rewriter.getIntegerType(64)),
324 {iota_dimension});
325 rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
326 broadcast_attr);
327 return success();
328 }
329 };
330
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)331 void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
332 MLIRContext* context) {
333 results.insert<IotaBroadcast>(context);
334 }
335
fold(ArrayRef<Attribute> operands)336 OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
337 auto dimension = iota_dimension();
338 auto result_ty = getResult().getType().cast<ShapedType>();
339 if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
340 Builder builder(getContext());
341 return builder.getZeroAttr(result_ty);
342 }
343
344 return {};
345 }
346
347 //===----------------------------------------------------------------------===//
348 // DynamicIotaOp
349 //===----------------------------------------------------------------------===//
350
351 namespace {
352
353 struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
354 using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
355
matchAndRewritemlir::mhlo::__anon5ac85cd70311::DynamicIotaIsStatic356 LogicalResult matchAndRewrite(DynamicIotaOp iota,
357 PatternRewriter& rewriter) const override {
358 auto result_ty = iota.getType().cast<ShapedType>();
359 if (!result_ty.hasStaticShape()) {
360 return failure();
361 }
362
363 rewriter.replaceOpWithNewOp<IotaOp>(iota, result_ty, iota.iota_dimension());
364 return success();
365 }
366 };
367
368 // Dynamic Iota operations across multiple dimensions can be reduced to an iota
369 // and a ranked broadcast.
370 struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
371 using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
372
matchAndRewritemlir::mhlo::__anon5ac85cd70311::DynamicIotaBroadcast373 LogicalResult matchAndRewrite(DynamicIotaOp iota,
374 PatternRewriter& rewriter) const override {
375 auto result_ty = iota.getType().cast<ShapedType>();
376 if (!result_ty.hasRank() || result_ty.getRank() < 2) {
377 return failure();
378 }
379
380 auto iota_dimension = iota.iota_dimension();
381 auto iota_dimension_int = iota_dimension;
382
383 auto converted_shape = rewriter.create<IndexCastOp>(
384 iota.getLoc(),
385 RankedTensorType::get(
386 iota.output_shape().getType().cast<ShapedType>().getShape(),
387 rewriter.getI64Type()),
388 iota.output_shape());
389
390 auto sliced_shape = rewriter.create<SliceOp>(
391 iota.getLoc(), converted_shape,
392 GetI64ElementsAttr(iota_dimension_int, &rewriter),
393 GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
394 GetI64ElementsAttr(1, &rewriter));
395
396 auto converted_sliced_shape = rewriter.create<IndexCastOp>(
397 iota.getLoc(),
398 RankedTensorType::get(
399 {1},
400 iota.output_shape().getType().cast<ShapedType>().getElementType()),
401 sliced_shape);
402
403 auto iota_type = RankedTensorType::get(
404 {result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
405
406 auto new_iota = rewriter.create<DynamicIotaOp>(
407 iota.getLoc(), iota_type, converted_sliced_shape,
408 rewriter.getI64IntegerAttr(0));
409
410 auto broadcast_attr = DenseIntElementsAttr::get(
411 RankedTensorType::get({1}, rewriter.getIntegerType(64)),
412 {iota_dimension});
413 rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
414 iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
415 return success();
416 }
417 };
418
419 } // namespace
420
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)421 void DynamicIotaOp::getCanonicalizationPatterns(
422 OwningRewritePatternList& results, MLIRContext* context) {
423 results.insert<DynamicIotaIsStatic>(context);
424 results.insert<DynamicIotaBroadcast>(context);
425 }
426
427 //===----------------------------------------------------------------------===//
428 // DynamicUpdateSliceOp
429 //===----------------------------------------------------------------------===//
430
Verify(DynamicUpdateSliceOp op)431 static LogicalResult Verify(DynamicUpdateSliceOp op) {
432 OperandRange indices = op.start_indices();
433 if (indices.size() <= 1) return success();
434
435 // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
436 // is OK to cast indices to ShapedType here.
437 auto idx_tensor = indices.take_front().front().getType().cast<ShapedType>();
438 Type first_elem_ty = idx_tensor.getElementType();
439 Type elem_ty;
440
441 for (auto idx : llvm::drop_begin(indices, 1)) {
442 idx_tensor = idx.getType().cast<ShapedType>();
443 elem_ty = idx_tensor.getElementType();
444
445 if (first_elem_ty != elem_ty) {
446 return op.emitOpError() << "start indices must have same element type "
447 "(encountered mismatch: "
448 << first_elem_ty << " vs " << elem_ty << ")";
449 }
450 }
451 return success();
452 }
453
454 //===----------------------------------------------------------------------===//
455 // AbsOp
456 //===----------------------------------------------------------------------===//
457
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)458 LogicalResult AbsOp::inferReturnTypes(
459 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
460 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
461 auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
462 Type element_ty = operand_ty.getElementType();
463 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
464 element_ty = complex_ty.getElementType();
465 }
466
467 Type result_ty;
468 if (operand_ty.hasRank()) {
469 result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
470 } else {
471 result_ty = UnrankedTensorType::get(element_ty);
472 }
473 inferredReturnTypes.push_back(result_ty);
474 return success();
475 }
476
477 //===----------------------------------------------------------------------===//
478 // CollectivePermuteOp
479 //===----------------------------------------------------------------------===//
480
Verify(CollectivePermuteOp op)481 static LogicalResult Verify(CollectivePermuteOp op) {
482 // Check that source target pair is Nx2 tensor.
483 auto type = op.source_target_pairs().getType().dyn_cast<RankedTensorType>();
484 if (type.getRank() != 2)
485 return op.emitError() << "expect source_target_pairs attribute to be of "
486 "rank 2, but got rank "
487 << type.getRank();
488 if (type.getShape()[1] != 2)
489 return op.emitError()
490 << "expect source_target_pairs attribute of shape (N, 2), but got ("
491 << type.getShape() << ")";
492 // Check source target pairs for duplicate sources or targets
493 llvm::DenseSet<int64_t> sources;
494 llvm::DenseSet<int64_t> targets;
495 for (auto i = op.source_target_pairs().begin(),
496 e = op.source_target_pairs().end();
497 i != e; ++i) {
498 auto val = (*i).getSExtValue();
499 if (i.getIndex() % 2 == 0) {
500 bool is_unique = sources.insert(val).second;
501 if (!is_unique) return op.emitError() << "duplicate sources not allowed.";
502 } else {
503 bool is_unique = targets.insert(val).second;
504 if (!is_unique) return op.emitError() << "duplicate targets not allowed.";
505 }
506 }
507 return success();
508 }
509
510 //===----------------------------------------------------------------------===//
511 // ConvertOp
512 //===----------------------------------------------------------------------===//
513
build(OpBuilder & builder,OperationState & result,Value operand,Type result_element_ty)514 void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
515 Type result_element_ty) {
516 Type result_ty;
517 Type operand_ty = operand.getType();
518 if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
519 result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
520 } else {
521 result_ty = UnrankedTensorType::get(result_element_ty);
522 }
523 build(builder, result, result_ty, operand);
524 }
525
fold(ArrayRef<Attribute> operands)526 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
527 auto operand_ty = getOperand().getType().cast<TensorType>();
528 auto result_ty = getResult().getType().cast<TensorType>();
529 if (operand_ty == result_ty) return getOperand();
530
531 // If the result has non-static shape, a convert op is necessary to go from
532 // static shape to non-static shape.
533 if (!result_ty.hasStaticShape()) return {};
534
535 // TODO(hinsu): Handle unsigned types.
536 if (operand_ty.getElementType().isUnsignedInteger() ||
537 result_ty.getElementType().isUnsignedInteger()) {
538 return {};
539 }
540
541 // If the operand is constant, we can do the conversion now.
542 if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
543 return hlo::ConvertElementsAttr(elementsAttr,
544 getElementTypeOrSelf(getResult()));
545 }
546
547 return {};
548 }
549
550 //===----------------------------------------------------------------------===//
551 // DequantizeOp
552 //===----------------------------------------------------------------------===//
553
Verify(DequantizeOp op)554 static LogicalResult Verify(DequantizeOp op) {
555 auto input_type = op.input().getType().dyn_cast<ShapedType>();
556 auto output_type = op.output().getType().dyn_cast<ShapedType>();
557 if (!input_type || !output_type) {
558 return op.emitError() << "ranked input and output.";
559 }
560 auto input_shape = input_type.getShape();
561 auto output_shape = output_type.getShape().vec();
562 if (op.transpose_output()) {
563 std::reverse(output_shape.begin(), output_shape.end());
564 }
565
566 // Check the input rank and output rank are same, and also the lower
567 // dimensions are same.
568 if (input_shape.size() != output_shape.size() ||
569 !std::equal(input_shape.begin(),
570 std::next(input_shape.begin(), input_shape.size() - 1),
571 output_shape.begin())) {
572 return op.emitError() << "mismatched dimensions.";
573 }
574
575 // Check that the last dimension of the output is 2x or 4x of that of the
576 // input depending on the unpacked input is 16 or 8 bits.
577 int input_last_dim = *input_shape.rbegin();
578 int output_last_dim = *output_shape.rbegin();
579 int scale_factor = op.is_16bits() ? 2 : 4;
580 if (output_last_dim != scale_factor * input_last_dim) {
581 return op.emitError() << "last dimension of output should be "
582 << scale_factor << "x of the input.";
583 }
584
585 return success();
586 }
587
588 //===----------------------------------------------------------------------===//
589 // GetTupleElementOp
590 //===----------------------------------------------------------------------===//
591
Verify(GetTupleElementOp op)592 static LogicalResult Verify(GetTupleElementOp op) {
593 auto indexVal = op.index();
594 auto operandType = op.getOperand().getType().cast<TupleType>();
595 if (indexVal >= operandType.size()) {
596 return op.emitOpError(
597 llvm::formatv("index {0} is out of bounds of operand with size {1}",
598 indexVal, operandType.size()));
599 }
600
601 auto expectedType = operandType.getType(indexVal);
602 if (op.getType() != expectedType) {
603 return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
604 op.getType(), expectedType));
605 }
606 return success();
607 }
608
fold(ArrayRef<Attribute> operands)609 OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
610 if (auto tupleOp =
611 dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
612 return tupleOp.getOperand(index());
613 }
614
615 return {};
616 }
617
618 //===----------------------------------------------------------------------===//
619 // TupleOp
620 //===----------------------------------------------------------------------===//
621
Verify(TupleOp op)622 static LogicalResult Verify(TupleOp op) {
623 SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
624 op.operand_type_end()};
625 auto expectedType = TupleType::get(op.getContext(), operandTypes);
626 if (op.getType() != expectedType) {
627 return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
628 op.getType(), expectedType));
629 }
630 return success();
631 }
632
633 namespace {
634
635 // Pattern for unpacking and repacking the same tuple.
636 struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
637 using OpRewritePattern<TupleOp>::OpRewritePattern;
638
matchAndRewritemlir::mhlo::__anon5ac85cd70411::UnpackRepackSameTuple639 LogicalResult matchAndRewrite(TupleOp op,
640 PatternRewriter& rewriter) const override {
641 if (op.val().empty()) return failure();
642
643 Value first_element = op.val().front();
644 auto first_element_op =
645 dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
646 if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
647 return failure();
648
649 Value tuple_predecessor = first_element_op.getOperand();
650 if (tuple_predecessor.getType() != op.getType()) return failure();
651
652 for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
653 auto element_op = dyn_cast_or_null<GetTupleElementOp>(
654 element_and_idx.value().getDefiningOp());
655 if (!element_op ||
656 element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
657 element_op.getOperand() != tuple_predecessor)
658 return failure();
659 }
660
661 rewriter.replaceOp(op, tuple_predecessor);
662 return success();
663 }
664 };
665
666 } // namespace
667
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)668 void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
669 MLIRContext* context) {
670 results.insert<UnpackRepackSameTuple>(context);
671 }
672
673 //===----------------------------------------------------------------------===//
674 // AllToAllOp
675 //===----------------------------------------------------------------------===//
676
Verify(AllToAllOp op)677 static LogicalResult Verify(AllToAllOp op) {
678 // If operand is ranked, size of split dimension should be a multiple of split
679 // count.
680 auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
681 if (!type) return success();
682 auto split_dim_size = type.getDimSize(op.split_dimension());
683 auto split_count = op.split_count();
684 if (split_dim_size % split_count != 0) {
685 return op.emitError() << "split dimension has size " << split_dim_size
686 << ", expected to be a multiple of split_count "
687 << split_count;
688 }
689 return success();
690 }
691
692 //===----------------------------------------------------------------------===//
693 // BroadcastOp
694 //===----------------------------------------------------------------------===//
695
696 // TODO(b/129012527) These should be expressed as type constraints.
Verify(BroadcastOp op)697 static LogicalResult Verify(BroadcastOp op) {
698 auto sizes = op.broadcast_sizes();
699 auto sizesType = sizes.getType();
700 auto sizesRank = sizesType.getRank();
701 if (sizesRank != 1) {
702 return op.emitOpError(llvm::formatv(
703 "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
704 }
705
706 auto resultType = op.getResult().getType().cast<RankedTensorType>();
707 auto resultRank = resultType.getRank();
708 auto operandType = op.operand().getType().cast<RankedTensorType>();
709 auto operandRank = operandType.getRank();
710 auto sizesSize = sizesType.getNumElements();
711 auto expectedRank = operandRank + sizesSize;
712
713 if (resultRank != expectedRank) {
714 return op.emitOpError(
715 llvm::formatv("result rank ({0}) does not match operand rank "
716 "({1}) plus size of broadcast_sizes ({2})",
717 resultRank, operandRank, sizesSize));
718 }
719
720 llvm::SmallVector<int64_t, 10> expectedShape(sizes.getValues<int64_t>());
721
722 auto operandShape = operandType.getShape();
723 expectedShape.insert(expectedShape.end(), operandShape.begin(),
724 operandShape.end());
725
726 auto resultShape = resultType.getShape();
727 if (resultShape != llvm::makeArrayRef(expectedShape)) {
728 return op.emitOpError(llvm::formatv(
729 "result has shape [{0}] instead of [{1}]",
730 llvm::make_range(resultShape.begin(), resultShape.end()),
731 llvm::make_range(expectedShape.begin(), expectedShape.end())));
732 }
733
734 return success();
735 }
736
737 //===----------------------------------------------------------------------===//
738 // BroadcastInDimOp
739 //===----------------------------------------------------------------------===//
740
Verify(BroadcastInDimOp op)741 static LogicalResult Verify(BroadcastInDimOp op) {
742 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
743 if (!operandType) {
744 // The following verification checks all depend on knowing the rank of
745 // the operand. Bail out now if we don't know the rank of the operand.
746 return success();
747 }
748
749 auto operandRank = operandType.getRank();
750 if (!op.broadcast_dimensions()) {
751 if (operandRank == 0) {
752 return success();
753 }
754 return op.emitOpError(
755 llvm::formatv("broadcast_dimensions is absent, but required because "
756 "operand has non-zero rank ({0})",
757 operandRank));
758 }
759
760 auto dimensions = op.broadcast_dimensions();
761 auto dimensionsType = op.broadcast_dimensions().getType();
762 auto dimensionsRank = dimensionsType.getRank();
763 if (dimensionsRank != 1) {
764 return op.emitOpError(llvm::formatv(
765 "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
766 }
767
768 auto dimensionsSize = dimensionsType.getNumElements();
769 if (dimensionsSize != operandRank) {
770 return op.emitOpError(llvm::formatv(
771 "broadcast_dimensions size ({0}) does not match operand rank ({1})",
772 dimensionsSize, operandRank));
773 }
774
775 auto resultType = op.getResult().getType().cast<RankedTensorType>();
776 auto resultRank = resultType.getRank();
777 if (resultRank < operandRank) {
778 return op.emitOpError(
779 llvm::formatv("result rank ({0}) is less than operand rank ({1})",
780 resultRank, operandRank));
781 }
782
783 for (int i = 0; i != dimensionsSize; ++i) {
784 auto dimIndex = dimensions.getValue<int64_t>(i);
785 if (dimIndex >= resultRank) {
786 return op.emitOpError(
787 llvm::formatv("broadcast_dimensions contains invalid value {0} for "
788 "result with rank {1}",
789 dimIndex, resultRank));
790 }
791
792 if (!operandType.isDynamicDim(i)) {
793 auto dimSize = operandType.getDimSize(i);
794 auto resultDimSize = resultType.getDimSize(dimIndex);
795 if (dimSize != 1 && dimSize != resultDimSize) {
796 return op.emitOpError(
797 llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
798 "1 or size of result dimension {2} ({3})",
799 i, dimSize, dimIndex, resultDimSize));
800 }
801 }
802 }
803
804 return success();
805 }
806
fold(ArrayRef<Attribute> attrs)807 OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
808 auto type = getType().cast<RankedTensorType>();
809 if (type == getOperand().getType()) {
810 auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
811 if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
812 llvm::seq<int64_t>(0, type.getRank()).begin())) {
813 return {};
814 }
815 return getOperand();
816 }
817
818 // Constant fold when an operand is a splat tensor attribute.
819 if (!attrs[0] || !type.hasStaticShape()) return {};
820 auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
821 if (!splatOperandAttr) return {};
822 // MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element
823 // attribute iterator not implemented for complex element types.
824 if (type.getElementType().isa<ComplexType>()) return {};
825 return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue());
826 }
827
828 //===----------------------------------------------------------------------===//
829 // DynamicBroadcastInDimOp
830 //===----------------------------------------------------------------------===//
831
Verify(DynamicBroadcastInDimOp op)832 static LogicalResult Verify(DynamicBroadcastInDimOp op) {
833 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
834 auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
835
836 // If either the operand or result are unranked, there is very little
837 // to verify statically.
838 if (!operandType || !resultType) {
839 return success();
840 }
841
842 auto outputDimensionsType =
843 op.output_dimensions().getType().cast<RankedTensorType>();
844 auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
845 auto operandRank = operandType.getRank();
846 auto resultRank = resultType.getRank();
847
848 // Verify broadcast_dimensions.
849 auto bcastDimensions = op.broadcast_dimensions();
850 auto bcastDimensionsType = op.broadcast_dimensions().getType();
851 auto bcastDimensionsRank = bcastDimensionsType.getRank();
852 // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
853 if (bcastDimensionsRank != 1) {
854 return op.emitOpError(
855 llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
856 bcastDimensionsRank));
857 }
858
859 auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
860 if (bcastDimensionsSize != operandRank) {
861 return op.emitOpError(llvm::formatv(
862 "broadcast_dimensions size ({0}) does not match operand rank ({1})",
863 bcastDimensionsSize, operandRank));
864 }
865
866 if (resultRank < operandRank) {
867 return op.emitOpError(
868 llvm::formatv("result rank ({0}) is less than operand rank ({1})",
869 resultRank, operandRank));
870 }
871
872 for (int i = 0; i != bcastDimensionsSize; ++i) {
873 auto dimIndex = bcastDimensions.getValue<int64_t>(i);
874 if (dimIndex >= resultRank) {
875 return op.emitOpError(
876 llvm::formatv("broadcast_dimensions contains invalid value {0} for "
877 "result with rank {1}",
878 dimIndex, resultRank));
879 }
880
881 auto dimSize = operandType.getDimSize(i);
882 auto resultDimSize = resultType.getDimSize(dimIndex);
883 // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
884 // add a manual check for this.
885 if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
886 return op.emitOpError(
887 llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
888 "with size of result dimension {2} ({3})",
889 i, dimSize, dimIndex, resultDimSize));
890 }
891 }
892
893 if (outputDimensionsSize != resultRank) {
894 return op.emitOpError(
895 llvm::formatv("result rank ({0}) is not equal to number of output "
896 "dimensions ({1})",
897 resultRank, outputDimensionsSize));
898 }
899
900 return success();
901 }
902
903 // If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
904 // BroadcastInDimOp.
905 class DynamicBroadcastInDimOpNotActuallyDynamic
906 : public OpRewritePattern<DynamicBroadcastInDimOp> {
907 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicBroadcastInDimOp op,PatternRewriter & rewriter) const908 LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
909 PatternRewriter& rewriter) const override {
910 auto type = op.getType().dyn_cast<RankedTensorType>();
911 if (!type || !type.hasStaticShape()) {
912 return rewriter.notifyMatchFailure(op, "requires static shape");
913 }
914 rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
915 op, op.getType(), op.operand(), op.broadcast_dimensions());
916 return success();
917 }
918 };
919
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)920 void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
921 OwningRewritePatternList& results, MLIRContext* context) {
922 results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
923 DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2>(
924 context);
925 }
926
927 //===----------------------------------------------------------------------===//
928 // ClampOp
929 //===----------------------------------------------------------------------===//
930
Verify(ClampOp op)931 static LogicalResult Verify(ClampOp op) {
932 auto operandType = op.operand().getType().cast<RankedTensorType>();
933 auto operandShape = operandType.getShape();
934 auto minType = op.min().getType().cast<RankedTensorType>();
935
936 auto minShape = minType.getShape();
937 if (minShape != operandShape && minType.getRank() != 0) {
938 return op.emitOpError(llvm::formatv(
939 "min shape [{0}] is not scalar and does not match operand shape [{1}]",
940 llvm::make_range(minShape.begin(), minShape.end()),
941 llvm::make_range(operandShape.begin(), operandShape.end())));
942 }
943
944 auto maxType = op.max().getType().cast<RankedTensorType>();
945 auto maxShape = maxType.getShape();
946 if (maxShape != operandShape && maxType.getRank() != 0) {
947 return op.emitOpError(llvm::formatv(
948 "max shape [{0}] is not scalar and does not match operand shape [{1}]",
949 llvm::make_range(maxShape.begin(), maxShape.end()),
950 llvm::make_range(operandShape.begin(), operandShape.end())));
951 }
952
953 return success();
954 }
955
956 //===----------------------------------------------------------------------===//
957 // ComplexOp
958 //===----------------------------------------------------------------------===//
959
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)960 LogicalResult ComplexOp::inferReturnTypes(
961 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
962 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
963 auto type = operands[0].getType();
964 auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
965 Type result_ty;
966 if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
967 result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
968 } else if (type.isa<UnrankedTensorType>()) {
969 result_ty = UnrankedTensorType::get(element_ty);
970 } else {
971 result_ty = element_ty;
972 }
973 inferredReturnTypes.push_back(result_ty);
974 return success();
975 }
976
fold(ArrayRef<Attribute> operands)977 OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
978 auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
979 auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
980 if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
981 return real_op.getOperand();
982 }
983
984 return {};
985 }
986
987 //===----------------------------------------------------------------------===//
988 // ImagOp
989 //===----------------------------------------------------------------------===//
990
991 namespace {
CreateRealType(Type type)992 Type CreateRealType(Type type) {
993 auto element_ty = getElementTypeOrSelf(type);
994 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
995 element_ty = complex_ty.getElementType();
996 }
997
998 if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
999 return RankedTensorType::get(ranked_type.getShape(), element_ty);
1000 } else if (type.dyn_cast<UnrankedTensorType>()) {
1001 return UnrankedTensorType::get(element_ty);
1002 }
1003
1004 return element_ty;
1005 }
1006 } // namespace
1007
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1008 LogicalResult ImagOp::inferReturnTypes(
1009 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1010 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1011 inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
1012 return success();
1013 }
1014
fold(ArrayRef<Attribute> operands)1015 OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
1016 if (auto complex_op =
1017 dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
1018 return complex_op.getOperand(1);
1019 }
1020
1021 return {};
1022 }
1023
1024 //===----------------------------------------------------------------------===//
1025 // IsFiniteOp
1026 //===----------------------------------------------------------------------===//
1027
getSameShapeTensorType(TensorType tensor_type,Type element_type)1028 TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) {
1029 if (auto ranked_tensor_ty = tensor_type.dyn_cast<RankedTensorType>()) {
1030 return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type);
1031 }
1032 if (auto unranked_tensor_ty = tensor_type.dyn_cast<UnrankedTensorType>()) {
1033 return UnrankedTensorType::get(element_type);
1034 }
1035 llvm_unreachable("unhandled type");
1036 }
1037
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1038 LogicalResult IsFiniteOp::inferReturnTypes(
1039 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
1040 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1041 auto arg_ty = operands.front().getType().cast<TensorType>();
1042 Builder b(ctx);
1043 inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type()));
1044 return success();
1045 }
1046
1047 //===----------------------------------------------------------------------===//
1048 // RealOp
1049 //===----------------------------------------------------------------------===//
1050
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1051 LogicalResult RealOp::inferReturnTypes(
1052 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1053 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1054 inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
1055 return success();
1056 }
1057
fold(ArrayRef<Attribute> operands)1058 OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
1059 if (auto complex_op =
1060 dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
1061 return complex_op.getOperand(0);
1062 }
1063
1064 return {};
1065 }
1066
1067 //===----------------------------------------------------------------------===//
1068 // ConcatenateOp
1069 //===----------------------------------------------------------------------===//
1070
1071 namespace {
1072 class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
1073 public:
1074 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(ConcatenateOp op,PatternRewriter & rewriter) const1075 LogicalResult matchAndRewrite(ConcatenateOp op,
1076 PatternRewriter& rewriter) const override {
1077 auto axis = op.dimension();
1078 llvm::SmallVector<Value, 6> new_operands;
1079 for (auto operand : op.getOperands()) {
1080 auto ty = operand.getType().cast<ShapedType>();
1081 if (ty.getDimSize(axis) != 0) {
1082 new_operands.push_back(operand);
1083 }
1084 }
1085
1086 if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
1087 rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
1088 new_operands, op.dimension());
1089 return success();
1090 }
1091
1092 return failure();
1093 }
1094 };
1095 } // namespace
1096
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1097 LogicalResult ConcatenateOp::inferReturnTypes(
1098 MLIRContext*, Optional<Location> location, ValueRange operands,
1099 DictionaryAttr attributes, RegionRange regions,
1100 SmallVectorImpl<Type>& inferredReturnTypes) {
1101 if (operands.empty()) {
1102 return failure();
1103 }
1104
1105 auto dimension_attr = attributes.get("dimension").cast<IntegerAttr>();
1106 auto dimension = dimension_attr.getInt();
1107
1108 auto first_type = (*operands.begin()).getType().cast<ShapedType>();
1109 auto out_element = first_type.getElementType();
1110
1111 for (auto operand : operands.getTypes()) {
1112 auto element_type = getElementTypeOrSelf(operand);
1113 if (element_type != out_element) {
1114 return failure();
1115 }
1116 }
1117
1118 // Find the first ranked input to determine the output rank.
1119 for (auto type : operands.getTypes()) {
1120 auto shaped_type = type.cast<ShapedType>();
1121 if (shaped_type.hasRank()) {
1122 first_type = shaped_type;
1123 break;
1124 }
1125 }
1126
1127 // If all inputs are unranked, the result must be unranked.
1128 if (!first_type.hasRank()) {
1129 inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
1130 return success();
1131 }
1132
1133 if (first_type.getRank() == 0)
1134 return emitOptionalError(location, "rank-0 values cannot be concatenated");
1135
1136 auto out_shape = llvm::to_vector<6>(first_type.getShape());
1137
1138 // Determine what the non-concatenate dimensions should be.
1139 for (auto type : operands.getTypes()) {
1140 auto shaped_ty = type.cast<ShapedType>();
1141 if (!shaped_ty.hasRank()) {
1142 continue;
1143 }
1144
1145 for (auto it : llvm::enumerate(shaped_ty.getShape())) {
1146 // If a dimension is not dynamic, the output shape should match.
1147 if (ShapedType::isDynamic(out_shape[it.index()])) {
1148 out_shape[it.index()] = it.value();
1149 }
1150 }
1151 }
1152
1153 out_shape[dimension] = 0;
1154
1155 for (auto operand : operands.getTypes()) {
1156 auto type = operand.cast<ShapedType>();
1157 if (!type.hasRank()) {
1158 inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
1159 return success();
1160 }
1161
1162 // If the dimension is dynamic we know the output dimension is dynamic.
1163 auto dim = type.getShape()[dimension];
1164 if (dim == -1) {
1165 out_shape[dimension] = -1;
1166 break;
1167 }
1168
1169 out_shape[dimension] += dim;
1170 }
1171
1172 inferredReturnTypes.push_back(RankedTensorType::get(out_shape, out_element));
1173
1174 return success();
1175 }
1176
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1177 void ConcatenateOp::getCanonicalizationPatterns(
1178 OwningRewritePatternList& results, MLIRContext* context) {
1179 results.insert<ConcatenateOperandRemoval>(context);
1180 }
1181
1182 template <typename T>
foldConcatenateHelper(ConcatenateOp * op,ArrayRef<Attribute> operands)1183 static Attribute foldConcatenateHelper(ConcatenateOp* op,
1184 ArrayRef<Attribute> operands) {
1185 auto axis = op->dimension();
1186 auto type = op->getType().cast<ShapedType>();
1187
1188 SmallVector<T, 6> values;
1189 auto shape = type.getShape();
1190
1191 size_t top_size = 1;
1192 for (int i = 0, e = axis; i < e; i++) {
1193 top_size = top_size * shape[i];
1194 }
1195
1196 for (size_t i = 0; i < top_size; i++) {
1197 for (auto operand : operands) {
1198 DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
1199 size_t bottom_size = attr.getNumElements() / top_size;
1200 auto iter = attr.getValues<T>().begin() + i * bottom_size;
1201 values.append(iter, iter + bottom_size);
1202 }
1203 }
1204
1205 return DenseElementsAttr::get(type, values);
1206 }
1207
foldConcatenate(ConcatenateOp * op,ArrayRef<Attribute> operands)1208 static Attribute foldConcatenate(ConcatenateOp* op,
1209 ArrayRef<Attribute> operands) {
1210 for (auto operand : operands) {
1211 if (!operand) return {};
1212 }
1213
1214 auto type = op->getResult().getType().cast<ShapedType>();
1215 auto etype = type.getElementType();
1216 if (etype.isa<IntegerType>()) {
1217 return foldConcatenateHelper<APInt>(op, operands);
1218 }
1219
1220 if (etype.isa<FloatType>()) {
1221 return foldConcatenateHelper<APFloat>(op, operands);
1222 }
1223
1224 return {};
1225 }
1226
fold(ArrayRef<Attribute> operands)1227 OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
1228 if (getNumOperands() == 1) return getOperand(0);
1229
1230 ShapedType type = getResult().getType().cast<ShapedType>();
1231 if (!type.hasStaticShape()) return {};
1232
1233 auto axis = dimension();
1234 if (auto attr = foldConcatenate(this, operands)) {
1235 return attr;
1236 }
1237
1238 llvm::SmallVector<Value, 6> new_operands;
1239 for (auto operand : getOperands()) {
1240 auto ty = operand.getType().cast<ShapedType>();
1241 if (ty.getDimSize(axis) != 0) {
1242 return {};
1243 }
1244 }
1245
1246 return DenseElementsAttr::get(type, ArrayRef<Attribute>());
1247 }
1248
Verify(ConcatenateOp op)1249 static LogicalResult Verify(ConcatenateOp op) {
1250 Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
1251 RankedTensorType first_ranked_type;
1252 int num_operands = op.getNumOperands();
1253 for (int i = 0; i < num_operands; i++) {
1254 auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
1255 if (second_type.getElementType() != element_type) {
1256 return op.emitOpError(
1257 llvm::formatv("operands (0) and ({0}) do not match element type", i));
1258 }
1259
1260 if (!second_type.hasRank()) {
1261 continue;
1262 }
1263
1264 if (!first_ranked_type) {
1265 first_ranked_type = second_type.cast<RankedTensorType>();
1266 continue;
1267 }
1268
1269 if (first_ranked_type.getRank() != second_type.getRank()) {
1270 return op.emitOpError(
1271 llvm::formatv("operands (0) and ({0}) do not match rank", i));
1272 }
1273
1274 auto first_shape = second_type.getShape();
1275 auto second_shape = second_type.getShape();
1276 for (int d = 0; d < first_ranked_type.getRank(); ++d) {
1277 if (first_shape[d] != second_shape[d] && d != op.dimension()) {
1278 return op.emitOpError(llvm::formatv(
1279 "operands (0) and ({0}) non-concat dimensions do not match "
1280 "({1}) != ({2})",
1281 i, llvm::make_range(first_shape.begin(), first_shape.end()),
1282 llvm::make_range(second_shape.begin(), second_shape.end())));
1283 }
1284 }
1285 }
1286 return success();
1287 }
1288
1289 //===----------------------------------------------------------------------===//
1290 // DynamicReshapeOp
1291 //===----------------------------------------------------------------------===//
1292
Verify(DynamicReshapeOp op)1293 static LogicalResult Verify(DynamicReshapeOp op) {
1294 auto result_type = op.result().getType().dyn_cast<RankedTensorType>();
1295 auto output_shape_type =
1296 op.output_shape().getType().dyn_cast<RankedTensorType>();
1297 if (result_type && output_shape_type && output_shape_type.hasStaticShape() &&
1298 output_shape_type.getDimSize(0) != result_type.getRank()) {
1299 return op.emitError() << "output should have a rank equal to the number of "
1300 "elements in output_shape";
1301 }
1302 return success();
1303 }
1304
1305 namespace {
1306 class DynamicReshapeOpNotActuallyDynamic
1307 : public OpRewritePattern<DynamicReshapeOp> {
1308 public:
1309 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1310 LogicalResult matchAndRewrite(DynamicReshapeOp op,
1311 PatternRewriter& rewriter) const override {
1312 auto type = op.result().getType().dyn_cast<RankedTensorType>();
1313 if (!type || !type.hasStaticShape()) {
1314 return rewriter.notifyMatchFailure(op, "requires static shape tensor");
1315 }
1316 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
1317 return success();
1318 }
1319 };
1320
1321 // Canonicalizes
1322 // %0 = some_op(%tensor)
1323 // %1 = "mhlo.dynamic_reshape"(%0, %shape)
1324 // (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
1325 // ... uses of %1.
1326 //
1327 // into
1328 //
1329 // ... uses of %0.
1330 // This canonicalization is only correct if the input is correct!
1331 // TODO(b/178779691): Use a more sophisticated canonicalization that preserves
1332 // errors in input, and still allows us to get rid of redundant reshapes.
1333 class RemoveRedundantRank1DynamicReshape
1334 : public OpRewritePattern<DynamicReshapeOp> {
1335 public:
1336 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1337 LogicalResult matchAndRewrite(DynamicReshapeOp op,
1338 PatternRewriter& rewriter) const override {
1339 auto type = op.result().getType().dyn_cast<RankedTensorType>();
1340 if (!type || type.getRank() != 1 || type.hasStaticShape()) {
1341 return rewriter.notifyMatchFailure(
1342 op, "requires rank 1 shape tensor with dynamic dimension");
1343 }
1344 auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
1345 if (!operand_type || operand_type.getRank() != 1 ||
1346 operand_type.hasStaticShape()) {
1347 return rewriter.notifyMatchFailure(
1348 op, "requires rank 1 shape tensor with dynamic dimension");
1349 }
1350 rewriter.replaceOp(op, {op.operand()});
1351 return success();
1352 }
1353 };
1354
1355 // Canonicalizes
1356 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
1357 // %1 = same_operands_and_result_shape_op(%tensor)
1358 // %2 = "mhlo.dynamic_reshape"(%1, %shape)
1359 // ... uses of %2.
1360 //
1361 // into
1362 //
1363 // %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
1364 // %1 = same_operands_and_result_shape_op(%tensor)
1365 // ... uses of %1.
1366 class DynamicReshapeOpSameShapeOpResult
1367 : public OpRewritePattern<DynamicReshapeOp> {
1368 public:
1369 using OpRewritePattern::OpRewritePattern;
1370
matchAndRewrite(DynamicReshapeOp op,PatternRewriter & rewriter) const1371 LogicalResult matchAndRewrite(DynamicReshapeOp op,
1372 PatternRewriter& rewriter) const override {
1373 Operation* def_op = op.operand().getDefiningOp();
1374 if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
1375 return failure();
1376 }
1377 Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
1378 if (!input_def_op) {
1379 return failure();
1380 }
1381 auto reshape = dyn_cast<DynamicReshapeOp>(*input_def_op);
1382 if (reshape && reshape.output_shape() == op.output_shape()) {
1383 rewriter.replaceOp(op, {def_op->getResult(0)});
1384 return success();
1385 }
1386 return failure();
1387 }
1388 };
1389 } // namespace
1390
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1391 void DynamicReshapeOp::getCanonicalizationPatterns(
1392 OwningRewritePatternList& results, MLIRContext* context) {
1393 // clang-format off
1394 results.insert<
1395 DynamicReshapeOpNotActuallyDynamic,
1396 DynamicReshapeOpSameShapeOpResult,
1397 RemoveRedundantDynamicBroadcast,
1398 RemoveRedundantDynamicReshape,
1399 RemoveRedundantRank1DynamicReshape,
1400 ShapeOfDynamicReshape
1401 >(context);
1402 // clang-format on
1403 }
1404
1405 //===----------------------------------------------------------------------===//
1406 // DynamicSliceOp
1407 //===----------------------------------------------------------------------===//
1408
1409 namespace {
1410 // Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
1411 // This canonicalization is applied the case when the `begin` input values are
1412 // compile time constants and thus can be made into a tensor.
1413 struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
1414 using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
1415
matchAndRewritemlir::mhlo::__anon5ac85cd70811::DynamicSliceToSlice1416 LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice,
1417 PatternRewriter& rewriter) const override {
1418 Value input = dynamic_slice.operand();
1419 auto input_tensor = input.getType().dyn_cast<RankedTensorType>();
1420 if (!input_tensor) return failure();
1421
1422 SmallVector<int64_t, 4> temp_start_indices;
1423 for (Value start : dynamic_slice.start_indices()) {
1424 APInt val;
1425 if (!matchPattern(start, m_ConstantInt(&val))) {
1426 return failure();
1427 }
1428 temp_start_indices.push_back(*(val.getRawData()));
1429 }
1430
1431 // At this point we've determined that the start indices are all constants;
1432 // pack them into a single tensor.
1433 auto loc = dynamic_slice.getLoc();
1434 int64_t input_rank = input_tensor.getRank();
1435 auto slice_start_indices =
1436 GetI64ElementsAttr(temp_start_indices, &rewriter);
1437 DenseIntElementsAttr slice_limits = BuildSliceLimits(
1438 slice_start_indices, dynamic_slice.slice_sizes(), &rewriter);
1439 DenseIntElementsAttr slice_strides =
1440 GetI64ElementsAttr(SmallVector<int64_t, 4>(input_rank, 1), &rewriter);
1441 auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
1442 slice_limits, slice_strides);
1443 rewriter.replaceOp(dynamic_slice, {result});
1444 return success();
1445 }
1446 };
1447
1448 } // namespace
1449
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1450 void DynamicSliceOp::getCanonicalizationPatterns(
1451 OwningRewritePatternList& results, MLIRContext* context) {
1452 results.insert<DynamicSliceToSlice>(context);
1453 }
1454
1455 // Verifies that the number of slice sizes and the number of start indices match
Verify(DynamicSliceOp op)1456 static LogicalResult Verify(DynamicSliceOp op) {
1457 int num_slice_sizes = op.slice_sizes().getNumElements();
1458 int num_start_indices = op.start_indices().size();
1459 if (num_start_indices != num_slice_sizes) {
1460 return op.emitOpError()
1461 << "has mismatched number of slice sizes (" << num_slice_sizes
1462 << ") and number of start indices (" << num_start_indices << ")";
1463 }
1464 return success();
1465 }
1466
1467 //===----------------------------------------------------------------------===//
1468 // InfeedOp
1469 //===----------------------------------------------------------------------===//
1470
1471 // Checks that the result type is of the form `tuple< any_type, token >`.
Verify(InfeedOp op)1472 static LogicalResult Verify(InfeedOp op) {
1473 auto result_ty = op.getResult().getType().cast<TupleType>();
1474 auto subtypes = result_ty.getTypes();
1475 if (subtypes.size() != 2)
1476 return op.emitOpError()
1477 << "result is expected to be a tuple of size 2, but got "
1478 << subtypes.size();
1479 if (!subtypes[1].isa<TokenType>())
1480 return op.emitOpError() << "second element of result tuple is expected to "
1481 "be of token type, but got "
1482 << subtypes[1];
1483 return success();
1484 }
1485
1486 //===----------------------------------------------------------------------===//
1487 // Logical Ops
1488 //===----------------------------------------------------------------------===//
1489
fold(ArrayRef<Attribute> operands)1490 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
1491 if (lhs() == rhs()) return lhs();
1492
1493 auto rType = getType().cast<ShapedType>();
1494 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1495 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1496
1497 if (lhsVal && lhsVal.isSplat()) {
1498 if (lhsVal.getSplatValue()
1499 .cast<IntegerAttr>()
1500 .getValue()
1501 .isAllOnesValue()) {
1502 return rhs();
1503 }
1504
1505 if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1506 return lhsVal;
1507 }
1508 }
1509
1510 if (rhsVal && rhsVal.isSplat()) {
1511 if (rhsVal.getSplatValue()
1512 .cast<IntegerAttr>()
1513 .getValue()
1514 .isAllOnesValue()) {
1515 return lhs();
1516 }
1517
1518 if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1519 return rhsVal;
1520 }
1521 }
1522
1523 if (!rhsVal || !lhsVal) return {};
1524
1525 llvm::SmallVector<APInt, 4> values;
1526 values.reserve(rhsVal.getNumElements());
1527 for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
1528 values.push_back(std::get<0>(it) & std::get<1>(it));
1529 }
1530
1531 return DenseIntElementsAttr::get(rType, values);
1532 }
1533
fold(ArrayRef<Attribute> operands)1534 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
1535 if (lhs() == rhs()) return lhs();
1536
1537 auto rType = getType().cast<ShapedType>();
1538 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1539 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1540
1541 if (lhsVal && lhsVal.isSplat()) {
1542 if (lhsVal.getSplatValue()
1543 .cast<IntegerAttr>()
1544 .getValue()
1545 .isAllOnesValue()) {
1546 return lhsVal;
1547 }
1548
1549 if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1550 return rhs();
1551 }
1552 }
1553
1554 if (rhsVal && rhsVal.isSplat()) {
1555 if (rhsVal.getSplatValue()
1556 .cast<IntegerAttr>()
1557 .getValue()
1558 .isAllOnesValue()) {
1559 return rhsVal;
1560 }
1561
1562 if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1563 return lhs();
1564 }
1565 }
1566
1567 if (!rhsVal || !lhsVal) return {};
1568
1569 llvm::SmallVector<APInt, 4> values;
1570 values.reserve(rhsVal.getNumElements());
1571 for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
1572 values.push_back(std::get<0>(it) | std::get<1>(it));
1573 }
1574
1575 return DenseIntElementsAttr::get(rType, values);
1576 }
1577
fold(ArrayRef<Attribute> operands)1578 OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
1579 auto rType = getType().cast<ShapedType>();
1580 if (lhs() == rhs()) {
1581 Builder builder(getContext());
1582 return builder.getZeroAttr(rType);
1583 }
1584
1585 auto lhsVal = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1586 auto rhsVal = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1587
1588 if (lhsVal && lhsVal.isSplat()) {
1589 if (lhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1590 return rhs();
1591 }
1592 }
1593
1594 if (rhsVal && rhsVal.isSplat()) {
1595 if (rhsVal.getSplatValue().cast<IntegerAttr>().getValue().isNullValue()) {
1596 return lhs();
1597 }
1598 }
1599
1600 if (!rhsVal || !lhsVal) return {};
1601
1602 llvm::SmallVector<APInt, 4> values;
1603 values.reserve(rhsVal.getNumElements());
1604 for (auto it : llvm::zip(rhsVal.getIntValues(), lhsVal.getIntValues())) {
1605 values.push_back(std::get<0>(it) ^ std::get<1>(it));
1606 }
1607
1608 return DenseIntElementsAttr::get(rType, values);
1609 }
1610
1611 //===----------------------------------------------------------------------===//
1612 // MapOp
1613 //===----------------------------------------------------------------------===//
1614
Verify(MapOp op)1615 static LogicalResult Verify(MapOp op) {
1616 // Checks if the number of `operands` match the arity of the map `computation`
1617 // region.
1618 auto& computation_block = op.computation().front();
1619 auto computation_args = computation_block.getArguments();
1620 if (op.operands().size() != computation_args.size())
1621 return op.emitOpError()
1622 << "expects number of operands to match the arity "
1623 "of map computation, but got: "
1624 << op.operands().size() << " and " << computation_args.size();
1625
1626 // The parameters of computation should all be scalars and match the element
1627 // type of operands.
1628 auto operand_type = op.operands()[0].getType().cast<TensorType>();
1629 auto operand_elem_ty = operand_type.getElementType();
1630
1631 for (auto indexed_arg : llvm::enumerate(computation_args)) {
1632 auto arg_type = indexed_arg.value().getType().dyn_cast<TensorType>();
1633 if (!arg_type || arg_type.getRank() != 0)
1634 return op.emitOpError()
1635 << "computation arguments must be 0-rank tensor, but got: arg #"
1636 << indexed_arg.index() << " of type "
1637 << indexed_arg.value().getType();
1638 if (arg_type.getElementType() != operand_elem_ty) {
1639 return op.emitOpError()
1640 << "element type of operands and computation arguments must "
1641 "match, but got: "
1642 << operand_elem_ty << " and " << arg_type.getElementType();
1643 }
1644 }
1645
1646 // Mapped computation must return single output
1647 auto computation_outputs = computation_block.getTerminator()->getOperands();
1648 if (computation_outputs.size() != 1)
1649 return op.emitOpError()
1650 << "computation must return single output, but got: "
1651 << computation_outputs.size();
1652
1653 // The output of computation must be scalar and have the same element type
1654 // as op result.
1655 auto computation_output_type =
1656 computation_outputs[0].getType().dyn_cast<TensorType>();
1657 if (!computation_output_type || computation_output_type.getRank() != 0)
1658 return op.emitOpError()
1659 << "computation must return 0-rank tensor, but got: "
1660 << computation_outputs[0].getType();
1661
1662 auto result_type = op.getType().cast<TensorType>();
1663 if (computation_output_type.getElementType() != result_type.getElementType())
1664 return op.emitOpError() << "element type of result and computation output "
1665 "must match, but got: "
1666 << result_type.getElementType() << " and "
1667 << computation_output_type.getElementType();
1668
1669 // Checks that the requested map dimension numbers are monotonically
1670 // increasing.
1671 auto values = op.dimensions().getValues<int64_t>();
1672 auto dimensions = std::vector<int64_t>{values.begin(), values.end()};
1673 for (int i = 0, e = dimensions.size(); i < e; ++i) {
1674 if (dimensions[i] != i)
1675 return op.emitOpError() << "requires monotonically increasing dimension "
1676 "numbers, but got: "
1677 << op.dimensions();
1678 }
1679
1680 // Checks that number of dimensions of operands matches the size of
1681 // `dimensions` since we currently only support mapping across all
1682 // dimensions: i.e., scalar map functions.
1683 if (operand_type.hasRank()) {
1684 if (dimensions.size() != operand_type.getShape().size())
1685 return op.emitOpError()
1686 << "applied to a subset of dimensions currently not supported: "
1687 "operand dimensions = "
1688 << operand_type.getShape().size()
1689 << ", requested map dimensions size = " << dimensions.size();
1690 }
1691
1692 return success();
1693 }
1694
1695 //===----------------------------------------------------------------------===//
1696 // RecvOp
1697 //===----------------------------------------------------------------------===//
1698
1699 // Checks that the result type is of the form `tuple<any_type, mhlo::token>`
Verify(RecvOp op)1700 static LogicalResult Verify(RecvOp op) {
1701 auto result_ty = op.getResult().getType().cast<TupleType>();
1702 auto subtypes = result_ty.getTypes();
1703 if (subtypes.size() != 2)
1704 return op.emitOpError()
1705 << "result is expected to be a tuple of size 2, but got "
1706 << subtypes.size();
1707 if (!subtypes[1].isa<TokenType>())
1708 return op.emitOpError() << "second element of result tuple is expected to "
1709 "be of token type, but got "
1710 << subtypes[1];
1711 return success();
1712 }
1713
1714 //===----------------------------------------------------------------------===//
1715 // CopyOp
1716 //===----------------------------------------------------------------------===//
1717
fold(ArrayRef<Attribute> operands)1718 OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
1719
1720 //===----------------------------------------------------------------------===//
1721 // ReverseOp
1722 //===----------------------------------------------------------------------===//
1723
fold(ArrayRef<Attribute> operands)1724 OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
1725 auto input = operand();
1726
1727 // No dimensions to reverse.
1728 if (dimensions().getNumElements() == 0) return input;
1729
1730 llvm::SmallVector<APInt, 5> new_dims;
1731 new_dims.reserve(dimensions().getNumElements());
1732
1733 auto shaped_type = input.getType().cast<ShapedType>();
1734 for (auto dim : dimensions().getValues<APInt>()) {
1735 if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) {
1736 return nullptr;
1737 }
1738 }
1739
1740 return input;
1741 }
1742
1743 //===----------------------------------------------------------------------===//
1744 // ReduceOp
1745 //===----------------------------------------------------------------------===//
1746
1747 // Returns the result type after reducing operand of the given type across the
1748 // specified dimensions.
GetReduceResultType(Type operand_ty,DenseIntElementsAttr dimensions,Builder * builder)1749 static TensorType GetReduceResultType(Type operand_ty,
1750 DenseIntElementsAttr dimensions,
1751 Builder* builder) {
1752 Type element_ty = getElementTypeOrSelf(operand_ty);
1753
1754 auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>();
1755 if (!ranked_ty) return UnrankedTensorType::get(element_ty);
1756
1757 int64_t rank = ranked_ty.getRank();
1758 llvm::SmallVector<bool, 4> dims_mask(rank, false);
1759 for (int64_t dim : dimensions.getValues<int64_t>()) dims_mask[dim] = true;
1760
1761 SmallVector<int64_t, 4> shape;
1762 for (int64_t i = 0; i < rank; ++i) {
1763 if (!dims_mask[i]) shape.push_back(ranked_ty.getDimSize(i));
1764 }
1765
1766 return RankedTensorType::get(shape, element_ty);
1767 }
1768
build(OpBuilder & builder,OperationState & state,ValueRange operands,ValueRange init_values,DenseIntElementsAttr dimensions)1769 void ReduceOp::build(OpBuilder& builder, OperationState& state,
1770 ValueRange operands, ValueRange init_values,
1771 DenseIntElementsAttr dimensions) {
1772 SmallVector<Type, 1> result_ty;
1773 result_ty.reserve(operands.size());
1774
1775 for (Value operand : operands) {
1776 result_ty.push_back(
1777 GetReduceResultType(operand.getType(), dimensions, &builder));
1778 }
1779 build(builder, state, result_ty, operands, init_values, dimensions);
1780 }
1781
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1782 LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
1783 SmallVectorImpl<OpFoldResult>& results) {
1784 // No dimensions to reduce.
1785 if (dimensions().getNumElements() == 0) {
1786 for (Value input : this->operands()) {
1787 results.push_back(input);
1788 }
1789 return success();
1790 }
1791 return failure();
1792 }
1793
1794 //===----------------------------------------------------------------------===//
1795 // SelectOp
1796 //===----------------------------------------------------------------------===//
1797
Verify(SelectOp op)1798 static LogicalResult Verify(SelectOp op) {
1799 // TODO(jpienaar): Update to allow broadcastable and unranked inputs. This
1800 // corresponds to the client side HLO.
1801 return success();
1802 }
1803
fold(ArrayRef<Attribute> operands)1804 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
1805 if (on_true() == on_false()) {
1806 return on_true();
1807 }
1808
1809 auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1810 if (!predicate) {
1811 return {};
1812 }
1813
1814 auto predicateTy = predicate.getType().cast<ShapedType>();
1815 if (!predicateTy.getElementType().isInteger(1)) {
1816 return {};
1817 }
1818
1819 if (predicate.isSplat()) {
1820 return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
1821 : on_false();
1822 }
1823
1824 return {};
1825 }
1826
1827 // Makes it such that a SelectOp that is a non-root operation in a DRR infers
1828 // the return type based on operand type.
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1829 LogicalResult SelectOp::inferReturnTypes(
1830 MLIRContext*, Optional<Location> location, ValueRange operands,
1831 DictionaryAttr attributes, RegionRange regions,
1832 SmallVectorImpl<Type>& inferredReturnTypes) {
1833 auto x_type = operands[1].getType();
1834 auto y_type = operands[2].getType();
1835 auto x_tensor = x_type.cast<TensorType>();
1836 auto y_tensor = y_type.cast<TensorType>();
1837
1838 // Check for type compatibility in the select op. This requires that the two
1839 // non-predicate operands:
1840 // (a) have the same element type
1841 // (b) have compatible shapes (i.e. the same shape and/or at least one
1842 // dynamic shape)
1843 if (x_tensor.getElementType() != y_tensor.getElementType() ||
1844 failed(mlir::verifyCompatibleShape(x_type, y_type))) {
1845 return emitOptionalError(location, "incompatible operand types: ", x_type,
1846 " and ", y_type);
1847 }
1848
1849 // TODO(lucyfox): Support output shape inference when operands have compatible
1850 // shapes. (The output shape should be the most general of the operand shapes
1851 // at each dimension.) For now, handle the straightforward cases and fail
1852 // otherwise. When this is fully implemented, this logic should move into
1853 // reusable functionality in MLIR Core.
1854 Type output_type;
1855 if (x_type == y_type || !x_tensor.hasRank()) {
1856 output_type = x_type;
1857 } else if (!y_tensor.hasRank()) {
1858 output_type = y_type;
1859 } else {
1860 return emitOptionalError(location,
1861 "currently unsupported operand types: ", x_type,
1862 " and ", y_type);
1863 }
1864 inferredReturnTypes.assign({output_type});
1865 return success();
1866 }
1867
inferReturnTypeComponents(mlir::MLIRContext *,llvm::Optional<mlir::Location>,mlir::ValueRange,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> &)1868 LogicalResult SelectOp::inferReturnTypeComponents(
1869 mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
1870 mlir::DictionaryAttr, mlir::RegionRange,
1871 llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
1872 // TODO(b/168772852)
1873 return failure();
1874 }
1875
reifyReturnTypeShapes(OpBuilder & builder,SmallVectorImpl<Value> & reifiedReturnShapes)1876 LogicalResult SelectOp::reifyReturnTypeShapes(
1877 OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
1878 return deriveShapeFromFirstOperand(&builder, getOperation(),
1879 &reifiedReturnShapes);
1880 }
1881
1882 //===----------------------------------------------------------------------===//
1883 // SetDimensionSizeOp
1884 //===----------------------------------------------------------------------===//
1885
Verify(SetDimensionSizeOp op)1886 static LogicalResult Verify(SetDimensionSizeOp op) {
1887 if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
1888 if (size.getRank() != 0)
1889 return op.emitOpError() << "size operand should be of rank-0";
1890 }
1891
1892 return VerifyDimAttr(op);
1893 }
1894
fold(ArrayRef<Attribute> operands)1895 OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
1896 DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1897 if (input) return input;
1898
1899 DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1900 if (!size || !size.isSplat()) return {};
1901
1902 auto ty = getType().dyn_cast<RankedTensorType>();
1903 if (!ty) return {};
1904
1905 int64_t dim_size = ty.getDimSize(dimension());
1906 if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
1907 return operand();
1908 return {};
1909 }
1910
1911 //===----------------------------------------------------------------------===//
1912 // PadOp
1913 //===----------------------------------------------------------------------===//
1914
Verify(PadOp op)1915 static LogicalResult Verify(PadOp op) {
1916 auto input_type = op.operand().getType().cast<RankedTensorType>();
1917 auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
1918
1919 if (pad_type.getRank() != 0) {
1920 return op.emitOpError(
1921 llvm::formatv("padding value type should be a rank-0 "
1922 "tensor, is rank {0}",
1923 pad_type.getRank()));
1924 }
1925
1926 const auto& padding_low = op.edge_padding_low();
1927 if (padding_low.getType().getNumElements() != input_type.getRank()) {
1928 return op.emitOpError(llvm::formatv(
1929 "edge_padding_low length ({0}) must match operand rank ({1})",
1930 padding_low.getType().getNumElements(), input_type.getRank()));
1931 }
1932
1933 const auto& padding_high = op.edge_padding_high();
1934 if (padding_high.getType().getNumElements() != input_type.getRank()) {
1935 return op.emitOpError(llvm::formatv(
1936 "edge_padding_high length ({0}) must match operand rank ({1})",
1937 padding_high.getType().getNumElements(), input_type.getRank()));
1938 }
1939
1940 const auto& padding_interior = op.interior_padding();
1941 if (padding_interior.getType().getNumElements() != input_type.getRank()) {
1942 return op.emitOpError(llvm::formatv(
1943 "interior_padding length ({0}) must match operand rank ({1})",
1944 padding_interior.getType().getNumElements(), input_type.getRank()));
1945 }
1946
1947 auto input_shape = input_type.getShape();
1948 auto output_shape =
1949 op.getResult().getType().cast<RankedTensorType>().getShape();
1950 if (input_shape.size() != output_shape.size()) {
1951 return op.emitOpError(
1952 llvm::formatv("operand rank ({0}) and result rank({0}) should match",
1953 input_shape.size(), output_shape.size()));
1954 }
1955
1956 for (int i = 0, e = input_shape.size(); i < e; i++) {
1957 int padding_low_val = padding_low.getValue<IntegerAttr>(i).getInt();
1958 int padding_high_val = padding_high.getValue<IntegerAttr>(i).getInt();
1959 int padding_interior_val =
1960 padding_interior.getValue<IntegerAttr>(i).getInt();
1961 int expected_output =
1962 input_shape[i] + padding_low_val + padding_high_val +
1963 std::max<int64_t>(input_shape[i] - 1, 0LL) * padding_interior_val;
1964 if (expected_output != output_shape[i]) {
1965 return op.emitOpError(llvm::formatv(
1966 "expected output shape's dimension #{0} to be {1} but found {2}", i,
1967 expected_output, output_shape[i]));
1968 }
1969 }
1970
1971 return success();
1972 }
1973
fold(ArrayRef<Attribute> operands)1974 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
1975 // If all padding is zero then it is an identity pad.
1976 auto is_zero = [](const APInt& i) { return i == 0; };
1977 if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) &&
1978 llvm::all_of(edge_padding_high().getIntValues(), is_zero) &&
1979 llvm::all_of(interior_padding().getIntValues(), is_zero))
1980 return operand();
1981
1982 // If any padding is negative then it isn't supported by the folder (yet).
1983 auto is_negative = [](const APInt& i) { return i.slt(0); };
1984 if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) &&
1985 llvm::all_of(edge_padding_high().getIntValues(), is_negative) &&
1986 llvm::all_of(interior_padding().getIntValues(), is_negative))
1987 return {};
1988
1989 DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
1990 DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1991 RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
1992 if (!input || !input.getType().hasRank() || !padding || !return_type ||
1993 !return_type.hasStaticShape())
1994 return {};
1995
1996 // Fill the full result tensor with the padding value.
1997 llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
1998 padding.getValue({}));
1999
2000 auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
2001 llvm::ArrayRef<int64_t> shape) {
2002 for (int64_t i = index.size() - 1; i >= 0; --i) {
2003 ++index[i];
2004 if (index[i] < shape[i]) return;
2005 index[i] = 0;
2006 }
2007 };
2008
2009 // Iterate over all elements of the input tensor and copy it to the correct
2010 // location in the output tensor.
2011 llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
2012 uint64_t num_elements = input.getNumElements();
2013 for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
2014 uint64_t result_idx = 0;
2015 uint64_t idx_multiplyer = 1;
2016 for (int64_t i = index.size() - 1; i >= 0; --i) {
2017 result_idx +=
2018 (edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
2019 index[i] *
2020 (interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
2021 idx_multiplyer;
2022 idx_multiplyer *= return_type.getDimSize(i);
2023 }
2024 result[result_idx] = input.getValue(index);
2025 next_index(index, input.getType().getShape());
2026 }
2027 return DenseElementsAttr::get(return_type, result);
2028 }
2029
2030 //===----------------------------------------------------------------------===//
2031 // ReshapeOp
2032 //===----------------------------------------------------------------------===//
2033
Verify(ReshapeOp op)2034 static LogicalResult Verify(ReshapeOp op) {
2035 // If the operand type is dynamically shaped there is nothing to verify.
2036 auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
2037 if (!operand_ty || !operand_ty.hasStaticShape()) return success();
2038
2039 // If the operand type is statically shaped (not required) the number of
2040 // elements must match that of the result type.
2041 auto result_ty = op.getType().cast<RankedTensorType>();
2042 assert(result_ty && result_ty.hasStaticShape() &&
2043 "result type must be statically shaped");
2044 int64_t num_result_elements = result_ty.getNumElements();
2045 int64_t num_operand_elements = operand_ty.getNumElements();
2046 if (num_result_elements != num_operand_elements)
2047 return op.emitOpError()
2048 << "number of output elements (" << num_result_elements
2049 << ") doesn't match expected number of elements ("
2050 << num_operand_elements << ")";
2051
2052 return success();
2053 }
2054
fold(ArrayRef<Attribute> operands)2055 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
2056 if (getOperand().getType() == getType()) {
2057 return getOperand();
2058 }
2059
2060 if (auto prev_op =
2061 dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) {
2062 setOperand(prev_op.getOperand());
2063 return getResult();
2064 }
2065
2066 if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
2067 return elements.reshape(getResult().getType().cast<ShapedType>());
2068 }
2069
2070 return {};
2071 }
2072
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2073 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2074 MLIRContext* context) {
2075 results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
2076 context);
2077 }
2078
2079 //===----------------------------------------------------------------------===//
2080 // ReplicaId Op
2081 //===----------------------------------------------------------------------===//
2082
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2083 LogicalResult ReplicaIdOp::inferReturnTypes(
2084 MLIRContext* context, Optional<Location>, ValueRange operands,
2085 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2086 inferredReturnTypes.push_back(RankedTensorType::get(
2087 /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
2088 return success();
2089 }
2090
2091 //===----------------------------------------------------------------------===//
2092 // Case Op
2093 //===----------------------------------------------------------------------===//
2094
Verify(CaseOp op)2095 static LogicalResult Verify(CaseOp op) {
2096 auto num_branches = op.branches().size();
2097 if (op.branch_operands().size() != num_branches)
2098 return op.emitOpError() << "expects number of branches " << num_branches
2099 << " to be same as number of branch operands "
2100 << op.branch_operands().size();
2101
2102 MutableArrayRef<Region> branches = op.branches();
2103 OperandRange branch_operands = op.branch_operands();
2104 for (unsigned i = 0; i < num_branches; ++i) {
2105 mlir::Region& branch_region = branches[i];
2106 if (branch_region.empty())
2107 return op.emitOpError() << "cannot have empty regions";
2108 mlir::Block& entry_block = branch_region.front();
2109 if (entry_block.getNumArguments() != 1)
2110 return op.emitOpError()
2111 << "expects branch regions to have single argument, but found "
2112 << entry_block.getNumArguments() << " for branch " << i;
2113 auto operand = branch_operands[i];
2114 if (entry_block.getArgument(0).getType() != operand.getType())
2115 return op.emitOpError()
2116 << "expects operand " << i + 1 << " to be of type "
2117 << entry_block.getArgument(0).getType() << ", but found "
2118 << operand.getType();
2119 WalkResult walker = branch_region.walk([&](ReturnOp return_op) {
2120 if (return_op.getOperands().getTypes() != op.getResultTypes())
2121 return WalkResult::interrupt();
2122 return WalkResult::advance();
2123 });
2124 if (walker.wasInterrupted())
2125 return op.emitOpError()
2126 << "branch " << i
2127 << " returned values do not match op result types";
2128 }
2129 return success();
2130 }
2131
2132 //===----------------------------------------------------------------------===//
2133 // SqrtOp
2134 //===----------------------------------------------------------------------===//
2135
fold(ArrayRef<Attribute> operands)2136 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
2137 auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2138 if (!val) return {};
2139
2140 auto type = getElementTypeOrSelf(getType());
2141 if (!type.isF32() && !type.isF64()) return {};
2142
2143 auto shaped_type = getType().cast<ShapedType>();
2144 if (!shaped_type.hasStaticShape()) return {};
2145
2146 int bit_width = type.getIntOrFloatBitWidth();
2147 llvm::SmallVector<APFloat, 4> values;
2148 values.reserve(val.getNumElements());
2149 for (auto it : val.getFloatValues()) {
2150 double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble();
2151 if (value < 0) return {};
2152 value = std::sqrt(value);
2153 if (bit_width == 32)
2154 values.emplace_back(static_cast<float>(value));
2155 else
2156 values.emplace_back(value);
2157 }
2158 return DenseFPElementsAttr::get(shaped_type, values);
2159 }
2160
2161 //===----------------------------------------------------------------------===//
2162 // UnaryOps
2163 //===----------------------------------------------------------------------===//
2164
2165 template <typename Op, typename ElementType = Type, typename ValType,
2166 typename Convert>
UnaryFolder(Op * op,ArrayRef<Attribute> attrs)2167 static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
2168 if (!attrs[0]) return {};
2169
2170 DenseElementsAttr val = attrs[0].dyn_cast<DenseElementsAttr>();
2171 if (!val) return {};
2172
2173 ShapedType type = op->getType().template cast<ShapedType>();
2174 if (!type.hasStaticShape()) {
2175 return {};
2176 }
2177
2178 Type etype = type.getElementType();
2179
2180 // Evaluate for integer values.
2181 if (!etype.isa<ElementType>()) {
2182 return {};
2183 }
2184
2185 SmallVector<ValType, 6> values;
2186 values.reserve(val.getNumElements());
2187 for (const auto v : val.getValues<ValType>()) {
2188 values.push_back(Convert()(v));
2189 }
2190
2191 return DenseElementsAttr::get(type, values);
2192 }
2193
2194 struct round {
operator ()mlir::mhlo::round2195 APFloat operator()(const APFloat& f) {
2196 APFloat r = f;
2197 r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
2198 return r;
2199 }
2200 };
2201
2202 #define UNARY_FOLDER(Op, Func) \
2203 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
2204 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
2205 return UnaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
2206 if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
2207 return UnaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
2208 return {}; \
2209 }
2210
2211 #define UNARY_FOLDER_FLOAT(Op, Func) \
2212 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
2213 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
2214 return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
2215 return {}; \
2216 }
2217
2218 UNARY_FOLDER(NegOp, std::negate);
2219 UNARY_FOLDER_FLOAT(RoundOp, round);
2220
2221 //===----------------------------------------------------------------------===//
2222 // BinaryOps
2223 //===----------------------------------------------------------------------===//
2224
2225 namespace {
2226
2227 // Updates the element type of a (presumed) tensor type 'x', returning either
2228 // a permuted UnrankedTensorType or RankedTensorType.
UpdateResultElementType(Builder * builder,Type x,Type element_type)2229 static Type UpdateResultElementType(Builder* builder, Type x,
2230 Type element_type) {
2231 auto x_ranked = x.dyn_cast<RankedTensorType>();
2232 if (!x_ranked) {
2233 return UnrankedTensorType::get(element_type);
2234 }
2235
2236 auto shape_x = x_ranked.getShape();
2237 return RankedTensorType::get(shape_x, element_type);
2238 }
2239 } // namespace
2240
2241 template <typename Op, typename ElementType = Type, typename ValType,
2242 typename Convert>
BinaryFolder(Op * op,ArrayRef<Attribute> attrs)2243 static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
2244 if (!attrs[0] || !attrs[1]) return {};
2245
2246 DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
2247 DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
2248 if (!lhs || !rhs) return {};
2249
2250 ShapedType type = op->getType().template cast<ShapedType>();
2251 if (!type.hasStaticShape()) {
2252 return {};
2253 }
2254
2255 Type etype = type.getElementType();
2256
2257 // Evaluate for integer values.
2258 if (!etype.isa<ElementType>()) {
2259 return {};
2260 }
2261
2262 SmallVector<ValType, 6> values;
2263 values.reserve(lhs.getNumElements());
2264 for (const auto zip :
2265 llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) {
2266 values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
2267 }
2268
2269 return DenseElementsAttr::get(type, values);
2270 }
2271
2272 template <typename T>
2273 struct divide : std::divides<T> {};
2274
2275 template <>
2276 struct divide<APInt> {
operator ()mlir::mhlo::divide2277 APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
2278 };
2279
2280 template <typename T>
2281 struct remainder : std::modulus<T> {};
2282
2283 template <>
2284 struct remainder<APInt> {
operator ()mlir::mhlo::remainder2285 APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
2286 };
2287
2288 template <>
2289 struct remainder<APFloat> {
operator ()mlir::mhlo::remainder2290 APFloat operator()(const APFloat& a, const APFloat& b) const {
2291 APFloat result(a);
2292 result.remainder(b);
2293 return result;
2294 }
2295 };
2296
2297 template <typename T>
2298 struct max {
operator ()mlir::mhlo::max2299 T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
2300 };
2301
2302 template <>
2303 struct max<APInt> {
operator ()mlir::mhlo::max2304 APInt operator()(const APInt& a, const APInt& b) const {
2305 return llvm::APIntOps::smax(a, b);
2306 }
2307 };
2308
2309 template <typename T>
2310 struct min {
operator ()mlir::mhlo::min2311 T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
2312 };
2313
2314 template <>
2315 struct min<APInt> {
operator ()mlir::mhlo::min2316 APInt operator()(const APInt& a, const APInt& b) const {
2317 return llvm::APIntOps::smin(a, b);
2318 }
2319 };
2320
2321 #define BINARY_FOLDER(Op, Func) \
2322 OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
2323 if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
2324 return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
2325 if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
2326 return BinaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
2327 return {}; \
2328 }
2329
2330 // Addition, subtraction and multiplication use the std:: versions of the ops.
2331 // Due to the other ops behaving differently in signed vs unsigned integers,
2332 // APInts need a special implementation. Currently, it replicates signed int
2333 // op behavior.
2334 BINARY_FOLDER(AddOp, std::plus);
2335 BINARY_FOLDER(SubOp, std::minus);
2336 BINARY_FOLDER(MulOp, std::multiplies);
2337 BINARY_FOLDER(DivOp, divide);
2338 BINARY_FOLDER(RemOp, remainder);
2339 BINARY_FOLDER(MaxOp, max);
2340 BINARY_FOLDER(MinOp, min);
2341
2342 #undef BINARY_FOLDER
2343
2344 //===----------------------------------------------------------------------===//
2345 // SliceOp
2346 //===----------------------------------------------------------------------===//
2347
2348 // Returns output dimension size for slice result for the given arguments.
2349 // Returns -1 if arguments are illegal.
InferSliceDim(int64_t input_dim,int64_t start,int64_t end,int64_t stride)2350 static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
2351 int64_t stride) {
2352 if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
2353 stride == 0)
2354 return -1;
2355
2356 return llvm::divideCeil(end - start, stride);
2357 }
2358
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2359 LogicalResult SliceOp::inferReturnTypes(
2360 MLIRContext* context, Optional<Location> location, ValueRange operands,
2361 DictionaryAttr attributes, RegionRange regions,
2362 SmallVectorImpl<Type>& inferredReturnTypes) {
2363 SliceOpAdaptor slice(operands, attributes);
2364 // TODO(jpienaar): Update this code after refactoring verify.
2365 if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) {
2366 return failure();
2367 }
2368
2369 Type ty = slice.operand().getType();
2370 RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
2371 if (!ranked_ty) {
2372 // The operand type is unranked, so the best we can infer for the result
2373 // type is an unranked tensor with the same element type as the operand
2374 // type.
2375 inferredReturnTypes.assign({ty});
2376 return success();
2377 }
2378
2379 ShapedType attr_ty = slice.start_indices().getType();
2380 if (attr_ty.getRank() != 1) {
2381 return emitOptionalError(location, "start_indices has rank ",
2382 attr_ty.getRank(), " instead of required rank 1");
2383 }
2384
2385 int64_t rank = ranked_ty.getRank();
2386 if (attr_ty.getNumElements() != rank) {
2387 return emitOptionalError(
2388 location, "the number of elements in start_indices (",
2389 attr_ty.getNumElements(), ") does not match the rank of the operand (",
2390 rank, ")");
2391 }
2392
2393 if (!attr_ty.getElementType().isSignlessInteger(64) ||
2394 slice.limit_indices().getType() != attr_ty ||
2395 slice.strides().getType() != attr_ty) {
2396 // Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
2397 // having been verified at this point. Emit an error message that matches
2398 // the one that would be reported by AllTypesMatch for a more consistent
2399 // user experience.
2400 // TODO(b/171567182): Clean this up after AllTypesMatch has been refactored.
2401 return emitOptionalError(location,
2402 "failed to verify that all of {start_indices, "
2403 "limit_indices, strides} have same type");
2404 }
2405
2406 SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
2407 SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
2408 SmallVector<int64_t, 4> stride_vals(slice.strides().getValues<int64_t>());
2409
2410 SmallVector<int64_t, 4> shape;
2411 shape.reserve(rank);
2412 for (int64_t i = 0, e = rank; i != e; i++) {
2413 shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
2414 stride_vals[i]));
2415 }
2416 inferredReturnTypes.assign(
2417 {RankedTensorType::get(shape, ranked_ty.getElementType())});
2418 return success();
2419 }
2420
2421 template <typename I, typename E>
SliceElements(I values,ArrayRef<int64_t> sizes,ArrayRef<int64_t> starts,ArrayRef<int64_t> limits,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<E> * out_values)2422 static void SliceElements(I values, ArrayRef<int64_t> sizes,
2423 ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
2424 ArrayRef<int64_t> strides,
2425 llvm::SmallVectorImpl<E>* out_values) {
2426 assert(starts.size() == limits.size());
2427 assert(starts.size() == strides.size());
2428 if (starts.empty()) return;
2429
2430 int64_t start = starts.front();
2431 int64_t limit = limits.front();
2432 int64_t stride = strides.front();
2433 if (starts.size() == 1) {
2434 for (int i = start; i < limit; i += stride) {
2435 out_values->push_back(*(values + i));
2436 }
2437 return;
2438 }
2439
2440 for (; start < limit; start += stride) {
2441 auto begin = values + start * sizes.front();
2442 SliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
2443 limits.drop_front(), strides.drop_front(), out_values);
2444 }
2445 }
2446
2447 template <typename I, typename E>
FoldSlice(SliceOp * op,I values)2448 static Attribute FoldSlice(SliceOp* op, I values) {
2449 auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
2450 auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
2451 auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
2452
2453 auto result_type = op->operand().getType().cast<ShapedType>();
2454 if (!result_type.hasStaticShape()) return {};
2455
2456 auto shape = result_type.getShape();
2457 int64_t count = result_type.getNumElements();
2458 if (count == 0) {
2459 return DenseElementsAttr::get<E>(
2460 op->getResult().getType().cast<ShapedType>(),
2461 /*list=*/{});
2462 }
2463
2464 // Compute the striding for each dimension.
2465 llvm::SmallVector<int64_t, 6> sizes;
2466 sizes.reserve(shape.size());
2467 for (auto v : shape) {
2468 count = count / v;
2469 sizes.push_back(count);
2470 }
2471
2472 llvm::SmallVector<E, 6> out_values;
2473 out_values.reserve(result_type.getNumElements());
2474 SliceElements<I, E>(values, sizes, start, limit, stride, &out_values);
2475
2476 return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
2477 out_values);
2478 }
2479
fold(ArrayRef<Attribute> operands)2480 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
2481 // Check if the SliceOp is a NoOp operation.
2482 auto operand_type = getOperand().getType().cast<ShapedType>();
2483 auto result_type = getResult().getType().cast<ShapedType>();
2484
2485 if (operand_type.hasStaticShape() && result_type.hasStaticShape() &&
2486 (operand_type.getShape() == result_type.getShape())) {
2487 return getOperand();
2488 }
2489
2490 if (operands.empty() || !operands.front()) return {};
2491
2492 // Evaluate for statically valued inputs.
2493 DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
2494 if (!elements) return {};
2495
2496 auto etype = elements.getType().getElementType();
2497 if (etype.isa<IntegerType>()) {
2498 return FoldSlice<DenseElementsAttr::IntElementIterator, APInt>(
2499 this, elements.getIntValues().begin());
2500 } else if (etype.isa<FloatType>()) {
2501 return FoldSlice<
2502 llvm::mapped_iterator<DenseElementsAttr::IntElementIterator,
2503 std::function<APFloat(const APInt&)>>,
2504 APFloat>(this, elements.getFloatValues().begin());
2505 }
2506
2507 return {};
2508 }
2509
2510 namespace {
2511 // In cases where a concat is fed into a slice, it is possible the concat
2512 // can be simplified or bypassed. This checks which inputs to the concat are
2513 // used by the slice, either reducing the number of concatenated values or
2514 // entirely removes the concat.
2515 struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
2516 using OpRewritePattern<SliceOp>::OpRewritePattern;
2517
matchAndRewritemlir::mhlo::__anon5ac85cd70e11::SimplifyConcatSlice2518 LogicalResult matchAndRewrite(SliceOp slice,
2519 PatternRewriter& rewriter) const override {
2520 auto result_ty = slice.getType().cast<ShapedType>();
2521 if (!result_ty.hasStaticShape()) {
2522 return failure();
2523 }
2524
2525 auto slice_input = slice.operand();
2526 auto slice_input_ty = slice_input.getType().cast<ShapedType>();
2527 auto concat = dyn_cast_or_null<ConcatenateOp>(slice_input.getDefiningOp());
2528 if (!concat) {
2529 return failure();
2530 }
2531
2532 auto dimension = concat.dimension();
2533
2534 auto start = slice.start_indices().getIntValues();
2535 auto limit = slice.limit_indices().getIntValues();
2536
2537 auto slice_start = (*(start.begin() + dimension)).getSExtValue();
2538 auto slice_limit = (*(limit.begin() + dimension)).getSExtValue();
2539
2540 // We need to determine what inputs from the concat affect the slice, and
2541 // how the bounds of the slice need to be updated for the minimally required
2542 // inputs.
2543 int64_t running_size = 0;
2544 int64_t front_offset = slice_input_ty.getShape()[dimension];
2545
2546 auto subset_start = concat.operand_end();
2547 auto subset_end = concat.operand_end();
2548 for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
2549 auto input = *it;
2550 ShapedType input_ty = input.getType().cast<ShapedType>();
2551 if (input_ty.isDynamicDim(dimension)) {
2552 return failure();
2553 }
2554 auto dim_size = input_ty.getShape()[dimension];
2555
2556 // If this position is in the slice its the start of the subset and we
2557 // need to update the start and limit values.
2558 if (running_size + dim_size > slice_start &&
2559 subset_start == concat.operand_end()) {
2560 subset_start = it;
2561 front_offset = running_size;
2562 }
2563
2564 // Determine the last required offset.
2565 if (running_size < slice_limit) {
2566 subset_end = it + 1;
2567 }
2568
2569 running_size += dim_size;
2570 }
2571
2572 auto subset_size = subset_end - subset_start;
2573 // We need all inputs so no optimization.
2574 if (subset_size == concat.getNumOperands()) {
2575 return failure();
2576 }
2577
2578 if (subset_size > 1 && !concat.getResult().hasOneUse()) {
2579 return failure();
2580 }
2581
2582 auto concat_range = OperandRange(subset_start, subset_end);
2583 auto new_concat = rewriter.create<ConcatenateOp>(
2584 concat.getLoc(), concat_range, concat.dimension());
2585
2586 llvm::SmallVector<APInt, 6> new_start(start);
2587 llvm::SmallVector<APInt, 6> new_limit(limit);
2588 new_start[dimension] -= front_offset;
2589 new_limit[dimension] -= front_offset;
2590
2591 auto attr_type = slice.start_indices().getType().cast<ShapedType>();
2592 auto create = rewriter.create<SliceOp>(
2593 slice.getLoc(), new_concat,
2594 DenseIntElementsAttr::get(attr_type, new_start),
2595 DenseIntElementsAttr::get(attr_type, new_limit), slice.strides());
2596 rewriter.replaceOp(slice, create.getResult());
2597 return success();
2598 }
2599 };
2600 } // namespace
2601
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2602 void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
2603 MLIRContext* context) {
2604 results.insert<SimplifyConcatSlice>(context);
2605 }
2606
2607 //===----------------------------------------------------------------------===//
2608 // SortOp
2609 //===----------------------------------------------------------------------===//
2610
build(OpBuilder & builder,OperationState & state,ValueRange operands,int64_t dimension,bool is_stable)2611 void SortOp::build(OpBuilder& builder, OperationState& state,
2612 ValueRange operands, int64_t dimension, bool is_stable) {
2613 state.addOperands(operands);
2614 state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
2615 state.addAttribute("is_stable", builder.getBoolAttr(dimension));
2616
2617 for (Value operand : operands) state.addTypes(operand.getType());
2618
2619 state.addRegion();
2620 }
2621
Verify(SortOp op)2622 static LogicalResult Verify(SortOp op) {
2623 Operation::operand_range operands = op.operands();
2624 if (operands.empty()) return op.emitOpError("requires at least one input");
2625
2626 // TODO(antiagainst): verify partionally dynamic shapes
2627 if (llvm::all_of(operands, [](Value operand) {
2628 return operand.getType().cast<ShapedType>().hasRank();
2629 })) {
2630 ArrayRef<int64_t> input_shape =
2631 (*operands.begin()).getType().cast<ShapedType>().getShape();
2632
2633 if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
2634 return operand.getType().cast<ShapedType>().getShape() != input_shape;
2635 }))
2636 return op.emitOpError("requires all inputs to have the same dimensions");
2637
2638 int64_t rank = input_shape.size();
2639 int64_t cmp_dim = op.dimension();
2640 if (cmp_dim < -rank || cmp_dim >= rank)
2641 return op.emitOpError("dimension attribute value must be in range [-")
2642 << rank << ", " << rank << "), but found " << cmp_dim;
2643 }
2644
2645 Block& block = op.comparator().front();
2646 size_t num_operands = op.getOperation()->getNumOperands();
2647 if (block.getNumArguments() != 2 * num_operands)
2648 return op.emitOpError("comparator block should have ")
2649 << 2 * num_operands << " arguments";
2650
2651 for (auto indexed_operand : llvm::enumerate(operands)) {
2652 int index = indexed_operand.index();
2653 Type element_type =
2654 indexed_operand.value().getType().cast<ShapedType>().getElementType();
2655 Type tensor_type = RankedTensorType::get({}, element_type);
2656 for (int i : {2 * index, 2 * index + 1}) {
2657 Type arg_type = block.getArgument(i).getType();
2658 if (arg_type != tensor_type)
2659 return op.emitOpError("comparator block argument #")
2660 << i << " should be of type " << tensor_type << " but got "
2661 << arg_type;
2662 }
2663 }
2664
2665 return success();
2666 }
2667
2668 //===----------------------------------------------------------------------===//
2669 // TransposeOp
2670 //===----------------------------------------------------------------------===//
2671
fold(ArrayRef<Attribute> operands)2672 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
2673 for (auto it : llvm::enumerate(permutation().getValues<APInt>())) {
2674 if (it.index() != it.value()) {
2675 return {};
2676 }
2677 }
2678 return getOperand();
2679 }
2680
Verify(TransposeOp op)2681 static LogicalResult Verify(TransposeOp op) {
2682 // permutation is an attribute of the op so it has static shape.
2683 auto permutationType = op.permutation().getType();
2684 auto permutationRank = permutationType.getRank();
2685 if (permutationRank != 1) {
2686 return op.emitOpError(llvm::formatv(
2687 "permutation has rank {0} instead of rank 1", permutationRank));
2688 }
2689 auto permutationSize = permutationType.getNumElements();
2690
2691 auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
2692 if (operandType) {
2693 auto operandRank = operandType.getRank();
2694 if (operandRank != permutationSize) {
2695 return op.emitOpError(llvm::formatv(
2696 "operand rank ({0}) does not match permutation size ({1})",
2697 operandRank, permutationSize));
2698 }
2699 }
2700
2701 auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
2702 if (resultType) {
2703 auto resultRank = resultType.getRank();
2704 if (resultRank != permutationSize) {
2705 return op.emitOpError(llvm::formatv(
2706 "result rank ({0}) does not match permutation size ({1})", resultRank,
2707 permutationSize));
2708 }
2709 }
2710
2711 if (!resultType || !operandType) return success();
2712
2713 auto operandRank = operandType.getRank();
2714 SmallVector<int64_t, 4> expectedShape(operandRank);
2715 for (int i = 0; i != operandRank; ++i) {
2716 auto permutedDim = op.permutation().getValue<IntegerAttr>(i).getInt();
2717 expectedShape[i] = operandType.getDimSize(permutedDim);
2718 }
2719
2720 auto expectedType =
2721 RankedTensorType::get(expectedShape, resultType.getElementType());
2722 if (failed(verifyCompatibleShape(resultType, expectedType))) {
2723 return op.emitOpError(llvm::formatv(
2724 "result type {0} is incompatible with the expected type {1}",
2725 resultType, expectedType));
2726 }
2727
2728 return success();
2729 }
2730
2731 //===----------------------------------------------------------------------===//
2732 // TriangularSolveOp
2733 //===----------------------------------------------------------------------===//
2734
Verify(TriangularSolveOp op)2735 static LogicalResult Verify(TriangularSolveOp op) {
2736 auto a_type = op.a().getType().dyn_cast<RankedTensorType>();
2737
2738 // Skip verifier if a is unranked tensor.
2739 if (!a_type) return success();
2740
2741 // Check that a should have rank >= 2
2742 auto a_rank = a_type.getRank();
2743 if (a_rank < 2)
2744 return op.emitOpError()
2745 << "operand 'a' must have rank >= 2, but got " << a_type;
2746
2747 // The two minor dimensions of a must have same size.
2748 if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1))
2749 return op.emitOpError() << "two minor dimensions of operand 'a' must have "
2750 "equal size, but got "
2751 << a_type;
2752
2753 auto b_type = op.b().getType().dyn_cast<RankedTensorType>();
2754 // If b is unranked skip remaining checks.
2755 if (!b_type) return success();
2756
2757 // Check that a and b have same rank.
2758 auto b_rank = b_type.getRank();
2759 if (a_rank != b_rank)
2760 return op.emitOpError() << "operands must have equal rank, but got "
2761 << a_type << " and " << b_type;
2762
2763 // The shared dimension of a and b should match.
2764 if (a_type.getDimSize(a_rank - 1) !=
2765 b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1)))
2766 return op.emitOpError() << "shared dimension of operands 'a' and 'b' does "
2767 "not match, but got "
2768 << a_type << " and " << b_type;
2769
2770 // The leading batch dimensions of a and b must be equal.
2771 auto a_batch_dims = a_type.getShape().drop_back(2);
2772 auto b_batch_dims = b_type.getShape().drop_back(2);
2773 if (a_batch_dims != b_batch_dims)
2774 return op.emitOpError()
2775 << "leading batch dimensions of the operands must be same, but got "
2776 << a_type << " and " << b_type;
2777
2778 // Result and argument b must have same shape.
2779 auto result_type = op.getType().dyn_cast<RankedTensorType>();
2780 if (!result_type) return success();
2781 if (result_type != b_type)
2782 return op.emitOpError()
2783 << "result and operand 'b' must have same shape, but got "
2784 << result_type << " and " << b_type;
2785 return success();
2786 }
2787
2788 //===----------------------------------------------------------------------===//
2789 // GetTupleElementOp
2790 //===----------------------------------------------------------------------===//
2791
build(OpBuilder & builder,OperationState & result,Value tuple,int32_t index)2792 void GetTupleElementOp::build(OpBuilder& builder, OperationState& result,
2793 Value tuple, int32_t index) {
2794 if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
2795 auto element_type = tuple_type.getType(index);
2796 build(builder, result, element_type, tuple,
2797 builder.getI32IntegerAttr(index));
2798 return;
2799 }
2800
2801 build(builder, result, tuple.getType(), tuple,
2802 builder.getI32IntegerAttr(index));
2803 }
2804
2805 //===----------------------------------------------------------------------===//
2806 // TupleOp
2807 //===----------------------------------------------------------------------===//
2808
build(OpBuilder & builder,OperationState & result,ValueRange values)2809 void TupleOp::build(OpBuilder& builder, OperationState& result,
2810 ValueRange values) {
2811 SmallVector<Type, 4> types;
2812 types.reserve(values.size());
2813 for (auto val : values) {
2814 types.push_back(val.getType());
2815 }
2816
2817 build(builder, result, builder.getTupleType(types), values);
2818 }
2819
2820 //===----------------------------------------------------------------------===//
2821 // UnaryEinsumOp
2822 //===----------------------------------------------------------------------===//
2823
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2824 void UnaryEinsumOp::getCanonicalizationPatterns(
2825 OwningRewritePatternList& results, MLIRContext* context) {
2826 results.insert<UnaryEinsumToEinsum>(context);
2827 }
2828
2829 //===----------------------------------------------------------------------===//
2830 // CompareOp
2831 //===----------------------------------------------------------------------===//
2832
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,StringAttr comparison_direction,StringAttr compare_type)2833 void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
2834 Value rhs, StringAttr comparison_direction,
2835 StringAttr compare_type) {
2836 auto new_type =
2837 UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
2838 build(builder, result, new_type, lhs, rhs, comparison_direction,
2839 compare_type);
2840 }
2841
inferReturnTypeComponents(mlir::MLIRContext *,llvm::Optional<mlir::Location>,mlir::ValueRange,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> &)2842 LogicalResult CompareOp::inferReturnTypeComponents(
2843 mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
2844 mlir::DictionaryAttr, mlir::RegionRange,
2845 llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
2846 // TODO(b/168772852)
2847 return failure();
2848 }
2849
reifyReturnTypeShapes(OpBuilder & builder,SmallVectorImpl<Value> & reifiedReturnShapes)2850 LogicalResult CompareOp::reifyReturnTypeShapes(
2851 OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
2852 return deriveShapeFromFirstOperand(&builder, getOperation(),
2853 &reifiedReturnShapes);
2854 }
2855
2856 template <typename T>
2857 struct less : std::less<T> {};
2858
2859 template <>
2860 struct less<APInt> {
operator ()mlir::mhlo::less2861 bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
2862 };
2863
2864 template <typename T>
2865 struct less_equal : std::less_equal<T> {};
2866
2867 template <>
2868 struct less_equal<APInt> {
operator ()mlir::mhlo::less_equal2869 bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
2870 };
2871
2872 template <typename T>
2873 struct greater : std::greater<T> {};
2874
2875 template <>
2876 struct greater<APInt> {
operator ()mlir::mhlo::greater2877 bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
2878 };
2879
2880 template <typename T>
2881 struct greater_equal : std::greater_equal<T> {};
2882
2883 template <>
2884 struct greater_equal<APInt> {
operator ()mlir::mhlo::greater_equal2885 bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
2886 };
2887
2888 template <typename Op, typename ElementType, typename SrcType, typename Convert>
CompareFolder(CompareOp op,ArrayRef<Attribute> attrs)2889 static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
2890 if (!attrs[0] || !attrs[1]) return {};
2891
2892 DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
2893 DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
2894 if (!lhs || !rhs) return {};
2895
2896 ShapedType operand_type =
2897 op.getOperand(0).getType().template cast<ShapedType>();
2898 if (!operand_type.hasStaticShape()) {
2899 return {};
2900 }
2901
2902 if (!operand_type.getElementType().isa<ElementType>()) {
2903 return {};
2904 }
2905
2906 SmallVector<bool, 6> values;
2907 values.reserve(lhs.getNumElements());
2908 for (const auto zip :
2909 llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
2910 values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
2911 }
2912
2913 auto result_ty = op.getType().cast<ShapedType>();
2914 return DenseElementsAttr::get(result_ty, values);
2915 }
2916
fold(ArrayRef<Attribute> operands)2917 OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
2918 auto result_ty = getType().cast<ShapedType>();
2919 if (!result_ty.hasStaticShape()) return {};
2920
2921 auto direction = comparison_direction();
2922 if (lhs() == rhs() && !getElementTypeOrSelf(lhs()).isa<FloatType>()) {
2923 if (direction == "LE" || direction == "EQ" || direction == "GE") {
2924 return DenseIntElementsAttr::get(result_ty, {true});
2925 }
2926 return DenseIntElementsAttr::get(result_ty, {false});
2927 }
2928
2929 if (!operands[0] || !operands[1]) {
2930 return {};
2931 }
2932
2933 #define COMPARE_FOLDER(Op, comparison, Func) \
2934 if (direction == comparison) { \
2935 if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
2936 *this, operands)) \
2937 return folded; \
2938 if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>( \
2939 *this, operands)) \
2940 return folded; \
2941 }
2942
2943 COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
2944 COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
2945 COMPARE_FOLDER(CompareOp, "LT", less);
2946 COMPARE_FOLDER(CompareOp, "LE", less_equal);
2947 COMPARE_FOLDER(CompareOp, "GT", greater);
2948 COMPARE_FOLDER(CompareOp, "GE", greater_equal);
2949 #undef COMPARE_FOLDER
2950
2951 return {};
2952 }
2953
2954 //===----------------------------------------------------------------------===//
2955 // ScatterOp
2956 //===----------------------------------------------------------------------===//
2957
evaluateMhloRegion(Region & region,ArrayRef<Attribute> inputs)2958 llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
2959 ArrayRef<Attribute> inputs) {
2960 if (region.getNumArguments() != inputs.size()) return {};
2961
2962 llvm::DenseMap<Value, Attribute> values;
2963 values.reserve(region.getNumArguments());
2964 for (auto it : llvm::zip(region.getArguments(), inputs)) {
2965 values.try_emplace(std::get<0>(it), std::get<1>(it));
2966 }
2967
2968 for (auto& op : region.getOps()) {
2969 llvm::SmallVector<Attribute, 4> inputs;
2970 for (auto& operand : op.getOpOperands()) {
2971 inputs.push_back(values.lookup(operand.get()));
2972 }
2973 if (isa<ReturnOp>(op)) return inputs;
2974
2975 llvm::SmallVector<OpFoldResult, 4> results;
2976 if (failed(op.fold(inputs, results))) return {};
2977 for (auto it : llvm::zip(op.getResults(), results)) {
2978 if (!std::get<1>(it).is<Attribute>()) return {};
2979 values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
2980 }
2981 }
2982 return {};
2983 }
2984
fold(ArrayRef<Attribute> operands)2985 OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) {
2986 auto base = operands[0].dyn_cast_or_null<DenseElementsAttr>();
2987 auto index = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
2988 auto update = operands[2].dyn_cast_or_null<DenseElementsAttr>();
2989 if (!base || !index || !update) return {};
2990
2991 auto base_type = base.getType().dyn_cast<RankedTensorType>();
2992 auto index_type = index.getType().dyn_cast<RankedTensorType>();
2993 auto update_type = update.getType().dyn_cast<RankedTensorType>();
2994 if (!base_type || !index_type || !update_type) return {};
2995
2996 // Add the virtual trailing dimension of size 1 if index_vector_dim equals to
2997 // index_type.rank.
2998 const int64_t index_vector_dim =
2999 scatter_dimension_numbers().index_vector_dim().getInt();
3000 if (index_vector_dim == index_type.getRank()) {
3001 auto index_shape = index_type.getShape().vec();
3002 index_shape.push_back(1);
3003 index_type =
3004 RankedTensorType::get(index_shape, index_type.getElementType());
3005 index = index.reshape(index_type).cast<DenseIntElementsAttr>();
3006 }
3007
3008 // Increment the multi-dimensional index vector based on the limits for each
3009 // dimension specified by shape and returns false if the index rolled around
3010 // with true otherwise.
3011 auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
3012 llvm::ArrayRef<int64_t> shape) {
3013 for (int64_t i = index.size() - 1; i >= 0; --i) {
3014 ++index[i];
3015 if (index[i] < shape[i]) return true;
3016 index[i] = 0;
3017 }
3018 return false;
3019 };
3020
3021 // Iterate over all elements of the update tensor, then find the corresponding
3022 // value in the indices tensor to determine which location we have to update
3023 // in the base/result tensor.
3024 llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
3025 llvm::SmallVector<uint64_t, 8> update_index(update_type.getRank(), 0);
3026 llvm::SmallVector<uint64_t, 8> index_index;
3027 index_index.reserve(index_type.getRank());
3028 llvm::SmallVector<uint64_t, 8> base_index;
3029 base_index.reserve(base_type.getRank());
3030 do {
3031 // Compute the index for the slice of the indices tensor for this update
3032 // value.
3033 index_index.clear();
3034 if (index_vector_dim == 0) index_index.push_back(0);
3035 for (int64_t i = 0; i < update_index.size(); ++i) {
3036 if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0)
3037 index_index.push_back(update_index[i]);
3038 if (index_index.size() == index_vector_dim) index_index.push_back(0);
3039 }
3040
3041 // Compute the index for the given update value in the base tensor.
3042 base_index.assign(base_type.getRank(), 0);
3043 uint64_t index_count = index_type.getShape()[index_vector_dim];
3044 for (uint64_t i = 0; i < index_count; ++i) {
3045 uint64_t operand_dim = scatter_dimension_numbers()
3046 .scatter_dims_to_operand_dims()
3047 .getValue<APInt>({i})
3048 .getSExtValue();
3049 index_index[index_vector_dim] = i;
3050 base_index[operand_dim] +=
3051 index.getValue<APInt>(index_index).getSExtValue();
3052 }
3053 uint64_t update_window_dim_index = 0;
3054 for (uint64_t i = 0; i < base_index.size(); ++i) {
3055 if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i))
3056 continue;
3057 base_index[i] +=
3058 update_index[scatter_dimension_numbers()
3059 .update_window_dims()
3060 .getValue<APInt>({update_window_dim_index})
3061 .getSExtValue()];
3062 update_window_dim_index++;
3063 }
3064
3065 // Compute the linear index for the index into the base tensor.
3066 int64_t linear_base_index = 0;
3067 int64_t linear_base_index_multiplyer = 1;
3068 for (int64_t i = base_index.size() - 1; i >= 0; --i) {
3069 // Out of bound index have backend specific behaviour so avoid folding it.
3070 if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i])
3071 return {};
3072 linear_base_index += base_index[i] * linear_base_index_multiplyer;
3073 linear_base_index_multiplyer *= base_type.getShape()[i];
3074 }
3075
3076 // Evaluate update computation and update the value with the newly computed
3077 // attribute in the base tensor.
3078 auto lhs = DenseElementsAttr::get(
3079 RankedTensorType::get({}, base_type.getElementType()),
3080 results[linear_base_index]);
3081 auto rhs = DenseElementsAttr::get(
3082 RankedTensorType::get({}, base_type.getElementType()),
3083 update.getValue<Attribute>(update_index));
3084 auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs});
3085 if (new_value.size() != 1 || !new_value[0]) return {};
3086 results[linear_base_index] =
3087 new_value[0].cast<DenseElementsAttr>().getValue<Attribute>({});
3088 } while (next_index(update_index, update_type.getShape()));
3089
3090 return DenseElementsAttr::get(base_type, results);
3091 }
3092
3093 } // namespace mhlo
3094 } // namespace mlir
3095
3096 #define GET_OP_CLASSES
3097 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
3098
3099 namespace mlir {
3100 namespace mhlo {
3101
3102 //===----------------------------------------------------------------------===//
3103 // mhlo Dialect Interfaces
3104 //===----------------------------------------------------------------------===//
3105
3106 namespace {
3107 struct HLOInlinerInterface : public DialectInlinerInterface {
3108 using DialectInlinerInterface::DialectInlinerInterface;
3109
3110 // Allow all call operations to be inlined.
isLegalToInlinemlir::mhlo::__anon5ac85cd71211::HLOInlinerInterface3111 bool isLegalToInline(Operation* call, Operation* callable,
3112 bool wouldBeCloned) const final {
3113 return true;
3114 }
3115 // We don't have any special restrictions on what can be inlined into
3116 // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInlinemlir::mhlo::__anon5ac85cd71211::HLOInlinerInterface3117 bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned,
3118 BlockAndValueMapping& valueMapping) const final {
3119 return true;
3120 }
3121 // Operations in mhlo dialect are always legal to inline since they are
3122 // pure.
isLegalToInlinemlir::mhlo::__anon5ac85cd71211::HLOInlinerInterface3123 bool isLegalToInline(Operation*, Region*, bool,
3124 BlockAndValueMapping&) const final {
3125 return true;
3126 }
3127 };
3128 } // end anonymous namespace
3129
3130 //===----------------------------------------------------------------------===//
3131 // mhlo Dialect Constructor
3132 //===----------------------------------------------------------------------===//
3133
MhloDialect(MLIRContext * context)3134 MhloDialect::MhloDialect(MLIRContext* context)
3135 : Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) {
3136 addOperations<
3137 #define GET_OP_LIST
3138 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
3139 >();
3140 addInterfaces<HLOInlinerInterface>();
3141 addTypes<TokenType>();
3142 context->loadDialect<tensor::TensorDialect>();
3143 }
3144
parseType(DialectAsmParser & parser) const3145 Type MhloDialect::parseType(DialectAsmParser& parser) const {
3146 StringRef data_type;
3147 if (parser.parseKeyword(&data_type)) return Type();
3148
3149 if (data_type == "token") return TokenType::get(getContext());
3150 parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
3151 return nullptr;
3152 }
3153
printType(Type type,DialectAsmPrinter & os) const3154 void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
3155 if (type.isa<TokenType>()) {
3156 os << "token";
3157 return;
3158 }
3159 os << "<unknown mhlo type>";
3160 }
3161
3162 //===----------------------------------------------------------------------===//
3163 // Shape inference
3164 //===----------------------------------------------------------------------===//
3165
deriveShapeFromFirstOperand(OpBuilder * builder,Operation * op,SmallVectorImpl<Value> * reifiedReturnShapes)3166 LogicalResult deriveShapeFromFirstOperand(
3167 OpBuilder* builder, Operation* op,
3168 SmallVectorImpl<Value>* reifiedReturnShapes) {
3169 Value operand = op->getOperand(0);
3170 ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
3171 if (!operand_type) {
3172 op->emitOpError() << "first operand is not a shaped type";
3173 return failure();
3174 }
3175 auto loc = op->getLoc();
3176 SmallVector<Value, 4> shape_values;
3177 shape_values.reserve(operand_type.getRank());
3178 auto shape_scalar_type = builder->getIntegerType(64);
3179 for (auto element : llvm::enumerate(operand_type.getShape())) {
3180 if (element.value() == ShapedType::kDynamicSize) {
3181 Value dim = builder->create<DimOp>(loc, operand, element.index());
3182 shape_values.push_back(
3183 builder->create<IndexCastOp>(loc, dim, shape_scalar_type));
3184 } else {
3185 shape_values.push_back(builder->create<ConstantOp>(
3186 loc, builder->getI64IntegerAttr(element.value())));
3187 }
3188 }
3189 *reifiedReturnShapes = SmallVector<Value, 1>{
3190 builder->create<tensor::FromElementsOp>(loc, shape_values)};
3191 return success();
3192 }
3193
3194 } // namespace mhlo
3195 } // namespace mlir
3196