• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Analysis/Utils.h"
15 
16 #include "mlir/Analysis/AffineAnalysis.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 #define DEBUG_TYPE "analysis-utils"
26 
27 using namespace mlir;
28 
29 using llvm::SmallDenseMap;
30 
31 /// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
32 /// the outermost 'affine.for' operation to the innermost one.
getLoopIVs(Operation & op,SmallVectorImpl<AffineForOp> * loops)33 void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
34   auto *currOp = op.getParentOp();
35   AffineForOp currAffineForOp;
36   // Traverse up the hierarchy collecting all 'affine.for' operation while
37   // skipping over 'affine.if' operations.
38   while (currOp && ((currAffineForOp = dyn_cast<AffineForOp>(currOp)) ||
39                     isa<AffineIfOp>(currOp))) {
40     if (currAffineForOp)
41       loops->push_back(currAffineForOp);
42     currOp = currOp->getParentOp();
43   }
44   std::reverse(loops->begin(), loops->end());
45 }
46 
47 /// Populates 'ops' with IVs of the loops surrounding `op`, along with
48 /// `affine.if` operations interleaved between these loops, ordered from the
49 /// outermost `affine.for` operation to the innermost one.
getEnclosingAffineForAndIfOps(Operation & op,SmallVectorImpl<Operation * > * ops)50 void mlir::getEnclosingAffineForAndIfOps(Operation &op,
51                                          SmallVectorImpl<Operation *> *ops) {
52   ops->clear();
53   Operation *currOp = op.getParentOp();
54 
55   // Traverse up the hierarchy collecting all `affine.for` and `affine.if`
56   // operations.
57   while (currOp && (isa<AffineIfOp, AffineForOp>(currOp))) {
58     ops->push_back(currOp);
59     currOp = currOp->getParentOp();
60   }
61   std::reverse(ops->begin(), ops->end());
62 }
63 
64 // Populates 'cst' with FlatAffineConstraints which represent slice bounds.
65 LogicalResult
getAsConstraints(FlatAffineConstraints * cst)66 ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
67   assert(!lbOperands.empty());
68   // Adds src 'ivs' as dimension identifiers in 'cst'.
69   unsigned numDims = ivs.size();
70   // Adds operands (dst ivs and symbols) as symbols in 'cst'.
71   unsigned numSymbols = lbOperands[0].size();
72 
73   SmallVector<Value, 4> values(ivs);
74   // Append 'ivs' then 'operands' to 'values'.
75   values.append(lbOperands[0].begin(), lbOperands[0].end());
76   cst->reset(numDims, numSymbols, 0, values);
77 
78   // Add loop bound constraints for values which are loop IVs and equality
79   // constraints for symbols which are constants.
80   for (const auto &value : values) {
81     assert(cst->containsId(value) && "value expected to be present");
82     if (isValidSymbol(value)) {
83       // Check if the symbol is a constant.
84       if (auto cOp = value.getDefiningOp<ConstantIndexOp>())
85         cst->setIdToConstant(value, cOp.getValue());
86     } else if (auto loop = getForInductionVarOwner(value)) {
87       if (failed(cst->addAffineForOpDomain(loop)))
88         return failure();
89     }
90   }
91 
92   // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
93   LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
94   assert(succeeded(ret) &&
95          "should not fail as we never have semi-affine slice maps");
96   (void)ret;
97   return success();
98 }
99 
100 // Clears state bounds and operand state.
clearBounds()101 void ComputationSliceState::clearBounds() {
102   lbs.clear();
103   ubs.clear();
104   lbOperands.clear();
105   ubOperands.clear();
106 }
107 
dump() const108 void ComputationSliceState::dump() const {
109   llvm::errs() << "\tIVs:\n";
110   for (Value iv : ivs)
111     llvm::errs() << "\t\t" << iv << "\n";
112 
113   llvm::errs() << "\tLBs:\n";
114   for (auto &en : llvm::enumerate(lbs)) {
115     llvm::errs() << "\t\t" << en.value() << "\n";
116     llvm::errs() << "\t\tOperands:\n";
117     for (Value lbOp : lbOperands[en.index()])
118       llvm::errs() << "\t\t\t" << lbOp << "\n";
119   }
120 
121   llvm::errs() << "\tUBs:\n";
122   for (auto &en : llvm::enumerate(ubs)) {
123     llvm::errs() << "\t\t" << en.value() << "\n";
124     llvm::errs() << "\t\tOperands:\n";
125     for (Value ubOp : ubOperands[en.index()])
126       llvm::errs() << "\t\t\t" << ubOp << "\n";
127   }
128 }
129 
getRank() const130 unsigned MemRefRegion::getRank() const {
131   return memref.getType().cast<MemRefType>().getRank();
132 }
133 
getConstantBoundingSizeAndShape(SmallVectorImpl<int64_t> * shape,std::vector<SmallVector<int64_t,4>> * lbs,SmallVectorImpl<int64_t> * lbDivisors) const134 Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
135     SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
136     SmallVectorImpl<int64_t> *lbDivisors) const {
137   auto memRefType = memref.getType().cast<MemRefType>();
138   unsigned rank = memRefType.getRank();
139   if (shape)
140     shape->reserve(rank);
141 
142   assert(rank == cst.getNumDimIds() && "inconsistent memref region");
143 
144   // Use a copy of the region constraints that has upper/lower bounds for each
145   // memref dimension with static size added to guard against potential
146   // over-approximation from projection or union bounding box. We may not add
147   // this on the region itself since they might just be redundant constraints
148   // that will need non-trivials means to eliminate.
149   FlatAffineConstraints cstWithShapeBounds(cst);
150   for (unsigned r = 0; r < rank; r++) {
151     cstWithShapeBounds.addConstantLowerBound(r, 0);
152     int64_t dimSize = memRefType.getDimSize(r);
153     if (ShapedType::isDynamic(dimSize))
154       continue;
155     cstWithShapeBounds.addConstantUpperBound(r, dimSize - 1);
156   }
157 
158   // Find a constant upper bound on the extent of this memref region along each
159   // dimension.
160   int64_t numElements = 1;
161   int64_t diffConstant;
162   int64_t lbDivisor;
163   for (unsigned d = 0; d < rank; d++) {
164     SmallVector<int64_t, 4> lb;
165     Optional<int64_t> diff =
166         cstWithShapeBounds.getConstantBoundOnDimSize(d, &lb, &lbDivisor);
167     if (diff.hasValue()) {
168       diffConstant = diff.getValue();
169       assert(lbDivisor > 0);
170     } else {
171       // If no constant bound is found, then it can always be bound by the
172       // memref's dim size if the latter has a constant size along this dim.
173       auto dimSize = memRefType.getDimSize(d);
174       if (dimSize == -1)
175         return None;
176       diffConstant = dimSize;
177       // Lower bound becomes 0.
178       lb.resize(cstWithShapeBounds.getNumSymbolIds() + 1, 0);
179       lbDivisor = 1;
180     }
181     numElements *= diffConstant;
182     if (lbs) {
183       lbs->push_back(lb);
184       assert(lbDivisors && "both lbs and lbDivisor or none");
185       lbDivisors->push_back(lbDivisor);
186     }
187     if (shape) {
188       shape->push_back(diffConstant);
189     }
190   }
191   return numElements;
192 }
193 
getLowerAndUpperBound(unsigned pos,AffineMap & lbMap,AffineMap & ubMap) const194 void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
195                                          AffineMap &ubMap) const {
196   assert(pos < cst.getNumDimIds() && "invalid position");
197   auto memRefType = memref.getType().cast<MemRefType>();
198   unsigned rank = memRefType.getRank();
199 
200   assert(rank == cst.getNumDimIds() && "inconsistent memref region");
201 
202   auto boundPairs = cst.getLowerAndUpperBound(
203       pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolIds(),
204       /*localExprs=*/{}, memRefType.getContext());
205   lbMap = boundPairs.first;
206   ubMap = boundPairs.second;
207   assert(lbMap && "lower bound for a region must exist");
208   assert(ubMap && "upper bound for a region must exist");
209   assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolIds() - rank);
210   assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolIds() - rank);
211 }
212 
unionBoundingBox(const MemRefRegion & other)213 LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
214   assert(memref == other.memref);
215   return cst.unionBoundingBox(*other.getConstraints());
216 }
217 
218 /// Computes the memory region accessed by this memref with the region
219 /// represented as constraints symbolic/parametric in 'loopDepth' loops
220 /// surrounding opInst and any additional Function symbols.
221 //  For example, the memref region for this load operation at loopDepth = 1 will
222 //  be as below:
223 //
224 //    affine.for %i = 0 to 32 {
225 //      affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
226 //        load %A[%ii]
227 //      }
228 //    }
229 //
230 // region:  {memref = %A, write = false, {%i <= m0 <= %i + 7} }
231 // The last field is a 2-d FlatAffineConstraints symbolic in %i.
232 //
233 // TODO: extend this to any other memref dereferencing ops
234 // (dma_start, dma_wait).
compute(Operation * op,unsigned loopDepth,const ComputationSliceState * sliceState,bool addMemRefDimBounds)235 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
236                                     const ComputationSliceState *sliceState,
237                                     bool addMemRefDimBounds) {
238   assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
239          "affine read/write op expected");
240 
241   MemRefAccess access(op);
242   memref = access.memref;
243   write = access.isStore();
244 
245   unsigned rank = access.getRank();
246 
247   LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
248                           << "depth: " << loopDepth << "\n";);
249 
250   // 0-d memrefs.
251   if (rank == 0) {
252     SmallVector<AffineForOp, 4> ivs;
253     getLoopIVs(*op, &ivs);
254     assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
255     // The first 'loopDepth' IVs are symbols for this region.
256     ivs.resize(loopDepth);
257     SmallVector<Value, 4> regionSymbols;
258     extractForInductionVars(ivs, &regionSymbols);
259     // A 0-d memref has a 0-d region.
260     cst.reset(rank, loopDepth, /*numLocals=*/0, regionSymbols);
261     return success();
262   }
263 
264   // Build the constraints for this region.
265   AffineValueMap accessValueMap;
266   access.getAccessMap(&accessValueMap);
267   AffineMap accessMap = accessValueMap.getAffineMap();
268 
269   unsigned numDims = accessMap.getNumDims();
270   unsigned numSymbols = accessMap.getNumSymbols();
271   unsigned numOperands = accessValueMap.getNumOperands();
272   // Merge operands with slice operands.
273   SmallVector<Value, 4> operands;
274   operands.resize(numOperands);
275   for (unsigned i = 0; i < numOperands; ++i)
276     operands[i] = accessValueMap.getOperand(i);
277 
278   if (sliceState != nullptr) {
279     operands.reserve(operands.size() + sliceState->lbOperands[0].size());
280     // Append slice operands to 'operands' as symbols.
281     for (auto extraOperand : sliceState->lbOperands[0]) {
282       if (!llvm::is_contained(operands, extraOperand)) {
283         operands.push_back(extraOperand);
284         numSymbols++;
285       }
286     }
287   }
288   // We'll first associate the dims and symbols of the access map to the dims
289   // and symbols resp. of cst. This will change below once cst is
290   // fully constructed out.
291   cst.reset(numDims, numSymbols, 0, operands);
292 
293   // Add equality constraints.
294   // Add inequalities for loop lower/upper bounds.
295   for (unsigned i = 0; i < numDims + numSymbols; ++i) {
296     auto operand = operands[i];
297     if (auto loop = getForInductionVarOwner(operand)) {
298       // Note that cst can now have more dimensions than accessMap if the
299       // bounds expressions involve outer loops or other symbols.
300       // TODO: rewrite this to use getInstIndexSet; this way
301       // conditionals will be handled when the latter supports it.
302       if (failed(cst.addAffineForOpDomain(loop)))
303         return failure();
304     } else {
305       // Has to be a valid symbol.
306       auto symbol = operand;
307       assert(isValidSymbol(symbol));
308       // Check if the symbol is a constant.
309       if (auto *op = symbol.getDefiningOp()) {
310         if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
311           cst.setIdToConstant(symbol, constOp.getValue());
312         }
313       }
314     }
315   }
316 
317   // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
318   if (sliceState != nullptr) {
319     // Add dim and symbol slice operands.
320     for (auto operand : sliceState->lbOperands[0]) {
321       cst.addInductionVarOrTerminalSymbol(operand);
322     }
323     // Add upper/lower bounds from 'sliceState' to 'cst'.
324     LogicalResult ret =
325         cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
326                            sliceState->lbOperands[0]);
327     assert(succeeded(ret) &&
328            "should not fail as we never have semi-affine slice maps");
329     (void)ret;
330   }
331 
332   // Add access function equalities to connect loop IVs to data dimensions.
333   if (failed(cst.composeMap(&accessValueMap))) {
334     op->emitError("getMemRefRegion: compose affine map failed");
335     LLVM_DEBUG(accessValueMap.getAffineMap().dump());
336     return failure();
337   }
338 
339   // Set all identifiers appearing after the first 'rank' identifiers as
340   // symbolic identifiers - so that the ones corresponding to the memref
341   // dimensions are the dimensional identifiers for the memref region.
342   cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
343 
344   // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
345   // this memref region is symbolic.
346   SmallVector<AffineForOp, 4> enclosingIVs;
347   getLoopIVs(*op, &enclosingIVs);
348   assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
349   enclosingIVs.resize(loopDepth);
350   SmallVector<Value, 4> ids;
351   cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
352   for (auto id : ids) {
353     AffineForOp iv;
354     if ((iv = getForInductionVarOwner(id)) &&
355         llvm::is_contained(enclosingIVs, iv) == false) {
356       cst.projectOut(id);
357     }
358   }
359 
360   // Project out any local variables (these would have been added for any
361   // mod/divs).
362   cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
363 
364   // Constant fold any symbolic identifiers.
365   cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(),
366                           /*num=*/cst.getNumSymbolIds());
367 
368   assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format");
369 
370   // Add upper/lower bounds for each memref dimension with static size
371   // to guard against potential over-approximation from projection.
372   // TODO: Support dynamic memref dimensions.
373   if (addMemRefDimBounds) {
374     auto memRefType = memref.getType().cast<MemRefType>();
375     for (unsigned r = 0; r < rank; r++) {
376       cst.addConstantLowerBound(/*pos=*/r, /*lb=*/0);
377       if (memRefType.isDynamicDim(r))
378         continue;
379       cst.addConstantUpperBound(/*pos=*/r, memRefType.getDimSize(r) - 1);
380     }
381   }
382   cst.removeTrivialRedundancy();
383 
384   LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
385   LLVM_DEBUG(cst.dump());
386   return success();
387 }
388 
getMemRefEltSizeInBytes(MemRefType memRefType)389 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
390   auto elementType = memRefType.getElementType();
391 
392   unsigned sizeInBits;
393   if (elementType.isIntOrFloat()) {
394     sizeInBits = elementType.getIntOrFloatBitWidth();
395   } else {
396     auto vectorType = elementType.cast<VectorType>();
397     sizeInBits =
398         vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
399   }
400   return llvm::divideCeil(sizeInBits, 8);
401 }
402 
403 // Returns the size of the region.
getRegionSize()404 Optional<int64_t> MemRefRegion::getRegionSize() {
405   auto memRefType = memref.getType().cast<MemRefType>();
406 
407   auto layoutMaps = memRefType.getAffineMaps();
408   if (layoutMaps.size() > 1 ||
409       (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
410     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
411     return false;
412   }
413 
414   // Indices to use for the DmaStart op.
415   // Indices for the original memref being DMAed from/to.
416   SmallVector<Value, 4> memIndices;
417   // Indices for the faster buffer being DMAed into/from.
418   SmallVector<Value, 4> bufIndices;
419 
420   // Compute the extents of the buffer.
421   Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
422   if (!numElements.hasValue()) {
423     LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
424     return None;
425   }
426   return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
427 }
428 
429 /// Returns the size of memref data in bytes if it's statically shaped, None
430 /// otherwise.  If the element of the memref has vector type, takes into account
431 /// size of the vector as well.
432 //  TODO: improve/complete this when we have target data.
getMemRefSizeInBytes(MemRefType memRefType)433 Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
434   if (!memRefType.hasStaticShape())
435     return None;
436   auto elementType = memRefType.getElementType();
437   if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
438     return None;
439 
440   uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
441   for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
442     sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
443   }
444   return sizeInBytes;
445 }
446 
447 template <typename LoadOrStoreOp>
boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,bool emitError)448 LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
449                                             bool emitError) {
450   static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
451                                 AffineWriteOpInterface>::value,
452                 "argument should be either a AffineReadOpInterface or a "
453                 "AffineWriteOpInterface");
454 
455   Operation *op = loadOrStoreOp.getOperation();
456   MemRefRegion region(op->getLoc());
457   if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
458                             /*addMemRefDimBounds=*/false)))
459     return success();
460 
461   LLVM_DEBUG(llvm::dbgs() << "Memory region");
462   LLVM_DEBUG(region.getConstraints()->dump());
463 
464   bool outOfBounds = false;
465   unsigned rank = loadOrStoreOp.getMemRefType().getRank();
466 
467   // For each dimension, check for out of bounds.
468   for (unsigned r = 0; r < rank; r++) {
469     FlatAffineConstraints ucst(*region.getConstraints());
470 
471     // Intersect memory region with constraint capturing out of bounds (both out
472     // of upper and out of lower), and check if the constraint system is
473     // feasible. If it is, there is at least one point out of bounds.
474     SmallVector<int64_t, 4> ineq(rank + 1, 0);
475     int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
476     // TODO: handle dynamic dim sizes.
477     if (dimSize == -1)
478       continue;
479 
480     // Check for overflow: d_i >= memref dim size.
481     ucst.addConstantLowerBound(r, dimSize);
482     outOfBounds = !ucst.isEmpty();
483     if (outOfBounds && emitError) {
484       loadOrStoreOp.emitOpError()
485           << "memref out of upper bound access along dimension #" << (r + 1);
486     }
487 
488     // Check for a negative index.
489     FlatAffineConstraints lcst(*region.getConstraints());
490     std::fill(ineq.begin(), ineq.end(), 0);
491     // d_i <= -1;
492     lcst.addConstantUpperBound(r, -1);
493     outOfBounds = !lcst.isEmpty();
494     if (outOfBounds && emitError) {
495       loadOrStoreOp.emitOpError()
496           << "memref out of lower bound access along dimension #" << (r + 1);
497     }
498   }
499   return failure(outOfBounds);
500 }
501 
502 // Explicitly instantiate the template so that the compiler knows we need them!
503 template LogicalResult
504 mlir::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, bool emitError);
505 template LogicalResult
506 mlir::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, bool emitError);
507 
508 // Returns in 'positions' the Block positions of 'op' in each ancestor
509 // Block from the Block containing operation, stopping at 'limitBlock'.
findInstPosition(Operation * op,Block * limitBlock,SmallVectorImpl<unsigned> * positions)510 static void findInstPosition(Operation *op, Block *limitBlock,
511                              SmallVectorImpl<unsigned> *positions) {
512   Block *block = op->getBlock();
513   while (block != limitBlock) {
514     // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
515     // rely on linear scans.
516     int instPosInBlock = std::distance(block->begin(), op->getIterator());
517     positions->push_back(instPosInBlock);
518     op = block->getParentOp();
519     block = op->getBlock();
520   }
521   std::reverse(positions->begin(), positions->end());
522 }
523 
524 // Returns the Operation in a possibly nested set of Blocks, where the
525 // position of the operation is represented by 'positions', which has a
526 // Block position for each level of nesting.
getInstAtPosition(ArrayRef<unsigned> positions,unsigned level,Block * block)527 static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
528                                     unsigned level, Block *block) {
529   unsigned i = 0;
530   for (auto &op : *block) {
531     if (i != positions[level]) {
532       ++i;
533       continue;
534     }
535     if (level == positions.size() - 1)
536       return &op;
537     if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
538       return getInstAtPosition(positions, level + 1,
539                                childAffineForOp.getBody());
540 
541     for (auto &region : op.getRegions()) {
542       for (auto &b : region)
543         if (auto *ret = getInstAtPosition(positions, level + 1, &b))
544           return ret;
545     }
546     return nullptr;
547   }
548   return nullptr;
549 }
550 
551 // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
addMissingLoopIVBounds(SmallPtrSet<Value,8> & ivs,FlatAffineConstraints * cst)552 static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
553                                             FlatAffineConstraints *cst) {
554   for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
555     auto value = cst->getIdValue(i);
556     if (ivs.count(value) == 0) {
557       assert(isForInductionVar(value));
558       auto loop = getForInductionVarOwner(value);
559       if (failed(cst->addAffineForOpDomain(loop)))
560         return failure();
561     }
562   }
563   return success();
564 }
565 
566 /// Returns the innermost common loop depth for the set of operations in 'ops'.
567 // TODO: Move this to LoopUtils.
getInnermostCommonLoopDepth(ArrayRef<Operation * > ops,SmallVectorImpl<AffineForOp> * surroundingLoops)568 unsigned mlir::getInnermostCommonLoopDepth(
569     ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
570   unsigned numOps = ops.size();
571   assert(numOps > 0 && "Expected at least one operation");
572 
573   std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
574   unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
575   for (unsigned i = 0; i < numOps; ++i) {
576     getLoopIVs(*ops[i], &loops[i]);
577     loopDepthLimit =
578         std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
579   }
580 
581   unsigned loopDepth = 0;
582   for (unsigned d = 0; d < loopDepthLimit; ++d) {
583     unsigned i;
584     for (i = 1; i < numOps; ++i) {
585       if (loops[i - 1][d] != loops[i][d])
586         return loopDepth;
587     }
588     if (surroundingLoops)
589       surroundingLoops->push_back(loops[i - 1][d]);
590     ++loopDepth;
591   }
592   return loopDepth;
593 }
594 
595 /// Computes in 'sliceUnion' the union of all slice bounds computed at
596 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
597 /// Returns 'Success' if union was computed, 'failure' otherwise.
computeSliceUnion(ArrayRef<Operation * > opsA,ArrayRef<Operation * > opsB,unsigned loopDepth,unsigned numCommonLoops,bool isBackwardSlice,ComputationSliceState * sliceUnion)598 LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
599                                       ArrayRef<Operation *> opsB,
600                                       unsigned loopDepth,
601                                       unsigned numCommonLoops,
602                                       bool isBackwardSlice,
603                                       ComputationSliceState *sliceUnion) {
604   // Compute the union of slice bounds between all pairs in 'opsA' and
605   // 'opsB' in 'sliceUnionCst'.
606   FlatAffineConstraints sliceUnionCst;
607   assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
608   std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
609   for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) {
610     MemRefAccess srcAccess(opsA[i]);
611     for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) {
612       MemRefAccess dstAccess(opsB[j]);
613       if (srcAccess.memref != dstAccess.memref)
614         continue;
615       // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
616       if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
617           (isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
618         LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
619         return failure();
620       }
621 
622       bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
623                               isa<AffineReadOpInterface>(dstAccess.opInst);
624       FlatAffineConstraints dependenceConstraints;
625       // Check dependence between 'srcAccess' and 'dstAccess'.
626       DependenceResult result = checkMemrefAccessDependence(
627           srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
628           &dependenceConstraints, /*dependenceComponents=*/nullptr,
629           /*allowRAR=*/readReadAccesses);
630       if (result.value == DependenceResult::Failure) {
631         LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
632         return failure();
633       }
634       if (result.value == DependenceResult::NoDependence)
635         continue;
636       dependentOpPairs.push_back({opsA[i], opsB[j]});
637 
638       // Compute slice bounds for 'srcAccess' and 'dstAccess'.
639       ComputationSliceState tmpSliceState;
640       mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints,
641                                      loopDepth, isBackwardSlice,
642                                      &tmpSliceState);
643 
644       if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
645         // Initialize 'sliceUnionCst' with the bounds computed in previous step.
646         if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
647           LLVM_DEBUG(llvm::dbgs()
648                      << "Unable to compute slice bound constraints\n");
649           return failure();
650         }
651         assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
652         continue;
653       }
654 
655       // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
656       FlatAffineConstraints tmpSliceCst;
657       if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
658         LLVM_DEBUG(llvm::dbgs()
659                    << "Unable to compute slice bound constraints\n");
660         return failure();
661       }
662 
663       // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
664       if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
665 
666         // Pre-constraint id alignment: record loop IVs used in each constraint
667         // system.
668         SmallPtrSet<Value, 8> sliceUnionIVs;
669         for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
670           sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
671         SmallPtrSet<Value, 8> tmpSliceIVs;
672         for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
673           tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
674 
675         sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
676 
677         // Post-constraint id alignment: add loop IV bounds missing after
678         // id alignment to constraint systems. This can occur if one constraint
679         // system uses an loop IV that is not used by the other. The call
680         // to unionBoundingBox below expects constraints for each Loop IV, even
681         // if they are the unsliced full loop bounds added here.
682         if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
683           return failure();
684         if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
685           return failure();
686       }
687       // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
688       if (sliceUnionCst.getNumLocalIds() > 0 ||
689           tmpSliceCst.getNumLocalIds() > 0 ||
690           failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
691         LLVM_DEBUG(llvm::dbgs()
692                    << "Unable to compute union bounding box of slice bounds\n");
693         return failure();
694       }
695     }
696   }
697 
698   // Empty union.
699   if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
700     return failure();
701 
702   // Gather loops surrounding ops from loop nest where slice will be inserted.
703   SmallVector<Operation *, 4> ops;
704   for (auto &dep : dependentOpPairs) {
705     ops.push_back(isBackwardSlice ? dep.second : dep.first);
706   }
707   SmallVector<AffineForOp, 4> surroundingLoops;
708   unsigned innermostCommonLoopDepth =
709       getInnermostCommonLoopDepth(ops, &surroundingLoops);
710   if (loopDepth > innermostCommonLoopDepth) {
711     LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
712     return failure();
713   }
714 
715   // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
716   unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds();
717 
718   // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
719   sliceUnionCst.convertLoopIVSymbolsToDims();
720   sliceUnion->clearBounds();
721   sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
722   sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
723 
724   // Get slice bounds from slice union constraints 'sliceUnionCst'.
725   sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
726                                opsA[0]->getContext(), &sliceUnion->lbs,
727                                &sliceUnion->ubs);
728 
729   // Add slice bound operands of union.
730   SmallVector<Value, 4> sliceBoundOperands;
731   sliceUnionCst.getIdValues(numSliceLoopIVs,
732                             sliceUnionCst.getNumDimAndSymbolIds(),
733                             &sliceBoundOperands);
734 
735   // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
736   sliceUnion->ivs.clear();
737   sliceUnionCst.getIdValues(0, numSliceLoopIVs, &sliceUnion->ivs);
738 
739   // Set loop nest insertion point to block start at 'loopDepth'.
740   sliceUnion->insertPoint =
741       isBackwardSlice
742           ? surroundingLoops[loopDepth - 1].getBody()->begin()
743           : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
744 
745   // Give each bound its own copy of 'sliceBoundOperands' for subsequent
746   // canonicalization.
747   sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
748   sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
749   return success();
750 }
751 
752 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
753 // Computes slice bounds by projecting out any loop IVs from
754 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
755 // bounds in 'sliceState' which represent the one loop nest's IVs in terms of
756 // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
getComputationSliceState(Operation * depSourceOp,Operation * depSinkOp,FlatAffineConstraints * dependenceConstraints,unsigned loopDepth,bool isBackwardSlice,ComputationSliceState * sliceState)757 void mlir::getComputationSliceState(
758     Operation *depSourceOp, Operation *depSinkOp,
759     FlatAffineConstraints *dependenceConstraints, unsigned loopDepth,
760     bool isBackwardSlice, ComputationSliceState *sliceState) {
761   // Get loop nest surrounding src operation.
762   SmallVector<AffineForOp, 4> srcLoopIVs;
763   getLoopIVs(*depSourceOp, &srcLoopIVs);
764   unsigned numSrcLoopIVs = srcLoopIVs.size();
765 
766   // Get loop nest surrounding dst operation.
767   SmallVector<AffineForOp, 4> dstLoopIVs;
768   getLoopIVs(*depSinkOp, &dstLoopIVs);
769   unsigned numDstLoopIVs = dstLoopIVs.size();
770 
771   assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
772          (isBackwardSlice && loopDepth <= numDstLoopIVs));
773 
774   // Project out dimensions other than those up to 'loopDepth'.
775   unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
776   unsigned num =
777       isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
778   dependenceConstraints->projectOut(pos, num);
779 
780   // Add slice loop IV values to 'sliceState'.
781   unsigned offset = isBackwardSlice ? 0 : loopDepth;
782   unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
783   dependenceConstraints->getIdValues(offset, offset + numSliceLoopIVs,
784                                      &sliceState->ivs);
785 
786   // Set up lower/upper bound affine maps for the slice.
787   sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
788   sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
789 
790   // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
791   dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
792                                         depSourceOp->getContext(),
793                                         &sliceState->lbs, &sliceState->ubs);
794 
795   // Set up bound operands for the slice's lower and upper bounds.
796   SmallVector<Value, 4> sliceBoundOperands;
797   unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds();
798   for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
799     if (i < offset || i >= offset + numSliceLoopIVs) {
800       sliceBoundOperands.push_back(dependenceConstraints->getIdValue(i));
801     }
802   }
803 
804   // Give each bound its own copy of 'sliceBoundOperands' for subsequent
805   // canonicalization.
806   sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
807   sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
808 
809   // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
810   sliceState->insertPoint =
811       isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
812                       : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
813 
814   llvm::SmallDenseSet<Value, 8> sequentialLoops;
815   if (isa<AffineReadOpInterface>(depSourceOp) &&
816       isa<AffineReadOpInterface>(depSinkOp)) {
817     // For read-read access pairs, clear any slice bounds on sequential loops.
818     // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
819     getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
820                        &sequentialLoops);
821   }
822   // Clear all sliced loop bounds beginning at the first sequential loop, or
823   // first loop with a slice fusion barrier attribute..
824   // TODO: Use MemRef read/write regions instead of
825   // using 'kSliceFusionBarrierAttrName'.
826   auto getSliceLoop = [&](unsigned i) {
827     return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
828   };
829   for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
830     Value iv = getSliceLoop(i).getInductionVar();
831     if (sequentialLoops.count(iv) == 0 &&
832         getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr)
833       continue;
834     for (unsigned j = i; j < numSliceLoopIVs; ++j) {
835       sliceState->lbs[j] = AffineMap();
836       sliceState->ubs[j] = AffineMap();
837     }
838     break;
839   }
840 }
841 
842 /// Creates a computation slice of the loop nest surrounding 'srcOpInst',
843 /// updates the slice loop bounds with any non-null bound maps specified in
844 /// 'sliceState', and inserts this slice into the loop nest surrounding
845 /// 'dstOpInst' at loop depth 'dstLoopDepth'.
846 // TODO: extend the slicing utility to compute slices that
847 // aren't necessarily a one-to-one relation b/w the source and destination. The
848 // relation between the source and destination could be many-to-many in general.
849 // TODO: the slice computation is incorrect in the cases
850 // where the dependence from the source to the destination does not cover the
851 // entire destination index set. Subtract out the dependent destination
852 // iterations from destination index set and check for emptiness --- this is one
853 // solution.
854 AffineForOp
insertBackwardComputationSlice(Operation * srcOpInst,Operation * dstOpInst,unsigned dstLoopDepth,ComputationSliceState * sliceState)855 mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
856                                      unsigned dstLoopDepth,
857                                      ComputationSliceState *sliceState) {
858   // Get loop nest surrounding src operation.
859   SmallVector<AffineForOp, 4> srcLoopIVs;
860   getLoopIVs(*srcOpInst, &srcLoopIVs);
861   unsigned numSrcLoopIVs = srcLoopIVs.size();
862 
863   // Get loop nest surrounding dst operation.
864   SmallVector<AffineForOp, 4> dstLoopIVs;
865   getLoopIVs(*dstOpInst, &dstLoopIVs);
866   unsigned dstLoopIVsSize = dstLoopIVs.size();
867   if (dstLoopDepth > dstLoopIVsSize) {
868     dstOpInst->emitError("invalid destination loop depth");
869     return AffineForOp();
870   }
871 
872   // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
873   SmallVector<unsigned, 4> positions;
874   // TODO: This code is incorrect since srcLoopIVs can be 0-d.
875   findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
876 
877   // Clone src loop nest and insert it a the beginning of the operation block
878   // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
879   auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
880   OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
881   auto sliceLoopNest =
882       cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
883 
884   Operation *sliceInst =
885       getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
886   // Get loop nest surrounding 'sliceInst'.
887   SmallVector<AffineForOp, 4> sliceSurroundingLoops;
888   getLoopIVs(*sliceInst, &sliceSurroundingLoops);
889 
890   // Sanity check.
891   unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
892   (void)sliceSurroundingLoopsSize;
893   assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
894   unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
895   (void)sliceLoopLimit;
896   assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
897 
898   // Update loop bounds for loops in 'sliceLoopNest'.
899   for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
900     auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
901     if (AffineMap lbMap = sliceState->lbs[i])
902       forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
903     if (AffineMap ubMap = sliceState->ubs[i])
904       forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
905   }
906   return sliceLoopNest;
907 }
908 
909 // Constructs  MemRefAccess populating it with the memref, its indices and
910 // opinst from 'loadOrStoreOpInst'.
MemRefAccess(Operation * loadOrStoreOpInst)911 MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
912   if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
913     memref = loadOp.getMemRef();
914     opInst = loadOrStoreOpInst;
915     auto loadMemrefType = loadOp.getMemRefType();
916     indices.reserve(loadMemrefType.getRank());
917     for (auto index : loadOp.getMapOperands()) {
918       indices.push_back(index);
919     }
920   } else {
921     assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
922            "Affine read/write op expected");
923     auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
924     opInst = loadOrStoreOpInst;
925     memref = storeOp.getMemRef();
926     auto storeMemrefType = storeOp.getMemRefType();
927     indices.reserve(storeMemrefType.getRank());
928     for (auto index : storeOp.getMapOperands()) {
929       indices.push_back(index);
930     }
931   }
932 }
933 
getRank() const934 unsigned MemRefAccess::getRank() const {
935   return memref.getType().cast<MemRefType>().getRank();
936 }
937 
isStore() const938 bool MemRefAccess::isStore() const {
939   return isa<AffineWriteOpInterface>(opInst);
940 }
941 
942 /// Returns the nesting depth of this statement, i.e., the number of loops
943 /// surrounding this statement.
getNestingDepth(Operation * op)944 unsigned mlir::getNestingDepth(Operation *op) {
945   Operation *currOp = op;
946   unsigned depth = 0;
947   while ((currOp = currOp->getParentOp())) {
948     if (isa<AffineForOp>(currOp))
949       depth++;
950   }
951   return depth;
952 }
953 
954 /// Equal if both affine accesses are provably equivalent (at compile
955 /// time) when considering the memref, the affine maps and their respective
956 /// operands. The equality of access functions + operands is checked by
957 /// subtracting fully composed value maps, and then simplifying the difference
958 /// using the expression flattener.
959 /// TODO: this does not account for aliasing of memrefs.
operator ==(const MemRefAccess & rhs) const960 bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
961   if (memref != rhs.memref)
962     return false;
963 
964   AffineValueMap diff, thisMap, rhsMap;
965   getAccessMap(&thisMap);
966   rhs.getAccessMap(&rhsMap);
967   AffineValueMap::difference(thisMap, rhsMap, &diff);
968   return llvm::all_of(diff.getAffineMap().getResults(),
969                       [](AffineExpr e) { return e == 0; });
970 }
971 
972 /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
973 /// where each lists loops from outer-most to inner-most in loop nest.
getNumCommonSurroundingLoops(Operation & A,Operation & B)974 unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {
975   SmallVector<AffineForOp, 4> loopsA, loopsB;
976   getLoopIVs(A, &loopsA);
977   getLoopIVs(B, &loopsB);
978 
979   unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
980   unsigned numCommonLoops = 0;
981   for (unsigned i = 0; i < minNumLoops; ++i) {
982     if (loopsA[i].getOperation() != loopsB[i].getOperation())
983       break;
984     ++numCommonLoops;
985   }
986   return numCommonLoops;
987 }
988 
getMemoryFootprintBytes(Block & block,Block::iterator start,Block::iterator end,int memorySpace)989 static Optional<int64_t> getMemoryFootprintBytes(Block &block,
990                                                  Block::iterator start,
991                                                  Block::iterator end,
992                                                  int memorySpace) {
993   SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
994 
995   // Walk this 'affine.for' operation to gather all memory regions.
996   auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
997     if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
998       // Neither load nor a store op.
999       return WalkResult::advance();
1000     }
1001 
1002     // Compute the memref region symbolic in any IVs enclosing this block.
1003     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
1004     if (failed(
1005             region->compute(opInst,
1006                             /*loopDepth=*/getNestingDepth(&*block.begin())))) {
1007       return opInst->emitError("error obtaining memory region\n");
1008     }
1009 
1010     auto it = regions.find(region->memref);
1011     if (it == regions.end()) {
1012       regions[region->memref] = std::move(region);
1013     } else if (failed(it->second->unionBoundingBox(*region))) {
1014       return opInst->emitWarning(
1015           "getMemoryFootprintBytes: unable to perform a union on a memory "
1016           "region");
1017     }
1018     return WalkResult::advance();
1019   });
1020   if (result.wasInterrupted())
1021     return None;
1022 
1023   int64_t totalSizeInBytes = 0;
1024   for (const auto &region : regions) {
1025     Optional<int64_t> size = region.second->getRegionSize();
1026     if (!size.hasValue())
1027       return None;
1028     totalSizeInBytes += size.getValue();
1029   }
1030   return totalSizeInBytes;
1031 }
1032 
getMemoryFootprintBytes(AffineForOp forOp,int memorySpace)1033 Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp,
1034                                                 int memorySpace) {
1035   auto *forInst = forOp.getOperation();
1036   return ::getMemoryFootprintBytes(
1037       *forInst->getBlock(), Block::iterator(forInst),
1038       std::next(Block::iterator(forInst)), memorySpace);
1039 }
1040 
1041 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
1042 /// at 'forOp'.
getSequentialLoops(AffineForOp forOp,llvm::SmallDenseSet<Value,8> * sequentialLoops)1043 void mlir::getSequentialLoops(AffineForOp forOp,
1044                               llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
1045   forOp->walk([&](Operation *op) {
1046     if (auto innerFor = dyn_cast<AffineForOp>(op))
1047       if (!isLoopParallel(innerFor))
1048         sequentialLoops->insert(innerFor.getInductionVar());
1049   });
1050 }
1051 
1052 /// Returns true if 'forOp' is parallel.
isLoopParallel(AffineForOp forOp)1053 bool mlir::isLoopParallel(AffineForOp forOp) {
1054   // Collect all load and store ops in loop nest rooted at 'forOp'.
1055   SmallVector<Operation *, 8> loadAndStoreOpInsts;
1056   auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
1057     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst))
1058       loadAndStoreOpInsts.push_back(opInst);
1059     else if (!isa<AffineForOp, AffineYieldOp, AffineIfOp>(opInst) &&
1060              !MemoryEffectOpInterface::hasNoEffect(opInst))
1061       return WalkResult::interrupt();
1062 
1063     return WalkResult::advance();
1064   });
1065 
1066   // Stop early if the loop has unknown ops with side effects.
1067   if (walkResult.wasInterrupted())
1068     return false;
1069 
1070   // Dep check depth would be number of enclosing loops + 1.
1071   unsigned depth = getNestingDepth(forOp) + 1;
1072 
1073   // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'.
1074   for (auto *srcOpInst : loadAndStoreOpInsts) {
1075     MemRefAccess srcAccess(srcOpInst);
1076     for (auto *dstOpInst : loadAndStoreOpInsts) {
1077       MemRefAccess dstAccess(dstOpInst);
1078       FlatAffineConstraints dependenceConstraints;
1079       DependenceResult result = checkMemrefAccessDependence(
1080           srcAccess, dstAccess, depth, &dependenceConstraints,
1081           /*dependenceComponents=*/nullptr);
1082       if (result.value != DependenceResult::NoDependence)
1083         return false;
1084     }
1085   }
1086   return true;
1087 }
1088 
simplifyIntegerSet(IntegerSet set)1089 IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
1090   FlatAffineConstraints fac(set);
1091   if (fac.isEmpty())
1092     return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
1093                                    set.getContext());
1094   fac.removeTrivialRedundancy();
1095 
1096   auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
1097   assert(simplifiedSet && "guaranteed to succeed while roundtripping");
1098   return simplifiedSet;
1099 }
1100