1 //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements view-based alias and dependence analyses.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19
20 #define DEBUG_TYPE "linalg-dependence-analysis"
21
22 using namespace mlir;
23 using namespace mlir::linalg;
24
25 using llvm::dbgs;
26
find(Value v)27 Value Aliases::find(Value v) {
28 if (v.isa<BlockArgument>())
29 return v;
30
31 auto it = aliases.find(v);
32 if (it != aliases.end()) {
33 assert(it->getSecond().getType().isa<BaseMemRefType>() &&
34 "Memref expected");
35 return it->getSecond();
36 }
37
38 while (true) {
39 if (v.isa<BlockArgument>())
40 return v;
41
42 Operation *defOp = v.getDefiningOp();
43 if (!defOp)
44 return v;
45
46 if (isa<TensorToMemrefOp>(defOp))
47 return v;
48
49 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
50 // Collect all memory effects on `v`.
51 SmallVector<MemoryEffects::EffectInstance, 1> effects;
52 memEffect.getEffectsOnValue(v, effects);
53
54 // If we have the 'Allocate' memory effect on `v`, then `v` should be the
55 // original buffer.
56 if (llvm::any_of(
57 effects, [](const MemoryEffects::EffectInstance &instance) {
58 return isa<MemoryEffects::Allocate>(instance.getEffect());
59 }))
60 return v;
61 }
62
63 if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
64 auto it =
65 aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
66 return it.first->second;
67 }
68
69 llvm::errs() << "View alias analysis reduces to: " << v << "\n";
70 llvm_unreachable("unsupported view alias case");
71 }
72 }
73
getDependenceTypeStr(DependenceType depType)74 StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
75 switch (depType) {
76 case LinalgDependenceGraph::DependenceType::RAW:
77 return "RAW";
78 case LinalgDependenceGraph::DependenceType::RAR:
79 return "RAR";
80 case LinalgDependenceGraph::DependenceType::WAR:
81 return "WAR";
82 case LinalgDependenceGraph::DependenceType::WAW:
83 return "WAW";
84 default:
85 break;
86 }
87 llvm_unreachable("Unexpected DependenceType");
88 }
89
90 LinalgDependenceGraph
buildDependenceGraph(Aliases & aliases,FuncOp f)91 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
92 SmallVector<LinalgOp, 8> linalgOps;
93 f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
94 return LinalgDependenceGraph(aliases, linalgOps);
95 }
96
LinalgDependenceGraph(Aliases & aliases,ArrayRef<LinalgOp> ops)97 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
98 ArrayRef<LinalgOp> ops)
99 : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
100 for (auto en : llvm::enumerate(linalgOps)) {
101 linalgOpPositions.insert(
102 std::make_pair(en.value().getOperation(), en.index()));
103 }
104 for (unsigned i = 0, e = ops.size(); i < e; ++i) {
105 for (unsigned j = i + 1; j < e; ++j) {
106 addDependencesBetween(ops[i], ops[j]);
107 }
108 }
109 }
110
addDependenceElem(DependenceType dt,LinalgOpView indexingOpView,LinalgOpView dependentOpView)111 void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
112 LinalgOpView indexingOpView,
113 LinalgOpView dependentOpView) {
114 LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
115 << *indexingOpView.op << ", " << indexingOpView.operandIndex
116 << ") -> \n\t\t(" << *dependentOpView.op << ", "
117 << dependentOpView.operandIndex << ")");
118 dependencesFromGraphs[dt][indexingOpView.op].push_back(
119 LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
120 dependencesIntoGraphs[dt][dependentOpView.op].push_back(
121 LinalgDependenceGraphElem{indexingOpView, dependentOpView, dt});
122 }
123
124 LinalgDependenceGraph::dependence_range
getDependencesFrom(LinalgOp src,LinalgDependenceGraph::DependenceType dt) const125 LinalgDependenceGraph::getDependencesFrom(
126 LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
127 return getDependencesFrom(src.getOperation(), dt);
128 }
129
130 LinalgDependenceGraph::dependence_range
getDependencesFrom(Operation * src,LinalgDependenceGraph::DependenceType dt) const131 LinalgDependenceGraph::getDependencesFrom(
132 Operation *src, LinalgDependenceGraph::DependenceType dt) const {
133 auto iter = dependencesFromGraphs[dt].find(src);
134 if (iter == dependencesFromGraphs[dt].end())
135 return llvm::make_range(nullptr, nullptr);
136 return llvm::make_range(iter->second.begin(), iter->second.end());
137 }
138
139 LinalgDependenceGraph::dependence_range
getDependencesInto(LinalgOp dst,LinalgDependenceGraph::DependenceType dt) const140 LinalgDependenceGraph::getDependencesInto(
141 LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
142 return getDependencesInto(dst.getOperation(), dt);
143 }
144
145 LinalgDependenceGraph::dependence_range
getDependencesInto(Operation * dst,LinalgDependenceGraph::DependenceType dt) const146 LinalgDependenceGraph::getDependencesInto(
147 Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
148 auto iter = dependencesIntoGraphs[dt].find(dst);
149 if (iter == dependencesIntoGraphs[dt].end())
150 return llvm::make_range(nullptr, nullptr);
151 return llvm::make_range(iter->second.begin(), iter->second.end());
152 }
153
addDependencesBetween(LinalgOp src,LinalgOp dst)154 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
155 for (auto srcView : llvm::enumerate(src.getOutputBuffers())) { // W
156 unsigned srcIndex =
157 src.getOperandIndexForOutputIndex(srcView.index()).getValue();
158 // RAW graph
159 for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
160 if (aliases.alias(srcView.value(),
161 dstView.value())) { // if alias, fill RAW
162 unsigned dstIndex =
163 dst.getOperandIndexForInputIndex(dstView.index()).getValue();
164 addDependenceElem(DependenceType::RAW,
165 LinalgOpView{src.getOperation(), srcIndex},
166 LinalgOpView{dst.getOperation(), dstIndex});
167 }
168 }
169 // WAW graph
170 for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
171 if (aliases.alias(srcView.value(),
172 dstView.value())) { // if alias, fill WAW
173 unsigned dstIndex =
174 dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
175 addDependenceElem(DependenceType::WAW,
176 LinalgOpView{src.getOperation(), srcIndex},
177 LinalgOpView{dst.getOperation(), dstIndex});
178 }
179 }
180 }
181 for (auto srcView : llvm::enumerate(src.getInputBuffers())) { // R
182 unsigned srcIndex =
183 src.getOperandIndexForInputIndex(srcView.index()).getValue();
184 // RAR graph
185 for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R
186 if (aliases.alias(srcView.value(),
187 dstView.value())) { // if alias, fill RAR
188 unsigned dstIndex =
189 dst.getOperandIndexForInputIndex(dstView.index()).getValue();
190 addDependenceElem(DependenceType::RAR,
191 LinalgOpView{src.getOperation(), srcIndex},
192 LinalgOpView{dst.getOperation(), dstIndex});
193 }
194 }
195 // WAR graph
196 for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W
197 if (aliases.alias(srcView.value(),
198 dstView.value())) { // if alias, fill WAR
199 unsigned dstIndex =
200 dst.getOperandIndexForOutputIndex(dstView.index()).getValue();
201 addDependenceElem(DependenceType::WAR,
202 LinalgOpView{src.getOperation(), srcIndex},
203 LinalgOpView{dst.getOperation(), dstIndex});
204 }
205 }
206 }
207 }
208
209 SmallVector<Operation *, 8>
findCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp) const210 LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
211 LinalgOp dstLinalgOp) const {
212 return findOperationsWithCoveringDependences(
213 srcLinalgOp, dstLinalgOp, nullptr,
214 {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
215 }
216
findCoveringWrites(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const217 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
218 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
219 return findOperationsWithCoveringDependences(
220 srcLinalgOp, dstLinalgOp, view,
221 {DependenceType::WAW, DependenceType::WAR});
222 }
223
findCoveringReads(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view) const224 SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
225 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
226 return findOperationsWithCoveringDependences(
227 srcLinalgOp, dstLinalgOp, view,
228 {DependenceType::RAR, DependenceType::RAW});
229 }
230
231 SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,Value view,ArrayRef<DependenceType> types) const232 LinalgDependenceGraph::findOperationsWithCoveringDependences(
233 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
234 ArrayRef<DependenceType> types) const {
235 auto *src = srcLinalgOp.getOperation();
236 auto *dst = dstLinalgOp.getOperation();
237 auto srcPos = linalgOpPositions.lookup(src);
238 auto dstPos = linalgOpPositions.lookup(dst);
239 assert(srcPos < dstPos && "expected dst after src in IR traversal order");
240
241 SmallVector<Operation *, 8> res;
242 // Consider an intermediate interleaved `interim` op, look for any dependence
243 // to an aliasing view on a src -> op -> dst path.
244 // TODO: we are not considering paths yet, just interleaved positions.
245 for (auto dt : types) {
246 for (auto dependence : getDependencesFrom(src, dt)) {
247 auto interimPos = linalgOpPositions.lookup(dependence.dependentOpView.op);
248 // Skip if not interleaved.
249 if (interimPos >= dstPos || interimPos <= srcPos)
250 continue;
251 linalg::LinalgOp consumer =
252 cast<linalg::LinalgOp>(dependence.indexingOpView.op);
253 Value consumerView =
254 consumer.getShapedOperand(dependence.indexingOpView.operandIndex);
255 if (view && !aliases.alias(view, consumerView))
256 continue;
257 auto *op = dependence.dependentOpView.op;
258 LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
259 << getDependenceTypeStr(dt) << ": " << *src << " -> "
260 << *op << " on " << consumerView);
261 res.push_back(op);
262 }
263 }
264 return res;
265 }
266
hasDependenceFrom(LinalgOp srcLinalgOp,LinalgOp dstLinalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const267 bool LinalgDependenceGraph::hasDependenceFrom(
268 LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
269 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
270 for (auto dep : depTypes) {
271 for (auto dependence : getDependencesInto(dstLinalgOp, dep)) {
272 if (dependence.dependentOpView.op == srcLinalgOp)
273 return true;
274 }
275 }
276 return false;
277 }
278
hasDependentOperationsFrom(LinalgOp linalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const279 bool LinalgDependenceGraph::hasDependentOperationsFrom(
280 LinalgOp linalgOp,
281 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
282 for (auto dep : depTypes) {
283 if (!getDependencesFrom(linalgOp, dep).empty())
284 return true;
285 }
286 return false;
287 }
288
hasDependentOperationsInto(LinalgOp linalgOp,ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const289 bool LinalgDependenceGraph::hasDependentOperationsInto(
290 LinalgOp linalgOp,
291 ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
292 for (auto dep : depTypes) {
293 if (!getDependencesInto(linalgOp, dep).empty())
294 return true;
295 }
296 return false;
297 }
298
hasDependentOperations(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const299 bool LinalgDependenceGraph::hasDependentOperations(
300 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
301 return hasDependentOperationsInto(linalgOp, depTypes) ||
302 hasDependentOperationsFrom(linalgOp, depTypes);
303 }
304
305 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperationsInto(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const306 LinalgDependenceGraph::getDependentOperationsInto(
307 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
308 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
309 dependentOperations;
310 for (auto dependenceType : depTypes) {
311 auto dependencies = getDependencesInto(linalgOp, dependenceType);
312 dependentOperations.append(dependencies.begin(), dependencies.end());
313 }
314 return dependentOperations;
315 }
316
317 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperationsFrom(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const318 LinalgDependenceGraph::getDependentOperationsFrom(
319 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
320 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
321 dependentOperations;
322 for (auto dependenceType : depTypes) {
323 auto dependencies = getDependencesFrom(linalgOp, dependenceType);
324 dependentOperations.append(dependencies.begin(), dependencies.end());
325 }
326 return dependentOperations;
327 }
328
329 /// Returns all dependent operations (into and from) given `operation`.
330 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
getDependentOperations(LinalgOp linalgOp,ArrayRef<DependenceType> depTypes) const331 LinalgDependenceGraph::getDependentOperations(
332 LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
333 SmallVector<LinalgDependenceGraphElem, 2> dependentOperations =
334 getDependentOperationsInto(linalgOp, depTypes);
335 SmallVector<LinalgDependenceGraphElem, 2> t =
336 getDependentOperationsFrom(linalgOp, depTypes);
337 dependentOperations.append(t.begin(), t.end());
338 return dependentOperations;
339 }
340