• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/Dialect/Vector/VectorTransforms.h"
16 #include "mlir/Dialect/Vector/VectorUtils.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dominance.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "vector-transfer-opt"
23 
24 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
25 
26 using namespace mlir;
27 
28 /// Return the ancestor op in the region or nullptr if the region is not
29 /// an ancestor of the op.
findAncestorOpInRegion(Region * region,Operation * op)30 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
31   for (; op != nullptr && op->getParentRegion() != region;
32        op = op->getParentOp())
33     ;
34   return op;
35 }
36 
37 namespace {
38 
39 class TransferOptimization {
40 public:
TransferOptimization(FuncOp func)41   TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
42   void deadStoreOp(vector::TransferWriteOp);
43   void storeToLoadForwarding(vector::TransferReadOp);
removeDeadOp()44   void removeDeadOp() {
45     for (Operation *op : opToErase)
46       op->erase();
47     opToErase.clear();
48   }
49 
50 private:
51   bool isReachable(Operation *start, Operation *dest);
52   DominanceInfo dominators;
53   PostDominanceInfo postDominators;
54   std::vector<Operation *> opToErase;
55 };
56 
57 /// Return true if there is a path from start operation to dest operation,
58 /// otherwise return false. The operations have to be in the same region.
isReachable(Operation * start,Operation * dest)59 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
60   assert(start->getParentRegion() == dest->getParentRegion() &&
61          "This function only works for ops i the same region");
62   // Simple case where the start op dominate the destination.
63   if (dominators.dominates(start, dest))
64     return true;
65   Block *startBlock = start->getBlock();
66   Block *destBlock = dest->getBlock();
67   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
68                                     startBlock->succ_end());
69   SmallPtrSet<Block *, 32> visited;
70   while (!worklist.empty()) {
71     Block *bb = worklist.pop_back_val();
72     if (!visited.insert(bb).second)
73       continue;
74     if (dominators.dominates(bb, destBlock))
75       return true;
76     worklist.append(bb->succ_begin(), bb->succ_end());
77   }
78   return false;
79 }
80 
81 /// For transfer_write to overwrite fully another transfer_write must:
82 /// 1. Access the same memref with the same indices and vector type.
83 /// 2. Post-dominate the other transfer_write operation.
84 /// If several candidates are available, one must be post-dominated by all the
85 /// others since they are all post-dominating the same transfer_write. We only
86 /// consider the transfer_write post-dominated by all the other candidates as
87 /// this will be the first transfer_write executed after the potentially dead
88 /// transfer_write.
89 /// If we found such an overwriting transfer_write we know that the original
90 /// transfer_write is dead if all reads that can be reached from the potentially
91 /// dead transfer_write are dominated by the overwriting transfer_write.
deadStoreOp(vector::TransferWriteOp write)92 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
93   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
94                     << "\n");
95   llvm::SmallVector<Operation *, 8> reads;
96   Operation *firstOverwriteCandidate = nullptr;
97   for (auto *user : write.memref().getUsers()) {
98     if (user == write.getOperation())
99       continue;
100     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
101       // Check candidate that can override the store.
102       if (write.indices() == nextWrite.indices() &&
103           write.getVectorType() == nextWrite.getVectorType() &&
104           write.permutation_map() == write.permutation_map() &&
105           postDominators.postDominates(nextWrite, write)) {
106         if (firstOverwriteCandidate == nullptr ||
107             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
108           firstOverwriteCandidate = nextWrite;
109         else
110           assert(
111               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
112       }
113     } else {
114       if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
115         // Don't need to consider disjoint reads.
116         if (isDisjointTransferSet(
117                 cast<VectorTransferOpInterface>(write.getOperation()),
118                 cast<VectorTransferOpInterface>(read.getOperation())))
119           continue;
120       }
121       reads.push_back(user);
122     }
123   }
124   if (firstOverwriteCandidate == nullptr)
125     return;
126   Region *topRegion = firstOverwriteCandidate->getParentRegion();
127   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
128   assert(writeAncestor &&
129          "write op should be recursively part of the top region");
130 
131   for (Operation *read : reads) {
132     Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
133     // TODO: if the read and write have the same ancestor we could recurse in
134     // the region to know if the read is reachable with more precision.
135     if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
136       continue;
137     if (!dominators.dominates(firstOverwriteCandidate, read)) {
138       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
139                         << "\n");
140       return;
141     }
142   }
143   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
144                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
145   opToErase.push_back(write.getOperation());
146 }
147 
148 /// A transfer_write candidate to storeToLoad forwarding must:
149 /// 1. Access the same memref with the same indices and vector type as the
150 /// transfer_read.
151 /// 2. Dominate the transfer_read operation.
152 /// If several candidates are available, one must be dominated by all the others
153 /// since they are all dominating the same transfer_read. We only consider the
154 /// transfer_write dominated by all the other candidates as this will be the
155 /// last transfer_write executed before the transfer_read.
156 /// If we found such a candidate we can do the forwarding if all the other
157 /// potentially aliasing ops that may reach the transfer_read are post-dominated
158 /// by the transfer_write.
storeToLoadForwarding(vector::TransferReadOp read)159 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
160   if (read.hasMaskedDim())
161     return;
162   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
163                     << "\n");
164   SmallVector<Operation *, 8> blockingWrites;
165   vector::TransferWriteOp lastwrite = nullptr;
166   for (Operation *user : read.memref().getUsers()) {
167     if (isa<vector::TransferReadOp>(user))
168       continue;
169     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
170       // If there is a write, but we can prove that it is disjoint we can ignore
171       // the write.
172       if (isDisjointTransferSet(
173               cast<VectorTransferOpInterface>(write.getOperation()),
174               cast<VectorTransferOpInterface>(read.getOperation())))
175         continue;
176       if (dominators.dominates(write, read) && !write.hasMaskedDim() &&
177           write.indices() == read.indices() &&
178           write.getVectorType() == read.getVectorType() &&
179           write.permutation_map() == read.permutation_map()) {
180         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
181           lastwrite = write;
182         else
183           assert(dominators.dominates(write, lastwrite));
184         continue;
185       }
186     }
187     blockingWrites.push_back(user);
188   }
189 
190   if (lastwrite == nullptr)
191     return;
192 
193   Region *topRegion = lastwrite->getParentRegion();
194   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
195   assert(readAncestor &&
196          "read op should be recursively part of the top region");
197 
198   for (Operation *write : blockingWrites) {
199     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
200     // TODO: if the store and read have the same ancestor we could recurse in
201     // the region to know if the read is reachable with more precision.
202     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
203       continue;
204     if (!postDominators.postDominates(lastwrite, write)) {
205       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
206                         << *write << "\n");
207       return;
208     }
209   }
210 
211   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
212                     << " to: " << *read.getOperation() << "\n");
213   read.replaceAllUsesWith(lastwrite.vector());
214   opToErase.push_back(read.getOperation());
215 }
216 
217 } // namespace
218 
transferOpflowOpt(FuncOp func)219 void mlir::vector::transferOpflowOpt(FuncOp func) {
220   TransferOptimization opt(func);
221   // Run store to load forwarding first since it can expose more dead store
222   // opportunity.
223   func.walk(
224       [&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); });
225   opt.removeDeadOp();
226   func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); });
227   opt.removeDeadOp();
228 }
229