• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassRegistry.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 class ConvertCstrBroadcastableOp
24     : public OpRewritePattern<shape::CstrBroadcastableOp> {
25 public:
26   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(shape::CstrBroadcastableOp op,PatternRewriter & rewriter) const27   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
28                                 PatternRewriter &rewriter) const override {
29     if (op.getType().isa<shape::ShapeType>() ||
30         op.lhs().getType().isa<shape::ShapeType>() ||
31         op.rhs().getType().isa<shape::ShapeType>()) {
32       return rewriter.notifyMatchFailure(
33           op, "cannot convert error-propagating shapes");
34     }
35 
36     auto loc = op.getLoc();
37     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
38     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
39 
40     // Find smaller and greater rank and extent tensor.
41     Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
42     Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
43     Value lhsRankULE =
44         rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
45     Type indexTy = rewriter.getIndexType();
46     Value lesserRank =
47         rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
48     Value greaterRank =
49         rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
50     Value lesserRankOperand =
51         rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
52     Value greaterRankOperand =
53         rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
54 
55     Value rankDiff =
56         rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
57 
58     // Generate code to compare the shapes extent by extent, and emit errors for
59     // non-broadcast-compatible shapes.
60     // Two extents are broadcast-compatible if
61     // 1. they are both equal, or
62     // 2. at least one of them is 1.
63 
64     rewriter.create<scf::ForOp>(
65         loc, rankDiff, greaterRank, one, llvm::None,
66         [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
67           Value greaterRankOperandExtent = b.create<ExtractElementOp>(
68               loc, greaterRankOperand, ValueRange{iv});
69           Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
70           Value lesserRankOperandExtent = b.create<ExtractElementOp>(
71               loc, lesserRankOperand, ValueRange{ivShifted});
72 
73           Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
74               loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
75           Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
76               loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
77           Value extentsAgree =
78               b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
79                                lesserRankOperandExtent);
80           auto broadcastIsValid =
81               b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
82                              b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
83                                             lesserRankOperandExtentIsOne));
84           b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
85           b.create<scf::YieldOp>(loc);
86         });
87 
88     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
89     return success();
90   }
91 };
92 } // namespace
93 
94 namespace {
95 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
96 public:
97   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(shape::CstrRequireOp op,PatternRewriter & rewriter) const98   LogicalResult matchAndRewrite(shape::CstrRequireOp op,
99                                 PatternRewriter &rewriter) const override {
100     rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
101     rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
102     return success();
103   }
104 };
105 } // namespace
106 
populateConvertShapeConstraintsConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)107 void mlir::populateConvertShapeConstraintsConversionPatterns(
108     OwningRewritePatternList &patterns, MLIRContext *ctx) {
109   patterns.insert<ConvertCstrBroadcastableOp>(ctx);
110   patterns.insert<ConvertCstrRequireOp>(ctx);
111 }
112 
113 namespace {
114 // This pass eliminates shape constraints from the program, converting them to
115 // eager (side-effecting) error handling code. After eager error handling code
116 // is emitted, witnesses are satisfied, so they are replace with
117 // `shape.const_witness true`.
118 class ConvertShapeConstraints
119     : public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
runOnOperation()120   void runOnOperation() override {
121     auto func = getOperation();
122     auto *context = &getContext();
123 
124     OwningRewritePatternList patterns;
125     populateConvertShapeConstraintsConversionPatterns(patterns, context);
126 
127     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
128       return signalPassFailure();
129   }
130 };
131 } // namespace
132 
133 std::unique_ptr<OperationPass<FuncOp>>
createConvertShapeConstraintsPass()134 mlir::createConvertShapeConstraintsPass() {
135   return std::make_unique<ConvertShapeConstraints>();
136 }
137