//===- Shape.cpp - MLIR Shape Operations ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::shape; namespace { #include "ShapeCanonicalization.inc" } RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); } static bool isErrorPropagationPossible(TypeRange operandTypes) { for (Type ty : operandTypes) if (ty.isa() || ty.isa() || ty.isa()) return true; return false; } static LogicalResult verifySizeOrIndexOp(Operation *op) { assert(op != nullptr && op->getNumResults() == 1); Type resultTy = op->getResultTypes().front(); if (isErrorPropagationPossible(op->getOperandTypes())) { if (!resultTy.isa()) return op->emitOpError() << "if at least one of the operands can hold error values then " "the result must be of type `size` to propagate them"; } return success(); } static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { assert(op != nullptr && op->getNumResults() == 1); Type resultTy = op->getResultTypes().front(); if (isErrorPropagationPossible(op->getOperandTypes())) { if (!resultTy.isa()) return op->emitOpError() << "if at least one of the operands can hold error values then " "the result must be of type `shape` to propagate them"; } return success(); } //===----------------------------------------------------------------------===// // InlinerInterface //===----------------------------------------------------------------------===// namespace { /// This class defines the interface for inlining shape dialect ops. struct ShapeInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; // Returns true if the given region 'src' can be inlined into the region // 'dest' that is attached to an operation registered to the current dialect. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &) const final { return true; } // Returns true if the given operation 'op', that is registered to this // dialect, can be inlined into the region 'dest' that is attached to an // operation registered to the current dialect. bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, BlockAndValueMapping &) const final { return true; } }; } // namespace void ShapeDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" >(); addTypes(); addInterfaces(); // Allow unknown operations during prototyping and testing. As the dialect is // still evolving it makes it simple to start with an unregistered ops and // try different variants before actually defining the op. allowUnknownOperations(); } Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (type.isa() || type == getExtentTensorType(builder.getContext())) return builder.create(loc, type, value.cast()); if (type.isa()) return builder.create(loc, type, value.cast()); if (type.isa()) return builder.create(loc, type, value.cast()); if (type.isa()) return builder.create(loc, type, value); return nullptr; } /// Parse a type registered to this dialect. Type ShapeDialect::parseType(DialectAsmParser &parser) const { StringRef keyword; if (parser.parseKeyword(&keyword)) return Type(); if (keyword == "component") return ComponentType::get(getContext()); if (keyword == "element") return ElementType::get(getContext()); if (keyword == "shape") return ShapeType::get(getContext()); if (keyword == "size") return SizeType::get(getContext()); if (keyword == "value_shape") return ValueShapeType::get(getContext()); if (keyword == "witness") return WitnessType::get(getContext()); parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; return Type(); } /// Print a type registered to this dialect. void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "component"; }) .Case([&](Type) { os << "element"; }) .Case([&](Type) { os << "shape"; }) .Case([&](Type) { os << "size"; }) .Case([&](Type) { os << "value_shape"; }) .Case([&](Type) { os << "witness"; }) .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); } //===----------------------------------------------------------------------===// // AnyOp //===----------------------------------------------------------------------===// // TODO: Canonicalization should be implemented for shapes that can be // determined through mixtures of the known dimensions of the inputs. OpFoldResult AnyOp::fold(ArrayRef operands) { // Only the last operand is checked because AnyOp is commutative. if (operands.back()) return operands.back(); return nullptr; } //===----------------------------------------------------------------------===// // AssumingOp //===----------------------------------------------------------------------===// static ParseResult parseAssumingOp(OpAsmParser &parser, OperationState &result) { result.regions.reserve(1); Region *doRegion = result.addRegion(); auto &builder = parser.getBuilder(); OpAsmParser::OperandType cond; if (parser.parseOperand(cond) || parser.resolveOperand(cond, builder.getType(), result.operands)) return failure(); // Parse optional results type list. if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Parse the region and add a terminator if elided. if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } static void print(OpAsmPrinter &p, AssumingOp op) { bool yieldsResults = !op.results().empty(); p << AssumingOp::getOperationName() << " " << op.witness(); if (yieldsResults) { p << " -> (" << op.getResultTypes() << ")"; } p.printRegion(op.doRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/yieldsResults); p.printOptionalAttrDict(op.getAttrs()); } namespace { // Removes AssumingOp with a passing witness and inlines the region. struct AssumingWithTrue : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AssumingOp op, PatternRewriter &rewriter) const override { auto witness = op.witness().getDefiningOp(); if (!witness || !witness.passingAttr()) return failure(); AssumingOp::inlineRegionIntoParent(op, rewriter); return success(); } }; } // namespace void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { // If taking a passing witness, inline region. patterns.insert(context); } // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td void AssumingOp::getSuccessorRegions( Optional index, ArrayRef operands, SmallVectorImpl ®ions) { // AssumingOp has unconditional control flow into the region and back to the // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. if (index.hasValue()) { regions.push_back(RegionSuccessor(getResults())); return; } regions.push_back(RegionSuccessor(&doRegion())); } void AssumingOp::inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter) { auto *blockBeforeAssuming = rewriter.getInsertionBlock(); auto *assumingBlock = op.getBody(); auto initPosition = rewriter.getInsertionPoint(); auto *blockAfterAssuming = rewriter.splitBlock(blockBeforeAssuming, initPosition); // Remove the AssumingOp and AssumingYieldOp. auto &yieldOp = assumingBlock->back(); rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); rewriter.replaceOp(op, yieldOp.getOperands()); rewriter.eraseOp(&yieldOp); // Merge blocks together as there was no branching behavior from the // AssumingOp. rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); } //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// void AssumingAllOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } OpFoldResult AssumingAllOp::fold(ArrayRef operands) { // Iterate in reverse to first handle all constant operands. They are // guaranteed to be the tail of the inputs because this is commutative. for (int idx = operands.size() - 1; idx >= 0; idx--) { Attribute a = operands[idx]; // Cannot fold if any inputs are not constant; if (!a) return nullptr; // We do not need to keep statically known values after handling them in // this method. getOperation()->eraseOperand(idx); // Always false if any input is statically known false if (!a.cast().getValue()) return a; } // If this is reached, all inputs were statically known passing. return BoolAttr::get(true, getContext()); } static LogicalResult verify(AssumingAllOp op) { // Ensure that AssumingAllOp contains at least one operand if (op.getNumOperands() == 0) return op.emitOpError("no operands specified"); return success(); } //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// OpFoldResult BroadcastOp::fold(ArrayRef operands) { if (!operands[1]) return nullptr; auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); if (rhsShape.empty()) return lhs(); if (!operands[0]) return nullptr; auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); if (lhsShape.empty()) return rhs(); SmallVector resultShape; // If the shapes are not compatible, we can't fold it. // TODO: Fold to an "error". if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) return nullptr; Builder builder(getContext()); return builder.getIndexTensorAttr(resultShape); } //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// OpFoldResult ConcatOp::fold(ArrayRef operands) { if (!operands[0] || !operands[1]) return nullptr; auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); SmallVector resultShape; resultShape.append(lhsShape.begin(), lhsShape.end()); resultShape.append(rhsShape.begin(), rhsShape.end()); Builder builder(getContext()); return builder.getIndexTensorAttr(resultShape); } //===----------------------------------------------------------------------===// // ConstShapeOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ConstShapeOp &op) { p << "shape.const_shape "; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); p << "["; interleaveComma(op.shape().getValues(), p, [&](int64_t i) { p << i; }); p << "] : "; p.printType(op.getType()); } static ParseResult parseConstShapeOp(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); // We piggy-back on ArrayAttr parsing, though we don't internally store the // shape as an ArrayAttr. // TODO: Implement custom parser and maybe make syntax a bit more concise. Attribute extentsRaw; NamedAttrList dummy; if (parser.parseAttribute(extentsRaw, "dummy", dummy)) return failure(); auto extentsArray = extentsRaw.dyn_cast(); if (!extentsArray) return failure(); SmallVector ints; for (Attribute extent : extentsArray) { IntegerAttr attr = extent.dyn_cast(); if (!attr) return failure(); ints.push_back(attr.getInt()); } Builder &builder = parser.getBuilder(); result.addAttribute("shape", builder.getIndexTensorAttr(ints)); Type resultTy; if (parser.parseColonType(resultTy)) return failure(); result.types.push_back(resultTy); return success(); } OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } void ConstShapeOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // CstrBroadcastableOp //===----------------------------------------------------------------------===// namespace { // Given an input shape Value, try to obtain the shape's values. LogicalResult getShapeVec(Value input, SmallVectorImpl &shapeValues) { if (auto inputOp = input.getDefiningOp()) { auto type = inputOp.arg().getType().dyn_cast(); if (!type.hasRank()) return failure(); shapeValues = llvm::to_vector<6>(type.getShape()); return success(); } else if (auto inputOp = input.getDefiningOp()) { shapeValues = llvm::to_vector<6>(inputOp.shape().getValues()); return success(); } else { return failure(); } } } // namespace void CstrBroadcastableOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { // Canonicalization patterns have overlap with the considerations during // folding in case additional shape information is inferred at some point that // does not result in folding. patterns.insert(context); } OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { // Both operands are not needed if one is a scalar. if (operands[0] && operands[0].cast().getNumElements() == 0) return BoolAttr::get(true, getContext()); if (operands[1] && operands[1].cast().getNumElements() == 0) return BoolAttr::get(true, getContext()); if (operands[0] && operands[1]) { auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); SmallVector resultShape; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) return BoolAttr::get(true, getContext()); } // Lastly, see if folding can be completed based on what constraints are known // on the input shapes. SmallVector lhsShape, rhsShape; if (failed(getShapeVec(lhs(), lhsShape))) return nullptr; if (failed(getShapeVec(rhs(), rhsShape))) return nullptr; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) return BoolAttr::get(true, getContext()); // Because a failing witness result here represents an eventual assertion // failure, we do not replace it with a constant witness. return nullptr; } //===----------------------------------------------------------------------===// // CstrEqOp //===----------------------------------------------------------------------===// void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { // If inputs are equal, return passing witness patterns.insert(context); } OpFoldResult CstrEqOp::fold(ArrayRef operands) { if (llvm::all_of(operands, [&](Attribute a) { return a && a == operands[0]; })) return BoolAttr::get(true, getContext()); // Because a failing witness result here represents an eventual assertion // failure, we do not try to replace it with a constant witness. Similarly, we // cannot if there are any non-const inputs. return nullptr; } //===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// void ConstSizeOp::build(OpBuilder &builder, OperationState &result, int64_t value) { build(builder, result, builder.getIndexAttr(value)); } OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } void ConstSizeOp::getAsmResultNames( llvm::function_ref setNameFn) { SmallString<4> buffer; llvm::raw_svector_ostream os(buffer); os << "c" << value(); setNameFn(getResult(), os.str()); } //===----------------------------------------------------------------------===// // ConstWitnessOp //===----------------------------------------------------------------------===// OpFoldResult ConstWitnessOp::fold(ArrayRef) { return passingAttr(); } //===----------------------------------------------------------------------===// // CstrRequireOp //===----------------------------------------------------------------------===// OpFoldResult CstrRequireOp::fold(ArrayRef operands) { return operands[0]; } //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// OpFoldResult ShapeEqOp::fold(ArrayRef operands) { auto lhs = operands[0].dyn_cast_or_null(); if (lhs == nullptr) return {}; auto rhs = operands[1].dyn_cast_or_null(); if (rhs == nullptr) return {}; return BoolAttr::get(lhs == rhs, getContext()); } //===----------------------------------------------------------------------===// // IndexToSizeOp //===----------------------------------------------------------------------===// OpFoldResult IndexToSizeOp::fold(ArrayRef operands) { // Constant values of both types, `shape.size` and `index`, are represented as // `IntegerAttr`s which makes constant folding simple. if (Attribute arg = operands[0]) return arg; return {}; } void IndexToSizeOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // FromExtentsOp //===----------------------------------------------------------------------===// OpFoldResult FromExtentsOp::fold(ArrayRef operands) { if (llvm::any_of(operands, [](Attribute a) { return !a; })) return nullptr; SmallVector extents; for (auto attr : operands) extents.push_back(attr.cast().getInt()); Builder builder(getContext()); return builder.getIndexTensorAttr(extents); } //===----------------------------------------------------------------------===// // FunctionLibraryOp //===----------------------------------------------------------------------===// void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, StringRef name) { ensureTerminator(*result.addRegion(), builder, result.location); result.attributes.push_back(builder.getNamedAttr( ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); } FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { auto attr = mapping() .get(op->getName().getIdentifier()) .dyn_cast_or_null(); if (!attr) return nullptr; return lookupSymbol(attr); } ParseResult parseFunctionLibraryOp(OpAsmParser &parser, OperationState &result) { // Parse the op name. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), result.attributes)) return failure(); if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return failure(); auto *bodyRegion = result.addRegion(); if (parser.parseRegion(*bodyRegion)) return failure(); FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(), result.location); if (parser.parseKeyword("mapping")) return failure(); DictionaryAttr mappingAttr; if (parser.parseAttribute(mappingAttr, parser.getBuilder().getType(), "mapping", result.attributes)) return failure(); return success(); } void print(OpAsmPrinter &p, FunctionLibraryOp op) { p << op.getOperationName() << ' '; p.printSymbolName(op.getName()); p.printOptionalAttrDictWithKeyword( op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); p << " mapping "; p.printAttributeWithoutType(op.mappingAttr()); } //===----------------------------------------------------------------------===// // GetExtentOp //===----------------------------------------------------------------------===// Optional GetExtentOp::getConstantDim() { if (auto constSizeOp = dim().getDefiningOp()) return constSizeOp.value().getLimitedValue(); if (auto constantOp = dim().getDefiningOp()) return constantOp.value().cast().getInt(); return llvm::None; } OpFoldResult GetExtentOp::fold(ArrayRef operands) { auto elements = operands[0].dyn_cast_or_null(); if (!elements) return nullptr; Optional dim = getConstantDim(); if (!dim.hasValue()) return nullptr; if (dim.getValue() >= elements.getNumElements()) return nullptr; return elements.getValue({(uint64_t)dim.getValue()}); } void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, int64_t dim) { auto loc = result.location; auto dimAttr = builder.getIndexAttr(dim); if (shape.getType().isa()) { Value dim = builder.create(loc, dimAttr); build(builder, result, builder.getType(), shape, dim); } else { Value dim = builder.create(loc, builder.getIndexType(), dimAttr); build(builder, result, builder.getIndexType(), shape, dim); } } //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// OpFoldResult shape::RankOp::fold(ArrayRef operands) { auto shape = operands[0].dyn_cast_or_null(); if (!shape) return {}; int64_t rank = shape.getNumElements(); Builder builder(getContext()); return builder.getIndexAttr(rank); } /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. /// Constant folding fails in cases where only the rank is constant, not the /// shape itself. /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. /// /// Example: /// /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> /// %rank = shape.rank %shape /// /// becomes /// /// %rank = shape.const_size 3 namespace { struct RankShapeOfCanonicalizationPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::RankOp op, PatternRewriter &rewriter) const override { auto shapeOfOp = op.shape().getDefiningOp(); if (!shapeOfOp) return failure(); auto rankedTensorType = shapeOfOp.arg().getType().dyn_cast(); if (!rankedTensorType) return failure(); int64_t rank = rankedTensorType.getRank(); if (op.getType().isa()) { rewriter.replaceOpWithNewOp(op.getOperation(), rank); } else if (op.getType().isa()) { rewriter.replaceOpWithNewOp(op.getOperation(), rank); } else { return failure(); } return success(); } }; } // namespace void shape::RankOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// OpFoldResult NumElementsOp::fold(ArrayRef operands) { // Fold only when argument constant. Attribute shape = operands[0]; if (!shape) return {}; APInt product(64, 1); for (auto value : shape.cast()) product *= value; Builder builder(getContext()); return builder.getIndexAttr(product.getLimitedValue()); } void NumElementsOp::build(OpBuilder &builder, OperationState &result, Value shape) { if (shape.getType().isa()) { auto type = builder.getIndexType(); return build(builder, result, type, shape); } auto type = SizeType::get(builder.getContext()); return build(builder, result, type, shape); } //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// OpFoldResult MulOp::fold(ArrayRef operands) { auto lhs = operands[0].dyn_cast_or_null(); if (!lhs) return nullptr; auto rhs = operands[1].dyn_cast_or_null(); if (!rhs) return nullptr; APInt folded = lhs.getValue() * rhs.getValue(); Type indexTy = IndexType::get(getContext()); return IntegerAttr::get(indexTy, folded); } //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// OpFoldResult ShapeOfOp::fold(ArrayRef) { auto type = getOperand().getType().dyn_cast(); if (!type || !type.hasStaticShape()) return nullptr; Builder builder(getContext()); return builder.getIndexTensorAttr(type.getShape()); } void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { Type type = arg.getType().isa() ? (Type)getExtentTensorType(builder.getContext()) : (Type)builder.getType(); return ShapeOfOp::build(builder, result, type, arg); } namespace { struct ShapeOfWithTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::ShapeOfOp op, PatternRewriter &rewriter) const override { if (!op.arg().getType().isa()) return failure(); if (op.getType().isa()) return failure(); rewriter.replaceOpWithNewOp(op.getOperation(), op.arg()); return success(); } }; } // namespace void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// OpFoldResult SizeToIndexOp::fold(ArrayRef operands) { // Constant values of both types, `shape.size` and `index`, are represented as // `IntegerAttr`s which makes constant folding simple. if (Attribute arg = operands[0]) return arg; return impl::foldCastOp(*this); } void SizeToIndexOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// static LogicalResult verify(shape::YieldOp op) { auto *parentOp = op->getParentOp(); auto results = parentOp->getResults(); auto operands = op.getOperands(); if (parentOp->getNumResults() != op.getNumOperands()) return op.emitOpError() << "number of operands does not match number of " "results of its parent"; for (auto e : llvm::zip(results, operands)) if (std::get<0>(e).getType() != std::get<1>(e).getType()) return op.emitOpError() << "types mismatch between yield op and its parent"; return success(); } //===----------------------------------------------------------------------===// // SplitAtOp //===----------------------------------------------------------------------===// LogicalResult SplitAtOp::fold(ArrayRef operands, SmallVectorImpl &results) { if (!operands[0] || !operands[1]) return failure(); auto shapeVec = llvm::to_vector<6>( operands[0].cast().getValues()); auto shape = llvm::makeArrayRef(shapeVec); auto splitPoint = operands[1].cast().getInt(); // Verify that the split point is in the correct range. // TODO: Constant fold to an "error". int64_t rank = shape.size(); if (!(-rank <= splitPoint && splitPoint <= rank)) return failure(); if (splitPoint < 0) splitPoint += shape.size(); Builder builder(operands[0].getContext()); results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); return success(); } //===----------------------------------------------------------------------===// // ToExtentTensorOp //===----------------------------------------------------------------------===// OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { if (!operands[0]) return impl::foldCastOp(*this); Builder builder(getContext()); auto shape = llvm::to_vector<6>( operands[0].cast().getValues()); auto type = RankedTensorType::get({static_cast(shape.size())}, builder.getIndexType()); return DenseIntElementsAttr::get(type, shape); } //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, ValueRange initVals) { result.addOperands(shape); result.addOperands(initVals); Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); bodyBlock.addArgument(builder.getIndexType()); Type elementType; if (auto tensorType = shape.getType().dyn_cast()) elementType = tensorType.getElementType(); else elementType = SizeType::get(builder.getContext()); bodyBlock.addArgument(elementType); for (Type initValType : initVals.getTypes()) { bodyBlock.addArgument(initValType); result.addTypes(initValType); } } static LogicalResult verify(ReduceOp op) { // Verify block arg types. Block &block = op.region().front(); // The block takes index, extent, and aggregated values as arguments. auto blockArgsCount = op.initVals().size() + 2; if (block.getNumArguments() != blockArgsCount) return op.emitOpError() << "ReduceOp body is expected to have " << blockArgsCount << " arguments"; // The first block argument is the index and must always be of type `index`. if (!block.getArgument(0).getType().isa()) return op.emitOpError( "argument 0 of ReduceOp body is expected to be of IndexType"); // The second block argument is the extent and must be of type `size` or // `index`, depending on whether the reduce operation is applied to a shape or // to an extent tensor. Type extentTy = block.getArgument(1).getType(); if (op.shape().getType().isa()) { if (!extentTy.isa()) return op.emitOpError("argument 1 of ReduceOp body is expected to be of " "SizeType if the ReduceOp operates on a ShapeType"); } else { if (!extentTy.isa()) return op.emitOpError( "argument 1 of ReduceOp body is expected to be of IndexType if the " "ReduceOp operates on an extent tensor"); } for (auto type : llvm::enumerate(op.initVals())) if (block.getArgument(type.index() + 2).getType() != type.value().getType()) return op.emitOpError() << "type mismatch between argument " << type.index() + 2 << " of ReduceOp body and initial value " << type.index(); return success(); } static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { // Parse operands. SmallVector operands; Type shapeOrExtentTensorType; if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, OpAsmParser::Delimiter::Paren) || parser.parseColonType(shapeOrExtentTensorType) || parser.parseOptionalArrowTypeList(result.types)) return failure(); // Resolve operands. auto initVals = llvm::makeArrayRef(operands).drop_front(); if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, result.operands) || parser.resolveOperands(initVals, result.types, parser.getNameLoc(), result.operands)) return failure(); // Parse the body. Region *body = result.addRegion(); if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) return failure(); // Parse attributes. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } static void print(OpAsmPrinter &p, ReduceOp op) { p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() << ") : " << op.shape().getType(); p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs()); } #define GET_OP_CLASSES #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"