1 //===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines convenience types for working with standard operations 10 // in the MLIR operation set. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H 15 #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H 16 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/Interfaces/CallInterfaces.h" 22 #include "mlir/Interfaces/ControlFlowInterfaces.h" 23 #include "mlir/Interfaces/SideEffectInterfaces.h" 24 #include "mlir/Interfaces/VectorInterfaces.h" 25 #include "mlir/Interfaces/ViewLikeInterface.h" 26 27 // Pull in all enum type definitions and utility function declarations. 28 #include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc" 29 30 namespace mlir { 31 class AffineMap; 32 class Builder; 33 class FuncOp; 34 class OpBuilder; 35 36 raw_ostream &operator<<(raw_ostream &os, Range &range); 37 38 /// Return the list of Range (i.e. offset, size, stride). Each Range 39 /// entry contains either the dynamic value or a ConstantIndexOp constructed 40 /// with `b` at location `loc`. 41 SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 42 OpBuilder &b, Location loc); 43 44 #define GET_OP_CLASSES 45 #include "mlir/Dialect/StandardOps/IR/Ops.h.inc" 46 47 #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc" 48 49 /// This is a refinement of the "constant" op for the case where it is 50 /// returning a float value of FloatType. 51 /// 52 /// %1 = "std.constant"(){value: 42.0} : bf16 53 /// 54 class ConstantFloatOp : public ConstantOp { 55 public: 56 using ConstantOp::ConstantOp; 57 58 /// Builds a constant float op producing a float of the specified type. 59 static void build(OpBuilder &builder, OperationState &result, 60 const APFloat &value, FloatType type); 61 getValue()62 APFloat getValue() { 63 return (*this)->getAttrOfType<FloatAttr>("value").getValue(); 64 } 65 66 static bool classof(Operation *op); 67 }; 68 69 /// This is a refinement of the "constant" op for the case where it is 70 /// returning an integer value of IntegerType. 71 /// 72 /// %1 = "std.constant"(){value: 42} : i32 73 /// 74 class ConstantIntOp : public ConstantOp { 75 public: 76 using ConstantOp::ConstantOp; 77 /// Build a constant int op producing an integer of the specified width. 78 static void build(OpBuilder &builder, OperationState &result, int64_t value, 79 unsigned width); 80 81 /// Build a constant int op producing an integer with the specified type, 82 /// which must be an integer type. 83 static void build(OpBuilder &builder, OperationState &result, int64_t value, 84 Type type); 85 getValue()86 int64_t getValue() { 87 return (*this)->getAttrOfType<IntegerAttr>("value").getInt(); 88 } 89 90 static bool classof(Operation *op); 91 }; 92 93 /// This is a refinement of the "constant" op for the case where it is 94 /// returning an integer value of Index type. 95 /// 96 /// %1 = "std.constant"(){value: 99} : () -> index 97 /// 98 class ConstantIndexOp : public ConstantOp { 99 public: 100 using ConstantOp::ConstantOp; 101 102 /// Build a constant int op producing an index. 103 static void build(OpBuilder &builder, OperationState &result, int64_t value); 104 getValue()105 int64_t getValue() { 106 return (*this)->getAttrOfType<IntegerAttr>("value").getInt(); 107 } 108 109 static bool classof(Operation *op); 110 }; 111 112 // DmaStartOp starts a non-blocking DMA operation that transfers data from a 113 // source memref to a destination memref. The source and destination memref need 114 // not be of the same dimensionality, but need to have the same elemental type. 115 // The operands include the source and destination memref's each followed by its 116 // indices, size of the data transfer in terms of the number of elements (of the 117 // elemental type of the memref), a tag memref with its indices, and optionally 118 // at the end, a stride and a number_of_elements_per_stride arguments. The tag 119 // location is used by a DmaWaitOp to check for completion. The indices of the 120 // source memref, destination memref, and the tag memref have the same 121 // restrictions as any load/store. The optional stride arguments should be of 122 // 'index' type, and specify a stride for the slower memory space (memory space 123 // with a lower memory space id), transferring chunks of 124 // number_of_elements_per_stride every stride until %num_elements are 125 // transferred. Either both or no stride arguments should be specified. 126 // 127 // For example, a DmaStartOp operation that transfers 256 elements of a memref 128 // '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space 129 // 1 at indices [%k, %l], would be specified as follows: 130 // 131 // %num_elements = constant 256 132 // %idx = constant 0 : index 133 // %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> 134 // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : 135 // memref<40 x 128 x f32>, (d0) -> (d0), 0>, 136 // memref<2 x 1024 x f32>, (d0) -> (d0), 1>, 137 // memref<1 x i32>, (d0) -> (d0), 2> 138 // 139 // If %stride and %num_elt_per_stride are specified, the DMA is expected to 140 // transfer %num_elt_per_stride elements every %stride elements apart from 141 // memory space 0 until %num_elements are transferred. 142 // 143 // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, 144 // %num_elt_per_stride : 145 // 146 // TODO: add additional operands to allow source and destination striding, and 147 // multiple stride levels. 148 // TODO: Consider replacing src/dst memref indices with view memrefs. 149 class DmaStartOp 150 : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { 151 public: 152 using Op::Op; 153 154 static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, 155 ValueRange srcIndices, Value destMemRef, 156 ValueRange destIndices, Value numElements, Value tagMemRef, 157 ValueRange tagIndices, Value stride = nullptr, 158 Value elementsPerStride = nullptr); 159 160 // Returns the source MemRefType for this DMA operation. getSrcMemRef()161 Value getSrcMemRef() { return getOperand(0); } 162 // Returns the rank (number of indices) of the source MemRefType. getSrcMemRefRank()163 unsigned getSrcMemRefRank() { 164 return getSrcMemRef().getType().cast<MemRefType>().getRank(); 165 } 166 // Returns the source memref indices for this DMA operation. getSrcIndices()167 operand_range getSrcIndices() { 168 return {(*this)->operand_begin() + 1, 169 (*this)->operand_begin() + 1 + getSrcMemRefRank()}; 170 } 171 172 // Returns the destination MemRefType for this DMA operations. getDstMemRef()173 Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } 174 // Returns the rank (number of indices) of the destination MemRefType. getDstMemRefRank()175 unsigned getDstMemRefRank() { 176 return getDstMemRef().getType().cast<MemRefType>().getRank(); 177 } getSrcMemorySpace()178 unsigned getSrcMemorySpace() { 179 return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace(); 180 } getDstMemorySpace()181 unsigned getDstMemorySpace() { 182 return getDstMemRef().getType().cast<MemRefType>().getMemorySpace(); 183 } 184 185 // Returns the destination memref indices for this DMA operation. getDstIndices()186 operand_range getDstIndices() { 187 return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1, 188 (*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 + 189 getDstMemRefRank()}; 190 } 191 192 // Returns the number of elements being transferred by this DMA operation. getNumElements()193 Value getNumElements() { 194 return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); 195 } 196 197 // Returns the Tag MemRef for this DMA operation. getTagMemRef()198 Value getTagMemRef() { 199 return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); 200 } 201 // Returns the rank (number of indices) of the tag MemRefType. getTagMemRefRank()202 unsigned getTagMemRefRank() { 203 return getTagMemRef().getType().cast<MemRefType>().getRank(); 204 } 205 206 // Returns the tag memref index for this DMA operation. getTagIndices()207 operand_range getTagIndices() { 208 unsigned tagIndexStartPos = 209 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; 210 return {(*this)->operand_begin() + tagIndexStartPos, 211 (*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()}; 212 } 213 214 /// Returns true if this is a DMA from a faster memory space to a slower one. isDestMemorySpaceFaster()215 bool isDestMemorySpaceFaster() { 216 return (getSrcMemorySpace() < getDstMemorySpace()); 217 } 218 219 /// Returns true if this is a DMA from a slower memory space to a faster one. isSrcMemorySpaceFaster()220 bool isSrcMemorySpaceFaster() { 221 // Assumes that a lower number is for a slower memory space. 222 return (getDstMemorySpace() < getSrcMemorySpace()); 223 } 224 225 /// Given a DMA start operation, returns the operand position of either the 226 /// source or destination memref depending on the one that is at the higher 227 /// level of the memory hierarchy. Asserts failure if neither is true. getFasterMemPos()228 unsigned getFasterMemPos() { 229 assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); 230 return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; 231 } 232 getOperationName()233 static StringRef getOperationName() { return "std.dma_start"; } 234 static ParseResult parse(OpAsmParser &parser, OperationState &result); 235 void print(OpAsmPrinter &p); 236 LogicalResult verify(); 237 238 LogicalResult fold(ArrayRef<Attribute> cstOperands, 239 SmallVectorImpl<OpFoldResult> &results); 240 isStrided()241 bool isStrided() { 242 return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 243 1 + 1 + getTagMemRefRank(); 244 } 245 getStride()246 Value getStride() { 247 if (!isStrided()) 248 return nullptr; 249 return getOperand(getNumOperands() - 1 - 1); 250 } 251 getNumElementsPerStride()252 Value getNumElementsPerStride() { 253 if (!isStrided()) 254 return nullptr; 255 return getOperand(getNumOperands() - 1); 256 } 257 }; 258 259 // DmaWaitOp blocks until the completion of a DMA operation associated with the 260 // tag element '%tag[%index]'. %tag is a memref, and %index has to be an index 261 // with the same restrictions as any load/store index. %num_elements is the 262 // number of elements associated with the DMA operation. For example: 263 // 264 // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : 265 // memref<2048 x f32>, (d0) -> (d0), 0>, 266 // memref<256 x f32>, (d0) -> (d0), 1> 267 // memref<1 x i32>, (d0) -> (d0), 2> 268 // ... 269 // ... 270 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> 271 // 272 class DmaWaitOp 273 : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { 274 public: 275 using Op::Op; 276 277 static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, 278 ValueRange tagIndices, Value numElements); 279 getOperationName()280 static StringRef getOperationName() { return "std.dma_wait"; } 281 282 // Returns the Tag MemRef associated with the DMA operation being waited on. getTagMemRef()283 Value getTagMemRef() { return getOperand(0); } 284 285 // Returns the tag memref index for this DMA operation. getTagIndices()286 operand_range getTagIndices() { 287 return {(*this)->operand_begin() + 1, 288 (*this)->operand_begin() + 1 + getTagMemRefRank()}; 289 } 290 291 // Returns the rank (number of indices) of the tag memref. getTagMemRefRank()292 unsigned getTagMemRefRank() { 293 return getTagMemRef().getType().cast<MemRefType>().getRank(); 294 } 295 296 // Returns the number of elements transferred in the associated DMA operation. getNumElements()297 Value getNumElements() { return getOperand(1 + getTagMemRefRank()); } 298 299 static ParseResult parse(OpAsmParser &parser, OperationState &result); 300 void print(OpAsmPrinter &p); 301 LogicalResult fold(ArrayRef<Attribute> cstOperands, 302 SmallVectorImpl<OpFoldResult> &results); 303 LogicalResult verify(); 304 }; 305 306 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 307 /// `originalShape` with some `1` entries erased, return the vector of booleans 308 /// that specifies which of the entries of `originalShape` are keep to obtain 309 /// `reducedShape`. The returned mask can be applied as a projection to 310 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 311 /// which dimensions must be kept when e.g. compute MemRef strides under 312 /// rank-reducing operations. Return None if reducedShape cannot be obtained 313 /// by dropping only `1` entries in `originalShape`. 314 llvm::Optional<SmallVector<bool, 4>> 315 computeRankReductionMask(ArrayRef<int64_t> originalShape, 316 ArrayRef<int64_t> reducedShape); 317 318 /// Determines whether MemRefCastOp casts to a more dynamic version of the 319 /// source memref. This is useful to to fold a memref_cast into a consuming op 320 /// and implement canonicalization patterns for ops in different dialects that 321 /// may consume the results of memref_cast operations. Such foldable memref_cast 322 /// operations are typically inserted as `view` and `subview` ops and are 323 /// canonicalized, to preserve the type compatibility of their uses. 324 /// 325 /// Returns true when all conditions are met: 326 /// 1. source and result are ranked memrefs with strided semantics and same 327 /// element type and rank. 328 /// 2. each of the source's size, offset or stride has more static information 329 /// than the corresponding result's size, offset or stride. 330 /// 331 /// Example 1: 332 /// ```mlir 333 /// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32> 334 /// %2 = consumer %1 ... : memref<?x?xf32> ... 335 /// ``` 336 /// 337 /// may fold into: 338 /// 339 /// ```mlir 340 /// %2 = consumer %0 ... : memref<8x16xf32> ... 341 /// ``` 342 /// 343 /// Example 2: 344 /// ``` 345 /// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 346 /// to memref<?x?xf32> 347 /// consumer %1 : memref<?x?xf32> ... 348 /// ``` 349 /// 350 /// may fold into: 351 /// 352 /// ``` 353 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 354 /// ``` 355 bool canFoldIntoConsumerOp(MemRefCastOp castOp); 356 357 /// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors. 358 /// Determines whether TensorCastOp casts to a more dynamic version of the 359 /// source tensor. This is useful to fold a tensor_cast into a consuming op and 360 /// implement canonicalization patterns for ops in different dialects that may 361 /// consume the results of tensor_cast operations. Such foldable tensor_cast 362 /// operations are typically inserted as `subtensor` ops and are canonicalized, 363 /// to preserve the type compatibility of their uses. 364 /// 365 /// Returns true when all conditions are met: 366 /// 1. source and result are ranked tensors with same element type and rank. 367 /// 2. the tensor type has more static information than the result 368 /// 369 /// Example: 370 /// ```mlir 371 /// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 372 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 373 /// ``` 374 /// 375 /// folds into: 376 /// 377 /// ```mlir 378 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 379 /// ``` 380 bool canFoldIntoConsumerOp(TensorCastOp castOp); 381 382 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 383 /// comparison predicates. 384 bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, 385 const APInt &rhs); 386 387 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 388 /// comparison predicates. 389 bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, 390 const APFloat &rhs); 391 } // end namespace mlir 392 393 #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H 394