//===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "../PassDetail.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::shape; using namespace mlir::scf; /// Conversion patterns. namespace { class AnyOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AnyOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { AnyOp::Adaptor transformed(operands); // Replace `any` with its first operand. // Any operand would be a valid substitution. rewriter.replaceOp(op, {transformed.inputs().front()}); return success(); } namespace { template class BinaryOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SrcOpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { typename SrcOpTy::Adaptor transformed(operands); // For now, only error-free types are supported by this lowering. if (op.getType().template isa()) return failure(); rewriter.replaceOpWithNewOp(op, transformed.lhs(), transformed.rhs()); return success(); } }; } // namespace namespace { struct BroadcastOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BroadcastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult BroadcastOpConverter::matchAndRewrite( BroadcastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. if (op.getType().isa()) return failure(); assert(!op.lhs().getType().isa() && !op.rhs().getType().isa()); auto loc = op.getLoc(); BroadcastOp::Adaptor transformed(operands); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); // Find smaller and greater rank and extent tensor. Value lhsRank = rewriter.create(loc, op.lhs(), zero); Value rhsRank = rewriter.create(loc, op.rhs(), zero); Value lhsRankULE = rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); Type indexTy = rewriter.getIndexType(); Value lesserRank = rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); Value greaterRank = rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); auto erasedRankType = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); Value rankErasedLhs = rewriter.create(loc, erasedRankType, transformed.lhs()); Value rankErasedRhs = rewriter.create(loc, erasedRankType, transformed.rhs()); Value lesserRankOperand = rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); Value greaterRankOperand = rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); Value rankDiff = rewriter.create(loc, indexTy, greaterRank, lesserRank); rewriter.replaceOpWithNewOp( op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value outputDimension = args[0]; Value isUnchallengedDimension = b.create( loc, CmpIPredicate::ult, outputDimension, rankDiff); Value greaterRankOperandExtent = b.create( loc, greaterRankOperand, outputDimension); // The initial dimensions of the greater-rank operand are unchallenged, // so we can take them as-is. Otherwise, we need to do a comparison. // We need an actual branch here (instead of a select) because the // lesser-rank operand might be rank 0, so any extract_element would be // invalid. auto ifOp = b.create( loc, TypeRange{indexTy}, isUnchallengedDimension, [&](OpBuilder &b, Location loc) { b.create(loc, greaterRankOperandExtent); }, [&](OpBuilder &b, Location loc) { // The broadcasting logic is: // - if one extent (here we arbitrarily choose the extent from // the greater-rank operand) is equal to 1, then take the extent // from the other operand // - otherwise, take the extent as-is. // Note that this logic remains correct in the presence of // dimensions of zero extent. Value lesserRankOperandDimension = b.create(loc, indexTy, outputDimension, rankDiff); Value lesserRankOperandExtent = b.create( loc, lesserRankOperand, ValueRange{lesserRankOperandDimension}); Value greaterRankOperandExtentIsOne = b.create( loc, CmpIPredicate::eq, greaterRankOperandExtent, one); Value broadcastedExtent = b.create( loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent, greaterRankOperandExtent); b.create(loc, broadcastedExtent); }); b.create(loc, ifOp.getResult(0)); }); return success(); } namespace { class ConstShapeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstShapeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ConstShapeOpConverter::matchAndRewrite( ConstShapeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only extent tensors, not `shape.shape` // types. if (op.getType().isa()) return failure(); auto loc = op.getLoc(); SmallVector extentOperands; for (auto extent : op.shape()) { extentOperands.push_back( rewriter.create(loc, extent.getLimitedValue())); } Type indexTy = rewriter.getIndexType(); Value tensor = rewriter.create(loc, indexTy, extentOperands); Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); rewriter.replaceOpWithNewOp(op, tensor, resultTy); return success(); } namespace { class ConstSizeOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstSizeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ConstSizeOpConversion::matchAndRewrite( ConstSizeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); return success(); } namespace { struct IsBroadcastableOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(IsBroadcastableOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult IsBroadcastableOpConverter::matchAndRewrite( IsBroadcastableOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. IsBroadcastableOp::Adaptor transformed(operands); if (transformed.lhs().getType().isa() || transformed.rhs().getType().isa()) return failure(); auto loc = op.getLoc(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); // Find smaller and greater rank and extent tensor. Value lhsRank = rewriter.create(loc, transformed.lhs(), zero); Value rhsRank = rewriter.create(loc, transformed.rhs(), zero); Value lhsRankULE = rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); Type indexTy = rewriter.getIndexType(); Value lesserRank = rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); Value greaterRank = rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); auto erasedRankType = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); Value rankErasedLhs = rewriter.create(loc, erasedRankType, transformed.lhs()); Value rankErasedRhs = rewriter.create(loc, erasedRankType, transformed.rhs()); Value lesserRankOperand = rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); Value greaterRankOperand = rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); Value rankDiff = rewriter.create(loc, indexTy, greaterRank, lesserRank); Type i1Ty = rewriter.getI1Type(); Value init = rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); // Determine if all overlapping extents are broadcastable. auto reduceResult = rewriter.create( loc, rankDiff, greaterRank, one, ValueRange{init}, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { Value greaterRankOperandExtent = b.create(loc, greaterRankOperand, ValueRange{iv}); Value greaterRankOperandExtentIsOne = b.create( loc, CmpIPredicate::eq, greaterRankOperandExtent, one); Value ivShifted = b.create(loc, indexTy, iv, rankDiff); Value lesserRankOperandExtent = b.create( loc, lesserRankOperand, ValueRange{ivShifted}); Value lesserRankOperandExtentIsOne = b.create( loc, CmpIPredicate::eq, lesserRankOperandExtent, one); Value extentsAreEqual = b.create(loc, CmpIPredicate::eq, greaterRankOperandExtent, lesserRankOperandExtent); Value broadcastableExtents = b.create( loc, iterArgs[0], b.create(loc, b.create(loc, greaterRankOperandExtentIsOne, lesserRankOperandExtentIsOne), extentsAreEqual)); b.create(loc, broadcastableExtents); }); rewriter.replaceOp(op, reduceResult.results().front()); return success(); } namespace { class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(GetExtentOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult GetExtentOpConverter::matchAndRewrite( GetExtentOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { GetExtentOp::Adaptor transformed(operands); // For now, only error-free types are supported by this lowering. if (op.getType().isa()) return failure(); // Derive shape extent directly from shape origin if possible. This // circumvents the necessity to materialize the shape in memory. if (auto shapeOfOp = op.shape().getDefiningOp()) { if (shapeOfOp.arg().getType().isa()) { rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), transformed.dim()); return success(); } } rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), transformed.shape(), ValueRange{transformed.dim()}); return success(); } namespace { class RankOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::RankOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only error-free types. if (op.getType().isa()) return failure(); shape::RankOp::Adaptor transformed(operands); rewriter.replaceOpWithNewOp(op, transformed.shape(), 0); return success(); } namespace { /// Converts `shape.reduce` to `scf.for`. struct ReduceOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::ReduceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final; }; } // namespace LogicalResult ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands. if (op.shape().getType().isa()) return failure(); auto loc = op.getLoc(); shape::ReduceOp::Adaptor transformed(operands); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = rewriter.create(loc, indexTy, transformed.shape(), zero); auto loop = rewriter.create( loc, zero, rank, one, op.initVals(), [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { Value extent = b.create(loc, transformed.shape(), iv); SmallVector mappedValues{iv, extent}; mappedValues.append(args.begin(), args.end()); BlockAndValueMapping mapping; Block *reduceBody = op.getBody(); mapping.map(reduceBody->getArguments(), mappedValues); for (auto &nested : reduceBody->without_terminator()) b.clone(nested, mapping); SmallVector mappedResults; for (auto result : reduceBody->getTerminator()->getOperands()) mappedResults.push_back(mapping.lookup(result)); b.create(loc, mappedResults); }); rewriter.replaceOp(op, loop.getResults()); return success(); } namespace { /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is /// only defined on `tensor` operands. The test for equality first /// compares their size and, if equal, checks every extent for equality. /// /// Example: /// /// %result = shape.shape_eq %a, %b : tensor, tensor /// /// becomes /// /// %c0 = constant 0 : index /// %0 = dim %arg0, %c0 : tensor /// %1 = dim %arg1, %c0 : tensor /// %2 = cmpi "eq", %0, %1 : index /// %result = scf.if %2 -> (i1) { /// %c1 = constant 1 : index /// %true = constant true /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { /// %5 = extract_element %arg0[%arg2] : tensor /// %6 = extract_element %arg1[%arg2] : tensor /// %7 = cmpi "eq", %5, %6 : index /// %8 = and %arg3, %7 : i1 /// scf.yield %8 : i1 /// } /// scf.yield %4 : i1 /// } else { /// %false = constant false /// scf.yield %false : i1 /// } /// struct ShapeEqOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ShapeEqOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. if (op.lhs().getType().isa() || op.rhs().getType().isa()) { return failure(); } ShapeEqOp::Adaptor transformed(operands); auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); Value zero = rewriter.create(loc, 0); Value lhsRank = rewriter.create(loc, indexTy, transformed.lhs(), zero); Value rhsRank = rewriter.create(loc, indexTy, transformed.rhs(), zero); Value eqRank = rewriter.create(loc, CmpIPredicate::eq, lhsRank, rhsRank); Type i1Ty = rewriter.getI1Type(); rewriter.replaceOpWithNewOp( op, i1Ty, eqRank, [&](OpBuilder &b, Location loc) { Value one = b.create(loc, 1); Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); auto loop = b.create( loc, zero, lhsRank, one, ValueRange{init}, [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { Value conj = args[0]; Value lhsExtent = b.create(loc, transformed.lhs(), iv); Value rhsExtent = b.create(loc, transformed.rhs(), iv); Value eqExtent = b.create(loc, CmpIPredicate::eq, lhsExtent, rhsExtent); Value conjNext = b.create(loc, conj, eqExtent); b.create(loc, ValueRange({conjNext})); }); b.create(loc, loop.getResults()); }, [&](OpBuilder &b, Location loc) { Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); b.create(loc, result); }); return success(); } namespace { class ShapeOfOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ShapeOfOpConversion::matchAndRewrite( ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, only error-free types are supported by this lowering. if (op.getType().isa()) return failure(); // For ranked tensor arguments, lower to `tensor_from_elements`. auto loc = op.getLoc(); ShapeOfOp::Adaptor transformed(operands); Value tensor = transformed.arg(); Type tensorTy = tensor.getType(); if (tensorTy.isa()) { // Build values for individual extents. SmallVector extentValues; RankedTensorType rankedTensorTy = tensorTy.cast(); int64_t rank = rankedTensorTy.getRank(); for (int64_t i = 0; i < rank; i++) { if (rankedTensorTy.isDynamicDim(i)) { Value extent = rewriter.create(loc, tensor, i); extentValues.push_back(extent); } else { Value extent = rewriter.create(loc, rankedTensorTy.getDimSize(i)); extentValues.push_back(extent); } } // Materialize extent tensor. Value staticExtentTensor = rewriter.create( loc, rewriter.getIndexType(), extentValues); rewriter.replaceOpWithNewOp(op, staticExtentTensor, op.getType()); return success(); } // Lower to `dynamic_tensor_from_elements` otherwise. auto *ctx = rewriter.getContext(); Value rank = rewriter.create(loc, tensor); rewriter.replaceOpWithNewOp( op, getExtentTensorType(ctx), ValueRange{rank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value dim = args.front(); Value extent = b.create(loc, tensor, dim); b.create(loc, extent); }); return success(); } namespace { class ToExtentTensorOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ToExtentTensorOpAdaptor adaptor(operands); if (!adaptor.input().getType().isa()) return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); rewriter.replaceOpWithNewOp(op, adaptor.input(), op.getType()); return success(); } }; } // namespace namespace { /// Import the Shape Ops to Std Patterns. #include "ShapeToStandard.cpp.inc" } // namespace namespace { /// Conversion pass. class ConvertShapeToStandardPass : public ConvertShapeToStandardBase { void runOnOperation() override; }; } // namespace void ConvertShapeToStandardPass::runOnOperation() { // Setup target legality. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); target.addLegalDialect(); target.addLegalOp(); // Setup conversion patterns. OwningRewritePatternList patterns; populateShapeToStandardConversionPatterns(patterns, &ctx); // Apply conversion. auto module = getOperation(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } void mlir::populateShapeToStandardConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { // clang-format off populateWithGenerated(ctx, patterns); patterns.insert< AnyOpConversion, BinaryOpConversion, BinaryOpConversion, BroadcastOpConverter, ConstShapeOpConverter, ConstSizeOpConversion, IsBroadcastableOpConverter, GetExtentOpConverter, RankOpConverter, ReduceOpConverter, ShapeEqOpConverter, ShapeOfOpConversion, ToExtentTensorOpConversion>(ctx); // clang-format on } std::unique_ptr> mlir::createConvertShapeToStandardPass() { return std::make_unique(); }