• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallBitVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 
23 using namespace mlir;
24 using llvm::dbgs;
25 
26 #define DEBUG_TYPE "affine-analysis"
27 
28 //===----------------------------------------------------------------------===//
29 // AffineDialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 /// This class defines the interface for handling inlining with affine
34 /// operations.
35 struct AffineInlinerInterface : public DialectInlinerInterface {
36   using DialectInlinerInterface::DialectInlinerInterface;
37 
38   //===--------------------------------------------------------------------===//
39   // Analysis Hooks
40   //===--------------------------------------------------------------------===//
41 
42   /// Returns true if the given region 'src' can be inlined into the region
43   /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anon15f1ca220111::AffineInlinerInterface44   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
45                        BlockAndValueMapping &valueMapping) const final {
46     // Conservatively don't allow inlining into affine structures.
47     return false;
48   }
49 
50   /// Returns true if the given operation 'op', that is registered to this
51   /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anon15f1ca220111::AffineInlinerInterface52   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
53                        BlockAndValueMapping &valueMapping) const final {
54     // Always allow inlining affine operations into the top-level region of a
55     // function. There are some edge cases when inlining *into* affine
56     // structures, but that is handled in the other 'isLegalToInline' hook
57     // above.
58     // TODO: We should be able to inline into other regions than functions.
59     return isa<FuncOp>(region->getParentOp());
60   }
61 
62   /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anon15f1ca220111::AffineInlinerInterface63   bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
64 };
65 } // end anonymous namespace
66 
67 //===----------------------------------------------------------------------===//
68 // AffineDialect
69 //===----------------------------------------------------------------------===//
70 
initialize()71 void AffineDialect::initialize() {
72   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
73 #define GET_OP_LIST
74 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
75                 >();
76   addInterfaces<AffineInlinerInterface>();
77 }
78 
79 /// Materialize a single constant operation from a given attribute value with
80 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)81 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
82                                               Attribute value, Type type,
83                                               Location loc) {
84   return builder.create<ConstantOp>(loc, type, value);
85 }
86 
87 /// A utility function to check if a value is defined at the top level of an
88 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
89 /// conservatively assume it is not top-level. A value of index type defined at
90 /// the top level is always a valid symbol.
isTopLevelValue(Value value)91 bool mlir::isTopLevelValue(Value value) {
92   if (auto arg = value.dyn_cast<BlockArgument>()) {
93     // The block owning the argument may be unlinked, e.g. when the surrounding
94     // region has not yet been attached to an Op, at which point the parent Op
95     // is null.
96     Operation *parentOp = arg.getOwner()->getParentOp();
97     return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
98   }
99   // The defining Op may live in an unlinked block so its parent Op may be null.
100   Operation *parentOp = value.getDefiningOp()->getParentOp();
101   return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
102 }
103 
104 /// A utility function to check if a value is defined at the top level of
105 /// `region` or is an argument of `region`. A value of index type defined at the
106 /// top level of a `AffineScope` region is always a valid symbol for all
107 /// uses in that region.
isTopLevelValue(Value value,Region * region)108 static bool isTopLevelValue(Value value, Region *region) {
109   if (auto arg = value.dyn_cast<BlockArgument>())
110     return arg.getParentRegion() == region;
111   return value.getDefiningOp()->getParentRegion() == region;
112 }
113 
114 /// Returns the closest region enclosing `op` that is held by an operation with
115 /// trait `AffineScope`; `nullptr` if there is no such region.
116 //  TODO: getAffineScope should be publicly exposed for affine passes/utilities.
getAffineScope(Operation * op)117 static Region *getAffineScope(Operation *op) {
118   auto *curOp = op;
119   while (auto *parentOp = curOp->getParentOp()) {
120     if (parentOp->hasTrait<OpTrait::AffineScope>())
121       return curOp->getParentRegion();
122     curOp = parentOp;
123   }
124   return nullptr;
125 }
126 
127 // A Value can be used as a dimension id iff it meets one of the following
128 // conditions:
129 // *) It is valid as a symbol.
130 // *) It is an induction variable.
131 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)132 bool mlir::isValidDim(Value value) {
133   // The value must be an index type.
134   if (!value.getType().isIndex())
135     return false;
136 
137   if (auto *defOp = value.getDefiningOp())
138     return isValidDim(value, getAffineScope(defOp));
139 
140   // This value has to be a block argument for an op that has the
141   // `AffineScope` trait or for an affine.for or affine.parallel.
142   auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
143   return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
144                       isa<AffineForOp, AffineParallelOp>(parentOp));
145 }
146 
147 // Value can be used as a dimension id iff it meets one of the following
148 // conditions:
149 // *) It is valid as a symbol.
150 // *) It is an induction variable.
151 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)152 bool mlir::isValidDim(Value value, Region *region) {
153   // The value must be an index type.
154   if (!value.getType().isIndex())
155     return false;
156 
157   // All valid symbols are okay.
158   if (isValidSymbol(value, region))
159     return true;
160 
161   auto *op = value.getDefiningOp();
162   if (!op) {
163     // This value has to be a block argument for an affine.for or an
164     // affine.parallel.
165     auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
166     return isa<AffineForOp, AffineParallelOp>(parentOp);
167   }
168 
169   // Affine apply operation is ok if all of its operands are ok.
170   if (auto applyOp = dyn_cast<AffineApplyOp>(op))
171     return applyOp.isValidDim(region);
172   // The dim op is okay if its operand memref/tensor is defined at the top
173   // level.
174   if (auto dimOp = dyn_cast<DimOp>(op))
175     return isTopLevelValue(dimOp.memrefOrTensor());
176   return false;
177 }
178 
179 /// Returns true if the 'index' dimension of the `memref` defined by
180 /// `memrefDefOp` is a statically  shaped one or defined using a valid symbol
181 /// for `region`.
182 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)183 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
184                                     Region *region) {
185   auto memRefType = memrefDefOp.getType();
186   // Statically shaped.
187   if (!memRefType.isDynamicDim(index))
188     return true;
189   // Get the position of the dimension among dynamic dimensions;
190   unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
191   return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
192                        region);
193 }
194 
195 /// Returns true if the result of the dim op is a valid symbol for `region`.
isDimOpValidSymbol(DimOp dimOp,Region * region)196 static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
197   // The dim op is okay if its operand memref/tensor is defined at the top
198   // level.
199   if (isTopLevelValue(dimOp.memrefOrTensor()))
200     return true;
201 
202   // Conservatively handle remaining BlockArguments as non-valid symbols.
203   // E.g. scf.for iterArgs.
204   if (dimOp.memrefOrTensor().isa<BlockArgument>())
205     return false;
206 
207   // The dim op is also okay if its operand memref/tensor is a view/subview
208   // whose corresponding size is a valid symbol.
209   Optional<int64_t> index = dimOp.getConstantIndex();
210   assert(index.hasValue() &&
211          "expect only `dim` operations with a constant index");
212   int64_t i = index.getValue();
213   return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
214       .Case<ViewOp, SubViewOp, AllocOp>(
215           [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
216       .Default([](Operation *) { return false; });
217 }
218 
219 // A value can be used as a symbol (at all its use sites) iff it meets one of
220 // the following conditions:
221 // *) It is a constant.
222 // *) Its defining op or block arg appearance is immediately enclosed by an op
223 //    with `AffineScope` trait.
224 // *) It is the result of an affine.apply operation with symbol operands.
225 // *) It is a result of the dim op on a memref whose corresponding size is a
226 //    valid symbol.
isValidSymbol(Value value)227 bool mlir::isValidSymbol(Value value) {
228   // The value must be an index type.
229   if (!value.getType().isIndex())
230     return false;
231 
232   // Check that the value is a top level value.
233   if (isTopLevelValue(value))
234     return true;
235 
236   if (auto *defOp = value.getDefiningOp())
237     return isValidSymbol(value, getAffineScope(defOp));
238 
239   return false;
240 }
241 
242 /// A value can be used as a symbol for `region` iff it meets onf of the the
243 /// following conditions:
244 /// *) It is a constant.
245 /// *) It is the result of an affine apply operation with symbol arguments.
246 /// *) It is a result of the dim op on a memref whose corresponding size is
247 ///    a valid symbol.
248 /// *) It is defined at the top level of 'region' or is its argument.
249 /// *) It dominates `region`'s parent op.
250 /// If `region` is null, conservatively assume the symbol definition scope does
251 /// not exist and only accept the values that would be symbols regardless of
252 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)253 bool mlir::isValidSymbol(Value value, Region *region) {
254   // The value must be an index type.
255   if (!value.getType().isIndex())
256     return false;
257 
258   // A top-level value is a valid symbol.
259   if (region && ::isTopLevelValue(value, region))
260     return true;
261 
262   auto *defOp = value.getDefiningOp();
263   if (!defOp) {
264     // A block argument that is not a top-level value is a valid symbol if it
265     // dominates region's parent op.
266     if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
267       if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
268         return isValidSymbol(value, parentOpRegion);
269     return false;
270   }
271 
272   // Constant operation is ok.
273   Attribute operandCst;
274   if (matchPattern(defOp, m_Constant(&operandCst)))
275     return true;
276 
277   // Affine apply operation is ok if all of its operands are ok.
278   if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
279     return applyOp.isValidSymbol(region);
280 
281   // Dim op results could be valid symbols at any level.
282   if (auto dimOp = dyn_cast<DimOp>(defOp))
283     return isDimOpValidSymbol(dimOp, region);
284 
285   // Check for values dominating `region`'s parent op.
286   if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
287     if (auto *parentRegion = region->getParentOp()->getParentRegion())
288       return isValidSymbol(value, parentRegion);
289 
290   return false;
291 }
292 
293 // Returns true if 'value' is a valid index to an affine operation (e.g.
294 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
295 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)296 static bool isValidAffineIndexOperand(Value value, Region *region) {
297   return isValidDim(value, region) || isValidSymbol(value, region);
298 }
299 
300 /// Prints dimension and symbol list.
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & printer)301 static void printDimAndSymbolList(Operation::operand_iterator begin,
302                                   Operation::operand_iterator end,
303                                   unsigned numDims, OpAsmPrinter &printer) {
304   OperandRange operands(begin, end);
305   printer << '(' << operands.take_front(numDims) << ')';
306   if (operands.size() > numDims)
307     printer << '[' << operands.drop_front(numDims) << ']';
308 }
309 
310 /// Parses dimension and symbol list and returns true if parsing failed.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)311 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
312                                         SmallVectorImpl<Value> &operands,
313                                         unsigned &numDims) {
314   SmallVector<OpAsmParser::OperandType, 8> opInfos;
315   if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
316     return failure();
317   // Store number of dimensions for validation by caller.
318   numDims = opInfos.size();
319 
320   // Parse the optional symbol operands.
321   auto indexTy = parser.getBuilder().getIndexType();
322   return failure(parser.parseOperandList(
323                      opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
324                  parser.resolveOperands(opInfos, indexTy, operands));
325 }
326 
327 /// Utility function to verify that a set of operands are valid dimension and
328 /// symbol identifiers. The operands should be laid out such that the dimension
329 /// operands are before the symbol operands. This function returns failure if
330 /// there was an invalid operand. An operation is provided to emit any necessary
331 /// errors.
332 template <typename OpTy>
333 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)334 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
335                               unsigned numDims) {
336   unsigned opIt = 0;
337   for (auto operand : operands) {
338     if (opIt++ < numDims) {
339       if (!isValidDim(operand, getAffineScope(op)))
340         return op.emitOpError("operand cannot be used as a dimension id");
341     } else if (!isValidSymbol(operand, getAffineScope(op))) {
342       return op.emitOpError("operand cannot be used as a symbol");
343     }
344   }
345   return success();
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // AffineApplyOp
350 //===----------------------------------------------------------------------===//
351 
getAffineValueMap()352 AffineValueMap AffineApplyOp::getAffineValueMap() {
353   return AffineValueMap(getAffineMap(), getOperands(), getResult());
354 }
355 
parseAffineApplyOp(OpAsmParser & parser,OperationState & result)356 static ParseResult parseAffineApplyOp(OpAsmParser &parser,
357                                       OperationState &result) {
358   auto &builder = parser.getBuilder();
359   auto indexTy = builder.getIndexType();
360 
361   AffineMapAttr mapAttr;
362   unsigned numDims;
363   if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
364       parseDimAndSymbolList(parser, result.operands, numDims) ||
365       parser.parseOptionalAttrDict(result.attributes))
366     return failure();
367   auto map = mapAttr.getValue();
368 
369   if (map.getNumDims() != numDims ||
370       numDims + map.getNumSymbols() != result.operands.size()) {
371     return parser.emitError(parser.getNameLoc(),
372                             "dimension or symbol index mismatch");
373   }
374 
375   result.types.append(map.getNumResults(), indexTy);
376   return success();
377 }
378 
print(OpAsmPrinter & p,AffineApplyOp op)379 static void print(OpAsmPrinter &p, AffineApplyOp op) {
380   p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
381   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
382                         op.getAffineMap().getNumDims(), p);
383   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
384 }
385 
verify(AffineApplyOp op)386 static LogicalResult verify(AffineApplyOp op) {
387   // Check input and output dimensions match.
388   auto map = op.map();
389 
390   // Verify that operand count matches affine map dimension and symbol count.
391   if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
392     return op.emitOpError(
393         "operand count and affine map dimension and symbol count must match");
394 
395   // Verify that the map only produces one result.
396   if (map.getNumResults() != 1)
397     return op.emitOpError("mapping must produce one value");
398 
399   return success();
400 }
401 
402 // The result of the affine apply operation can be used as a dimension id if all
403 // its operands are valid dimension ids.
isValidDim()404 bool AffineApplyOp::isValidDim() {
405   return llvm::all_of(getOperands(),
406                       [](Value op) { return mlir::isValidDim(op); });
407 }
408 
409 // The result of the affine apply operation can be used as a dimension id if all
410 // its operands are valid dimension ids with the parent operation of `region`
411 // defining the polyhedral scope for symbols.
isValidDim(Region * region)412 bool AffineApplyOp::isValidDim(Region *region) {
413   return llvm::all_of(getOperands(),
414                       [&](Value op) { return ::isValidDim(op, region); });
415 }
416 
417 // The result of the affine apply operation can be used as a symbol if all its
418 // operands are symbols.
isValidSymbol()419 bool AffineApplyOp::isValidSymbol() {
420   return llvm::all_of(getOperands(),
421                       [](Value op) { return mlir::isValidSymbol(op); });
422 }
423 
424 // The result of the affine apply operation can be used as a symbol in `region`
425 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)426 bool AffineApplyOp::isValidSymbol(Region *region) {
427   return llvm::all_of(getOperands(), [&](Value operand) {
428     return mlir::isValidSymbol(operand, region);
429   });
430 }
431 
fold(ArrayRef<Attribute> operands)432 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
433   auto map = getAffineMap();
434 
435   // Fold dims and symbols to existing values.
436   auto expr = map.getResult(0);
437   if (auto dim = expr.dyn_cast<AffineDimExpr>())
438     return getOperand(dim.getPosition());
439   if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
440     return getOperand(map.getNumDims() + sym.getPosition());
441 
442   // Otherwise, default to folding the map.
443   SmallVector<Attribute, 1> result;
444   if (failed(map.constantFold(operands, result)))
445     return {};
446   return result[0];
447 }
448 
renumberOneDim(Value v)449 AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
450   DenseMap<Value, unsigned>::iterator iterPos;
451   bool inserted = false;
452   std::tie(iterPos, inserted) =
453       dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
454   if (inserted) {
455     reorderedDims.push_back(v);
456   }
457   return getAffineDimExpr(iterPos->second, v.getContext())
458       .cast<AffineDimExpr>();
459 }
460 
renumber(const AffineApplyNormalizer & other)461 AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
462   SmallVector<AffineExpr, 8> dimRemapping;
463   for (auto v : other.reorderedDims) {
464     auto kvp = other.dimValueToPosition.find(v);
465     if (dimRemapping.size() <= kvp->second)
466       dimRemapping.resize(kvp->second + 1);
467     dimRemapping[kvp->second] = renumberOneDim(kvp->first);
468   }
469   unsigned numSymbols = concatenatedSymbols.size();
470   unsigned numOtherSymbols = other.concatenatedSymbols.size();
471   SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
472   for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
473     symRemapping[idx] =
474         getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
475   }
476   concatenatedSymbols.insert(concatenatedSymbols.end(),
477                              other.concatenatedSymbols.begin(),
478                              other.concatenatedSymbols.end());
479   auto map = other.affineMap;
480   return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
481                                    reorderedDims.size(),
482                                    concatenatedSymbols.size());
483 }
484 
485 // Gather the positions of the operands that are produced by an AffineApplyOp.
486 static llvm::SetVector<unsigned>
indicesFromAffineApplyOp(ArrayRef<Value> operands)487 indicesFromAffineApplyOp(ArrayRef<Value> operands) {
488   llvm::SetVector<unsigned> res;
489   for (auto en : llvm::enumerate(operands))
490     if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp()))
491       res.insert(en.index());
492   return res;
493 }
494 
495 // Support the special case of a symbol coming from an AffineApplyOp that needs
496 // to be composed into the current AffineApplyOp.
497 // This case is handled by rewriting all such symbols into dims for the purpose
498 // of allowing mathematical AffineMap composition.
499 // Returns an AffineMap where symbols that come from an AffineApplyOp have been
500 // rewritten as dims and are ordered after the original dims.
501 // TODO: This promotion makes AffineMap lose track of which
502 // symbols are represented as dims. This loss is static but can still be
503 // recovered dynamically (with `isValidSymbol`). Still this is annoying for the
504 // semi-affine map case. A dynamic canonicalization of all dims that are valid
505 // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
506 // results in better simplifications and foldings. But we should evaluate
507 // whether this behavior is what we really want after using more.
promoteComposedSymbolsAsDims(AffineMap map,ArrayRef<Value> symbols)508 static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
509                                               ArrayRef<Value> symbols) {
510   if (symbols.empty()) {
511     return map;
512   }
513 
514   // Sanity check on symbols.
515   for (auto sym : symbols) {
516     assert(isValidSymbol(sym) && "Expected only valid symbols");
517     (void)sym;
518   }
519 
520   // Extract the symbol positions that come from an AffineApplyOp and
521   // needs to be rewritten as dims.
522   auto symPositions = indicesFromAffineApplyOp(symbols);
523   if (symPositions.empty()) {
524     return map;
525   }
526 
527   // Create the new map by replacing each symbol at pos by the next new dim.
528   unsigned numDims = map.getNumDims();
529   unsigned numSymbols = map.getNumSymbols();
530   unsigned numNewDims = 0;
531   unsigned numNewSymbols = 0;
532   SmallVector<AffineExpr, 8> symReplacements(numSymbols);
533   for (unsigned i = 0; i < numSymbols; ++i) {
534     symReplacements[i] =
535         symPositions.count(i) > 0
536             ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
537             : getAffineSymbolExpr(numNewSymbols++, map.getContext());
538   }
539   assert(numSymbols >= numNewDims);
540   AffineMap newMap = map.replaceDimsAndSymbols(
541       {}, symReplacements, numDims + numNewDims, numNewSymbols);
542 
543   return newMap;
544 }
545 
546 /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
547 /// keep a correspondence between the mathematical `map` and the `operands` of
548 /// a given AffineApplyOp. This correspondence is maintained by iterating over
549 /// the operands and forming an `auxiliaryMap` that can be composed
550 /// mathematically with `map`. To keep this correspondence in cases where
551 /// symbols are produced by affine.apply operations, we perform a local rewrite
552 /// of symbols as dims.
553 ///
554 /// Rationale for locally rewriting symbols as dims:
555 /// ================================================
556 /// The mathematical composition of AffineMap must always concatenate symbols
557 /// because it does not have enough information to do otherwise. For example,
558 /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
559 /// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
560 ///
561 /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
562 /// applied to the same mlir::Value for both s0 and s1.
563 /// As a consequence mathematical composition of AffineMap always concatenates
564 /// symbols.
565 ///
566 /// When AffineMaps are used in AffineApplyOp however, they may specify
567 /// composition via symbols, which is ambiguous mathematically. This corner case
568 /// is handled by locally rewriting such symbols that come from AffineApplyOp
569 /// into dims and composing through dims.
570 /// TODO: Composition via symbols comes at a significant code
571 /// complexity. Alternatively we should investigate whether we want to
572 /// explicitly disallow symbols coming from affine.apply and instead force the
573 /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
574 /// extra API calls for such uses, which haven't popped up until now) and the
575 /// benefit potentially big: simpler and more maintainable code for a
576 /// non-trivial, recursive, procedure.
AffineApplyNormalizer(AffineMap map,ArrayRef<Value> operands)577 AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
578                                              ArrayRef<Value> operands)
579     : AffineApplyNormalizer() {
580   static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
581   assert(map.getNumInputs() == operands.size() &&
582          "number of operands does not match the number of map inputs");
583 
584   LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
585 
586   // Promote symbols that come from an AffineApplyOp to dims by rewriting the
587   // map to always refer to:
588   //   (dims, symbols coming from AffineApplyOp, other symbols).
589   // The order of operands can remain unchanged.
590   // This is a simplification that relies on 2 ordering properties:
591   //   1. rewritten symbols always appear after the original dims in the map;
592   //   2. operands are traversed in order and either dispatched to:
593   //      a. auxiliaryExprs (dims and symbols rewritten as dims);
594   //      b. concatenatedSymbols (all other symbols)
595   // This allows operand order to remain unchanged.
596   unsigned numDimsBeforeRewrite = map.getNumDims();
597   map = promoteComposedSymbolsAsDims(map,
598                                      operands.take_back(map.getNumSymbols()));
599 
600   LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
601 
602   SmallVector<AffineExpr, 8> auxiliaryExprs;
603   bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
604   // We fully spell out the 2 cases below. In this particular instance a little
605   // code duplication greatly improves readability.
606   // Note that the first branch would disappear if we only supported full
607   // composition (i.e. infinite kMaxAffineApplyDepth).
608   if (!furtherCompose) {
609     // 1. Only dispatch dims or symbols.
610     for (auto en : llvm::enumerate(operands)) {
611       auto t = en.value();
612       assert(t.getType().isIndex());
613       bool isDim = (en.index() < map.getNumDims());
614       if (isDim) {
615         // a. The mathematical composition of AffineMap composes dims.
616         auxiliaryExprs.push_back(renumberOneDim(t));
617       } else {
618         // b. The mathematical composition of AffineMap concatenates symbols.
619         //    We do the same for symbol operands.
620         concatenatedSymbols.push_back(t);
621       }
622     }
623   } else {
624     assert(numDimsBeforeRewrite <= operands.size());
625     // 2. Compose AffineApplyOps and dispatch dims or symbols.
626     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
627       auto t = operands[i];
628       auto affineApply = t.getDefiningOp<AffineApplyOp>();
629       if (affineApply) {
630         // a. Compose affine.apply operations.
631         LLVM_DEBUG(affineApply->print(
632             dbgs() << "\nCompose AffineApplyOp recursively: "));
633         AffineMap affineApplyMap = affineApply.getAffineMap();
634         SmallVector<Value, 8> affineApplyOperands(
635             affineApply.getOperands().begin(), affineApply.getOperands().end());
636         AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
637 
638         LLVM_DEBUG(normalizer.affineMap.print(
639             dbgs() << "\nRenumber into current normalizer: "));
640 
641         auto renumberedMap = renumber(normalizer);
642 
643         LLVM_DEBUG(
644             renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
645 
646         auxiliaryExprs.push_back(renumberedMap.getResult(0));
647       } else {
648         if (i < numDimsBeforeRewrite) {
649           // b. The mathematical composition of AffineMap composes dims.
650           auxiliaryExprs.push_back(renumberOneDim(t));
651         } else {
652           // c. The mathematical composition of AffineMap concatenates symbols.
653           //    Note that the map composition will put symbols already present
654           //    in the map before any symbols coming from the auxiliary map, so
655           //    we insert them before any symbols that are due to renumbering,
656           //    and after the proper symbols we have seen already.
657           concatenatedSymbols.insert(
658               std::next(concatenatedSymbols.begin(), numProperSymbols++), t);
659         }
660       }
661     }
662   }
663 
664   // Early exit if `map` is already composed.
665   if (auxiliaryExprs.empty()) {
666     affineMap = map;
667     return;
668   }
669 
670   assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
671          "Unexpected number of concatenated symbols");
672   auto numDims = dimValueToPosition.size();
673   auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
674   auto auxiliaryMap =
675       AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext());
676 
677   LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
678   LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
679   LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
680 
681   // TODO: Disabling simplification results in major speed gains.
682   // Another option is to cache the results as it is expected a lot of redundant
683   // work is performed in practice.
684   affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
685 
686   LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
687   LLVM_DEBUG(dbgs() << "\n");
688 }
689 
normalize(AffineMap * otherMap,SmallVectorImpl<Value> * otherOperands)690 void AffineApplyNormalizer::normalize(AffineMap *otherMap,
691                                       SmallVectorImpl<Value> *otherOperands) {
692   AffineApplyNormalizer other(*otherMap, *otherOperands);
693   *otherMap = renumber(other);
694 
695   otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size());
696   otherOperands->assign(reorderedDims.begin(), reorderedDims.end());
697   otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end());
698 }
699 
700 /// Implements `map` and `operands` composition and simplification to support
701 /// `makeComposedAffineApply`. This can be called to achieve the same effects
702 /// on `map` and `operands` without creating an AffineApplyOp that needs to be
703 /// immediately deleted.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)704 static void composeAffineMapAndOperands(AffineMap *map,
705                                         SmallVectorImpl<Value> *operands) {
706   AffineApplyNormalizer normalizer(*map, *operands);
707   auto normalizedMap = normalizer.getAffineMap();
708   auto normalizedOperands = normalizer.getOperands();
709   canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
710   *map = normalizedMap;
711   *operands = normalizedOperands;
712   assert(*map);
713 }
714 
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)715 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
716                                             SmallVectorImpl<Value> *operands) {
717   while (llvm::any_of(*operands, [](Value v) {
718     return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
719   })) {
720     composeAffineMapAndOperands(map, operands);
721   }
722 }
723 
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> operands)724 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
725                                             AffineMap map,
726                                             ArrayRef<Value> operands) {
727   AffineMap normalizedMap = map;
728   SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
729   composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
730   assert(normalizedMap);
731   return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
732 }
733 
734 // A symbol may appear as a dim in affine.apply operations. This function
735 // canonicalizes dims that are valid symbols into actual symbols.
736 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)737 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
738                                         SmallVectorImpl<Value> *operands) {
739   if (!mapOrSet || operands->empty())
740     return;
741 
742   assert(mapOrSet->getNumInputs() == operands->size() &&
743          "map/set inputs must match number of operands");
744 
745   auto *context = mapOrSet->getContext();
746   SmallVector<Value, 8> resultOperands;
747   resultOperands.reserve(operands->size());
748   SmallVector<Value, 8> remappedSymbols;
749   remappedSymbols.reserve(operands->size());
750   unsigned nextDim = 0;
751   unsigned nextSym = 0;
752   unsigned oldNumSyms = mapOrSet->getNumSymbols();
753   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
754   for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
755     if (i < mapOrSet->getNumDims()) {
756       if (isValidSymbol((*operands)[i])) {
757         // This is a valid symbol that appears as a dim, canonicalize it.
758         dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
759         remappedSymbols.push_back((*operands)[i]);
760       } else {
761         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
762         resultOperands.push_back((*operands)[i]);
763       }
764     } else {
765       resultOperands.push_back((*operands)[i]);
766     }
767   }
768 
769   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
770   *operands = resultOperands;
771   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
772                                               oldNumSyms + nextSym);
773 
774   assert(mapOrSet->getNumInputs() == operands->size() &&
775          "map/set inputs must match number of operands");
776 }
777 
778 // Works for either an affine map or an integer set.
779 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)780 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
781                                             SmallVectorImpl<Value> *operands) {
782   static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
783                 "Argument must be either of AffineMap or IntegerSet type");
784 
785   if (!mapOrSet || operands->empty())
786     return;
787 
788   assert(mapOrSet->getNumInputs() == operands->size() &&
789          "map/set inputs must match number of operands");
790 
791   canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
792 
793   // Check to see what dims are used.
794   llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
795   llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
796   mapOrSet->walkExprs([&](AffineExpr expr) {
797     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
798       usedDims[dimExpr.getPosition()] = true;
799     else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
800       usedSyms[symExpr.getPosition()] = true;
801   });
802 
803   auto *context = mapOrSet->getContext();
804 
805   SmallVector<Value, 8> resultOperands;
806   resultOperands.reserve(operands->size());
807 
808   llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
809   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
810   unsigned nextDim = 0;
811   for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
812     if (usedDims[i]) {
813       // Remap dim positions for duplicate operands.
814       auto it = seenDims.find((*operands)[i]);
815       if (it == seenDims.end()) {
816         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
817         resultOperands.push_back((*operands)[i]);
818         seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
819       } else {
820         dimRemapping[i] = it->second;
821       }
822     }
823   }
824   llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
825   SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
826   unsigned nextSym = 0;
827   for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
828     if (!usedSyms[i])
829       continue;
830     // Handle constant operands (only needed for symbolic operands since
831     // constant operands in dimensional positions would have already been
832     // promoted to symbolic positions above).
833     IntegerAttr operandCst;
834     if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
835                      m_Constant(&operandCst))) {
836       symRemapping[i] =
837           getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
838       continue;
839     }
840     // Remap symbol positions for duplicate operands.
841     auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
842     if (it == seenSymbols.end()) {
843       symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
844       resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
845       seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
846                                         symRemapping[i]));
847     } else {
848       symRemapping[i] = it->second;
849     }
850   }
851   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
852                                               nextDim, nextSym);
853   *operands = resultOperands;
854 }
855 
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)856 void mlir::canonicalizeMapAndOperands(AffineMap *map,
857                                       SmallVectorImpl<Value> *operands) {
858   canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
859 }
860 
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)861 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
862                                       SmallVectorImpl<Value> *operands) {
863   canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
864 }
865 
866 namespace {
867 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
868 /// maps that supply results into them.
869 ///
870 template <typename AffineOpTy>
871 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
872   using OpRewritePattern<AffineOpTy>::OpRewritePattern;
873 
874   /// Replace the affine op with another instance of it with the supplied
875   /// map and mapOperands.
876   void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
877                        AffineMap map, ArrayRef<Value> mapOperands) const;
878 
matchAndRewrite__anon15f1ca220a11::SimplifyAffineOp879   LogicalResult matchAndRewrite(AffineOpTy affineOp,
880                                 PatternRewriter &rewriter) const override {
881     static_assert(llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
882                                   AffineStoreOp, AffineApplyOp, AffineMinOp,
883                                   AffineMaxOp>::value,
884                   "affine load/store/apply/prefetch/min/max op expected");
885     auto map = affineOp.getAffineMap();
886     AffineMap oldMap = map;
887     auto oldOperands = affineOp.getMapOperands();
888     SmallVector<Value, 8> resultOperands(oldOperands);
889     composeAffineMapAndOperands(&map, &resultOperands);
890     if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
891                                     resultOperands.begin()))
892       return failure();
893 
894     replaceAffineOp(rewriter, affineOp, map, resultOperands);
895     return success();
896   }
897 };
898 
899 // Specialize the template to account for the different build signatures for
900 // affine load, store, and apply ops.
901 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const902 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
903     PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
904     ArrayRef<Value> mapOperands) const {
905   rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
906                                             mapOperands);
907 }
908 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const909 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
910     PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
911     ArrayRef<Value> mapOperands) const {
912   rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
913       prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(),
914       prefetch.isWrite(), prefetch.isDataCache());
915 }
916 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const917 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
918     PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
919     ArrayRef<Value> mapOperands) const {
920   rewriter.replaceOpWithNewOp<AffineStoreOp>(
921       store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
922 }
923 
924 // Generic version for ops that don't have extra operands.
925 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const926 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
927     PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
928     ArrayRef<Value> mapOperands) const {
929   rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
930 }
931 } // end anonymous namespace.
932 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)933 void AffineApplyOp::getCanonicalizationPatterns(
934     OwningRewritePatternList &results, MLIRContext *context) {
935   results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
936 }
937 
938 //===----------------------------------------------------------------------===//
939 // Common canonicalization pattern support logic
940 //===----------------------------------------------------------------------===//
941 
942 /// This is a common class used for patterns of the form
943 /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
944 /// into the root operation directly.
foldMemRefCast(Operation * op)945 static LogicalResult foldMemRefCast(Operation *op) {
946   bool folded = false;
947   for (OpOperand &operand : op->getOpOperands()) {
948     auto cast = operand.get().getDefiningOp<MemRefCastOp>();
949     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
950       operand.set(cast.getOperand());
951       folded = true;
952     }
953   }
954   return success(folded);
955 }
956 
957 //===----------------------------------------------------------------------===//
958 // AffineDmaStartOp
959 //===----------------------------------------------------------------------===//
960 
961 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)962 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
963                              Value srcMemRef, AffineMap srcMap,
964                              ValueRange srcIndices, Value destMemRef,
965                              AffineMap dstMap, ValueRange destIndices,
966                              Value tagMemRef, AffineMap tagMap,
967                              ValueRange tagIndices, Value numElements,
968                              Value stride, Value elementsPerStride) {
969   result.addOperands(srcMemRef);
970   result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
971   result.addOperands(srcIndices);
972   result.addOperands(destMemRef);
973   result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
974   result.addOperands(destIndices);
975   result.addOperands(tagMemRef);
976   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
977   result.addOperands(tagIndices);
978   result.addOperands(numElements);
979   if (stride) {
980     result.addOperands({stride, elementsPerStride});
981   }
982 }
983 
print(OpAsmPrinter & p)984 void AffineDmaStartOp::print(OpAsmPrinter &p) {
985   p << "affine.dma_start " << getSrcMemRef() << '[';
986   p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
987   p << "], " << getDstMemRef() << '[';
988   p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
989   p << "], " << getTagMemRef() << '[';
990   p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
991   p << "], " << getNumElements();
992   if (isStrided()) {
993     p << ", " << getStride();
994     p << ", " << getNumElementsPerStride();
995   }
996   p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
997     << getTagMemRefType();
998 }
999 
1000 // Parse AffineDmaStartOp.
1001 // Ex:
1002 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1003 //     %stride, %num_elt_per_stride
1004 //       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1005 //
parse(OpAsmParser & parser,OperationState & result)1006 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1007                                     OperationState &result) {
1008   OpAsmParser::OperandType srcMemRefInfo;
1009   AffineMapAttr srcMapAttr;
1010   SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
1011   OpAsmParser::OperandType dstMemRefInfo;
1012   AffineMapAttr dstMapAttr;
1013   SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
1014   OpAsmParser::OperandType tagMemRefInfo;
1015   AffineMapAttr tagMapAttr;
1016   SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
1017   OpAsmParser::OperandType numElementsInfo;
1018   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1019 
1020   SmallVector<Type, 3> types;
1021   auto indexType = parser.getBuilder().getIndexType();
1022 
1023   // Parse and resolve the following list of operands:
1024   // *) dst memref followed by its affine maps operands (in square brackets).
1025   // *) src memref followed by its affine map operands (in square brackets).
1026   // *) tag memref followed by its affine map operands (in square brackets).
1027   // *) number of elements transferred by DMA operation.
1028   if (parser.parseOperand(srcMemRefInfo) ||
1029       parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1030                                     getSrcMapAttrName(), result.attributes) ||
1031       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1032       parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1033                                     getDstMapAttrName(), result.attributes) ||
1034       parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1035       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1036                                     getTagMapAttrName(), result.attributes) ||
1037       parser.parseComma() || parser.parseOperand(numElementsInfo))
1038     return failure();
1039 
1040   // Parse optional stride and elements per stride.
1041   if (parser.parseTrailingOperandList(strideInfo)) {
1042     return failure();
1043   }
1044   if (!strideInfo.empty() && strideInfo.size() != 2) {
1045     return parser.emitError(parser.getNameLoc(),
1046                             "expected two stride related operands");
1047   }
1048   bool isStrided = strideInfo.size() == 2;
1049 
1050   if (parser.parseColonTypeList(types))
1051     return failure();
1052 
1053   if (types.size() != 3)
1054     return parser.emitError(parser.getNameLoc(), "expected three types");
1055 
1056   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1057       parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1058       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1059       parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1060       parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1061       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1062       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1063     return failure();
1064 
1065   if (isStrided) {
1066     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1067       return failure();
1068   }
1069 
1070   // Check that src/dst/tag operand counts match their map.numInputs.
1071   if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1072       dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1073       tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1074     return parser.emitError(parser.getNameLoc(),
1075                             "memref operand count not equal to map.numInputs");
1076   return success();
1077 }
1078 
verify()1079 LogicalResult AffineDmaStartOp::verify() {
1080   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1081     return emitOpError("expected DMA source to be of memref type");
1082   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1083     return emitOpError("expected DMA destination to be of memref type");
1084   if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1085     return emitOpError("expected DMA tag to be of memref type");
1086 
1087   // DMAs from different memory spaces supported.
1088   if (getSrcMemorySpace() == getDstMemorySpace()) {
1089     return emitOpError("DMA should be between different memory spaces");
1090   }
1091   unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1092                               getDstMap().getNumInputs() +
1093                               getTagMap().getNumInputs();
1094   if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1095       getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1096     return emitOpError("incorrect number of operands");
1097   }
1098 
1099   Region *scope = getAffineScope(*this);
1100   for (auto idx : getSrcIndices()) {
1101     if (!idx.getType().isIndex())
1102       return emitOpError("src index to dma_start must have 'index' type");
1103     if (!isValidAffineIndexOperand(idx, scope))
1104       return emitOpError("src index must be a dimension or symbol identifier");
1105   }
1106   for (auto idx : getDstIndices()) {
1107     if (!idx.getType().isIndex())
1108       return emitOpError("dst index to dma_start must have 'index' type");
1109     if (!isValidAffineIndexOperand(idx, scope))
1110       return emitOpError("dst index must be a dimension or symbol identifier");
1111   }
1112   for (auto idx : getTagIndices()) {
1113     if (!idx.getType().isIndex())
1114       return emitOpError("tag index to dma_start must have 'index' type");
1115     if (!isValidAffineIndexOperand(idx, scope))
1116       return emitOpError("tag index must be a dimension or symbol identifier");
1117   }
1118   return success();
1119 }
1120 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1121 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1122                                      SmallVectorImpl<OpFoldResult> &results) {
1123   /// dma_start(memrefcast) -> dma_start
1124   return foldMemRefCast(*this);
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // AffineDmaWaitOp
1129 //===----------------------------------------------------------------------===//
1130 
1131 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1132 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1133                             Value tagMemRef, AffineMap tagMap,
1134                             ValueRange tagIndices, Value numElements) {
1135   result.addOperands(tagMemRef);
1136   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1137   result.addOperands(tagIndices);
1138   result.addOperands(numElements);
1139 }
1140 
print(OpAsmPrinter & p)1141 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1142   p << "affine.dma_wait " << getTagMemRef() << '[';
1143   SmallVector<Value, 2> operands(getTagIndices());
1144   p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1145   p << "], ";
1146   p.printOperand(getNumElements());
1147   p << " : " << getTagMemRef().getType();
1148 }
1149 
1150 // Parse AffineDmaWaitOp.
1151 // Eg:
1152 //   affine.dma_wait %tag[%index], %num_elements
1153 //     : memref<1 x i32, (d0) -> (d0), 4>
1154 //
parse(OpAsmParser & parser,OperationState & result)1155 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1156                                    OperationState &result) {
1157   OpAsmParser::OperandType tagMemRefInfo;
1158   AffineMapAttr tagMapAttr;
1159   SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1160   Type type;
1161   auto indexType = parser.getBuilder().getIndexType();
1162   OpAsmParser::OperandType numElementsInfo;
1163 
1164   // Parse tag memref, its map operands, and dma size.
1165   if (parser.parseOperand(tagMemRefInfo) ||
1166       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1167                                     getTagMapAttrName(), result.attributes) ||
1168       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1169       parser.parseColonType(type) ||
1170       parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1171       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1172       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1173     return failure();
1174 
1175   if (!type.isa<MemRefType>())
1176     return parser.emitError(parser.getNameLoc(),
1177                             "expected tag to be of memref type");
1178 
1179   if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1180     return parser.emitError(parser.getNameLoc(),
1181                             "tag memref operand count != to map.numInputs");
1182   return success();
1183 }
1184 
verify()1185 LogicalResult AffineDmaWaitOp::verify() {
1186   if (!getOperand(0).getType().isa<MemRefType>())
1187     return emitOpError("expected DMA tag to be of memref type");
1188   Region *scope = getAffineScope(*this);
1189   for (auto idx : getTagIndices()) {
1190     if (!idx.getType().isIndex())
1191       return emitOpError("index to dma_wait must have 'index' type");
1192     if (!isValidAffineIndexOperand(idx, scope))
1193       return emitOpError("index must be a dimension or symbol identifier");
1194   }
1195   return success();
1196 }
1197 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1198 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1199                                     SmallVectorImpl<OpFoldResult> &results) {
1200   /// dma_wait(memrefcast) -> dma_wait
1201   return foldMemRefCast(*this);
1202 }
1203 
1204 //===----------------------------------------------------------------------===//
1205 // AffineForOp
1206 //===----------------------------------------------------------------------===//
1207 
1208 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1209 /// bodyBuilder are empty/null, we include default terminator op.
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1210 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1211                         ValueRange lbOperands, AffineMap lbMap,
1212                         ValueRange ubOperands, AffineMap ubMap, int64_t step,
1213                         ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1214   assert(((!lbMap && lbOperands.empty()) ||
1215           lbOperands.size() == lbMap.getNumInputs()) &&
1216          "lower bound operand count does not match the affine map");
1217   assert(((!ubMap && ubOperands.empty()) ||
1218           ubOperands.size() == ubMap.getNumInputs()) &&
1219          "upper bound operand count does not match the affine map");
1220   assert(step > 0 && "step has to be a positive integer constant");
1221 
1222   for (Value val : iterArgs)
1223     result.addTypes(val.getType());
1224 
1225   // Add an attribute for the step.
1226   result.addAttribute(getStepAttrName(),
1227                       builder.getIntegerAttr(builder.getIndexType(), step));
1228 
1229   // Add the lower bound.
1230   result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1231   result.addOperands(lbOperands);
1232 
1233   // Add the upper bound.
1234   result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1235   result.addOperands(ubOperands);
1236 
1237   result.addOperands(iterArgs);
1238   // Create a region and a block for the body.  The argument of the region is
1239   // the loop induction variable.
1240   Region *bodyRegion = result.addRegion();
1241   bodyRegion->push_back(new Block);
1242   Block &bodyBlock = bodyRegion->front();
1243   Value inductionVar = bodyBlock.addArgument(builder.getIndexType());
1244   for (Value val : iterArgs)
1245     bodyBlock.addArgument(val.getType());
1246 
1247   // Create the default terminator if the builder is not provided and if the
1248   // iteration arguments are not provided. Otherwise, leave this to the caller
1249   // because we don't know which values to return from the loop.
1250   if (iterArgs.empty() && !bodyBuilder) {
1251     ensureTerminator(*bodyRegion, builder, result.location);
1252   } else if (bodyBuilder) {
1253     OpBuilder::InsertionGuard guard(builder);
1254     builder.setInsertionPointToStart(&bodyBlock);
1255     bodyBuilder(builder, result.location, inductionVar,
1256                 bodyBlock.getArguments().drop_front());
1257   }
1258 }
1259 
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1260 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1261                         int64_t ub, int64_t step, ValueRange iterArgs,
1262                         BodyBuilderFn bodyBuilder) {
1263   auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1264   auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1265   return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1266                bodyBuilder);
1267 }
1268 
verify(AffineForOp op)1269 static LogicalResult verify(AffineForOp op) {
1270   // Check that the body defines as single block argument for the induction
1271   // variable.
1272   auto *body = op.getBody();
1273   if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1274     return op.emitOpError(
1275         "expected body to have a single index argument for the "
1276         "induction variable");
1277 
1278   // Verify that the bound operands are valid dimension/symbols.
1279   /// Lower bound.
1280   if (op.getLowerBoundMap().getNumInputs() > 0)
1281     if (failed(
1282             verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1283                                           op.getLowerBoundMap().getNumDims())))
1284       return failure();
1285   /// Upper bound.
1286   if (op.getUpperBoundMap().getNumInputs() > 0)
1287     if (failed(
1288             verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1289                                           op.getUpperBoundMap().getNumDims())))
1290       return failure();
1291 
1292   unsigned opNumResults = op.getNumResults();
1293   if (opNumResults == 0)
1294     return success();
1295 
1296   // If ForOp defines values, check that the number and types of the defined
1297   // values match ForOp initial iter operands and backedge basic block
1298   // arguments.
1299   if (op.getNumIterOperands() != opNumResults)
1300     return op.emitOpError(
1301         "mismatch between the number of loop-carried values and results");
1302   if (op.getNumRegionIterArgs() != opNumResults)
1303     return op.emitOpError(
1304         "mismatch between the number of basic block args and results");
1305 
1306   return success();
1307 }
1308 
1309 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1310 static ParseResult parseBound(bool isLower, OperationState &result,
1311                               OpAsmParser &p) {
1312   // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1313   // the map has multiple results.
1314   bool failedToParsedMinMax =
1315       failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1316 
1317   auto &builder = p.getBuilder();
1318   auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1319                                : AffineForOp::getUpperBoundAttrName();
1320 
1321   // Parse ssa-id as identity map.
1322   SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1323   if (p.parseOperandList(boundOpInfos))
1324     return failure();
1325 
1326   if (!boundOpInfos.empty()) {
1327     // Check that only one operand was parsed.
1328     if (boundOpInfos.size() > 1)
1329       return p.emitError(p.getNameLoc(),
1330                          "expected only one loop bound operand");
1331 
1332     // TODO: improve error message when SSA value is not of index type.
1333     // Currently it is 'use of value ... expects different type than prior uses'
1334     if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1335                          result.operands))
1336       return failure();
1337 
1338     // Create an identity map using symbol id. This representation is optimized
1339     // for storage. Analysis passes may expand it into a multi-dimensional map
1340     // if desired.
1341     AffineMap map = builder.getSymbolIdentityMap();
1342     result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1343     return success();
1344   }
1345 
1346   // Get the attribute location.
1347   llvm::SMLoc attrLoc = p.getCurrentLocation();
1348 
1349   Attribute boundAttr;
1350   if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1351                        result.attributes))
1352     return failure();
1353 
1354   // Parse full form - affine map followed by dim and symbol list.
1355   if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1356     unsigned currentNumOperands = result.operands.size();
1357     unsigned numDims;
1358     if (parseDimAndSymbolList(p, result.operands, numDims))
1359       return failure();
1360 
1361     auto map = affineMapAttr.getValue();
1362     if (map.getNumDims() != numDims)
1363       return p.emitError(
1364           p.getNameLoc(),
1365           "dim operand count and affine map dim count must match");
1366 
1367     unsigned numDimAndSymbolOperands =
1368         result.operands.size() - currentNumOperands;
1369     if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1370       return p.emitError(
1371           p.getNameLoc(),
1372           "symbol operand count and affine map symbol count must match");
1373 
1374     // If the map has multiple results, make sure that we parsed the min/max
1375     // prefix.
1376     if (map.getNumResults() > 1 && failedToParsedMinMax) {
1377       if (isLower) {
1378         return p.emitError(attrLoc, "lower loop bound affine map with "
1379                                     "multiple results requires 'max' prefix");
1380       }
1381       return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1382                                   "results requires 'min' prefix");
1383     }
1384     return success();
1385   }
1386 
1387   // Parse custom assembly form.
1388   if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1389     result.attributes.pop_back();
1390     result.addAttribute(
1391         boundAttrName,
1392         AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1393     return success();
1394   }
1395 
1396   return p.emitError(
1397       p.getNameLoc(),
1398       "expected valid affine map representation for loop bounds");
1399 }
1400 
parseAffineForOp(OpAsmParser & parser,OperationState & result)1401 static ParseResult parseAffineForOp(OpAsmParser &parser,
1402                                     OperationState &result) {
1403   auto &builder = parser.getBuilder();
1404   OpAsmParser::OperandType inductionVariable;
1405   // Parse the induction variable followed by '='.
1406   if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1407     return failure();
1408 
1409   // Parse loop bounds.
1410   if (parseBound(/*isLower=*/true, result, parser) ||
1411       parser.parseKeyword("to", " between bounds") ||
1412       parseBound(/*isLower=*/false, result, parser))
1413     return failure();
1414 
1415   // Parse the optional loop step, we default to 1 if one is not present.
1416   if (parser.parseOptionalKeyword("step")) {
1417     result.addAttribute(
1418         AffineForOp::getStepAttrName(),
1419         builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1420   } else {
1421     llvm::SMLoc stepLoc = parser.getCurrentLocation();
1422     IntegerAttr stepAttr;
1423     if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1424                               AffineForOp::getStepAttrName().data(),
1425                               result.attributes))
1426       return failure();
1427 
1428     if (stepAttr.getValue().getSExtValue() < 0)
1429       return parser.emitError(
1430           stepLoc,
1431           "expected step to be representable as a positive signed integer");
1432   }
1433 
1434   // Parse the optional initial iteration arguments.
1435   SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
1436   SmallVector<Type, 4> argTypes;
1437   regionArgs.push_back(inductionVariable);
1438 
1439   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1440     // Parse assignment list and results type list.
1441     if (parser.parseAssignmentList(regionArgs, operands) ||
1442         parser.parseArrowTypeList(result.types))
1443       return failure();
1444     // Resolve input operands.
1445     for (auto operandType : llvm::zip(operands, result.types))
1446       if (parser.resolveOperand(std::get<0>(operandType),
1447                                 std::get<1>(operandType), result.operands))
1448         return failure();
1449   }
1450   // Induction variable.
1451   Type indexType = builder.getIndexType();
1452   argTypes.push_back(indexType);
1453   // Loop carried variables.
1454   argTypes.append(result.types.begin(), result.types.end());
1455   // Parse the body region.
1456   Region *body = result.addRegion();
1457   if (regionArgs.size() != argTypes.size())
1458     return parser.emitError(
1459         parser.getNameLoc(),
1460         "mismatch between the number of loop-carried values and results");
1461   if (parser.parseRegion(*body, regionArgs, argTypes))
1462     return failure();
1463 
1464   AffineForOp::ensureTerminator(*body, builder, result.location);
1465 
1466   // Parse the optional attribute list.
1467   return parser.parseOptionalAttrDict(result.attributes);
1468 }
1469 
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1470 static void printBound(AffineMapAttr boundMap,
1471                        Operation::operand_range boundOperands,
1472                        const char *prefix, OpAsmPrinter &p) {
1473   AffineMap map = boundMap.getValue();
1474 
1475   // Check if this bound should be printed using custom assembly form.
1476   // The decision to restrict printing custom assembly form to trivial cases
1477   // comes from the will to roundtrip MLIR binary -> text -> binary in a
1478   // lossless way.
1479   // Therefore, custom assembly form parsing and printing is only supported for
1480   // zero-operand constant maps and single symbol operand identity maps.
1481   if (map.getNumResults() == 1) {
1482     AffineExpr expr = map.getResult(0);
1483 
1484     // Print constant bound.
1485     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1486       if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1487         p << constExpr.getValue();
1488         return;
1489       }
1490     }
1491 
1492     // Print bound that consists of a single SSA symbol if the map is over a
1493     // single symbol.
1494     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1495       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1496         p.printOperand(*boundOperands.begin());
1497         return;
1498       }
1499     }
1500   } else {
1501     // Map has multiple results. Print 'min' or 'max' prefix.
1502     p << prefix << ' ';
1503   }
1504 
1505   // Print the map and its operands.
1506   p << boundMap;
1507   printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1508                         map.getNumDims(), p);
1509 }
1510 
getNumIterOperands()1511 unsigned AffineForOp::getNumIterOperands() {
1512   AffineMap lbMap = getLowerBoundMapAttr().getValue();
1513   AffineMap ubMap = getUpperBoundMapAttr().getValue();
1514 
1515   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
1516 }
1517 
print(OpAsmPrinter & p,AffineForOp op)1518 static void print(OpAsmPrinter &p, AffineForOp op) {
1519   p << op.getOperationName() << ' ';
1520   p.printOperand(op.getBody()->getArgument(0));
1521   p << " = ";
1522   printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1523   p << " to ";
1524   printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1525 
1526   if (op.getStep() != 1)
1527     p << " step " << op.getStep();
1528 
1529   bool printBlockTerminators = false;
1530   if (op.getNumIterOperands() > 0) {
1531     p << " iter_args(";
1532     auto regionArgs = op.getRegionIterArgs();
1533     auto operands = op.getIterOperands();
1534 
1535     llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1536       p << std::get<0>(it) << " = " << std::get<1>(it);
1537     });
1538     p << ") -> (" << op.getResultTypes() << ")";
1539     printBlockTerminators = true;
1540   }
1541 
1542   p.printRegion(op.region(),
1543                 /*printEntryBlockArgs=*/false, printBlockTerminators);
1544   p.printOptionalAttrDict(op.getAttrs(),
1545                           /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1546                                            op.getUpperBoundAttrName(),
1547                                            op.getStepAttrName()});
1548 }
1549 
1550 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1551 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1552   auto foldLowerOrUpperBound = [&forOp](bool lower) {
1553     // Check to see if each of the operands is the result of a constant.  If
1554     // so, get the value.  If not, ignore it.
1555     SmallVector<Attribute, 8> operandConstants;
1556     auto boundOperands =
1557         lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1558     for (auto operand : boundOperands) {
1559       Attribute operandCst;
1560       matchPattern(operand, m_Constant(&operandCst));
1561       operandConstants.push_back(operandCst);
1562     }
1563 
1564     AffineMap boundMap =
1565         lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1566     assert(boundMap.getNumResults() >= 1 &&
1567            "bound maps should have at least one result");
1568     SmallVector<Attribute, 4> foldedResults;
1569     if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1570       return failure();
1571 
1572     // Compute the max or min as applicable over the results.
1573     assert(!foldedResults.empty() && "bounds should have at least one result");
1574     auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1575     for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1576       auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1577       maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1578                        : llvm::APIntOps::smin(maxOrMin, foldedResult);
1579     }
1580     lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1581           : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1582     return success();
1583   };
1584 
1585   // Try to fold the lower bound.
1586   bool folded = false;
1587   if (!forOp.hasConstantLowerBound())
1588     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1589 
1590   // Try to fold the upper bound.
1591   if (!forOp.hasConstantUpperBound())
1592     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1593   return success(folded);
1594 }
1595 
1596 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1597 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1598   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1599   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1600 
1601   auto lbMap = forOp.getLowerBoundMap();
1602   auto ubMap = forOp.getUpperBoundMap();
1603   auto prevLbMap = lbMap;
1604   auto prevUbMap = ubMap;
1605 
1606   canonicalizeMapAndOperands(&lbMap, &lbOperands);
1607   lbMap = removeDuplicateExprs(lbMap);
1608 
1609   canonicalizeMapAndOperands(&ubMap, &ubOperands);
1610   ubMap = removeDuplicateExprs(ubMap);
1611 
1612   // Any canonicalization change always leads to updated map(s).
1613   if (lbMap == prevLbMap && ubMap == prevUbMap)
1614     return failure();
1615 
1616   if (lbMap != prevLbMap)
1617     forOp.setLowerBound(lbOperands, lbMap);
1618   if (ubMap != prevUbMap)
1619     forOp.setUpperBound(ubOperands, ubMap);
1620   return success();
1621 }
1622 
1623 namespace {
1624 /// This is a pattern to fold trivially empty loops.
1625 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1626   using OpRewritePattern<AffineForOp>::OpRewritePattern;
1627 
matchAndRewrite__anon15f1ca220d11::AffineForEmptyLoopFolder1628   LogicalResult matchAndRewrite(AffineForOp forOp,
1629                                 PatternRewriter &rewriter) const override {
1630     // Check that the body only contains a yield.
1631     if (!llvm::hasSingleElement(*forOp.getBody()))
1632       return failure();
1633     rewriter.eraseOp(forOp);
1634     return success();
1635   }
1636 };
1637 } // end anonymous namespace
1638 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1639 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1640                                               MLIRContext *context) {
1641   results.insert<AffineForEmptyLoopFolder>(context);
1642 }
1643 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1644 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1645                                 SmallVectorImpl<OpFoldResult> &results) {
1646   bool folded = succeeded(foldLoopBounds(*this));
1647   folded |= succeeded(canonicalizeLoopBounds(*this));
1648   return success(folded);
1649 }
1650 
getLowerBound()1651 AffineBound AffineForOp::getLowerBound() {
1652   auto lbMap = getLowerBoundMap();
1653   return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1654 }
1655 
getUpperBound()1656 AffineBound AffineForOp::getUpperBound() {
1657   auto lbMap = getLowerBoundMap();
1658   auto ubMap = getUpperBoundMap();
1659   return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
1660                      lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
1661 }
1662 
setLowerBound(ValueRange lbOperands,AffineMap map)1663 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1664   assert(lbOperands.size() == map.getNumInputs());
1665   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1666 
1667   SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1668 
1669   auto ubOperands = getUpperBoundOperands();
1670   newOperands.append(ubOperands.begin(), ubOperands.end());
1671   auto iterOperands = getIterOperands();
1672   newOperands.append(iterOperands.begin(), iterOperands.end());
1673   (*this)->setOperands(newOperands);
1674 
1675   setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1676 }
1677 
setUpperBound(ValueRange ubOperands,AffineMap map)1678 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1679   assert(ubOperands.size() == map.getNumInputs());
1680   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1681 
1682   SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1683   newOperands.append(ubOperands.begin(), ubOperands.end());
1684   auto iterOperands = getIterOperands();
1685   newOperands.append(iterOperands.begin(), iterOperands.end());
1686   (*this)->setOperands(newOperands);
1687 
1688   setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1689 }
1690 
setLowerBoundMap(AffineMap map)1691 void AffineForOp::setLowerBoundMap(AffineMap map) {
1692   auto lbMap = getLowerBoundMap();
1693   assert(lbMap.getNumDims() == map.getNumDims() &&
1694          lbMap.getNumSymbols() == map.getNumSymbols());
1695   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1696   (void)lbMap;
1697   setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1698 }
1699 
setUpperBoundMap(AffineMap map)1700 void AffineForOp::setUpperBoundMap(AffineMap map) {
1701   auto ubMap = getUpperBoundMap();
1702   assert(ubMap.getNumDims() == map.getNumDims() &&
1703          ubMap.getNumSymbols() == map.getNumSymbols());
1704   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1705   (void)ubMap;
1706   setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1707 }
1708 
hasConstantLowerBound()1709 bool AffineForOp::hasConstantLowerBound() {
1710   return getLowerBoundMap().isSingleConstant();
1711 }
1712 
hasConstantUpperBound()1713 bool AffineForOp::hasConstantUpperBound() {
1714   return getUpperBoundMap().isSingleConstant();
1715 }
1716 
getConstantLowerBound()1717 int64_t AffineForOp::getConstantLowerBound() {
1718   return getLowerBoundMap().getSingleConstantResult();
1719 }
1720 
getConstantUpperBound()1721 int64_t AffineForOp::getConstantUpperBound() {
1722   return getUpperBoundMap().getSingleConstantResult();
1723 }
1724 
setConstantLowerBound(int64_t value)1725 void AffineForOp::setConstantLowerBound(int64_t value) {
1726   setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1727 }
1728 
setConstantUpperBound(int64_t value)1729 void AffineForOp::setConstantUpperBound(int64_t value) {
1730   setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1731 }
1732 
getLowerBoundOperands()1733 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1734   return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1735 }
1736 
getUpperBoundOperands()1737 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1738   return {operand_begin() + getLowerBoundMap().getNumInputs(),
1739           operand_begin() + getLowerBoundMap().getNumInputs() +
1740               getUpperBoundMap().getNumInputs()};
1741 }
1742 
matchingBoundOperandList()1743 bool AffineForOp::matchingBoundOperandList() {
1744   auto lbMap = getLowerBoundMap();
1745   auto ubMap = getUpperBoundMap();
1746   if (lbMap.getNumDims() != ubMap.getNumDims() ||
1747       lbMap.getNumSymbols() != ubMap.getNumSymbols())
1748     return false;
1749 
1750   unsigned numOperands = lbMap.getNumInputs();
1751   for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1752     // Compare Value 's.
1753     if (getOperand(i) != getOperand(numOperands + i))
1754       return false;
1755   }
1756   return true;
1757 }
1758 
getLoopBody()1759 Region &AffineForOp::getLoopBody() { return region(); }
1760 
isDefinedOutsideOfLoop(Value value)1761 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1762   return !region().isAncestor(value.getParentRegion());
1763 }
1764 
moveOutOfLoop(ArrayRef<Operation * > ops)1765 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1766   for (auto *op : ops)
1767     op->moveBefore(*this);
1768   return success();
1769 }
1770 
1771 /// Returns true if the provided value is the induction variable of a
1772 /// AffineForOp.
isForInductionVar(Value val)1773 bool mlir::isForInductionVar(Value val) {
1774   return getForInductionVarOwner(val) != AffineForOp();
1775 }
1776 
1777 /// Returns the loop parent of an induction variable. If the provided value is
1778 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1779 AffineForOp mlir::getForInductionVarOwner(Value val) {
1780   auto ivArg = val.dyn_cast<BlockArgument>();
1781   if (!ivArg || !ivArg.getOwner())
1782     return AffineForOp();
1783   auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1784   return dyn_cast<AffineForOp>(containingInst);
1785 }
1786 
1787 /// Extracts the induction variables from a list of AffineForOps and returns
1788 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1789 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1790                                    SmallVectorImpl<Value> *ivs) {
1791   ivs->reserve(forInsts.size());
1792   for (auto forInst : forInsts)
1793     ivs->push_back(forInst.getInductionVar());
1794 }
1795 
1796 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
1797 /// operations.
1798 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)1799 static void buildAffineLoopNestImpl(
1800     OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
1801     ArrayRef<int64_t> steps,
1802     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
1803     LoopCreatorTy &&loopCreatorFn) {
1804   assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
1805   assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
1806 
1807   // If there are no loops to be constructed, construct the body anyway.
1808   OpBuilder::InsertionGuard guard(builder);
1809   if (lbs.empty()) {
1810     if (bodyBuilderFn)
1811       bodyBuilderFn(builder, loc, ValueRange());
1812     return;
1813   }
1814 
1815   // Create the loops iteratively and store the induction variables.
1816   SmallVector<Value, 4> ivs;
1817   ivs.reserve(lbs.size());
1818   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
1819     // Callback for creating the loop body, always creates the terminator.
1820     auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
1821                         ValueRange iterArgs) {
1822       ivs.push_back(iv);
1823       // In the innermost loop, call the body builder.
1824       if (i == e - 1 && bodyBuilderFn) {
1825         OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
1826         bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1827       }
1828       nestedBuilder.create<AffineYieldOp>(nestedLoc);
1829     };
1830 
1831     // Delegate actual loop creation to the callback in order to dispatch
1832     // between constant- and variable-bound loops.
1833     auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
1834     builder.setInsertionPointToStart(loop.getBody());
1835   }
1836 }
1837 
1838 /// Creates an affine loop from the bounds known to be constants.
1839 static AffineForOp
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1840 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
1841                              int64_t ub, int64_t step,
1842                              AffineForOp::BodyBuilderFn bodyBuilderFn) {
1843   return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
1844                                      bodyBuilderFn);
1845 }
1846 
1847 /// Creates an affine loop from the bounds that may or may not be constants.
1848 static AffineForOp
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1849 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
1850                           int64_t step,
1851                           AffineForOp::BodyBuilderFn bodyBuilderFn) {
1852   auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
1853   auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
1854   if (lbConst && ubConst)
1855     return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(),
1856                                         ubConst.getValue(), step,
1857                                         bodyBuilderFn);
1858   return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
1859                                      builder.getDimIdentityMap(), step,
1860                                      /*iterArgs=*/llvm::None, bodyBuilderFn);
1861 }
1862 
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1863 void mlir::buildAffineLoopNest(
1864     OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
1865     ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
1866     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1867   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1868                           buildAffineLoopFromConstants);
1869 }
1870 
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1871 void mlir::buildAffineLoopNest(
1872     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
1873     ArrayRef<int64_t> steps,
1874     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1875   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1876                           buildAffineLoopFromValues);
1877 }
1878 
1879 //===----------------------------------------------------------------------===//
1880 // AffineIfOp
1881 //===----------------------------------------------------------------------===//
1882 
1883 namespace {
1884 /// Remove else blocks that have nothing other than a zero value yield.
1885 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
1886   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1887 
matchAndRewrite__anon15f1ca220f11::SimplifyDeadElse1888   LogicalResult matchAndRewrite(AffineIfOp ifOp,
1889                                 PatternRewriter &rewriter) const override {
1890     if (ifOp.elseRegion().empty() ||
1891         !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
1892       return failure();
1893 
1894     rewriter.startRootUpdate(ifOp);
1895     rewriter.eraseBlock(ifOp.getElseBlock());
1896     rewriter.finalizeRootUpdate(ifOp);
1897     return success();
1898   }
1899 };
1900 } // end anonymous namespace.
1901 
verify(AffineIfOp op)1902 static LogicalResult verify(AffineIfOp op) {
1903   // Verify that we have a condition attribute.
1904   auto conditionAttr =
1905       op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1906   if (!conditionAttr)
1907     return op.emitOpError(
1908         "requires an integer set attribute named 'condition'");
1909 
1910   // Verify that there are enough operands for the condition.
1911   IntegerSet condition = conditionAttr.getValue();
1912   if (op.getNumOperands() != condition.getNumInputs())
1913     return op.emitOpError(
1914         "operand count and condition integer set dimension and "
1915         "symbol count must match");
1916 
1917   // Verify that the operands are valid dimension/symbols.
1918   if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
1919                                            condition.getNumDims())))
1920     return failure();
1921 
1922   return success();
1923 }
1924 
parseAffineIfOp(OpAsmParser & parser,OperationState & result)1925 static ParseResult parseAffineIfOp(OpAsmParser &parser,
1926                                    OperationState &result) {
1927   // Parse the condition attribute set.
1928   IntegerSetAttr conditionAttr;
1929   unsigned numDims;
1930   if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
1931                             result.attributes) ||
1932       parseDimAndSymbolList(parser, result.operands, numDims))
1933     return failure();
1934 
1935   // Verify the condition operands.
1936   auto set = conditionAttr.getValue();
1937   if (set.getNumDims() != numDims)
1938     return parser.emitError(
1939         parser.getNameLoc(),
1940         "dim operand count and integer set dim count must match");
1941   if (numDims + set.getNumSymbols() != result.operands.size())
1942     return parser.emitError(
1943         parser.getNameLoc(),
1944         "symbol operand count and integer set symbol count must match");
1945 
1946   if (parser.parseOptionalArrowTypeList(result.types))
1947     return failure();
1948 
1949   // Create the regions for 'then' and 'else'.  The latter must be created even
1950   // if it remains empty for the validity of the operation.
1951   result.regions.reserve(2);
1952   Region *thenRegion = result.addRegion();
1953   Region *elseRegion = result.addRegion();
1954 
1955   // Parse the 'then' region.
1956   if (parser.parseRegion(*thenRegion, {}, {}))
1957     return failure();
1958   AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
1959                                result.location);
1960 
1961   // If we find an 'else' keyword then parse the 'else' region.
1962   if (!parser.parseOptionalKeyword("else")) {
1963     if (parser.parseRegion(*elseRegion, {}, {}))
1964       return failure();
1965     AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
1966                                  result.location);
1967   }
1968 
1969   // Parse the optional attribute list.
1970   if (parser.parseOptionalAttrDict(result.attributes))
1971     return failure();
1972 
1973   return success();
1974 }
1975 
print(OpAsmPrinter & p,AffineIfOp op)1976 static void print(OpAsmPrinter &p, AffineIfOp op) {
1977   auto conditionAttr =
1978       op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1979   p << "affine.if " << conditionAttr;
1980   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
1981                         conditionAttr.getValue().getNumDims(), p);
1982   p.printOptionalArrowTypeList(op.getResultTypes());
1983   p.printRegion(op.thenRegion(),
1984                 /*printEntryBlockArgs=*/false,
1985                 /*printBlockTerminators=*/op.getNumResults());
1986 
1987   // Print the 'else' regions if it has any blocks.
1988   auto &elseRegion = op.elseRegion();
1989   if (!elseRegion.empty()) {
1990     p << " else";
1991     p.printRegion(elseRegion,
1992                   /*printEntryBlockArgs=*/false,
1993                   /*printBlockTerminators=*/op.getNumResults());
1994   }
1995 
1996   // Print the attribute list.
1997   p.printOptionalAttrDict(op.getAttrs(),
1998                           /*elidedAttrs=*/op.getConditionAttrName());
1999 }
2000 
getIntegerSet()2001 IntegerSet AffineIfOp::getIntegerSet() {
2002   return (*this)
2003       ->getAttrOfType<IntegerSetAttr>(getConditionAttrName())
2004       .getValue();
2005 }
setIntegerSet(IntegerSet newSet)2006 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2007   setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
2008 }
2009 
setConditional(IntegerSet set,ValueRange operands)2010 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2011   setIntegerSet(set);
2012   (*this)->setOperands(operands);
2013 }
2014 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)2015 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2016                        TypeRange resultTypes, IntegerSet set, ValueRange args,
2017                        bool withElseRegion) {
2018   assert(resultTypes.empty() || withElseRegion);
2019   result.addTypes(resultTypes);
2020   result.addOperands(args);
2021   result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
2022 
2023   Region *thenRegion = result.addRegion();
2024   thenRegion->push_back(new Block());
2025   if (resultTypes.empty())
2026     AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2027 
2028   Region *elseRegion = result.addRegion();
2029   if (withElseRegion) {
2030     elseRegion->push_back(new Block());
2031     if (resultTypes.empty())
2032       AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2033   }
2034 }
2035 
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)2036 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2037                        IntegerSet set, ValueRange args, bool withElseRegion) {
2038   AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2039                     withElseRegion);
2040 }
2041 
2042 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2043 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2044                                SmallVectorImpl<OpFoldResult> &) {
2045   auto set = getIntegerSet();
2046   SmallVector<Value, 4> operands(getOperands());
2047   canonicalizeSetAndOperands(&set, &operands);
2048 
2049   // Any canonicalization change always leads to either a reduction in the
2050   // number of operands or a change in the number of symbolic operands
2051   // (promotion of dims to symbols).
2052   if (operands.size() < getIntegerSet().getNumInputs() ||
2053       set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
2054     setConditional(set, operands);
2055     return success();
2056   }
2057 
2058   return failure();
2059 }
2060 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2061 void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2062                                              MLIRContext *context) {
2063   results.insert<SimplifyDeadElse>(context);
2064 }
2065 
2066 //===----------------------------------------------------------------------===//
2067 // AffineLoadOp
2068 //===----------------------------------------------------------------------===//
2069 
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)2070 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2071                          AffineMap map, ValueRange operands) {
2072   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2073   result.addOperands(operands);
2074   if (map)
2075     result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2076   auto memrefType = operands[0].getType().cast<MemRefType>();
2077   result.types.push_back(memrefType.getElementType());
2078 }
2079 
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)2080 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2081                          Value memref, AffineMap map, ValueRange mapOperands) {
2082   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2083   result.addOperands(memref);
2084   result.addOperands(mapOperands);
2085   auto memrefType = memref.getType().cast<MemRefType>();
2086   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2087   result.types.push_back(memrefType.getElementType());
2088 }
2089 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)2090 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2091                          Value memref, ValueRange indices) {
2092   auto memrefType = memref.getType().cast<MemRefType>();
2093   int64_t rank = memrefType.getRank();
2094   // Create identity map for memrefs with at least one dimension or () -> ()
2095   // for zero-dimensional memrefs.
2096   auto map =
2097       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2098   build(builder, result, memref, map, indices);
2099 }
2100 
parseAffineLoadOp(OpAsmParser & parser,OperationState & result)2101 static ParseResult parseAffineLoadOp(OpAsmParser &parser,
2102                                      OperationState &result) {
2103   auto &builder = parser.getBuilder();
2104   auto indexTy = builder.getIndexType();
2105 
2106   MemRefType type;
2107   OpAsmParser::OperandType memrefInfo;
2108   AffineMapAttr mapAttr;
2109   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2110   return failure(
2111       parser.parseOperand(memrefInfo) ||
2112       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2113                                     AffineLoadOp::getMapAttrName(),
2114                                     result.attributes) ||
2115       parser.parseOptionalAttrDict(result.attributes) ||
2116       parser.parseColonType(type) ||
2117       parser.resolveOperand(memrefInfo, type, result.operands) ||
2118       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2119       parser.addTypeToList(type.getElementType(), result.types));
2120 }
2121 
print(OpAsmPrinter & p,AffineLoadOp op)2122 static void print(OpAsmPrinter &p, AffineLoadOp op) {
2123   p << "affine.load " << op.getMemRef() << '[';
2124   if (AffineMapAttr mapAttr =
2125           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2126     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2127   p << ']';
2128   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2129   p << " : " << op.getMemRefType();
2130 }
2131 
2132 /// Verify common indexing invariants of affine.load, affine.store,
2133 /// affine.vector_load and affine.vector_store.
2134 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2135 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2136                        Operation::operand_range mapOperands,
2137                        MemRefType memrefType, unsigned numIndexOperands) {
2138   if (mapAttr) {
2139     AffineMap map = mapAttr.getValue();
2140     if (map.getNumResults() != memrefType.getRank())
2141       return op->emitOpError("affine map num results must equal memref rank");
2142     if (map.getNumInputs() != numIndexOperands)
2143       return op->emitOpError("expects as many subscripts as affine map inputs");
2144   } else {
2145     if (memrefType.getRank() != numIndexOperands)
2146       return op->emitOpError(
2147           "expects the number of subscripts to be equal to memref rank");
2148   }
2149 
2150   Region *scope = getAffineScope(op);
2151   for (auto idx : mapOperands) {
2152     if (!idx.getType().isIndex())
2153       return op->emitOpError("index to load must have 'index' type");
2154     if (!isValidAffineIndexOperand(idx, scope))
2155       return op->emitOpError("index must be a dimension or symbol identifier");
2156   }
2157 
2158   return success();
2159 }
2160 
verify(AffineLoadOp op)2161 LogicalResult verify(AffineLoadOp op) {
2162   auto memrefType = op.getMemRefType();
2163   if (op.getType() != memrefType.getElementType())
2164     return op.emitOpError("result type must match element type of memref");
2165 
2166   if (failed(verifyMemoryOpIndexing(
2167           op.getOperation(),
2168           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2169           op.getMapOperands(), memrefType,
2170           /*numIndexOperands=*/op.getNumOperands() - 1)))
2171     return failure();
2172 
2173   return success();
2174 }
2175 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2176 void AffineLoadOp::getCanonicalizationPatterns(
2177     OwningRewritePatternList &results, MLIRContext *context) {
2178   results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
2179 }
2180 
fold(ArrayRef<Attribute> cstOperands)2181 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2182   /// load(memrefcast) -> load
2183   if (succeeded(foldMemRefCast(*this)))
2184     return getResult();
2185   return OpFoldResult();
2186 }
2187 
2188 //===----------------------------------------------------------------------===//
2189 // AffineStoreOp
2190 //===----------------------------------------------------------------------===//
2191 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2192 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2193                           Value valueToStore, Value memref, AffineMap map,
2194                           ValueRange mapOperands) {
2195   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2196   result.addOperands(valueToStore);
2197   result.addOperands(memref);
2198   result.addOperands(mapOperands);
2199   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2200 }
2201 
2202 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2203 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2204                           Value valueToStore, Value memref,
2205                           ValueRange indices) {
2206   auto memrefType = memref.getType().cast<MemRefType>();
2207   int64_t rank = memrefType.getRank();
2208   // Create identity map for memrefs with at least one dimension or () -> ()
2209   // for zero-dimensional memrefs.
2210   auto map =
2211       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2212   build(builder, result, valueToStore, memref, map, indices);
2213 }
2214 
parseAffineStoreOp(OpAsmParser & parser,OperationState & result)2215 static ParseResult parseAffineStoreOp(OpAsmParser &parser,
2216                                       OperationState &result) {
2217   auto indexTy = parser.getBuilder().getIndexType();
2218 
2219   MemRefType type;
2220   OpAsmParser::OperandType storeValueInfo;
2221   OpAsmParser::OperandType memrefInfo;
2222   AffineMapAttr mapAttr;
2223   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2224   return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2225                  parser.parseOperand(memrefInfo) ||
2226                  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2227                                                AffineStoreOp::getMapAttrName(),
2228                                                result.attributes) ||
2229                  parser.parseOptionalAttrDict(result.attributes) ||
2230                  parser.parseColonType(type) ||
2231                  parser.resolveOperand(storeValueInfo, type.getElementType(),
2232                                        result.operands) ||
2233                  parser.resolveOperand(memrefInfo, type, result.operands) ||
2234                  parser.resolveOperands(mapOperands, indexTy, result.operands));
2235 }
2236 
print(OpAsmPrinter & p,AffineStoreOp op)2237 static void print(OpAsmPrinter &p, AffineStoreOp op) {
2238   p << "affine.store " << op.getValueToStore();
2239   p << ", " << op.getMemRef() << '[';
2240   if (AffineMapAttr mapAttr =
2241           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2242     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2243   p << ']';
2244   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2245   p << " : " << op.getMemRefType();
2246 }
2247 
verify(AffineStoreOp op)2248 LogicalResult verify(AffineStoreOp op) {
2249   // First operand must have same type as memref element type.
2250   auto memrefType = op.getMemRefType();
2251   if (op.getValueToStore().getType() != memrefType.getElementType())
2252     return op.emitOpError(
2253         "first operand must have same type memref element type");
2254 
2255   if (failed(verifyMemoryOpIndexing(
2256           op.getOperation(),
2257           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2258           op.getMapOperands(), memrefType,
2259           /*numIndexOperands=*/op.getNumOperands() - 2)))
2260     return failure();
2261 
2262   return success();
2263 }
2264 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2265 void AffineStoreOp::getCanonicalizationPatterns(
2266     OwningRewritePatternList &results, MLIRContext *context) {
2267   results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
2268 }
2269 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2270 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2271                                   SmallVectorImpl<OpFoldResult> &results) {
2272   /// store(memrefcast) -> store
2273   return foldMemRefCast(*this);
2274 }
2275 
2276 //===----------------------------------------------------------------------===//
2277 // AffineMinMaxOpBase
2278 //===----------------------------------------------------------------------===//
2279 
2280 template <typename T>
verifyAffineMinMaxOp(T op)2281 static LogicalResult verifyAffineMinMaxOp(T op) {
2282   // Verify that operand count matches affine map dimension and symbol count.
2283   if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
2284     return op.emitOpError(
2285         "operand count and affine map dimension and symbol count must match");
2286   return success();
2287 }
2288 
2289 template <typename T>
printAffineMinMaxOp(OpAsmPrinter & p,T op)2290 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2291   p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName());
2292   auto operands = op.getOperands();
2293   unsigned numDims = op.map().getNumDims();
2294   p << '(' << operands.take_front(numDims) << ')';
2295 
2296   if (operands.size() != numDims)
2297     p << '[' << operands.drop_front(numDims) << ']';
2298   p.printOptionalAttrDict(op.getAttrs(),
2299                           /*elidedAttrs=*/{T::getMapAttrName()});
2300 }
2301 
2302 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2303 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2304                                        OperationState &result) {
2305   auto &builder = parser.getBuilder();
2306   auto indexType = builder.getIndexType();
2307   SmallVector<OpAsmParser::OperandType, 8> dim_infos;
2308   SmallVector<OpAsmParser::OperandType, 8> sym_infos;
2309   AffineMapAttr mapAttr;
2310   return failure(
2311       parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
2312       parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
2313       parser.parseOperandList(sym_infos,
2314                               OpAsmParser::Delimiter::OptionalSquare) ||
2315       parser.parseOptionalAttrDict(result.attributes) ||
2316       parser.resolveOperands(dim_infos, indexType, result.operands) ||
2317       parser.resolveOperands(sym_infos, indexType, result.operands) ||
2318       parser.addTypeToList(indexType, result.types));
2319 }
2320 
2321 /// Fold an affine min or max operation with the given operands. The operand
2322 /// list may contain nulls, which are interpreted as the operand not being a
2323 /// constant.
2324 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2325 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2326   static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2327                 "expected affine min or max op");
2328 
2329   // Fold the affine map.
2330   // TODO: Fold more cases:
2331   // min(some_affine, some_affine + constant, ...), etc.
2332   SmallVector<int64_t, 2> results;
2333   auto foldedMap = op.map().partialConstantFold(operands, &results);
2334 
2335   // If some of the map results are not constant, try changing the map in-place.
2336   if (results.empty()) {
2337     // If the map is the same, report that folding did not happen.
2338     if (foldedMap == op.map())
2339       return {};
2340     op.setAttr("map", AffineMapAttr::get(foldedMap));
2341     return op.getResult();
2342   }
2343 
2344   // Otherwise, completely fold the op into a constant.
2345   auto resultIt = std::is_same<T, AffineMinOp>::value
2346                       ? std::min_element(results.begin(), results.end())
2347                       : std::max_element(results.begin(), results.end());
2348   if (resultIt == results.end())
2349     return {};
2350   return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2351 }
2352 
2353 //===----------------------------------------------------------------------===//
2354 // AffineMinOp
2355 //===----------------------------------------------------------------------===//
2356 //
2357 //   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
2358 //
2359 
fold(ArrayRef<Attribute> operands)2360 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
2361   return foldMinMaxOp(*this, operands);
2362 }
2363 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2364 void AffineMinOp::getCanonicalizationPatterns(
2365     OwningRewritePatternList &patterns, MLIRContext *context) {
2366   patterns.insert<SimplifyAffineOp<AffineMinOp>>(context);
2367 }
2368 
2369 //===----------------------------------------------------------------------===//
2370 // AffineMaxOp
2371 //===----------------------------------------------------------------------===//
2372 //
2373 //   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
2374 //
2375 
fold(ArrayRef<Attribute> operands)2376 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
2377   return foldMinMaxOp(*this, operands);
2378 }
2379 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2380 void AffineMaxOp::getCanonicalizationPatterns(
2381     OwningRewritePatternList &patterns, MLIRContext *context) {
2382   patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context);
2383 }
2384 
2385 //===----------------------------------------------------------------------===//
2386 // AffinePrefetchOp
2387 //===----------------------------------------------------------------------===//
2388 
2389 //
2390 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2391 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2392 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2393                                          OperationState &result) {
2394   auto &builder = parser.getBuilder();
2395   auto indexTy = builder.getIndexType();
2396 
2397   MemRefType type;
2398   OpAsmParser::OperandType memrefInfo;
2399   IntegerAttr hintInfo;
2400   auto i32Type = parser.getBuilder().getIntegerType(32);
2401   StringRef readOrWrite, cacheType;
2402 
2403   AffineMapAttr mapAttr;
2404   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2405   if (parser.parseOperand(memrefInfo) ||
2406       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2407                                     AffinePrefetchOp::getMapAttrName(),
2408                                     result.attributes) ||
2409       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2410       parser.parseComma() || parser.parseKeyword("locality") ||
2411       parser.parseLess() ||
2412       parser.parseAttribute(hintInfo, i32Type,
2413                             AffinePrefetchOp::getLocalityHintAttrName(),
2414                             result.attributes) ||
2415       parser.parseGreater() || parser.parseComma() ||
2416       parser.parseKeyword(&cacheType) ||
2417       parser.parseOptionalAttrDict(result.attributes) ||
2418       parser.parseColonType(type) ||
2419       parser.resolveOperand(memrefInfo, type, result.operands) ||
2420       parser.resolveOperands(mapOperands, indexTy, result.operands))
2421     return failure();
2422 
2423   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2424     return parser.emitError(parser.getNameLoc(),
2425                             "rw specifier has to be 'read' or 'write'");
2426   result.addAttribute(
2427       AffinePrefetchOp::getIsWriteAttrName(),
2428       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2429 
2430   if (!cacheType.equals("data") && !cacheType.equals("instr"))
2431     return parser.emitError(parser.getNameLoc(),
2432                             "cache type has to be 'data' or 'instr'");
2433 
2434   result.addAttribute(
2435       AffinePrefetchOp::getIsDataCacheAttrName(),
2436       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2437 
2438   return success();
2439 }
2440 
print(OpAsmPrinter & p,AffinePrefetchOp op)2441 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2442   p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2443   AffineMapAttr mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2444   if (mapAttr) {
2445     SmallVector<Value, 2> operands(op.getMapOperands());
2446     p.printAffineMapOfSSAIds(mapAttr, operands);
2447   }
2448   p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2449     << "locality<" << op.localityHint() << ">, "
2450     << (op.isDataCache() ? "data" : "instr");
2451   p.printOptionalAttrDict(
2452       op.getAttrs(),
2453       /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2454                        op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2455   p << " : " << op.getMemRefType();
2456 }
2457 
verify(AffinePrefetchOp op)2458 static LogicalResult verify(AffinePrefetchOp op) {
2459   auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2460   if (mapAttr) {
2461     AffineMap map = mapAttr.getValue();
2462     if (map.getNumResults() != op.getMemRefType().getRank())
2463       return op.emitOpError("affine.prefetch affine map num results must equal"
2464                             " memref rank");
2465     if (map.getNumInputs() + 1 != op.getNumOperands())
2466       return op.emitOpError("too few operands");
2467   } else {
2468     if (op.getNumOperands() != 1)
2469       return op.emitOpError("too few operands");
2470   }
2471 
2472   Region *scope = getAffineScope(op);
2473   for (auto idx : op.getMapOperands()) {
2474     if (!isValidAffineIndexOperand(idx, scope))
2475       return op.emitOpError("index must be a dimension or symbol identifier");
2476   }
2477   return success();
2478 }
2479 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2480 void AffinePrefetchOp::getCanonicalizationPatterns(
2481     OwningRewritePatternList &results, MLIRContext *context) {
2482   // prefetch(memrefcast) -> prefetch
2483   results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
2484 }
2485 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2486 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2487                                      SmallVectorImpl<OpFoldResult> &results) {
2488   /// prefetch(memrefcast) -> prefetch
2489   return foldMemRefCast(*this);
2490 }
2491 
2492 //===----------------------------------------------------------------------===//
2493 // AffineParallelOp
2494 //===----------------------------------------------------------------------===//
2495 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)2496 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2497                              TypeRange resultTypes,
2498                              ArrayRef<AtomicRMWKind> reductions,
2499                              ArrayRef<int64_t> ranges) {
2500   SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
2501                                      builder.getAffineConstantExpr(0));
2502   auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext());
2503   SmallVector<AffineExpr, 8> ubExprs;
2504   for (int64_t range : ranges)
2505     ubExprs.push_back(builder.getAffineConstantExpr(range));
2506   auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext());
2507   build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap,
2508         /*ubArgs=*/{});
2509 }
2510 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs)2511 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2512                              TypeRange resultTypes,
2513                              ArrayRef<AtomicRMWKind> reductions,
2514                              AffineMap lbMap, ValueRange lbArgs,
2515                              AffineMap ubMap, ValueRange ubArgs) {
2516   auto numDims = lbMap.getNumResults();
2517   // Verify that the dimensionality of both maps are the same.
2518   assert(numDims == ubMap.getNumResults() &&
2519          "num dims and num results mismatch");
2520   // Make default step sizes of 1.
2521   SmallVector<int64_t, 8> steps(numDims, 1);
2522   build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs,
2523         steps);
2524 }
2525 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs,ArrayRef<int64_t> steps)2526 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2527                              TypeRange resultTypes,
2528                              ArrayRef<AtomicRMWKind> reductions,
2529                              AffineMap lbMap, ValueRange lbArgs,
2530                              AffineMap ubMap, ValueRange ubArgs,
2531                              ArrayRef<int64_t> steps) {
2532   auto numDims = lbMap.getNumResults();
2533   // Verify that the dimensionality of the maps matches the number of steps.
2534   assert(numDims == ubMap.getNumResults() &&
2535          "num dims and num results mismatch");
2536   assert(numDims == steps.size() && "num dims and num steps mismatch");
2537 
2538   result.addTypes(resultTypes);
2539   // Convert the reductions to integer attributes.
2540   SmallVector<Attribute, 4> reductionAttrs;
2541   for (AtomicRMWKind reduction : reductions)
2542     reductionAttrs.push_back(
2543         builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
2544   result.addAttribute(getReductionsAttrName(),
2545                       builder.getArrayAttr(reductionAttrs));
2546   result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
2547   result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
2548   result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
2549   result.addOperands(lbArgs);
2550   result.addOperands(ubArgs);
2551   // Create a region and a block for the body.
2552   auto bodyRegion = result.addRegion();
2553   auto body = new Block();
2554   // Add all the block arguments.
2555   for (unsigned i = 0; i < numDims; ++i)
2556     body->addArgument(IndexType::get(builder.getContext()));
2557   bodyRegion->push_back(body);
2558   if (resultTypes.empty())
2559     ensureTerminator(*bodyRegion, builder, result.location);
2560 }
2561 
getLoopBody()2562 Region &AffineParallelOp::getLoopBody() { return region(); }
2563 
isDefinedOutsideOfLoop(Value value)2564 bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) {
2565   return !region().isAncestor(value.getParentRegion());
2566 }
2567 
moveOutOfLoop(ArrayRef<Operation * > ops)2568 LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
2569   for (Operation *op : ops)
2570     op->moveBefore(*this);
2571   return success();
2572 }
2573 
getNumDims()2574 unsigned AffineParallelOp::getNumDims() { return steps().size(); }
2575 
getLowerBoundsOperands()2576 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
2577   return getOperands().take_front(lowerBoundsMap().getNumInputs());
2578 }
2579 
getUpperBoundsOperands()2580 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
2581   return getOperands().drop_front(lowerBoundsMap().getNumInputs());
2582 }
2583 
getLowerBoundsValueMap()2584 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
2585   return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
2586 }
2587 
getUpperBoundsValueMap()2588 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
2589   return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
2590 }
2591 
getRangesValueMap()2592 AffineValueMap AffineParallelOp::getRangesValueMap() {
2593   AffineValueMap out;
2594   AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
2595                              &out);
2596   return out;
2597 }
2598 
getConstantRanges()2599 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
2600   // Try to convert all the ranges to constant expressions.
2601   SmallVector<int64_t, 8> out;
2602   AffineValueMap rangesValueMap = getRangesValueMap();
2603   out.reserve(rangesValueMap.getNumResults());
2604   for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
2605     auto expr = rangesValueMap.getResult(i);
2606     auto cst = expr.dyn_cast<AffineConstantExpr>();
2607     if (!cst)
2608       return llvm::None;
2609     out.push_back(cst.getValue());
2610   }
2611   return out;
2612 }
2613 
getBody()2614 Block *AffineParallelOp::getBody() { return &region().front(); }
2615 
getBodyBuilder()2616 OpBuilder AffineParallelOp::getBodyBuilder() {
2617   return OpBuilder(getBody(), std::prev(getBody()->end()));
2618 }
2619 
setLowerBounds(ValueRange lbOperands,AffineMap map)2620 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
2621   assert(lbOperands.size() == map.getNumInputs() &&
2622          "operands to map must match number of inputs");
2623   assert(map.getNumResults() >= 1 && "bounds map has at least one result");
2624 
2625   auto ubOperands = getUpperBoundsOperands();
2626 
2627   SmallVector<Value, 4> newOperands(lbOperands);
2628   newOperands.append(ubOperands.begin(), ubOperands.end());
2629   (*this)->setOperands(newOperands);
2630 
2631   lowerBoundsMapAttr(AffineMapAttr::get(map));
2632 }
2633 
setUpperBounds(ValueRange ubOperands,AffineMap map)2634 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
2635   assert(ubOperands.size() == map.getNumInputs() &&
2636          "operands to map must match number of inputs");
2637   assert(map.getNumResults() >= 1 && "bounds map has at least one result");
2638 
2639   SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
2640   newOperands.append(ubOperands.begin(), ubOperands.end());
2641   (*this)->setOperands(newOperands);
2642 
2643   upperBoundsMapAttr(AffineMapAttr::get(map));
2644 }
2645 
setLowerBoundsMap(AffineMap map)2646 void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
2647   AffineMap lbMap = lowerBoundsMap();
2648   assert(lbMap.getNumDims() == map.getNumDims() &&
2649          lbMap.getNumSymbols() == map.getNumSymbols());
2650   (void)lbMap;
2651   lowerBoundsMapAttr(AffineMapAttr::get(map));
2652 }
2653 
setUpperBoundsMap(AffineMap map)2654 void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
2655   AffineMap ubMap = upperBoundsMap();
2656   assert(ubMap.getNumDims() == map.getNumDims() &&
2657          ubMap.getNumSymbols() == map.getNumSymbols());
2658   (void)ubMap;
2659   upperBoundsMapAttr(AffineMapAttr::get(map));
2660 }
2661 
getSteps()2662 SmallVector<int64_t, 8> AffineParallelOp::getSteps() {
2663   SmallVector<int64_t, 8> result;
2664   for (Attribute attr : steps()) {
2665     result.push_back(attr.cast<IntegerAttr>().getInt());
2666   }
2667   return result;
2668 }
2669 
setSteps(ArrayRef<int64_t> newSteps)2670 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
2671   stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
2672 }
2673 
verify(AffineParallelOp op)2674 static LogicalResult verify(AffineParallelOp op) {
2675   auto numDims = op.getNumDims();
2676   if (op.lowerBoundsMap().getNumResults() != numDims ||
2677       op.upperBoundsMap().getNumResults() != numDims ||
2678       op.steps().size() != numDims ||
2679       op.getBody()->getNumArguments() != numDims)
2680     return op.emitOpError("region argument count and num results of upper "
2681                           "bounds, lower bounds, and steps must all match");
2682 
2683   if (op.reductions().size() != op.getNumResults())
2684     return op.emitOpError("a reduction must be specified for each output");
2685 
2686   // Verify reduction  ops are all valid
2687   for (Attribute attr : op.reductions()) {
2688     auto intAttr = attr.dyn_cast<IntegerAttr>();
2689     if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
2690       return op.emitOpError("invalid reduction attribute");
2691   }
2692 
2693   // Verify that the bound operands are valid dimension/symbols.
2694   /// Lower bounds.
2695   if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
2696                                            op.lowerBoundsMap().getNumDims())))
2697     return failure();
2698   /// Upper bounds.
2699   if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
2700                                            op.upperBoundsMap().getNumDims())))
2701     return failure();
2702   return success();
2703 }
2704 
canonicalize()2705 LogicalResult AffineValueMap::canonicalize() {
2706   SmallVector<Value, 4> newOperands{operands};
2707   auto newMap = getAffineMap();
2708   composeAffineMapAndOperands(&newMap, &newOperands);
2709   if (newMap == getAffineMap() && newOperands == operands)
2710     return failure();
2711   reset(newMap, newOperands);
2712   return success();
2713 }
2714 
2715 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineParallelOp op)2716 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
2717   AffineValueMap lb = op.getLowerBoundsValueMap();
2718   bool lbCanonicalized = succeeded(lb.canonicalize());
2719 
2720   AffineValueMap ub = op.getUpperBoundsValueMap();
2721   bool ubCanonicalized = succeeded(ub.canonicalize());
2722 
2723   // Any canonicalization change always leads to updated map(s).
2724   if (!lbCanonicalized && !ubCanonicalized)
2725     return failure();
2726 
2727   if (lbCanonicalized)
2728     op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
2729   if (ubCanonicalized)
2730     op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
2731 
2732   return success();
2733 }
2734 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2735 LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
2736                                      SmallVectorImpl<OpFoldResult> &results) {
2737   return canonicalizeLoopBounds(*this);
2738 }
2739 
print(OpAsmPrinter & p,AffineParallelOp op)2740 static void print(OpAsmPrinter &p, AffineParallelOp op) {
2741   p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
2742   p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(),
2743                            op.getLowerBoundsOperands());
2744   p << ") to (";
2745   p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(),
2746                            op.getUpperBoundsOperands());
2747   p << ')';
2748   SmallVector<int64_t, 8> steps = op.getSteps();
2749   bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
2750   if (!elideSteps) {
2751     p << " step (";
2752     llvm::interleaveComma(steps, p);
2753     p << ')';
2754   }
2755   if (op.getNumResults()) {
2756     p << " reduce (";
2757     llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
2758       AtomicRMWKind sym =
2759           *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
2760       p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
2761     });
2762     p << ") -> (" << op.getResultTypes() << ")";
2763   }
2764 
2765   p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
2766                 /*printBlockTerminators=*/op.getNumResults());
2767   p.printOptionalAttrDict(
2768       op.getAttrs(),
2769       /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
2770                        AffineParallelOp::getLowerBoundsMapAttrName(),
2771                        AffineParallelOp::getUpperBoundsMapAttrName(),
2772                        AffineParallelOp::getStepsAttrName()});
2773 }
2774 
2775 //
2776 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)`
2777 //               `to` `(` map-of-ssa-ids `)` steps? region attr-dict?
2778 // steps     ::= `steps` `(` integer-literals `)`
2779 //
parseAffineParallelOp(OpAsmParser & parser,OperationState & result)2780 static ParseResult parseAffineParallelOp(OpAsmParser &parser,
2781                                          OperationState &result) {
2782   auto &builder = parser.getBuilder();
2783   auto indexType = builder.getIndexType();
2784   AffineMapAttr lowerBoundsAttr, upperBoundsAttr;
2785   SmallVector<OpAsmParser::OperandType, 4> ivs;
2786   SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands;
2787   SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands;
2788   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
2789                                      OpAsmParser::Delimiter::Paren) ||
2790       parser.parseEqual() ||
2791       parser.parseAffineMapOfSSAIds(
2792           lowerBoundsMapOperands, lowerBoundsAttr,
2793           AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes,
2794           OpAsmParser::Delimiter::Paren) ||
2795       parser.resolveOperands(lowerBoundsMapOperands, indexType,
2796                              result.operands) ||
2797       parser.parseKeyword("to") ||
2798       parser.parseAffineMapOfSSAIds(
2799           upperBoundsMapOperands, upperBoundsAttr,
2800           AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes,
2801           OpAsmParser::Delimiter::Paren) ||
2802       parser.resolveOperands(upperBoundsMapOperands, indexType,
2803                              result.operands))
2804     return failure();
2805 
2806   AffineMapAttr stepsMapAttr;
2807   NamedAttrList stepsAttrs;
2808   SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
2809   if (failed(parser.parseOptionalKeyword("step"))) {
2810     SmallVector<int64_t, 4> steps(ivs.size(), 1);
2811     result.addAttribute(AffineParallelOp::getStepsAttrName(),
2812                         builder.getI64ArrayAttr(steps));
2813   } else {
2814     if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
2815                                       AffineParallelOp::getStepsAttrName(),
2816                                       stepsAttrs,
2817                                       OpAsmParser::Delimiter::Paren))
2818       return failure();
2819 
2820     // Convert steps from an AffineMap into an I64ArrayAttr.
2821     SmallVector<int64_t, 4> steps;
2822     auto stepsMap = stepsMapAttr.getValue();
2823     for (const auto &result : stepsMap.getResults()) {
2824       auto constExpr = result.dyn_cast<AffineConstantExpr>();
2825       if (!constExpr)
2826         return parser.emitError(parser.getNameLoc(),
2827                                 "steps must be constant integers");
2828       steps.push_back(constExpr.getValue());
2829     }
2830     result.addAttribute(AffineParallelOp::getStepsAttrName(),
2831                         builder.getI64ArrayAttr(steps));
2832   }
2833 
2834   // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
2835   // quoted strings are a member of the enum AtomicRMWKind.
2836   SmallVector<Attribute, 4> reductions;
2837   if (succeeded(parser.parseOptionalKeyword("reduce"))) {
2838     if (parser.parseLParen())
2839       return failure();
2840     do {
2841       // Parse a single quoted string via the attribute parsing, and then
2842       // verify it is a member of the enum and convert to it's integer
2843       // representation.
2844       StringAttr attrVal;
2845       NamedAttrList attrStorage;
2846       auto loc = parser.getCurrentLocation();
2847       if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
2848                                 attrStorage))
2849         return failure();
2850       llvm::Optional<AtomicRMWKind> reduction =
2851           symbolizeAtomicRMWKind(attrVal.getValue());
2852       if (!reduction)
2853         return parser.emitError(loc, "invalid reduction value: ") << attrVal;
2854       reductions.push_back(builder.getI64IntegerAttr(
2855           static_cast<int64_t>(reduction.getValue())));
2856       // While we keep getting commas, keep parsing.
2857     } while (succeeded(parser.parseOptionalComma()));
2858     if (parser.parseRParen())
2859       return failure();
2860   }
2861   result.addAttribute(AffineParallelOp::getReductionsAttrName(),
2862                       builder.getArrayAttr(reductions));
2863 
2864   // Parse return types of reductions (if any)
2865   if (parser.parseOptionalArrowTypeList(result.types))
2866     return failure();
2867 
2868   // Now parse the body.
2869   Region *body = result.addRegion();
2870   SmallVector<Type, 4> types(ivs.size(), indexType);
2871   if (parser.parseRegion(*body, ivs, types) ||
2872       parser.parseOptionalAttrDict(result.attributes))
2873     return failure();
2874 
2875   // Add a terminator if none was parsed.
2876   AffineParallelOp::ensureTerminator(*body, builder, result.location);
2877   return success();
2878 }
2879 
2880 //===----------------------------------------------------------------------===//
2881 // AffineYieldOp
2882 //===----------------------------------------------------------------------===//
2883 
verify(AffineYieldOp op)2884 static LogicalResult verify(AffineYieldOp op) {
2885   auto *parentOp = op->getParentOp();
2886   auto results = parentOp->getResults();
2887   auto operands = op.getOperands();
2888 
2889   if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
2890     return op.emitOpError() << "only terminates affine.if/for/parallel regions";
2891   if (parentOp->getNumResults() != op.getNumOperands())
2892     return op.emitOpError() << "parent of yield must have same number of "
2893                                "results as the yield operands";
2894   for (auto it : llvm::zip(results, operands)) {
2895     if (std::get<0>(it).getType() != std::get<1>(it).getType())
2896       return op.emitOpError()
2897              << "types mismatch between yield op and its parent";
2898   }
2899 
2900   return success();
2901 }
2902 
2903 //===----------------------------------------------------------------------===//
2904 // AffineVectorLoadOp
2905 //===----------------------------------------------------------------------===//
2906 
build(OpBuilder & builder,OperationState & result,VectorType resultType,AffineMap map,ValueRange operands)2907 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2908                                VectorType resultType, AffineMap map,
2909                                ValueRange operands) {
2910   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2911   result.addOperands(operands);
2912   if (map)
2913     result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2914   result.types.push_back(resultType);
2915 }
2916 
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,AffineMap map,ValueRange mapOperands)2917 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2918                                VectorType resultType, Value memref,
2919                                AffineMap map, ValueRange mapOperands) {
2920   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2921   result.addOperands(memref);
2922   result.addOperands(mapOperands);
2923   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2924   result.types.push_back(resultType);
2925 }
2926 
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,ValueRange indices)2927 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2928                                VectorType resultType, Value memref,
2929                                ValueRange indices) {
2930   auto memrefType = memref.getType().cast<MemRefType>();
2931   int64_t rank = memrefType.getRank();
2932   // Create identity map for memrefs with at least one dimension or () -> ()
2933   // for zero-dimensional memrefs.
2934   auto map =
2935       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2936   build(builder, result, resultType, memref, map, indices);
2937 }
2938 
parseAffineVectorLoadOp(OpAsmParser & parser,OperationState & result)2939 static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
2940                                            OperationState &result) {
2941   auto &builder = parser.getBuilder();
2942   auto indexTy = builder.getIndexType();
2943 
2944   MemRefType memrefType;
2945   VectorType resultType;
2946   OpAsmParser::OperandType memrefInfo;
2947   AffineMapAttr mapAttr;
2948   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2949   return failure(
2950       parser.parseOperand(memrefInfo) ||
2951       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2952                                     AffineVectorLoadOp::getMapAttrName(),
2953                                     result.attributes) ||
2954       parser.parseOptionalAttrDict(result.attributes) ||
2955       parser.parseColonType(memrefType) || parser.parseComma() ||
2956       parser.parseType(resultType) ||
2957       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2958       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2959       parser.addTypeToList(resultType, result.types));
2960 }
2961 
print(OpAsmPrinter & p,AffineVectorLoadOp op)2962 static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
2963   p << "affine.vector_load " << op.getMemRef() << '[';
2964   if (AffineMapAttr mapAttr =
2965           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2966     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2967   p << ']';
2968   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2969   p << " : " << op.getMemRefType() << ", " << op.getType();
2970 }
2971 
2972 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)2973 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
2974                                           VectorType vectorType) {
2975   // Check that memref and vector element types match.
2976   if (memrefType.getElementType() != vectorType.getElementType())
2977     return op->emitOpError(
2978         "requires memref and vector types of the same elemental type");
2979   return success();
2980 }
2981 
verify(AffineVectorLoadOp op)2982 static LogicalResult verify(AffineVectorLoadOp op) {
2983   MemRefType memrefType = op.getMemRefType();
2984   if (failed(verifyMemoryOpIndexing(
2985           op.getOperation(),
2986           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2987           op.getMapOperands(), memrefType,
2988           /*numIndexOperands=*/op.getNumOperands() - 1)))
2989     return failure();
2990 
2991   if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
2992                                   op.getVectorType())))
2993     return failure();
2994 
2995   return success();
2996 }
2997 
2998 //===----------------------------------------------------------------------===//
2999 // AffineVectorStoreOp
3000 //===----------------------------------------------------------------------===//
3001 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)3002 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3003                                 Value valueToStore, Value memref, AffineMap map,
3004                                 ValueRange mapOperands) {
3005   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3006   result.addOperands(valueToStore);
3007   result.addOperands(memref);
3008   result.addOperands(mapOperands);
3009   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
3010 }
3011 
3012 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)3013 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3014                                 Value valueToStore, Value memref,
3015                                 ValueRange indices) {
3016   auto memrefType = memref.getType().cast<MemRefType>();
3017   int64_t rank = memrefType.getRank();
3018   // Create identity map for memrefs with at least one dimension or () -> ()
3019   // for zero-dimensional memrefs.
3020   auto map =
3021       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3022   build(builder, result, valueToStore, memref, map, indices);
3023 }
3024 
parseAffineVectorStoreOp(OpAsmParser & parser,OperationState & result)3025 static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
3026                                             OperationState &result) {
3027   auto indexTy = parser.getBuilder().getIndexType();
3028 
3029   MemRefType memrefType;
3030   VectorType resultType;
3031   OpAsmParser::OperandType storeValueInfo;
3032   OpAsmParser::OperandType memrefInfo;
3033   AffineMapAttr mapAttr;
3034   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
3035   return failure(
3036       parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3037       parser.parseOperand(memrefInfo) ||
3038       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3039                                     AffineVectorStoreOp::getMapAttrName(),
3040                                     result.attributes) ||
3041       parser.parseOptionalAttrDict(result.attributes) ||
3042       parser.parseColonType(memrefType) || parser.parseComma() ||
3043       parser.parseType(resultType) ||
3044       parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
3045       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3046       parser.resolveOperands(mapOperands, indexTy, result.operands));
3047 }
3048 
print(OpAsmPrinter & p,AffineVectorStoreOp op)3049 static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
3050   p << "affine.vector_store " << op.getValueToStore();
3051   p << ", " << op.getMemRef() << '[';
3052   if (AffineMapAttr mapAttr =
3053           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
3054     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
3055   p << ']';
3056   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
3057   p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
3058 }
3059 
verify(AffineVectorStoreOp op)3060 static LogicalResult verify(AffineVectorStoreOp op) {
3061   MemRefType memrefType = op.getMemRefType();
3062   if (failed(verifyMemoryOpIndexing(
3063           op.getOperation(),
3064           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
3065           op.getMapOperands(), memrefType,
3066           /*numIndexOperands=*/op.getNumOperands() - 2)))
3067     return failure();
3068 
3069   if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
3070                                   op.getVectorType())))
3071     return failure();
3072 
3073   return success();
3074 }
3075 
3076 //===----------------------------------------------------------------------===//
3077 // TableGen'd op method definitions
3078 //===----------------------------------------------------------------------===//
3079 
3080 #define GET_OP_CLASSES
3081 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
3082