• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Traits.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SmallString.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 using namespace mlir::shape;
24 
25 namespace {
26 #include "ShapeCanonicalization.inc"
27 }
28 
getExtentTensorType(MLIRContext * ctx)29 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
30   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
31 }
32 
isErrorPropagationPossible(TypeRange operandTypes)33 static bool isErrorPropagationPossible(TypeRange operandTypes) {
34   for (Type ty : operandTypes)
35     if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
36       return true;
37   return false;
38 }
39 
verifySizeOrIndexOp(Operation * op)40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41   assert(op != nullptr && op->getNumResults() == 1);
42   Type resultTy = op->getResultTypes().front();
43   if (isErrorPropagationPossible(op->getOperandTypes())) {
44     if (!resultTy.isa<SizeType>())
45       return op->emitOpError()
46              << "if at least one of the operands can hold error values then "
47                 "the result must be of type `size` to propagate them";
48   }
49   return success();
50 }
51 
verifyShapeOrExtentTensorOp(Operation * op)52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53   assert(op != nullptr && op->getNumResults() == 1);
54   Type resultTy = op->getResultTypes().front();
55   if (isErrorPropagationPossible(op->getOperandTypes())) {
56     if (!resultTy.isa<ShapeType>())
57       return op->emitOpError()
58              << "if at least one of the operands can hold error values then "
59                 "the result must be of type `shape` to propagate them";
60   }
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71   using DialectInlinerInterface::DialectInlinerInterface;
72 
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonf9874f600211::ShapeInlinerInterface75   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
isLegalToInline__anonf9874f600211::ShapeInlinerInterface83   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 };
88 } // namespace
89 
initialize()90 void ShapeDialect::initialize() {
91   addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94       >();
95   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
96            WitnessType>();
97   addInterfaces<ShapeInlinerInterface>();
98   // Allow unknown operations during prototyping and testing. As the dialect is
99   // still evolving it makes it simple to start with an unregistered ops and
100   // try different variants before actually defining the op.
101   allowUnknownOperations();
102 }
103 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)104 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
105                                              Attribute value, Type type,
106                                              Location loc) {
107   if (type.isa<ShapeType>() ||
108       type == getExtentTensorType(builder.getContext()))
109     return builder.create<ConstShapeOp>(loc, type,
110                                         value.cast<DenseIntElementsAttr>());
111   if (type.isa<SizeType>())
112     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
113   if (type.isa<WitnessType>())
114     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
115   if (type.isa<IndexType>())
116     return builder.create<ConstantOp>(loc, type, value);
117   return nullptr;
118 }
119 
120 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const121 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
122   StringRef keyword;
123   if (parser.parseKeyword(&keyword))
124     return Type();
125 
126   if (keyword == "component")
127     return ComponentType::get(getContext());
128   if (keyword == "element")
129     return ElementType::get(getContext());
130   if (keyword == "shape")
131     return ShapeType::get(getContext());
132   if (keyword == "size")
133     return SizeType::get(getContext());
134   if (keyword == "value_shape")
135     return ValueShapeType::get(getContext());
136   if (keyword == "witness")
137     return WitnessType::get(getContext());
138 
139   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
140   return Type();
141 }
142 
143 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const144 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
145   TypeSwitch<Type>(type)
146       .Case<ComponentType>([&](Type) { os << "component"; })
147       .Case<ElementType>([&](Type) { os << "element"; })
148       .Case<ShapeType>([&](Type) { os << "shape"; })
149       .Case<SizeType>([&](Type) { os << "size"; })
150       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
151       .Case<WitnessType>([&](Type) { os << "witness"; })
152       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // AnyOp
157 //===----------------------------------------------------------------------===//
158 
159 // TODO: Canonicalization should be implemented for shapes that can be
160 // determined through mixtures of the known dimensions of the inputs.
fold(ArrayRef<Attribute> operands)161 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
162   // Only the last operand is checked because AnyOp is commutative.
163   if (operands.back())
164     return operands.back();
165 
166   return nullptr;
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // AssumingOp
171 //===----------------------------------------------------------------------===//
172 
parseAssumingOp(OpAsmParser & parser,OperationState & result)173 static ParseResult parseAssumingOp(OpAsmParser &parser,
174                                    OperationState &result) {
175   result.regions.reserve(1);
176   Region *doRegion = result.addRegion();
177 
178   auto &builder = parser.getBuilder();
179   OpAsmParser::OperandType cond;
180   if (parser.parseOperand(cond) ||
181       parser.resolveOperand(cond, builder.getType<WitnessType>(),
182                             result.operands))
183     return failure();
184 
185   // Parse optional results type list.
186   if (parser.parseOptionalArrowTypeList(result.types))
187     return failure();
188 
189   // Parse the region and add a terminator if elided.
190   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
191     return failure();
192   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
193 
194   // Parse the optional attribute list.
195   if (parser.parseOptionalAttrDict(result.attributes))
196     return failure();
197   return success();
198 }
199 
print(OpAsmPrinter & p,AssumingOp op)200 static void print(OpAsmPrinter &p, AssumingOp op) {
201   bool yieldsResults = !op.results().empty();
202 
203   p << AssumingOp::getOperationName() << " " << op.witness();
204   if (yieldsResults) {
205     p << " -> (" << op.getResultTypes() << ")";
206   }
207   p.printRegion(op.doRegion(),
208                 /*printEntryBlockArgs=*/false,
209                 /*printBlockTerminators=*/yieldsResults);
210   p.printOptionalAttrDict(op.getAttrs());
211 }
212 
213 namespace {
214 // Removes AssumingOp with a passing witness and inlines the region.
215 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
216   using OpRewritePattern<AssumingOp>::OpRewritePattern;
217 
matchAndRewrite__anonf9874f600a11::AssumingWithTrue218   LogicalResult matchAndRewrite(AssumingOp op,
219                                 PatternRewriter &rewriter) const override {
220     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
221     if (!witness || !witness.passingAttr())
222       return failure();
223 
224     AssumingOp::inlineRegionIntoParent(op, rewriter);
225     return success();
226   }
227 };
228 } // namespace
229 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)230 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
231                                              MLIRContext *context) {
232   // If taking a passing witness, inline region.
233   patterns.insert<AssumingWithTrue>(context);
234 }
235 
236 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)237 void AssumingOp::getSuccessorRegions(
238     Optional<unsigned> index, ArrayRef<Attribute> operands,
239     SmallVectorImpl<RegionSuccessor> &regions) {
240   // AssumingOp has unconditional control flow into the region and back to the
241   // parent, so return the correct RegionSuccessor purely based on the index
242   // being None or 0.
243   if (index.hasValue()) {
244     regions.push_back(RegionSuccessor(getResults()));
245     return;
246   }
247 
248   regions.push_back(RegionSuccessor(&doRegion()));
249 }
250 
inlineRegionIntoParent(AssumingOp & op,PatternRewriter & rewriter)251 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
252                                         PatternRewriter &rewriter) {
253   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
254   auto *assumingBlock = op.getBody();
255   auto initPosition = rewriter.getInsertionPoint();
256   auto *blockAfterAssuming =
257       rewriter.splitBlock(blockBeforeAssuming, initPosition);
258 
259   // Remove the AssumingOp and AssumingYieldOp.
260   auto &yieldOp = assumingBlock->back();
261   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
262   rewriter.replaceOp(op, yieldOp.getOperands());
263   rewriter.eraseOp(&yieldOp);
264 
265   // Merge blocks together as there was no branching behavior from the
266   // AssumingOp.
267   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
268   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // AssumingAllOp
273 //===----------------------------------------------------------------------===//
274 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)275 void AssumingAllOp::getCanonicalizationPatterns(
276     OwningRewritePatternList &patterns, MLIRContext *context) {
277   patterns.insert<AssumingAllOneOp>(context);
278 }
279 
fold(ArrayRef<Attribute> operands)280 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
281   // Iterate in reverse to first handle all constant operands. They are
282   // guaranteed to be the tail of the inputs because this is commutative.
283   for (int idx = operands.size() - 1; idx >= 0; idx--) {
284     Attribute a = operands[idx];
285     // Cannot fold if any inputs are not constant;
286     if (!a)
287       return nullptr;
288 
289     // We do not need to keep statically known values after handling them in
290     // this method.
291     getOperation()->eraseOperand(idx);
292 
293     // Always false if any input is statically known false
294     if (!a.cast<BoolAttr>().getValue())
295       return a;
296   }
297   // If this is reached, all inputs were statically known passing.
298   return BoolAttr::get(true, getContext());
299 }
300 
verify(AssumingAllOp op)301 static LogicalResult verify(AssumingAllOp op) {
302   // Ensure that AssumingAllOp contains at least one operand
303   if (op.getNumOperands() == 0)
304     return op.emitOpError("no operands specified");
305 
306   return success();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // BroadcastOp
311 //===----------------------------------------------------------------------===//
312 
fold(ArrayRef<Attribute> operands)313 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
314   if (!operands[1])
315     return nullptr;
316 
317   auto rhsShape = llvm::to_vector<6>(
318       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
319   if (rhsShape.empty())
320     return lhs();
321 
322   if (!operands[0])
323     return nullptr;
324 
325   auto lhsShape = llvm::to_vector<6>(
326       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
327   if (lhsShape.empty())
328     return rhs();
329 
330   SmallVector<int64_t, 6> resultShape;
331   // If the shapes are not compatible, we can't fold it.
332   // TODO: Fold to an "error".
333   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
334     return nullptr;
335   Builder builder(getContext());
336   return builder.getIndexTensorAttr(resultShape);
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // ConcatOp
341 //===----------------------------------------------------------------------===//
342 
fold(ArrayRef<Attribute> operands)343 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
344   if (!operands[0] || !operands[1])
345     return nullptr;
346   auto lhsShape = llvm::to_vector<6>(
347       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
348   auto rhsShape = llvm::to_vector<6>(
349       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
350   SmallVector<int64_t, 6> resultShape;
351   resultShape.append(lhsShape.begin(), lhsShape.end());
352   resultShape.append(rhsShape.begin(), rhsShape.end());
353   Builder builder(getContext());
354   return builder.getIndexTensorAttr(resultShape);
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // ConstShapeOp
359 //===----------------------------------------------------------------------===//
360 
print(OpAsmPrinter & p,ConstShapeOp & op)361 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
362   p << "shape.const_shape ";
363   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
364   p << "[";
365   interleaveComma(op.shape().getValues<int64_t>(), p,
366                   [&](int64_t i) { p << i; });
367   p << "] : ";
368   p.printType(op.getType());
369 }
370 
parseConstShapeOp(OpAsmParser & parser,OperationState & result)371 static ParseResult parseConstShapeOp(OpAsmParser &parser,
372                                      OperationState &result) {
373   if (parser.parseOptionalAttrDict(result.attributes))
374     return failure();
375   // We piggy-back on ArrayAttr parsing, though we don't internally store the
376   // shape as an ArrayAttr.
377   // TODO: Implement custom parser and maybe make syntax a bit more concise.
378   Attribute extentsRaw;
379   NamedAttrList dummy;
380   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
381     return failure();
382   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
383   if (!extentsArray)
384     return failure();
385   SmallVector<int64_t, 6> ints;
386   for (Attribute extent : extentsArray) {
387     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
388     if (!attr)
389       return failure();
390     ints.push_back(attr.getInt());
391   }
392   Builder &builder = parser.getBuilder();
393   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
394   Type resultTy;
395   if (parser.parseColonType(resultTy))
396     return failure();
397   result.types.push_back(resultTy);
398   return success();
399 }
400 
fold(ArrayRef<Attribute>)401 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
402 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)403 void ConstShapeOp::getCanonicalizationPatterns(
404     OwningRewritePatternList &patterns, MLIRContext *context) {
405   patterns.insert<TensorCastConstShape>(context);
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // CstrBroadcastableOp
410 //===----------------------------------------------------------------------===//
411 
412 namespace {
413 // Given an input shape Value, try to obtain the shape's values.
getShapeVec(Value input,SmallVectorImpl<int64_t> & shapeValues)414 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
415   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
416     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
417     if (!type.hasRank())
418       return failure();
419     shapeValues = llvm::to_vector<6>(type.getShape());
420     return success();
421   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
422     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
423     return success();
424   } else {
425     return failure();
426   }
427 }
428 } // namespace
429 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)430 void CstrBroadcastableOp::getCanonicalizationPatterns(
431     OwningRewritePatternList &patterns, MLIRContext *context) {
432   // Canonicalization patterns have overlap with the considerations during
433   // folding in case additional shape information is inferred at some point that
434   // does not result in folding.
435   patterns.insert<CstrBroadcastableEqOps>(context);
436 }
437 
fold(ArrayRef<Attribute> operands)438 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
439   // Both operands are not needed if one is a scalar.
440   if (operands[0] &&
441       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
442     return BoolAttr::get(true, getContext());
443   if (operands[1] &&
444       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
445     return BoolAttr::get(true, getContext());
446 
447   if (operands[0] && operands[1]) {
448     auto lhsShape = llvm::to_vector<6>(
449         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
450     auto rhsShape = llvm::to_vector<6>(
451         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
452     SmallVector<int64_t, 6> resultShape;
453     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
454       return BoolAttr::get(true, getContext());
455   }
456 
457   // Lastly, see if folding can be completed based on what constraints are known
458   // on the input shapes.
459   SmallVector<int64_t, 6> lhsShape, rhsShape;
460   if (failed(getShapeVec(lhs(), lhsShape)))
461     return nullptr;
462   if (failed(getShapeVec(rhs(), rhsShape)))
463     return nullptr;
464 
465   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
466     return BoolAttr::get(true, getContext());
467 
468   // Because a failing witness result here represents an eventual assertion
469   // failure, we do not replace it with a constant witness.
470   return nullptr;
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // CstrEqOp
475 //===----------------------------------------------------------------------===//
476 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)477 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
478                                            MLIRContext *context) {
479   // If inputs are equal, return passing witness
480   patterns.insert<CstrEqEqOps>(context);
481 }
482 
fold(ArrayRef<Attribute> operands)483 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
484   if (llvm::all_of(operands,
485                    [&](Attribute a) { return a && a == operands[0]; }))
486     return BoolAttr::get(true, getContext());
487 
488   // Because a failing witness result here represents an eventual assertion
489   // failure, we do not try to replace it with a constant witness. Similarly, we
490   // cannot if there are any non-const inputs.
491   return nullptr;
492 }
493 
494 //===----------------------------------------------------------------------===//
495 // ConstSizeOp
496 //===----------------------------------------------------------------------===//
497 
build(OpBuilder & builder,OperationState & result,int64_t value)498 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
499                         int64_t value) {
500   build(builder, result, builder.getIndexAttr(value));
501 }
502 
fold(ArrayRef<Attribute>)503 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
504 
getAsmResultNames(llvm::function_ref<void (Value,StringRef)> setNameFn)505 void ConstSizeOp::getAsmResultNames(
506     llvm::function_ref<void(Value, StringRef)> setNameFn) {
507   SmallString<4> buffer;
508   llvm::raw_svector_ostream os(buffer);
509   os << "c" << value();
510   setNameFn(getResult(), os.str());
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // ConstWitnessOp
515 //===----------------------------------------------------------------------===//
516 
fold(ArrayRef<Attribute>)517 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
518 
519 //===----------------------------------------------------------------------===//
520 // CstrRequireOp
521 //===----------------------------------------------------------------------===//
522 
fold(ArrayRef<Attribute> operands)523 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
524   return operands[0];
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // ShapeEqOp
529 //===----------------------------------------------------------------------===//
530 
fold(ArrayRef<Attribute> operands)531 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
532   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
533   if (lhs == nullptr)
534     return {};
535   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
536   if (rhs == nullptr)
537     return {};
538   return BoolAttr::get(lhs == rhs, getContext());
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // IndexToSizeOp
543 //===----------------------------------------------------------------------===//
544 
fold(ArrayRef<Attribute> operands)545 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
546   // Constant values of both types, `shape.size` and `index`, are represented as
547   // `IntegerAttr`s which makes constant folding simple.
548   if (Attribute arg = operands[0])
549     return arg;
550   return {};
551 }
552 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)553 void IndexToSizeOp::getCanonicalizationPatterns(
554     OwningRewritePatternList &patterns, MLIRContext *context) {
555   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // FromExtentsOp
560 //===----------------------------------------------------------------------===//
561 
fold(ArrayRef<Attribute> operands)562 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
563   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
564     return nullptr;
565   SmallVector<int64_t, 6> extents;
566   for (auto attr : operands)
567     extents.push_back(attr.cast<IntegerAttr>().getInt());
568   Builder builder(getContext());
569   return builder.getIndexTensorAttr(extents);
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // FunctionLibraryOp
574 //===----------------------------------------------------------------------===//
575 
build(OpBuilder & builder,OperationState & result,StringRef name)576 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
577                               StringRef name) {
578   ensureTerminator(*result.addRegion(), builder, result.location);
579   result.attributes.push_back(builder.getNamedAttr(
580       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
581 }
582 
getShapeFunction(Operation * op)583 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
584   auto attr = mapping()
585                   .get(op->getName().getIdentifier())
586                   .dyn_cast_or_null<FlatSymbolRefAttr>();
587   if (!attr)
588     return nullptr;
589   return lookupSymbol<FuncOp>(attr);
590 }
591 
parseFunctionLibraryOp(OpAsmParser & parser,OperationState & result)592 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
593                                    OperationState &result) {
594   // Parse the op name.
595   StringAttr nameAttr;
596   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
597                              result.attributes))
598     return failure();
599 
600   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
601     return failure();
602 
603   auto *bodyRegion = result.addRegion();
604   if (parser.parseRegion(*bodyRegion))
605     return failure();
606 
607   FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
608                                       result.location);
609   if (parser.parseKeyword("mapping"))
610     return failure();
611 
612   DictionaryAttr mappingAttr;
613   if (parser.parseAttribute(mappingAttr,
614                             parser.getBuilder().getType<NoneType>(), "mapping",
615                             result.attributes))
616     return failure();
617   return success();
618 }
619 
print(OpAsmPrinter & p,FunctionLibraryOp op)620 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
621   p << op.getOperationName() << ' ';
622   p.printSymbolName(op.getName());
623   p.printOptionalAttrDictWithKeyword(
624       op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
625   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
626                 /*printBlockTerminators=*/false);
627   p << " mapping ";
628   p.printAttributeWithoutType(op.mappingAttr());
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // GetExtentOp
633 //===----------------------------------------------------------------------===//
634 
getConstantDim()635 Optional<int64_t> GetExtentOp::getConstantDim() {
636   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
637     return constSizeOp.value().getLimitedValue();
638   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
639     return constantOp.value().cast<IntegerAttr>().getInt();
640   return llvm::None;
641 }
642 
fold(ArrayRef<Attribute> operands)643 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
644   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
645   if (!elements)
646     return nullptr;
647   Optional<int64_t> dim = getConstantDim();
648   if (!dim.hasValue())
649     return nullptr;
650   if (dim.getValue() >= elements.getNumElements())
651     return nullptr;
652   return elements.getValue({(uint64_t)dim.getValue()});
653 }
654 
build(OpBuilder & builder,OperationState & result,Value shape,int64_t dim)655 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
656                         int64_t dim) {
657   auto loc = result.location;
658   auto dimAttr = builder.getIndexAttr(dim);
659   if (shape.getType().isa<ShapeType>()) {
660     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
661     build(builder, result, builder.getType<SizeType>(), shape, dim);
662   } else {
663     Value dim =
664         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
665     build(builder, result, builder.getIndexType(), shape, dim);
666   }
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // RankOp
671 //===----------------------------------------------------------------------===//
672 
fold(ArrayRef<Attribute> operands)673 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
674   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
675   if (!shape)
676     return {};
677   int64_t rank = shape.getNumElements();
678   Builder builder(getContext());
679   return builder.getIndexAttr(rank);
680 }
681 
682 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
683 /// Constant folding fails in cases where only the rank is constant, not the
684 /// shape itself.
685 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
686 ///
687 /// Example:
688 ///
689 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
690 /// %rank = shape.rank %shape
691 ///
692 /// becomes
693 ///
694 /// %rank = shape.const_size 3
695 
696 namespace {
697 struct RankShapeOfCanonicalizationPattern
698     : public OpRewritePattern<shape::RankOp> {
699   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
700 
matchAndRewrite__anonf9874f600f11::RankShapeOfCanonicalizationPattern701   LogicalResult matchAndRewrite(shape::RankOp op,
702                                 PatternRewriter &rewriter) const override {
703     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
704     if (!shapeOfOp)
705       return failure();
706     auto rankedTensorType =
707         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
708     if (!rankedTensorType)
709       return failure();
710     int64_t rank = rankedTensorType.getRank();
711     if (op.getType().isa<IndexType>()) {
712       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
713     } else if (op.getType().isa<shape::SizeType>()) {
714       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
715     } else {
716       return failure();
717     }
718     return success();
719   }
720 };
721 } // namespace
722 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)723 void shape::RankOp::getCanonicalizationPatterns(
724     OwningRewritePatternList &patterns, MLIRContext *context) {
725   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // NumElementsOp
730 //===----------------------------------------------------------------------===//
731 
fold(ArrayRef<Attribute> operands)732 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
733 
734   // Fold only when argument constant.
735   Attribute shape = operands[0];
736   if (!shape)
737     return {};
738 
739   APInt product(64, 1);
740   for (auto value : shape.cast<DenseIntElementsAttr>())
741     product *= value;
742   Builder builder(getContext());
743   return builder.getIndexAttr(product.getLimitedValue());
744 }
745 
build(OpBuilder & builder,OperationState & result,Value shape)746 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
747                           Value shape) {
748   if (shape.getType().isa<ShapedType>()) {
749     auto type = builder.getIndexType();
750     return build(builder, result, type, shape);
751   }
752   auto type = SizeType::get(builder.getContext());
753   return build(builder, result, type, shape);
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // MulOp
758 //===----------------------------------------------------------------------===//
759 
fold(ArrayRef<Attribute> operands)760 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
761   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
762   if (!lhs)
763     return nullptr;
764   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
765   if (!rhs)
766     return nullptr;
767   APInt folded = lhs.getValue() * rhs.getValue();
768   Type indexTy = IndexType::get(getContext());
769   return IntegerAttr::get(indexTy, folded);
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // ShapeOfOp
774 //===----------------------------------------------------------------------===//
775 
fold(ArrayRef<Attribute>)776 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
777   auto type = getOperand().getType().dyn_cast<ShapedType>();
778   if (!type || !type.hasStaticShape())
779     return nullptr;
780   Builder builder(getContext());
781   return builder.getIndexTensorAttr(type.getShape());
782 }
783 
build(OpBuilder & builder,OperationState & result,Value arg)784 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
785   Type type = arg.getType().isa<ShapedType>()
786                   ? (Type)getExtentTensorType(builder.getContext())
787                   : (Type)builder.getType<ShapeType>();
788   return ShapeOfOp::build(builder, result, type, arg);
789 }
790 
791 namespace {
792 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
793   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
794 
matchAndRewrite__anonf9874f601011::ShapeOfWithTensor795   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
796                                 PatternRewriter &rewriter) const override {
797     if (!op.arg().getType().isa<ShapedType>())
798       return failure();
799     if (op.getType().isa<ShapedType>())
800       return failure();
801 
802     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
803     return success();
804   }
805 };
806 } // namespace
807 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)808 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
809                                             MLIRContext *context) {
810   patterns.insert<ShapeOfWithTensor>(context);
811 }
812 
813 //===----------------------------------------------------------------------===//
814 // SizeToIndexOp
815 //===----------------------------------------------------------------------===//
816 
fold(ArrayRef<Attribute> operands)817 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
818   // Constant values of both types, `shape.size` and `index`, are represented as
819   // `IntegerAttr`s which makes constant folding simple.
820   if (Attribute arg = operands[0])
821     return arg;
822   return impl::foldCastOp(*this);
823 }
824 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)825 void SizeToIndexOp::getCanonicalizationPatterns(
826     OwningRewritePatternList &patterns, MLIRContext *context) {
827   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // YieldOp
832 //===----------------------------------------------------------------------===//
833 
verify(shape::YieldOp op)834 static LogicalResult verify(shape::YieldOp op) {
835   auto *parentOp = op->getParentOp();
836   auto results = parentOp->getResults();
837   auto operands = op.getOperands();
838 
839   if (parentOp->getNumResults() != op.getNumOperands())
840     return op.emitOpError() << "number of operands does not match number of "
841                                "results of its parent";
842   for (auto e : llvm::zip(results, operands))
843     if (std::get<0>(e).getType() != std::get<1>(e).getType())
844       return op.emitOpError()
845              << "types mismatch between yield op and its parent";
846 
847   return success();
848 }
849 
850 //===----------------------------------------------------------------------===//
851 // SplitAtOp
852 //===----------------------------------------------------------------------===//
853 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)854 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
855                               SmallVectorImpl<OpFoldResult> &results) {
856   if (!operands[0] || !operands[1])
857     return failure();
858   auto shapeVec = llvm::to_vector<6>(
859       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
860   auto shape = llvm::makeArrayRef(shapeVec);
861   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
862   // Verify that the split point is in the correct range.
863   // TODO: Constant fold to an "error".
864   int64_t rank = shape.size();
865   if (!(-rank <= splitPoint && splitPoint <= rank))
866     return failure();
867   if (splitPoint < 0)
868     splitPoint += shape.size();
869   Builder builder(operands[0].getContext());
870   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
871   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
872   return success();
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // ToExtentTensorOp
877 //===----------------------------------------------------------------------===//
878 
fold(ArrayRef<Attribute> operands)879 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
880   if (!operands[0])
881     return impl::foldCastOp(*this);
882   Builder builder(getContext());
883   auto shape = llvm::to_vector<6>(
884       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
885   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
886                                     builder.getIndexType());
887   return DenseIntElementsAttr::get(type, shape);
888 }
889 
890 //===----------------------------------------------------------------------===//
891 // ReduceOp
892 //===----------------------------------------------------------------------===//
893 
build(OpBuilder & builder,OperationState & result,Value shape,ValueRange initVals)894 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
895                      ValueRange initVals) {
896   result.addOperands(shape);
897   result.addOperands(initVals);
898 
899   Region *bodyRegion = result.addRegion();
900   bodyRegion->push_back(new Block);
901   Block &bodyBlock = bodyRegion->front();
902   bodyBlock.addArgument(builder.getIndexType());
903 
904   Type elementType;
905   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
906     elementType = tensorType.getElementType();
907   else
908     elementType = SizeType::get(builder.getContext());
909   bodyBlock.addArgument(elementType);
910 
911   for (Type initValType : initVals.getTypes()) {
912     bodyBlock.addArgument(initValType);
913     result.addTypes(initValType);
914   }
915 }
916 
verify(ReduceOp op)917 static LogicalResult verify(ReduceOp op) {
918   // Verify block arg types.
919   Block &block = op.region().front();
920 
921   // The block takes index, extent, and aggregated values as arguments.
922   auto blockArgsCount = op.initVals().size() + 2;
923   if (block.getNumArguments() != blockArgsCount)
924     return op.emitOpError() << "ReduceOp body is expected to have "
925                             << blockArgsCount << " arguments";
926 
927   // The first block argument is the index and must always be of type `index`.
928   if (!block.getArgument(0).getType().isa<IndexType>())
929     return op.emitOpError(
930         "argument 0 of ReduceOp body is expected to be of IndexType");
931 
932   // The second block argument is the extent and must be of type `size` or
933   // `index`, depending on whether the reduce operation is applied to a shape or
934   // to an extent tensor.
935   Type extentTy = block.getArgument(1).getType();
936   if (op.shape().getType().isa<ShapeType>()) {
937     if (!extentTy.isa<SizeType>())
938       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
939                             "SizeType if the ReduceOp operates on a ShapeType");
940   } else {
941     if (!extentTy.isa<IndexType>())
942       return op.emitOpError(
943           "argument 1 of ReduceOp body is expected to be of IndexType if the "
944           "ReduceOp operates on an extent tensor");
945   }
946 
947   for (auto type : llvm::enumerate(op.initVals()))
948     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
949       return op.emitOpError()
950              << "type mismatch between argument " << type.index() + 2
951              << " of ReduceOp body and initial value " << type.index();
952   return success();
953 }
954 
parseReduceOp(OpAsmParser & parser,OperationState & result)955 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
956   // Parse operands.
957   SmallVector<OpAsmParser::OperandType, 3> operands;
958   Type shapeOrExtentTensorType;
959   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
960                               OpAsmParser::Delimiter::Paren) ||
961       parser.parseColonType(shapeOrExtentTensorType) ||
962       parser.parseOptionalArrowTypeList(result.types))
963     return failure();
964 
965   // Resolve operands.
966   auto initVals = llvm::makeArrayRef(operands).drop_front();
967   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
968                             result.operands) ||
969       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
970                              result.operands))
971     return failure();
972 
973   // Parse the body.
974   Region *body = result.addRegion();
975   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
976     return failure();
977 
978   // Parse attributes.
979   if (parser.parseOptionalAttrDict(result.attributes))
980     return failure();
981 
982   return success();
983 }
984 
print(OpAsmPrinter & p,ReduceOp op)985 static void print(OpAsmPrinter &p, ReduceOp op) {
986   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
987     << ") : " << op.shape().getType();
988   p.printOptionalArrowTypeList(op.getResultTypes());
989   p.printRegion(op.region());
990   p.printOptionalAttrDict(op.getAttrs());
991 }
992 
993 #define GET_OP_CLASSES
994 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
995