1 //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Specializes parallel loops and for loops for easier unrolling and
10 // vectorization.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/SCF/Passes.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21
22 using namespace mlir;
23 using scf::ForOp;
24 using scf::ParallelOp;
25
26 /// Rewrite a parallel loop with bounds defined by an affine.min with a constant
27 /// into 2 loops after checking if the bounds are equal to that constant. This
28 /// is beneficial if the loop will almost always have the constant bound and
29 /// that version can be fully unrolled and vectorized.
specializeParallelLoopForUnrolling(ParallelOp op)30 static void specializeParallelLoopForUnrolling(ParallelOp op) {
31 SmallVector<int64_t, 2> constantIndices;
32 constantIndices.reserve(op.upperBound().size());
33 for (auto bound : op.upperBound()) {
34 auto minOp = bound.getDefiningOp<AffineMinOp>();
35 if (!minOp)
36 return;
37 int64_t minConstant = std::numeric_limits<int64_t>::max();
38 for (AffineExpr expr : minOp.map().getResults()) {
39 if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
40 minConstant = std::min(minConstant, constantIndex.getValue());
41 }
42 if (minConstant == std::numeric_limits<int64_t>::max())
43 return;
44 constantIndices.push_back(minConstant);
45 }
46
47 OpBuilder b(op);
48 BlockAndValueMapping map;
49 Value cond;
50 for (auto bound : llvm::zip(op.upperBound(), constantIndices)) {
51 Value constant = b.create<ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
52 Value cmp = b.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq,
53 std::get<0>(bound), constant);
54 cond = cond ? b.create<AndOp>(op.getLoc(), cond, cmp) : cmp;
55 map.map(std::get<0>(bound), constant);
56 }
57 auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
58 ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
59 ifOp.getElseBodyBuilder().clone(*op.getOperation());
60 op.erase();
61 }
62
63 /// Rewrite a for loop with bounds defined by an affine.min with a constant into
64 /// 2 loops after checking if the bounds are equal to that constant. This is
65 /// beneficial if the loop will almost always have the constant bound and that
66 /// version can be fully unrolled and vectorized.
specializeForLoopForUnrolling(ForOp op)67 static void specializeForLoopForUnrolling(ForOp op) {
68 auto bound = op.upperBound();
69 auto minOp = bound.getDefiningOp<AffineMinOp>();
70 if (!minOp)
71 return;
72 int64_t minConstant = std::numeric_limits<int64_t>::max();
73 for (AffineExpr expr : minOp.map().getResults()) {
74 if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
75 minConstant = std::min(minConstant, constantIndex.getValue());
76 }
77 if (minConstant == std::numeric_limits<int64_t>::max())
78 return;
79
80 OpBuilder b(op);
81 BlockAndValueMapping map;
82 Value constant = b.create<ConstantIndexOp>(op.getLoc(), minConstant);
83 Value cond =
84 b.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, bound, constant);
85 map.map(bound, constant);
86 auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
87 ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
88 ifOp.getElseBodyBuilder().clone(*op.getOperation());
89 op.erase();
90 }
91
92 namespace {
93 struct ParallelLoopSpecialization
94 : public SCFParallelLoopSpecializationBase<ParallelLoopSpecialization> {
runOnFunction__anon5d2e62270111::ParallelLoopSpecialization95 void runOnFunction() override {
96 getFunction().walk(
97 [](ParallelOp op) { specializeParallelLoopForUnrolling(op); });
98 }
99 };
100
101 struct ForLoopSpecialization
102 : public SCFForLoopSpecializationBase<ForLoopSpecialization> {
runOnFunction__anon5d2e62270111::ForLoopSpecialization103 void runOnFunction() override {
104 getFunction().walk([](ForOp op) { specializeForLoopForUnrolling(op); });
105 }
106 };
107 } // namespace
108
createParallelLoopSpecializationPass()109 std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
110 return std::make_unique<ParallelLoopSpecialization>();
111 }
112
createForLoopSpecializationPass()113 std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
114 return std::make_unique<ForLoopSpecialization>();
115 }
116