1 //===- VectorToSCF.h - Utils to convert from the vector dialect -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ 10 #define MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ 11 12 #include "mlir/IR/PatternMatch.h" 13 14 namespace mlir { 15 class MLIRContext; 16 class OwningRewritePatternList; 17 class Pass; 18 19 /// Control whether unrolling is used when lowering vector transfer ops to SCF. 20 /// 21 /// Case 1: 22 /// ======= 23 /// When `unroll` is false, a temporary buffer is created through which 24 /// individual 1-D vector are staged. this is consistent with the lack of an 25 /// LLVM instruction to dynamically index into an aggregate (see the Vector 26 /// dialect lowering to LLVM deep dive). 27 /// An instruction such as: 28 /// ``` 29 /// vector.transfer_write %vec, %A[%base, %base] : 30 /// vector<17x15xf32>, memref<?x?xf32> 31 /// ``` 32 /// Lowers to pseudo-IR resembling: 33 /// ``` 34 /// %0 = alloc() : memref<17xvector<15xf32>> 35 /// %1 = vector.type_cast %0 : 36 /// memref<17xvector<15xf32>> to memref<vector<17x15xf32>> 37 /// store %vec, %1[] : memref<vector<17x15xf32>> 38 /// %dim = dim %A, 0 : memref<?x?xf32> 39 /// affine.for %I = 0 to 17 { 40 /// %add = affine.apply %I + %base 41 /// %cmp = cmpi "slt", %add, %dim : index 42 /// scf.if %cmp { 43 /// %vec_1d = load %0[%I] : memref<17xvector<15xf32>> 44 /// vector.transfer_write %vec_1d, %A[%add, %base] : 45 /// vector<15xf32>, memref<?x?xf32> 46 /// ``` 47 /// 48 /// Case 2: 49 /// ======= 50 /// When `unroll` is true, the temporary buffer is skipped and static indices 51 /// into aggregates can be used (see the Vector dialect lowering to LLVM deep 52 /// dive). 53 /// An instruction such as: 54 /// ``` 55 /// vector.transfer_write %vec, %A[%base, %base] : 56 /// vector<3x15xf32>, memref<?x?xf32> 57 /// ``` 58 /// Lowers to pseudo-IR resembling: 59 /// ``` 60 /// %0 = vector.extract %arg2[0] : vector<3x15xf32> 61 /// vector.transfer_write %0, %arg0[%arg1, %arg1] : vector<15xf32>, 62 /// memref<?x?xf32> %1 = affine.apply #map1()[%arg1] %2 = vector.extract 63 /// %arg2[1] : vector<3x15xf32> vector.transfer_write %2, %arg0[%1, %arg1] : 64 /// vector<15xf32>, memref<?x?xf32> %3 = affine.apply #map2()[%arg1] %4 = 65 /// vector.extract %arg2[2] : vector<3x15xf32> vector.transfer_write %4, 66 /// %arg0[%3, %arg1] : vector<15xf32>, memref<?x?xf32> 67 /// ``` 68 struct VectorTransferToSCFOptions { 69 bool unroll = false; setUnrollVectorTransferToSCFOptions70 VectorTransferToSCFOptions &setUnroll(bool u) { 71 unroll = u; 72 return *this; 73 } 74 }; 75 76 /// Implements lowering of TransferReadOp and TransferWriteOp to a 77 /// proper abstraction for the hardware. 78 /// 79 /// There are multiple cases. 80 /// 81 /// Case A: Permutation Map does not permute or broadcast. 82 /// ====================================================== 83 /// 84 /// Progressive lowering occurs to 1-D vector transfer ops according to the 85 /// description in `VectorTransferToSCFOptions`. 86 /// 87 /// Case B: Permutation Map permutes and/or broadcasts. 88 /// ====================================================== 89 /// 90 /// This path will be progressively deprecated and folded into the case above by 91 /// using vector broadcast and transpose operations. 92 /// 93 /// This path only emits a simple loop nest that performs clipped pointwise 94 /// copies from a remote to a locally allocated memory. 95 /// 96 /// Consider the case: 97 /// 98 /// ```mlir 99 /// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into 100 /// // vector<32x256xf32> and pad with %f0 to handle the boundary case: 101 /// %f0 = constant 0.0f : f32 102 /// scf.for %i0 = 0 to %0 { 103 /// scf.for %i1 = 0 to %1 step %c256 { 104 /// scf.for %i2 = 0 to %2 step %c32 { 105 /// %v = vector.transfer_read %A[%i0, %i1, %i2], %f0 106 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : 107 /// memref<?x?x?xf32>, vector<32x256xf32> 108 /// }}} 109 /// ``` 110 /// 111 /// The rewriters construct loop and indices that access MemRef A in a pattern 112 /// resembling the following (while guaranteeing an always full-tile 113 /// abstraction): 114 /// 115 /// ```mlir 116 /// scf.for %d2 = 0 to %c256 { 117 /// scf.for %d1 = 0 to %c32 { 118 /// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32 119 /// %tmp[%d2, %d1] = %s 120 /// } 121 /// } 122 /// ``` 123 /// 124 /// In the current state, only a clipping transfer is implemented by `clip`, 125 /// which creates individual indexing expressions of the form: 126 /// 127 /// ```mlir-dsc 128 /// auto condMax = i + ii < N; 129 /// auto max = std_select(condMax, i + ii, N - one) 130 /// auto cond = i + ii < zero; 131 /// std_select(cond, zero, max); 132 /// ``` 133 /// 134 /// In the future, clipping should not be the only way and instead we should 135 /// load vectors + mask them. Similarly on the write side, load/mask/store for 136 /// implementing RMW behavior. 137 /// 138 /// Lowers TransferOp into a combination of: 139 /// 1. local memory allocation; 140 /// 2. perfect loop nest over: 141 /// a. scalar load/stores from local buffers (viewed as a scalar memref); 142 /// a. scalar store/load to original memref (with clipping). 143 /// 3. vector_load/store 144 /// 4. local memory deallocation. 145 /// Minor variations occur depending on whether a TransferReadOp or 146 /// a TransferWriteOp is rewritten. 147 template <typename TransferOpTy> 148 struct VectorTransferRewriter : public RewritePattern { 149 explicit VectorTransferRewriter(VectorTransferToSCFOptions options, 150 MLIRContext *context); 151 152 /// Used for staging the transfer in a local buffer. 153 MemRefType tmpMemRefType(TransferOpTy transfer) const; 154 155 /// Performs the rewrite. 156 LogicalResult matchAndRewrite(Operation *op, 157 PatternRewriter &rewriter) const override; 158 159 /// See description of `VectorTransferToSCFOptions`. 160 VectorTransferToSCFOptions options; 161 }; 162 163 /// Collect a set of patterns to convert from the Vector dialect to SCF + std. 164 void populateVectorToSCFConversionPatterns( 165 OwningRewritePatternList &patterns, MLIRContext *context, 166 const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); 167 168 /// Create a pass to convert a subset of vector ops to SCF. 169 std::unique_ptr<Pass> createConvertVectorToSCFPass( 170 const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); 171 172 } // namespace mlir 173 174 #endif // MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ 175