1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Traits.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SmallString.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/raw_ostream.h"
21
22 using namespace mlir;
23 using namespace mlir::shape;
24
25 namespace {
26 #include "ShapeCanonicalization.inc"
27 }
28
getExtentTensorType(MLIRContext * ctx)29 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
30 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
31 }
32
isErrorPropagationPossible(TypeRange operandTypes)33 static bool isErrorPropagationPossible(TypeRange operandTypes) {
34 for (Type ty : operandTypes)
35 if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
36 return true;
37 return false;
38 }
39
verifySizeOrIndexOp(Operation * op)40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41 assert(op != nullptr && op->getNumResults() == 1);
42 Type resultTy = op->getResultTypes().front();
43 if (isErrorPropagationPossible(op->getOperandTypes())) {
44 if (!resultTy.isa<SizeType>())
45 return op->emitOpError()
46 << "if at least one of the operands can hold error values then "
47 "the result must be of type `size` to propagate them";
48 }
49 return success();
50 }
51
verifyShapeOrExtentTensorOp(Operation * op)52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53 assert(op != nullptr && op->getNumResults() == 1);
54 Type resultTy = op->getResultTypes().front();
55 if (isErrorPropagationPossible(op->getOperandTypes())) {
56 if (!resultTy.isa<ShapeType>())
57 return op->emitOpError()
58 << "if at least one of the operands can hold error values then "
59 "the result must be of type `shape` to propagate them";
60 }
61 return success();
62 }
63
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71 using DialectInlinerInterface::DialectInlinerInterface;
72
73 // Returns true if the given region 'src' can be inlined into the region
74 // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonf9874f600211::ShapeInlinerInterface75 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76 BlockAndValueMapping &) const final {
77 return true;
78 }
79
80 // Returns true if the given operation 'op', that is registered to this
81 // dialect, can be inlined into the region 'dest' that is attached to an
82 // operation registered to the current dialect.
isLegalToInline__anonf9874f600211::ShapeInlinerInterface83 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84 BlockAndValueMapping &) const final {
85 return true;
86 }
87 };
88 } // namespace
89
initialize()90 void ShapeDialect::initialize() {
91 addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94 >();
95 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
96 WitnessType>();
97 addInterfaces<ShapeInlinerInterface>();
98 // Allow unknown operations during prototyping and testing. As the dialect is
99 // still evolving it makes it simple to start with an unregistered ops and
100 // try different variants before actually defining the op.
101 allowUnknownOperations();
102 }
103
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)104 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
105 Attribute value, Type type,
106 Location loc) {
107 if (type.isa<ShapeType>() ||
108 type == getExtentTensorType(builder.getContext()))
109 return builder.create<ConstShapeOp>(loc, type,
110 value.cast<DenseIntElementsAttr>());
111 if (type.isa<SizeType>())
112 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
113 if (type.isa<WitnessType>())
114 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
115 if (type.isa<IndexType>())
116 return builder.create<ConstantOp>(loc, type, value);
117 return nullptr;
118 }
119
120 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const121 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
122 StringRef keyword;
123 if (parser.parseKeyword(&keyword))
124 return Type();
125
126 if (keyword == "component")
127 return ComponentType::get(getContext());
128 if (keyword == "element")
129 return ElementType::get(getContext());
130 if (keyword == "shape")
131 return ShapeType::get(getContext());
132 if (keyword == "size")
133 return SizeType::get(getContext());
134 if (keyword == "value_shape")
135 return ValueShapeType::get(getContext());
136 if (keyword == "witness")
137 return WitnessType::get(getContext());
138
139 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
140 return Type();
141 }
142
143 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const144 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
145 TypeSwitch<Type>(type)
146 .Case<ComponentType>([&](Type) { os << "component"; })
147 .Case<ElementType>([&](Type) { os << "element"; })
148 .Case<ShapeType>([&](Type) { os << "shape"; })
149 .Case<SizeType>([&](Type) { os << "size"; })
150 .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
151 .Case<WitnessType>([&](Type) { os << "witness"; })
152 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
153 }
154
155 //===----------------------------------------------------------------------===//
156 // AnyOp
157 //===----------------------------------------------------------------------===//
158
159 // TODO: Canonicalization should be implemented for shapes that can be
160 // determined through mixtures of the known dimensions of the inputs.
fold(ArrayRef<Attribute> operands)161 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
162 // Only the last operand is checked because AnyOp is commutative.
163 if (operands.back())
164 return operands.back();
165
166 return nullptr;
167 }
168
169 //===----------------------------------------------------------------------===//
170 // AssumingOp
171 //===----------------------------------------------------------------------===//
172
parseAssumingOp(OpAsmParser & parser,OperationState & result)173 static ParseResult parseAssumingOp(OpAsmParser &parser,
174 OperationState &result) {
175 result.regions.reserve(1);
176 Region *doRegion = result.addRegion();
177
178 auto &builder = parser.getBuilder();
179 OpAsmParser::OperandType cond;
180 if (parser.parseOperand(cond) ||
181 parser.resolveOperand(cond, builder.getType<WitnessType>(),
182 result.operands))
183 return failure();
184
185 // Parse optional results type list.
186 if (parser.parseOptionalArrowTypeList(result.types))
187 return failure();
188
189 // Parse the region and add a terminator if elided.
190 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
191 return failure();
192 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
193
194 // Parse the optional attribute list.
195 if (parser.parseOptionalAttrDict(result.attributes))
196 return failure();
197 return success();
198 }
199
print(OpAsmPrinter & p,AssumingOp op)200 static void print(OpAsmPrinter &p, AssumingOp op) {
201 bool yieldsResults = !op.results().empty();
202
203 p << AssumingOp::getOperationName() << " " << op.witness();
204 if (yieldsResults) {
205 p << " -> (" << op.getResultTypes() << ")";
206 }
207 p.printRegion(op.doRegion(),
208 /*printEntryBlockArgs=*/false,
209 /*printBlockTerminators=*/yieldsResults);
210 p.printOptionalAttrDict(op.getAttrs());
211 }
212
213 namespace {
214 // Removes AssumingOp with a passing witness and inlines the region.
215 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
216 using OpRewritePattern<AssumingOp>::OpRewritePattern;
217
matchAndRewrite__anonf9874f600a11::AssumingWithTrue218 LogicalResult matchAndRewrite(AssumingOp op,
219 PatternRewriter &rewriter) const override {
220 auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
221 if (!witness || !witness.passingAttr())
222 return failure();
223
224 AssumingOp::inlineRegionIntoParent(op, rewriter);
225 return success();
226 }
227 };
228 } // namespace
229
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)230 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
231 MLIRContext *context) {
232 // If taking a passing witness, inline region.
233 patterns.insert<AssumingWithTrue>(context);
234 }
235
236 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)237 void AssumingOp::getSuccessorRegions(
238 Optional<unsigned> index, ArrayRef<Attribute> operands,
239 SmallVectorImpl<RegionSuccessor> ®ions) {
240 // AssumingOp has unconditional control flow into the region and back to the
241 // parent, so return the correct RegionSuccessor purely based on the index
242 // being None or 0.
243 if (index.hasValue()) {
244 regions.push_back(RegionSuccessor(getResults()));
245 return;
246 }
247
248 regions.push_back(RegionSuccessor(&doRegion()));
249 }
250
inlineRegionIntoParent(AssumingOp & op,PatternRewriter & rewriter)251 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
252 PatternRewriter &rewriter) {
253 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
254 auto *assumingBlock = op.getBody();
255 auto initPosition = rewriter.getInsertionPoint();
256 auto *blockAfterAssuming =
257 rewriter.splitBlock(blockBeforeAssuming, initPosition);
258
259 // Remove the AssumingOp and AssumingYieldOp.
260 auto &yieldOp = assumingBlock->back();
261 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
262 rewriter.replaceOp(op, yieldOp.getOperands());
263 rewriter.eraseOp(&yieldOp);
264
265 // Merge blocks together as there was no branching behavior from the
266 // AssumingOp.
267 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
268 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
269 }
270
271 //===----------------------------------------------------------------------===//
272 // AssumingAllOp
273 //===----------------------------------------------------------------------===//
274
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)275 void AssumingAllOp::getCanonicalizationPatterns(
276 OwningRewritePatternList &patterns, MLIRContext *context) {
277 patterns.insert<AssumingAllOneOp>(context);
278 }
279
fold(ArrayRef<Attribute> operands)280 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
281 // Iterate in reverse to first handle all constant operands. They are
282 // guaranteed to be the tail of the inputs because this is commutative.
283 for (int idx = operands.size() - 1; idx >= 0; idx--) {
284 Attribute a = operands[idx];
285 // Cannot fold if any inputs are not constant;
286 if (!a)
287 return nullptr;
288
289 // We do not need to keep statically known values after handling them in
290 // this method.
291 getOperation()->eraseOperand(idx);
292
293 // Always false if any input is statically known false
294 if (!a.cast<BoolAttr>().getValue())
295 return a;
296 }
297 // If this is reached, all inputs were statically known passing.
298 return BoolAttr::get(true, getContext());
299 }
300
verify(AssumingAllOp op)301 static LogicalResult verify(AssumingAllOp op) {
302 // Ensure that AssumingAllOp contains at least one operand
303 if (op.getNumOperands() == 0)
304 return op.emitOpError("no operands specified");
305
306 return success();
307 }
308
309 //===----------------------------------------------------------------------===//
310 // BroadcastOp
311 //===----------------------------------------------------------------------===//
312
fold(ArrayRef<Attribute> operands)313 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
314 if (!operands[1])
315 return nullptr;
316
317 auto rhsShape = llvm::to_vector<6>(
318 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
319 if (rhsShape.empty())
320 return lhs();
321
322 if (!operands[0])
323 return nullptr;
324
325 auto lhsShape = llvm::to_vector<6>(
326 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
327 if (lhsShape.empty())
328 return rhs();
329
330 SmallVector<int64_t, 6> resultShape;
331 // If the shapes are not compatible, we can't fold it.
332 // TODO: Fold to an "error".
333 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
334 return nullptr;
335 Builder builder(getContext());
336 return builder.getIndexTensorAttr(resultShape);
337 }
338
339 //===----------------------------------------------------------------------===//
340 // ConcatOp
341 //===----------------------------------------------------------------------===//
342
fold(ArrayRef<Attribute> operands)343 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
344 if (!operands[0] || !operands[1])
345 return nullptr;
346 auto lhsShape = llvm::to_vector<6>(
347 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
348 auto rhsShape = llvm::to_vector<6>(
349 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
350 SmallVector<int64_t, 6> resultShape;
351 resultShape.append(lhsShape.begin(), lhsShape.end());
352 resultShape.append(rhsShape.begin(), rhsShape.end());
353 Builder builder(getContext());
354 return builder.getIndexTensorAttr(resultShape);
355 }
356
357 //===----------------------------------------------------------------------===//
358 // ConstShapeOp
359 //===----------------------------------------------------------------------===//
360
print(OpAsmPrinter & p,ConstShapeOp & op)361 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
362 p << "shape.const_shape ";
363 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
364 p << "[";
365 interleaveComma(op.shape().getValues<int64_t>(), p,
366 [&](int64_t i) { p << i; });
367 p << "] : ";
368 p.printType(op.getType());
369 }
370
parseConstShapeOp(OpAsmParser & parser,OperationState & result)371 static ParseResult parseConstShapeOp(OpAsmParser &parser,
372 OperationState &result) {
373 if (parser.parseOptionalAttrDict(result.attributes))
374 return failure();
375 // We piggy-back on ArrayAttr parsing, though we don't internally store the
376 // shape as an ArrayAttr.
377 // TODO: Implement custom parser and maybe make syntax a bit more concise.
378 Attribute extentsRaw;
379 NamedAttrList dummy;
380 if (parser.parseAttribute(extentsRaw, "dummy", dummy))
381 return failure();
382 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
383 if (!extentsArray)
384 return failure();
385 SmallVector<int64_t, 6> ints;
386 for (Attribute extent : extentsArray) {
387 IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
388 if (!attr)
389 return failure();
390 ints.push_back(attr.getInt());
391 }
392 Builder &builder = parser.getBuilder();
393 result.addAttribute("shape", builder.getIndexTensorAttr(ints));
394 Type resultTy;
395 if (parser.parseColonType(resultTy))
396 return failure();
397 result.types.push_back(resultTy);
398 return success();
399 }
400
fold(ArrayRef<Attribute>)401 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
402
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)403 void ConstShapeOp::getCanonicalizationPatterns(
404 OwningRewritePatternList &patterns, MLIRContext *context) {
405 patterns.insert<TensorCastConstShape>(context);
406 }
407
408 //===----------------------------------------------------------------------===//
409 // CstrBroadcastableOp
410 //===----------------------------------------------------------------------===//
411
412 namespace {
413 // Given an input shape Value, try to obtain the shape's values.
getShapeVec(Value input,SmallVectorImpl<int64_t> & shapeValues)414 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
415 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
416 auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
417 if (!type.hasRank())
418 return failure();
419 shapeValues = llvm::to_vector<6>(type.getShape());
420 return success();
421 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
422 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
423 return success();
424 } else {
425 return failure();
426 }
427 }
428 } // namespace
429
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)430 void CstrBroadcastableOp::getCanonicalizationPatterns(
431 OwningRewritePatternList &patterns, MLIRContext *context) {
432 // Canonicalization patterns have overlap with the considerations during
433 // folding in case additional shape information is inferred at some point that
434 // does not result in folding.
435 patterns.insert<CstrBroadcastableEqOps>(context);
436 }
437
fold(ArrayRef<Attribute> operands)438 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
439 // Both operands are not needed if one is a scalar.
440 if (operands[0] &&
441 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
442 return BoolAttr::get(true, getContext());
443 if (operands[1] &&
444 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
445 return BoolAttr::get(true, getContext());
446
447 if (operands[0] && operands[1]) {
448 auto lhsShape = llvm::to_vector<6>(
449 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
450 auto rhsShape = llvm::to_vector<6>(
451 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
452 SmallVector<int64_t, 6> resultShape;
453 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
454 return BoolAttr::get(true, getContext());
455 }
456
457 // Lastly, see if folding can be completed based on what constraints are known
458 // on the input shapes.
459 SmallVector<int64_t, 6> lhsShape, rhsShape;
460 if (failed(getShapeVec(lhs(), lhsShape)))
461 return nullptr;
462 if (failed(getShapeVec(rhs(), rhsShape)))
463 return nullptr;
464
465 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
466 return BoolAttr::get(true, getContext());
467
468 // Because a failing witness result here represents an eventual assertion
469 // failure, we do not replace it with a constant witness.
470 return nullptr;
471 }
472
473 //===----------------------------------------------------------------------===//
474 // CstrEqOp
475 //===----------------------------------------------------------------------===//
476
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)477 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
478 MLIRContext *context) {
479 // If inputs are equal, return passing witness
480 patterns.insert<CstrEqEqOps>(context);
481 }
482
fold(ArrayRef<Attribute> operands)483 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
484 if (llvm::all_of(operands,
485 [&](Attribute a) { return a && a == operands[0]; }))
486 return BoolAttr::get(true, getContext());
487
488 // Because a failing witness result here represents an eventual assertion
489 // failure, we do not try to replace it with a constant witness. Similarly, we
490 // cannot if there are any non-const inputs.
491 return nullptr;
492 }
493
494 //===----------------------------------------------------------------------===//
495 // ConstSizeOp
496 //===----------------------------------------------------------------------===//
497
build(OpBuilder & builder,OperationState & result,int64_t value)498 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
499 int64_t value) {
500 build(builder, result, builder.getIndexAttr(value));
501 }
502
fold(ArrayRef<Attribute>)503 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
504
getAsmResultNames(llvm::function_ref<void (Value,StringRef)> setNameFn)505 void ConstSizeOp::getAsmResultNames(
506 llvm::function_ref<void(Value, StringRef)> setNameFn) {
507 SmallString<4> buffer;
508 llvm::raw_svector_ostream os(buffer);
509 os << "c" << value();
510 setNameFn(getResult(), os.str());
511 }
512
513 //===----------------------------------------------------------------------===//
514 // ConstWitnessOp
515 //===----------------------------------------------------------------------===//
516
fold(ArrayRef<Attribute>)517 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
518
519 //===----------------------------------------------------------------------===//
520 // CstrRequireOp
521 //===----------------------------------------------------------------------===//
522
fold(ArrayRef<Attribute> operands)523 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
524 return operands[0];
525 }
526
527 //===----------------------------------------------------------------------===//
528 // ShapeEqOp
529 //===----------------------------------------------------------------------===//
530
fold(ArrayRef<Attribute> operands)531 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
532 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
533 if (lhs == nullptr)
534 return {};
535 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
536 if (rhs == nullptr)
537 return {};
538 return BoolAttr::get(lhs == rhs, getContext());
539 }
540
541 //===----------------------------------------------------------------------===//
542 // IndexToSizeOp
543 //===----------------------------------------------------------------------===//
544
fold(ArrayRef<Attribute> operands)545 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
546 // Constant values of both types, `shape.size` and `index`, are represented as
547 // `IntegerAttr`s which makes constant folding simple.
548 if (Attribute arg = operands[0])
549 return arg;
550 return {};
551 }
552
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)553 void IndexToSizeOp::getCanonicalizationPatterns(
554 OwningRewritePatternList &patterns, MLIRContext *context) {
555 patterns.insert<SizeToIndexToSizeCanonicalization>(context);
556 }
557
558 //===----------------------------------------------------------------------===//
559 // FromExtentsOp
560 //===----------------------------------------------------------------------===//
561
fold(ArrayRef<Attribute> operands)562 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
563 if (llvm::any_of(operands, [](Attribute a) { return !a; }))
564 return nullptr;
565 SmallVector<int64_t, 6> extents;
566 for (auto attr : operands)
567 extents.push_back(attr.cast<IntegerAttr>().getInt());
568 Builder builder(getContext());
569 return builder.getIndexTensorAttr(extents);
570 }
571
572 //===----------------------------------------------------------------------===//
573 // FunctionLibraryOp
574 //===----------------------------------------------------------------------===//
575
build(OpBuilder & builder,OperationState & result,StringRef name)576 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
577 StringRef name) {
578 ensureTerminator(*result.addRegion(), builder, result.location);
579 result.attributes.push_back(builder.getNamedAttr(
580 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
581 }
582
getShapeFunction(Operation * op)583 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
584 auto attr = mapping()
585 .get(op->getName().getIdentifier())
586 .dyn_cast_or_null<FlatSymbolRefAttr>();
587 if (!attr)
588 return nullptr;
589 return lookupSymbol<FuncOp>(attr);
590 }
591
parseFunctionLibraryOp(OpAsmParser & parser,OperationState & result)592 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
593 OperationState &result) {
594 // Parse the op name.
595 StringAttr nameAttr;
596 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
597 result.attributes))
598 return failure();
599
600 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
601 return failure();
602
603 auto *bodyRegion = result.addRegion();
604 if (parser.parseRegion(*bodyRegion))
605 return failure();
606
607 FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
608 result.location);
609 if (parser.parseKeyword("mapping"))
610 return failure();
611
612 DictionaryAttr mappingAttr;
613 if (parser.parseAttribute(mappingAttr,
614 parser.getBuilder().getType<NoneType>(), "mapping",
615 result.attributes))
616 return failure();
617 return success();
618 }
619
print(OpAsmPrinter & p,FunctionLibraryOp op)620 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
621 p << op.getOperationName() << ' ';
622 p.printSymbolName(op.getName());
623 p.printOptionalAttrDictWithKeyword(
624 op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
625 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
626 /*printBlockTerminators=*/false);
627 p << " mapping ";
628 p.printAttributeWithoutType(op.mappingAttr());
629 }
630
631 //===----------------------------------------------------------------------===//
632 // GetExtentOp
633 //===----------------------------------------------------------------------===//
634
getConstantDim()635 Optional<int64_t> GetExtentOp::getConstantDim() {
636 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
637 return constSizeOp.value().getLimitedValue();
638 if (auto constantOp = dim().getDefiningOp<ConstantOp>())
639 return constantOp.value().cast<IntegerAttr>().getInt();
640 return llvm::None;
641 }
642
fold(ArrayRef<Attribute> operands)643 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
644 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
645 if (!elements)
646 return nullptr;
647 Optional<int64_t> dim = getConstantDim();
648 if (!dim.hasValue())
649 return nullptr;
650 if (dim.getValue() >= elements.getNumElements())
651 return nullptr;
652 return elements.getValue({(uint64_t)dim.getValue()});
653 }
654
build(OpBuilder & builder,OperationState & result,Value shape,int64_t dim)655 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
656 int64_t dim) {
657 auto loc = result.location;
658 auto dimAttr = builder.getIndexAttr(dim);
659 if (shape.getType().isa<ShapeType>()) {
660 Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
661 build(builder, result, builder.getType<SizeType>(), shape, dim);
662 } else {
663 Value dim =
664 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
665 build(builder, result, builder.getIndexType(), shape, dim);
666 }
667 }
668
669 //===----------------------------------------------------------------------===//
670 // RankOp
671 //===----------------------------------------------------------------------===//
672
fold(ArrayRef<Attribute> operands)673 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
674 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
675 if (!shape)
676 return {};
677 int64_t rank = shape.getNumElements();
678 Builder builder(getContext());
679 return builder.getIndexAttr(rank);
680 }
681
682 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
683 /// Constant folding fails in cases where only the rank is constant, not the
684 /// shape itself.
685 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
686 ///
687 /// Example:
688 ///
689 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
690 /// %rank = shape.rank %shape
691 ///
692 /// becomes
693 ///
694 /// %rank = shape.const_size 3
695
696 namespace {
697 struct RankShapeOfCanonicalizationPattern
698 : public OpRewritePattern<shape::RankOp> {
699 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
700
matchAndRewrite__anonf9874f600f11::RankShapeOfCanonicalizationPattern701 LogicalResult matchAndRewrite(shape::RankOp op,
702 PatternRewriter &rewriter) const override {
703 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
704 if (!shapeOfOp)
705 return failure();
706 auto rankedTensorType =
707 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
708 if (!rankedTensorType)
709 return failure();
710 int64_t rank = rankedTensorType.getRank();
711 if (op.getType().isa<IndexType>()) {
712 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
713 } else if (op.getType().isa<shape::SizeType>()) {
714 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
715 } else {
716 return failure();
717 }
718 return success();
719 }
720 };
721 } // namespace
722
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)723 void shape::RankOp::getCanonicalizationPatterns(
724 OwningRewritePatternList &patterns, MLIRContext *context) {
725 patterns.insert<RankShapeOfCanonicalizationPattern>(context);
726 }
727
728 //===----------------------------------------------------------------------===//
729 // NumElementsOp
730 //===----------------------------------------------------------------------===//
731
fold(ArrayRef<Attribute> operands)732 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
733
734 // Fold only when argument constant.
735 Attribute shape = operands[0];
736 if (!shape)
737 return {};
738
739 APInt product(64, 1);
740 for (auto value : shape.cast<DenseIntElementsAttr>())
741 product *= value;
742 Builder builder(getContext());
743 return builder.getIndexAttr(product.getLimitedValue());
744 }
745
build(OpBuilder & builder,OperationState & result,Value shape)746 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
747 Value shape) {
748 if (shape.getType().isa<ShapedType>()) {
749 auto type = builder.getIndexType();
750 return build(builder, result, type, shape);
751 }
752 auto type = SizeType::get(builder.getContext());
753 return build(builder, result, type, shape);
754 }
755
756 //===----------------------------------------------------------------------===//
757 // MulOp
758 //===----------------------------------------------------------------------===//
759
fold(ArrayRef<Attribute> operands)760 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
761 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
762 if (!lhs)
763 return nullptr;
764 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
765 if (!rhs)
766 return nullptr;
767 APInt folded = lhs.getValue() * rhs.getValue();
768 Type indexTy = IndexType::get(getContext());
769 return IntegerAttr::get(indexTy, folded);
770 }
771
772 //===----------------------------------------------------------------------===//
773 // ShapeOfOp
774 //===----------------------------------------------------------------------===//
775
fold(ArrayRef<Attribute>)776 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
777 auto type = getOperand().getType().dyn_cast<ShapedType>();
778 if (!type || !type.hasStaticShape())
779 return nullptr;
780 Builder builder(getContext());
781 return builder.getIndexTensorAttr(type.getShape());
782 }
783
build(OpBuilder & builder,OperationState & result,Value arg)784 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
785 Type type = arg.getType().isa<ShapedType>()
786 ? (Type)getExtentTensorType(builder.getContext())
787 : (Type)builder.getType<ShapeType>();
788 return ShapeOfOp::build(builder, result, type, arg);
789 }
790
791 namespace {
792 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
793 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
794
matchAndRewrite__anonf9874f601011::ShapeOfWithTensor795 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
796 PatternRewriter &rewriter) const override {
797 if (!op.arg().getType().isa<ShapedType>())
798 return failure();
799 if (op.getType().isa<ShapedType>())
800 return failure();
801
802 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
803 return success();
804 }
805 };
806 } // namespace
807
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)808 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
809 MLIRContext *context) {
810 patterns.insert<ShapeOfWithTensor>(context);
811 }
812
813 //===----------------------------------------------------------------------===//
814 // SizeToIndexOp
815 //===----------------------------------------------------------------------===//
816
fold(ArrayRef<Attribute> operands)817 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
818 // Constant values of both types, `shape.size` and `index`, are represented as
819 // `IntegerAttr`s which makes constant folding simple.
820 if (Attribute arg = operands[0])
821 return arg;
822 return impl::foldCastOp(*this);
823 }
824
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)825 void SizeToIndexOp::getCanonicalizationPatterns(
826 OwningRewritePatternList &patterns, MLIRContext *context) {
827 patterns.insert<IndexToSizeToIndexCanonicalization>(context);
828 }
829
830 //===----------------------------------------------------------------------===//
831 // YieldOp
832 //===----------------------------------------------------------------------===//
833
verify(shape::YieldOp op)834 static LogicalResult verify(shape::YieldOp op) {
835 auto *parentOp = op->getParentOp();
836 auto results = parentOp->getResults();
837 auto operands = op.getOperands();
838
839 if (parentOp->getNumResults() != op.getNumOperands())
840 return op.emitOpError() << "number of operands does not match number of "
841 "results of its parent";
842 for (auto e : llvm::zip(results, operands))
843 if (std::get<0>(e).getType() != std::get<1>(e).getType())
844 return op.emitOpError()
845 << "types mismatch between yield op and its parent";
846
847 return success();
848 }
849
850 //===----------------------------------------------------------------------===//
851 // SplitAtOp
852 //===----------------------------------------------------------------------===//
853
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)854 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
855 SmallVectorImpl<OpFoldResult> &results) {
856 if (!operands[0] || !operands[1])
857 return failure();
858 auto shapeVec = llvm::to_vector<6>(
859 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
860 auto shape = llvm::makeArrayRef(shapeVec);
861 auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
862 // Verify that the split point is in the correct range.
863 // TODO: Constant fold to an "error".
864 int64_t rank = shape.size();
865 if (!(-rank <= splitPoint && splitPoint <= rank))
866 return failure();
867 if (splitPoint < 0)
868 splitPoint += shape.size();
869 Builder builder(operands[0].getContext());
870 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
871 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
872 return success();
873 }
874
875 //===----------------------------------------------------------------------===//
876 // ToExtentTensorOp
877 //===----------------------------------------------------------------------===//
878
fold(ArrayRef<Attribute> operands)879 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
880 if (!operands[0])
881 return impl::foldCastOp(*this);
882 Builder builder(getContext());
883 auto shape = llvm::to_vector<6>(
884 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
885 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
886 builder.getIndexType());
887 return DenseIntElementsAttr::get(type, shape);
888 }
889
890 //===----------------------------------------------------------------------===//
891 // ReduceOp
892 //===----------------------------------------------------------------------===//
893
build(OpBuilder & builder,OperationState & result,Value shape,ValueRange initVals)894 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
895 ValueRange initVals) {
896 result.addOperands(shape);
897 result.addOperands(initVals);
898
899 Region *bodyRegion = result.addRegion();
900 bodyRegion->push_back(new Block);
901 Block &bodyBlock = bodyRegion->front();
902 bodyBlock.addArgument(builder.getIndexType());
903
904 Type elementType;
905 if (auto tensorType = shape.getType().dyn_cast<TensorType>())
906 elementType = tensorType.getElementType();
907 else
908 elementType = SizeType::get(builder.getContext());
909 bodyBlock.addArgument(elementType);
910
911 for (Type initValType : initVals.getTypes()) {
912 bodyBlock.addArgument(initValType);
913 result.addTypes(initValType);
914 }
915 }
916
verify(ReduceOp op)917 static LogicalResult verify(ReduceOp op) {
918 // Verify block arg types.
919 Block &block = op.region().front();
920
921 // The block takes index, extent, and aggregated values as arguments.
922 auto blockArgsCount = op.initVals().size() + 2;
923 if (block.getNumArguments() != blockArgsCount)
924 return op.emitOpError() << "ReduceOp body is expected to have "
925 << blockArgsCount << " arguments";
926
927 // The first block argument is the index and must always be of type `index`.
928 if (!block.getArgument(0).getType().isa<IndexType>())
929 return op.emitOpError(
930 "argument 0 of ReduceOp body is expected to be of IndexType");
931
932 // The second block argument is the extent and must be of type `size` or
933 // `index`, depending on whether the reduce operation is applied to a shape or
934 // to an extent tensor.
935 Type extentTy = block.getArgument(1).getType();
936 if (op.shape().getType().isa<ShapeType>()) {
937 if (!extentTy.isa<SizeType>())
938 return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
939 "SizeType if the ReduceOp operates on a ShapeType");
940 } else {
941 if (!extentTy.isa<IndexType>())
942 return op.emitOpError(
943 "argument 1 of ReduceOp body is expected to be of IndexType if the "
944 "ReduceOp operates on an extent tensor");
945 }
946
947 for (auto type : llvm::enumerate(op.initVals()))
948 if (block.getArgument(type.index() + 2).getType() != type.value().getType())
949 return op.emitOpError()
950 << "type mismatch between argument " << type.index() + 2
951 << " of ReduceOp body and initial value " << type.index();
952 return success();
953 }
954
parseReduceOp(OpAsmParser & parser,OperationState & result)955 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
956 // Parse operands.
957 SmallVector<OpAsmParser::OperandType, 3> operands;
958 Type shapeOrExtentTensorType;
959 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
960 OpAsmParser::Delimiter::Paren) ||
961 parser.parseColonType(shapeOrExtentTensorType) ||
962 parser.parseOptionalArrowTypeList(result.types))
963 return failure();
964
965 // Resolve operands.
966 auto initVals = llvm::makeArrayRef(operands).drop_front();
967 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
968 result.operands) ||
969 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
970 result.operands))
971 return failure();
972
973 // Parse the body.
974 Region *body = result.addRegion();
975 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
976 return failure();
977
978 // Parse attributes.
979 if (parser.parseOptionalAttrDict(result.attributes))
980 return failure();
981
982 return success();
983 }
984
print(OpAsmPrinter & p,ReduceOp op)985 static void print(OpAsmPrinter &p, ReduceOp op) {
986 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
987 << ") : " << op.shape().getType();
988 p.printOptionalArrowTypeList(op.getResultTypes());
989 p.printRegion(op.region());
990 p.printOptionalAttrDict(op.getAttrs());
991 }
992
993 #define GET_OP_CLASSES
994 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
995