1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/AffineExpr.h"
10 #include "AffineExprDetail.h"
11 #include "mlir/IR/AffineExprVisitor.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/Support/MathExtras.h"
15 #include "mlir/Support/TypeID.h"
16 #include "llvm/ADT/STLExtras.h"
17
18 using namespace mlir;
19 using namespace mlir::detail;
20
getContext() const21 MLIRContext *AffineExpr::getContext() const { return expr->context; }
22
getKind() const23 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
24
25 /// Walk all of the AffineExprs in this subgraph in postorder.
walk(std::function<void (AffineExpr)> callback) const26 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
27 struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
28 std::function<void(AffineExpr)> callback;
29
30 AffineExprWalker(std::function<void(AffineExpr)> callback)
31 : callback(callback) {}
32
33 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
34 void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
35 void visitDimExpr(AffineDimExpr expr) { callback(expr); }
36 void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
37 };
38
39 AffineExprWalker(callback).walkPostOrder(*this);
40 }
41
42 // Dispatch affine expression construction based on kind.
getAffineBinaryOpExpr(AffineExprKind kind,AffineExpr lhs,AffineExpr rhs)43 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
44 AffineExpr rhs) {
45 if (kind == AffineExprKind::Add)
46 return lhs + rhs;
47 if (kind == AffineExprKind::Mul)
48 return lhs * rhs;
49 if (kind == AffineExprKind::FloorDiv)
50 return lhs.floorDiv(rhs);
51 if (kind == AffineExprKind::CeilDiv)
52 return lhs.ceilDiv(rhs);
53 if (kind == AffineExprKind::Mod)
54 return lhs % rhs;
55
56 llvm_unreachable("unknown binary operation on affine expressions");
57 }
58
59 /// This method substitutes any uses of dimensions and symbols (e.g.
60 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
61 AffineExpr
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements) const62 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
63 ArrayRef<AffineExpr> symReplacements) const {
64 switch (getKind()) {
65 case AffineExprKind::Constant:
66 return *this;
67 case AffineExprKind::DimId: {
68 unsigned dimId = cast<AffineDimExpr>().getPosition();
69 if (dimId >= dimReplacements.size())
70 return *this;
71 return dimReplacements[dimId];
72 }
73 case AffineExprKind::SymbolId: {
74 unsigned symId = cast<AffineSymbolExpr>().getPosition();
75 if (symId >= symReplacements.size())
76 return *this;
77 return symReplacements[symId];
78 }
79 case AffineExprKind::Add:
80 case AffineExprKind::Mul:
81 case AffineExprKind::FloorDiv:
82 case AffineExprKind::CeilDiv:
83 case AffineExprKind::Mod:
84 auto binOp = cast<AffineBinaryOpExpr>();
85 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
86 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
87 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
88 if (newLHS == lhs && newRHS == rhs)
89 return *this;
90 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
91 }
92 llvm_unreachable("Unknown AffineExpr");
93 }
94
95 /// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1].
shiftSymbols(unsigned numSymbols,unsigned shift) const96 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const {
97 SmallVector<AffineExpr, 4> symbols;
98 for (unsigned idx = 0; idx < numSymbols; ++idx)
99 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
100 return replaceDimsAndSymbols({}, symbols);
101 }
102
103 /// Sparse replace method. Return the modified expression tree.
104 AffineExpr
replace(const DenseMap<AffineExpr,AffineExpr> & map) const105 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
106 auto it = map.find(*this);
107 if (it != map.end())
108 return it->second;
109 switch (getKind()) {
110 default:
111 return *this;
112 case AffineExprKind::Add:
113 case AffineExprKind::Mul:
114 case AffineExprKind::FloorDiv:
115 case AffineExprKind::CeilDiv:
116 case AffineExprKind::Mod:
117 auto binOp = cast<AffineBinaryOpExpr>();
118 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
119 auto newLHS = lhs.replace(map);
120 auto newRHS = rhs.replace(map);
121 if (newLHS == lhs && newRHS == rhs)
122 return *this;
123 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
124 }
125 llvm_unreachable("Unknown AffineExpr");
126 }
127
128 /// Sparse replace method. Return the modified expression tree.
replace(AffineExpr expr,AffineExpr replacement) const129 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
130 DenseMap<AffineExpr, AffineExpr> map;
131 map.insert(std::make_pair(expr, replacement));
132 return replace(map);
133 }
134 /// Returns true if this expression is made out of only symbols and
135 /// constants (no dimensional identifiers).
isSymbolicOrConstant() const136 bool AffineExpr::isSymbolicOrConstant() const {
137 switch (getKind()) {
138 case AffineExprKind::Constant:
139 return true;
140 case AffineExprKind::DimId:
141 return false;
142 case AffineExprKind::SymbolId:
143 return true;
144
145 case AffineExprKind::Add:
146 case AffineExprKind::Mul:
147 case AffineExprKind::FloorDiv:
148 case AffineExprKind::CeilDiv:
149 case AffineExprKind::Mod: {
150 auto expr = this->cast<AffineBinaryOpExpr>();
151 return expr.getLHS().isSymbolicOrConstant() &&
152 expr.getRHS().isSymbolicOrConstant();
153 }
154 }
155 llvm_unreachable("Unknown AffineExpr");
156 }
157
158 /// Returns true if this is a pure affine expression, i.e., multiplication,
159 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
isPureAffine() const160 bool AffineExpr::isPureAffine() const {
161 switch (getKind()) {
162 case AffineExprKind::SymbolId:
163 case AffineExprKind::DimId:
164 case AffineExprKind::Constant:
165 return true;
166 case AffineExprKind::Add: {
167 auto op = cast<AffineBinaryOpExpr>();
168 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
169 }
170
171 case AffineExprKind::Mul: {
172 // TODO: Canonicalize the constants in binary operators to the RHS when
173 // possible, allowing this to merge into the next case.
174 auto op = cast<AffineBinaryOpExpr>();
175 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
176 (op.getLHS().template isa<AffineConstantExpr>() ||
177 op.getRHS().template isa<AffineConstantExpr>());
178 }
179 case AffineExprKind::FloorDiv:
180 case AffineExprKind::CeilDiv:
181 case AffineExprKind::Mod: {
182 auto op = cast<AffineBinaryOpExpr>();
183 return op.getLHS().isPureAffine() &&
184 op.getRHS().template isa<AffineConstantExpr>();
185 }
186 }
187 llvm_unreachable("Unknown AffineExpr");
188 }
189
190 // Returns the greatest known integral divisor of this affine expression.
getLargestKnownDivisor() const191 int64_t AffineExpr::getLargestKnownDivisor() const {
192 AffineBinaryOpExpr binExpr(nullptr);
193 switch (getKind()) {
194 case AffineExprKind::SymbolId:
195 LLVM_FALLTHROUGH;
196 case AffineExprKind::DimId:
197 return 1;
198 case AffineExprKind::Constant:
199 return std::abs(this->cast<AffineConstantExpr>().getValue());
200 case AffineExprKind::Mul: {
201 binExpr = this->cast<AffineBinaryOpExpr>();
202 return binExpr.getLHS().getLargestKnownDivisor() *
203 binExpr.getRHS().getLargestKnownDivisor();
204 }
205 case AffineExprKind::Add:
206 LLVM_FALLTHROUGH;
207 case AffineExprKind::FloorDiv:
208 case AffineExprKind::CeilDiv:
209 case AffineExprKind::Mod: {
210 binExpr = cast<AffineBinaryOpExpr>();
211 return llvm::GreatestCommonDivisor64(
212 binExpr.getLHS().getLargestKnownDivisor(),
213 binExpr.getRHS().getLargestKnownDivisor());
214 }
215 }
216 llvm_unreachable("Unknown AffineExpr");
217 }
218
isMultipleOf(int64_t factor) const219 bool AffineExpr::isMultipleOf(int64_t factor) const {
220 AffineBinaryOpExpr binExpr(nullptr);
221 uint64_t l, u;
222 switch (getKind()) {
223 case AffineExprKind::SymbolId:
224 LLVM_FALLTHROUGH;
225 case AffineExprKind::DimId:
226 return factor * factor == 1;
227 case AffineExprKind::Constant:
228 return cast<AffineConstantExpr>().getValue() % factor == 0;
229 case AffineExprKind::Mul: {
230 binExpr = cast<AffineBinaryOpExpr>();
231 // It's probably not worth optimizing this further (to not traverse the
232 // whole sub-tree under - it that would require a version of isMultipleOf
233 // that on a 'false' return also returns the largest known divisor).
234 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
235 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
236 (l * u) % factor == 0;
237 }
238 case AffineExprKind::Add:
239 case AffineExprKind::FloorDiv:
240 case AffineExprKind::CeilDiv:
241 case AffineExprKind::Mod: {
242 binExpr = cast<AffineBinaryOpExpr>();
243 return llvm::GreatestCommonDivisor64(
244 binExpr.getLHS().getLargestKnownDivisor(),
245 binExpr.getRHS().getLargestKnownDivisor()) %
246 factor ==
247 0;
248 }
249 }
250 llvm_unreachable("Unknown AffineExpr");
251 }
252
isFunctionOfDim(unsigned position) const253 bool AffineExpr::isFunctionOfDim(unsigned position) const {
254 if (getKind() == AffineExprKind::DimId) {
255 return *this == mlir::getAffineDimExpr(position, getContext());
256 }
257 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
258 return expr.getLHS().isFunctionOfDim(position) ||
259 expr.getRHS().isFunctionOfDim(position);
260 }
261 return false;
262 }
263
AffineBinaryOpExpr(AffineExpr::ImplType * ptr)264 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
265 : AffineExpr(ptr) {}
getLHS() const266 AffineExpr AffineBinaryOpExpr::getLHS() const {
267 return static_cast<ImplType *>(expr)->lhs;
268 }
getRHS() const269 AffineExpr AffineBinaryOpExpr::getRHS() const {
270 return static_cast<ImplType *>(expr)->rhs;
271 }
272
AffineDimExpr(AffineExpr::ImplType * ptr)273 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
getPosition() const274 unsigned AffineDimExpr::getPosition() const {
275 return static_cast<ImplType *>(expr)->position;
276 }
277
278 /// Returns true if the expression is divisible by the given symbol with
279 /// position `symbolPos`. The argument `opKind` specifies here what kind of
280 /// division or mod operation called this division. It helps in implementing the
281 /// commutative property of the floordiv and ceildiv operations. If the argument
282 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
283 /// operation, then the commutative property can be used otherwise, the floordiv
284 /// operation is not divisible. The same argument holds for ceildiv operation.
isDivisibleBySymbol(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)285 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
286 AffineExprKind opKind) {
287 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
288 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
289 opKind == AffineExprKind::CeilDiv) &&
290 "unexpected opKind");
291 switch (expr.getKind()) {
292 case AffineExprKind::Constant:
293 if (expr.cast<AffineConstantExpr>().getValue())
294 return false;
295 return true;
296 case AffineExprKind::DimId:
297 return false;
298 case AffineExprKind::SymbolId:
299 return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
300 // Checks divisibility by the given symbol for both operands.
301 case AffineExprKind::Add: {
302 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
303 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
304 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
305 }
306 // Checks divisibility by the given symbol for both operands. Consider the
307 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
308 // this is a division by s1 and both the operands of modulo are divisible by
309 // s1 but it is not divisible by s1 always. The third argument is
310 // `AffineExprKind::Mod` for this reason.
311 case AffineExprKind::Mod: {
312 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
313 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
314 AffineExprKind::Mod) &&
315 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
316 AffineExprKind::Mod);
317 }
318 // Checks if any of the operand divisible by the given symbol.
319 case AffineExprKind::Mul: {
320 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
321 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
322 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
323 }
324 // Floordiv and ceildiv are divisible by the given symbol when the first
325 // operand is divisible, and the affine expression kind of the argument expr
326 // is same as the argument `opKind`. This can be inferred from commutative
327 // property of floordiv and ceildiv operations and are as follow:
328 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
329 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
330 // It will fail if operations are not same. For example:
331 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
332 case AffineExprKind::FloorDiv:
333 case AffineExprKind::CeilDiv: {
334 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
335 if (opKind != expr.getKind())
336 return false;
337 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
338 }
339 }
340 llvm_unreachable("Unknown AffineExpr");
341 }
342
343 /// Divides the given expression by the given symbol at position `symbolPos`. It
344 /// considers the divisibility condition is checked before calling itself. A
345 /// null expression is returned whenever the divisibility condition fails.
symbolicDivide(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)346 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
347 AffineExprKind opKind) {
348 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
349 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
350 opKind == AffineExprKind::CeilDiv) &&
351 "unexpected opKind");
352 switch (expr.getKind()) {
353 case AffineExprKind::Constant:
354 if (expr.cast<AffineConstantExpr>().getValue() != 0)
355 return nullptr;
356 return getAffineConstantExpr(0, expr.getContext());
357 case AffineExprKind::DimId:
358 return nullptr;
359 case AffineExprKind::SymbolId:
360 return getAffineConstantExpr(1, expr.getContext());
361 // Dividing both operands by the given symbol.
362 case AffineExprKind::Add: {
363 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
364 return getAffineBinaryOpExpr(
365 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
366 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
367 }
368 // Dividing both operands by the given symbol.
369 case AffineExprKind::Mod: {
370 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
371 return getAffineBinaryOpExpr(
372 expr.getKind(),
373 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
374 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
375 }
376 // Dividing any of the operand by the given symbol.
377 case AffineExprKind::Mul: {
378 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
379 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
380 return binaryExpr.getLHS() *
381 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
382 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
383 binaryExpr.getRHS();
384 }
385 // Dividing first operand only by the given symbol.
386 case AffineExprKind::FloorDiv:
387 case AffineExprKind::CeilDiv: {
388 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
389 return getAffineBinaryOpExpr(
390 expr.getKind(),
391 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
392 binaryExpr.getRHS());
393 }
394 }
395 llvm_unreachable("Unknown AffineExpr");
396 }
397
398 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
399 /// operations when the second operand simplifies to a symbol and the first
400 /// operand is divisible by that symbol. It can be applied to any semi-affine
401 /// expression. Returned expression can either be a semi-affine or pure affine
402 /// expression.
simplifySemiAffine(AffineExpr expr)403 static AffineExpr simplifySemiAffine(AffineExpr expr) {
404 switch (expr.getKind()) {
405 case AffineExprKind::Constant:
406 case AffineExprKind::DimId:
407 case AffineExprKind::SymbolId:
408 return expr;
409 case AffineExprKind::Add:
410 case AffineExprKind::Mul: {
411 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
412 return getAffineBinaryOpExpr(expr.getKind(),
413 simplifySemiAffine(binaryExpr.getLHS()),
414 simplifySemiAffine(binaryExpr.getRHS()));
415 }
416 // Check if the simplification of the second operand is a symbol, and the
417 // first operand is divisible by it. If the operation is a modulo, a constant
418 // zero expression is returned. In the case of floordiv and ceildiv, the
419 // symbol from the simplification of the second operand divides the first
420 // operand. Otherwise, simplification is not possible.
421 case AffineExprKind::FloorDiv:
422 case AffineExprKind::CeilDiv:
423 case AffineExprKind::Mod: {
424 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
425 AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
426 AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
427 AffineSymbolExpr symbolExpr =
428 simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
429 if (!symbolExpr)
430 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
431 unsigned symbolPos = symbolExpr.getPosition();
432 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
433 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
434 if (expr.getKind() == AffineExprKind::Mod)
435 return getAffineConstantExpr(0, expr.getContext());
436 return symbolicDivide(sLHS, symbolPos, expr.getKind());
437 }
438 }
439 llvm_unreachable("Unknown AffineExpr");
440 }
441
getAffineDimOrSymbol(AffineExprKind kind,unsigned position,MLIRContext * context)442 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
443 MLIRContext *context) {
444 auto assignCtx = [context](AffineDimExprStorage *storage) {
445 storage->context = context;
446 };
447
448 StorageUniquer &uniquer = context->getAffineUniquer();
449 return uniquer.get<AffineDimExprStorage>(
450 assignCtx, static_cast<unsigned>(kind), position);
451 }
452
getAffineDimExpr(unsigned position,MLIRContext * context)453 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
454 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
455 }
456
AffineSymbolExpr(AffineExpr::ImplType * ptr)457 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
458 : AffineExpr(ptr) {}
getPosition() const459 unsigned AffineSymbolExpr::getPosition() const {
460 return static_cast<ImplType *>(expr)->position;
461 }
462
getAffineSymbolExpr(unsigned position,MLIRContext * context)463 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
464 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
465 ;
466 }
467
AffineConstantExpr(AffineExpr::ImplType * ptr)468 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
469 : AffineExpr(ptr) {}
getValue() const470 int64_t AffineConstantExpr::getValue() const {
471 return static_cast<ImplType *>(expr)->constant;
472 }
473
operator ==(int64_t v) const474 bool AffineExpr::operator==(int64_t v) const {
475 return *this == getAffineConstantExpr(v, getContext());
476 }
477
getAffineConstantExpr(int64_t constant,MLIRContext * context)478 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
479 auto assignCtx = [context](AffineConstantExprStorage *storage) {
480 storage->context = context;
481 };
482
483 StorageUniquer &uniquer = context->getAffineUniquer();
484 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
485 }
486
487 /// Simplify add expression. Return nullptr if it can't be simplified.
simplifyAdd(AffineExpr lhs,AffineExpr rhs)488 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
489 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
490 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
491 // Fold if both LHS, RHS are a constant.
492 if (lhsConst && rhsConst)
493 return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
494 lhs.getContext());
495
496 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
497 // If only one of them is a symbolic expressions, make it the RHS.
498 if (lhs.isa<AffineConstantExpr>() ||
499 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
500 return rhs + lhs;
501 }
502
503 // At this point, if there was a constant, it would be on the right.
504
505 // Addition with a zero is a noop, return the other input.
506 if (rhsConst) {
507 if (rhsConst.getValue() == 0)
508 return lhs;
509 }
510 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
511 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
512 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
513 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
514 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
515 }
516
517 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
518 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
519 // respective multiplicands.
520 Optional<int64_t> rLhsConst, rRhsConst;
521 AffineExpr firstExpr, secondExpr;
522 AffineConstantExpr rLhsConstExpr;
523 auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
524 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
525 (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
526 rLhsConst = rLhsConstExpr.getValue();
527 firstExpr = lBinOpExpr.getLHS();
528 } else {
529 rLhsConst = 1;
530 firstExpr = lhs;
531 }
532
533 auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
534 AffineConstantExpr rRhsConstExpr;
535 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
536 (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
537 rRhsConst = rRhsConstExpr.getValue();
538 secondExpr = rBinOpExpr.getLHS();
539 } else {
540 rRhsConst = 1;
541 secondExpr = rhs;
542 }
543
544 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
545 return getAffineBinaryOpExpr(
546 AffineExprKind::Mul, firstExpr,
547 getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
548 lhs.getContext()));
549
550 // When doing successive additions, bring constant to the right: turn (d0 + 2)
551 // + d1 into (d0 + d1) + 2.
552 if (lBin && lBin.getKind() == AffineExprKind::Add) {
553 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
554 return lBin.getLHS() + rhs + lrhs;
555 }
556 }
557
558 // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
559 // leads to a much more efficient form when 'c' is a power of two, and in
560 // general a more compact and readable form.
561
562 // Process '(expr floordiv c) * (-c)'.
563 if (!rBinOpExpr)
564 return nullptr;
565
566 auto lrhs = rBinOpExpr.getLHS();
567 auto rrhs = rBinOpExpr.getRHS();
568
569 // Process lrhs, which is 'expr floordiv c'.
570 AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
571 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
572 return nullptr;
573
574 auto llrhs = lrBinOpExpr.getLHS();
575 auto rlrhs = lrBinOpExpr.getRHS();
576
577 if (lhs == llrhs && rlrhs == -rrhs) {
578 return lhs % rlrhs;
579 }
580 return nullptr;
581 }
582
operator +(int64_t v) const583 AffineExpr AffineExpr::operator+(int64_t v) const {
584 return *this + getAffineConstantExpr(v, getContext());
585 }
operator +(AffineExpr other) const586 AffineExpr AffineExpr::operator+(AffineExpr other) const {
587 if (auto simplified = simplifyAdd(*this, other))
588 return simplified;
589
590 StorageUniquer &uniquer = getContext()->getAffineUniquer();
591 return uniquer.get<AffineBinaryOpExprStorage>(
592 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
593 }
594
595 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
simplifyMul(AffineExpr lhs,AffineExpr rhs)596 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
597 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
598 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
599
600 if (lhsConst && rhsConst)
601 return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
602 lhs.getContext());
603
604 assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
605
606 // Canonicalize the mul expression so that the constant/symbolic term is the
607 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
608 // constant. (Note that a constant is trivially symbolic).
609 if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
610 // At least one of them has to be symbolic.
611 return rhs * lhs;
612 }
613
614 // At this point, if there was a constant, it would be on the right.
615
616 // Multiplication with a one is a noop, return the other input.
617 if (rhsConst) {
618 if (rhsConst.getValue() == 1)
619 return lhs;
620 // Multiplication with zero.
621 if (rhsConst.getValue() == 0)
622 return rhsConst;
623 }
624
625 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
626 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
627 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
628 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
629 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
630 }
631
632 // When doing successive multiplication, bring constant to the right: turn (d0
633 // * 2) * d1 into (d0 * d1) * 2.
634 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
635 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
636 return (lBin.getLHS() * rhs) * lrhs;
637 }
638 }
639
640 return nullptr;
641 }
642
operator *(int64_t v) const643 AffineExpr AffineExpr::operator*(int64_t v) const {
644 return *this * getAffineConstantExpr(v, getContext());
645 }
operator *(AffineExpr other) const646 AffineExpr AffineExpr::operator*(AffineExpr other) const {
647 if (auto simplified = simplifyMul(*this, other))
648 return simplified;
649
650 StorageUniquer &uniquer = getContext()->getAffineUniquer();
651 return uniquer.get<AffineBinaryOpExprStorage>(
652 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
653 }
654
655 // Unary minus, delegate to operator*.
operator -() const656 AffineExpr AffineExpr::operator-() const {
657 return *this * getAffineConstantExpr(-1, getContext());
658 }
659
660 // Delegate to operator+.
operator -(int64_t v) const661 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
operator -(AffineExpr other) const662 AffineExpr AffineExpr::operator-(AffineExpr other) const {
663 return *this + (-other);
664 }
665
simplifyFloorDiv(AffineExpr lhs,AffineExpr rhs)666 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
667 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
668 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
669
670 // mlir floordiv by zero or negative numbers is undefined and preserved as is.
671 if (!rhsConst || rhsConst.getValue() < 1)
672 return nullptr;
673
674 if (lhsConst)
675 return getAffineConstantExpr(
676 floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
677
678 // Fold floordiv of a multiply with a constant that is a multiple of the
679 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
680 if (rhsConst == 1)
681 return lhs;
682
683 // Simplify (expr * const) floordiv divConst when expr is known to be a
684 // multiple of divConst.
685 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
686 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
687 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
688 // rhsConst is known to be a positive constant.
689 if (lrhs.getValue() % rhsConst.getValue() == 0)
690 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
691 }
692 }
693
694 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
695 // known to be a multiple of divConst.
696 if (lBin && lBin.getKind() == AffineExprKind::Add) {
697 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
698 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
699 // rhsConst is known to be a positive constant.
700 if (llhsDiv % rhsConst.getValue() == 0 ||
701 lrhsDiv % rhsConst.getValue() == 0)
702 return lBin.getLHS().floorDiv(rhsConst.getValue()) +
703 lBin.getRHS().floorDiv(rhsConst.getValue());
704 }
705
706 return nullptr;
707 }
708
floorDiv(uint64_t v) const709 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
710 return floorDiv(getAffineConstantExpr(v, getContext()));
711 }
floorDiv(AffineExpr other) const712 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
713 if (auto simplified = simplifyFloorDiv(*this, other))
714 return simplified;
715
716 StorageUniquer &uniquer = getContext()->getAffineUniquer();
717 return uniquer.get<AffineBinaryOpExprStorage>(
718 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
719 other);
720 }
721
simplifyCeilDiv(AffineExpr lhs,AffineExpr rhs)722 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
723 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
724 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
725
726 if (!rhsConst || rhsConst.getValue() < 1)
727 return nullptr;
728
729 if (lhsConst)
730 return getAffineConstantExpr(
731 ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
732
733 // Fold ceildiv of a multiply with a constant that is a multiple of the
734 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
735 if (rhsConst.getValue() == 1)
736 return lhs;
737
738 // Simplify (expr * const) ceildiv divConst when const is known to be a
739 // multiple of divConst.
740 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
741 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
742 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
743 // rhsConst is known to be a positive constant.
744 if (lrhs.getValue() % rhsConst.getValue() == 0)
745 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
746 }
747 }
748
749 return nullptr;
750 }
751
ceilDiv(uint64_t v) const752 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
753 return ceilDiv(getAffineConstantExpr(v, getContext()));
754 }
ceilDiv(AffineExpr other) const755 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
756 if (auto simplified = simplifyCeilDiv(*this, other))
757 return simplified;
758
759 StorageUniquer &uniquer = getContext()->getAffineUniquer();
760 return uniquer.get<AffineBinaryOpExprStorage>(
761 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
762 other);
763 }
764
simplifyMod(AffineExpr lhs,AffineExpr rhs)765 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
766 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
767 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
768
769 // mod w.r.t zero or negative numbers is undefined and preserved as is.
770 if (!rhsConst || rhsConst.getValue() < 1)
771 return nullptr;
772
773 if (lhsConst)
774 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
775 lhs.getContext());
776
777 // Fold modulo of an expression that is known to be a multiple of a constant
778 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
779 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
780 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
781 return getAffineConstantExpr(0, lhs.getContext());
782
783 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
784 // known to be a multiple of divConst.
785 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
786 if (lBin && lBin.getKind() == AffineExprKind::Add) {
787 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
788 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
789 // rhsConst is known to be a positive constant.
790 if (llhsDiv % rhsConst.getValue() == 0)
791 return lBin.getRHS() % rhsConst.getValue();
792 if (lrhsDiv % rhsConst.getValue() == 0)
793 return lBin.getLHS() % rhsConst.getValue();
794 }
795
796 return nullptr;
797 }
798
operator %(uint64_t v) const799 AffineExpr AffineExpr::operator%(uint64_t v) const {
800 return *this % getAffineConstantExpr(v, getContext());
801 }
operator %(AffineExpr other) const802 AffineExpr AffineExpr::operator%(AffineExpr other) const {
803 if (auto simplified = simplifyMod(*this, other))
804 return simplified;
805
806 StorageUniquer &uniquer = getContext()->getAffineUniquer();
807 return uniquer.get<AffineBinaryOpExprStorage>(
808 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
809 }
810
compose(AffineMap map) const811 AffineExpr AffineExpr::compose(AffineMap map) const {
812 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
813 map.getResults().end());
814 return replaceDimsAndSymbols(dimReplacements, {});
815 }
operator <<(raw_ostream & os,AffineExpr expr)816 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
817 expr.print(os);
818 return os;
819 }
820
821 /// Constructs an affine expression from a flat ArrayRef. If there are local
822 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
823 /// products expression, `localExprs` is expected to have the AffineExpr
824 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
825 /// in the format [dims, symbols, locals, constant term].
getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,unsigned numDims,unsigned numSymbols,ArrayRef<AffineExpr> localExprs,MLIRContext * context)826 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
827 unsigned numDims,
828 unsigned numSymbols,
829 ArrayRef<AffineExpr> localExprs,
830 MLIRContext *context) {
831 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
832 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
833 "unexpected number of local expressions");
834
835 auto expr = getAffineConstantExpr(0, context);
836 // Dimensions and symbols.
837 for (unsigned j = 0; j < numDims + numSymbols; j++) {
838 if (flatExprs[j] == 0)
839 continue;
840 auto id = j < numDims ? getAffineDimExpr(j, context)
841 : getAffineSymbolExpr(j - numDims, context);
842 expr = expr + id * flatExprs[j];
843 }
844
845 // Local identifiers.
846 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
847 j++) {
848 if (flatExprs[j] == 0)
849 continue;
850 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
851 expr = expr + term;
852 }
853
854 // Constant term.
855 int64_t constTerm = flatExprs[flatExprs.size() - 1];
856 if (constTerm != 0)
857 expr = expr + constTerm;
858 return expr;
859 }
860
SimpleAffineExprFlattener(unsigned numDims,unsigned numSymbols)861 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
862 unsigned numSymbols)
863 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
864 operandExprStack.reserve(8);
865 }
866
visitMulExpr(AffineBinaryOpExpr expr)867 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
868 assert(operandExprStack.size() >= 2);
869 // This is a pure affine expr; the RHS will be a constant.
870 assert(expr.getRHS().isa<AffineConstantExpr>());
871 // Get the RHS constant.
872 auto rhsConst = operandExprStack.back()[getConstantIndex()];
873 operandExprStack.pop_back();
874 // Update the LHS in place instead of pop and push.
875 auto &lhs = operandExprStack.back();
876 for (unsigned i = 0, e = lhs.size(); i < e; i++) {
877 lhs[i] *= rhsConst;
878 }
879 }
880
visitAddExpr(AffineBinaryOpExpr expr)881 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
882 assert(operandExprStack.size() >= 2);
883 const auto &rhs = operandExprStack.back();
884 auto &lhs = operandExprStack[operandExprStack.size() - 2];
885 assert(lhs.size() == rhs.size());
886 // Update the LHS in place.
887 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
888 lhs[i] += rhs[i];
889 }
890 // Pop off the RHS.
891 operandExprStack.pop_back();
892 }
893
894 //
895 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
896 //
897 // A mod expression "expr mod c" is thus flattened by introducing a new local
898 // variable q (= expr floordiv c), such that expr mod c is replaced with
899 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
visitModExpr(AffineBinaryOpExpr expr)900 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
901 assert(operandExprStack.size() >= 2);
902 // This is a pure affine expr; the RHS will be a constant.
903 assert(expr.getRHS().isa<AffineConstantExpr>());
904 auto rhsConst = operandExprStack.back()[getConstantIndex()];
905 operandExprStack.pop_back();
906 auto &lhs = operandExprStack.back();
907 // TODO: handle modulo by zero case when this issue is fixed
908 // at the other places in the IR.
909 assert(rhsConst > 0 && "RHS constant has to be positive");
910
911 // Check if the LHS expression is a multiple of modulo factor.
912 unsigned i, e;
913 for (i = 0, e = lhs.size(); i < e; i++)
914 if (lhs[i] % rhsConst != 0)
915 break;
916 // If yes, modulo expression here simplifies to zero.
917 if (i == lhs.size()) {
918 std::fill(lhs.begin(), lhs.end(), 0);
919 return;
920 }
921
922 // Add a local variable for the quotient, i.e., expr % c is replaced by
923 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
924 // the GCD of expr and c.
925 SmallVector<int64_t, 8> floorDividend(lhs);
926 uint64_t gcd = rhsConst;
927 for (unsigned i = 0, e = lhs.size(); i < e; i++)
928 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
929 // Simplify the numerator and the denominator.
930 if (gcd != 1) {
931 for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
932 floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
933 }
934 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
935
936 // Construct the AffineExpr form of the floordiv to store in localExprs.
937 MLIRContext *context = expr.getContext();
938 auto dividendExpr = getAffineExprFromFlatForm(
939 floorDividend, numDims, numSymbols, localExprs, context);
940 auto divisorExpr = getAffineConstantExpr(floorDivisor, context);
941 auto floorDivExpr = dividendExpr.floorDiv(divisorExpr);
942 int loc;
943 if ((loc = findLocalId(floorDivExpr)) == -1) {
944 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
945 // Set result at top of stack to "lhs - rhsConst * q".
946 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
947 } else {
948 // Reuse the existing local id.
949 lhs[getLocalVarStartIndex() + loc] = -rhsConst;
950 }
951 }
952
visitCeilDivExpr(AffineBinaryOpExpr expr)953 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
954 visitDivExpr(expr, /*isCeil=*/true);
955 }
visitFloorDivExpr(AffineBinaryOpExpr expr)956 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
957 visitDivExpr(expr, /*isCeil=*/false);
958 }
959
visitDimExpr(AffineDimExpr expr)960 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
961 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
962 auto &eq = operandExprStack.back();
963 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
964 eq[getDimStartIndex() + expr.getPosition()] = 1;
965 }
966
visitSymbolExpr(AffineSymbolExpr expr)967 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
968 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
969 auto &eq = operandExprStack.back();
970 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
971 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
972 }
973
visitConstantExpr(AffineConstantExpr expr)974 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
975 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
976 auto &eq = operandExprStack.back();
977 eq[getConstantIndex()] = expr.getValue();
978 }
979
980 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
981 // A floordiv is thus flattened by introducing a new local variable q, and
982 // replacing that expression with 'q' while adding the constraints
983 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
984 // FlatAffineConstraints::addLocalFloorDiv).
985 //
986 // A ceildiv is similarly flattened:
987 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
visitDivExpr(AffineBinaryOpExpr expr,bool isCeil)988 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
989 bool isCeil) {
990 assert(operandExprStack.size() >= 2);
991 assert(expr.getRHS().isa<AffineConstantExpr>());
992
993 // This is a pure affine expr; the RHS is a positive constant.
994 int64_t rhsConst = operandExprStack.back()[getConstantIndex()];
995 // TODO: handle division by zero at the same time the issue is
996 // fixed at other places.
997 assert(rhsConst > 0 && "RHS constant has to be positive");
998 operandExprStack.pop_back();
999 auto &lhs = operandExprStack.back();
1000
1001 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1002 // common divisors of the numerator and denominator.
1003 uint64_t gcd = std::abs(rhsConst);
1004 for (unsigned i = 0, e = lhs.size(); i < e; i++)
1005 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1006 // Simplify the numerator and the denominator.
1007 if (gcd != 1) {
1008 for (unsigned i = 0, e = lhs.size(); i < e; i++)
1009 lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1010 }
1011 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1012 // If the divisor becomes 1, the updated LHS is the result. (The
1013 // divisor can't be negative since rhsConst is positive).
1014 if (divisor == 1)
1015 return;
1016
1017 // If the divisor cannot be simplified to one, we will have to retain
1018 // the ceil/floor expr (simplified up until here). Add an existential
1019 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1020 // by a new identifier, q.
1021 MLIRContext *context = expr.getContext();
1022 auto a =
1023 getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1024 auto b = getAffineConstantExpr(divisor, context);
1025
1026 int loc;
1027 auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1028 if ((loc = findLocalId(divExpr)) == -1) {
1029 if (!isCeil) {
1030 SmallVector<int64_t, 8> dividend(lhs);
1031 addLocalFloorDivId(dividend, divisor, divExpr);
1032 } else {
1033 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1034 SmallVector<int64_t, 8> dividend(lhs);
1035 dividend.back() += divisor - 1;
1036 addLocalFloorDivId(dividend, divisor, divExpr);
1037 }
1038 }
1039 // Set the expression on stack to the local var introduced to capture the
1040 // result of the division (floor or ceil).
1041 std::fill(lhs.begin(), lhs.end(), 0);
1042 if (loc == -1)
1043 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1044 else
1045 lhs[getLocalVarStartIndex() + loc] = 1;
1046 }
1047
1048 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1049 // The local identifier added is always a floordiv of a pure add/mul affine
1050 // function of other identifiers, coefficients of which are specified in
1051 // dividend and with respect to a positive constant divisor. localExpr is the
1052 // simplified tree expression (AffineExpr) corresponding to the quantifier.
addLocalFloorDivId(ArrayRef<int64_t> dividend,int64_t divisor,AffineExpr localExpr)1053 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1054 int64_t divisor,
1055 AffineExpr localExpr) {
1056 assert(divisor > 0 && "positive constant divisor expected");
1057 for (auto &subExpr : operandExprStack)
1058 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1059 localExprs.push_back(localExpr);
1060 numLocals++;
1061 // dividend and divisor are not used here; an override of this method uses it.
1062 }
1063
findLocalId(AffineExpr localExpr)1064 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1065 SmallVectorImpl<AffineExpr>::iterator it;
1066 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1067 return -1;
1068 return it - localExprs.begin();
1069 }
1070
1071 /// Simplify the affine expression by flattening it and reconstructing it.
simplifyAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols)1072 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1073 unsigned numSymbols) {
1074 // Simplify semi-affine expressions separately.
1075 if (!expr.isPureAffine())
1076 expr = simplifySemiAffine(expr);
1077 if (!expr.isPureAffine())
1078 return expr;
1079
1080 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1081 flattener.walkPostOrder(expr);
1082 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1083 auto simplifiedExpr =
1084 getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1085 flattener.localExprs, expr.getContext());
1086 flattener.operandExprStack.pop_back();
1087 assert(flattener.operandExprStack.empty());
1088
1089 return simplifiedExpr;
1090 }
1091