1 //===- AffineOps.h - MLIR Affine Operations -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines convenience types for working with Affine operations 10 // in the MLIR operation set. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H 15 #define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H 16 17 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/Dialect.h" 23 #include "mlir/IR/OpDefinition.h" 24 #include "mlir/Interfaces/LoopLikeInterface.h" 25 #include "mlir/Interfaces/SideEffectInterfaces.h" 26 27 namespace mlir { 28 class AffineApplyOp; 29 class AffineBound; 30 class AffineDimExpr; 31 class AffineValueMap; 32 class AffineYieldOp; 33 class FlatAffineConstraints; 34 class OpBuilder; 35 36 /// A utility function to check if a value is defined at the top level of an 37 /// op with trait `AffineScope` or is a region argument for such an op. A value 38 /// of index type defined at the top level is always a valid symbol for all its 39 /// uses. 40 bool isTopLevelValue(Value value); 41 42 /// AffineDmaStartOp starts a non-blocking DMA operation that transfers data 43 /// from a source memref to a destination memref. The source and destination 44 /// memref need not be of the same dimensionality, but need to have the same 45 /// elemental type. The operands include the source and destination memref's 46 /// each followed by its indices, size of the data transfer in terms of the 47 /// number of elements (of the elemental type of the memref), a tag memref with 48 /// its indices, and optionally at the end, a stride and a 49 /// number_of_elements_per_stride arguments. The tag location is used by an 50 /// AffineDmaWaitOp to check for completion. The indices of the source memref, 51 /// destination memref, and the tag memref have the same restrictions as any 52 /// affine.load/store. In particular, index for each memref dimension must be an 53 /// affine expression of loop induction variables and symbols. 54 /// The optional stride arguments should be of 'index' type, and specify a 55 /// stride for the slower memory space (memory space with a lower memory space 56 /// id), transferring chunks of number_of_elements_per_stride every stride until 57 /// %num_elements are transferred. Either both or no stride arguments should be 58 /// specified. The value of 'num_elements' must be a multiple of 59 /// 'number_of_elements_per_stride'. 60 // 61 // For example, a DmaStartOp operation that transfers 256 elements of a memref 62 // '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in memory 63 // space 1 at indices [%k + 7, %l], would be specified as follows: 64 // 65 // %num_elements = constant 256 66 // %idx = constant 0 : index 67 // %tag = alloc() : memref<1xi32, 4> 68 // affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], 69 // %num_elements : 70 // memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> 71 // 72 // If %stride and %num_elt_per_stride are specified, the DMA is expected to 73 // transfer %num_elt_per_stride elements every %stride elements apart from 74 // memory space 0 until %num_elements are transferred. 75 // 76 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements, 77 // %stride, %num_elt_per_stride : ... 78 // 79 // TODO: add additional operands to allow source and destination striding, and 80 // multiple stride levels (possibly using AffineMaps to specify multiple levels 81 // of striding). 82 // TODO: Consider replacing src/dst memref indices with view memrefs. 83 class AffineDmaStartOp 84 : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable, 85 OpTrait::VariadicOperands, OpTrait::ZeroResult> { 86 public: 87 using Op::Op; 88 89 static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, 90 AffineMap srcMap, ValueRange srcIndices, Value destMemRef, 91 AffineMap dstMap, ValueRange destIndices, Value tagMemRef, 92 AffineMap tagMap, ValueRange tagIndices, Value numElements, 93 Value stride = nullptr, Value elementsPerStride = nullptr); 94 95 /// Returns the operand index of the src memref. getSrcMemRefOperandIndex()96 unsigned getSrcMemRefOperandIndex() { return 0; } 97 98 /// Returns the source MemRefType for this DMA operation. getSrcMemRef()99 Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } getSrcMemRefType()100 MemRefType getSrcMemRefType() { 101 return getSrcMemRef().getType().cast<MemRefType>(); 102 } 103 104 /// Returns the rank (number of indices) of the source MemRefType. getSrcMemRefRank()105 unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); } 106 107 /// Returns the affine map used to access the src memref. getSrcMap()108 AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } getSrcMapAttr()109 AffineMapAttr getSrcMapAttr() { 110 return getAttr(getSrcMapAttrName()).cast<AffineMapAttr>(); 111 } 112 113 /// Returns the source memref affine map indices for this DMA operation. getSrcIndices()114 operand_range getSrcIndices() { 115 return {operand_begin() + getSrcMemRefOperandIndex() + 1, 116 operand_begin() + getSrcMemRefOperandIndex() + 1 + 117 getSrcMap().getNumInputs()}; 118 } 119 120 /// Returns the memory space of the src memref. getSrcMemorySpace()121 unsigned getSrcMemorySpace() { 122 return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace(); 123 } 124 125 /// Returns the operand index of the dst memref. getDstMemRefOperandIndex()126 unsigned getDstMemRefOperandIndex() { 127 return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); 128 } 129 130 /// Returns the destination MemRefType for this DMA operations. getDstMemRef()131 Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } getDstMemRefType()132 MemRefType getDstMemRefType() { 133 return getDstMemRef().getType().cast<MemRefType>(); 134 } 135 136 /// Returns the rank (number of indices) of the destination MemRefType. getDstMemRefRank()137 unsigned getDstMemRefRank() { 138 return getDstMemRef().getType().cast<MemRefType>().getRank(); 139 } 140 141 /// Returns the memory space of the src memref. getDstMemorySpace()142 unsigned getDstMemorySpace() { 143 return getDstMemRef().getType().cast<MemRefType>().getMemorySpace(); 144 } 145 146 /// Returns the affine map used to access the dst memref. getDstMap()147 AffineMap getDstMap() { return getDstMapAttr().getValue(); } getDstMapAttr()148 AffineMapAttr getDstMapAttr() { 149 return getAttr(getDstMapAttrName()).cast<AffineMapAttr>(); 150 } 151 152 /// Returns the destination memref indices for this DMA operation. getDstIndices()153 operand_range getDstIndices() { 154 return {operand_begin() + getDstMemRefOperandIndex() + 1, 155 operand_begin() + getDstMemRefOperandIndex() + 1 + 156 getDstMap().getNumInputs()}; 157 } 158 159 /// Returns the operand index of the tag memref. getTagMemRefOperandIndex()160 unsigned getTagMemRefOperandIndex() { 161 return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs(); 162 } 163 164 /// Returns the Tag MemRef for this DMA operation. getTagMemRef()165 Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } getTagMemRefType()166 MemRefType getTagMemRefType() { 167 return getTagMemRef().getType().cast<MemRefType>(); 168 } 169 170 /// Returns the rank (number of indices) of the tag MemRefType. getTagMemRefRank()171 unsigned getTagMemRefRank() { 172 return getTagMemRef().getType().cast<MemRefType>().getRank(); 173 } 174 175 /// Returns the affine map used to access the tag memref. getTagMap()176 AffineMap getTagMap() { return getTagMapAttr().getValue(); } getTagMapAttr()177 AffineMapAttr getTagMapAttr() { 178 return getAttr(getTagMapAttrName()).cast<AffineMapAttr>(); 179 } 180 181 /// Returns the tag memref indices for this DMA operation. getTagIndices()182 operand_range getTagIndices() { 183 return {operand_begin() + getTagMemRefOperandIndex() + 1, 184 operand_begin() + getTagMemRefOperandIndex() + 1 + 185 getTagMap().getNumInputs()}; 186 } 187 188 /// Returns the number of elements being transferred by this DMA operation. getNumElements()189 Value getNumElements() { 190 return getOperand(getTagMemRefOperandIndex() + 1 + 191 getTagMap().getNumInputs()); 192 } 193 194 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)195 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 196 if (memref == getSrcMemRef()) 197 return {Identifier::get(getSrcMapAttrName(), getContext()), 198 getSrcMapAttr()}; 199 else if (memref == getDstMemRef()) 200 return {Identifier::get(getDstMapAttrName(), getContext()), 201 getDstMapAttr()}; 202 assert(memref == getTagMemRef() && 203 "DmaStartOp expected source, destination or tag memref"); 204 return {Identifier::get(getTagMapAttrName(), getContext()), 205 getTagMapAttr()}; 206 } 207 208 /// Returns true if this is a DMA from a faster memory space to a slower one. isDestMemorySpaceFaster()209 bool isDestMemorySpaceFaster() { 210 return (getSrcMemorySpace() < getDstMemorySpace()); 211 } 212 213 /// Returns true if this is a DMA from a slower memory space to a faster one. isSrcMemorySpaceFaster()214 bool isSrcMemorySpaceFaster() { 215 // Assumes that a lower number is for a slower memory space. 216 return (getDstMemorySpace() < getSrcMemorySpace()); 217 } 218 219 /// Given a DMA start operation, returns the operand position of either the 220 /// source or destination memref depending on the one that is at the higher 221 /// level of the memory hierarchy. Asserts failure if neither is true. getFasterMemPos()222 unsigned getFasterMemPos() { 223 assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); 224 return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex(); 225 } 226 getSrcMapAttrName()227 static StringRef getSrcMapAttrName() { return "src_map"; } getDstMapAttrName()228 static StringRef getDstMapAttrName() { return "dst_map"; } getTagMapAttrName()229 static StringRef getTagMapAttrName() { return "tag_map"; } 230 getOperationName()231 static StringRef getOperationName() { return "affine.dma_start"; } 232 static ParseResult parse(OpAsmParser &parser, OperationState &result); 233 void print(OpAsmPrinter &p); 234 LogicalResult verify(); 235 LogicalResult fold(ArrayRef<Attribute> cstOperands, 236 SmallVectorImpl<OpFoldResult> &results); 237 238 /// Returns true if this DMA operation is strided, returns false otherwise. isStrided()239 bool isStrided() { 240 return getNumOperands() != 241 getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1; 242 } 243 244 /// Returns the stride value for this DMA operation. getStride()245 Value getStride() { 246 if (!isStrided()) 247 return nullptr; 248 return getOperand(getNumOperands() - 1 - 1); 249 } 250 251 /// Returns the number of elements to transfer per stride for this DMA op. getNumElementsPerStride()252 Value getNumElementsPerStride() { 253 if (!isStrided()) 254 return nullptr; 255 return getOperand(getNumOperands() - 1); 256 } 257 }; 258 259 /// AffineDmaWaitOp blocks until the completion of a DMA operation associated 260 /// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be 261 /// an index with the same restrictions as any load/store index. In particular, 262 /// index for each memref dimension must be an affine expression of loop 263 /// induction variables and symbols. %num_elements is the number of elements 264 /// associated with the DMA operation. For example: 265 // 266 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements : 267 // memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2> 268 // ... 269 // ... 270 // affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2> 271 // 272 class AffineDmaWaitOp 273 : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable, 274 OpTrait::VariadicOperands, OpTrait::ZeroResult> { 275 public: 276 using Op::Op; 277 278 static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, 279 AffineMap tagMap, ValueRange tagIndices, Value numElements); 280 getOperationName()281 static StringRef getOperationName() { return "affine.dma_wait"; } 282 283 // Returns the Tag MemRef associated with the DMA operation being waited on. getTagMemRef()284 Value getTagMemRef() { return getOperand(0); } getTagMemRefType()285 MemRefType getTagMemRefType() { 286 return getTagMemRef().getType().cast<MemRefType>(); 287 } 288 289 /// Returns the affine map used to access the tag memref. getTagMap()290 AffineMap getTagMap() { return getTagMapAttr().getValue(); } getTagMapAttr()291 AffineMapAttr getTagMapAttr() { 292 return getAttr(getTagMapAttrName()).cast<AffineMapAttr>(); 293 } 294 295 // Returns the tag memref index for this DMA operation. getTagIndices()296 operand_range getTagIndices() { 297 return {operand_begin() + 1, 298 operand_begin() + 1 + getTagMap().getNumInputs()}; 299 } 300 301 // Returns the rank (number of indices) of the tag memref. getTagMemRefRank()302 unsigned getTagMemRefRank() { 303 return getTagMemRef().getType().cast<MemRefType>().getRank(); 304 } 305 306 /// Returns the AffineMapAttr associated with 'memref'. getAffineMapAttrForMemRef(Value memref)307 NamedAttribute getAffineMapAttrForMemRef(Value memref) { 308 assert(memref == getTagMemRef()); 309 return {Identifier::get(getTagMapAttrName(), getContext()), 310 getTagMapAttr()}; 311 } 312 313 /// Returns the number of elements transferred in the associated DMA op. getNumElements()314 Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); } 315 getTagMapAttrName()316 static StringRef getTagMapAttrName() { return "tag_map"; } 317 static ParseResult parse(OpAsmParser &parser, OperationState &result); 318 void print(OpAsmPrinter &p); 319 LogicalResult verify(); 320 LogicalResult fold(ArrayRef<Attribute> cstOperands, 321 SmallVectorImpl<OpFoldResult> &results); 322 }; 323 324 /// Returns true if the given Value can be used as a dimension id in the region 325 /// of the closest surrounding op that has the trait `AffineScope`. 326 bool isValidDim(Value value); 327 328 /// Returns true if the given Value can be used as a dimension id in `region`, 329 /// i.e., for all its uses in `region`. 330 bool isValidDim(Value value, Region *region); 331 332 /// Returns true if the given value can be used as a symbol in the region of the 333 /// closest surrounding op that has the trait `AffineScope`. 334 bool isValidSymbol(Value value); 335 336 /// Returns true if the given Value can be used as a symbol for `region`, i.e., 337 /// for all its uses in `region`. 338 bool isValidSymbol(Value value, Region *region); 339 340 /// Parses dimension and symbol list and returns true if parsing failed. 341 ParseResult parseDimAndSymbolList(OpAsmParser &parser, 342 SmallVectorImpl<Value> &operands, 343 unsigned &numDims); 344 345 /// Modifies both `map` and `operands` in-place so as to: 346 /// 1. drop duplicate operands 347 /// 2. drop unused dims and symbols from map 348 /// 3. promote valid symbols to symbolic operands in case they appeared as 349 /// dimensional operands 350 /// 4. propagate constant operands and drop them 351 void canonicalizeMapAndOperands(AffineMap *map, 352 SmallVectorImpl<Value> *operands); 353 354 /// Canonicalizes an integer set the same way canonicalizeMapAndOperands does 355 /// for affine maps. 356 void canonicalizeSetAndOperands(IntegerSet *set, 357 SmallVectorImpl<Value> *operands); 358 359 /// Returns a composed AffineApplyOp by composing `map` and `operands` with 360 /// other AffineApplyOps supplying those operands. The operands of the resulting 361 /// AffineApplyOp do not change the length of AffineApplyOp chains. 362 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, 363 ArrayRef<Value> operands); 364 365 /// Given an affine map `map` and its input `operands`, this method composes 366 /// into `map`, maps of AffineApplyOps whose results are the values in 367 /// `operands`, iteratively until no more of `operands` are the result of an 368 /// AffineApplyOp. When this function returns, `map` becomes the composed affine 369 /// map, and each Value in `operands` is guaranteed to be either a loop IV or a 370 /// terminal symbol, i.e., a symbol defined at the top level or a block/function 371 /// argument. 372 void fullyComposeAffineMapAndOperands(AffineMap *map, 373 SmallVectorImpl<Value> *operands); 374 375 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.h.inc" 376 377 #define GET_OP_CLASSES 378 #include "mlir/Dialect/Affine/IR/AffineOps.h.inc" 379 380 /// Returns true if the provided value is the induction variable of a 381 /// AffineForOp. 382 bool isForInductionVar(Value val); 383 384 /// Returns the loop parent of an induction variable. If the provided value is 385 /// not an induction variable, then return nullptr. 386 AffineForOp getForInductionVarOwner(Value val); 387 388 /// Extracts the induction variables from a list of AffineForOps and places them 389 /// in the output argument `ivs`. 390 void extractForInductionVars(ArrayRef<AffineForOp> forInsts, 391 SmallVectorImpl<Value> *ivs); 392 393 /// Builds a perfect nest of affine "for" loops, i.e. each loop except the 394 /// innermost only contains another loop and a terminator. The loops iterate 395 /// from "lbs" to "ubs" with "steps". The body of the innermost loop is 396 /// populated by calling "bodyBuilderFn" and providing it with an OpBuilder, a 397 /// Location and a list of loop induction variables. 398 void buildAffineLoopNest(OpBuilder &builder, Location loc, 399 ArrayRef<int64_t> lbs, ArrayRef<int64_t> ubs, 400 ArrayRef<int64_t> steps, 401 function_ref<void(OpBuilder &, Location, ValueRange)> 402 bodyBuilderFn = nullptr); 403 void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, 404 ValueRange ubs, ArrayRef<int64_t> steps, 405 function_ref<void(OpBuilder &, Location, ValueRange)> 406 bodyBuilderFn = nullptr); 407 408 /// AffineBound represents a lower or upper bound in the for operation. 409 /// This class does not own the underlying operands. Instead, it refers 410 /// to the operands stored in the AffineForOp. Its life span should not exceed 411 /// that of the for operation it refers to. 412 class AffineBound { 413 public: getAffineForOp()414 AffineForOp getAffineForOp() { return op; } getMap()415 AffineMap getMap() { return map; } 416 getNumOperands()417 unsigned getNumOperands() { return opEnd - opStart; } getOperand(unsigned idx)418 Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); } 419 420 using operand_iterator = AffineForOp::operand_iterator; 421 using operand_range = AffineForOp::operand_range; 422 operand_begin()423 operand_iterator operand_begin() { return op.operand_begin() + opStart; } operand_end()424 operand_iterator operand_end() { return op.operand_begin() + opEnd; } getOperands()425 operand_range getOperands() { return {operand_begin(), operand_end()}; } 426 427 private: 428 // 'affine.for' operation that contains this bound. 429 AffineForOp op; 430 // Start and end positions of this affine bound operands in the list of 431 // the containing 'affine.for' operation operands. 432 unsigned opStart, opEnd; 433 // Affine map for this bound. 434 AffineMap map; 435 AffineBound(AffineForOp op,unsigned opStart,unsigned opEnd,AffineMap map)436 AffineBound(AffineForOp op, unsigned opStart, unsigned opEnd, AffineMap map) 437 : op(op), opStart(opStart), opEnd(opEnd), map(map) {} 438 439 friend class AffineForOp; 440 }; 441 442 /// An `AffineApplyNormalizer` is a helper class that supports renumbering 443 /// operands of AffineApplyOp. This acts as a reindexing map of Value to 444 /// positional dims or symbols and allows simplifications such as: 445 /// 446 /// ```mlir 447 /// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) 448 /// ``` 449 /// 450 /// into: 451 /// 452 /// ```mlir 453 /// %1 = affine.apply () -> (0) 454 /// ``` 455 struct AffineApplyNormalizer { 456 AffineApplyNormalizer(AffineMap map, ArrayRef<Value> operands); 457 458 /// Returns the AffineMap resulting from normalization. getAffineMapAffineApplyNormalizer459 AffineMap getAffineMap() { return affineMap; } 460 getOperandsAffineApplyNormalizer461 SmallVector<Value, 8> getOperands() { 462 SmallVector<Value, 8> res(reorderedDims); 463 res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); 464 return res; 465 } 466 getNumSymbolsAffineApplyNormalizer467 unsigned getNumSymbols() { return concatenatedSymbols.size(); } getNumDimsAffineApplyNormalizer468 unsigned getNumDims() { return reorderedDims.size(); } 469 470 /// Normalizes 'otherMap' and its operands 'otherOperands' to map to this 471 /// normalizer's coordinate space. 472 void normalize(AffineMap *otherMap, SmallVectorImpl<Value> *otherOperands); 473 474 private: 475 /// Helper function to insert `v` into the coordinate system of the current 476 /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding 477 /// renumbered position. 478 AffineDimExpr renumberOneDim(Value v); 479 480 /// Given an `other` normalizer, this rewrites `other.affineMap` in the 481 /// coordinate system of the current AffineApplyNormalizer. 482 /// Returns the rewritten AffineMap and updates the dims and symbols of 483 /// `this`. 484 AffineMap renumber(const AffineApplyNormalizer &other); 485 486 /// Maps of Value to position in `affineMap`. 487 DenseMap<Value, unsigned> dimValueToPosition; 488 489 /// Ordered dims and symbols matching positional dims and symbols in 490 /// `affineMap`. 491 SmallVector<Value, 8> reorderedDims; 492 SmallVector<Value, 8> concatenatedSymbols; 493 494 /// The number of symbols in concatenated symbols that belong to the original 495 /// map as opposed to those concatendated during map composition. 496 unsigned numProperSymbols; 497 498 AffineMap affineMap; 499 500 /// Used with RAII to control the depth at which AffineApply are composed 501 /// recursively. Only accepts depth 1 for now to allow a behavior where a 502 /// newly composed AffineApplyOp does not increase the length of the chain of 503 /// AffineApplyOps. Full composition is implemented iteratively on top of 504 /// this behavior. affineApplyDepthAffineApplyNormalizer505 static unsigned &affineApplyDepth() { 506 static thread_local unsigned depth = 0; 507 return depth; 508 } 509 static constexpr unsigned kMaxAffineApplyDepth = 1; 510 AffineApplyNormalizerAffineApplyNormalizer511 AffineApplyNormalizer() : numProperSymbols(0) { affineApplyDepth()++; } 512 513 public: ~AffineApplyNormalizerAffineApplyNormalizer514 ~AffineApplyNormalizer() { affineApplyDepth()--; } 515 }; 516 517 } // end namespace mlir 518 519 #endif 520