1 //===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/PresburgerSet.h"
10 #include "mlir/Analysis/Presburger/Simplex.h"
11 #include "llvm/ADT/STLExtras.h"
12 #include "llvm/ADT/SmallBitVector.h"
13
14 using namespace mlir;
15
PresburgerSet(const FlatAffineConstraints & fac)16 PresburgerSet::PresburgerSet(const FlatAffineConstraints &fac)
17 : nDim(fac.getNumDimIds()), nSym(fac.getNumSymbolIds()) {
18 unionFACInPlace(fac);
19 }
20
getNumFACs() const21 unsigned PresburgerSet::getNumFACs() const {
22 return flatAffineConstraints.size();
23 }
24
getNumDims() const25 unsigned PresburgerSet::getNumDims() const { return nDim; }
26
getNumSyms() const27 unsigned PresburgerSet::getNumSyms() const { return nSym; }
28
29 ArrayRef<FlatAffineConstraints>
getAllFlatAffineConstraints() const30 PresburgerSet::getAllFlatAffineConstraints() const {
31 return flatAffineConstraints;
32 }
33
34 const FlatAffineConstraints &
getFlatAffineConstraints(unsigned index) const35 PresburgerSet::getFlatAffineConstraints(unsigned index) const {
36 assert(index < flatAffineConstraints.size() && "index out of bounds!");
37 return flatAffineConstraints[index];
38 }
39
40 /// Assert that the FlatAffineConstraints and PresburgerSet live in
41 /// compatible spaces.
assertDimensionsCompatible(const FlatAffineConstraints & fac,const PresburgerSet & set)42 static void assertDimensionsCompatible(const FlatAffineConstraints &fac,
43 const PresburgerSet &set) {
44 assert(fac.getNumDimIds() == set.getNumDims() &&
45 "Number of dimensions of the FlatAffineConstraints and PresburgerSet"
46 "do not match!");
47 assert(fac.getNumSymbolIds() == set.getNumSyms() &&
48 "Number of symbols of the FlatAffineConstraints and PresburgerSet"
49 "do not match!");
50 }
51
52 /// Assert that the two PresburgerSets live in compatible spaces.
assertDimensionsCompatible(const PresburgerSet & setA,const PresburgerSet & setB)53 static void assertDimensionsCompatible(const PresburgerSet &setA,
54 const PresburgerSet &setB) {
55 assert(setA.getNumDims() == setB.getNumDims() &&
56 "Number of dimensions of the PresburgerSets do not match!");
57 assert(setA.getNumSyms() == setB.getNumSyms() &&
58 "Number of symbols of the PresburgerSets do not match!");
59 }
60
61 /// Mutate this set, turning it into the union of this set and the given
62 /// FlatAffineConstraints.
unionFACInPlace(const FlatAffineConstraints & fac)63 void PresburgerSet::unionFACInPlace(const FlatAffineConstraints &fac) {
64 assertDimensionsCompatible(fac, *this);
65 flatAffineConstraints.push_back(fac);
66 }
67
68 /// Mutate this set, turning it into the union of this set and the given set.
69 ///
70 /// This is accomplished by simply adding all the FACs of the given set to this
71 /// set.
unionSetInPlace(const PresburgerSet & set)72 void PresburgerSet::unionSetInPlace(const PresburgerSet &set) {
73 assertDimensionsCompatible(set, *this);
74 for (const FlatAffineConstraints &fac : set.flatAffineConstraints)
75 unionFACInPlace(fac);
76 }
77
78 /// Return the union of this set and the given set.
unionSet(const PresburgerSet & set) const79 PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const {
80 assertDimensionsCompatible(set, *this);
81 PresburgerSet result = *this;
82 result.unionSetInPlace(set);
83 return result;
84 }
85
86 /// A point is contained in the union iff any of the parts contain the point.
containsPoint(ArrayRef<int64_t> point) const87 bool PresburgerSet::containsPoint(ArrayRef<int64_t> point) const {
88 for (const FlatAffineConstraints &fac : flatAffineConstraints) {
89 if (fac.containsPoint(point))
90 return true;
91 }
92 return false;
93 }
94
getUniverse(unsigned nDim,unsigned nSym)95 PresburgerSet PresburgerSet::getUniverse(unsigned nDim, unsigned nSym) {
96 PresburgerSet result(nDim, nSym);
97 result.unionFACInPlace(FlatAffineConstraints::getUniverse(nDim, nSym));
98 return result;
99 }
100
getEmptySet(unsigned nDim,unsigned nSym)101 PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
102 return PresburgerSet(nDim, nSym);
103 }
104
105 // Return the intersection of this set with the given set.
106 //
107 // We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
108 // as (S_1 and T_1) or (S_1 and T_2) or ...
intersect(const PresburgerSet & set) const109 PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
110 assertDimensionsCompatible(set, *this);
111
112 PresburgerSet result(nDim, nSym);
113 for (const FlatAffineConstraints &csA : flatAffineConstraints) {
114 for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
115 FlatAffineConstraints intersection(csA);
116 intersection.append(csB);
117 if (!intersection.isEmpty())
118 result.unionFACInPlace(std::move(intersection));
119 }
120 }
121 return result;
122 }
123
124 /// Return `coeffs` with all the elements negated.
getNegatedCoeffs(ArrayRef<int64_t> coeffs)125 static SmallVector<int64_t, 8> getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
126 SmallVector<int64_t, 8> negatedCoeffs;
127 negatedCoeffs.reserve(coeffs.size());
128 for (int64_t coeff : coeffs)
129 negatedCoeffs.emplace_back(-coeff);
130 return negatedCoeffs;
131 }
132
133 /// Return the complement of the given inequality.
134 ///
135 /// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is
136 /// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0.
getComplementIneq(ArrayRef<int64_t> ineq)137 static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
138 SmallVector<int64_t, 8> coeffs;
139 coeffs.reserve(ineq.size());
140 for (int64_t coeff : ineq)
141 coeffs.emplace_back(-coeff);
142 --coeffs.back();
143 return coeffs;
144 }
145
146 /// Return the set difference b \ s and accumulate the result into `result`.
147 /// `simplex` must correspond to b.
148 ///
149 /// In the following, V denotes union, ^ denotes intersection, \ denotes set
150 /// difference and ~ denotes complement.
151 /// Let b be the FlatAffineConstraints and s = (V_i s_i) be the set. We want
152 /// b \ (V_i s_i).
153 ///
154 /// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
155 /// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
156 /// ~s_i = (~s_i1) V (s_i1 ^ ~s_i2) V (s_i1 ^ s_i2 ^ ~s_i3) V ...
157 /// And the required result is (b ^ ~s_i1) V (b ^ s_i1 ^ ~s_i2) V ...
158 /// We recurse by subtracting V_{j > i} S_j from each of these parts and
159 /// returning the union of the results. Each equality is handled as a
160 /// conjunction of two inequalities.
161 ///
162 /// As a heuristic, we try adding all the constraints and check if simplex
163 /// says that the intersection is empty. Also, in the process we find out that
164 /// some constraints are redundant. These redundant constraints are ignored.
subtractRecursively(FlatAffineConstraints & b,Simplex & simplex,const PresburgerSet & s,unsigned i,PresburgerSet & result)165 static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
166 const PresburgerSet &s, unsigned i,
167 PresburgerSet &result) {
168 if (i == s.getNumFACs()) {
169 result.unionFACInPlace(b);
170 return;
171 }
172 const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
173 unsigned initialSnapshot = simplex.getSnapshot();
174 unsigned offset = simplex.numConstraints();
175 simplex.intersectFlatAffineConstraints(sI);
176
177 if (simplex.isEmpty()) {
178 /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
179 simplex.rollback(initialSnapshot);
180 subtractRecursively(b, simplex, s, i + 1, result);
181 return;
182 }
183
184 simplex.detectRedundant();
185 llvm::SmallBitVector isMarkedRedundant;
186 for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
187 j++)
188 isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
189
190 simplex.rollback(initialSnapshot);
191
192 // Recurse with the part b ^ ~ineq. Note that b is modified throughout
193 // subtractRecursively. At the time this function is called, the current b is
194 // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
195 // inequality, s_{i,j+1}. This function recurses into the next level i + 1
196 // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
197 auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
198 size_t snapshot = simplex.getSnapshot();
199 b.addInequality(ineq);
200 simplex.addInequality(ineq);
201 subtractRecursively(b, simplex, s, i + 1, result);
202 b.removeInequality(b.getNumInequalities() - 1);
203 simplex.rollback(snapshot);
204 };
205
206 // For each inequality ineq, we first recurse with the part where ineq
207 // is not satisfied, and then add the ineq to b and simplex because
208 // ineq must be satisfied by all later parts.
209 auto processInequality = [&](ArrayRef<int64_t> ineq) {
210 recurseWithInequality(getComplementIneq(ineq));
211 b.addInequality(ineq);
212 simplex.addInequality(ineq);
213 };
214
215 // processInequality appends some additional constraints to b. We want to
216 // rollback b to its initial state before returning, which we will do by
217 // removing all constraints beyond the original number of inequalities
218 // and equalities, so we store these counts first.
219 unsigned originalNumIneqs = b.getNumInequalities();
220 unsigned originalNumEqs = b.getNumEqualities();
221
222 for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
223 if (isMarkedRedundant[j])
224 continue;
225 processInequality(sI.getInequality(j));
226 }
227
228 offset = sI.getNumInequalities();
229 for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
230 const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
231 // Same as the above loop for inequalities, done once each for the positive
232 // and negative inequalities that make up this equality.
233 if (!isMarkedRedundant[offset + 2 * j])
234 processInequality(coeffs);
235 if (!isMarkedRedundant[offset + 2 * j + 1])
236 processInequality(getNegatedCoeffs(coeffs));
237 }
238
239 // Rollback b and simplex to their initial states.
240 for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
241 b.removeInequality(i - 1);
242
243 for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
244 b.removeEquality(i - 1);
245
246 simplex.rollback(initialSnapshot);
247 }
248
249 /// Return the set difference fac \ set.
250 ///
251 /// The FAC here is modified in subtractRecursively, so it cannot be a const
252 /// reference even though it is restored to its original state before returning
253 /// from that function.
getSetDifference(FlatAffineConstraints fac,const PresburgerSet & set)254 PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
255 const PresburgerSet &set) {
256 assertDimensionsCompatible(fac, set);
257 if (fac.isEmptyByGCDTest())
258 return PresburgerSet::getEmptySet(fac.getNumDimIds(),
259 fac.getNumSymbolIds());
260
261 PresburgerSet result(fac.getNumDimIds(), fac.getNumSymbolIds());
262 Simplex simplex(fac);
263 subtractRecursively(fac, simplex, set, 0, result);
264 return result;
265 }
266
267 /// Return the complement of this set.
complement() const268 PresburgerSet PresburgerSet::complement() const {
269 return getSetDifference(
270 FlatAffineConstraints::getUniverse(getNumDims(), getNumSyms()), *this);
271 }
272
273 /// Return the result of subtract the given set from this set, i.e.,
274 /// return `this \ set`.
subtract(const PresburgerSet & set) const275 PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const {
276 assertDimensionsCompatible(set, *this);
277 PresburgerSet result(nDim, nSym);
278 // We compute (V_i t_i) \ (V_i set_i) as V_i (t_i \ V_i set_i).
279 for (const FlatAffineConstraints &fac : flatAffineConstraints)
280 result.unionSetInPlace(getSetDifference(fac, set));
281 return result;
282 }
283
284 /// Return true if all the sets in the union are known to be integer empty,
285 /// false otherwise.
isIntegerEmpty() const286 bool PresburgerSet::isIntegerEmpty() const {
287 assert(nSym == 0 && "isIntegerEmpty is intended for non-symbolic sets");
288 // The set is empty iff all of the disjuncts are empty.
289 for (const FlatAffineConstraints &fac : flatAffineConstraints) {
290 if (!fac.isIntegerEmpty())
291 return false;
292 }
293 return true;
294 }
295
findIntegerSample(SmallVectorImpl<int64_t> & sample)296 bool PresburgerSet::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
297 assert(nSym == 0 && "findIntegerSample is intended for non-symbolic sets");
298 // A sample exists iff any of the disjuncts contains a sample.
299 for (const FlatAffineConstraints &fac : flatAffineConstraints) {
300 if (Optional<SmallVector<int64_t, 8>> opt = fac.findIntegerSample()) {
301 sample = std::move(*opt);
302 return true;
303 }
304 }
305 return false;
306 }
307
print(raw_ostream & os) const308 void PresburgerSet::print(raw_ostream &os) const {
309 os << getNumFACs() << " FlatAffineConstraints:\n";
310 for (const FlatAffineConstraints &fac : flatAffineConstraints) {
311 fac.print(os);
312 os << '\n';
313 }
314 }
315
dump() const316 void PresburgerSet::dump() const { print(llvm::errs()); }
317