1 //===- AffineAnalysis.cpp - Affine structures analysis routines -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for affine structures
10 // (expressions, maps, sets), and other utilities relying on such analysis.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Analysis/AffineAnalysis.h"
15 #include "mlir/Analysis/Utils.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/Support/MathExtras.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/raw_ostream.h"
26
27 #define DEBUG_TYPE "affine-analysis"
28
29 using namespace mlir;
30
31 using llvm::dbgs;
32
33 /// Returns the sequence of AffineApplyOp Operations operation in
34 /// 'affineApplyOps', which are reachable via a search starting from 'operands',
35 /// and ending at operands which are not defined by AffineApplyOps.
36 // TODO: Add a method to AffineApplyOp which forward substitutes the
37 // AffineApplyOp into any user AffineApplyOps.
getReachableAffineApplyOps(ArrayRef<Value> operands,SmallVectorImpl<Operation * > & affineApplyOps)38 void mlir::getReachableAffineApplyOps(
39 ArrayRef<Value> operands, SmallVectorImpl<Operation *> &affineApplyOps) {
40 struct State {
41 // The ssa value for this node in the DFS traversal.
42 Value value;
43 // The operand index of 'value' to explore next during DFS traversal.
44 unsigned operandIndex;
45 };
46 SmallVector<State, 4> worklist;
47 for (auto operand : operands) {
48 worklist.push_back({operand, 0});
49 }
50
51 while (!worklist.empty()) {
52 State &state = worklist.back();
53 auto *opInst = state.value.getDefiningOp();
54 // Note: getDefiningOp will return nullptr if the operand is not an
55 // Operation (i.e. block argument), which is a terminator for the search.
56 if (!isa_and_nonnull<AffineApplyOp>(opInst)) {
57 worklist.pop_back();
58 continue;
59 }
60
61 if (state.operandIndex == 0) {
62 // Pre-Visit: Add 'opInst' to reachable sequence.
63 affineApplyOps.push_back(opInst);
64 }
65 if (state.operandIndex < opInst->getNumOperands()) {
66 // Visit: Add next 'affineApplyOp' operand to worklist.
67 // Get next operand to visit at 'operandIndex'.
68 auto nextOperand = opInst->getOperand(state.operandIndex);
69 // Increment 'operandIndex' in 'state'.
70 ++state.operandIndex;
71 // Add 'nextOperand' to worklist.
72 worklist.push_back({nextOperand, 0});
73 } else {
74 // Post-visit: done visiting operands AffineApplyOp, pop off stack.
75 worklist.pop_back();
76 }
77 }
78 }
79
80 // Builds a system of constraints with dimensional identifiers corresponding to
81 // the loop IVs of the forOps appearing in that order. Any symbols founds in
82 // the bound operands are added as symbols in the system. Returns failure for
83 // the yet unimplemented cases.
84 // TODO: Handle non-unit steps through local variables or stride information in
85 // FlatAffineConstraints. (For eg., by using iv - lb % step = 0 and/or by
86 // introducing a method in FlatAffineConstraints setExprStride(ArrayRef<int64_t>
87 // expr, int64_t stride)
getIndexSet(MutableArrayRef<Operation * > ops,FlatAffineConstraints * domain)88 LogicalResult mlir::getIndexSet(MutableArrayRef<Operation *> ops,
89 FlatAffineConstraints *domain) {
90 SmallVector<Value, 4> indices;
91 SmallVector<AffineForOp, 8> forOps;
92
93 for (Operation *op : ops) {
94 assert((isa<AffineForOp, AffineIfOp>(op)) &&
95 "ops should have either AffineForOp or AffineIfOp");
96 if (AffineForOp forOp = dyn_cast<AffineForOp>(op))
97 forOps.push_back(forOp);
98 }
99 extractForInductionVars(forOps, &indices);
100 // Reset while associated Values in 'indices' to the domain.
101 domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
102 for (Operation *op : ops) {
103 // Add constraints from forOp's bounds.
104 if (AffineForOp forOp = dyn_cast<AffineForOp>(op)) {
105 if (failed(domain->addAffineForOpDomain(forOp)))
106 return failure();
107 } else if (AffineIfOp ifOp = dyn_cast<AffineIfOp>(op)) {
108 domain->addAffineIfOpDomain(ifOp);
109 }
110 }
111 return success();
112 }
113
114 /// Computes the iteration domain for 'op' and populates 'indexSet', which
115 /// encapsulates the constraints involving loops surrounding 'op' and
116 /// potentially involving any Function symbols. The dimensional identifiers in
117 /// 'indexSet' correspond to the loops surrounding 'op' from outermost to
118 /// innermost.
getOpIndexSet(Operation * op,FlatAffineConstraints * indexSet)119 static LogicalResult getOpIndexSet(Operation *op,
120 FlatAffineConstraints *indexSet) {
121 SmallVector<Operation *, 4> ops;
122 getEnclosingAffineForAndIfOps(*op, &ops);
123 return getIndexSet(ops, indexSet);
124 }
125
126 namespace {
127 // ValuePositionMap manages the mapping from Values which represent dimension
128 // and symbol identifiers from 'src' and 'dst' access functions to positions
129 // in new space where some Values are kept separate (using addSrc/DstValue)
130 // and some Values are merged (addSymbolValue).
131 // Position lookups return the absolute position in the new space which
132 // has the following format:
133 //
134 // [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers]
135 //
136 // Note: access function non-IV dimension identifiers (that have 'dimension'
137 // positions in the access function position space) are assigned as symbols
138 // in the output position space. Convenience access functions which lookup
139 // an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
140 // the common case of resolving positions for all access function operands.
141 //
142 // TODO: Generalize this: could take a template parameter for the number of maps
143 // (3 in the current case), and lookups could take indices of maps to check. So
144 // getSrcDimOrSymPos would be "getPos(value, {0, 2})".
145 class ValuePositionMap {
146 public:
addSrcValue(Value value)147 void addSrcValue(Value value) {
148 if (addValueAt(value, &srcDimPosMap, numSrcDims))
149 ++numSrcDims;
150 }
addDstValue(Value value)151 void addDstValue(Value value) {
152 if (addValueAt(value, &dstDimPosMap, numDstDims))
153 ++numDstDims;
154 }
addSymbolValue(Value value)155 void addSymbolValue(Value value) {
156 if (addValueAt(value, &symbolPosMap, numSymbols))
157 ++numSymbols;
158 }
getSrcDimOrSymPos(Value value) const159 unsigned getSrcDimOrSymPos(Value value) const {
160 return getDimOrSymPos(value, srcDimPosMap, 0);
161 }
getDstDimOrSymPos(Value value) const162 unsigned getDstDimOrSymPos(Value value) const {
163 return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
164 }
getSymPos(Value value) const165 unsigned getSymPos(Value value) const {
166 auto it = symbolPosMap.find(value);
167 assert(it != symbolPosMap.end());
168 return numSrcDims + numDstDims + it->second;
169 }
170
getNumSrcDims() const171 unsigned getNumSrcDims() const { return numSrcDims; }
getNumDstDims() const172 unsigned getNumDstDims() const { return numDstDims; }
getNumDims() const173 unsigned getNumDims() const { return numSrcDims + numDstDims; }
getNumSymbols() const174 unsigned getNumSymbols() const { return numSymbols; }
175
176 private:
addValueAt(Value value,DenseMap<Value,unsigned> * posMap,unsigned position)177 bool addValueAt(Value value, DenseMap<Value, unsigned> *posMap,
178 unsigned position) {
179 auto it = posMap->find(value);
180 if (it == posMap->end()) {
181 (*posMap)[value] = position;
182 return true;
183 }
184 return false;
185 }
getDimOrSymPos(Value value,const DenseMap<Value,unsigned> & dimPosMap,unsigned dimPosOffset) const186 unsigned getDimOrSymPos(Value value,
187 const DenseMap<Value, unsigned> &dimPosMap,
188 unsigned dimPosOffset) const {
189 auto it = dimPosMap.find(value);
190 if (it != dimPosMap.end()) {
191 return dimPosOffset + it->second;
192 }
193 it = symbolPosMap.find(value);
194 assert(it != symbolPosMap.end());
195 return numSrcDims + numDstDims + it->second;
196 }
197
198 unsigned numSrcDims = 0;
199 unsigned numDstDims = 0;
200 unsigned numSymbols = 0;
201 DenseMap<Value, unsigned> srcDimPosMap;
202 DenseMap<Value, unsigned> dstDimPosMap;
203 DenseMap<Value, unsigned> symbolPosMap;
204 };
205 } // namespace
206
207 // Builds a map from Value to identifier position in a new merged identifier
208 // list, which is the result of merging dim/symbol lists from src/dst
209 // iteration domains, the format of which is as follows:
210 //
211 // [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
212 //
213 // This method populates 'valuePosMap' with mappings from operand Values in
214 // 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
215 // to the position of these values in the merged list.
buildDimAndSymbolPositionMaps(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,const AffineValueMap & srcAccessMap,const AffineValueMap & dstAccessMap,ValuePositionMap * valuePosMap,FlatAffineConstraints * dependenceConstraints)216 static void buildDimAndSymbolPositionMaps(
217 const FlatAffineConstraints &srcDomain,
218 const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
219 const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
220 FlatAffineConstraints *dependenceConstraints) {
221
222 // IsDimState is a tri-state boolean. It is used to distinguish three
223 // different cases of the values passed to updateValuePosMap.
224 // - When it is TRUE, we are certain that all values are dim values.
225 // - When it is FALSE, we are certain that all values are symbol values.
226 // - When it is UNKNOWN, we need to further check whether the value is from a
227 // loop IV to determine its type (dim or symbol).
228
229 // We need this enumeration because sometimes we cannot determine whether a
230 // Value is a symbol or a dim by the information from the Value itself. If a
231 // Value appears in an affine map of a loop, we can determine whether it is a
232 // dim or not by the function `isForInductionVar`. But when a Value is in the
233 // affine set of an if-statement, there is no way to identify its category
234 // (dim/symbol) by itself. Fortunately, the Values to be inserted into
235 // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such
236 // information of Value category: `srcDomain` and `dstDomain` organize Values
237 // by their category, such that the position of each Value stored in
238 // `srcDomain` and `dstDomain` marks which category that a Value belongs to.
239 // Therefore, we can separate Values into dim and symbol groups before passing
240 // them to the function `updateValuePosMap`. Specifically, when passing the
241 // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE.
242 // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are
243 // not explicitly categorized into dim or symbol, and we have to rely on
244 // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in
245 // this case.
246 enum IsDimState { TRUE, FALSE, UNKNOWN };
247
248 // This function places each given Value (in `values`) under a respective
249 // category in `valuePosMap`. Specifically, the placement rules are:
250 // 1) If `isDim` is FALSE, then every value in `values` are inserted into
251 // `valuePosMap` as symbols.
252 // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an
253 // induction variable of a for-loop, we treat it as symbol as well.
254 // 3) For other cases, we decide whether to add a value to the `src` or the
255 // `dst` section of the dim category simply by the boolean value `isSrc`.
256 auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc,
257 IsDimState isDim) {
258 for (unsigned i = 0, e = values.size(); i < e; ++i) {
259 auto value = values[i];
260 if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) {
261 assert(isValidSymbol(value) &&
262 "access operand has to be either a loop IV or a symbol");
263 valuePosMap->addSymbolValue(value);
264 } else {
265 if (isSrc)
266 valuePosMap->addSrcValue(value);
267 else
268 valuePosMap->addDstValue(value);
269 }
270 }
271 };
272
273 // Collect values from the src and dst domains. For each domain, we separate
274 // the collected values into dim and symbol parts.
275 SmallVector<Value, 4> srcDimValues, dstDimValues, srcSymbolValues,
276 dstSymbolValues;
277 srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcDimValues);
278 dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstDimValues);
279 srcDomain.getIdValues(srcDomain.getNumDimIds(),
280 srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
281 dstDomain.getIdValues(dstDomain.getNumDimIds(),
282 dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
283
284 // Update value position map with dim values from src iteration domain.
285 updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE);
286 // Update value position map with dim values from dst iteration domain.
287 updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE);
288 // Update value position map with symbols from src iteration domain.
289 updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE);
290 // Update value position map with symbols from dst iteration domain.
291 updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE);
292 // Update value position map with identifiers from src access function.
293 updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true,
294 /*isDim=*/UNKNOWN);
295 // Update value position map with identifiers from dst access function.
296 updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false,
297 /*isDim=*/UNKNOWN);
298 }
299
300 // Sets up dependence constraints columns appropriately, in the format:
301 // [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
initDependenceConstraints(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,const AffineValueMap & srcAccessMap,const AffineValueMap & dstAccessMap,const ValuePositionMap & valuePosMap,FlatAffineConstraints * dependenceConstraints)302 static void initDependenceConstraints(
303 const FlatAffineConstraints &srcDomain,
304 const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
305 const AffineValueMap &dstAccessMap, const ValuePositionMap &valuePosMap,
306 FlatAffineConstraints *dependenceConstraints) {
307 // Calculate number of equalities/inequalities and columns required to
308 // initialize FlatAffineConstraints for 'dependenceDomain'.
309 unsigned numIneq =
310 srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
311 AffineMap srcMap = srcAccessMap.getAffineMap();
312 assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
313 unsigned numEq = srcMap.getNumResults();
314 unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
315 unsigned numSymbols = valuePosMap.getNumSymbols();
316 unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds();
317 unsigned numIds = numDims + numSymbols + numLocals;
318 unsigned numCols = numIds + 1;
319
320 // Set flat affine constraints sizes and reserving space for constraints.
321 dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
322 numLocals);
323
324 // Set values corresponding to dependence constraint identifiers.
325 SmallVector<Value, 4> srcLoopIVs, dstLoopIVs;
326 srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
327 dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
328
329 dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs);
330 dependenceConstraints->setIdValues(
331 srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
332
333 // Set values for the symbolic identifier dimensions. `isSymbolDetermined`
334 // indicates whether we are certain that the `values` passed in are all
335 // symbols. If `isSymbolDetermined` is true, then we treat every Value in
336 // `values` as a symbol; otherwise, we let the function `isForInductionVar` to
337 // distinguish whether a Value in `values` is a symbol or not.
338 auto setSymbolIds = [&](ArrayRef<Value> values,
339 bool isSymbolDetermined = true) {
340 for (auto value : values) {
341 if (isSymbolDetermined || !isForInductionVar(value)) {
342 assert(isValidSymbol(value) && "expected symbol");
343 dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
344 }
345 }
346 };
347
348 // We are uncertain about whether all operands in `srcAccessMap` and
349 // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false.
350 setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false);
351 setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false);
352
353 SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
354 srcDomain.getIdValues(srcDomain.getNumDimIds(),
355 srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
356 dstDomain.getIdValues(dstDomain.getNumDimIds(),
357 dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
358 // Since we only take symbol Values out of `srcDomain` and `dstDomain`,
359 // `isSymbolDetermined` is kept to its default value: true.
360 setSymbolIds(srcSymbolValues);
361 setSymbolIds(dstSymbolValues);
362
363 for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds();
364 i < e; i++)
365 assert(dependenceConstraints->getIds()[i].hasValue());
366 }
367
368 // Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
369 // 'dependenceDomain'.
370 // Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
371 // srcDomain/dstDomain Value maps.
addDomainConstraints(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,const ValuePositionMap & valuePosMap,FlatAffineConstraints * dependenceDomain)372 static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
373 const FlatAffineConstraints &dstDomain,
374 const ValuePositionMap &valuePosMap,
375 FlatAffineConstraints *dependenceDomain) {
376 unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds();
377
378 SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols());
379
380 auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) {
381 const FlatAffineConstraints &domain = isSrc ? srcDomain : dstDomain;
382 unsigned numCsts =
383 isEq ? domain.getNumEqualities() : domain.getNumInequalities();
384 unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds();
385 auto at = [&](unsigned i, unsigned j) -> int64_t {
386 return isEq ? domain.atEq(i, j) : domain.atIneq(i, j);
387 };
388 auto map = [&](unsigned i) -> int64_t {
389 return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getIdValue(i))
390 : valuePosMap.getDstDimOrSymPos(domain.getIdValue(i));
391 };
392
393 for (unsigned i = 0; i < numCsts; ++i) {
394 // Zero fill.
395 std::fill(cst.begin(), cst.end(), 0);
396 // Set coefficients for identifiers corresponding to domain.
397 for (unsigned j = 0; j < numDimAndSymbolIds; ++j)
398 cst[map(j)] = at(i, j);
399 // Local terms.
400 for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++)
401 cst[depNumDimsAndSymbolIds + localOffset + j] =
402 at(i, numDimAndSymbolIds + j);
403 // Set constant term.
404 cst[cst.size() - 1] = at(i, domain.getNumCols() - 1);
405 // Add constraint.
406 if (isEq)
407 dependenceDomain->addEquality(cst);
408 else
409 dependenceDomain->addInequality(cst);
410 }
411 };
412
413 // Add equalities from src domain.
414 addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0);
415 // Add inequalities from src domain.
416 addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0);
417 // Add equalities from dst domain.
418 addDomain(/*isSrc=*/false, /*isEq=*/true,
419 /*localOffset=*/srcDomain.getNumLocalIds());
420 // Add inequalities from dst domain.
421 addDomain(/*isSrc=*/false, /*isEq=*/false,
422 /*localOffset=*/srcDomain.getNumLocalIds());
423 }
424
425 // Adds equality constraints that equate src and dst access functions
426 // represented by 'srcAccessMap' and 'dstAccessMap' for each result.
427 // Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
428 // For example, given the following two accesses functions to a 2D memref:
429 //
430 // Source access function:
431 // (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
432 //
433 // Destination access function:
434 // (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
435 //
436 // This method constructs the following equality constraints in
437 // 'dependenceDomain', by equating the access functions for each result
438 // (i.e. each memref dim). Notice that 'd0' for the destination access function
439 // is mapped into 'd0' in the equality constraint:
440 //
441 // d0 d1 s0 c
442 // -- -- -- --
443 // a0 -c0 (a1 - c1) (a1 - c2) = 0
444 // b0 -f0 (b1 - f1) (b1 - f2) = 0
445 //
446 // Returns failure if any AffineExpr cannot be flattened (due to it being
447 // semi-affine). Returns success otherwise.
448 static LogicalResult
addMemRefAccessConstraints(const AffineValueMap & srcAccessMap,const AffineValueMap & dstAccessMap,const ValuePositionMap & valuePosMap,FlatAffineConstraints * dependenceDomain)449 addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
450 const AffineValueMap &dstAccessMap,
451 const ValuePositionMap &valuePosMap,
452 FlatAffineConstraints *dependenceDomain) {
453 AffineMap srcMap = srcAccessMap.getAffineMap();
454 AffineMap dstMap = dstAccessMap.getAffineMap();
455 assert(srcMap.getNumResults() == dstMap.getNumResults());
456 unsigned numResults = srcMap.getNumResults();
457
458 unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
459 ArrayRef<Value> srcOperands = srcAccessMap.getOperands();
460
461 unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
462 ArrayRef<Value> dstOperands = dstAccessMap.getOperands();
463
464 std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
465 std::vector<SmallVector<int64_t, 8>> destFlatExprs;
466 FlatAffineConstraints srcLocalVarCst, destLocalVarCst;
467 // Get flattened expressions for the source destination maps.
468 if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) ||
469 failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst)))
470 return failure();
471
472 unsigned domNumLocalIds = dependenceDomain->getNumLocalIds();
473 unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds();
474 unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds();
475 unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds;
476 for (unsigned i = 0; i < numLocalIdsToAdd; i++) {
477 dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds());
478 }
479
480 unsigned numDims = dependenceDomain->getNumDimIds();
481 unsigned numSymbols = dependenceDomain->getNumSymbolIds();
482 unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
483 unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds;
484
485 // Equality to add.
486 SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
487 for (unsigned i = 0; i < numResults; ++i) {
488 // Zero fill.
489 std::fill(eq.begin(), eq.end(), 0);
490
491 // Flattened AffineExpr for src result 'i'.
492 const auto &srcFlatExpr = srcFlatExprs[i];
493 // Set identifier coefficients from src access function.
494 for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
495 eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
496 // Local terms.
497 for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
498 eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j];
499 // Set constant term.
500 eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
501
502 // Flattened AffineExpr for dest result 'i'.
503 const auto &destFlatExpr = destFlatExprs[i];
504 // Set identifier coefficients from dst access function.
505 for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
506 eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
507 // Local terms.
508 for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
509 eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j];
510 // Set constant term.
511 eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
512
513 // Add equality constraint.
514 dependenceDomain->addEquality(eq);
515 }
516
517 // Add equality constraints for any operands that are defined by constant ops.
518 auto addEqForConstOperands = [&](ArrayRef<Value> operands) {
519 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
520 if (isForInductionVar(operands[i]))
521 continue;
522 auto symbol = operands[i];
523 assert(isValidSymbol(symbol));
524 // Check if the symbol is a constant.
525 if (auto cOp = symbol.getDefiningOp<ConstantIndexOp>())
526 dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
527 cOp.getValue());
528 }
529 };
530
531 // Add equality constraints for any src symbols defined by constant ops.
532 addEqForConstOperands(srcOperands);
533 // Add equality constraints for any dst symbols defined by constant ops.
534 addEqForConstOperands(dstOperands);
535
536 // By construction (see flattener), local var constraints will not have any
537 // equalities.
538 assert(srcLocalVarCst.getNumEqualities() == 0 &&
539 destLocalVarCst.getNumEqualities() == 0);
540 // Add inequalities from srcLocalVarCst and destLocalVarCst into the
541 // dependence domain.
542 SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
543 for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
544 std::fill(ineq.begin(), ineq.end(), 0);
545
546 // Set identifier coefficients from src local var constraints.
547 for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
548 ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
549 srcLocalVarCst.atIneq(r, j);
550 // Local terms.
551 for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
552 ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
553 // Set constant term.
554 ineq[ineq.size() - 1] =
555 srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
556 dependenceDomain->addInequality(ineq);
557 }
558
559 for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
560 std::fill(ineq.begin(), ineq.end(), 0);
561 // Set identifier coefficients from dest local var constraints.
562 for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
563 ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
564 destLocalVarCst.atIneq(r, j);
565 // Local terms.
566 for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
567 ineq[newLocalIdOffset + numSrcLocalIds + j] =
568 destLocalVarCst.atIneq(r, dstNumIds + j);
569 // Set constant term.
570 ineq[ineq.size() - 1] =
571 destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
572
573 dependenceDomain->addInequality(ineq);
574 }
575 return success();
576 }
577
578 // Returns the number of outer loop common to 'src/dstDomain'.
579 // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
580 static unsigned
getNumCommonLoops(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,SmallVectorImpl<AffineForOp> * commonLoops=nullptr)581 getNumCommonLoops(const FlatAffineConstraints &srcDomain,
582 const FlatAffineConstraints &dstDomain,
583 SmallVectorImpl<AffineForOp> *commonLoops = nullptr) {
584 // Find the number of common loops shared by src and dst accesses.
585 unsigned minNumLoops =
586 std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
587 unsigned numCommonLoops = 0;
588 for (unsigned i = 0; i < minNumLoops; ++i) {
589 if (!isForInductionVar(srcDomain.getIdValue(i)) ||
590 !isForInductionVar(dstDomain.getIdValue(i)) ||
591 srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
592 break;
593 if (commonLoops != nullptr)
594 commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i)));
595 ++numCommonLoops;
596 }
597 if (commonLoops != nullptr)
598 assert(commonLoops->size() == numCommonLoops);
599 return numCommonLoops;
600 }
601
602 /// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
getCommonBlock(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,const FlatAffineConstraints & srcDomain,unsigned numCommonLoops)603 static Block *getCommonBlock(const MemRefAccess &srcAccess,
604 const MemRefAccess &dstAccess,
605 const FlatAffineConstraints &srcDomain,
606 unsigned numCommonLoops) {
607 // Get the chain of ancestor blocks to the given `MemRefAccess` instance. The
608 // search terminates when either an op with the `AffineScope` trait or
609 // `endBlock` is reached.
610 auto getChainOfAncestorBlocks = [&](const MemRefAccess &access,
611 SmallVector<Block *, 4> &ancestorBlocks,
612 Block *endBlock = nullptr) {
613 Block *currBlock = access.opInst->getBlock();
614 // Loop terminates when the currBlock is nullptr or equals to the endBlock,
615 // or its parent operation holds an affine scope.
616 while (currBlock && currBlock != endBlock &&
617 !currBlock->getParentOp()->hasTrait<OpTrait::AffineScope>()) {
618 ancestorBlocks.push_back(currBlock);
619 currBlock = currBlock->getParentOp()->getBlock();
620 }
621 };
622
623 if (numCommonLoops == 0) {
624 Block *block = srcAccess.opInst->getBlock();
625 while (!llvm::isa<FuncOp>(block->getParentOp())) {
626 block = block->getParentOp()->getBlock();
627 }
628 return block;
629 }
630 Value commonForIV = srcDomain.getIdValue(numCommonLoops - 1);
631 AffineForOp forOp = getForInductionVarOwner(commonForIV);
632 assert(forOp && "commonForValue was not an induction variable");
633
634 // Find the closest common block including those in AffineIf.
635 SmallVector<Block *, 4> srcAncestorBlocks, dstAncestorBlocks;
636 getChainOfAncestorBlocks(srcAccess, srcAncestorBlocks, forOp.getBody());
637 getChainOfAncestorBlocks(dstAccess, dstAncestorBlocks, forOp.getBody());
638
639 Block *commonBlock = forOp.getBody();
640 for (int i = srcAncestorBlocks.size() - 1, j = dstAncestorBlocks.size() - 1;
641 i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j];
642 i--, j--)
643 commonBlock = srcAncestorBlocks[i];
644
645 return commonBlock;
646 }
647
648 // Returns true if the ancestor operation of 'srcAccess' appears before the
649 // ancestor operation of 'dstAccess' in the common ancestral block. Returns
650 // false otherwise.
651 // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
652 // the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that
653 // 'numCommonLoops' is the number of contiguous surrounding outer loops.
srcAppearsBeforeDstInAncestralBlock(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,const FlatAffineConstraints & srcDomain,unsigned numCommonLoops)654 static bool srcAppearsBeforeDstInAncestralBlock(
655 const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
656 const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) {
657 // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
658 auto *commonBlock =
659 getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
660 // Check the dominance relationship between the respective ancestors of the
661 // src and dst in the Block of the innermost among the common loops.
662 auto *srcInst = commonBlock->findAncestorOpInBlock(*srcAccess.opInst);
663 assert(srcInst != nullptr);
664 auto *dstInst = commonBlock->findAncestorOpInBlock(*dstAccess.opInst);
665 assert(dstInst != nullptr);
666
667 // Determine whether dstInst comes after srcInst.
668 return srcInst->isBeforeInBlock(dstInst);
669 }
670
671 // Adds ordering constraints to 'dependenceDomain' based on number of loops
672 // common to 'src/dstDomain' and requested 'loopDepth'.
673 // Note that 'loopDepth' cannot exceed the number of common loops plus one.
674 // EX: Given a loop nest of depth 2 with IVs 'i' and 'j':
675 // *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1
676 // *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1
677 // *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j'
addOrderingConstraints(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,unsigned loopDepth,FlatAffineConstraints * dependenceDomain)678 static void addOrderingConstraints(const FlatAffineConstraints &srcDomain,
679 const FlatAffineConstraints &dstDomain,
680 unsigned loopDepth,
681 FlatAffineConstraints *dependenceDomain) {
682 unsigned numCols = dependenceDomain->getNumCols();
683 SmallVector<int64_t, 4> eq(numCols);
684 unsigned numSrcDims = srcDomain.getNumDimIds();
685 unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
686 unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth);
687 for (unsigned i = 0; i < numCommonLoopConstraints; ++i) {
688 std::fill(eq.begin(), eq.end(), 0);
689 eq[i] = -1;
690 eq[i + numSrcDims] = 1;
691 if (i == loopDepth - 1) {
692 eq[numCols - 1] = -1;
693 dependenceDomain->addInequality(eq);
694 } else {
695 dependenceDomain->addEquality(eq);
696 }
697 }
698 }
699
700 // Computes distance and direction vectors in 'dependences', by adding
701 // variables to 'dependenceDomain' which represent the difference of the IVs,
702 // eliminating all other variables, and reading off distance vectors from
703 // equality constraints (if possible), and direction vectors from inequalities.
computeDirectionVector(const FlatAffineConstraints & srcDomain,const FlatAffineConstraints & dstDomain,unsigned loopDepth,FlatAffineConstraints * dependenceDomain,SmallVector<DependenceComponent,2> * dependenceComponents)704 static void computeDirectionVector(
705 const FlatAffineConstraints &srcDomain,
706 const FlatAffineConstraints &dstDomain, unsigned loopDepth,
707 FlatAffineConstraints *dependenceDomain,
708 SmallVector<DependenceComponent, 2> *dependenceComponents) {
709 // Find the number of common loops shared by src and dst accesses.
710 SmallVector<AffineForOp, 4> commonLoops;
711 unsigned numCommonLoops =
712 getNumCommonLoops(srcDomain, dstDomain, &commonLoops);
713 if (numCommonLoops == 0)
714 return;
715 // Compute direction vectors for requested loop depth.
716 unsigned numIdsToEliminate = dependenceDomain->getNumIds();
717 // Add new variables to 'dependenceDomain' to represent the direction
718 // constraints for each shared loop.
719 for (unsigned j = 0; j < numCommonLoops; ++j) {
720 dependenceDomain->addDimId(j);
721 }
722
723 // Add equality constraints for each common loop, setting newly introduced
724 // variable at column 'j' to the 'dst' IV minus the 'src IV.
725 SmallVector<int64_t, 4> eq;
726 eq.resize(dependenceDomain->getNumCols());
727 unsigned numSrcDims = srcDomain.getNumDimIds();
728 // Constraint variables format:
729 // [num-common-loops][num-src-dim-ids][num-dst-dim-ids][num-symbols][constant]
730 for (unsigned j = 0; j < numCommonLoops; ++j) {
731 std::fill(eq.begin(), eq.end(), 0);
732 eq[j] = 1;
733 eq[j + numCommonLoops] = 1;
734 eq[j + numCommonLoops + numSrcDims] = -1;
735 dependenceDomain->addEquality(eq);
736 }
737
738 // Eliminate all variables other than the direction variables just added.
739 dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate);
740
741 // Scan each common loop variable column and set direction vectors based
742 // on eliminated constraint system.
743 dependenceComponents->resize(numCommonLoops);
744 for (unsigned j = 0; j < numCommonLoops; ++j) {
745 (*dependenceComponents)[j].op = commonLoops[j].getOperation();
746 auto lbConst = dependenceDomain->getConstantLowerBound(j);
747 (*dependenceComponents)[j].lb =
748 lbConst.getValueOr(std::numeric_limits<int64_t>::min());
749 auto ubConst = dependenceDomain->getConstantUpperBound(j);
750 (*dependenceComponents)[j].ub =
751 ubConst.getValueOr(std::numeric_limits<int64_t>::max());
752 }
753 }
754
755 // Populates 'accessMap' with composition of AffineApplyOps reachable from
756 // indices of MemRefAccess.
getAccessMap(AffineValueMap * accessMap) const757 void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
758 // Get affine map from AffineLoad/Store.
759 AffineMap map;
760 if (auto loadOp = dyn_cast<AffineReadOpInterface>(opInst))
761 map = loadOp.getAffineMap();
762 else
763 map = cast<AffineWriteOpInterface>(opInst).getAffineMap();
764
765 SmallVector<Value, 8> operands(indices.begin(), indices.end());
766 fullyComposeAffineMapAndOperands(&map, &operands);
767 map = simplifyAffineMap(map);
768 canonicalizeMapAndOperands(&map, &operands);
769 accessMap->reset(map, operands);
770 }
771
772 // Builds a flat affine constraint system to check if there exists a dependence
773 // between memref accesses 'srcAccess' and 'dstAccess'.
774 // Returns 'NoDependence' if the accesses can be definitively shown not to
775 // access the same element.
776 // Returns 'HasDependence' if the accesses do access the same element.
777 // Returns 'Failure' if an error or unsupported case was encountered.
778 // If a dependence exists, returns in 'dependenceComponents' a direction
779 // vector for the dependence, with a component for each loop IV in loops
780 // common to both accesses (see Dependence in AffineAnalysis.h for details).
781 //
782 // The memref access dependence check is comprised of the following steps:
783 // *) Compute access functions for each access. Access functions are computed
784 // using AffineValueMaps initialized with the indices from an access, then
785 // composed with AffineApplyOps reachable from operands of that access,
786 // until operands of the AffineValueMap are loop IVs or symbols.
787 // *) Build iteration domain constraints for each access. Iteration domain
788 // constraints are pairs of inequality constraints representing the
789 // upper/lower loop bounds for each AffineForOp in the loop nest associated
790 // with each access.
791 // *) Build dimension and symbol position maps for each access, which map
792 // Values from access functions and iteration domains to their position
793 // in the merged constraint system built by this method.
794 //
795 // This method builds a constraint system with the following column format:
796 //
797 // [src-dim-identifiers, dst-dim-identifiers, symbols, constant]
798 //
799 // For example, given the following MLIR code with "source" and "destination"
800 // accesses to the same memref label, and symbols %M, %N, %K:
801 //
802 // affine.for %i0 = 0 to 100 {
803 // affine.for %i1 = 0 to 50 {
804 // %a0 = affine.apply
805 // (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N]
806 // // Source memref access.
807 // store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32>
808 // }
809 // }
810 //
811 // affine.for %i2 = 0 to 100 {
812 // affine.for %i3 = 0 to 50 {
813 // %a1 = affine.apply
814 // (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M]
815 // // Destination memref access.
816 // %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32>
817 // }
818 // }
819 //
820 // The access functions would be the following:
821 //
822 // src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
823 // dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
824 //
825 // The iteration domains for the src/dst accesses would be the following:
826 //
827 // src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
828 // dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
829 //
830 // The symbols by both accesses would be assigned to a canonical position order
831 // which will be used in the dependence constraint system:
832 //
833 // symbol name: %M %N %K
834 // symbol pos: 0 1 2
835 //
836 // Equality constraints are built by equating each result of src/destination
837 // access functions. For this example, the following two equality constraints
838 // will be added to the dependence constraint system:
839 //
840 // [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
841 // 2 -4 -7 -9 1 1 0 0 = 0
842 // 0 3 0 -11 -1 0 1 0 = 0
843 //
844 // Inequality constraints from the iteration domain will be meged into
845 // the dependence constraint system
846 //
847 // [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
848 // 1 0 0 0 0 0 0 0 >= 0
849 // -1 0 0 0 0 0 0 100 >= 0
850 // 0 1 0 0 0 0 0 0 >= 0
851 // 0 -1 0 0 0 0 0 50 >= 0
852 // 0 0 1 0 0 0 0 0 >= 0
853 // 0 0 -1 0 0 0 0 100 >= 0
854 // 0 0 0 1 0 0 0 0 >= 0
855 // 0 0 0 -1 0 0 0 50 >= 0
856 //
857 //
858 // TODO: Support AffineExprs mod/floordiv/ceildiv.
checkMemrefAccessDependence(const MemRefAccess & srcAccess,const MemRefAccess & dstAccess,unsigned loopDepth,FlatAffineConstraints * dependenceConstraints,SmallVector<DependenceComponent,2> * dependenceComponents,bool allowRAR)859 DependenceResult mlir::checkMemrefAccessDependence(
860 const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
861 unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
862 SmallVector<DependenceComponent, 2> *dependenceComponents, bool allowRAR) {
863 LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
864 << Twine(loopDepth) << " between:\n";);
865 LLVM_DEBUG(srcAccess.opInst->dump(););
866 LLVM_DEBUG(dstAccess.opInst->dump(););
867
868 // Return 'NoDependence' if these accesses do not access the same memref.
869 if (srcAccess.memref != dstAccess.memref)
870 return DependenceResult::NoDependence;
871
872 // Return 'NoDependence' if one of these accesses is not an
873 // AffineWriteOpInterface.
874 if (!allowRAR && !isa<AffineWriteOpInterface>(srcAccess.opInst) &&
875 !isa<AffineWriteOpInterface>(dstAccess.opInst))
876 return DependenceResult::NoDependence;
877
878 // Get composed access function for 'srcAccess'.
879 AffineValueMap srcAccessMap;
880 srcAccess.getAccessMap(&srcAccessMap);
881
882 // Get composed access function for 'dstAccess'.
883 AffineValueMap dstAccessMap;
884 dstAccess.getAccessMap(&dstAccessMap);
885
886 // Get iteration domain for the 'srcAccess' operation.
887 FlatAffineConstraints srcDomain;
888 if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain)))
889 return DependenceResult::Failure;
890
891 // Get iteration domain for 'dstAccess' operation.
892 FlatAffineConstraints dstDomain;
893 if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain)))
894 return DependenceResult::Failure;
895
896 // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
897 // operation of 'srcAccess' does not properly dominate the ancestor
898 // operation of 'dstAccess' in the same common operation block.
899 // Note: this check is skipped if 'allowRAR' is true, because because RAR
900 // deps can exist irrespective of lexicographic ordering b/w src and dst.
901 unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
902 assert(loopDepth <= numCommonLoops + 1);
903 if (!allowRAR && loopDepth > numCommonLoops &&
904 !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
905 numCommonLoops)) {
906 return DependenceResult::NoDependence;
907 }
908 // Build dim and symbol position maps for each access from access operand
909 // Value to position in merged constraint system.
910 ValuePositionMap valuePosMap;
911 buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
912 dstAccessMap, &valuePosMap,
913 dependenceConstraints);
914 initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
915 valuePosMap, dependenceConstraints);
916
917 assert(valuePosMap.getNumDims() ==
918 srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
919
920 // Create memref access constraint by equating src/dst access functions.
921 // Note that this check is conservative, and will fail in the future when
922 // local variables for mod/div exprs are supported.
923 if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
924 dependenceConstraints)))
925 return DependenceResult::Failure;
926
927 // Add 'src' happens before 'dst' ordering constraints.
928 addOrderingConstraints(srcDomain, dstDomain, loopDepth,
929 dependenceConstraints);
930 // Add src and dst domain constraints.
931 addDomainConstraints(srcDomain, dstDomain, valuePosMap,
932 dependenceConstraints);
933
934 // Return 'NoDependence' if the solution space is empty: no dependence.
935 if (dependenceConstraints->isEmpty()) {
936 return DependenceResult::NoDependence;
937 }
938
939 // Compute dependence direction vector and return true.
940 if (dependenceComponents != nullptr) {
941 computeDirectionVector(srcDomain, dstDomain, loopDepth,
942 dependenceConstraints, dependenceComponents);
943 }
944
945 LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
946 LLVM_DEBUG(dependenceConstraints->dump());
947 return DependenceResult::HasDependence;
948 }
949
950 /// Gathers dependence components for dependences between all ops in loop nest
951 /// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
getDependenceComponents(AffineForOp forOp,unsigned maxLoopDepth,std::vector<SmallVector<DependenceComponent,2>> * depCompsVec)952 void mlir::getDependenceComponents(
953 AffineForOp forOp, unsigned maxLoopDepth,
954 std::vector<SmallVector<DependenceComponent, 2>> *depCompsVec) {
955 // Collect all load and store ops in loop nest rooted at 'forOp'.
956 SmallVector<Operation *, 8> loadAndStoreOps;
957 forOp->walk([&](Operation *op) {
958 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
959 loadAndStoreOps.push_back(op);
960 });
961
962 unsigned numOps = loadAndStoreOps.size();
963 for (unsigned d = 1; d <= maxLoopDepth; ++d) {
964 for (unsigned i = 0; i < numOps; ++i) {
965 auto *srcOp = loadAndStoreOps[i];
966 MemRefAccess srcAccess(srcOp);
967 for (unsigned j = 0; j < numOps; ++j) {
968 auto *dstOp = loadAndStoreOps[j];
969 MemRefAccess dstAccess(dstOp);
970
971 FlatAffineConstraints dependenceConstraints;
972 SmallVector<DependenceComponent, 2> depComps;
973 // TODO: Explore whether it would be profitable to pre-compute and store
974 // deps instead of repeatedly checking.
975 DependenceResult result = checkMemrefAccessDependence(
976 srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
977 if (hasDependence(result))
978 depCompsVec->push_back(depComps);
979 }
980 }
981 }
982 }
983