1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/AffineMap.h"
10 #include "AffineMapDetail.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Support/LogicalResult.h"
14 #include "mlir/Support/MathExtras.h"
15 #include "llvm/ADT/SmallSet.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Support/raw_ostream.h"
18
19 using namespace mlir;
20
21 namespace {
22
23 // AffineExprConstantFolder evaluates an affine expression using constant
24 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute
25 // representing the constant value of the affine expression evaluated on
26 // constant 'operandConsts', or nullptr if it can't be folded.
27 class AffineExprConstantFolder {
28 public:
AffineExprConstantFolder(unsigned numDims,ArrayRef<Attribute> operandConsts)29 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
30 : numDims(numDims), operandConsts(operandConsts) {}
31
32 /// Attempt to constant fold the specified affine expr, or return null on
33 /// failure.
constantFold(AffineExpr expr)34 IntegerAttr constantFold(AffineExpr expr) {
35 if (auto result = constantFoldImpl(expr))
36 return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
37 return nullptr;
38 }
39
40 private:
constantFoldImpl(AffineExpr expr)41 Optional<int64_t> constantFoldImpl(AffineExpr expr) {
42 switch (expr.getKind()) {
43 case AffineExprKind::Add:
44 return constantFoldBinExpr(
45 expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
46 case AffineExprKind::Mul:
47 return constantFoldBinExpr(
48 expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
49 case AffineExprKind::Mod:
50 return constantFoldBinExpr(
51 expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
52 case AffineExprKind::FloorDiv:
53 return constantFoldBinExpr(
54 expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
55 case AffineExprKind::CeilDiv:
56 return constantFoldBinExpr(
57 expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
58 case AffineExprKind::Constant:
59 return expr.cast<AffineConstantExpr>().getValue();
60 case AffineExprKind::DimId:
61 if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
62 .dyn_cast_or_null<IntegerAttr>())
63 return attr.getInt();
64 return llvm::None;
65 case AffineExprKind::SymbolId:
66 if (auto attr = operandConsts[numDims +
67 expr.cast<AffineSymbolExpr>().getPosition()]
68 .dyn_cast_or_null<IntegerAttr>())
69 return attr.getInt();
70 return llvm::None;
71 }
72 llvm_unreachable("Unknown AffineExpr");
73 }
74
75 // TODO: Change these to operate on APInts too.
constantFoldBinExpr(AffineExpr expr,int64_t (* op)(int64_t,int64_t))76 Optional<int64_t> constantFoldBinExpr(AffineExpr expr,
77 int64_t (*op)(int64_t, int64_t)) {
78 auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
79 if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
80 if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
81 return op(*lhs, *rhs);
82 return llvm::None;
83 }
84
85 // The number of dimension operands in AffineMap containing this expression.
86 unsigned numDims;
87 // The constant valued operands used to evaluate this AffineExpr.
88 ArrayRef<Attribute> operandConsts;
89 };
90
91 } // end anonymous namespace
92
93 /// Returns a single constant result affine map.
getConstantMap(int64_t val,MLIRContext * context)94 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
95 return get(/*dimCount=*/0, /*symbolCount=*/0,
96 {getAffineConstantExpr(val, context)});
97 }
98
99 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most
100 /// minor dimensions.
getMinorIdentityMap(unsigned dims,unsigned results,MLIRContext * context)101 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
102 MLIRContext *context) {
103 assert(dims >= results && "Dimension mismatch");
104 auto id = AffineMap::getMultiDimIdentityMap(dims, context);
105 return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
106 }
107
isMinorIdentity() const108 bool AffineMap::isMinorIdentity() const {
109 return *this ==
110 getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
111 }
112
113 /// Returns an AffineMap representing a permutation.
getPermutationMap(ArrayRef<unsigned> permutation,MLIRContext * context)114 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
115 MLIRContext *context) {
116 assert(!permutation.empty() &&
117 "Cannot create permutation map from empty permutation vector");
118 SmallVector<AffineExpr, 4> affExprs;
119 for (auto index : permutation)
120 affExprs.push_back(getAffineDimExpr(index, context));
121 auto m = std::max_element(permutation.begin(), permutation.end());
122 auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
123 assert(permutationMap.isPermutation() && "Invalid permutation vector");
124 return permutationMap;
125 }
126
127 template <typename AffineExprContainer>
getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,int64_t & maxDim,int64_t & maxSym)128 static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
129 int64_t &maxDim, int64_t &maxSym) {
130 for (const auto &exprs : exprsList) {
131 for (auto expr : exprs) {
132 expr.walk([&maxDim, &maxSym](AffineExpr e) {
133 if (auto d = e.dyn_cast<AffineDimExpr>())
134 maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
135 if (auto s = e.dyn_cast<AffineSymbolExpr>())
136 maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
137 });
138 }
139 }
140 }
141
142 template <typename AffineExprContainer>
143 static SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<AffineExprContainer> exprsList)144 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
145 assert(!exprsList.empty());
146 assert(!exprsList[0].empty());
147 auto context = exprsList[0][0].getContext();
148 int64_t maxDim = -1, maxSym = -1;
149 getMaxDimAndSymbol(exprsList, maxDim, maxSym);
150 SmallVector<AffineMap, 4> maps;
151 maps.reserve(exprsList.size());
152 for (const auto &exprs : exprsList)
153 maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
154 /*symbolCount=*/maxSym + 1, exprs, context));
155 return maps;
156 }
157
158 SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList)159 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
160 return ::inferFromExprList(exprsList);
161 }
162
163 SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<SmallVector<AffineExpr,4>> exprsList)164 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
165 return ::inferFromExprList(exprsList);
166 }
167
getMultiDimIdentityMap(unsigned numDims,MLIRContext * context)168 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
169 MLIRContext *context) {
170 SmallVector<AffineExpr, 4> dimExprs;
171 dimExprs.reserve(numDims);
172 for (unsigned i = 0; i < numDims; ++i)
173 dimExprs.push_back(mlir::getAffineDimExpr(i, context));
174 return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context);
175 }
176
getContext() const177 MLIRContext *AffineMap::getContext() const { return map->context; }
178
isIdentity() const179 bool AffineMap::isIdentity() const {
180 if (getNumDims() != getNumResults())
181 return false;
182 ArrayRef<AffineExpr> results = getResults();
183 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
184 auto expr = results[i].dyn_cast<AffineDimExpr>();
185 if (!expr || expr.getPosition() != i)
186 return false;
187 }
188 return true;
189 }
190
isEmpty() const191 bool AffineMap::isEmpty() const {
192 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
193 }
194
isSingleConstant() const195 bool AffineMap::isSingleConstant() const {
196 return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
197 }
198
getSingleConstantResult() const199 int64_t AffineMap::getSingleConstantResult() const {
200 assert(isSingleConstant() && "map must have a single constant result");
201 return getResult(0).cast<AffineConstantExpr>().getValue();
202 }
203
getNumDims() const204 unsigned AffineMap::getNumDims() const {
205 assert(map && "uninitialized map storage");
206 return map->numDims;
207 }
getNumSymbols() const208 unsigned AffineMap::getNumSymbols() const {
209 assert(map && "uninitialized map storage");
210 return map->numSymbols;
211 }
getNumResults() const212 unsigned AffineMap::getNumResults() const {
213 assert(map && "uninitialized map storage");
214 return map->results.size();
215 }
getNumInputs() const216 unsigned AffineMap::getNumInputs() const {
217 assert(map && "uninitialized map storage");
218 return map->numDims + map->numSymbols;
219 }
220
getResults() const221 ArrayRef<AffineExpr> AffineMap::getResults() const {
222 assert(map && "uninitialized map storage");
223 return map->results;
224 }
getResult(unsigned idx) const225 AffineExpr AffineMap::getResult(unsigned idx) const {
226 assert(map && "uninitialized map storage");
227 return map->results[idx];
228 }
229
getDimPosition(unsigned idx) const230 unsigned AffineMap::getDimPosition(unsigned idx) const {
231 return getResult(idx).cast<AffineDimExpr>().getPosition();
232 }
233
234 /// Folds the results of the application of an affine map on the provided
235 /// operands to a constant if possible. Returns false if the folding happens,
236 /// true otherwise.
237 LogicalResult
constantFold(ArrayRef<Attribute> operandConstants,SmallVectorImpl<Attribute> & results) const238 AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
239 SmallVectorImpl<Attribute> &results) const {
240 // Attempt partial folding.
241 SmallVector<int64_t, 2> integers;
242 partialConstantFold(operandConstants, &integers);
243
244 // If all expressions folded to a constant, populate results with attributes
245 // containing those constants.
246 if (integers.empty())
247 return failure();
248
249 auto range = llvm::map_range(integers, [this](int64_t i) {
250 return IntegerAttr::get(IndexType::get(getContext()), i);
251 });
252 results.append(range.begin(), range.end());
253 return success();
254 }
255
256 AffineMap
partialConstantFold(ArrayRef<Attribute> operandConstants,SmallVectorImpl<int64_t> * results) const257 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
258 SmallVectorImpl<int64_t> *results) const {
259 assert(getNumInputs() == operandConstants.size());
260
261 // Fold each of the result expressions.
262 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
263 SmallVector<AffineExpr, 4> exprs;
264 exprs.reserve(getNumResults());
265
266 for (auto expr : getResults()) {
267 auto folded = exprFolder.constantFold(expr);
268 // If did not fold to a constant, keep the original expression, and clear
269 // the integer results vector.
270 if (folded) {
271 exprs.push_back(
272 getAffineConstantExpr(folded.getInt(), folded.getContext()));
273 if (results)
274 results->push_back(folded.getInt());
275 } else {
276 exprs.push_back(expr);
277 if (results) {
278 results->clear();
279 results = nullptr;
280 }
281 }
282 }
283
284 return get(getNumDims(), getNumSymbols(), exprs, getContext());
285 }
286
287 /// Walk all of the AffineExpr's in this mapping. Each node in an expression
288 /// tree is visited in postorder.
walkExprs(std::function<void (AffineExpr)> callback) const289 void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const {
290 for (auto expr : getResults())
291 expr.walk(callback);
292 }
293
294 /// This method substitutes any uses of dimensions and symbols (e.g.
295 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
296 /// expression mapping. Because this can be used to eliminate dims and
297 /// symbols, the client needs to specify the number of dims and symbols in
298 /// the result. The returned map always has the same number of results.
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements,unsigned numResultDims,unsigned numResultSyms) const299 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
300 ArrayRef<AffineExpr> symReplacements,
301 unsigned numResultDims,
302 unsigned numResultSyms) const {
303 SmallVector<AffineExpr, 8> results;
304 results.reserve(getNumResults());
305 for (auto expr : getResults())
306 results.push_back(
307 expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
308
309 return get(numResultDims, numResultSyms, results, getContext());
310 }
311
compose(AffineMap map)312 AffineMap AffineMap::compose(AffineMap map) {
313 assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
314 // Prepare `map` by concatenating the symbols and rewriting its exprs.
315 unsigned numDims = map.getNumDims();
316 unsigned numSymbolsThisMap = getNumSymbols();
317 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
318 SmallVector<AffineExpr, 8> newDims(numDims);
319 for (unsigned idx = 0; idx < numDims; ++idx) {
320 newDims[idx] = getAffineDimExpr(idx, getContext());
321 }
322 SmallVector<AffineExpr, 8> newSymbols(numSymbols);
323 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
324 newSymbols[idx - numSymbolsThisMap] =
325 getAffineSymbolExpr(idx, getContext());
326 }
327 auto newMap =
328 map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
329 SmallVector<AffineExpr, 8> exprs;
330 exprs.reserve(getResults().size());
331 for (auto expr : getResults())
332 exprs.push_back(expr.compose(newMap));
333 return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
334 }
335
compose(ArrayRef<int64_t> values)336 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) {
337 assert(getNumSymbols() == 0 && "Expected symbol-less map");
338 SmallVector<AffineExpr, 4> exprs;
339 exprs.reserve(values.size());
340 MLIRContext *ctx = getContext();
341 for (auto v : values)
342 exprs.push_back(getAffineConstantExpr(v, ctx));
343 auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
344 SmallVector<int64_t, 4> res;
345 res.reserve(resMap.getNumResults());
346 for (auto e : resMap.getResults())
347 res.push_back(e.cast<AffineConstantExpr>().getValue());
348 return res;
349 }
350
isProjectedPermutation()351 bool AffineMap::isProjectedPermutation() {
352 if (getNumSymbols() > 0)
353 return false;
354 SmallVector<bool, 8> seen(getNumInputs(), false);
355 for (auto expr : getResults()) {
356 if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
357 if (seen[dim.getPosition()])
358 return false;
359 seen[dim.getPosition()] = true;
360 continue;
361 }
362 return false;
363 }
364 return true;
365 }
366
isPermutation()367 bool AffineMap::isPermutation() {
368 if (getNumDims() != getNumResults())
369 return false;
370 return isProjectedPermutation();
371 }
372
getSubMap(ArrayRef<unsigned> resultPos)373 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
374 SmallVector<AffineExpr, 4> exprs;
375 exprs.reserve(resultPos.size());
376 for (auto idx : resultPos)
377 exprs.push_back(getResult(idx));
378 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
379 }
380
getMajorSubMap(unsigned numResults)381 AffineMap AffineMap::getMajorSubMap(unsigned numResults) {
382 if (numResults == 0)
383 return AffineMap();
384 if (numResults > getNumResults())
385 return *this;
386 return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults)));
387 }
388
getMinorSubMap(unsigned numResults)389 AffineMap AffineMap::getMinorSubMap(unsigned numResults) {
390 if (numResults == 0)
391 return AffineMap();
392 if (numResults > getNumResults())
393 return *this;
394 return getSubMap(llvm::to_vector<4>(
395 llvm::seq<unsigned>(getNumResults() - numResults, getNumResults())));
396 }
397
simplifyAffineMap(AffineMap map)398 AffineMap mlir::simplifyAffineMap(AffineMap map) {
399 SmallVector<AffineExpr, 8> exprs;
400 for (auto e : map.getResults()) {
401 exprs.push_back(
402 simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
403 }
404 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
405 map.getContext());
406 }
407
removeDuplicateExprs(AffineMap map)408 AffineMap mlir::removeDuplicateExprs(AffineMap map) {
409 auto results = map.getResults();
410 SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end());
411 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
412 uniqueExprs.end());
413 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs,
414 map.getContext());
415 }
416
inversePermutation(AffineMap map)417 AffineMap mlir::inversePermutation(AffineMap map) {
418 if (map.isEmpty())
419 return map;
420 assert(map.getNumSymbols() == 0 && "expected map without symbols");
421 SmallVector<AffineExpr, 4> exprs(map.getNumDims());
422 for (auto en : llvm::enumerate(map.getResults())) {
423 auto expr = en.value();
424 // Skip non-permutations.
425 if (auto d = expr.dyn_cast<AffineDimExpr>()) {
426 if (exprs[d.getPosition()])
427 continue;
428 exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
429 }
430 }
431 SmallVector<AffineExpr, 4> seenExprs;
432 seenExprs.reserve(map.getNumDims());
433 for (auto expr : exprs)
434 if (expr)
435 seenExprs.push_back(expr);
436 if (seenExprs.size() != map.getNumInputs())
437 return AffineMap();
438 return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
439 }
440
concatAffineMaps(ArrayRef<AffineMap> maps)441 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
442 unsigned numResults = 0, numDims = 0, numSymbols = 0;
443 for (auto m : maps)
444 numResults += m.getNumResults();
445 SmallVector<AffineExpr, 8> results;
446 results.reserve(numResults);
447 for (auto m : maps) {
448 for (auto res : m.getResults())
449 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
450
451 numSymbols += m.getNumSymbols();
452 numDims = std::max(m.getNumDims(), numDims);
453 }
454 return AffineMap::get(numDims, numSymbols, results,
455 maps.front().getContext());
456 }
457
getProjectedMap(AffineMap map,ArrayRef<unsigned> projectedDimensions)458 AffineMap mlir::getProjectedMap(AffineMap map,
459 ArrayRef<unsigned> projectedDimensions) {
460 DenseSet<unsigned> projectedDims(projectedDimensions.begin(),
461 projectedDimensions.end());
462 MLIRContext *context = map.getContext();
463 SmallVector<AffineExpr, 4> resultExprs;
464 for (auto dim : enumerate(llvm::seq<unsigned>(0, map.getNumDims()))) {
465 if (!projectedDims.count(dim.value()))
466 resultExprs.push_back(getAffineDimExpr(dim.index(), context));
467 else
468 resultExprs.push_back(getAffineConstantExpr(0, context));
469 }
470 return map.compose(AffineMap::get(
471 map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context));
472 }
473
474 //===----------------------------------------------------------------------===//
475 // MutableAffineMap.
476 //===----------------------------------------------------------------------===//
477
MutableAffineMap(AffineMap map)478 MutableAffineMap::MutableAffineMap(AffineMap map)
479 : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
480 context(map.getContext()) {
481 for (auto result : map.getResults())
482 results.push_back(result);
483 }
484
reset(AffineMap map)485 void MutableAffineMap::reset(AffineMap map) {
486 results.clear();
487 numDims = map.getNumDims();
488 numSymbols = map.getNumSymbols();
489 context = map.getContext();
490 for (auto result : map.getResults())
491 results.push_back(result);
492 }
493
isMultipleOf(unsigned idx,int64_t factor) const494 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
495 if (results[idx].isMultipleOf(factor))
496 return true;
497
498 // TODO: use simplifyAffineExpr and FlatAffineConstraints to
499 // complete this (for a more powerful analysis).
500 return false;
501 }
502
503 // Simplifies the result affine expressions of this map. The expressions have to
504 // be pure for the simplification implemented.
simplify()505 void MutableAffineMap::simplify() {
506 // Simplify each of the results if possible.
507 // TODO: functional-style map
508 for (unsigned i = 0, e = getNumResults(); i < e; i++) {
509 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
510 }
511 }
512
getAffineMap() const513 AffineMap MutableAffineMap::getAffineMap() const {
514 return AffineMap::get(numDims, numSymbols, results, context);
515 }
516