• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- MemRefDataFlowOpt.cpp - MemRef DataFlow Optimization pass ------ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to forward memref stores to loads, thereby
10 // potentially getting rid of intermediate memref's entirely.
11 // TODO: In the future, similar techniques could be used to eliminate
12 // dead memref store's and perform more complex forwarding when support for
13 // SSA scalars live out of 'affine.for'/'affine.if' statements is available.
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/AffineAnalysis.h"
18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/Transforms/Passes.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include <algorithm>
25 
26 #define DEBUG_TYPE "memref-dataflow-opt"
27 
28 using namespace mlir;
29 
30 namespace {
31 // The store to load forwarding relies on three conditions:
32 //
33 // 1) they need to have mathematically equivalent affine access functions
34 // (checked after full composition of load/store operands); this implies that
35 // they access the same single memref element for all iterations of the common
36 // surrounding loop,
37 //
38 // 2) the store op should dominate the load op,
39 //
40 // 3) among all op's that satisfy both (1) and (2), the one that postdominates
41 // all store op's that have a dependence into the load, is provably the last
42 // writer to the particular memref location being loaded at the load op, and its
43 // store value can be forwarded to the load. Note that the only dependences
44 // that are to be considered are those that are satisfied at the block* of the
45 // innermost common surrounding loop of the <store, load> being considered.
46 //
47 // (* A dependence being satisfied at a block: a dependence that is satisfied by
48 // virtue of the destination operation appearing textually / lexically after
49 // the source operation within the body of a 'affine.for' operation; thus, a
50 // dependence is always either satisfied by a loop or by a block).
51 //
52 // The above conditions are simple to check, sufficient, and powerful for most
53 // cases in practice - they are sufficient, but not necessary --- since they
54 // don't reason about loops that are guaranteed to execute at least once or
55 // multiple sources to forward from.
56 //
57 // TODO: more forwarding can be done when support for
58 // loop/conditional live-out SSA values is available.
59 // TODO: do general dead store elimination for memref's. This pass
60 // currently only eliminates the stores only if no other loads/uses (other
61 // than dealloc) remain.
62 //
63 struct MemRefDataFlowOpt : public MemRefDataFlowOptBase<MemRefDataFlowOpt> {
64   void runOnFunction() override;
65 
66   void forwardStoreToLoad(AffineReadOpInterface loadOp);
67 
68   // A list of memref's that are potentially dead / could be eliminated.
69   SmallPtrSet<Value, 4> memrefsToErase;
70   // Load op's whose results were replaced by those forwarded from stores.
71   SmallVector<Operation *, 8> loadOpsToErase;
72 
73   DominanceInfo *domInfo = nullptr;
74   PostDominanceInfo *postDomInfo = nullptr;
75 };
76 
77 } // end anonymous namespace
78 
79 /// Creates a pass to perform optimizations relying on memref dataflow such as
80 /// store to load forwarding, elimination of dead stores, and dead allocs.
createMemRefDataFlowOptPass()81 std::unique_ptr<OperationPass<FuncOp>> mlir::createMemRefDataFlowOptPass() {
82   return std::make_unique<MemRefDataFlowOpt>();
83 }
84 
85 // This is a straightforward implementation not optimized for speed. Optimize
86 // if needed.
forwardStoreToLoad(AffineReadOpInterface loadOp)87 void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
88   // First pass over the use list to get the minimum number of surrounding
89   // loops common between the load op and the store op, with min taken across
90   // all store ops.
91   SmallVector<Operation *, 8> storeOps;
92   unsigned minSurroundingLoops = getNestingDepth(loadOp);
93   for (auto *user : loadOp.getMemRef().getUsers()) {
94     auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
95     if (!storeOp)
96       continue;
97     unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
98     minSurroundingLoops = std::min(nsLoops, minSurroundingLoops);
99     storeOps.push_back(storeOp);
100   }
101 
102   // The list of store op candidates for forwarding that satisfy conditions
103   // (1) and (2) above - they will be filtered later when checking (3).
104   SmallVector<Operation *, 8> fwdingCandidates;
105 
106   // Store ops that have a dependence into the load (even if they aren't
107   // forwarding candidates). Each forwarding candidate will be checked for a
108   // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
109   SmallVector<Operation *, 8> depSrcStores;
110 
111   for (auto *storeOp : storeOps) {
112     MemRefAccess srcAccess(storeOp);
113     MemRefAccess destAccess(loadOp);
114     // Find stores that may be reaching the load.
115     FlatAffineConstraints dependenceConstraints;
116     unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
117     unsigned d;
118     // Dependences at loop depth <= minSurroundingLoops do NOT matter.
119     for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
120       DependenceResult result = checkMemrefAccessDependence(
121           srcAccess, destAccess, d, &dependenceConstraints,
122           /*dependenceComponents=*/nullptr);
123       if (hasDependence(result))
124         break;
125     }
126     if (d == minSurroundingLoops)
127       continue;
128 
129     // Stores that *may* be reaching the load.
130     depSrcStores.push_back(storeOp);
131 
132     // 1. Check if the store and the load have mathematically equivalent
133     // affine access functions; this implies that they statically refer to the
134     // same single memref element. As an example this filters out cases like:
135     //     store %A[%i0 + 1]
136     //     load %A[%i0]
137     //     store %A[%M]
138     //     load %A[%N]
139     // Use the AffineValueMap difference based memref access equality checking.
140     if (srcAccess != destAccess)
141       continue;
142 
143     // 2. The store has to dominate the load op to be candidate.
144     if (!domInfo->dominates(storeOp, loadOp))
145       continue;
146 
147     // We now have a candidate for forwarding.
148     fwdingCandidates.push_back(storeOp);
149   }
150 
151   // 3. Of all the store op's that meet the above criteria, the store that
152   // postdominates all 'depSrcStores' (if one exists) is the unique store
153   // providing the value to the load, i.e., provably the last writer to that
154   // memref loc.
155   // Note: this can be implemented in a cleaner way with postdominator tree
156   // traversals. Consider this for the future if needed.
157   Operation *lastWriteStoreOp = nullptr;
158   for (auto *storeOp : fwdingCandidates) {
159     if (llvm::all_of(depSrcStores, [&](Operation *depStore) {
160           return postDomInfo->postDominates(storeOp, depStore);
161         })) {
162       lastWriteStoreOp = storeOp;
163       break;
164     }
165   }
166   if (!lastWriteStoreOp)
167     return;
168 
169   // Perform the actual store to load forwarding.
170   Value storeVal =
171     cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
172   loadOp.getValue().replaceAllUsesWith(storeVal);
173   // Record the memref for a later sweep to optimize away.
174   memrefsToErase.insert(loadOp.getMemRef());
175   // Record this to erase later.
176   loadOpsToErase.push_back(loadOp);
177 }
178 
runOnFunction()179 void MemRefDataFlowOpt::runOnFunction() {
180   // Only supports single block functions at the moment.
181   FuncOp f = getFunction();
182   if (!llvm::hasSingleElement(f)) {
183     markAllAnalysesPreserved();
184     return;
185   }
186 
187   domInfo = &getAnalysis<DominanceInfo>();
188   postDomInfo = &getAnalysis<PostDominanceInfo>();
189 
190   loadOpsToErase.clear();
191   memrefsToErase.clear();
192 
193   // Walk all load's and perform store to load forwarding.
194   f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); });
195 
196   // Erase all load op's whose results were replaced with store fwd'ed ones.
197   for (auto *loadOp : loadOpsToErase)
198     loadOp->erase();
199 
200   // Check if the store fwd'ed memrefs are now left with only stores and can
201   // thus be completely deleted. Note: the canonicalize pass should be able
202   // to do this as well, but we'll do it here since we collected these anyway.
203   for (auto memref : memrefsToErase) {
204     // If the memref hasn't been alloc'ed in this function, skip.
205     Operation *defOp = memref.getDefiningOp();
206     if (!defOp || !isa<AllocOp>(defOp))
207       // TODO: if the memref was returned by a 'call' operation, we
208       // could still erase it if the call had no side-effects.
209       continue;
210     if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
211           return !isa<AffineWriteOpInterface, DeallocOp>(ownerOp);
212         }))
213       continue;
214 
215     // Erase all stores, the dealloc, and the alloc on the memref.
216     for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
217       user->erase();
218     defOp->erase();
219   }
220 }
221