• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/Presburger/Simplex.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/AffineExprVisitor.h"
19 #include "mlir/IR/IntegerSet.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/MathExtras.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 #define DEBUG_TYPE "affine-structures"
27 
28 using namespace mlir;
29 using llvm::SmallDenseMap;
30 using llvm::SmallDenseSet;
31 
32 namespace {
33 
34 // See comments for SimpleAffineExprFlattener.
35 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
36 // constraint information associated with mod's, floordiv's, and ceildiv's
37 // in FlatAffineConstraints 'localVarCst'.
38 struct AffineExprFlattener : public SimpleAffineExprFlattener {
39 public:
40   // Constraints connecting newly introduced local variables (for mod's and
41   // div's) to existing (dimensional and symbolic) ones. These are always
42   // inequalities.
43   FlatAffineConstraints localVarCst;
44 
AffineExprFlattener__anon0aa6b4c00111::AffineExprFlattener45   AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
46       : SimpleAffineExprFlattener(nDims, nSymbols) {
47     localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
48   }
49 
50 private:
51   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
52   // The local identifier added is always a floordiv of a pure add/mul affine
53   // function of other identifiers, coefficients of which are specified in
54   // `dividend' and with respect to the positive constant `divisor'. localExpr
55   // is the simplified tree expression (AffineExpr) corresponding to the
56   // quantifier.
addLocalFloorDivId__anon0aa6b4c00111::AffineExprFlattener57   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
58                           AffineExpr localExpr) override {
59     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
60     // Update localVarCst.
61     localVarCst.addLocalFloorDiv(dividend, divisor);
62   }
63 };
64 
65 } // end anonymous namespace
66 
67 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
68 // flattened (i.e., semi-affine expressions not handled yet).
69 static LogicalResult
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs,unsigned numDims,unsigned numSymbols,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)70 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
71                         unsigned numSymbols,
72                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
73                         FlatAffineConstraints *localVarCst) {
74   if (exprs.empty()) {
75     localVarCst->reset(numDims, numSymbols);
76     return success();
77   }
78 
79   AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
80   // Use the same flattener to simplify each expression successively. This way
81   // local identifiers / expressions are shared.
82   for (auto expr : exprs) {
83     if (!expr.isPureAffine())
84       return failure();
85 
86     flattener.walkPostOrder(expr);
87   }
88 
89   assert(flattener.operandExprStack.size() == exprs.size());
90   flattenedExprs->clear();
91   flattenedExprs->assign(flattener.operandExprStack.begin(),
92                          flattener.operandExprStack.end());
93 
94   if (localVarCst)
95     localVarCst->clearAndCopyFrom(flattener.localVarCst);
96 
97   return success();
98 }
99 
100 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
101 // be flattened (semi-affine expressions not handled yet).
102 LogicalResult
getFlattenedAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols,SmallVectorImpl<int64_t> * flattenedExpr,FlatAffineConstraints * localVarCst)103 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
104                              unsigned numSymbols,
105                              SmallVectorImpl<int64_t> *flattenedExpr,
106                              FlatAffineConstraints *localVarCst) {
107   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
108   LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
109                                                 &flattenedExprs, localVarCst);
110   *flattenedExpr = flattenedExprs[0];
111   return ret;
112 }
113 
114 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
115 /// flattened (i.e., semi-affine expressions not handled yet).
getFlattenedAffineExprs(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)116 LogicalResult mlir::getFlattenedAffineExprs(
117     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
118     FlatAffineConstraints *localVarCst) {
119   if (map.getNumResults() == 0) {
120     localVarCst->reset(map.getNumDims(), map.getNumSymbols());
121     return success();
122   }
123   return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
124                                    map.getNumSymbols(), flattenedExprs,
125                                    localVarCst);
126 }
127 
getFlattenedAffineExprs(IntegerSet set,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineConstraints * localVarCst)128 LogicalResult mlir::getFlattenedAffineExprs(
129     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
130     FlatAffineConstraints *localVarCst) {
131   if (set.getNumConstraints() == 0) {
132     localVarCst->reset(set.getNumDims(), set.getNumSymbols());
133     return success();
134   }
135   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
136                                    set.getNumSymbols(), flattenedExprs,
137                                    localVarCst);
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // FlatAffineConstraints.
142 //===----------------------------------------------------------------------===//
143 
144 // Copy constructor.
FlatAffineConstraints(const FlatAffineConstraints & other)145 FlatAffineConstraints::FlatAffineConstraints(
146     const FlatAffineConstraints &other) {
147   numReservedCols = other.numReservedCols;
148   numDims = other.getNumDimIds();
149   numSymbols = other.getNumSymbolIds();
150   numIds = other.getNumIds();
151 
152   auto otherIds = other.getIds();
153   ids.reserve(numReservedCols);
154   ids.append(otherIds.begin(), otherIds.end());
155 
156   unsigned numReservedEqualities = other.getNumReservedEqualities();
157   unsigned numReservedInequalities = other.getNumReservedInequalities();
158 
159   equalities.reserve(numReservedEqualities * numReservedCols);
160   inequalities.reserve(numReservedInequalities * numReservedCols);
161 
162   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
163     addInequality(other.getInequality(r));
164   }
165   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
166     addEquality(other.getEquality(r));
167   }
168 }
169 
170 // Clones this object.
clone() const171 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
172   return std::make_unique<FlatAffineConstraints>(*this);
173 }
174 
175 // Construct from an IntegerSet.
FlatAffineConstraints(IntegerSet set)176 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
177     : numReservedCols(set.getNumInputs() + 1),
178       numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
179       numSymbols(set.getNumSymbols()) {
180   equalities.reserve(set.getNumEqualities() * numReservedCols);
181   inequalities.reserve(set.getNumInequalities() * numReservedCols);
182   ids.resize(numIds, None);
183 
184   // Flatten expressions and add them to the constraint system.
185   std::vector<SmallVector<int64_t, 8>> flatExprs;
186   FlatAffineConstraints localVarCst;
187   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
188     assert(false && "flattening unimplemented for semi-affine integer sets");
189     return;
190   }
191   assert(flatExprs.size() == set.getNumConstraints());
192   for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
193     addLocalId(getNumLocalIds());
194   }
195 
196   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
197     const auto &flatExpr = flatExprs[i];
198     assert(flatExpr.size() == getNumCols());
199     if (set.getEqFlags()[i]) {
200       addEquality(flatExpr);
201     } else {
202       addInequality(flatExpr);
203     }
204   }
205   // Add the other constraints involving local id's from flattening.
206   append(localVarCst);
207 }
208 
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> idArgs)209 void FlatAffineConstraints::reset(unsigned numReservedInequalities,
210                                   unsigned numReservedEqualities,
211                                   unsigned newNumReservedCols,
212                                   unsigned newNumDims, unsigned newNumSymbols,
213                                   unsigned newNumLocals,
214                                   ArrayRef<Value> idArgs) {
215   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
216          "minimum 1 column");
217   numReservedCols = newNumReservedCols;
218   numDims = newNumDims;
219   numSymbols = newNumSymbols;
220   numIds = numDims + numSymbols + newNumLocals;
221   assert(idArgs.empty() || idArgs.size() == numIds);
222 
223   clearConstraints();
224   if (numReservedEqualities >= 1)
225     equalities.reserve(newNumReservedCols * numReservedEqualities);
226   if (numReservedInequalities >= 1)
227     inequalities.reserve(newNumReservedCols * numReservedInequalities);
228   if (idArgs.empty()) {
229     ids.resize(numIds, None);
230   } else {
231     ids.assign(idArgs.begin(), idArgs.end());
232   }
233 }
234 
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> idArgs)235 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
236                                   unsigned newNumLocals,
237                                   ArrayRef<Value> idArgs) {
238   reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
239         newNumSymbols, newNumLocals, idArgs);
240 }
241 
append(const FlatAffineConstraints & other)242 void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
243   assert(other.getNumCols() == getNumCols());
244   assert(other.getNumDimIds() == getNumDimIds());
245   assert(other.getNumSymbolIds() == getNumSymbolIds());
246 
247   inequalities.reserve(inequalities.size() +
248                        other.getNumInequalities() * numReservedCols);
249   equalities.reserve(equalities.size() +
250                      other.getNumEqualities() * numReservedCols);
251 
252   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
253     addInequality(other.getInequality(r));
254   }
255   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
256     addEquality(other.getEquality(r));
257   }
258 }
259 
addLocalId(unsigned pos)260 void FlatAffineConstraints::addLocalId(unsigned pos) {
261   addId(IdKind::Local, pos);
262 }
263 
addDimId(unsigned pos,Value id)264 void FlatAffineConstraints::addDimId(unsigned pos, Value id) {
265   addId(IdKind::Dimension, pos, id);
266 }
267 
addSymbolId(unsigned pos,Value id)268 void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) {
269   addId(IdKind::Symbol, pos, id);
270 }
271 
272 /// Adds a dimensional identifier. The added column is initialized to
273 /// zero.
addId(IdKind kind,unsigned pos,Value id)274 void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) {
275   if (kind == IdKind::Dimension)
276     assert(pos <= getNumDimIds());
277   else if (kind == IdKind::Symbol)
278     assert(pos <= getNumSymbolIds());
279   else
280     assert(pos <= getNumLocalIds());
281 
282   unsigned oldNumReservedCols = numReservedCols;
283 
284   // Check if a resize is necessary.
285   if (getNumCols() + 1 > numReservedCols) {
286     equalities.resize(getNumEqualities() * (getNumCols() + 1));
287     inequalities.resize(getNumInequalities() * (getNumCols() + 1));
288     numReservedCols++;
289   }
290 
291   int absolutePos;
292 
293   if (kind == IdKind::Dimension) {
294     absolutePos = pos;
295     numDims++;
296   } else if (kind == IdKind::Symbol) {
297     absolutePos = pos + getNumDimIds();
298     numSymbols++;
299   } else {
300     absolutePos = pos + getNumDimIds() + getNumSymbolIds();
301   }
302   numIds++;
303 
304   // Note that getNumCols() now will already return the new size, which will be
305   // at least one.
306   int numInequalities = static_cast<int>(getNumInequalities());
307   int numEqualities = static_cast<int>(getNumEqualities());
308   int numCols = static_cast<int>(getNumCols());
309   for (int r = numInequalities - 1; r >= 0; r--) {
310     for (int c = numCols - 2; c >= 0; c--) {
311       if (c < absolutePos)
312         atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
313       else
314         atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
315     }
316     atIneq(r, absolutePos) = 0;
317   }
318 
319   for (int r = numEqualities - 1; r >= 0; r--) {
320     for (int c = numCols - 2; c >= 0; c--) {
321       // All values in column absolutePositions < absolutePos have the same
322       // coordinates in the 2-d view of the coefficient buffer.
323       if (c < absolutePos)
324         atEq(r, c) = equalities[r * oldNumReservedCols + c];
325       else
326         // Those at absolutePosition >= absolutePos, get a shifted
327         // absolutePosition.
328         atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
329     }
330     // Initialize added dimension to zero.
331     atEq(r, absolutePos) = 0;
332   }
333 
334   // If an 'id' is provided, insert it; otherwise use None.
335   if (id)
336     ids.insert(ids.begin() + absolutePos, id);
337   else
338     ids.insert(ids.begin() + absolutePos, None);
339   assert(ids.size() == getNumIds());
340 }
341 
342 /// Checks if two constraint systems are in the same space, i.e., if they are
343 /// associated with the same set of identifiers, appearing in the same order.
areIdsAligned(const FlatAffineConstraints & A,const FlatAffineConstraints & B)344 static bool areIdsAligned(const FlatAffineConstraints &A,
345                           const FlatAffineConstraints &B) {
346   return A.getNumDimIds() == B.getNumDimIds() &&
347          A.getNumSymbolIds() == B.getNumSymbolIds() &&
348          A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
349 }
350 
351 /// Calls areIdsAligned to check if two constraint systems have the same set
352 /// of identifiers in the same order.
areIdsAlignedWithOther(const FlatAffineConstraints & other)353 bool FlatAffineConstraints::areIdsAlignedWithOther(
354     const FlatAffineConstraints &other) {
355   return areIdsAligned(*this, other);
356 }
357 
358 /// Checks if the SSA values associated with `cst''s identifiers are unique.
359 static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineConstraints & cst)360 areIdsUnique(const FlatAffineConstraints &cst) {
361   SmallPtrSet<Value, 8> uniqueIds;
362   for (auto id : cst.getIds()) {
363     if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
364       return false;
365   }
366   return true;
367 }
368 
369 /// Merge and align the identifiers of A and B starting at 'offset', so that
370 /// both constraint systems get the union of the contained identifiers that is
371 /// dimension-wise and symbol-wise unique; both constraint systems are updated
372 /// so that they have the union of all identifiers, with A's original
373 /// identifiers appearing first followed by any of B's identifiers that didn't
374 /// appear in A. Local identifiers of each system are by design separate/local
375 /// and are placed one after other (A's followed by B's).
376 //  Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
377 //      Output: both A, B have (%i, %j, %k) [%M, %N, %P]
378 //
mergeAndAlignIds(unsigned offset,FlatAffineConstraints * A,FlatAffineConstraints * B)379 static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
380                              FlatAffineConstraints *B) {
381   assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
382   // A merge/align isn't meaningful if a cst's ids aren't distinct.
383   assert(areIdsUnique(*A) && "A's id values aren't unique");
384   assert(areIdsUnique(*B) && "B's id values aren't unique");
385 
386   assert(std::all_of(A->getIds().begin() + offset,
387                      A->getIds().begin() + A->getNumDimAndSymbolIds(),
388                      [](Optional<Value> id) { return id.hasValue(); }));
389 
390   assert(std::all_of(B->getIds().begin() + offset,
391                      B->getIds().begin() + B->getNumDimAndSymbolIds(),
392                      [](Optional<Value> id) { return id.hasValue(); }));
393 
394   // Place local id's of A after local id's of B.
395   for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
396     B->addLocalId(0);
397   }
398   for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
399        t++) {
400     A->addLocalId(A->getNumLocalIds());
401   }
402 
403   SmallVector<Value, 4> aDimValues, aSymValues;
404   A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
405   A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
406   {
407     // Merge dims from A into B.
408     unsigned d = offset;
409     for (auto aDimValue : aDimValues) {
410       unsigned loc;
411       if (B->findId(aDimValue, &loc)) {
412         assert(loc >= offset && "A's dim appears in B's aligned range");
413         assert(loc < B->getNumDimIds() &&
414                "A's dim appears in B's non-dim position");
415         B->swapId(d, loc);
416       } else {
417         B->addDimId(d);
418         B->setIdValue(d, aDimValue);
419       }
420       d++;
421     }
422 
423     // Dimensions that are in B, but not in A, are added at the end.
424     for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
425       A->addDimId(A->getNumDimIds());
426       A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
427     }
428   }
429   {
430     // Merge symbols: merge A's symbols into B first.
431     unsigned s = B->getNumDimIds();
432     for (auto aSymValue : aSymValues) {
433       unsigned loc;
434       if (B->findId(aSymValue, &loc)) {
435         assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
436                "A's symbol appears in B's non-symbol position");
437         B->swapId(s, loc);
438       } else {
439         B->addSymbolId(s - B->getNumDimIds());
440         B->setIdValue(s, aSymValue);
441       }
442       s++;
443     }
444     // Symbols that are in B, but not in A, are added at the end.
445     for (unsigned t = A->getNumDimAndSymbolIds(),
446                   e = B->getNumDimAndSymbolIds();
447          t < e; t++) {
448       A->addSymbolId(A->getNumSymbolIds());
449       A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
450     }
451   }
452   assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
453 }
454 
455 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
mergeAndAlignIdsWithOther(unsigned offset,FlatAffineConstraints * other)456 void FlatAffineConstraints::mergeAndAlignIdsWithOther(
457     unsigned offset, FlatAffineConstraints *other) {
458   mergeAndAlignIds(offset, this, other);
459 }
460 
461 // This routine may add additional local variables if the flattened expression
462 // corresponding to the map has such variables due to mod's, ceildiv's, and
463 // floordiv's in it.
composeMap(const AffineValueMap * vMap)464 LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
465   std::vector<SmallVector<int64_t, 8>> flatExprs;
466   FlatAffineConstraints localCst;
467   if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
468                                      &localCst))) {
469     LLVM_DEBUG(llvm::dbgs()
470                << "composition unimplemented for semi-affine maps\n");
471     return failure();
472   }
473   assert(flatExprs.size() == vMap->getNumResults());
474 
475   // Add localCst information.
476   if (localCst.getNumLocalIds() > 0) {
477     localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(),
478                          /*values=*/vMap->getOperands());
479     // Align localCst and this.
480     mergeAndAlignIds(/*offset=*/0, &localCst, this);
481     // Finally, append localCst to this constraint set.
482     append(localCst);
483   }
484 
485   // Add dimensions corresponding to the map's results.
486   for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
487     // TODO: Consider using a batched version to add a range of IDs.
488     addDimId(0);
489   }
490 
491   // We add one equality for each result connecting the result dim of the map to
492   // the other identifiers.
493   // For eg: if the expression is 16*i0 + i1, and this is the r^th
494   // iteration/result of the value map, we are adding the equality:
495   //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
496   //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
497   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
498     const auto &flatExpr = flatExprs[r];
499     assert(flatExpr.size() >= vMap->getNumOperands() + 1);
500 
501     // eqToAdd is the equality corresponding to the flattened affine expression.
502     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
503     // Set the coefficient for this result to one.
504     eqToAdd[r] = 1;
505 
506     // Dims and symbols.
507     for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
508       unsigned loc;
509       bool ret = findId(vMap->getOperand(i), &loc);
510       assert(ret && "value map's id can't be found");
511       (void)ret;
512       // Negate 'eq[r]' since the newly added dimension will be set to this one.
513       eqToAdd[loc] = -flatExpr[i];
514     }
515     // Local vars common to eq and localCst are at the beginning.
516     unsigned j = getNumDimIds() + getNumSymbolIds();
517     unsigned end = flatExpr.size() - 1;
518     for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
519       eqToAdd[j] = -flatExpr[i];
520     }
521 
522     // Constant term.
523     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
524 
525     // Add the equality connecting the result of the map to this constraint set.
526     addEquality(eqToAdd);
527   }
528 
529   return success();
530 }
531 
532 // Similar to composeMap except that no Value's need be associated with the
533 // constraint system nor are they looked at -- since the dimensions and
534 // symbols of 'other' are expected to correspond 1:1 to 'this' system. It
535 // is thus not convenient to share code with composeMap.
composeMatchingMap(AffineMap other)536 LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
537   assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
538   assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
539 
540   std::vector<SmallVector<int64_t, 8>> flatExprs;
541   FlatAffineConstraints localCst;
542   if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
543     LLVM_DEBUG(llvm::dbgs()
544                << "composition unimplemented for semi-affine maps\n");
545     return failure();
546   }
547   assert(flatExprs.size() == other.getNumResults());
548 
549   // Add localCst information.
550   if (localCst.getNumLocalIds() > 0) {
551     // Place local id's of A after local id's of B.
552     for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
553       addLocalId(0);
554     }
555     // Finally, append localCst to this constraint set.
556     append(localCst);
557   }
558 
559   // Add dimensions corresponding to the map's results.
560   for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
561     addDimId(0);
562   }
563 
564   // We add one equality for each result connecting the result dim of the map to
565   // the other identifiers.
566   // For eg: if the expression is 16*i0 + i1, and this is the r^th
567   // iteration/result of the value map, we are adding the equality:
568   //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
569   //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
570   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
571     const auto &flatExpr = flatExprs[r];
572     assert(flatExpr.size() >= other.getNumInputs() + 1);
573 
574     // eqToAdd is the equality corresponding to the flattened affine expression.
575     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
576     // Set the coefficient for this result to one.
577     eqToAdd[r] = 1;
578 
579     // Dims and symbols.
580     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
581       // Negate 'eq[r]' since the newly added dimension will be set to this one.
582       eqToAdd[e + i] = -flatExpr[i];
583     }
584     // Local vars common to eq and localCst are at the beginning.
585     unsigned j = getNumDimIds() + getNumSymbolIds();
586     unsigned end = flatExpr.size() - 1;
587     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
588       eqToAdd[j] = -flatExpr[i];
589     }
590 
591     // Constant term.
592     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
593 
594     // Add the equality connecting the result of the map to this constraint set.
595     addEquality(eqToAdd);
596   }
597 
598   return success();
599 }
600 
601 // Turn a dimension into a symbol.
turnDimIntoSymbol(FlatAffineConstraints * cst,Value id)602 static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
603   unsigned pos;
604   if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
605     cst->swapId(pos, cst->getNumDimIds() - 1);
606     cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
607   }
608 }
609 
610 // Turn a symbol into a dimension.
turnSymbolIntoDim(FlatAffineConstraints * cst,Value id)611 static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
612   unsigned pos;
613   if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
614       pos < cst->getNumDimAndSymbolIds()) {
615     cst->swapId(pos, cst->getNumDimIds());
616     cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
617   }
618 }
619 
620 // Changes all symbol identifiers which are loop IVs to dim identifiers.
convertLoopIVSymbolsToDims()621 void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
622   // Gather all symbols which are loop IVs.
623   SmallVector<Value, 4> loopIVs;
624   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
625     if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
626       loopIVs.push_back(ids[i].getValue());
627   }
628   // Turn each symbol in 'loopIVs' into a dim identifier.
629   for (auto iv : loopIVs) {
630     turnSymbolIntoDim(this, iv);
631   }
632 }
633 
addInductionVarOrTerminalSymbol(Value id)634 void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
635   if (containsId(id))
636     return;
637 
638   // Caller is expected to fully compose map/operands if necessary.
639   assert((isTopLevelValue(id) || isForInductionVar(id)) &&
640          "non-terminal symbol / loop IV expected");
641   // Outer loop IVs could be used in forOp's bounds.
642   if (auto loop = getForInductionVarOwner(id)) {
643     addDimId(getNumDimIds(), id);
644     if (failed(this->addAffineForOpDomain(loop)))
645       LLVM_DEBUG(
646           loop.emitWarning("failed to add domain info to constraint system"));
647     return;
648   }
649   // Add top level symbol.
650   addSymbolId(getNumSymbolIds(), id);
651   // Check if the symbol is a constant.
652   if (auto constOp = id.getDefiningOp<ConstantIndexOp>())
653     setIdToConstant(id, constOp.getValue());
654 }
655 
addAffineForOpDomain(AffineForOp forOp)656 LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
657   unsigned pos;
658   // Pre-condition for this method.
659   if (!findId(forOp.getInductionVar(), &pos)) {
660     assert(false && "Value not found");
661     return failure();
662   }
663 
664   int64_t step = forOp.getStep();
665   if (step != 1) {
666     if (!forOp.hasConstantLowerBound())
667       forOp.emitWarning("domain conservatively approximated");
668     else {
669       // Add constraints for the stride.
670       // (iv - lb) % step = 0 can be written as:
671       // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
672       // Add local variable 'q' and add the above equality.
673       // The first constraint is q = (iv - lb) floordiv step
674       SmallVector<int64_t, 8> dividend(getNumCols(), 0);
675       int64_t lb = forOp.getConstantLowerBound();
676       dividend[pos] = 1;
677       dividend.back() -= lb;
678       addLocalFloorDiv(dividend, step);
679       // Second constraint: (iv - lb) - step * q = 0.
680       SmallVector<int64_t, 8> eq(getNumCols(), 0);
681       eq[pos] = 1;
682       eq.back() -= lb;
683       // For the local var just added above.
684       eq[getNumCols() - 2] = -step;
685       addEquality(eq);
686     }
687   }
688 
689   if (forOp.hasConstantLowerBound()) {
690     addConstantLowerBound(pos, forOp.getConstantLowerBound());
691   } else {
692     // Non-constant lower bound case.
693     if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(),
694                                     forOp.getLowerBoundOperands(),
695                                     /*eq=*/false, /*lower=*/true)))
696       return failure();
697   }
698 
699   if (forOp.hasConstantUpperBound()) {
700     addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
701     return success();
702   }
703   // Non-constant upper bound case.
704   return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(),
705                               forOp.getUpperBoundOperands(),
706                               /*eq=*/false, /*lower=*/false);
707 }
708 
addAffineIfOpDomain(AffineIfOp ifOp)709 void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
710   // Create the base constraints from the integer set attached to ifOp.
711   FlatAffineConstraints cst(ifOp.getIntegerSet());
712 
713   // Bind ids in the constraints to ifOp operands.
714   SmallVector<Value, 4> operands = ifOp.getOperands();
715   cst.setIdValues(0, cst.getNumDimAndSymbolIds(), operands);
716 
717   // Merge the constraints from ifOp to the current domain. We need first merge
718   // and align the IDs from both constraints, and then append the constraints
719   // from the ifOp into the current one.
720   mergeAndAlignIdsWithOther(0, &cst);
721   append(cst);
722 }
723 
724 // Searches for a constraint with a non-zero coefficient at 'colIdx' in
725 // equality (isEq=true) or inequality (isEq=false) constraints.
726 // Returns true and sets row found in search in 'rowIdx'.
727 // Returns false otherwise.
findConstraintWithNonZeroAt(const FlatAffineConstraints & cst,unsigned colIdx,bool isEq,unsigned * rowIdx)728 static bool findConstraintWithNonZeroAt(const FlatAffineConstraints &cst,
729                                         unsigned colIdx, bool isEq,
730                                         unsigned *rowIdx) {
731   assert(colIdx < cst.getNumCols() && "position out of bounds");
732   auto at = [&](unsigned rowIdx) -> int64_t {
733     return isEq ? cst.atEq(rowIdx, colIdx) : cst.atIneq(rowIdx, colIdx);
734   };
735   unsigned e = isEq ? cst.getNumEqualities() : cst.getNumInequalities();
736   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
737     if (at(*rowIdx) != 0) {
738       return true;
739     }
740   }
741   return false;
742 }
743 
744 // Normalizes the coefficient values across all columns in 'rowIDx' by their
745 // GCD in equality or inequality constraints as specified by 'isEq'.
746 template <bool isEq>
normalizeConstraintByGCD(FlatAffineConstraints * constraints,unsigned rowIdx)747 static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
748                                      unsigned rowIdx) {
749   auto at = [&](unsigned colIdx) -> int64_t {
750     return isEq ? constraints->atEq(rowIdx, colIdx)
751                 : constraints->atIneq(rowIdx, colIdx);
752   };
753   uint64_t gcd = std::abs(at(0));
754   for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
755     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
756   }
757   if (gcd > 0 && gcd != 1) {
758     for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
759       int64_t v = at(j) / static_cast<int64_t>(gcd);
760       isEq ? constraints->atEq(rowIdx, j) = v
761            : constraints->atIneq(rowIdx, j) = v;
762     }
763   }
764 }
765 
normalizeConstraintsByGCD()766 void FlatAffineConstraints::normalizeConstraintsByGCD() {
767   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
768     normalizeConstraintByGCD</*isEq=*/true>(this, i);
769   }
770   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
771     normalizeConstraintByGCD</*isEq=*/false>(this, i);
772   }
773 }
774 
hasConsistentState() const775 bool FlatAffineConstraints::hasConsistentState() const {
776   if (inequalities.size() != getNumInequalities() * numReservedCols)
777     return false;
778   if (equalities.size() != getNumEqualities() * numReservedCols)
779     return false;
780   if (ids.size() != getNumIds())
781     return false;
782 
783   // Catches errors where numDims, numSymbols, numIds aren't consistent.
784   if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
785     return false;
786 
787   return true;
788 }
789 
790 /// Checks all rows of equality/inequality constraints for trivial
791 /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
792 /// after elimination. Returns 'true' if an invalid constraint is found;
793 /// 'false' otherwise.
hasInvalidConstraint() const794 bool FlatAffineConstraints::hasInvalidConstraint() const {
795   assert(hasConsistentState());
796   auto check = [&](bool isEq) -> bool {
797     unsigned numCols = getNumCols();
798     unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
799     for (unsigned i = 0, e = numRows; i < e; ++i) {
800       unsigned j;
801       for (j = 0; j < numCols - 1; ++j) {
802         int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
803         // Skip rows with non-zero variable coefficients.
804         if (v != 0)
805           break;
806       }
807       if (j < numCols - 1) {
808         continue;
809       }
810       // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
811       // Example invalid constraints include: '1 == 0' or '-1 >= 0'
812       int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
813       if ((isEq && v != 0) || (!isEq && v < 0)) {
814         return true;
815       }
816     }
817     return false;
818   };
819   if (check(/*isEq=*/true))
820     return true;
821   return check(/*isEq=*/false);
822 }
823 
824 // Eliminate identifier from constraint at 'rowIdx' based on coefficient at
825 // pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
826 // updated as they have already been eliminated.
eliminateFromConstraint(FlatAffineConstraints * constraints,unsigned rowIdx,unsigned pivotRow,unsigned pivotCol,unsigned elimColStart,bool isEq)827 static void eliminateFromConstraint(FlatAffineConstraints *constraints,
828                                     unsigned rowIdx, unsigned pivotRow,
829                                     unsigned pivotCol, unsigned elimColStart,
830                                     bool isEq) {
831   // Skip if equality 'rowIdx' if same as 'pivotRow'.
832   if (isEq && rowIdx == pivotRow)
833     return;
834   auto at = [&](unsigned i, unsigned j) -> int64_t {
835     return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
836   };
837   int64_t leadCoeff = at(rowIdx, pivotCol);
838   // Skip if leading coefficient at 'rowIdx' is already zero.
839   if (leadCoeff == 0)
840     return;
841   int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
842   int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
843   int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
844   int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
845   int64_t rowMultiplier = lcm / std::abs(leadCoeff);
846 
847   unsigned numCols = constraints->getNumCols();
848   for (unsigned j = 0; j < numCols; ++j) {
849     // Skip updating column 'j' if it was just eliminated.
850     if (j >= elimColStart && j < pivotCol)
851       continue;
852     int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
853                 rowMultiplier * at(rowIdx, j);
854     isEq ? constraints->atEq(rowIdx, j) = v
855          : constraints->atIneq(rowIdx, j) = v;
856   }
857 }
858 
859 // Remove coefficients in column range [colStart, colLimit) in place.
860 // This removes in data in the specified column range, and copies any
861 // remaining valid data into place.
shiftColumnsToLeft(FlatAffineConstraints * constraints,unsigned colStart,unsigned colLimit,bool isEq)862 static void shiftColumnsToLeft(FlatAffineConstraints *constraints,
863                                unsigned colStart, unsigned colLimit,
864                                bool isEq) {
865   assert(colLimit <= constraints->getNumIds());
866   if (colLimit <= colStart)
867     return;
868 
869   unsigned numCols = constraints->getNumCols();
870   unsigned numRows = isEq ? constraints->getNumEqualities()
871                           : constraints->getNumInequalities();
872   unsigned numToEliminate = colLimit - colStart;
873   for (unsigned r = 0, e = numRows; r < e; ++r) {
874     for (unsigned c = colLimit; c < numCols; ++c) {
875       if (isEq) {
876         constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c);
877       } else {
878         constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c);
879       }
880     }
881   }
882 }
883 
884 // Removes identifiers in column range [idStart, idLimit), and copies any
885 // remaining valid data into place, and updates member variables.
removeIdRange(unsigned idStart,unsigned idLimit)886 void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
887   assert(idLimit < getNumCols() && "invalid id limit");
888 
889   if (idStart >= idLimit)
890     return;
891 
892   // We are going to be removing one or more identifiers from the range.
893   assert(idStart < numIds && "invalid idStart position");
894 
895   // TODO: Make 'removeIdRange' a lambda called from here.
896   // Remove eliminated identifiers from equalities.
897   shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true);
898 
899   // Remove eliminated identifiers from inequalities.
900   shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false);
901 
902   // Update members numDims, numSymbols and numIds.
903   unsigned numDimsEliminated = 0;
904   unsigned numLocalsEliminated = 0;
905   unsigned numColsEliminated = idLimit - idStart;
906   if (idStart < numDims) {
907     numDimsEliminated = std::min(numDims, idLimit) - idStart;
908   }
909   // Check how many local id's were removed. Note that our identifier order is
910   // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
911   if (idLimit > numDims + numSymbols) {
912     numLocalsEliminated = std::min(
913         idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
914   }
915   unsigned numSymbolsEliminated =
916       numColsEliminated - numDimsEliminated - numLocalsEliminated;
917 
918   numDims -= numDimsEliminated;
919   numSymbols -= numSymbolsEliminated;
920   numIds = numIds - numColsEliminated;
921 
922   ids.erase(ids.begin() + idStart, ids.begin() + idLimit);
923 
924   // No resize necessary. numReservedCols remains the same.
925 }
926 
927 /// Returns the position of the identifier that has the minimum <number of lower
928 /// bounds> times <number of upper bounds> from the specified range of
929 /// identifiers [start, end). It is often best to eliminate in the increasing
930 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
931 /// that many new constraints.
getBestIdToEliminate(const FlatAffineConstraints & cst,unsigned start,unsigned end)932 static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
933                                      unsigned start, unsigned end) {
934   assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
935 
936   auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
937     unsigned numLb = 0;
938     unsigned numUb = 0;
939     for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
940       if (cst.atIneq(r, pos) > 0) {
941         ++numLb;
942       } else if (cst.atIneq(r, pos) < 0) {
943         ++numUb;
944       }
945     }
946     return numLb * numUb;
947   };
948 
949   unsigned minLoc = start;
950   unsigned min = getProductOfNumLowerUpperBounds(start);
951   for (unsigned c = start + 1; c < end; c++) {
952     unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
953     if (numLbUbProduct < min) {
954       min = numLbUbProduct;
955       minLoc = c;
956     }
957   }
958   return minLoc;
959 }
960 
961 // Checks for emptiness of the set by eliminating identifiers successively and
962 // using the GCD test (on all equality constraints) and checking for trivially
963 // invalid constraints. Returns 'true' if the constraint system is found to be
964 // empty; false otherwise.
isEmpty() const965 bool FlatAffineConstraints::isEmpty() const {
966   if (isEmptyByGCDTest() || hasInvalidConstraint())
967     return true;
968 
969   // First, eliminate as many identifiers as possible using Gaussian
970   // elimination.
971   FlatAffineConstraints tmpCst(*this);
972   unsigned currentPos = 0;
973   while (currentPos < tmpCst.getNumIds()) {
974     tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
975     ++currentPos;
976     // We check emptiness through trivial checks after eliminating each ID to
977     // detect emptiness early. Since the checks isEmptyByGCDTest() and
978     // hasInvalidConstraint() are linear time and single sweep on the constraint
979     // buffer, this appears reasonable - but can optimize in the future.
980     if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
981       return true;
982   }
983 
984   // Eliminate the remaining using FM.
985   for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
986     tmpCst.FourierMotzkinEliminate(
987         getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
988     // Check for a constraint explosion. This rarely happens in practice, but
989     // this check exists as a safeguard against improperly constructed
990     // constraint systems or artificially created arbitrarily complex systems
991     // that aren't the intended use case for FlatAffineConstraints. This is
992     // needed since FM has a worst case exponential complexity in theory.
993     if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
994       LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
995       return false;
996     }
997 
998     // FM wouldn't have modified the equalities in any way. So no need to again
999     // run GCD test. Check for trivial invalid constraints.
1000     if (tmpCst.hasInvalidConstraint())
1001       return true;
1002   }
1003   return false;
1004 }
1005 
1006 // Runs the GCD test on all equality constraints. Returns 'true' if this test
1007 // fails on any equality. Returns 'false' otherwise.
1008 // This test can be used to disprove the existence of a solution. If it returns
1009 // true, no integer solution to the equality constraints can exist.
1010 //
1011 // GCD test definition:
1012 //
1013 // The equality constraint:
1014 //
1015 //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
1016 //
1017 // has an integer solution iff:
1018 //
1019 //  GCD of c_1, c_2, ..., c_n divides c_0.
1020 //
isEmptyByGCDTest() const1021 bool FlatAffineConstraints::isEmptyByGCDTest() const {
1022   assert(hasConsistentState());
1023   unsigned numCols = getNumCols();
1024   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1025     uint64_t gcd = std::abs(atEq(i, 0));
1026     for (unsigned j = 1; j < numCols - 1; ++j) {
1027       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
1028     }
1029     int64_t v = std::abs(atEq(i, numCols - 1));
1030     if (gcd > 0 && (v % gcd != 0)) {
1031       return true;
1032     }
1033   }
1034   return false;
1035 }
1036 
1037 // First, try the GCD test heuristic.
1038 //
1039 // If that doesn't find the set empty, check if the set is unbounded. If it is,
1040 // we cannot use the GBR algorithm and we conservatively return false.
1041 //
1042 // If the set is bounded, we use the complete emptiness check for this case
1043 // provided by Simplex::findIntegerSample(), which gives a definitive answer.
isIntegerEmpty() const1044 bool FlatAffineConstraints::isIntegerEmpty() const {
1045   if (isEmptyByGCDTest())
1046     return true;
1047 
1048   Simplex simplex(*this);
1049   if (simplex.isUnbounded())
1050     return false;
1051   return !simplex.findIntegerSample().hasValue();
1052 }
1053 
1054 Optional<SmallVector<int64_t, 8>>
findIntegerSample() const1055 FlatAffineConstraints::findIntegerSample() const {
1056   return Simplex(*this).findIntegerSample();
1057 }
1058 
1059 /// Helper to evaluate an affine expression at a point.
1060 /// The expression is a list of coefficients for the dimensions followed by the
1061 /// constant term.
valueAt(ArrayRef<int64_t> expr,ArrayRef<int64_t> point)1062 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
1063   assert(expr.size() == 1 + point.size() &&
1064          "Dimensionalities of point and expresion don't match!");
1065   int64_t value = expr.back();
1066   for (unsigned i = 0; i < point.size(); ++i)
1067     value += expr[i] * point[i];
1068   return value;
1069 }
1070 
1071 /// A point satisfies an equality iff the value of the equality at the
1072 /// expression is zero, and it satisfies an inequality iff the value of the
1073 /// inequality at that point is non-negative.
containsPoint(ArrayRef<int64_t> point) const1074 bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
1075   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1076     if (valueAt(getEquality(i), point) != 0)
1077       return false;
1078   }
1079   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1080     if (valueAt(getInequality(i), point) < 0)
1081       return false;
1082   }
1083   return true;
1084 }
1085 
1086 /// Tightens inequalities given that we are dealing with integer spaces. This is
1087 /// analogous to the GCD test but applied to inequalities. The constant term can
1088 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
1089 ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
1090 /// fast method - linear in the number of coefficients.
1091 // Example on how this affects practical cases: consider the scenario:
1092 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
1093 // j >= 100 instead of the tighter (exact) j >= 128.
GCDTightenInequalities()1094 void FlatAffineConstraints::GCDTightenInequalities() {
1095   unsigned numCols = getNumCols();
1096   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1097     uint64_t gcd = std::abs(atIneq(i, 0));
1098     for (unsigned j = 1; j < numCols - 1; ++j) {
1099       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
1100     }
1101     if (gcd > 0 && gcd != 1) {
1102       int64_t gcdI = static_cast<int64_t>(gcd);
1103       // Tighten the constant term and normalize the constraint by the GCD.
1104       atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
1105       for (unsigned j = 0, e = numCols - 1; j < e; ++j)
1106         atIneq(i, j) /= gcdI;
1107     }
1108   }
1109 }
1110 
1111 // Eliminates all identifier variables in column range [posStart, posLimit).
1112 // Returns the number of variables eliminated.
gaussianEliminateIds(unsigned posStart,unsigned posLimit)1113 unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
1114                                                      unsigned posLimit) {
1115   // Return if identifier positions to eliminate are out of range.
1116   assert(posLimit <= numIds);
1117   assert(hasConsistentState());
1118 
1119   if (posStart >= posLimit)
1120     return 0;
1121 
1122   GCDTightenInequalities();
1123 
1124   unsigned pivotCol = 0;
1125   for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1126     // Find a row which has a non-zero coefficient in column 'j'.
1127     unsigned pivotRow;
1128     if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
1129                                      &pivotRow)) {
1130       // No pivot row in equalities with non-zero at 'pivotCol'.
1131       if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
1132                                        &pivotRow)) {
1133         // If inequalities are also non-zero in 'pivotCol', it can be
1134         // eliminated.
1135         continue;
1136       }
1137       break;
1138     }
1139 
1140     // Eliminate identifier at 'pivotCol' from each equality row.
1141     for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1142       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1143                               /*isEq=*/true);
1144       normalizeConstraintByGCD</*isEq=*/true>(this, i);
1145     }
1146 
1147     // Eliminate identifier at 'pivotCol' from each inequality row.
1148     for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
1149       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
1150                               /*isEq=*/false);
1151       normalizeConstraintByGCD</*isEq=*/false>(this, i);
1152     }
1153     removeEquality(pivotRow);
1154     GCDTightenInequalities();
1155   }
1156   // Update position limit based on number eliminated.
1157   posLimit = pivotCol;
1158   // Remove eliminated columns from all constraints.
1159   removeIdRange(posStart, posLimit);
1160   return posLimit - posStart;
1161 }
1162 
1163 // Detect the identifier at 'pos' (say id_r) as modulo of another identifier
1164 // (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
1165 // could be detected as the floordiv of n. For eg:
1166 // id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3    <=>
1167 //                          id_r = id_n mod 4, id_q = id_n floordiv 4.
1168 // lbConst and ubConst are the constant lower and upper bounds for 'pos' -
1169 // pre-detected at the caller.
detectAsMod(const FlatAffineConstraints & cst,unsigned pos,int64_t lbConst,int64_t ubConst,SmallVectorImpl<AffineExpr> * memo)1170 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
1171                         int64_t lbConst, int64_t ubConst,
1172                         SmallVectorImpl<AffineExpr> *memo) {
1173   assert(pos < cst.getNumIds() && "invalid position");
1174 
1175   // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
1176   // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
1177   // and id_q the quotient when dividing id_n by the divisor.
1178 
1179   if (lbConst != 0 || ubConst < 1)
1180     return false;
1181 
1182   int64_t divisor = ubConst + 1;
1183 
1184   // Now check for: id_r =  id_n - divisor * id_q. As an example, we
1185   // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
1186   unsigned seenQuotient = 0, seenDividend = 0;
1187   int quotientPos = -1, dividendPos = -1;
1188   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1189     // id_n should have coeff 1 or -1.
1190     if (std::abs(cst.atEq(r, pos)) != 1)
1191       continue;
1192     // constant term should be 0.
1193     if (cst.atEq(r, cst.getNumCols() - 1) != 0)
1194       continue;
1195     unsigned c, f;
1196     int quotientSign = 1, dividendSign = 1;
1197     for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
1198       if (c == pos)
1199         continue;
1200       // The coefficient of the quotient should be +/-divisor.
1201       // TODO: could be extended to detect an affine function for the quotient
1202       // (i.e., the coeff could be a non-zero multiple of divisor).
1203       int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
1204       if (v == divisor || v == -divisor) {
1205         seenQuotient++;
1206         quotientPos = c;
1207         quotientSign = v > 0 ? 1 : -1;
1208       }
1209       // The coefficient of the dividend should be +/-1.
1210       // TODO: could be extended to detect an affine function of the other
1211       // identifiers as the dividend.
1212       else if (v == -1 || v == 1) {
1213         seenDividend++;
1214         dividendPos = c;
1215         dividendSign = v < 0 ? 1 : -1;
1216       } else if (cst.atEq(r, c) != 0) {
1217         // Cannot be inferred as a mod since the constraint has a coefficient
1218         // for an identifier that's neither a unit nor the divisor (see TODOs
1219         // above).
1220         break;
1221       }
1222     }
1223     if (c < f)
1224       // Cannot be inferred as a mod since the constraint has a coefficient for
1225       // an identifier that's neither a unit nor the divisor (see TODOs above).
1226       continue;
1227 
1228     // We are looking for exactly one identifier as the dividend.
1229     if (seenDividend == 1 && seenQuotient >= 1) {
1230       if (!(*memo)[dividendPos])
1231         return false;
1232       // Successfully detected a mod.
1233       (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1234       auto ub = cst.getConstantUpperBound(dividendPos);
1235       if (ub.hasValue() && ub.getValue() < divisor)
1236         // The mod can be optimized away.
1237         (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
1238       else
1239         (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
1240 
1241       if (seenQuotient == 1 && !(*memo)[quotientPos])
1242         // Successfully detected a floordiv as well.
1243         (*memo)[quotientPos] =
1244             (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
1245       return true;
1246     }
1247   }
1248   return false;
1249 }
1250 
1251 /// Gather all lower and upper bounds of the identifier at `pos`, and
1252 /// optionally any equalities on it. In addition, the bounds are to be
1253 /// independent of identifiers in position range [`offset`, `offset` + `num`).
getLowerAndUpperBoundIndices(unsigned pos,SmallVectorImpl<unsigned> * lbIndices,SmallVectorImpl<unsigned> * ubIndices,SmallVectorImpl<unsigned> * eqIndices,unsigned offset,unsigned num) const1254 void FlatAffineConstraints::getLowerAndUpperBoundIndices(
1255     unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
1256     SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
1257     unsigned offset, unsigned num) const {
1258   assert(pos < getNumIds() && "invalid position");
1259   assert(offset + num < getNumCols() && "invalid range");
1260 
1261   // Checks for a constraint that has a non-zero coeff for the identifiers in
1262   // the position range [offset, offset + num) while ignoring `pos`.
1263   auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
1264     unsigned c, f;
1265     auto cst = isEq ? getEquality(r) : getInequality(r);
1266     for (c = offset, f = offset + num; c < f; ++c) {
1267       if (c == pos)
1268         continue;
1269       if (cst[c] != 0)
1270         break;
1271     }
1272     return c < f;
1273   };
1274 
1275   // Gather all lower bounds and upper bounds of the variable. Since the
1276   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1277   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1278   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1279     // The bounds are to be independent of [offset, offset + num) columns.
1280     if (containsConstraintDependentOnRange(r, /*isEq=*/false))
1281       continue;
1282     if (atIneq(r, pos) >= 1) {
1283       // Lower bound.
1284       lbIndices->push_back(r);
1285     } else if (atIneq(r, pos) <= -1) {
1286       // Upper bound.
1287       ubIndices->push_back(r);
1288     }
1289   }
1290 
1291   // An equality is both a lower and upper bound. Record any equalities
1292   // involving the pos^th identifier.
1293   if (!eqIndices)
1294     return;
1295 
1296   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1297     if (atEq(r, pos) == 0)
1298       continue;
1299     if (containsConstraintDependentOnRange(r, /*isEq=*/true))
1300       continue;
1301     eqIndices->push_back(r);
1302   }
1303 }
1304 
1305 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
1306 /// function of other identifiers (where the divisor is a positive constant)
1307 /// given the initial set of expressions in `exprs`. If it can be, the
1308 /// corresponding position in `exprs` is set as the detected affine expr. For
1309 /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
1310 /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
1311 /// <= i <= 32q + 31 => q = i floordiv 32.
detectAsFloorDiv(const FlatAffineConstraints & cst,unsigned pos,MLIRContext * context,SmallVectorImpl<AffineExpr> & exprs)1312 static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
1313                              MLIRContext *context,
1314                              SmallVectorImpl<AffineExpr> &exprs) {
1315   assert(pos < cst.getNumIds() && "invalid position");
1316 
1317   SmallVector<unsigned, 4> lbIndices, ubIndices;
1318   cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices);
1319 
1320   // Check if any lower bound, upper bound pair is of the form:
1321   // divisor * id >=  expr - (divisor - 1)    <-- Lower bound for 'id'
1322   // divisor * id <=  expr                    <-- Upper bound for 'id'
1323   // Then, 'id' is equivalent to 'expr floordiv divisor'.  (where divisor > 1).
1324   //
1325   // For example, if -32*k + 16*i + j >= 0
1326   //                  32*k - 16*i - j + 31 >= 0   <=>
1327   //             k = ( 16*i + j ) floordiv 32
1328   unsigned seenDividends = 0;
1329   for (auto ubPos : ubIndices) {
1330     for (auto lbPos : lbIndices) {
1331       // Check if the lower bound's constant term is divisor - 1. The
1332       // 'divisor' here is cst.atIneq(lbPos, pos) and we already know that it's
1333       // positive (since cst.Ineq(lbPos, ...) is a lower bound expr for 'pos'.
1334       int64_t divisor = cst.atIneq(lbPos, pos);
1335       int64_t lbConstTerm = cst.atIneq(lbPos, cst.getNumCols() - 1);
1336       if (lbConstTerm != divisor - 1)
1337         continue;
1338       // Check if upper bound's constant term is 0.
1339       if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
1340         continue;
1341       // For the remaining part, check if the lower bound expr's coeff's are
1342       // negations of corresponding upper bound ones'.
1343       unsigned c, f;
1344       for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1345         if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
1346           break;
1347         if (c != pos && cst.atIneq(lbPos, c) != 0)
1348           seenDividends++;
1349       }
1350       // Lb coeff's aren't negative of ub coeff's (for the non constant term
1351       // part).
1352       if (c < f)
1353         continue;
1354       if (seenDividends >= 1) {
1355         // Construct the dividend expression.
1356         auto dividendExpr = getAffineConstantExpr(0, context);
1357         unsigned c, f;
1358         for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
1359           if (c == pos)
1360             continue;
1361           int64_t ubVal = cst.atIneq(ubPos, c);
1362           if (ubVal == 0)
1363             continue;
1364           if (!exprs[c])
1365             break;
1366           dividendExpr = dividendExpr + ubVal * exprs[c];
1367         }
1368         // Expression can't be constructed as it depends on a yet unknown
1369         // identifier.
1370         // TODO: Visit/compute the identifiers in an order so that this doesn't
1371         // happen. More complex but much more efficient.
1372         if (c < f)
1373           continue;
1374         // Successfully detected the floordiv.
1375         exprs[pos] = dividendExpr.floorDiv(divisor);
1376         return true;
1377       }
1378     }
1379   }
1380   return false;
1381 }
1382 
1383 // Fills an inequality row with the value 'val'.
fillInequality(FlatAffineConstraints * cst,unsigned r,int64_t val)1384 static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
1385                                   int64_t val) {
1386   for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1387     cst->atIneq(r, c) = val;
1388   }
1389 }
1390 
1391 // Negates an inequality.
negateInequality(FlatAffineConstraints * cst,unsigned r)1392 static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
1393   for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
1394     cst->atIneq(r, c) = -cst->atIneq(r, c);
1395   }
1396 }
1397 
1398 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
1399 // to check if a constraint is redundant.
removeRedundantInequalities()1400 void FlatAffineConstraints::removeRedundantInequalities() {
1401   SmallVector<bool, 32> redun(getNumInequalities(), false);
1402   // To check if an inequality is redundant, we replace the inequality by its
1403   // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
1404   // system is empty. If it is, the inequality is redundant.
1405   FlatAffineConstraints tmpCst(*this);
1406   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1407     // Change the inequality to its complement.
1408     negateInequality(&tmpCst, r);
1409     tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
1410     if (tmpCst.isEmpty()) {
1411       redun[r] = true;
1412       // Zero fill the redundant inequality.
1413       fillInequality(this, r, /*val=*/0);
1414       fillInequality(&tmpCst, r, /*val=*/0);
1415     } else {
1416       // Reverse the change (to avoid recreating tmpCst each time).
1417       tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
1418       negateInequality(&tmpCst, r);
1419     }
1420   }
1421 
1422   // Scan to get rid of all rows marked redundant, in-place.
1423   auto copyRow = [&](unsigned src, unsigned dest) {
1424     if (src == dest)
1425       return;
1426     for (unsigned c = 0, e = getNumCols(); c < e; c++) {
1427       atIneq(dest, c) = atIneq(src, c);
1428     }
1429   };
1430   unsigned pos = 0;
1431   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1432     if (!redun[r])
1433       copyRow(r, pos++);
1434   }
1435   inequalities.resize(numReservedCols * pos);
1436 }
1437 
1438 // A more complex check to eliminate redundant inequalities and equalities. Uses
1439 // Simplex to check if a constraint is redundant.
removeRedundantConstraints()1440 void FlatAffineConstraints::removeRedundantConstraints() {
1441   // First, we run GCDTightenInequalities. This allows us to catch some
1442   // constraints which are not redundant when considering rational solutions
1443   // but are redundant in terms of integer solutions.
1444   GCDTightenInequalities();
1445   Simplex simplex(*this);
1446   simplex.detectRedundant();
1447 
1448   auto copyInequality = [&](unsigned src, unsigned dest) {
1449     if (src == dest)
1450       return;
1451     for (unsigned c = 0, e = getNumCols(); c < e; c++)
1452       atIneq(dest, c) = atIneq(src, c);
1453   };
1454   unsigned pos = 0;
1455   unsigned numIneqs = getNumInequalities();
1456   // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
1457   // the first constraints added are the inequalities.
1458   for (unsigned r = 0; r < numIneqs; r++) {
1459     if (!simplex.isMarkedRedundant(r))
1460       copyInequality(r, pos++);
1461   }
1462   inequalities.resize(numReservedCols * pos);
1463 
1464   // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
1465   // after the inequalities, a pair of constraints for each equality is added.
1466   // An equality is redundant if both the inequalities in its pair are
1467   // redundant.
1468   auto copyEquality = [&](unsigned src, unsigned dest) {
1469     if (src == dest)
1470       return;
1471     for (unsigned c = 0, e = getNumCols(); c < e; c++)
1472       atEq(dest, c) = atEq(src, c);
1473   };
1474   pos = 0;
1475   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1476     if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
1477           simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
1478       copyEquality(r, pos++);
1479   }
1480   equalities.resize(numReservedCols * pos);
1481 }
1482 
getLowerAndUpperBound(unsigned pos,unsigned offset,unsigned num,unsigned symStartPos,ArrayRef<AffineExpr> localExprs,MLIRContext * context) const1483 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
1484     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
1485     ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
1486   assert(pos + offset < getNumDimIds() && "invalid dim start pos");
1487   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
1488   assert(getNumLocalIds() == localExprs.size() &&
1489          "incorrect local exprs count");
1490 
1491   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
1492   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
1493                                offset, num);
1494 
1495   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
1496   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
1497     b.clear();
1498     for (unsigned i = 0, e = a.size(); i < e; ++i) {
1499       if (i < offset || i >= offset + num)
1500         b.push_back(a[i]);
1501     }
1502   };
1503 
1504   SmallVector<int64_t, 8> lb, ub;
1505   SmallVector<AffineExpr, 4> lbExprs;
1506   unsigned dimCount = symStartPos - num;
1507   unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
1508   lbExprs.reserve(lbIndices.size() + eqIndices.size());
1509   // Lower bound expressions.
1510   for (auto idx : lbIndices) {
1511     auto ineq = getInequality(idx);
1512     // Extract the lower bound (in terms of other coeff's + const), i.e., if
1513     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
1514     // - 1.
1515     addCoeffs(ineq, lb);
1516     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
1517     auto expr =
1518         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
1519     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
1520     int64_t divisor = std::abs(ineq[pos + offset]);
1521     expr = (expr + divisor - 1).floorDiv(divisor);
1522     lbExprs.push_back(expr);
1523   }
1524 
1525   SmallVector<AffineExpr, 4> ubExprs;
1526   ubExprs.reserve(ubIndices.size() + eqIndices.size());
1527   // Upper bound expressions.
1528   for (auto idx : ubIndices) {
1529     auto ineq = getInequality(idx);
1530     // Extract the upper bound (in terms of other coeff's + const).
1531     addCoeffs(ineq, ub);
1532     auto expr =
1533         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
1534     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
1535     // Upper bound is exclusive.
1536     ubExprs.push_back(expr + 1);
1537   }
1538 
1539   // Equalities. It's both a lower and a upper bound.
1540   SmallVector<int64_t, 4> b;
1541   for (auto idx : eqIndices) {
1542     auto eq = getEquality(idx);
1543     addCoeffs(eq, b);
1544     if (eq[pos + offset] > 0)
1545       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
1546 
1547     // Extract the upper bound (in terms of other coeff's + const).
1548     auto expr =
1549         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1550     expr = expr.floorDiv(std::abs(eq[pos + offset]));
1551     // Upper bound is exclusive.
1552     ubExprs.push_back(expr + 1);
1553     // Lower bound.
1554     expr =
1555         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
1556     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
1557     lbExprs.push_back(expr);
1558   }
1559 
1560   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
1561   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
1562 
1563   return {lbMap, ubMap};
1564 }
1565 
1566 /// Computes the lower and upper bounds of the first 'num' dimensional
1567 /// identifiers (starting at 'offset') as affine maps of the remaining
1568 /// identifiers (dimensional and symbolic identifiers). Local identifiers are
1569 /// themselves explicitly computed as affine functions of other identifiers in
1570 /// this process if needed.
getSliceBounds(unsigned offset,unsigned num,MLIRContext * context,SmallVectorImpl<AffineMap> * lbMaps,SmallVectorImpl<AffineMap> * ubMaps)1571 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
1572                                            MLIRContext *context,
1573                                            SmallVectorImpl<AffineMap> *lbMaps,
1574                                            SmallVectorImpl<AffineMap> *ubMaps) {
1575   assert(num < getNumDimIds() && "invalid range");
1576 
1577   // Basic simplification.
1578   normalizeConstraintsByGCD();
1579 
1580   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
1581                           << " identifiers\n");
1582   LLVM_DEBUG(dump());
1583 
1584   // Record computed/detected identifiers.
1585   SmallVector<AffineExpr, 8> memo(getNumIds());
1586   // Initialize dimensional and symbolic identifiers.
1587   for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
1588     if (i < offset)
1589       memo[i] = getAffineDimExpr(i, context);
1590     else if (i >= offset + num)
1591       memo[i] = getAffineDimExpr(i - num, context);
1592   }
1593   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
1594     memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
1595 
1596   bool changed;
1597   do {
1598     changed = false;
1599     // Identify yet unknown identifiers as constants or mod's / floordiv's of
1600     // other identifiers if possible.
1601     for (unsigned pos = 0; pos < getNumIds(); pos++) {
1602       if (memo[pos])
1603         continue;
1604 
1605       auto lbConst = getConstantLowerBound(pos);
1606       auto ubConst = getConstantUpperBound(pos);
1607       if (lbConst.hasValue() && ubConst.hasValue()) {
1608         // Detect equality to a constant.
1609         if (lbConst.getValue() == ubConst.getValue()) {
1610           memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
1611           changed = true;
1612           continue;
1613         }
1614 
1615         // Detect an identifier as modulo of another identifier w.r.t a
1616         // constant.
1617         if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
1618                         &memo)) {
1619           changed = true;
1620           continue;
1621         }
1622       }
1623 
1624       // Detect an identifier as a floordiv of an affine function of other
1625       // identifiers (divisor is a positive constant).
1626       if (detectAsFloorDiv(*this, pos, context, memo)) {
1627         changed = true;
1628         continue;
1629       }
1630 
1631       // Detect an identifier as an expression of other identifiers.
1632       unsigned idx;
1633       if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
1634         continue;
1635       }
1636 
1637       // Build AffineExpr solving for identifier 'pos' in terms of all others.
1638       auto expr = getAffineConstantExpr(0, context);
1639       unsigned j, e;
1640       for (j = 0, e = getNumIds(); j < e; ++j) {
1641         if (j == pos)
1642           continue;
1643         int64_t c = atEq(idx, j);
1644         if (c == 0)
1645           continue;
1646         // If any of the involved IDs hasn't been found yet, we can't proceed.
1647         if (!memo[j])
1648           break;
1649         expr = expr + memo[j] * c;
1650       }
1651       if (j < e)
1652         // Can't construct expression as it depends on a yet uncomputed
1653         // identifier.
1654         continue;
1655 
1656       // Add constant term to AffineExpr.
1657       expr = expr + atEq(idx, getNumIds());
1658       int64_t vPos = atEq(idx, pos);
1659       assert(vPos != 0 && "expected non-zero here");
1660       if (vPos > 0)
1661         expr = (-expr).floorDiv(vPos);
1662       else
1663         // vPos < 0.
1664         expr = expr.floorDiv(-vPos);
1665       // Successfully constructed expression.
1666       memo[pos] = expr;
1667       changed = true;
1668     }
1669     // This loop is guaranteed to reach a fixed point - since once an
1670     // identifier's explicit form is computed (in memo[pos]), it's not updated
1671     // again.
1672   } while (changed);
1673 
1674   // Set the lower and upper bound maps for all the identifiers that were
1675   // computed as affine expressions of the rest as the "detected expr" and
1676   // "detected expr + 1" respectively; set the undetected ones to null.
1677   Optional<FlatAffineConstraints> tmpClone;
1678   for (unsigned pos = 0; pos < num; pos++) {
1679     unsigned numMapDims = getNumDimIds() - num;
1680     unsigned numMapSymbols = getNumSymbolIds();
1681     AffineExpr expr = memo[pos + offset];
1682     if (expr)
1683       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1684 
1685     AffineMap &lbMap = (*lbMaps)[pos];
1686     AffineMap &ubMap = (*ubMaps)[pos];
1687 
1688     if (expr) {
1689       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1690       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
1691     } else {
1692       // TODO: Whenever there are local identifiers in the dependence
1693       // constraints, we'll conservatively over-approximate, since we don't
1694       // always explicitly compute them above (in the while loop).
1695       if (getNumLocalIds() == 0) {
1696         // Work on a copy so that we don't update this constraint system.
1697         if (!tmpClone) {
1698           tmpClone.emplace(FlatAffineConstraints(*this));
1699           // Removing redundant inequalities is necessary so that we don't get
1700           // redundant loop bounds.
1701           tmpClone->removeRedundantInequalities();
1702         }
1703         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1704             pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
1705       }
1706 
1707       // If the above fails, we'll just use the constant lower bound and the
1708       // constant upper bound (if they exist) as the slice bounds.
1709       // TODO: being conservative for the moment in cases that
1710       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1711       // fixed (b/126426796).
1712       if (!lbMap || lbMap.getNumResults() > 1) {
1713         LLVM_DEBUG(llvm::dbgs()
1714                    << "WARNING: Potentially over-approximating slice lb\n");
1715         auto lbConst = getConstantLowerBound(pos + offset);
1716         if (lbConst.hasValue()) {
1717           lbMap = AffineMap::get(
1718               numMapDims, numMapSymbols,
1719               getAffineConstantExpr(lbConst.getValue(), context));
1720         }
1721       }
1722       if (!ubMap || ubMap.getNumResults() > 1) {
1723         LLVM_DEBUG(llvm::dbgs()
1724                    << "WARNING: Potentially over-approximating slice ub\n");
1725         auto ubConst = getConstantUpperBound(pos + offset);
1726         if (ubConst.hasValue()) {
1727           (ubMap) = AffineMap::get(
1728               numMapDims, numMapSymbols,
1729               getAffineConstantExpr(ubConst.getValue() + 1, context));
1730         }
1731       }
1732     }
1733     LLVM_DEBUG(llvm::dbgs()
1734                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1735     LLVM_DEBUG(lbMap.dump(););
1736     LLVM_DEBUG(llvm::dbgs()
1737                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1738     LLVM_DEBUG(ubMap.dump(););
1739   }
1740 }
1741 
1742 LogicalResult
addLowerOrUpperBound(unsigned pos,AffineMap boundMap,ValueRange boundOperands,bool eq,bool lower)1743 FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
1744                                             ValueRange boundOperands, bool eq,
1745                                             bool lower) {
1746   assert(pos < getNumDimAndSymbolIds() && "invalid position");
1747   // Equality follows the logic of lower bound except that we add an equality
1748   // instead of an inequality.
1749   assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
1750   if (eq)
1751     lower = true;
1752 
1753   // Fully compose map and operands; canonicalize and simplify so that we
1754   // transitively get to terminal symbols or loop IVs.
1755   auto map = boundMap;
1756   SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1757   fullyComposeAffineMapAndOperands(&map, &operands);
1758   map = simplifyAffineMap(map);
1759   canonicalizeMapAndOperands(&map, &operands);
1760   for (auto operand : operands)
1761     addInductionVarOrTerminalSymbol(operand);
1762 
1763   FlatAffineConstraints localVarCst;
1764   std::vector<SmallVector<int64_t, 8>> flatExprs;
1765   if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
1766     LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
1767     return failure();
1768   }
1769 
1770   // Merge and align with localVarCst.
1771   if (localVarCst.getNumLocalIds() > 0) {
1772     // Set values for localVarCst.
1773     localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
1774     for (auto operand : operands) {
1775       unsigned pos;
1776       if (findId(operand, &pos)) {
1777         if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
1778           // If the local var cst has this as a dim, turn it into its symbol.
1779           turnDimIntoSymbol(&localVarCst, operand);
1780         } else if (pos < getNumDimIds()) {
1781           // Or vice versa.
1782           turnSymbolIntoDim(&localVarCst, operand);
1783         }
1784       }
1785     }
1786     mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
1787     append(localVarCst);
1788   }
1789 
1790   // Record positions of the operands in the constraint system. Need to do
1791   // this here since the constraint system changes after a bound is added.
1792   SmallVector<unsigned, 8> positions;
1793   unsigned numOperands = operands.size();
1794   for (auto operand : operands) {
1795     unsigned pos;
1796     if (!findId(operand, &pos))
1797       assert(0 && "expected to be found");
1798     positions.push_back(pos);
1799   }
1800 
1801   for (const auto &flatExpr : flatExprs) {
1802     SmallVector<int64_t, 4> ineq(getNumCols(), 0);
1803     ineq[pos] = lower ? 1 : -1;
1804     // Dims and symbols.
1805     for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
1806       ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
1807     }
1808     // Copy over the local id coefficients.
1809     unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
1810     for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
1811          jj++, j++) {
1812       ineq[j] =
1813           lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
1814     }
1815     // Constant term.
1816     ineq[getNumCols() - 1] =
1817         lower ? -flatExpr[flatExpr.size() - 1]
1818               // Upper bound in flattenedExpr is an exclusive one.
1819               : flatExpr[flatExpr.size() - 1] - 1;
1820     eq ? addEquality(ineq) : addInequality(ineq);
1821   }
1822   return success();
1823 }
1824 
1825 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1826 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
1827 // system. Note that both lower/upper bounds share the same operand list
1828 // 'operands'.
1829 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1830 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1831 // Note that both lower/upper bounds use operands from 'operands'.
1832 // Returns failure for unimplemented cases such as semi-affine expressions or
1833 // expressions with mod/floordiv.
addSliceBounds(ArrayRef<Value> values,ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)1834 LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
1835                                                     ArrayRef<AffineMap> lbMaps,
1836                                                     ArrayRef<AffineMap> ubMaps,
1837                                                     ArrayRef<Value> operands) {
1838   assert(values.size() == lbMaps.size());
1839   assert(lbMaps.size() == ubMaps.size());
1840 
1841   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1842     unsigned pos;
1843     if (!findId(values[i], &pos))
1844       continue;
1845 
1846     AffineMap lbMap = lbMaps[i];
1847     AffineMap ubMap = ubMaps[i];
1848     assert(!lbMap || lbMap.getNumInputs() == operands.size());
1849     assert(!ubMap || ubMap.getNumInputs() == operands.size());
1850 
1851     // Check if this slice is just an equality along this dimension.
1852     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1853         ubMap.getNumResults() == 1 &&
1854         lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1855       if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
1856                                       /*lower=*/true)))
1857         return failure();
1858       continue;
1859     }
1860 
1861     if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
1862                                              /*lower=*/true)))
1863       return failure();
1864 
1865     if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
1866                                              /*lower=*/false)))
1867       return failure();
1868   }
1869   return success();
1870 }
1871 
addEquality(ArrayRef<int64_t> eq)1872 void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
1873   assert(eq.size() == getNumCols());
1874   unsigned offset = equalities.size();
1875   equalities.resize(equalities.size() + numReservedCols);
1876   std::copy(eq.begin(), eq.end(), equalities.begin() + offset);
1877 }
1878 
addInequality(ArrayRef<int64_t> inEq)1879 void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
1880   assert(inEq.size() == getNumCols());
1881   unsigned offset = inequalities.size();
1882   inequalities.resize(inequalities.size() + numReservedCols);
1883   std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset);
1884 }
1885 
addConstantLowerBound(unsigned pos,int64_t lb)1886 void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
1887   assert(pos < getNumCols());
1888   unsigned offset = inequalities.size();
1889   inequalities.resize(inequalities.size() + numReservedCols);
1890   std::fill(inequalities.begin() + offset,
1891             inequalities.begin() + offset + getNumCols(), 0);
1892   inequalities[offset + pos] = 1;
1893   inequalities[offset + getNumCols() - 1] = -lb;
1894 }
1895 
addConstantUpperBound(unsigned pos,int64_t ub)1896 void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
1897   assert(pos < getNumCols());
1898   unsigned offset = inequalities.size();
1899   inequalities.resize(inequalities.size() + numReservedCols);
1900   std::fill(inequalities.begin() + offset,
1901             inequalities.begin() + offset + getNumCols(), 0);
1902   inequalities[offset + pos] = -1;
1903   inequalities[offset + getNumCols() - 1] = ub;
1904 }
1905 
addConstantLowerBound(ArrayRef<int64_t> expr,int64_t lb)1906 void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
1907                                                   int64_t lb) {
1908   assert(expr.size() == getNumCols());
1909   unsigned offset = inequalities.size();
1910   inequalities.resize(inequalities.size() + numReservedCols);
1911   std::fill(inequalities.begin() + offset,
1912             inequalities.begin() + offset + getNumCols(), 0);
1913   std::copy(expr.begin(), expr.end(), inequalities.begin() + offset);
1914   inequalities[offset + getNumCols() - 1] += -lb;
1915 }
1916 
addConstantUpperBound(ArrayRef<int64_t> expr,int64_t ub)1917 void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
1918                                                   int64_t ub) {
1919   assert(expr.size() == getNumCols());
1920   unsigned offset = inequalities.size();
1921   inequalities.resize(inequalities.size() + numReservedCols);
1922   std::fill(inequalities.begin() + offset,
1923             inequalities.begin() + offset + getNumCols(), 0);
1924   for (unsigned i = 0, e = getNumCols(); i < e; i++) {
1925     inequalities[offset + i] = -expr[i];
1926   }
1927   inequalities[offset + getNumCols() - 1] += ub;
1928 }
1929 
1930 /// Adds a new local identifier as the floordiv of an affine function of other
1931 /// identifiers, the coefficients of which are provided in 'dividend' and with
1932 /// respect to a positive constant 'divisor'. Two constraints are added to the
1933 /// system to capture equivalence with the floordiv.
1934 ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
addLocalFloorDiv(ArrayRef<int64_t> dividend,int64_t divisor)1935 void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1936                                              int64_t divisor) {
1937   assert(dividend.size() == getNumCols() && "incorrect dividend size");
1938   assert(divisor > 0 && "positive divisor expected");
1939 
1940   addLocalId(getNumLocalIds());
1941 
1942   // Add two constraints for this new identifier 'q'.
1943   SmallVector<int64_t, 8> bound(dividend.size() + 1);
1944 
1945   // dividend - q * divisor >= 0
1946   std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
1947             bound.begin());
1948   bound.back() = dividend.back();
1949   bound[getNumIds() - 1] = -divisor;
1950   addInequality(bound);
1951 
1952   // -dividend +qdivisor * q + divisor - 1 >= 0
1953   std::transform(bound.begin(), bound.end(), bound.begin(),
1954                  std::negate<int64_t>());
1955   bound[bound.size() - 1] += divisor - 1;
1956   addInequality(bound);
1957 }
1958 
findId(Value id,unsigned * pos) const1959 bool FlatAffineConstraints::findId(Value id, unsigned *pos) const {
1960   unsigned i = 0;
1961   for (const auto &mayBeId : ids) {
1962     if (mayBeId.hasValue() && mayBeId.getValue() == id) {
1963       *pos = i;
1964       return true;
1965     }
1966     i++;
1967   }
1968   return false;
1969 }
1970 
containsId(Value id) const1971 bool FlatAffineConstraints::containsId(Value id) const {
1972   return llvm::any_of(ids, [&](const Optional<Value> &mayBeId) {
1973     return mayBeId.hasValue() && mayBeId.getValue() == id;
1974   });
1975 }
1976 
swapId(unsigned posA,unsigned posB)1977 void FlatAffineConstraints::swapId(unsigned posA, unsigned posB) {
1978   assert(posA < getNumIds() && "invalid position A");
1979   assert(posB < getNumIds() && "invalid position B");
1980 
1981   if (posA == posB)
1982     return;
1983 
1984   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
1985     std::swap(atIneq(r, posA), atIneq(r, posB));
1986   for (unsigned r = 0, e = getNumEqualities(); r < e; r++)
1987     std::swap(atEq(r, posA), atEq(r, posB));
1988   std::swap(getId(posA), getId(posB));
1989 }
1990 
setDimSymbolSeparation(unsigned newSymbolCount)1991 void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
1992   assert(newSymbolCount <= numDims + numSymbols &&
1993          "invalid separation position");
1994   numDims = numDims + numSymbols - newSymbolCount;
1995   numSymbols = newSymbolCount;
1996 }
1997 
1998 /// Sets the specified identifier to a constant value.
setIdToConstant(unsigned pos,int64_t val)1999 void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
2000   unsigned offset = equalities.size();
2001   equalities.resize(equalities.size() + numReservedCols);
2002   std::fill(equalities.begin() + offset,
2003             equalities.begin() + offset + getNumCols(), 0);
2004   equalities[offset + pos] = 1;
2005   equalities[offset + getNumCols() - 1] = -val;
2006 }
2007 
2008 /// Sets the specified identifier to a constant value; asserts if the id is not
2009 /// found.
setIdToConstant(Value id,int64_t val)2010 void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) {
2011   unsigned pos;
2012   if (!findId(id, &pos))
2013     // This is a pre-condition for this method.
2014     assert(0 && "id not found");
2015   setIdToConstant(pos, val);
2016 }
2017 
removeEquality(unsigned pos)2018 void FlatAffineConstraints::removeEquality(unsigned pos) {
2019   unsigned numEqualities = getNumEqualities();
2020   assert(pos < numEqualities);
2021   unsigned outputIndex = pos * numReservedCols;
2022   unsigned inputIndex = (pos + 1) * numReservedCols;
2023   unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols;
2024   std::copy(equalities.begin() + inputIndex,
2025             equalities.begin() + inputIndex + numElemsToCopy,
2026             equalities.begin() + outputIndex);
2027   assert(equalities.size() >= numReservedCols);
2028   equalities.resize(equalities.size() - numReservedCols);
2029 }
2030 
removeInequality(unsigned pos)2031 void FlatAffineConstraints::removeInequality(unsigned pos) {
2032   unsigned numInequalities = getNumInequalities();
2033   assert(pos < numInequalities && "invalid position");
2034   unsigned outputIndex = pos * numReservedCols;
2035   unsigned inputIndex = (pos + 1) * numReservedCols;
2036   unsigned numElemsToCopy = (numInequalities - pos - 1) * numReservedCols;
2037   std::copy(inequalities.begin() + inputIndex,
2038             inequalities.begin() + inputIndex + numElemsToCopy,
2039             inequalities.begin() + outputIndex);
2040   assert(inequalities.size() >= numReservedCols);
2041   inequalities.resize(inequalities.size() - numReservedCols);
2042 }
2043 
2044 /// Finds an equality that equates the specified identifier to a constant.
2045 /// Returns the position of the equality row. If 'symbolic' is set to true,
2046 /// symbols are also treated like a constant, i.e., an affine function of the
2047 /// symbols is also treated like a constant. Returns -1 if such an equality
2048 /// could not be found.
findEqualityToConstant(const FlatAffineConstraints & cst,unsigned pos,bool symbolic=false)2049 static int findEqualityToConstant(const FlatAffineConstraints &cst,
2050                                   unsigned pos, bool symbolic = false) {
2051   assert(pos < cst.getNumIds() && "invalid position");
2052   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2053     int64_t v = cst.atEq(r, pos);
2054     if (v * v != 1)
2055       continue;
2056     unsigned c;
2057     unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
2058     // This checks for zeros in all positions other than 'pos' in [0, f)
2059     for (c = 0; c < f; c++) {
2060       if (c == pos)
2061         continue;
2062       if (cst.atEq(r, c) != 0) {
2063         // Dependent on another identifier.
2064         break;
2065       }
2066     }
2067     if (c == f)
2068       // Equality is free of other identifiers.
2069       return r;
2070   }
2071   return -1;
2072 }
2073 
setAndEliminate(unsigned pos,int64_t constVal)2074 void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) {
2075   assert(pos < getNumIds() && "invalid position");
2076   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2077     atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal;
2078   }
2079   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2080     atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal;
2081   }
2082   removeId(pos);
2083 }
2084 
constantFoldId(unsigned pos)2085 LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
2086   assert(pos < getNumIds() && "invalid position");
2087   int rowIdx;
2088   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
2089     return failure();
2090 
2091   // atEq(rowIdx, pos) is either -1 or 1.
2092   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
2093   int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
2094   setAndEliminate(pos, constVal);
2095   return success();
2096 }
2097 
constantFoldIdRange(unsigned pos,unsigned num)2098 void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
2099   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
2100     if (failed(constantFoldId(t)))
2101       t++;
2102   }
2103 }
2104 
2105 /// Returns the extent (upper bound - lower bound) of the specified
2106 /// identifier if it is found to be a constant; returns None if it's not a
2107 /// constant. This methods treats symbolic identifiers specially, i.e.,
2108 /// it looks for constant differences between affine expressions involving
2109 /// only the symbolic identifiers. See comments at function definition for
2110 /// example. 'lb', if provided, is set to the lower bound associated with the
2111 /// constant difference. Note that 'lb' is purely symbolic and thus will contain
2112 /// the coefficients of the symbolic identifiers and the constant coefficient.
2113 //  Egs: 0 <= i <= 15, return 16.
2114 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
2115 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
2116 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
2117 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
getConstantBoundOnDimSize(unsigned pos,SmallVectorImpl<int64_t> * lb,int64_t * boundFloorDivisor,SmallVectorImpl<int64_t> * ub,unsigned * minLbPos,unsigned * minUbPos) const2118 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
2119     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
2120     SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
2121     unsigned *minUbPos) const {
2122   assert(pos < getNumDimIds() && "Invalid identifier position");
2123 
2124   // Find an equality for 'pos'^th identifier that equates it to some function
2125   // of the symbolic identifiers (+ constant).
2126   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
2127   if (eqPos != -1) {
2128     auto eq = getEquality(eqPos);
2129     // If the equality involves a local var, punt for now.
2130     // TODO: this can be handled in the future by using the explicit
2131     // representation of the local vars.
2132     if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
2133                      [](int64_t coeff) { return coeff == 0; }))
2134       return None;
2135 
2136     // This identifier can only take a single value.
2137     if (lb) {
2138       // Set lb to that symbolic value.
2139       lb->resize(getNumSymbolIds() + 1);
2140       if (ub)
2141         ub->resize(getNumSymbolIds() + 1);
2142       for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
2143         int64_t v = atEq(eqPos, pos);
2144         // atEq(eqRow, pos) is either -1 or 1.
2145         assert(v * v == 1);
2146         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
2147                          : -atEq(eqPos, getNumDimIds() + c) / v;
2148         // Since this is an equality, ub = lb.
2149         if (ub)
2150           (*ub)[c] = (*lb)[c];
2151       }
2152       assert(boundFloorDivisor &&
2153              "both lb and divisor or none should be provided");
2154       *boundFloorDivisor = 1;
2155     }
2156     if (minLbPos)
2157       *minLbPos = eqPos;
2158     if (minUbPos)
2159       *minUbPos = eqPos;
2160     return 1;
2161   }
2162 
2163   // Check if the identifier appears at all in any of the inequalities.
2164   unsigned r, e;
2165   for (r = 0, e = getNumInequalities(); r < e; r++) {
2166     if (atIneq(r, pos) != 0)
2167       break;
2168   }
2169   if (r == e)
2170     // If it doesn't, there isn't a bound on it.
2171     return None;
2172 
2173   // Positions of constraints that are lower/upper bounds on the variable.
2174   SmallVector<unsigned, 4> lbIndices, ubIndices;
2175 
2176   // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
2177   // the bounds can only involve symbolic (and local) identifiers. Since the
2178   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2179   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2180   getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
2181                                /*eqIndices=*/nullptr, /*offset=*/0,
2182                                /*num=*/getNumDimIds());
2183 
2184   Optional<int64_t> minDiff = None;
2185   unsigned minLbPosition = 0, minUbPosition = 0;
2186   for (auto ubPos : ubIndices) {
2187     for (auto lbPos : lbIndices) {
2188       // Look for a lower bound and an upper bound that only differ by a
2189       // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
2190       // For example, if ii is the pos^th variable, we are looking for
2191       // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
2192       // minimum among all such constant differences is kept since that's the
2193       // constant bounding the extent of the pos^th variable.
2194       unsigned j, e;
2195       for (j = 0, e = getNumCols() - 1; j < e; j++)
2196         if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
2197           break;
2198         }
2199       if (j < getNumCols() - 1)
2200         continue;
2201       int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
2202                                  atIneq(lbPos, getNumCols() - 1) + 1,
2203                              atIneq(lbPos, pos));
2204       if (minDiff == None || diff < minDiff) {
2205         minDiff = diff;
2206         minLbPosition = lbPos;
2207         minUbPosition = ubPos;
2208       }
2209     }
2210   }
2211   if (lb && minDiff.hasValue()) {
2212     // Set lb to the symbolic lower bound.
2213     lb->resize(getNumSymbolIds() + 1);
2214     if (ub)
2215       ub->resize(getNumSymbolIds() + 1);
2216     // The lower bound is the ceildiv of the lb constraint over the coefficient
2217     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
2218     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
2219     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
2220     *boundFloorDivisor = atIneq(minLbPosition, pos);
2221     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
2222     for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
2223       (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
2224     }
2225     if (ub) {
2226       for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
2227         (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
2228     }
2229     // The lower bound leads to a ceildiv while the upper bound is a floordiv
2230     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
2231     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
2232     // the constant term for the lower bound.
2233     (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
2234   }
2235   if (minLbPos)
2236     *minLbPos = minLbPosition;
2237   if (minUbPos)
2238     *minUbPos = minUbPosition;
2239   return minDiff;
2240 }
2241 
2242 template <bool isLower>
2243 Optional<int64_t>
computeConstantLowerOrUpperBound(unsigned pos)2244 FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
2245   assert(pos < getNumIds() && "invalid position");
2246   // Project to 'pos'.
2247   projectOut(0, pos);
2248   projectOut(1, getNumIds() - 1);
2249   // Check if there's an equality equating the '0'^th identifier to a constant.
2250   int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
2251   if (eqRowIdx != -1)
2252     // atEq(rowIdx, 0) is either -1 or 1.
2253     return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
2254 
2255   // Check if the identifier appears at all in any of the inequalities.
2256   unsigned r, e;
2257   for (r = 0, e = getNumInequalities(); r < e; r++) {
2258     if (atIneq(r, 0) != 0)
2259       break;
2260   }
2261   if (r == e)
2262     // If it doesn't, there isn't a bound on it.
2263     return None;
2264 
2265   Optional<int64_t> minOrMaxConst = None;
2266 
2267   // Take the max across all const lower bounds (or min across all constant
2268   // upper bounds).
2269   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2270     if (isLower) {
2271       if (atIneq(r, 0) <= 0)
2272         // Not a lower bound.
2273         continue;
2274     } else if (atIneq(r, 0) >= 0) {
2275       // Not an upper bound.
2276       continue;
2277     }
2278     unsigned c, f;
2279     for (c = 0, f = getNumCols() - 1; c < f; c++)
2280       if (c != 0 && atIneq(r, c) != 0)
2281         break;
2282     if (c < getNumCols() - 1)
2283       // Not a constant bound.
2284       continue;
2285 
2286     int64_t boundConst =
2287         isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
2288                 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
2289     if (isLower) {
2290       if (minOrMaxConst == None || boundConst > minOrMaxConst)
2291         minOrMaxConst = boundConst;
2292     } else {
2293       if (minOrMaxConst == None || boundConst < minOrMaxConst)
2294         minOrMaxConst = boundConst;
2295     }
2296   }
2297   return minOrMaxConst;
2298 }
2299 
2300 Optional<int64_t>
getConstantLowerBound(unsigned pos) const2301 FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
2302   FlatAffineConstraints tmpCst(*this);
2303   return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
2304 }
2305 
2306 Optional<int64_t>
getConstantUpperBound(unsigned pos) const2307 FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
2308   FlatAffineConstraints tmpCst(*this);
2309   return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
2310 }
2311 
2312 // A simple (naive and conservative) check for hyper-rectangularity.
isHyperRectangular(unsigned pos,unsigned num) const2313 bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
2314                                                unsigned num) const {
2315   assert(pos < getNumCols() - 1);
2316   // Check for two non-zero coefficients in the range [pos, pos + sum).
2317   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2318     unsigned sum = 0;
2319     for (unsigned c = pos; c < pos + num; c++) {
2320       if (atIneq(r, c) != 0)
2321         sum++;
2322     }
2323     if (sum > 1)
2324       return false;
2325   }
2326   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2327     unsigned sum = 0;
2328     for (unsigned c = pos; c < pos + num; c++) {
2329       if (atEq(r, c) != 0)
2330         sum++;
2331     }
2332     if (sum > 1)
2333       return false;
2334   }
2335   return true;
2336 }
2337 
print(raw_ostream & os) const2338 void FlatAffineConstraints::print(raw_ostream &os) const {
2339   assert(hasConsistentState());
2340   os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
2341      << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
2342      << " constraints)\n";
2343   os << "(";
2344   for (unsigned i = 0, e = getNumIds(); i < e; i++) {
2345     if (ids[i] == None)
2346       os << "None ";
2347     else
2348       os << "Value ";
2349   }
2350   os << " const)\n";
2351   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2352     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2353       os << atEq(i, j) << " ";
2354     }
2355     os << "= 0\n";
2356   }
2357   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2358     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2359       os << atIneq(i, j) << " ";
2360     }
2361     os << ">= 0\n";
2362   }
2363   os << '\n';
2364 }
2365 
dump() const2366 void FlatAffineConstraints::dump() const { print(llvm::errs()); }
2367 
2368 /// Removes duplicate constraints, trivially true constraints, and constraints
2369 /// that can be detected as redundant as a result of differing only in their
2370 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
2371 /// considered trivially true.
2372 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
2373 //  remove duplicates in place.
removeTrivialRedundancy()2374 void FlatAffineConstraints::removeTrivialRedundancy() {
2375   GCDTightenInequalities();
2376   normalizeConstraintsByGCD();
2377 
2378   // A map used to detect redundancy stemming from constraints that only differ
2379   // in their constant term. The value stored is <row position, const term>
2380   // for a given row.
2381   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
2382       rowsWithoutConstTerm;
2383   // To unique rows.
2384   SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
2385 
2386   // Check if constraint is of the form <non-negative-constant> >= 0.
2387   auto isTriviallyValid = [&](unsigned r) -> bool {
2388     for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
2389       if (atIneq(r, c) != 0)
2390         return false;
2391     }
2392     return atIneq(r, getNumCols() - 1) >= 0;
2393   };
2394 
2395   // Detect and mark redundant constraints.
2396   SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
2397   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2398     int64_t *rowStart = inequalities.data() + numReservedCols * r;
2399     auto row = ArrayRef<int64_t>(rowStart, getNumCols());
2400     if (isTriviallyValid(r) || !rowSet.insert(row).second) {
2401       redunIneq[r] = true;
2402       continue;
2403     }
2404 
2405     // Among constraints that only differ in the constant term part, mark
2406     // everything other than the one with the smallest constant term redundant.
2407     // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
2408     // former two are redundant).
2409     int64_t constTerm = atIneq(r, getNumCols() - 1);
2410     auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
2411     const auto &ret =
2412         rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
2413     if (!ret.second) {
2414       // Check if the other constraint has a higher constant term.
2415       auto &val = ret.first->second;
2416       if (val.second > constTerm) {
2417         // The stored row is redundant. Mark it so, and update with this one.
2418         redunIneq[val.first] = true;
2419         val = {r, constTerm};
2420       } else {
2421         // The one stored makes this one redundant.
2422         redunIneq[r] = true;
2423       }
2424     }
2425   }
2426 
2427   auto copyRow = [&](unsigned src, unsigned dest) {
2428     if (src == dest)
2429       return;
2430     for (unsigned c = 0, e = getNumCols(); c < e; c++) {
2431       atIneq(dest, c) = atIneq(src, c);
2432     }
2433   };
2434 
2435   // Scan to get rid of all rows marked redundant, in-place.
2436   unsigned pos = 0;
2437   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2438     if (!redunIneq[r])
2439       copyRow(r, pos++);
2440   }
2441   inequalities.resize(numReservedCols * pos);
2442 
2443   // TODO: consider doing this for equalities as well, but probably not worth
2444   // the savings.
2445 }
2446 
clearAndCopyFrom(const FlatAffineConstraints & other)2447 void FlatAffineConstraints::clearAndCopyFrom(
2448     const FlatAffineConstraints &other) {
2449   FlatAffineConstraints copy(other);
2450   std::swap(*this, copy);
2451   assert(copy.getNumIds() == copy.getIds().size());
2452 }
2453 
removeId(unsigned pos)2454 void FlatAffineConstraints::removeId(unsigned pos) {
2455   removeIdRange(pos, pos + 1);
2456 }
2457 
2458 static std::pair<unsigned, unsigned>
getNewNumDimsSymbols(unsigned pos,const FlatAffineConstraints & cst)2459 getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
2460   unsigned numDims = cst.getNumDimIds();
2461   unsigned numSymbols = cst.getNumSymbolIds();
2462   unsigned newNumDims, newNumSymbols;
2463   if (pos < numDims) {
2464     newNumDims = numDims - 1;
2465     newNumSymbols = numSymbols;
2466   } else if (pos < numDims + numSymbols) {
2467     assert(numSymbols >= 1);
2468     newNumDims = numDims;
2469     newNumSymbols = numSymbols - 1;
2470   } else {
2471     newNumDims = numDims;
2472     newNumSymbols = numSymbols;
2473   }
2474   return {newNumDims, newNumSymbols};
2475 }
2476 
2477 #undef DEBUG_TYPE
2478 #define DEBUG_TYPE "fm"
2479 
2480 /// Eliminates identifier at the specified position using Fourier-Motzkin
2481 /// variable elimination. This technique is exact for rational spaces but
2482 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
2483 /// to a projection operation yielding the (convex) set of integer points
2484 /// contained in the rational shadow of the set. An emptiness test that relies
2485 /// on this method will guarantee emptiness, i.e., it disproves the existence of
2486 /// a solution if it says it's empty.
2487 /// If a non-null isResultIntegerExact is passed, it is set to true if the
2488 /// result is also integer exact. If it's set to false, the obtained solution
2489 /// *may* not be exact, i.e., it may contain integer points that do not have an
2490 /// integer pre-image in the original set.
2491 ///
2492 /// Eg:
2493 /// j >= 0, j <= i + 1
2494 /// i >= 0, i <= N + 1
2495 /// Eliminating i yields,
2496 ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
2497 ///
2498 /// If darkShadow = true, this method computes the dark shadow on elimination;
2499 /// the dark shadow is a convex integer subset of the exact integer shadow. A
2500 /// non-empty dark shadow proves the existence of an integer solution. The
2501 /// elimination in such a case could however be an under-approximation, and thus
2502 /// should not be used for scanning sets or used by itself for dependence
2503 /// checking.
2504 ///
2505 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
2506 ///            ^
2507 ///            |
2508 ///            | * * * * o o
2509 ///         i  | * * o o o o
2510 ///            | o * * * * *
2511 ///            --------------->
2512 ///                 j ->
2513 ///
2514 /// Eliminating i from this system (projecting on the j dimension):
2515 /// rational shadow / integer light shadow:  1 <= j <= 6
2516 /// dark shadow:                             3 <= j <= 6
2517 /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
2518 /// holes/splinters:                         j = 2
2519 ///
2520 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
2521 // TODO: a slight modification to yield dark shadow version of FM (tightened),
2522 // which can prove the existence of a solution if there is one.
FourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)2523 void FlatAffineConstraints::FourierMotzkinEliminate(
2524     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
2525   LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
2526   LLVM_DEBUG(dump());
2527   assert(pos < getNumIds() && "invalid position");
2528   assert(hasConsistentState());
2529 
2530   // Check if this identifier can be eliminated through a substitution.
2531   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2532     if (atEq(r, pos) != 0) {
2533       // Use Gaussian elimination here (since we have an equality).
2534       LogicalResult ret = gaussianEliminateId(pos);
2535       (void)ret;
2536       assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
2537       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
2538       LLVM_DEBUG(dump());
2539       return;
2540     }
2541   }
2542 
2543   // A fast linear time tightening.
2544   GCDTightenInequalities();
2545 
2546   // Check if the identifier appears at all in any of the inequalities.
2547   unsigned r, e;
2548   for (r = 0, e = getNumInequalities(); r < e; r++) {
2549     if (atIneq(r, pos) != 0)
2550       break;
2551   }
2552   if (r == getNumInequalities()) {
2553     // If it doesn't appear, just remove the column and return.
2554     // TODO: refactor removeColumns to use it from here.
2555     removeId(pos);
2556     LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2557     LLVM_DEBUG(dump());
2558     return;
2559   }
2560 
2561   // Positions of constraints that are lower bounds on the variable.
2562   SmallVector<unsigned, 4> lbIndices;
2563   // Positions of constraints that are lower bounds on the variable.
2564   SmallVector<unsigned, 4> ubIndices;
2565   // Positions of constraints that do not involve the variable.
2566   std::vector<unsigned> nbIndices;
2567   nbIndices.reserve(getNumInequalities());
2568 
2569   // Gather all lower bounds and upper bounds of the variable. Since the
2570   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
2571   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
2572   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
2573     if (atIneq(r, pos) == 0) {
2574       // Id does not appear in bound.
2575       nbIndices.push_back(r);
2576     } else if (atIneq(r, pos) >= 1) {
2577       // Lower bound.
2578       lbIndices.push_back(r);
2579     } else {
2580       // Upper bound.
2581       ubIndices.push_back(r);
2582     }
2583   }
2584 
2585   // Set the number of dimensions, symbols in the resulting system.
2586   const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
2587   unsigned newNumDims = dimsSymbols.first;
2588   unsigned newNumSymbols = dimsSymbols.second;
2589 
2590   SmallVector<Optional<Value>, 8> newIds;
2591   newIds.reserve(numIds - 1);
2592   newIds.append(ids.begin(), ids.begin() + pos);
2593   newIds.append(ids.begin() + pos + 1, ids.end());
2594 
2595   /// Create the new system which has one identifier less.
2596   FlatAffineConstraints newFac(
2597       lbIndices.size() * ubIndices.size() + nbIndices.size(),
2598       getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
2599       /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
2600 
2601   assert(newFac.getIds().size() == newFac.getNumIds());
2602 
2603   // This will be used to check if the elimination was integer exact.
2604   unsigned lcmProducts = 1;
2605 
2606   // Let x be the variable we are eliminating.
2607   // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
2608   // that c_l, c_u >= 1) we have:
2609   // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
2610   // We thus generate a constraint:
2611   // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
2612   // Note if c_l = c_u = 1, all integer points captured by the resulting
2613   // constraint correspond to integer points in the original system (i.e., they
2614   // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
2615   // integer exact.
2616   for (auto ubPos : ubIndices) {
2617     for (auto lbPos : lbIndices) {
2618       SmallVector<int64_t, 4> ineq;
2619       ineq.reserve(newFac.getNumCols());
2620       int64_t lbCoeff = atIneq(lbPos, pos);
2621       // Note that in the comments above, ubCoeff is the negation of the
2622       // coefficient in the canonical form as the view taken here is that of the
2623       // term being moved to the other size of '>='.
2624       int64_t ubCoeff = -atIneq(ubPos, pos);
2625       // TODO: refactor this loop to avoid all branches inside.
2626       for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2627         if (l == pos)
2628           continue;
2629         assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
2630         int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
2631         ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
2632                        atIneq(lbPos, l) * (lcm / lbCoeff));
2633         lcmProducts *= lcm;
2634       }
2635       if (darkShadow) {
2636         // The dark shadow is a convex subset of the exact integer shadow. If
2637         // there is a point here, it proves the existence of a solution.
2638         ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
2639       }
2640       // TODO: we need to have a way to add inequalities in-place in
2641       // FlatAffineConstraints instead of creating and copying over.
2642       newFac.addInequality(ineq);
2643     }
2644   }
2645 
2646   LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
2647                           << "\n");
2648   if (lcmProducts == 1 && isResultIntegerExact)
2649     *isResultIntegerExact = true;
2650 
2651   // Copy over the constraints not involving this variable.
2652   for (auto nbPos : nbIndices) {
2653     SmallVector<int64_t, 4> ineq;
2654     ineq.reserve(getNumCols() - 1);
2655     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2656       if (l == pos)
2657         continue;
2658       ineq.push_back(atIneq(nbPos, l));
2659     }
2660     newFac.addInequality(ineq);
2661   }
2662 
2663   assert(newFac.getNumConstraints() ==
2664          lbIndices.size() * ubIndices.size() + nbIndices.size());
2665 
2666   // Copy over the equalities.
2667   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
2668     SmallVector<int64_t, 4> eq;
2669     eq.reserve(newFac.getNumCols());
2670     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
2671       if (l == pos)
2672         continue;
2673       eq.push_back(atEq(r, l));
2674     }
2675     newFac.addEquality(eq);
2676   }
2677 
2678   // GCD tightening and normalization allows detection of more trivially
2679   // redundant constraints.
2680   newFac.GCDTightenInequalities();
2681   newFac.normalizeConstraintsByGCD();
2682   newFac.removeTrivialRedundancy();
2683   clearAndCopyFrom(newFac);
2684   LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
2685   LLVM_DEBUG(dump());
2686 }
2687 
2688 #undef DEBUG_TYPE
2689 #define DEBUG_TYPE "affine-structures"
2690 
projectOut(unsigned pos,unsigned num)2691 void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
2692   if (num == 0)
2693     return;
2694 
2695   // 'pos' can be at most getNumCols() - 2 if num > 0.
2696   assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
2697   assert(pos + num < getNumCols() && "invalid range");
2698 
2699   // Eliminate as many identifiers as possible using Gaussian elimination.
2700   unsigned currentPos = pos;
2701   unsigned numToEliminate = num;
2702   unsigned numGaussianEliminated = 0;
2703 
2704   while (currentPos < getNumIds()) {
2705     unsigned curNumEliminated =
2706         gaussianEliminateIds(currentPos, currentPos + numToEliminate);
2707     ++currentPos;
2708     numToEliminate -= curNumEliminated + 1;
2709     numGaussianEliminated += curNumEliminated;
2710   }
2711 
2712   // Eliminate the remaining using Fourier-Motzkin.
2713   for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
2714     unsigned numToEliminate = num - numGaussianEliminated - i;
2715     FourierMotzkinEliminate(
2716         getBestIdToEliminate(*this, pos, pos + numToEliminate));
2717   }
2718 
2719   // Fast/trivial simplifications.
2720   GCDTightenInequalities();
2721   // Normalize constraints after tightening since the latter impacts this, but
2722   // not the other way round.
2723   normalizeConstraintsByGCD();
2724 }
2725 
projectOut(Value id)2726 void FlatAffineConstraints::projectOut(Value id) {
2727   unsigned pos;
2728   bool ret = findId(id, &pos);
2729   assert(ret);
2730   (void)ret;
2731   FourierMotzkinEliminate(pos);
2732 }
2733 
clearConstraints()2734 void FlatAffineConstraints::clearConstraints() {
2735   equalities.clear();
2736   inequalities.clear();
2737 }
2738 
2739 namespace {
2740 
2741 enum BoundCmpResult { Greater, Less, Equal, Unknown };
2742 
2743 /// Compares two affine bounds whose coefficients are provided in 'first' and
2744 /// 'second'. The last coefficient is the constant term.
compareBounds(ArrayRef<int64_t> a,ArrayRef<int64_t> b)2745 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
2746   assert(a.size() == b.size());
2747 
2748   // For the bounds to be comparable, their corresponding identifier
2749   // coefficients should be equal; the constant terms are then compared to
2750   // determine less/greater/equal.
2751 
2752   if (!std::equal(a.begin(), a.end() - 1, b.begin()))
2753     return Unknown;
2754 
2755   if (a.back() == b.back())
2756     return Equal;
2757 
2758   return a.back() < b.back() ? Less : Greater;
2759 }
2760 } // namespace
2761 
2762 // Returns constraints that are common to both A & B.
getCommonConstraints(const FlatAffineConstraints & A,const FlatAffineConstraints & B,FlatAffineConstraints & C)2763 static void getCommonConstraints(const FlatAffineConstraints &A,
2764                                  const FlatAffineConstraints &B,
2765                                  FlatAffineConstraints &C) {
2766   C.reset(A.getNumDimIds(), A.getNumSymbolIds(), A.getNumLocalIds());
2767   // A naive O(n^2) check should be enough here given the input sizes.
2768   for (unsigned r = 0, e = A.getNumInequalities(); r < e; ++r) {
2769     for (unsigned s = 0, f = B.getNumInequalities(); s < f; ++s) {
2770       if (A.getInequality(r) == B.getInequality(s)) {
2771         C.addInequality(A.getInequality(r));
2772         break;
2773       }
2774     }
2775   }
2776   for (unsigned r = 0, e = A.getNumEqualities(); r < e; ++r) {
2777     for (unsigned s = 0, f = B.getNumEqualities(); s < f; ++s) {
2778       if (A.getEquality(r) == B.getEquality(s)) {
2779         C.addEquality(A.getEquality(r));
2780         break;
2781       }
2782     }
2783   }
2784 }
2785 
2786 // Computes the bounding box with respect to 'other' by finding the min of the
2787 // lower bounds and the max of the upper bounds along each of the dimensions.
2788 LogicalResult
unionBoundingBox(const FlatAffineConstraints & otherCst)2789 FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
2790   assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
2791   assert(otherCst.getIds()
2792              .slice(0, getNumDimIds())
2793              .equals(getIds().slice(0, getNumDimIds())) &&
2794          "dim values mismatch");
2795   assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
2796   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
2797 
2798   // Align `other` to this.
2799   Optional<FlatAffineConstraints> otherCopy;
2800   if (!areIdsAligned(*this, otherCst)) {
2801     otherCopy.emplace(FlatAffineConstraints(otherCst));
2802     mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue());
2803   }
2804 
2805   const auto &otherAligned = otherCopy ? *otherCopy : otherCst;
2806 
2807   // Get the constraints common to both systems; these will be added as is to
2808   // the union.
2809   FlatAffineConstraints commonCst;
2810   getCommonConstraints(*this, otherAligned, commonCst);
2811 
2812   std::vector<SmallVector<int64_t, 8>> boundingLbs;
2813   std::vector<SmallVector<int64_t, 8>> boundingUbs;
2814   boundingLbs.reserve(2 * getNumDimIds());
2815   boundingUbs.reserve(2 * getNumDimIds());
2816 
2817   // To hold lower and upper bounds for each dimension.
2818   SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
2819   // To compute min of lower bounds and max of upper bounds for each dimension.
2820   SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
2821   SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
2822   // To compute final new lower and upper bounds for the union.
2823   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
2824 
2825   int64_t lbFloorDivisor, otherLbFloorDivisor;
2826   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2827     auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
2828     if (!extent.hasValue())
2829       // TODO: symbolic extents when necessary.
2830       // TODO: handle union if a dimension is unbounded.
2831       return failure();
2832 
2833     auto otherExtent = otherAligned.getConstantBoundOnDimSize(
2834         d, &otherLb, &otherLbFloorDivisor, &otherUb);
2835     if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
2836       // TODO: symbolic extents when necessary.
2837       return failure();
2838 
2839     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
2840 
2841     auto res = compareBounds(lb, otherLb);
2842     // Identify min.
2843     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
2844       minLb = lb;
2845       // Since the divisor is for a floordiv, we need to convert to ceildiv,
2846       // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
2847       // div * i >= expr - div + 1.
2848       minLb.back() -= lbFloorDivisor - 1;
2849     } else if (res == BoundCmpResult::Greater) {
2850       minLb = otherLb;
2851       minLb.back() -= otherLbFloorDivisor - 1;
2852     } else {
2853       // Uncomparable - check for constant lower/upper bounds.
2854       auto constLb = getConstantLowerBound(d);
2855       auto constOtherLb = otherAligned.getConstantLowerBound(d);
2856       if (!constLb.hasValue() || !constOtherLb.hasValue())
2857         return failure();
2858       std::fill(minLb.begin(), minLb.end(), 0);
2859       minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
2860     }
2861 
2862     // Do the same for ub's but max of upper bounds. Identify max.
2863     auto uRes = compareBounds(ub, otherUb);
2864     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
2865       maxUb = ub;
2866     } else if (uRes == BoundCmpResult::Less) {
2867       maxUb = otherUb;
2868     } else {
2869       // Uncomparable - check for constant lower/upper bounds.
2870       auto constUb = getConstantUpperBound(d);
2871       auto constOtherUb = otherAligned.getConstantUpperBound(d);
2872       if (!constUb.hasValue() || !constOtherUb.hasValue())
2873         return failure();
2874       std::fill(maxUb.begin(), maxUb.end(), 0);
2875       maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
2876     }
2877 
2878     std::fill(newLb.begin(), newLb.end(), 0);
2879     std::fill(newUb.begin(), newUb.end(), 0);
2880 
2881     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
2882     // and so it's the divisor for newLb and newUb as well.
2883     newLb[d] = lbFloorDivisor;
2884     newUb[d] = -lbFloorDivisor;
2885     // Copy over the symbolic part + constant term.
2886     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
2887     std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
2888                    newLb.begin() + getNumDimIds(), std::negate<int64_t>());
2889     std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
2890 
2891     boundingLbs.push_back(newLb);
2892     boundingUbs.push_back(newUb);
2893   }
2894 
2895   // Clear all constraints and add the lower/upper bounds for the bounding box.
2896   clearConstraints();
2897   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2898     addInequality(boundingLbs[d]);
2899     addInequality(boundingUbs[d]);
2900   }
2901 
2902   // Add the constraints that were common to both systems.
2903   append(commonCst);
2904   removeTrivialRedundancy();
2905 
2906   // TODO: copy over pure symbolic constraints from this and 'other' over to the
2907   // union (since the above are just the union along dimensions); we shouldn't
2908   // be discarding any other constraints on the symbols.
2909 
2910   return success();
2911 }
2912 
2913 /// Compute an explicit representation for local vars. For all systems coming
2914 /// from MLIR integer sets, maps, or expressions where local vars were
2915 /// introduced to model floordivs and mods, this always succeeds.
computeLocalVars(const FlatAffineConstraints & cst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)2916 static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
2917                                       SmallVectorImpl<AffineExpr> &memo,
2918                                       MLIRContext *context) {
2919   unsigned numDims = cst.getNumDimIds();
2920   unsigned numSyms = cst.getNumSymbolIds();
2921 
2922   // Initialize dimensional and symbolic identifiers.
2923   for (unsigned i = 0; i < numDims; i++)
2924     memo[i] = getAffineDimExpr(i, context);
2925   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
2926     memo[i] = getAffineSymbolExpr(i - numDims, context);
2927 
2928   bool changed;
2929   do {
2930     // Each time `changed` is true at the end of this iteration, one or more
2931     // local vars would have been detected as floordivs and set in memo; so the
2932     // number of null entries in memo[...] strictly reduces; so this converges.
2933     changed = false;
2934     for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
2935       if (!memo[numDims + numSyms + i] &&
2936           detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
2937         changed = true;
2938   } while (changed);
2939 
2940   ArrayRef<AffineExpr> localExprs =
2941       ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
2942   return success(
2943       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
2944 }
2945 
getIneqAsAffineValueMap(unsigned pos,unsigned ineqPos,AffineValueMap & vmap,MLIRContext * context) const2946 void FlatAffineConstraints::getIneqAsAffineValueMap(
2947     unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
2948     MLIRContext *context) const {
2949   unsigned numDims = getNumDimIds();
2950   unsigned numSyms = getNumSymbolIds();
2951 
2952   assert(pos < numDims && "invalid position");
2953   assert(ineqPos < getNumInequalities() && "invalid inequality position");
2954 
2955   // Get expressions for local vars.
2956   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
2957   if (failed(computeLocalVars(*this, memo, context)))
2958     assert(false &&
2959            "one or more local exprs do not have an explicit representation");
2960   auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
2961 
2962   // Compute the AffineExpr lower/upper bound for this inequality.
2963   ArrayRef<int64_t> inequality = getInequality(ineqPos);
2964   SmallVector<int64_t, 8> bound;
2965   bound.reserve(getNumCols() - 1);
2966   // Everything other than the coefficient at `pos`.
2967   bound.append(inequality.begin(), inequality.begin() + pos);
2968   bound.append(inequality.begin() + pos + 1, inequality.end());
2969 
2970   if (inequality[pos] > 0)
2971     // Lower bound.
2972     std::transform(bound.begin(), bound.end(), bound.begin(),
2973                    std::negate<int64_t>());
2974   else
2975     // Upper bound (which is exclusive).
2976     bound.back() += 1;
2977 
2978   // Convert to AffineExpr (tree) form.
2979   auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
2980                                              localExprs, context);
2981 
2982   // Get the values to bind to this affine expr (all dims and symbols).
2983   SmallVector<Value, 4> operands;
2984   getIdValues(0, pos, &operands);
2985   SmallVector<Value, 4> trailingOperands;
2986   getIdValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
2987   operands.append(trailingOperands.begin(), trailingOperands.end());
2988   vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
2989 }
2990 
2991 /// Returns true if the pos^th column is all zero for both inequalities and
2992 /// equalities..
isColZero(const FlatAffineConstraints & cst,unsigned pos)2993 static bool isColZero(const FlatAffineConstraints &cst, unsigned pos) {
2994   unsigned rowPos;
2995   return !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/false, &rowPos) &&
2996          !findConstraintWithNonZeroAt(cst, pos, /*isEq=*/true, &rowPos);
2997 }
2998 
getAsIntegerSet(MLIRContext * context) const2999 IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
3000   if (getNumConstraints() == 0)
3001     // Return universal set (always true): 0 == 0.
3002     return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
3003                            getAffineConstantExpr(/*constant=*/0, context),
3004                            /*eqFlags=*/true);
3005 
3006   // Construct local references.
3007   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
3008 
3009   if (failed(computeLocalVars(*this, memo, context))) {
3010     // Check if the local variables without an explicit representation have
3011     // zero coefficients everywhere.
3012     for (unsigned i = getNumDimAndSymbolIds(), e = getNumIds(); i < e; ++i) {
3013       if (!memo[i] && !isColZero(*this, /*pos=*/i)) {
3014         LLVM_DEBUG(llvm::dbgs() << "one or more local exprs do not have an "
3015                                    "explicit representation");
3016         return IntegerSet();
3017       }
3018     }
3019   }
3020 
3021   ArrayRef<AffineExpr> localExprs =
3022       ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
3023 
3024   // Construct the IntegerSet from the equalities/inequalities.
3025   unsigned numDims = getNumDimIds();
3026   unsigned numSyms = getNumSymbolIds();
3027 
3028   SmallVector<bool, 16> eqFlags(getNumConstraints());
3029   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
3030   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
3031 
3032   SmallVector<AffineExpr, 8> exprs;
3033   exprs.reserve(getNumConstraints());
3034 
3035   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
3036     exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
3037                                               localExprs, context));
3038   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
3039     exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
3040                                               numSyms, localExprs, context));
3041   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
3042 }
3043 
3044 /// Find positions of inequalities and equalities that do not have a coefficient
3045 /// for [pos, pos + num) identifiers.
getIndependentConstraints(const FlatAffineConstraints & cst,unsigned pos,unsigned num,SmallVectorImpl<unsigned> & nbIneqIndices,SmallVectorImpl<unsigned> & nbEqIndices)3046 static void getIndependentConstraints(const FlatAffineConstraints &cst,
3047                                       unsigned pos, unsigned num,
3048                                       SmallVectorImpl<unsigned> &nbIneqIndices,
3049                                       SmallVectorImpl<unsigned> &nbEqIndices) {
3050   assert(pos < cst.getNumIds() && "invalid start position");
3051   assert(pos + num <= cst.getNumIds() && "invalid limit");
3052 
3053   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
3054     // The bounds are to be independent of [offset, offset + num) columns.
3055     unsigned c;
3056     for (c = pos; c < pos + num; ++c) {
3057       if (cst.atIneq(r, c) != 0)
3058         break;
3059     }
3060     if (c == pos + num)
3061       nbIneqIndices.push_back(r);
3062   }
3063 
3064   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
3065     // The bounds are to be independent of [offset, offset + num) columns.
3066     unsigned c;
3067     for (c = pos; c < pos + num; ++c) {
3068       if (cst.atEq(r, c) != 0)
3069         break;
3070     }
3071     if (c == pos + num)
3072       nbEqIndices.push_back(r);
3073   }
3074 }
3075 
removeIndependentConstraints(unsigned pos,unsigned num)3076 void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
3077                                                          unsigned num) {
3078   assert(pos + num <= getNumIds() && "invalid range");
3079 
3080   // Remove constraints that are independent of these identifiers.
3081   SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
3082   getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
3083 
3084   // Iterate in reverse so that indices don't have to be updated.
3085   // TODO: This method can be made more efficient (because removal of each
3086   // inequality leads to much shifting/copying in the underlying buffer).
3087   for (auto nbIndex : llvm::reverse(nbIneqIndices))
3088     removeInequality(nbIndex);
3089   for (auto nbIndex : llvm::reverse(nbEqIndices))
3090     removeEquality(nbIndex);
3091 }
3092