//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "../PassDetail.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; namespace { class ConvertCstrBroadcastableOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, PatternRewriter &rewriter) const override { if (op.getType().isa() || op.lhs().getType().isa() || op.rhs().getType().isa()) { return rewriter.notifyMatchFailure( op, "cannot convert error-propagating shapes"); } auto loc = op.getLoc(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); // Find smaller and greater rank and extent tensor. Value lhsRank = rewriter.create(loc, op.lhs(), zero); Value rhsRank = rewriter.create(loc, op.rhs(), zero); Value lhsRankULE = rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); Type indexTy = rewriter.getIndexType(); Value lesserRank = rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); Value greaterRank = rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); Value lesserRankOperand = rewriter.create(loc, lhsRankULE, op.lhs(), op.rhs()); Value greaterRankOperand = rewriter.create(loc, lhsRankULE, op.rhs(), op.lhs()); Value rankDiff = rewriter.create(loc, indexTy, greaterRank, lesserRank); // Generate code to compare the shapes extent by extent, and emit errors for // non-broadcast-compatible shapes. // Two extents are broadcast-compatible if // 1. they are both equal, or // 2. at least one of them is 1. rewriter.create( loc, rankDiff, greaterRank, one, llvm::None, [&](OpBuilder &b, Location loc, Value iv, ValueRange) { Value greaterRankOperandExtent = b.create( loc, greaterRankOperand, ValueRange{iv}); Value ivShifted = b.create(loc, indexTy, iv, rankDiff); Value lesserRankOperandExtent = b.create( loc, lesserRankOperand, ValueRange{ivShifted}); Value greaterRankOperandExtentIsOne = b.create( loc, CmpIPredicate::eq, greaterRankOperandExtent, one); Value lesserRankOperandExtentIsOne = b.create( loc, CmpIPredicate::eq, lesserRankOperandExtent, one); Value extentsAgree = b.create(loc, CmpIPredicate::eq, greaterRankOperandExtent, lesserRankOperandExtent); auto broadcastIsValid = b.create(loc, b.getI1Type(), extentsAgree, b.create(loc, greaterRankOperandExtentIsOne, lesserRankOperandExtentIsOne)); b.create(loc, broadcastIsValid, "invalid broadcast"); b.create(loc); }); rewriter.replaceOpWithNewOp(op, true); return success(); } }; } // namespace namespace { class ConvertCstrRequireOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrRequireOp op, PatternRewriter &rewriter) const override { rewriter.create(op.getLoc(), op.pred(), op.msgAttr()); rewriter.replaceOpWithNewOp(op, true); return success(); } }; } // namespace void mlir::populateConvertShapeConstraintsConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx); patterns.insert(ctx); } namespace { // This pass eliminates shape constraints from the program, converting them to // eager (side-effecting) error handling code. After eager error handling code // is emitted, witnesses are satisfied, so they are replace with // `shape.const_witness true`. class ConvertShapeConstraints : public ConvertShapeConstraintsBase { void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); OwningRewritePatternList patterns; populateConvertShapeConstraintsConversionPatterns(patterns, context); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::createConvertShapeConstraintsPass() { return std::make_unique(); }