1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unrolling.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Analysis/LoopAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/Passes.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Transforms/LoopUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/Support/CommandLine.h"
22 #include "llvm/Support/Debug.h"
23
24 using namespace mlir;
25
26 #define DEBUG_TYPE "affine-loop-unroll"
27
28 namespace {
29
30 // TODO: this is really a test pass and should be moved out of dialect
31 // transforms.
32
33 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
34 /// full unroll threshold was specified, in which case, fully unrolls all loops
35 /// with trip count less than the specified threshold. The latter is for testing
36 /// purposes, especially for testing outer loop unrolling.
37 struct LoopUnroll : public AffineLoopUnrollBase<LoopUnroll> {
38 // Callback to obtain unroll factors; if this has a callable target, takes
39 // precedence over command-line argument or passed argument.
40 const std::function<unsigned(AffineForOp)> getUnrollFactor;
41
LoopUnroll__anon33d5d8310111::LoopUnroll42 LoopUnroll() : getUnrollFactor(nullptr) {}
LoopUnroll__anon33d5d8310111::LoopUnroll43 LoopUnroll(const LoopUnroll &other)
44 : AffineLoopUnrollBase<LoopUnroll>(other),
45 getUnrollFactor(other.getUnrollFactor) {}
LoopUnroll__anon33d5d8310111::LoopUnroll46 explicit LoopUnroll(
47 Optional<unsigned> unrollFactor = None, bool unrollUpToFactor = false,
48 bool unrollFull = false,
49 const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
50 : getUnrollFactor(getUnrollFactor) {
51 if (unrollFactor)
52 this->unrollFactor = *unrollFactor;
53 this->unrollUpToFactor = unrollUpToFactor;
54 this->unrollFull = unrollFull;
55 }
56
57 void runOnFunction() override;
58
59 /// Unroll this for op. Returns failure if nothing was done.
60 LogicalResult runOnAffineForOp(AffineForOp forOp);
61 };
62 } // end anonymous namespace
63
64 /// Returns true if no other affine.for ops are nested within.
isInnermostAffineForOp(AffineForOp forOp)65 static bool isInnermostAffineForOp(AffineForOp forOp) {
66 // Only for the innermost affine.for op's.
67 bool isInnermost = true;
68 forOp.walk([&](AffineForOp thisForOp) {
69 // Since this is a post order walk, we are able to conclude here.
70 isInnermost = (thisForOp == forOp);
71 return WalkResult::interrupt();
72 });
73 return isInnermost;
74 }
75
76 /// Gathers loops that have no affine.for's nested within.
gatherInnermostLoops(FuncOp f,SmallVectorImpl<AffineForOp> & loops)77 static void gatherInnermostLoops(FuncOp f,
78 SmallVectorImpl<AffineForOp> &loops) {
79 f.walk([&](AffineForOp forOp) {
80 if (isInnermostAffineForOp(forOp))
81 loops.push_back(forOp);
82 });
83 }
84
runOnFunction()85 void LoopUnroll::runOnFunction() {
86 if (unrollFull && unrollFullThreshold.hasValue()) {
87 // Store short loops as we walk.
88 SmallVector<AffineForOp, 4> loops;
89
90 // Gathers all loops with trip count <= minTripCount. Do a post order walk
91 // so that loops are gathered from innermost to outermost (or else unrolling
92 // an outer one may delete gathered inner ones).
93 getFunction().walk([&](AffineForOp forOp) {
94 Optional<uint64_t> tripCount = getConstantTripCount(forOp);
95 if (tripCount.hasValue() && tripCount.getValue() <= unrollFullThreshold)
96 loops.push_back(forOp);
97 });
98 for (auto forOp : loops)
99 loopUnrollFull(forOp);
100 return;
101 }
102
103 // If the call back is provided, we will recurse until no loops are found.
104 FuncOp func = getFunction();
105 SmallVector<AffineForOp, 4> loops;
106 for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
107 loops.clear();
108 gatherInnermostLoops(func, loops);
109 if (loops.empty())
110 break;
111 bool unrolled = false;
112 for (auto forOp : loops)
113 unrolled |= succeeded(runOnAffineForOp(forOp));
114 if (!unrolled)
115 // Break out if nothing was unrolled.
116 break;
117 }
118 }
119
120 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
121 /// failure otherwise. The default unroll factor is 4.
runOnAffineForOp(AffineForOp forOp)122 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
123 // Use the function callback if one was provided.
124 if (getUnrollFactor)
125 return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
126 // Unroll completely if full loop unroll was specified.
127 if (unrollFull)
128 return loopUnrollFull(forOp);
129 // Otherwise, unroll by the given unroll factor.
130 if (unrollUpToFactor)
131 return loopUnrollUpToFactor(forOp, unrollFactor);
132 return loopUnrollByFactor(forOp, unrollFactor);
133 }
134
createLoopUnrollPass(int unrollFactor,bool unrollUpToFactor,bool unrollFull,const std::function<unsigned (AffineForOp)> & getUnrollFactor)135 std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopUnrollPass(
136 int unrollFactor, bool unrollUpToFactor, bool unrollFull,
137 const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
138 return std::make_unique<LoopUnroll>(
139 unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
140 unrollUpToFactor, unrollFull, getUnrollFactor);
141 }
142