• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Specializes parallel loops and for loops for easier unrolling and
10 // vectorization.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/SCF/Passes.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 
22 using namespace mlir;
23 using scf::ForOp;
24 using scf::ParallelOp;
25 
26 /// Rewrite a parallel loop with bounds defined by an affine.min with a constant
27 /// into 2 loops after checking if the bounds are equal to that constant. This
28 /// is beneficial if the loop will almost always have the constant bound and
29 /// that version can be fully unrolled and vectorized.
specializeParallelLoopForUnrolling(ParallelOp op)30 static void specializeParallelLoopForUnrolling(ParallelOp op) {
31   SmallVector<int64_t, 2> constantIndices;
32   constantIndices.reserve(op.upperBound().size());
33   for (auto bound : op.upperBound()) {
34     auto minOp = bound.getDefiningOp<AffineMinOp>();
35     if (!minOp)
36       return;
37     int64_t minConstant = std::numeric_limits<int64_t>::max();
38     for (AffineExpr expr : minOp.map().getResults()) {
39       if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
40         minConstant = std::min(minConstant, constantIndex.getValue());
41     }
42     if (minConstant == std::numeric_limits<int64_t>::max())
43       return;
44     constantIndices.push_back(minConstant);
45   }
46 
47   OpBuilder b(op);
48   BlockAndValueMapping map;
49   Value cond;
50   for (auto bound : llvm::zip(op.upperBound(), constantIndices)) {
51     Value constant = b.create<ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
52     Value cmp = b.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq,
53                                  std::get<0>(bound), constant);
54     cond = cond ? b.create<AndOp>(op.getLoc(), cond, cmp) : cmp;
55     map.map(std::get<0>(bound), constant);
56   }
57   auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
58   ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
59   ifOp.getElseBodyBuilder().clone(*op.getOperation());
60   op.erase();
61 }
62 
63 /// Rewrite a for loop with bounds defined by an affine.min with a constant into
64 /// 2 loops after checking if the bounds are equal to that constant. This is
65 /// beneficial if the loop will almost always have the constant bound and that
66 /// version can be fully unrolled and vectorized.
specializeForLoopForUnrolling(ForOp op)67 static void specializeForLoopForUnrolling(ForOp op) {
68   auto bound = op.upperBound();
69   auto minOp = bound.getDefiningOp<AffineMinOp>();
70   if (!minOp)
71     return;
72   int64_t minConstant = std::numeric_limits<int64_t>::max();
73   for (AffineExpr expr : minOp.map().getResults()) {
74     if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
75       minConstant = std::min(minConstant, constantIndex.getValue());
76   }
77   if (minConstant == std::numeric_limits<int64_t>::max())
78     return;
79 
80   OpBuilder b(op);
81   BlockAndValueMapping map;
82   Value constant = b.create<ConstantIndexOp>(op.getLoc(), minConstant);
83   Value cond =
84       b.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, bound, constant);
85   map.map(bound, constant);
86   auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
87   ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
88   ifOp.getElseBodyBuilder().clone(*op.getOperation());
89   op.erase();
90 }
91 
92 namespace {
93 struct ParallelLoopSpecialization
94     : public SCFParallelLoopSpecializationBase<ParallelLoopSpecialization> {
runOnFunction__anon5d2e62270111::ParallelLoopSpecialization95   void runOnFunction() override {
96     getFunction().walk(
97         [](ParallelOp op) { specializeParallelLoopForUnrolling(op); });
98   }
99 };
100 
101 struct ForLoopSpecialization
102     : public SCFForLoopSpecializationBase<ForLoopSpecialization> {
runOnFunction__anon5d2e62270111::ForLoopSpecialization103   void runOnFunction() override {
104     getFunction().walk([](ForOp op) { specializeForLoopForUnrolling(op); });
105   }
106 };
107 } // namespace
108 
createParallelLoopSpecializationPass()109 std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
110   return std::make_unique<ParallelLoopSpecialization>();
111 }
112 
createForLoopSpecializationPass()113 std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
114   return std::make_unique<ForLoopSpecialization>();
115 }
116