1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/SCF/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Dialect/Vector/VectorUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/Transforms/LoopUtils.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/Debug.h"
27
28 #define DEBUG_TYPE "linalg-hoisting"
29
30 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
31
32 using namespace mlir;
33 using namespace mlir::linalg;
34
35 using llvm::dbgs;
36
hoistViewAllocOps(FuncOp func)37 void mlir::linalg::hoistViewAllocOps(FuncOp func) {
38 bool changed = true;
39 while (changed) {
40 changed = false;
41 func.walk([&changed](Operation *op) {
42 if (!isa<AllocOp, AllocaOp, DeallocOp>(op))
43 return;
44
45 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n");
46 auto loop = dyn_cast<scf::ForOp>(op->getParentOp());
47 LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n");
48
49 // Only hoist out of immediately enclosing scf::ForOp.
50 if (!loop)
51 return;
52
53 // If any operand is defined inside the loop don't hoist.
54 if (llvm::any_of(op->getOperands(), [&](Value v) {
55 return !loop.isDefinedOutsideOfLoop(v);
56 }))
57 return;
58
59 LLVM_DEBUG(DBGS() << "All operands defined outside \n");
60
61 // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist.
62 Value v;
63 if (op->getNumResults() > 0) {
64 assert(op->getNumResults() == 1 && "Unexpected multi-result alloc");
65 v = op->getResult(0);
66 }
67 if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) {
68 return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner());
69 })) {
70 LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n");
71 return;
72 }
73
74 // Move AllocOp before the loop.
75 if (isa<AllocOp, AllocaOp>(op))
76 loop.moveOutOfLoop({op});
77 else // Move DeallocOp outside of the loop.
78 op->moveAfter(loop);
79 changed = true;
80 });
81 }
82 }
83
hoistRedundantVectorTransfers(FuncOp func)84 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
85 bool changed = true;
86 while (changed) {
87 changed = false;
88
89 func.walk([&](vector::TransferReadOp transferRead) {
90 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
91 << *transferRead.getOperation() << "\n");
92 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
93 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
94 << "\n");
95 if (!loop)
96 return WalkResult::advance();
97
98 if (failed(moveLoopInvariantCode(
99 cast<LoopLikeOpInterface>(loop.getOperation()))))
100 llvm_unreachable(
101 "Unexpected failure to move invariant code out of loop");
102
103 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
104 << "\n");
105
106 llvm::SetVector<Operation *> forwardSlice;
107 getForwardSlice(transferRead, &forwardSlice);
108
109 // Look for the last TransferWriteOp in the forwardSlice of
110 // `transferRead` that operates on the same memref.
111 vector::TransferWriteOp transferWrite;
112 for (auto *sliceOp : llvm::reverse(forwardSlice)) {
113 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
114 if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
115 continue;
116 transferWrite = candidateWrite;
117 }
118
119 // All operands of the TransferRead must be defined outside of the loop.
120 for (auto operand : transferRead.getOperands())
121 if (!loop.isDefinedOutsideOfLoop(operand))
122 return WalkResult::advance();
123
124 // Only hoist transfer_read / transfer_write pairs for now.
125 if (!transferWrite)
126 return WalkResult::advance();
127
128 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
129 << "\n");
130
131 // Approximate aliasing by checking that:
132 // 1. indices are the same,
133 // 2. no other operations in the loop access the same memref except
134 // for transfer_read/transfer_write accessing statically disjoint
135 // slices.
136 if (transferRead.indices() != transferWrite.indices() &&
137 transferRead.getVectorType() == transferWrite.getVectorType())
138 return WalkResult::advance();
139
140 // TODO: may want to memoize this information for performance but it
141 // likely gets invalidated often.
142 DominanceInfo dom(loop);
143 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
144 return WalkResult::advance();
145 for (auto &use : transferRead.memref().getUses()) {
146 if (!dom.properlyDominates(loop, use.getOwner()))
147 continue;
148 if (use.getOwner() == transferRead.getOperation() ||
149 use.getOwner() == transferWrite.getOperation())
150 continue;
151 if (auto transferWriteUse =
152 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
153 if (!isDisjointTransferSet(
154 cast<VectorTransferOpInterface>(transferWrite.getOperation()),
155 cast<VectorTransferOpInterface>(
156 transferWriteUse.getOperation())))
157 return WalkResult::advance();
158 } else if (auto transferReadUse =
159 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
160 if (!isDisjointTransferSet(
161 cast<VectorTransferOpInterface>(transferWrite.getOperation()),
162 cast<VectorTransferOpInterface>(
163 transferReadUse.getOperation())))
164 return WalkResult::advance();
165 } else {
166 // Unknown use, we cannot prove that it doesn't alias with the
167 // transferRead/transferWrite operations.
168 return WalkResult::advance();
169 }
170 }
171
172 // Hoist read before.
173 if (failed(loop.moveOutOfLoop({transferRead})))
174 llvm_unreachable(
175 "Unexpected failure to move transfer read out of loop");
176
177 // Hoist write after.
178 transferWrite->moveAfter(loop);
179
180 // Rewrite `loop` with new yields by cloning and erase the original loop.
181 OpBuilder b(transferRead);
182 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
183 transferWrite.vector());
184
185 // Transfer write has been hoisted, need to update the written value to
186 // the value yielded by the newForOp.
187 transferWrite.vector().replaceAllUsesWith(
188 newForOp.getResults().take_back()[0]);
189
190 changed = true;
191 loop.erase();
192 // Need to interrupt and restart because erasing the loop messes up the
193 // walk.
194 return WalkResult::interrupt();
195 });
196 }
197 }
198