• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/SCF/Passes.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpDefinition.h"
21 
22 using namespace mlir;
23 using namespace mlir::scf;
24 
25 /// Verify there are no nested ParallelOps.
hasNestedParallelOp(ParallelOp ploop)26 static bool hasNestedParallelOp(ParallelOp ploop) {
27   auto walkResult =
28       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
29   return walkResult.wasInterrupted();
30 }
31 
32 /// Verify equal iteration spaces.
equalIterationSpaces(ParallelOp firstPloop,ParallelOp secondPloop)33 static bool equalIterationSpaces(ParallelOp firstPloop,
34                                  ParallelOp secondPloop) {
35   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
36     return false;
37 
38   auto matchOperands = [&](const OperandRange &lhs,
39                            const OperandRange &rhs) -> bool {
40     // TODO: Extend this to support aliases and equal constants.
41     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
42   };
43   return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) &&
44          matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) &&
45          matchOperands(firstPloop.step(), secondPloop.step());
46 }
47 
48 /// Checks if the parallel loops have mixed access to the same buffers. Returns
49 /// `true` if the first parallel loop writes to the same indices that the second
50 /// loop reads.
haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)51 static bool haveNoReadsAfterWriteExceptSameIndex(
52     ParallelOp firstPloop, ParallelOp secondPloop,
53     const BlockAndValueMapping &firstToSecondPloopIndices) {
54   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
55   firstPloop.getBody()->walk([&](StoreOp store) {
56     bufferStores[store.getMemRef()].push_back(store.indices());
57   });
58   auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) {
59     // Stop if the memref is defined in secondPloop body. Careful alias analysis
60     // is needed.
61     auto *memrefDef = load.getMemRef().getDefiningOp();
62     if (memrefDef && memrefDef->getBlock() == load->getBlock())
63       return WalkResult::interrupt();
64 
65     auto write = bufferStores.find(load.getMemRef());
66     if (write == bufferStores.end())
67       return WalkResult::advance();
68 
69     // Allow only single write access per buffer.
70     if (write->second.size() != 1)
71       return WalkResult::interrupt();
72 
73     // Check that the load indices of secondPloop coincide with store indices of
74     // firstPloop for the same memrefs.
75     auto storeIndices = write->second.front();
76     auto loadIndices = load.indices();
77     if (storeIndices.size() != loadIndices.size())
78       return WalkResult::interrupt();
79     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
80       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
81           loadIndices[i])
82         return WalkResult::interrupt();
83     }
84     return WalkResult::advance();
85   });
86   return !walkResult.wasInterrupted();
87 }
88 
89 /// Analyzes dependencies in the most primitive way by checking simple read and
90 /// write patterns.
91 static LogicalResult
verifyDependencies(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)92 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
93                    const BlockAndValueMapping &firstToSecondPloopIndices) {
94   if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
95                                             firstToSecondPloopIndices))
96     return failure();
97 
98   BlockAndValueMapping secondToFirstPloopIndices;
99   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
100                                 firstPloop.getBody()->getArguments());
101   return success(haveNoReadsAfterWriteExceptSameIndex(
102       secondPloop, firstPloop, secondToFirstPloopIndices));
103 }
104 
105 static bool
isFusionLegal(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)106 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
107               const BlockAndValueMapping &firstToSecondPloopIndices) {
108   return !hasNestedParallelOp(firstPloop) &&
109          !hasNestedParallelOp(secondPloop) &&
110          equalIterationSpaces(firstPloop, secondPloop) &&
111          succeeded(verifyDependencies(firstPloop, secondPloop,
112                                       firstToSecondPloopIndices));
113 }
114 
115 /// Prepends operations of firstPloop's body into secondPloop's body.
fuseIfLegal(ParallelOp firstPloop,ParallelOp secondPloop,OpBuilder b)116 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
117                         OpBuilder b) {
118   BlockAndValueMapping firstToSecondPloopIndices;
119   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
120                                 secondPloop.getBody()->getArguments());
121 
122   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
123     return;
124 
125   b.setInsertionPointToStart(secondPloop.getBody());
126   for (auto &op : firstPloop.getBody()->without_terminator())
127     b.clone(op, firstToSecondPloopIndices);
128   firstPloop.erase();
129 }
130 
naivelyFuseParallelOps(Region & region)131 void mlir::scf::naivelyFuseParallelOps(Region &region) {
132   OpBuilder b(region);
133   // Consider every single block and attempt to fuse adjacent loops.
134   for (auto &block : region) {
135     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
136     // Not using `walk()` to traverse only top-level parallel loops and also
137     // make sure that there are no side-effecting ops between the parallel
138     // loops.
139     bool noSideEffects = true;
140     for (auto &op : block) {
141       if (auto ploop = dyn_cast<ParallelOp>(op)) {
142         if (noSideEffects) {
143           ploopChains.back().push_back(ploop);
144         } else {
145           ploopChains.push_back({ploop});
146           noSideEffects = true;
147         }
148         continue;
149       }
150       // TODO: Handle region side effects properly.
151       noSideEffects &=
152           MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
153     }
154     for (ArrayRef<ParallelOp> ploops : ploopChains) {
155       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
156         fuseIfLegal(ploops[i], ploops[i + 1], b);
157     }
158   }
159 }
160 
161 namespace {
162 struct ParallelLoopFusion
163     : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
runOnOperation__anona7b89f290511::ParallelLoopFusion164   void runOnOperation() override {
165     getOperation()->walk([&](Operation *child) {
166       for (Region &region : child->getRegions())
167         naivelyFuseParallelOps(region);
168     });
169   }
170 };
171 } // namespace
172 
createParallelLoopFusionPass()173 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
174   return std::make_unique<ParallelLoopFusion>();
175 }
176