• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unrolling.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Analysis/LoopAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/Passes.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Transforms/LoopUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/Support/CommandLine.h"
22 #include "llvm/Support/Debug.h"
23 
24 using namespace mlir;
25 
26 #define DEBUG_TYPE "affine-loop-unroll"
27 
28 namespace {
29 
30 // TODO: this is really a test pass and should be moved out of dialect
31 // transforms.
32 
33 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
34 /// full unroll threshold was specified, in which case, fully unrolls all loops
35 /// with trip count less than the specified threshold. The latter is for testing
36 /// purposes, especially for testing outer loop unrolling.
37 struct LoopUnroll : public AffineLoopUnrollBase<LoopUnroll> {
38   // Callback to obtain unroll factors; if this has a callable target, takes
39   // precedence over command-line argument or passed argument.
40   const std::function<unsigned(AffineForOp)> getUnrollFactor;
41 
LoopUnroll__anon33d5d8310111::LoopUnroll42   LoopUnroll() : getUnrollFactor(nullptr) {}
LoopUnroll__anon33d5d8310111::LoopUnroll43   LoopUnroll(const LoopUnroll &other)
44       : AffineLoopUnrollBase<LoopUnroll>(other),
45         getUnrollFactor(other.getUnrollFactor) {}
LoopUnroll__anon33d5d8310111::LoopUnroll46   explicit LoopUnroll(
47       Optional<unsigned> unrollFactor = None, bool unrollUpToFactor = false,
48       bool unrollFull = false,
49       const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
50       : getUnrollFactor(getUnrollFactor) {
51     if (unrollFactor)
52       this->unrollFactor = *unrollFactor;
53     this->unrollUpToFactor = unrollUpToFactor;
54     this->unrollFull = unrollFull;
55   }
56 
57   void runOnFunction() override;
58 
59   /// Unroll this for op. Returns failure if nothing was done.
60   LogicalResult runOnAffineForOp(AffineForOp forOp);
61 };
62 } // end anonymous namespace
63 
64 /// Returns true if no other affine.for ops are nested within.
isInnermostAffineForOp(AffineForOp forOp)65 static bool isInnermostAffineForOp(AffineForOp forOp) {
66   // Only for the innermost affine.for op's.
67   bool isInnermost = true;
68   forOp.walk([&](AffineForOp thisForOp) {
69     // Since this is a post order walk, we are able to conclude here.
70     isInnermost = (thisForOp == forOp);
71     return WalkResult::interrupt();
72   });
73   return isInnermost;
74 }
75 
76 /// Gathers loops that have no affine.for's nested within.
gatherInnermostLoops(FuncOp f,SmallVectorImpl<AffineForOp> & loops)77 static void gatherInnermostLoops(FuncOp f,
78                                  SmallVectorImpl<AffineForOp> &loops) {
79   f.walk([&](AffineForOp forOp) {
80     if (isInnermostAffineForOp(forOp))
81       loops.push_back(forOp);
82   });
83 }
84 
runOnFunction()85 void LoopUnroll::runOnFunction() {
86   if (unrollFull && unrollFullThreshold.hasValue()) {
87     // Store short loops as we walk.
88     SmallVector<AffineForOp, 4> loops;
89 
90     // Gathers all loops with trip count <= minTripCount. Do a post order walk
91     // so that loops are gathered from innermost to outermost (or else unrolling
92     // an outer one may delete gathered inner ones).
93     getFunction().walk([&](AffineForOp forOp) {
94       Optional<uint64_t> tripCount = getConstantTripCount(forOp);
95       if (tripCount.hasValue() && tripCount.getValue() <= unrollFullThreshold)
96         loops.push_back(forOp);
97     });
98     for (auto forOp : loops)
99       loopUnrollFull(forOp);
100     return;
101   }
102 
103   // If the call back is provided, we will recurse until no loops are found.
104   FuncOp func = getFunction();
105   SmallVector<AffineForOp, 4> loops;
106   for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
107     loops.clear();
108     gatherInnermostLoops(func, loops);
109     if (loops.empty())
110       break;
111     bool unrolled = false;
112     for (auto forOp : loops)
113       unrolled |= succeeded(runOnAffineForOp(forOp));
114     if (!unrolled)
115       // Break out if nothing was unrolled.
116       break;
117   }
118 }
119 
120 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
121 /// failure otherwise. The default unroll factor is 4.
runOnAffineForOp(AffineForOp forOp)122 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
123   // Use the function callback if one was provided.
124   if (getUnrollFactor)
125     return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
126   // Unroll completely if full loop unroll was specified.
127   if (unrollFull)
128     return loopUnrollFull(forOp);
129   // Otherwise, unroll by the given unroll factor.
130   if (unrollUpToFactor)
131     return loopUnrollUpToFactor(forOp, unrollFactor);
132   return loopUnrollByFactor(forOp, unrollFactor);
133 }
134 
createLoopUnrollPass(int unrollFactor,bool unrollUpToFactor,bool unrollFull,const std::function<unsigned (AffineForOp)> & getUnrollFactor)135 std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopUnrollPass(
136     int unrollFactor, bool unrollUpToFactor, bool unrollFull,
137     const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
138   return std::make_unique<LoopUnroll>(
139       unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
140       unrollUpToFactor, unrollFull, getUnrollFactor);
141 }
142