• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines convenience types for working with standard operations
10 // in the MLIR operation set.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
15 #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
16 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/Interfaces/CallInterfaces.h"
22 #include "mlir/Interfaces/ControlFlowInterfaces.h"
23 #include "mlir/Interfaces/SideEffectInterfaces.h"
24 #include "mlir/Interfaces/VectorInterfaces.h"
25 #include "mlir/Interfaces/ViewLikeInterface.h"
26 
27 // Pull in all enum type definitions and utility function declarations.
28 #include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"
29 
30 namespace mlir {
31 class AffineMap;
32 class Builder;
33 class FuncOp;
34 class OpBuilder;
35 
36 raw_ostream &operator<<(raw_ostream &os, Range &range);
37 
38 /// Return the list of Range (i.e. offset, size, stride). Each Range
39 /// entry contains either the dynamic value or a ConstantIndexOp constructed
40 /// with `b` at location `loc`.
41 SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
42                                         OpBuilder &b, Location loc);
43 
44 #define GET_OP_CLASSES
45 #include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
46 
47 #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
48 
49 /// This is a refinement of the "constant" op for the case where it is
50 /// returning a float value of FloatType.
51 ///
52 ///   %1 = "std.constant"(){value: 42.0} : bf16
53 ///
54 class ConstantFloatOp : public ConstantOp {
55 public:
56   using ConstantOp::ConstantOp;
57 
58   /// Builds a constant float op producing a float of the specified type.
59   static void build(OpBuilder &builder, OperationState &result,
60                     const APFloat &value, FloatType type);
61 
getValue()62   APFloat getValue() {
63     return (*this)->getAttrOfType<FloatAttr>("value").getValue();
64   }
65 
66   static bool classof(Operation *op);
67 };
68 
69 /// This is a refinement of the "constant" op for the case where it is
70 /// returning an integer value of IntegerType.
71 ///
72 ///   %1 = "std.constant"(){value: 42} : i32
73 ///
74 class ConstantIntOp : public ConstantOp {
75 public:
76   using ConstantOp::ConstantOp;
77   /// Build a constant int op producing an integer of the specified width.
78   static void build(OpBuilder &builder, OperationState &result, int64_t value,
79                     unsigned width);
80 
81   /// Build a constant int op producing an integer with the specified type,
82   /// which must be an integer type.
83   static void build(OpBuilder &builder, OperationState &result, int64_t value,
84                     Type type);
85 
getValue()86   int64_t getValue() {
87     return (*this)->getAttrOfType<IntegerAttr>("value").getInt();
88   }
89 
90   static bool classof(Operation *op);
91 };
92 
93 /// This is a refinement of the "constant" op for the case where it is
94 /// returning an integer value of Index type.
95 ///
96 ///   %1 = "std.constant"(){value: 99} : () -> index
97 ///
98 class ConstantIndexOp : public ConstantOp {
99 public:
100   using ConstantOp::ConstantOp;
101 
102   /// Build a constant int op producing an index.
103   static void build(OpBuilder &builder, OperationState &result, int64_t value);
104 
getValue()105   int64_t getValue() {
106     return (*this)->getAttrOfType<IntegerAttr>("value").getInt();
107   }
108 
109   static bool classof(Operation *op);
110 };
111 
112 // DmaStartOp starts a non-blocking DMA operation that transfers data from a
113 // source memref to a destination memref. The source and destination memref need
114 // not be of the same dimensionality, but need to have the same elemental type.
115 // The operands include the source and destination memref's each followed by its
116 // indices, size of the data transfer in terms of the number of elements (of the
117 // elemental type of the memref), a tag memref with its indices, and optionally
118 // at the end, a stride and a number_of_elements_per_stride arguments. The tag
119 // location is used by a DmaWaitOp to check for completion. The indices of the
120 // source memref, destination memref, and the tag memref have the same
121 // restrictions as any load/store. The optional stride arguments should be of
122 // 'index' type, and specify a stride for the slower memory space (memory space
123 // with a lower memory space id), transferring chunks of
124 // number_of_elements_per_stride every stride until %num_elements are
125 // transferred. Either both or no stride arguments should be specified.
126 //
127 // For example, a DmaStartOp operation that transfers 256 elements of a memref
128 // '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
129 // 1 at indices [%k, %l], would be specified as follows:
130 //
131 //   %num_elements = constant 256
132 //   %idx = constant 0 : index
133 //   %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
134 //   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
135 //     memref<40 x 128 x f32>, (d0) -> (d0), 0>,
136 //     memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
137 //     memref<1 x i32>, (d0) -> (d0), 2>
138 //
139 //   If %stride and %num_elt_per_stride are specified, the DMA is expected to
140 //   transfer %num_elt_per_stride elements every %stride elements apart from
141 //   memory space 0 until %num_elements are transferred.
142 //
143 //   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
144 //             %num_elt_per_stride :
145 //
146 // TODO: add additional operands to allow source and destination striding, and
147 // multiple stride levels.
148 // TODO: Consider replacing src/dst memref indices with view memrefs.
149 class DmaStartOp
150     : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
151 public:
152   using Op::Op;
153 
154   static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
155                     ValueRange srcIndices, Value destMemRef,
156                     ValueRange destIndices, Value numElements, Value tagMemRef,
157                     ValueRange tagIndices, Value stride = nullptr,
158                     Value elementsPerStride = nullptr);
159 
160   // Returns the source MemRefType for this DMA operation.
getSrcMemRef()161   Value getSrcMemRef() { return getOperand(0); }
162   // Returns the rank (number of indices) of the source MemRefType.
getSrcMemRefRank()163   unsigned getSrcMemRefRank() {
164     return getSrcMemRef().getType().cast<MemRefType>().getRank();
165   }
166   // Returns the source memref indices for this DMA operation.
getSrcIndices()167   operand_range getSrcIndices() {
168     return {(*this)->operand_begin() + 1,
169             (*this)->operand_begin() + 1 + getSrcMemRefRank()};
170   }
171 
172   // Returns the destination MemRefType for this DMA operations.
getDstMemRef()173   Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
174   // Returns the rank (number of indices) of the destination MemRefType.
getDstMemRefRank()175   unsigned getDstMemRefRank() {
176     return getDstMemRef().getType().cast<MemRefType>().getRank();
177   }
getSrcMemorySpace()178   unsigned getSrcMemorySpace() {
179     return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
180   }
getDstMemorySpace()181   unsigned getDstMemorySpace() {
182     return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
183   }
184 
185   // Returns the destination memref indices for this DMA operation.
getDstIndices()186   operand_range getDstIndices() {
187     return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
188             (*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
189                 getDstMemRefRank()};
190   }
191 
192   // Returns the number of elements being transferred by this DMA operation.
getNumElements()193   Value getNumElements() {
194     return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
195   }
196 
197   // Returns the Tag MemRef for this DMA operation.
getTagMemRef()198   Value getTagMemRef() {
199     return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
200   }
201   // Returns the rank (number of indices) of the tag MemRefType.
getTagMemRefRank()202   unsigned getTagMemRefRank() {
203     return getTagMemRef().getType().cast<MemRefType>().getRank();
204   }
205 
206   // Returns the tag memref index for this DMA operation.
getTagIndices()207   operand_range getTagIndices() {
208     unsigned tagIndexStartPos =
209         1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
210     return {(*this)->operand_begin() + tagIndexStartPos,
211             (*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
212   }
213 
214   /// Returns true if this is a DMA from a faster memory space to a slower one.
isDestMemorySpaceFaster()215   bool isDestMemorySpaceFaster() {
216     return (getSrcMemorySpace() < getDstMemorySpace());
217   }
218 
219   /// Returns true if this is a DMA from a slower memory space to a faster one.
isSrcMemorySpaceFaster()220   bool isSrcMemorySpaceFaster() {
221     // Assumes that a lower number is for a slower memory space.
222     return (getDstMemorySpace() < getSrcMemorySpace());
223   }
224 
225   /// Given a DMA start operation, returns the operand position of either the
226   /// source or destination memref depending on the one that is at the higher
227   /// level of the memory hierarchy. Asserts failure if neither is true.
getFasterMemPos()228   unsigned getFasterMemPos() {
229     assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
230     return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
231   }
232 
getOperationName()233   static StringRef getOperationName() { return "std.dma_start"; }
234   static ParseResult parse(OpAsmParser &parser, OperationState &result);
235   void print(OpAsmPrinter &p);
236   LogicalResult verify();
237 
238   LogicalResult fold(ArrayRef<Attribute> cstOperands,
239                      SmallVectorImpl<OpFoldResult> &results);
240 
isStrided()241   bool isStrided() {
242     return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
243                                    1 + 1 + getTagMemRefRank();
244   }
245 
getStride()246   Value getStride() {
247     if (!isStrided())
248       return nullptr;
249     return getOperand(getNumOperands() - 1 - 1);
250   }
251 
getNumElementsPerStride()252   Value getNumElementsPerStride() {
253     if (!isStrided())
254       return nullptr;
255     return getOperand(getNumOperands() - 1);
256   }
257 };
258 
259 // DmaWaitOp blocks until the completion of a DMA operation associated with the
260 // tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
261 // with the same restrictions as any load/store index. %num_elements is the
262 // number of elements associated with the DMA operation. For example:
263 //
264 //   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
265 //     memref<2048 x f32>, (d0) -> (d0), 0>,
266 //     memref<256 x f32>, (d0) -> (d0), 1>
267 //     memref<1 x i32>, (d0) -> (d0), 2>
268 //   ...
269 //   ...
270 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
271 //
272 class DmaWaitOp
273     : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
274 public:
275   using Op::Op;
276 
277   static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
278                     ValueRange tagIndices, Value numElements);
279 
getOperationName()280   static StringRef getOperationName() { return "std.dma_wait"; }
281 
282   // Returns the Tag MemRef associated with the DMA operation being waited on.
getTagMemRef()283   Value getTagMemRef() { return getOperand(0); }
284 
285   // Returns the tag memref index for this DMA operation.
getTagIndices()286   operand_range getTagIndices() {
287     return {(*this)->operand_begin() + 1,
288             (*this)->operand_begin() + 1 + getTagMemRefRank()};
289   }
290 
291   // Returns the rank (number of indices) of the tag memref.
getTagMemRefRank()292   unsigned getTagMemRefRank() {
293     return getTagMemRef().getType().cast<MemRefType>().getRank();
294   }
295 
296   // Returns the number of elements transferred in the associated DMA operation.
getNumElements()297   Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
298 
299   static ParseResult parse(OpAsmParser &parser, OperationState &result);
300   void print(OpAsmPrinter &p);
301   LogicalResult fold(ArrayRef<Attribute> cstOperands,
302                      SmallVectorImpl<OpFoldResult> &results);
303   LogicalResult verify();
304 };
305 
306 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
307 /// `originalShape` with some `1` entries erased, return the vector of booleans
308 /// that specifies which of the entries of `originalShape` are keep to obtain
309 /// `reducedShape`. The returned mask can be applied as a projection to
310 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
311 /// which dimensions must be kept when e.g. compute MemRef strides under
312 /// rank-reducing operations. Return None if reducedShape cannot be obtained
313 /// by dropping only `1` entries in `originalShape`.
314 llvm::Optional<SmallVector<bool, 4>>
315 computeRankReductionMask(ArrayRef<int64_t> originalShape,
316                          ArrayRef<int64_t> reducedShape);
317 
318 /// Determines whether MemRefCastOp casts to a more dynamic version of the
319 /// source memref. This is useful to to fold a memref_cast into a consuming op
320 /// and implement canonicalization patterns for ops in different dialects that
321 /// may consume the results of memref_cast operations. Such foldable memref_cast
322 /// operations are typically inserted as `view` and `subview` ops and are
323 /// canonicalized, to preserve the type compatibility of their uses.
324 ///
325 /// Returns true when all conditions are met:
326 /// 1. source and result are ranked memrefs with strided semantics and same
327 /// element type and rank.
328 /// 2. each of the source's size, offset or stride has more static information
329 /// than the corresponding result's size, offset or stride.
330 ///
331 /// Example 1:
332 /// ```mlir
333 ///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
334 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
335 /// ```
336 ///
337 /// may fold into:
338 ///
339 /// ```mlir
340 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
341 /// ```
342 ///
343 /// Example 2:
344 /// ```
345 ///   %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
346 ///          to memref<?x?xf32>
347 ///   consumer %1 : memref<?x?xf32> ...
348 /// ```
349 ///
350 /// may fold into:
351 ///
352 /// ```
353 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
354 /// ```
355 bool canFoldIntoConsumerOp(MemRefCastOp castOp);
356 
357 /// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
358 /// Determines whether TensorCastOp casts to a more dynamic version of the
359 /// source tensor. This is useful to fold a tensor_cast into a consuming op and
360 /// implement canonicalization patterns for ops in different dialects that may
361 /// consume the results of tensor_cast operations. Such foldable tensor_cast
362 /// operations are typically inserted as `subtensor` ops and are canonicalized,
363 /// to preserve the type compatibility of their uses.
364 ///
365 /// Returns true when all conditions are met:
366 /// 1. source and result are ranked tensors with same element type and rank.
367 /// 2. the tensor type has more static information than the result
368 ///
369 /// Example:
370 /// ```mlir
371 ///   %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
372 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
373 /// ```
374 ///
375 /// folds into:
376 ///
377 /// ```mlir
378 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
379 /// ```
380 bool canFoldIntoConsumerOp(TensorCastOp castOp);
381 
382 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
383 /// comparison predicates.
384 bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
385                        const APInt &rhs);
386 
387 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
388 /// comparison predicates.
389 bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
390                        const APFloat &rhs);
391 } // end namespace mlir
392 
393 #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H
394