1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Dialect/SCF/Passes.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpDefinition.h"
21
22 using namespace mlir;
23 using namespace mlir::scf;
24
25 /// Verify there are no nested ParallelOps.
hasNestedParallelOp(ParallelOp ploop)26 static bool hasNestedParallelOp(ParallelOp ploop) {
27 auto walkResult =
28 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
29 return walkResult.wasInterrupted();
30 }
31
32 /// Verify equal iteration spaces.
equalIterationSpaces(ParallelOp firstPloop,ParallelOp secondPloop)33 static bool equalIterationSpaces(ParallelOp firstPloop,
34 ParallelOp secondPloop) {
35 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
36 return false;
37
38 auto matchOperands = [&](const OperandRange &lhs,
39 const OperandRange &rhs) -> bool {
40 // TODO: Extend this to support aliases and equal constants.
41 return std::equal(lhs.begin(), lhs.end(), rhs.begin());
42 };
43 return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) &&
44 matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) &&
45 matchOperands(firstPloop.step(), secondPloop.step());
46 }
47
48 /// Checks if the parallel loops have mixed access to the same buffers. Returns
49 /// `true` if the first parallel loop writes to the same indices that the second
50 /// loop reads.
haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)51 static bool haveNoReadsAfterWriteExceptSameIndex(
52 ParallelOp firstPloop, ParallelOp secondPloop,
53 const BlockAndValueMapping &firstToSecondPloopIndices) {
54 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
55 firstPloop.getBody()->walk([&](StoreOp store) {
56 bufferStores[store.getMemRef()].push_back(store.indices());
57 });
58 auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) {
59 // Stop if the memref is defined in secondPloop body. Careful alias analysis
60 // is needed.
61 auto *memrefDef = load.getMemRef().getDefiningOp();
62 if (memrefDef && memrefDef->getBlock() == load->getBlock())
63 return WalkResult::interrupt();
64
65 auto write = bufferStores.find(load.getMemRef());
66 if (write == bufferStores.end())
67 return WalkResult::advance();
68
69 // Allow only single write access per buffer.
70 if (write->second.size() != 1)
71 return WalkResult::interrupt();
72
73 // Check that the load indices of secondPloop coincide with store indices of
74 // firstPloop for the same memrefs.
75 auto storeIndices = write->second.front();
76 auto loadIndices = load.indices();
77 if (storeIndices.size() != loadIndices.size())
78 return WalkResult::interrupt();
79 for (int i = 0, e = storeIndices.size(); i < e; ++i) {
80 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
81 loadIndices[i])
82 return WalkResult::interrupt();
83 }
84 return WalkResult::advance();
85 });
86 return !walkResult.wasInterrupted();
87 }
88
89 /// Analyzes dependencies in the most primitive way by checking simple read and
90 /// write patterns.
91 static LogicalResult
verifyDependencies(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)92 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
93 const BlockAndValueMapping &firstToSecondPloopIndices) {
94 if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
95 firstToSecondPloopIndices))
96 return failure();
97
98 BlockAndValueMapping secondToFirstPloopIndices;
99 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
100 firstPloop.getBody()->getArguments());
101 return success(haveNoReadsAfterWriteExceptSameIndex(
102 secondPloop, firstPloop, secondToFirstPloopIndices));
103 }
104
105 static bool
isFusionLegal(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)106 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
107 const BlockAndValueMapping &firstToSecondPloopIndices) {
108 return !hasNestedParallelOp(firstPloop) &&
109 !hasNestedParallelOp(secondPloop) &&
110 equalIterationSpaces(firstPloop, secondPloop) &&
111 succeeded(verifyDependencies(firstPloop, secondPloop,
112 firstToSecondPloopIndices));
113 }
114
115 /// Prepends operations of firstPloop's body into secondPloop's body.
fuseIfLegal(ParallelOp firstPloop,ParallelOp secondPloop,OpBuilder b)116 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
117 OpBuilder b) {
118 BlockAndValueMapping firstToSecondPloopIndices;
119 firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
120 secondPloop.getBody()->getArguments());
121
122 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
123 return;
124
125 b.setInsertionPointToStart(secondPloop.getBody());
126 for (auto &op : firstPloop.getBody()->without_terminator())
127 b.clone(op, firstToSecondPloopIndices);
128 firstPloop.erase();
129 }
130
naivelyFuseParallelOps(Region & region)131 void mlir::scf::naivelyFuseParallelOps(Region ®ion) {
132 OpBuilder b(region);
133 // Consider every single block and attempt to fuse adjacent loops.
134 for (auto &block : region) {
135 SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
136 // Not using `walk()` to traverse only top-level parallel loops and also
137 // make sure that there are no side-effecting ops between the parallel
138 // loops.
139 bool noSideEffects = true;
140 for (auto &op : block) {
141 if (auto ploop = dyn_cast<ParallelOp>(op)) {
142 if (noSideEffects) {
143 ploopChains.back().push_back(ploop);
144 } else {
145 ploopChains.push_back({ploop});
146 noSideEffects = true;
147 }
148 continue;
149 }
150 // TODO: Handle region side effects properly.
151 noSideEffects &=
152 MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
153 }
154 for (ArrayRef<ParallelOp> ploops : ploopChains) {
155 for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
156 fuseIfLegal(ploops[i], ploops[i + 1], b);
157 }
158 }
159 }
160
161 namespace {
162 struct ParallelLoopFusion
163 : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
runOnOperation__anona7b89f290511::ParallelLoopFusion164 void runOnOperation() override {
165 getOperation()->walk([&](Operation *child) {
166 for (Region ®ion : child->getRegions())
167 naivelyFuseParallelOps(region);
168 });
169 }
170 };
171 } // namespace
172
createParallelLoopFusionPass()173 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
174 return std::make_unique<ParallelLoopFusion>();
175 }
176