1 //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Vector dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_VECTOR_VECTOROPS_H 14 #define MLIR_DIALECT_VECTOR_VECTOROPS_H 15 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/OpDefinition.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 #include "mlir/Interfaces/VectorInterfaces.h" 23 #include "mlir/Interfaces/ViewLikeInterface.h" 24 25 namespace mlir { 26 class MLIRContext; 27 class OwningRewritePatternList; 28 namespace vector { 29 30 /// Collect a set of vector-to-vector canonicalization patterns. 31 void populateVectorToVectorCanonicalizationPatterns( 32 OwningRewritePatternList &patterns, MLIRContext *context); 33 34 /// Collect a set of vector-to-vector transformation patterns. 35 void populateVectorToVectorTransformationPatterns( 36 OwningRewritePatternList &patterns, MLIRContext *context); 37 38 /// Collect a set of vector slices transformation patterns: 39 /// ExtractSlicesOpLowering, InsertSlicesOpLowering 40 /// Useful for clients that want to express all vector "slices" 41 /// ops in terms of more elementary vector "slice" ops. If all 42 /// "produced" tuple values are "consumed" (the most common 43 /// use for "slices" ops), this lowering removes all tuple related 44 /// operations as well (through DCE and folding). If tuple values 45 /// "leak" coming in, however, some tuple related ops will remain. 46 void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns, 47 MLIRContext *context); 48 49 /// Enum to control the lowering of `vector.contract` operations. 50 enum class VectorContractLowering { 51 /// Progressively lower to finer grained `vector.contract` and dot-products. 52 Dot = 0, 53 /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. 54 Matmul = 1, 55 /// Lower to `vector.outerproduct`. 56 OuterProduct = 2, 57 }; 58 /// Enum to control the lowering of `vector.transpose` operations. 59 enum class VectorTransposeLowering { 60 /// Lower transpose into element-wise extract and inserts. 61 EltWise = 0, 62 /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix 63 /// intrinsics. 64 Flat = 1, 65 }; 66 /// Enum to control the splitting of `vector.transfer` operations into masked 67 /// and unmasked variants. 68 enum class VectorTransferSplit { 69 /// Do not split vector transfer operations. 70 None = 0, 71 /// Split using masked + unmasked vector.transfer operations. 72 VectorTransfer = 1, 73 /// Split using a unmasked vector.transfer + linalg.fill + linalg.copy 74 /// operations. 75 LinalgCopy = 2, 76 /// Do not split vector transfer operation but instead mark it as "unmasked". 77 ForceUnmasked = 3 78 }; 79 /// Structure to control the behavior of vector transform patterns. 80 struct VectorTransformsOptions { 81 /// Option to control the lowering of vector.contract. 82 VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; 83 VectorTransformsOptions & setVectorTransformsOptionsVectorTransformsOptions84 setVectorTransformsOptions(VectorContractLowering opt) { 85 vectorContractLowering = opt; 86 return *this; 87 } 88 /// Option to control the lowering of vector.transpose. 89 VectorTransposeLowering vectorTransposeLowering = 90 VectorTransposeLowering::EltWise; 91 VectorTransformsOptions & setVectorTransposeLoweringVectorTransformsOptions92 setVectorTransposeLowering(VectorTransposeLowering opt) { 93 vectorTransposeLowering = opt; 94 return *this; 95 } 96 /// Option to control the splitting of vector transfers. 97 VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; setVectorTransferSplitVectorTransformsOptions98 VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { 99 vectorTransferSplit = opt; 100 return *this; 101 } 102 }; 103 104 /// Collect a set of transformation patterns that are related to contracting 105 /// or expanding vector operations: 106 /// ContractionOpLowering, 107 /// ShapeCastOp2DDownCastRewritePattern, 108 /// ShapeCastOp2DUpCastRewritePattern 109 /// BroadcastOpLowering, 110 /// TransposeOpLowering 111 /// OuterproductOpLowering 112 /// These transformation express higher level vector ops in terms of more 113 /// elementary extraction, insertion, reduction, product, and broadcast ops. 114 void populateVectorContractLoweringPatterns( 115 OwningRewritePatternList &patterns, MLIRContext *context, 116 VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions()); 117 118 /// Returns the integer type required for subscripts in the vector dialect. 119 IntegerType getVectorSubscriptType(Builder &builder); 120 121 /// Returns an integer array attribute containing the given values using 122 /// the integer type required for subscripts in the vector dialect. 123 ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values); 124 125 namespace impl { 126 /// Build the default minor identity map suitable for a vector transfer. This 127 /// also handles the case memref<... x vector<...>> -> vector<...> in which the 128 /// rank of the identity map must take the vector element type into account. 129 AffineMap getTransferMinorIdentityMap(MemRefType memRefType, 130 VectorType vectorType); 131 } // namespace impl 132 } // end namespace vector 133 } // end namespace mlir 134 135 #define GET_OP_CLASSES 136 #include "mlir/Dialect/Vector/VectorOps.h.inc" 137 #include "mlir/Dialect/Vector/VectorOpsDialect.h.inc" 138 139 #endif // MLIR_DIALECT_VECTOR_VECTOROPS_H 140