1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/Vector/VectorTransforms.h"
16 #include "mlir/Dialect/Vector/VectorUtils.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dominance.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Support/Debug.h"
21
22 #define DEBUG_TYPE "vector-transfer-opt"
23
24 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
25
26 using namespace mlir;
27
28 /// Return the ancestor op in the region or nullptr if the region is not
29 /// an ancestor of the op.
findAncestorOpInRegion(Region * region,Operation * op)30 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
31 for (; op != nullptr && op->getParentRegion() != region;
32 op = op->getParentOp())
33 ;
34 return op;
35 }
36
37 namespace {
38
39 class TransferOptimization {
40 public:
TransferOptimization(FuncOp func)41 TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
42 void deadStoreOp(vector::TransferWriteOp);
43 void storeToLoadForwarding(vector::TransferReadOp);
removeDeadOp()44 void removeDeadOp() {
45 for (Operation *op : opToErase)
46 op->erase();
47 opToErase.clear();
48 }
49
50 private:
51 bool isReachable(Operation *start, Operation *dest);
52 DominanceInfo dominators;
53 PostDominanceInfo postDominators;
54 std::vector<Operation *> opToErase;
55 };
56
57 /// Return true if there is a path from start operation to dest operation,
58 /// otherwise return false. The operations have to be in the same region.
isReachable(Operation * start,Operation * dest)59 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
60 assert(start->getParentRegion() == dest->getParentRegion() &&
61 "This function only works for ops i the same region");
62 // Simple case where the start op dominate the destination.
63 if (dominators.dominates(start, dest))
64 return true;
65 Block *startBlock = start->getBlock();
66 Block *destBlock = dest->getBlock();
67 SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
68 startBlock->succ_end());
69 SmallPtrSet<Block *, 32> visited;
70 while (!worklist.empty()) {
71 Block *bb = worklist.pop_back_val();
72 if (!visited.insert(bb).second)
73 continue;
74 if (dominators.dominates(bb, destBlock))
75 return true;
76 worklist.append(bb->succ_begin(), bb->succ_end());
77 }
78 return false;
79 }
80
81 /// For transfer_write to overwrite fully another transfer_write must:
82 /// 1. Access the same memref with the same indices and vector type.
83 /// 2. Post-dominate the other transfer_write operation.
84 /// If several candidates are available, one must be post-dominated by all the
85 /// others since they are all post-dominating the same transfer_write. We only
86 /// consider the transfer_write post-dominated by all the other candidates as
87 /// this will be the first transfer_write executed after the potentially dead
88 /// transfer_write.
89 /// If we found such an overwriting transfer_write we know that the original
90 /// transfer_write is dead if all reads that can be reached from the potentially
91 /// dead transfer_write are dominated by the overwriting transfer_write.
deadStoreOp(vector::TransferWriteOp write)92 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
93 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
94 << "\n");
95 llvm::SmallVector<Operation *, 8> reads;
96 Operation *firstOverwriteCandidate = nullptr;
97 for (auto *user : write.memref().getUsers()) {
98 if (user == write.getOperation())
99 continue;
100 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
101 // Check candidate that can override the store.
102 if (write.indices() == nextWrite.indices() &&
103 write.getVectorType() == nextWrite.getVectorType() &&
104 write.permutation_map() == write.permutation_map() &&
105 postDominators.postDominates(nextWrite, write)) {
106 if (firstOverwriteCandidate == nullptr ||
107 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
108 firstOverwriteCandidate = nextWrite;
109 else
110 assert(
111 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
112 }
113 } else {
114 if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
115 // Don't need to consider disjoint reads.
116 if (isDisjointTransferSet(
117 cast<VectorTransferOpInterface>(write.getOperation()),
118 cast<VectorTransferOpInterface>(read.getOperation())))
119 continue;
120 }
121 reads.push_back(user);
122 }
123 }
124 if (firstOverwriteCandidate == nullptr)
125 return;
126 Region *topRegion = firstOverwriteCandidate->getParentRegion();
127 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
128 assert(writeAncestor &&
129 "write op should be recursively part of the top region");
130
131 for (Operation *read : reads) {
132 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
133 // TODO: if the read and write have the same ancestor we could recurse in
134 // the region to know if the read is reachable with more precision.
135 if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
136 continue;
137 if (!dominators.dominates(firstOverwriteCandidate, read)) {
138 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
139 << "\n");
140 return;
141 }
142 }
143 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
144 << " overwritten by: " << *firstOverwriteCandidate << "\n");
145 opToErase.push_back(write.getOperation());
146 }
147
148 /// A transfer_write candidate to storeToLoad forwarding must:
149 /// 1. Access the same memref with the same indices and vector type as the
150 /// transfer_read.
151 /// 2. Dominate the transfer_read operation.
152 /// If several candidates are available, one must be dominated by all the others
153 /// since they are all dominating the same transfer_read. We only consider the
154 /// transfer_write dominated by all the other candidates as this will be the
155 /// last transfer_write executed before the transfer_read.
156 /// If we found such a candidate we can do the forwarding if all the other
157 /// potentially aliasing ops that may reach the transfer_read are post-dominated
158 /// by the transfer_write.
storeToLoadForwarding(vector::TransferReadOp read)159 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
160 if (read.hasMaskedDim())
161 return;
162 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
163 << "\n");
164 SmallVector<Operation *, 8> blockingWrites;
165 vector::TransferWriteOp lastwrite = nullptr;
166 for (Operation *user : read.memref().getUsers()) {
167 if (isa<vector::TransferReadOp>(user))
168 continue;
169 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
170 // If there is a write, but we can prove that it is disjoint we can ignore
171 // the write.
172 if (isDisjointTransferSet(
173 cast<VectorTransferOpInterface>(write.getOperation()),
174 cast<VectorTransferOpInterface>(read.getOperation())))
175 continue;
176 if (dominators.dominates(write, read) && !write.hasMaskedDim() &&
177 write.indices() == read.indices() &&
178 write.getVectorType() == read.getVectorType() &&
179 write.permutation_map() == read.permutation_map()) {
180 if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
181 lastwrite = write;
182 else
183 assert(dominators.dominates(write, lastwrite));
184 continue;
185 }
186 }
187 blockingWrites.push_back(user);
188 }
189
190 if (lastwrite == nullptr)
191 return;
192
193 Region *topRegion = lastwrite->getParentRegion();
194 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
195 assert(readAncestor &&
196 "read op should be recursively part of the top region");
197
198 for (Operation *write : blockingWrites) {
199 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
200 // TODO: if the store and read have the same ancestor we could recurse in
201 // the region to know if the read is reachable with more precision.
202 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
203 continue;
204 if (!postDominators.postDominates(lastwrite, write)) {
205 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
206 << *write << "\n");
207 return;
208 }
209 }
210
211 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
212 << " to: " << *read.getOperation() << "\n");
213 read.replaceAllUsesWith(lastwrite.vector());
214 opToErase.push_back(read.getOperation());
215 }
216
217 } // namespace
218
transferOpflowOpt(FuncOp func)219 void mlir::vector::transferOpflowOpt(FuncOp func) {
220 TransferOptimization opt(func);
221 // Run store to load forwarding first since it can expose more dead store
222 // opportunity.
223 func.walk(
224 [&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); });
225 opt.removeDeadOp();
226 func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); });
227 opt.removeDeadOp();
228 }
229