• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Vector/VectorOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/Target/LLVMIR/TypeTranslation.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 // Helper to reduce vector type by one rank at front.
reducedVectorTypeFront(VectorType tp)24 static VectorType reducedVectorTypeFront(VectorType tp) {
25   assert((tp.getRank() > 1) && "unlowerable vector type");
26   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
27 }
28 
29 // Helper to reduce vector type by *all* but one rank at back.
reducedVectorTypeBack(VectorType tp)30 static VectorType reducedVectorTypeBack(VectorType tp) {
31   assert((tp.getRank() > 1) && "unlowerable vector type");
32   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
33 }
34 
35 // Helper that picks the proper sequence for inserting.
insertOne(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,Value val1,Value val2,Type llvmType,int64_t rank,int64_t pos)36 static Value insertOne(ConversionPatternRewriter &rewriter,
37                        LLVMTypeConverter &typeConverter, Location loc,
38                        Value val1, Value val2, Type llvmType, int64_t rank,
39                        int64_t pos) {
40   if (rank == 1) {
41     auto idxType = rewriter.getIndexType();
42     auto constant = rewriter.create<LLVM::ConstantOp>(
43         loc, typeConverter.convertType(idxType),
44         rewriter.getIntegerAttr(idxType, pos));
45     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
46                                                   constant);
47   }
48   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
49                                               rewriter.getI64ArrayAttr(pos));
50 }
51 
52 // Helper that picks the proper sequence for inserting.
insertOne(PatternRewriter & rewriter,Location loc,Value from,Value into,int64_t offset)53 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
54                        Value into, int64_t offset) {
55   auto vectorType = into.getType().cast<VectorType>();
56   if (vectorType.getRank() > 1)
57     return rewriter.create<InsertOp>(loc, from, into, offset);
58   return rewriter.create<vector::InsertElementOp>(
59       loc, vectorType, from, into,
60       rewriter.create<ConstantIndexOp>(loc, offset));
61 }
62 
63 // Helper that picks the proper sequence for extracting.
extractOne(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,Value val,Type llvmType,int64_t rank,int64_t pos)64 static Value extractOne(ConversionPatternRewriter &rewriter,
65                         LLVMTypeConverter &typeConverter, Location loc,
66                         Value val, Type llvmType, int64_t rank, int64_t pos) {
67   if (rank == 1) {
68     auto idxType = rewriter.getIndexType();
69     auto constant = rewriter.create<LLVM::ConstantOp>(
70         loc, typeConverter.convertType(idxType),
71         rewriter.getIntegerAttr(idxType, pos));
72     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
73                                                    constant);
74   }
75   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
76                                                rewriter.getI64ArrayAttr(pos));
77 }
78 
79 // Helper that picks the proper sequence for extracting.
extractOne(PatternRewriter & rewriter,Location loc,Value vector,int64_t offset)80 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
81                         int64_t offset) {
82   auto vectorType = vector.getType().cast<VectorType>();
83   if (vectorType.getRank() > 1)
84     return rewriter.create<ExtractOp>(loc, vector, offset);
85   return rewriter.create<vector::ExtractElementOp>(
86       loc, vectorType.getElementType(), vector,
87       rewriter.create<ConstantIndexOp>(loc, offset));
88 }
89 
90 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
91 // TODO: Better support for attribute subtype forwarding + slicing.
getI64SubArray(ArrayAttr arrayAttr,unsigned dropFront=0,unsigned dropBack=0)92 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
93                                               unsigned dropFront = 0,
94                                               unsigned dropBack = 0) {
95   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
96   auto range = arrayAttr.getAsRange<IntegerAttr>();
97   SmallVector<int64_t, 4> res;
98   res.reserve(arrayAttr.size() - dropFront - dropBack);
99   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
100        it != eit; ++it)
101     res.push_back((*it).getValue().getSExtValue());
102   return res;
103 }
104 
105 // Helper that returns a vector comparison that constructs a mask:
106 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107 //
108 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109 //       much more compact, IR for this operation, but LLVM eventually
110 //       generates more elaborate instructions for this intrinsic since it
111 //       is very conservative on the boundary conditions.
buildVectorComparison(ConversionPatternRewriter & rewriter,Operation * op,bool enableIndexOptimizations,int64_t dim,Value b,Value * off=nullptr)112 static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113                                    Operation *op, bool enableIndexOptimizations,
114                                    int64_t dim, Value b, Value *off = nullptr) {
115   auto loc = op->getLoc();
116   // If we can assume all indices fit in 32-bit, we perform the vector
117   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118   // Otherwise we perform the vector comparison using 64-bit indices.
119   Value indices;
120   Type idxType;
121   if (enableIndexOptimizations) {
122     indices = rewriter.create<ConstantOp>(
123         loc, rewriter.getI32VectorAttr(
124                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125     idxType = rewriter.getI32Type();
126   } else {
127     indices = rewriter.create<ConstantOp>(
128         loc, rewriter.getI64VectorAttr(
129                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130     idxType = rewriter.getI64Type();
131   }
132   // Add in an offset if requested.
133   if (off) {
134     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136     indices = rewriter.create<AddIOp>(loc, ov, indices);
137   }
138   // Construct the vector comparison.
139   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142 }
143 
144 // Helper that returns data layout alignment of an operation with memref.
145 template <typename T>
getMemRefAlignment(LLVMTypeConverter & typeConverter,T op,unsigned & align)146 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
147                                  unsigned &align) {
148   Type elementTy =
149       typeConverter.convertType(op.getMemRefType().getElementType());
150   if (!elementTy)
151     return failure();
152 
153   // TODO: this should use the MLIR data layout when it becomes available and
154   // stop depending on translation.
155   llvm::LLVMContext llvmContext;
156   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
157               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
158                                      typeConverter.getDataLayout());
159   return success();
160 }
161 
162 // Helper that returns the base address of a memref.
getBase(ConversionPatternRewriter & rewriter,Location loc,Value memref,MemRefType memRefType,Value & base)163 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
164                              Value memref, MemRefType memRefType, Value &base) {
165   // Inspect stride and offset structure.
166   //
167   // TODO: flat memory only for now, generalize
168   //
169   int64_t offset;
170   SmallVector<int64_t, 4> strides;
171   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
172   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
173       offset != 0 || memRefType.getMemorySpace() != 0)
174     return failure();
175   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
176   return success();
177 }
178 
179 // Helper that returns a pointer given a memref base.
getBasePtr(ConversionPatternRewriter & rewriter,Location loc,Value memref,MemRefType memRefType,Value & ptr)180 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
181                                 Location loc, Value memref,
182                                 MemRefType memRefType, Value &ptr) {
183   Value base;
184   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
185     return failure();
186   auto pType = MemRefDescriptor(memref).getElementPtrType();
187   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
188   return success();
189 }
190 
191 // Helper that returns a bit-casted pointer given a memref base.
getBasePtr(ConversionPatternRewriter & rewriter,Location loc,Value memref,MemRefType memRefType,Type type,Value & ptr)192 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
193                                 Location loc, Value memref,
194                                 MemRefType memRefType, Type type, Value &ptr) {
195   Value base;
196   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
197     return failure();
198   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
199   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
200   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
201   return success();
202 }
203 
204 // Helper that returns vector of pointers given a memref base and an index
205 // vector.
getIndexedPtrs(ConversionPatternRewriter & rewriter,Location loc,Value memref,Value indices,MemRefType memRefType,VectorType vType,Type iType,Value & ptrs)206 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
207                                     Location loc, Value memref, Value indices,
208                                     MemRefType memRefType, VectorType vType,
209                                     Type iType, Value &ptrs) {
210   Value base;
211   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
212     return failure();
213   auto pType = MemRefDescriptor(memref).getElementPtrType();
214   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
215   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
216   return success();
217 }
218 
219 static LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,TransferReadOp xferOp,ArrayRef<Value> operands,Value dataPtr)220 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
221                                  LLVMTypeConverter &typeConverter, Location loc,
222                                  TransferReadOp xferOp,
223                                  ArrayRef<Value> operands, Value dataPtr) {
224   unsigned align;
225   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
226     return failure();
227   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
228   return success();
229 }
230 
231 static LogicalResult
replaceTransferOpWithMasked(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,TransferReadOp xferOp,ArrayRef<Value> operands,Value dataPtr,Value mask)232 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
233                             LLVMTypeConverter &typeConverter, Location loc,
234                             TransferReadOp xferOp, ArrayRef<Value> operands,
235                             Value dataPtr, Value mask) {
236   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
237   VectorType fillType = xferOp.getVectorType();
238   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
239   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
240 
241   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
242   if (!vecTy)
243     return failure();
244 
245   unsigned align;
246   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
247     return failure();
248 
249   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
250       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
251       rewriter.getI32IntegerAttr(align));
252   return success();
253 }
254 
255 static LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,TransferWriteOp xferOp,ArrayRef<Value> operands,Value dataPtr)256 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
257                                  LLVMTypeConverter &typeConverter, Location loc,
258                                  TransferWriteOp xferOp,
259                                  ArrayRef<Value> operands, Value dataPtr) {
260   unsigned align;
261   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
262     return failure();
263   auto adaptor = TransferWriteOpAdaptor(operands);
264   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
265                                              align);
266   return success();
267 }
268 
269 static LogicalResult
replaceTransferOpWithMasked(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,TransferWriteOp xferOp,ArrayRef<Value> operands,Value dataPtr,Value mask)270 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
271                             LLVMTypeConverter &typeConverter, Location loc,
272                             TransferWriteOp xferOp, ArrayRef<Value> operands,
273                             Value dataPtr, Value mask) {
274   unsigned align;
275   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
276     return failure();
277 
278   auto adaptor = TransferWriteOpAdaptor(operands);
279   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
280       xferOp, adaptor.vector(), dataPtr, mask,
281       rewriter.getI32IntegerAttr(align));
282   return success();
283 }
284 
getTransferOpAdapter(TransferReadOp xferOp,ArrayRef<Value> operands)285 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
286                                                   ArrayRef<Value> operands) {
287   return TransferReadOpAdaptor(operands);
288 }
289 
getTransferOpAdapter(TransferWriteOp xferOp,ArrayRef<Value> operands)290 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
291                                                    ArrayRef<Value> operands) {
292   return TransferWriteOpAdaptor(operands);
293 }
294 
295 namespace {
296 
297 /// Conversion pattern for a vector.matrix_multiply.
298 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
299 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
300 public:
VectorMatmulOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)301   explicit VectorMatmulOpConversion(MLIRContext *context,
302                                     LLVMTypeConverter &typeConverter)
303       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
304                              typeConverter) {}
305 
306   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const307   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto matmulOp = cast<vector::MatmulOp>(op);
310     auto adaptor = vector::MatmulOpAdaptor(operands);
311     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
312         op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
313         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
314         matmulOp.rhs_columns());
315     return success();
316   }
317 };
318 
319 /// Conversion pattern for a vector.flat_transpose.
320 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
321 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
322 public:
VectorFlatTransposeOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)323   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
324                                            LLVMTypeConverter &typeConverter)
325       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
326                              context, typeConverter) {}
327 
328   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const329   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
330                   ConversionPatternRewriter &rewriter) const override {
331     auto transOp = cast<vector::FlatTransposeOp>(op);
332     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
333     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
334         transOp, typeConverter->convertType(transOp.res().getType()),
335         adaptor.matrix(), transOp.rows(), transOp.columns());
336     return success();
337   }
338 };
339 
340 /// Conversion pattern for a vector.maskedload.
341 class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
342 public:
VectorMaskedLoadOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)343   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
344                                         LLVMTypeConverter &typeConverter)
345       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
346                              typeConverter) {}
347 
348   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const349   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
350                   ConversionPatternRewriter &rewriter) const override {
351     auto loc = op->getLoc();
352     auto load = cast<vector::MaskedLoadOp>(op);
353     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
354 
355     // Resolve alignment.
356     unsigned align;
357     if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
358       return failure();
359 
360     auto vtype = typeConverter->convertType(load.getResultVectorType());
361     Value ptr;
362     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
363                           vtype, ptr)))
364       return failure();
365 
366     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
367         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
368         rewriter.getI32IntegerAttr(align));
369     return success();
370   }
371 };
372 
373 /// Conversion pattern for a vector.maskedstore.
374 class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
375 public:
VectorMaskedStoreOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)376   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
377                                          LLVMTypeConverter &typeConverter)
378       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
379                              typeConverter) {}
380 
381   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const382   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
383                   ConversionPatternRewriter &rewriter) const override {
384     auto loc = op->getLoc();
385     auto store = cast<vector::MaskedStoreOp>(op);
386     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
387 
388     // Resolve alignment.
389     unsigned align;
390     if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
391       return failure();
392 
393     auto vtype = typeConverter->convertType(store.getValueVectorType());
394     Value ptr;
395     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
396                           vtype, ptr)))
397       return failure();
398 
399     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
400         store, adaptor.value(), ptr, adaptor.mask(),
401         rewriter.getI32IntegerAttr(align));
402     return success();
403   }
404 };
405 
406 /// Conversion pattern for a vector.gather.
407 class VectorGatherOpConversion : public ConvertToLLVMPattern {
408 public:
VectorGatherOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)409   explicit VectorGatherOpConversion(MLIRContext *context,
410                                     LLVMTypeConverter &typeConverter)
411       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
412                              typeConverter) {}
413 
414   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const415   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
416                   ConversionPatternRewriter &rewriter) const override {
417     auto loc = op->getLoc();
418     auto gather = cast<vector::GatherOp>(op);
419     auto adaptor = vector::GatherOpAdaptor(operands);
420 
421     // Resolve alignment.
422     unsigned align;
423     if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
424       return failure();
425 
426     // Get index ptrs.
427     VectorType vType = gather.getResultVectorType();
428     Type iType = gather.getIndicesVectorType().getElementType();
429     Value ptrs;
430     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
431                               gather.getMemRefType(), vType, iType, ptrs)))
432       return failure();
433 
434     // Replace with the gather intrinsic.
435     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
436         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
437         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
438     return success();
439   }
440 };
441 
442 /// Conversion pattern for a vector.scatter.
443 class VectorScatterOpConversion : public ConvertToLLVMPattern {
444 public:
VectorScatterOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)445   explicit VectorScatterOpConversion(MLIRContext *context,
446                                      LLVMTypeConverter &typeConverter)
447       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
448                              typeConverter) {}
449 
450   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const451   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
452                   ConversionPatternRewriter &rewriter) const override {
453     auto loc = op->getLoc();
454     auto scatter = cast<vector::ScatterOp>(op);
455     auto adaptor = vector::ScatterOpAdaptor(operands);
456 
457     // Resolve alignment.
458     unsigned align;
459     if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
460       return failure();
461 
462     // Get index ptrs.
463     VectorType vType = scatter.getValueVectorType();
464     Type iType = scatter.getIndicesVectorType().getElementType();
465     Value ptrs;
466     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
467                               scatter.getMemRefType(), vType, iType, ptrs)))
468       return failure();
469 
470     // Replace with the scatter intrinsic.
471     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
472         scatter, adaptor.value(), ptrs, adaptor.mask(),
473         rewriter.getI32IntegerAttr(align));
474     return success();
475   }
476 };
477 
478 /// Conversion pattern for a vector.expandload.
479 class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
480 public:
VectorExpandLoadOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)481   explicit VectorExpandLoadOpConversion(MLIRContext *context,
482                                         LLVMTypeConverter &typeConverter)
483       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
484                              typeConverter) {}
485 
486   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const487   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
488                   ConversionPatternRewriter &rewriter) const override {
489     auto loc = op->getLoc();
490     auto expand = cast<vector::ExpandLoadOp>(op);
491     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
492 
493     Value ptr;
494     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
495                           ptr)))
496       return failure();
497 
498     auto vType = expand.getResultVectorType();
499     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
500         op, typeConverter->convertType(vType), ptr, adaptor.mask(),
501         adaptor.pass_thru());
502     return success();
503   }
504 };
505 
506 /// Conversion pattern for a vector.compressstore.
507 class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
508 public:
VectorCompressStoreOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)509   explicit VectorCompressStoreOpConversion(MLIRContext *context,
510                                            LLVMTypeConverter &typeConverter)
511       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
512                              context, typeConverter) {}
513 
514   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const515   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
516                   ConversionPatternRewriter &rewriter) const override {
517     auto loc = op->getLoc();
518     auto compress = cast<vector::CompressStoreOp>(op);
519     auto adaptor = vector::CompressStoreOpAdaptor(operands);
520 
521     Value ptr;
522     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
523                           compress.getMemRefType(), ptr)))
524       return failure();
525 
526     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
527         op, adaptor.value(), ptr, adaptor.mask());
528     return success();
529   }
530 };
531 
532 /// Conversion pattern for all vector reductions.
533 class VectorReductionOpConversion : public ConvertToLLVMPattern {
534 public:
VectorReductionOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter,bool reassociateFPRed)535   explicit VectorReductionOpConversion(MLIRContext *context,
536                                        LLVMTypeConverter &typeConverter,
537                                        bool reassociateFPRed)
538       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
539                              typeConverter),
540         reassociateFPReductions(reassociateFPRed) {}
541 
542   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const543   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
544                   ConversionPatternRewriter &rewriter) const override {
545     auto reductionOp = cast<vector::ReductionOp>(op);
546     auto kind = reductionOp.kind();
547     Type eltType = reductionOp.dest().getType();
548     Type llvmType = typeConverter->convertType(eltType);
549     if (eltType.isIntOrIndex()) {
550       // Integer reductions: add/mul/min/max/and/or/xor.
551       if (kind == "add")
552         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
553             op, llvmType, operands[0]);
554       else if (kind == "mul")
555         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
556             op, llvmType, operands[0]);
557       else if (kind == "min" &&
558                (eltType.isIndex() || eltType.isUnsignedInteger()))
559         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
560             op, llvmType, operands[0]);
561       else if (kind == "min")
562         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
563             op, llvmType, operands[0]);
564       else if (kind == "max" &&
565                (eltType.isIndex() || eltType.isUnsignedInteger()))
566         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
567             op, llvmType, operands[0]);
568       else if (kind == "max")
569         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
570             op, llvmType, operands[0]);
571       else if (kind == "and")
572         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
573             op, llvmType, operands[0]);
574       else if (kind == "or")
575         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
576             op, llvmType, operands[0]);
577       else if (kind == "xor")
578         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
579             op, llvmType, operands[0]);
580       else
581         return failure();
582       return success();
583     }
584 
585     if (!eltType.isa<FloatType>())
586       return failure();
587 
588     // Floating-point reductions: add/mul/min/max
589     if (kind == "add") {
590       // Optional accumulator (or zero).
591       Value acc = operands.size() > 1 ? operands[1]
592                                       : rewriter.create<LLVM::ConstantOp>(
593                                             op->getLoc(), llvmType,
594                                             rewriter.getZeroAttr(eltType));
595       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
596           op, llvmType, acc, operands[0],
597           rewriter.getBoolAttr(reassociateFPReductions));
598     } else if (kind == "mul") {
599       // Optional accumulator (or one).
600       Value acc = operands.size() > 1
601                       ? operands[1]
602                       : rewriter.create<LLVM::ConstantOp>(
603                             op->getLoc(), llvmType,
604                             rewriter.getFloatAttr(eltType, 1.0));
605       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
606           op, llvmType, acc, operands[0],
607           rewriter.getBoolAttr(reassociateFPReductions));
608     } else if (kind == "min")
609       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
610                                                             operands[0]);
611     else if (kind == "max")
612       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
613                                                             operands[0]);
614     else
615       return failure();
616     return success();
617   }
618 
619 private:
620   const bool reassociateFPReductions;
621 };
622 
623 /// Conversion pattern for a vector.create_mask (1-D only).
624 class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
625 public:
VectorCreateMaskOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter,bool enableIndexOpt)626   explicit VectorCreateMaskOpConversion(MLIRContext *context,
627                                         LLVMTypeConverter &typeConverter,
628                                         bool enableIndexOpt)
629       : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
630                              typeConverter),
631         enableIndexOptimizations(enableIndexOpt) {}
632 
633   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const634   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
635                   ConversionPatternRewriter &rewriter) const override {
636     auto dstType = op->getResult(0).getType().cast<VectorType>();
637     int64_t rank = dstType.getRank();
638     if (rank == 1) {
639       rewriter.replaceOp(
640           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
641                                     dstType.getDimSize(0), operands[0]));
642       return success();
643     }
644     return failure();
645   }
646 
647 private:
648   const bool enableIndexOptimizations;
649 };
650 
651 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
652 public:
VectorShuffleOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)653   explicit VectorShuffleOpConversion(MLIRContext *context,
654                                      LLVMTypeConverter &typeConverter)
655       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
656                              typeConverter) {}
657 
658   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const659   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
660                   ConversionPatternRewriter &rewriter) const override {
661     auto loc = op->getLoc();
662     auto adaptor = vector::ShuffleOpAdaptor(operands);
663     auto shuffleOp = cast<vector::ShuffleOp>(op);
664     auto v1Type = shuffleOp.getV1VectorType();
665     auto v2Type = shuffleOp.getV2VectorType();
666     auto vectorType = shuffleOp.getVectorType();
667     Type llvmType = typeConverter->convertType(vectorType);
668     auto maskArrayAttr = shuffleOp.mask();
669 
670     // Bail if result type cannot be lowered.
671     if (!llvmType)
672       return failure();
673 
674     // Get rank and dimension sizes.
675     int64_t rank = vectorType.getRank();
676     assert(v1Type.getRank() == rank);
677     assert(v2Type.getRank() == rank);
678     int64_t v1Dim = v1Type.getDimSize(0);
679 
680     // For rank 1, where both operands have *exactly* the same vector type,
681     // there is direct shuffle support in LLVM. Use it!
682     if (rank == 1 && v1Type == v2Type) {
683       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
684           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
685       rewriter.replaceOp(op, shuffle);
686       return success();
687     }
688 
689     // For all other cases, insert the individual values individually.
690     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
691     int64_t insPos = 0;
692     for (auto en : llvm::enumerate(maskArrayAttr)) {
693       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
694       Value value = adaptor.v1();
695       if (extPos >= v1Dim) {
696         extPos -= v1Dim;
697         value = adaptor.v2();
698       }
699       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
700                                  llvmType, rank, extPos);
701       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
702                          llvmType, rank, insPos++);
703     }
704     rewriter.replaceOp(op, insert);
705     return success();
706   }
707 };
708 
709 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
710 public:
VectorExtractElementOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)711   explicit VectorExtractElementOpConversion(MLIRContext *context,
712                                             LLVMTypeConverter &typeConverter)
713       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
714                              context, typeConverter) {}
715 
716   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const717   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
718                   ConversionPatternRewriter &rewriter) const override {
719     auto adaptor = vector::ExtractElementOpAdaptor(operands);
720     auto extractEltOp = cast<vector::ExtractElementOp>(op);
721     auto vectorType = extractEltOp.getVectorType();
722     auto llvmType = typeConverter->convertType(vectorType.getElementType());
723 
724     // Bail if result type cannot be lowered.
725     if (!llvmType)
726       return failure();
727 
728     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
729         op, llvmType, adaptor.vector(), adaptor.position());
730     return success();
731   }
732 };
733 
734 class VectorExtractOpConversion : public ConvertToLLVMPattern {
735 public:
VectorExtractOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)736   explicit VectorExtractOpConversion(MLIRContext *context,
737                                      LLVMTypeConverter &typeConverter)
738       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
739                              typeConverter) {}
740 
741   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const742   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
743                   ConversionPatternRewriter &rewriter) const override {
744     auto loc = op->getLoc();
745     auto adaptor = vector::ExtractOpAdaptor(operands);
746     auto extractOp = cast<vector::ExtractOp>(op);
747     auto vectorType = extractOp.getVectorType();
748     auto resultType = extractOp.getResult().getType();
749     auto llvmResultType = typeConverter->convertType(resultType);
750     auto positionArrayAttr = extractOp.position();
751 
752     // Bail if result type cannot be lowered.
753     if (!llvmResultType)
754       return failure();
755 
756     // One-shot extraction of vector from array (only requires extractvalue).
757     if (resultType.isa<VectorType>()) {
758       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
759           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
760       rewriter.replaceOp(op, extracted);
761       return success();
762     }
763 
764     // Potential extraction of 1-D vector from array.
765     auto *context = op->getContext();
766     Value extracted = adaptor.vector();
767     auto positionAttrs = positionArrayAttr.getValue();
768     if (positionAttrs.size() > 1) {
769       auto oneDVectorType = reducedVectorTypeBack(vectorType);
770       auto nMinusOnePositionAttrs =
771           ArrayAttr::get(positionAttrs.drop_back(), context);
772       extracted = rewriter.create<LLVM::ExtractValueOp>(
773           loc, typeConverter->convertType(oneDVectorType), extracted,
774           nMinusOnePositionAttrs);
775     }
776 
777     // Remaining extraction of element from 1-D LLVM vector
778     auto position = positionAttrs.back().cast<IntegerAttr>();
779     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
780     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
781     extracted =
782         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
783     rewriter.replaceOp(op, extracted);
784 
785     return success();
786   }
787 };
788 
789 /// Conversion pattern that turns a vector.fma on a 1-D vector
790 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
791 /// This does not match vectors of n >= 2 rank.
792 ///
793 /// Example:
794 /// ```
795 ///  vector.fma %a, %a, %a : vector<8xf32>
796 /// ```
797 /// is converted to:
798 /// ```
799 ///  llvm.intr.fmuladd %va, %va, %va:
800 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
801 ///    -> !llvm<"<8 x float>">
802 /// ```
803 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
804 public:
VectorFMAOp1DConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)805   explicit VectorFMAOp1DConversion(MLIRContext *context,
806                                    LLVMTypeConverter &typeConverter)
807       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
808                              typeConverter) {}
809 
810   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const811   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
812                   ConversionPatternRewriter &rewriter) const override {
813     auto adaptor = vector::FMAOpAdaptor(operands);
814     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
815     VectorType vType = fmaOp.getVectorType();
816     if (vType.getRank() != 1)
817       return failure();
818     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
819                                                  adaptor.rhs(), adaptor.acc());
820     return success();
821   }
822 };
823 
824 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
825 public:
VectorInsertElementOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)826   explicit VectorInsertElementOpConversion(MLIRContext *context,
827                                            LLVMTypeConverter &typeConverter)
828       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
829                              context, typeConverter) {}
830 
831   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const832   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
833                   ConversionPatternRewriter &rewriter) const override {
834     auto adaptor = vector::InsertElementOpAdaptor(operands);
835     auto insertEltOp = cast<vector::InsertElementOp>(op);
836     auto vectorType = insertEltOp.getDestVectorType();
837     auto llvmType = typeConverter->convertType(vectorType);
838 
839     // Bail if result type cannot be lowered.
840     if (!llvmType)
841       return failure();
842 
843     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
844         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
845     return success();
846   }
847 };
848 
849 class VectorInsertOpConversion : public ConvertToLLVMPattern {
850 public:
VectorInsertOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)851   explicit VectorInsertOpConversion(MLIRContext *context,
852                                     LLVMTypeConverter &typeConverter)
853       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
854                              typeConverter) {}
855 
856   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const857   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
858                   ConversionPatternRewriter &rewriter) const override {
859     auto loc = op->getLoc();
860     auto adaptor = vector::InsertOpAdaptor(operands);
861     auto insertOp = cast<vector::InsertOp>(op);
862     auto sourceType = insertOp.getSourceType();
863     auto destVectorType = insertOp.getDestVectorType();
864     auto llvmResultType = typeConverter->convertType(destVectorType);
865     auto positionArrayAttr = insertOp.position();
866 
867     // Bail if result type cannot be lowered.
868     if (!llvmResultType)
869       return failure();
870 
871     // One-shot insertion of a vector into an array (only requires insertvalue).
872     if (sourceType.isa<VectorType>()) {
873       Value inserted = rewriter.create<LLVM::InsertValueOp>(
874           loc, llvmResultType, adaptor.dest(), adaptor.source(),
875           positionArrayAttr);
876       rewriter.replaceOp(op, inserted);
877       return success();
878     }
879 
880     // Potential extraction of 1-D vector from array.
881     auto *context = op->getContext();
882     Value extracted = adaptor.dest();
883     auto positionAttrs = positionArrayAttr.getValue();
884     auto position = positionAttrs.back().cast<IntegerAttr>();
885     auto oneDVectorType = destVectorType;
886     if (positionAttrs.size() > 1) {
887       oneDVectorType = reducedVectorTypeBack(destVectorType);
888       auto nMinusOnePositionAttrs =
889           ArrayAttr::get(positionAttrs.drop_back(), context);
890       extracted = rewriter.create<LLVM::ExtractValueOp>(
891           loc, typeConverter->convertType(oneDVectorType), extracted,
892           nMinusOnePositionAttrs);
893     }
894 
895     // Insertion of an element into a 1-D LLVM vector.
896     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
897     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
898     Value inserted = rewriter.create<LLVM::InsertElementOp>(
899         loc, typeConverter->convertType(oneDVectorType), extracted,
900         adaptor.source(), constant);
901 
902     // Potential insertion of resulting 1-D vector into array.
903     if (positionAttrs.size() > 1) {
904       auto nMinusOnePositionAttrs =
905           ArrayAttr::get(positionAttrs.drop_back(), context);
906       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
907                                                       adaptor.dest(), inserted,
908                                                       nMinusOnePositionAttrs);
909     }
910 
911     rewriter.replaceOp(op, inserted);
912     return success();
913   }
914 };
915 
916 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
917 ///
918 /// Example:
919 /// ```
920 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
921 /// ```
922 /// is rewritten into:
923 /// ```
924 ///  %r = splat %f0: vector<2x4xf32>
925 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
926 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
927 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
928 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
929 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
930 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
931 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
932 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
933 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
934 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
935 ///  // %r3 holds the final value.
936 /// ```
937 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
938 public:
939   using OpRewritePattern<FMAOp>::OpRewritePattern;
940 
matchAndRewrite(FMAOp op,PatternRewriter & rewriter) const941   LogicalResult matchAndRewrite(FMAOp op,
942                                 PatternRewriter &rewriter) const override {
943     auto vType = op.getVectorType();
944     if (vType.getRank() < 2)
945       return failure();
946 
947     auto loc = op.getLoc();
948     auto elemType = vType.getElementType();
949     Value zero = rewriter.create<ConstantOp>(loc, elemType,
950                                              rewriter.getZeroAttr(elemType));
951     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
952     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
953       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
954       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
955       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
956       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
957       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
958     }
959     rewriter.replaceOp(op, desc);
960     return success();
961   }
962 };
963 
964 // When ranks are different, InsertStridedSlice needs to extract a properly
965 // ranked vector from the destination vector into which to insert. This pattern
966 // only takes care of this part and forwards the rest of the conversion to
967 // another pattern that converts InsertStridedSlice for operands of the same
968 // rank.
969 //
970 // RewritePattern for InsertStridedSliceOp where source and destination vectors
971 // have different ranks. In this case:
972 //   1. the proper subvector is extracted from the destination vector
973 //   2. a new InsertStridedSlice op is created to insert the source in the
974 //   destination subvector
975 //   3. the destination subvector is inserted back in the proper place
976 //   4. the op is replaced by the result of step 3.
977 // The new InsertStridedSlice from step 2. will be picked up by a
978 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
979 class VectorInsertStridedSliceOpDifferentRankRewritePattern
980     : public OpRewritePattern<InsertStridedSliceOp> {
981 public:
982   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
983 
matchAndRewrite(InsertStridedSliceOp op,PatternRewriter & rewriter) const984   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
985                                 PatternRewriter &rewriter) const override {
986     auto srcType = op.getSourceVectorType();
987     auto dstType = op.getDestVectorType();
988 
989     if (op.offsets().getValue().empty())
990       return failure();
991 
992     auto loc = op.getLoc();
993     int64_t rankDiff = dstType.getRank() - srcType.getRank();
994     assert(rankDiff >= 0);
995     if (rankDiff == 0)
996       return failure();
997 
998     int64_t rankRest = dstType.getRank() - rankDiff;
999     // Extract / insert the subvector of matching rank and InsertStridedSlice
1000     // on it.
1001     Value extracted =
1002         rewriter.create<ExtractOp>(loc, op.dest(),
1003                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
1004                                                   /*dropBack=*/rankRest));
1005     // A different pattern will kick in for InsertStridedSlice with matching
1006     // ranks.
1007     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
1008         loc, op.source(), extracted,
1009         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1010         getI64SubArray(op.strides(), /*dropFront=*/0));
1011     rewriter.replaceOpWithNewOp<InsertOp>(
1012         op, stridedSliceInnerOp.getResult(), op.dest(),
1013         getI64SubArray(op.offsets(), /*dropFront=*/0,
1014                        /*dropBack=*/rankRest));
1015     return success();
1016   }
1017 };
1018 
1019 // RewritePattern for InsertStridedSliceOp where source and destination vectors
1020 // have the same rank. In this case, we reduce
1021 //   1. the proper subvector is extracted from the destination vector
1022 //   2. a new InsertStridedSlice op is created to insert the source in the
1023 //   destination subvector
1024 //   3. the destination subvector is inserted back in the proper place
1025 //   4. the op is replaced by the result of step 3.
1026 // The new InsertStridedSlice from step 2. will be picked up by a
1027 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
1028 class VectorInsertStridedSliceOpSameRankRewritePattern
1029     : public OpRewritePattern<InsertStridedSliceOp> {
1030 public:
VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext * ctx)1031   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1032       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1033     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1034     // bounded as the rank is strictly decreasing.
1035     setHasBoundedRewriteRecursion();
1036   }
1037 
matchAndRewrite(InsertStridedSliceOp op,PatternRewriter & rewriter) const1038   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
1039                                 PatternRewriter &rewriter) const override {
1040     auto srcType = op.getSourceVectorType();
1041     auto dstType = op.getDestVectorType();
1042 
1043     if (op.offsets().getValue().empty())
1044       return failure();
1045 
1046     int64_t rankDiff = dstType.getRank() - srcType.getRank();
1047     assert(rankDiff >= 0);
1048     if (rankDiff != 0)
1049       return failure();
1050 
1051     if (srcType == dstType) {
1052       rewriter.replaceOp(op, op.source());
1053       return success();
1054     }
1055 
1056     int64_t offset =
1057         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1058     int64_t size = srcType.getShape().front();
1059     int64_t stride =
1060         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1061 
1062     auto loc = op.getLoc();
1063     Value res = op.dest();
1064     // For each slice of the source vector along the most major dimension.
1065     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1066          off += stride, ++idx) {
1067       // 1. extract the proper subvector (or element) from source
1068       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1069       if (extractedSource.getType().isa<VectorType>()) {
1070         // 2. If we have a vector, extract the proper subvector from destination
1071         // Otherwise we are at the element level and no need to recurse.
1072         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1073         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1074         // smaller rank.
1075         extractedSource = rewriter.create<InsertStridedSliceOp>(
1076             loc, extractedSource, extractedDest,
1077             getI64SubArray(op.offsets(), /* dropFront=*/1),
1078             getI64SubArray(op.strides(), /* dropFront=*/1));
1079       }
1080       // 4. Insert the extractedSource into the res vector.
1081       res = insertOne(rewriter, loc, extractedSource, res, off);
1082     }
1083 
1084     rewriter.replaceOp(op, res);
1085     return success();
1086   }
1087 };
1088 
1089 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1090 /// static layout.
1091 static llvm::Optional<SmallVector<int64_t, 4>>
computeContiguousStrides(MemRefType memRefType)1092 computeContiguousStrides(MemRefType memRefType) {
1093   int64_t offset;
1094   SmallVector<int64_t, 4> strides;
1095   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1096     return None;
1097   if (!strides.empty() && strides.back() != 1)
1098     return None;
1099   // If no layout or identity layout, this is contiguous by definition.
1100   if (memRefType.getAffineMaps().empty() ||
1101       memRefType.getAffineMaps().front().isIdentity())
1102     return strides;
1103 
1104   // Otherwise, we must determine contiguity form shapes. This can only ever
1105   // work in static cases because MemRefType is underspecified to represent
1106   // contiguous dynamic shapes in other ways than with just empty/identity
1107   // layout.
1108   auto sizes = memRefType.getShape();
1109   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1110     if (ShapedType::isDynamic(sizes[index + 1]) ||
1111         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
1112         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
1113       return None;
1114     if (strides[index] != strides[index + 1] * sizes[index + 1])
1115       return None;
1116   }
1117   return strides;
1118 }
1119 
1120 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
1121 public:
VectorTypeCastOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)1122   explicit VectorTypeCastOpConversion(MLIRContext *context,
1123                                       LLVMTypeConverter &typeConverter)
1124       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
1125                              typeConverter) {}
1126 
1127   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1128   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1129                   ConversionPatternRewriter &rewriter) const override {
1130     auto loc = op->getLoc();
1131     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
1132     MemRefType sourceMemRefType =
1133         castOp.getOperand().getType().cast<MemRefType>();
1134     MemRefType targetMemRefType =
1135         castOp.getResult().getType().cast<MemRefType>();
1136 
1137     // Only static shape casts supported atm.
1138     if (!sourceMemRefType.hasStaticShape() ||
1139         !targetMemRefType.hasStaticShape())
1140       return failure();
1141 
1142     auto llvmSourceDescriptorTy =
1143         operands[0].getType().dyn_cast<LLVM::LLVMType>();
1144     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
1145       return failure();
1146     MemRefDescriptor sourceMemRef(operands[0]);
1147 
1148     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1149                                       .dyn_cast_or_null<LLVM::LLVMType>();
1150     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
1151       return failure();
1152 
1153     // Only contiguous source buffers supported atm.
1154     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1155     if (!sourceStrides)
1156       return failure();
1157     auto targetStrides = computeContiguousStrides(targetMemRefType);
1158     if (!targetStrides)
1159       return failure();
1160     // Only support static strides for now, regardless of contiguity.
1161     if (llvm::any_of(*targetStrides, [](int64_t stride) {
1162           return ShapedType::isDynamicStrideOrOffset(stride);
1163         }))
1164       return failure();
1165 
1166     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
1167 
1168     // Create descriptor.
1169     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1170     Type llvmTargetElementTy = desc.getElementPtrType();
1171     // Set allocated ptr.
1172     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1173     allocated =
1174         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1175     desc.setAllocatedPtr(rewriter, loc, allocated);
1176     // Set aligned ptr.
1177     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1178     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1179     desc.setAlignedPtr(rewriter, loc, ptr);
1180     // Fill offset 0.
1181     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1182     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1183     desc.setOffset(rewriter, loc, zero);
1184 
1185     // Fill size and stride descriptors in memref.
1186     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1187       int64_t index = indexedSize.index();
1188       auto sizeAttr =
1189           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1190       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1191       desc.setSize(rewriter, loc, index, size);
1192       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1193                                                 (*targetStrides)[index]);
1194       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1195       desc.setStride(rewriter, loc, index, stride);
1196     }
1197 
1198     rewriter.replaceOp(op, {desc});
1199     return success();
1200   }
1201 };
1202 
1203 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1204 /// sequence of:
1205 /// 1. Get the source/dst address as an LLVM vector pointer.
1206 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1207 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1208 /// 4. Create a mask where offsetVector is compared against memref upper bound.
1209 /// 5. Rewrite op as a masked read or write.
1210 template <typename ConcreteOp>
1211 class VectorTransferConversion : public ConvertToLLVMPattern {
1212 public:
VectorTransferConversion(MLIRContext * context,LLVMTypeConverter & typeConv,bool enableIndexOpt)1213   explicit VectorTransferConversion(MLIRContext *context,
1214                                     LLVMTypeConverter &typeConv,
1215                                     bool enableIndexOpt)
1216       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1217         enableIndexOptimizations(enableIndexOpt) {}
1218 
1219   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1220   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1221                   ConversionPatternRewriter &rewriter) const override {
1222     auto xferOp = cast<ConcreteOp>(op);
1223     auto adaptor = getTransferOpAdapter(xferOp, operands);
1224 
1225     if (xferOp.getVectorType().getRank() > 1 ||
1226         llvm::size(xferOp.indices()) == 0)
1227       return failure();
1228     if (xferOp.permutation_map() !=
1229         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1230                                        xferOp.getVectorType().getRank(),
1231                                        op->getContext()))
1232       return failure();
1233     // Only contiguous source tensors supported atm.
1234     auto strides = computeContiguousStrides(xferOp.getMemRefType());
1235     if (!strides)
1236       return failure();
1237 
1238     auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
1239 
1240     Location loc = op->getLoc();
1241     MemRefType memRefType = xferOp.getMemRefType();
1242 
1243     if (auto memrefVectorElementType =
1244             memRefType.getElementType().dyn_cast<VectorType>()) {
1245       // Memref has vector element type.
1246       if (memrefVectorElementType.getElementType() !=
1247           xferOp.getVectorType().getElementType())
1248         return failure();
1249 #ifndef NDEBUG
1250       // Check that memref vector type is a suffix of 'vectorType.
1251       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1252       unsigned resultVecRank = xferOp.getVectorType().getRank();
1253       assert(memrefVecEltRank <= resultVecRank);
1254       // TODO: Move this to isSuffix in Vector/Utils.h.
1255       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1256       auto memrefVecEltShape = memrefVectorElementType.getShape();
1257       auto resultVecShape = xferOp.getVectorType().getShape();
1258       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1259         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1260                "memref vector element shape should match suffix of vector "
1261                "result shape.");
1262 #endif // ifndef NDEBUG
1263     }
1264 
1265     // 1. Get the source/dst address as an LLVM vector pointer.
1266     //    The vector pointer would always be on address space 0, therefore
1267     //    addrspacecast shall be used when source/dst memrefs are not on
1268     //    address space 0.
1269     // TODO: support alignment when possible.
1270     Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
1271                                          adaptor.indices(), rewriter);
1272     auto vecTy =
1273         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1274     Value vectorDataPtr;
1275     if (memRefType.getMemorySpace() == 0)
1276       vectorDataPtr =
1277           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1278     else
1279       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1280           loc, vecTy.getPointerTo(), dataPtr);
1281 
1282     if (!xferOp.isMaskedDim(0))
1283       return replaceTransferOpWithLoadOrStore(
1284           rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
1285 
1286     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1287     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1288     // 4. Let dim the memref dimension, compute the vector comparison mask:
1289     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1290     //
1291     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1292     //       dimensions here.
1293     unsigned vecWidth = vecTy.getVectorNumElements();
1294     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1295     Value off = xferOp.indices()[lastIndex];
1296     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1297     Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1298                                        vecWidth, dim, &off);
1299 
1300     // 5. Rewrite as a masked read / write.
1301     return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
1302                                        xferOp, operands, vectorDataPtr, mask);
1303   }
1304 
1305 private:
1306   const bool enableIndexOptimizations;
1307 };
1308 
1309 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1310 public:
VectorPrintOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)1311   explicit VectorPrintOpConversion(MLIRContext *context,
1312                                    LLVMTypeConverter &typeConverter)
1313       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1314                              typeConverter) {}
1315 
1316   // Proof-of-concept lowering implementation that relies on a small
1317   // runtime support library, which only needs to provide a few
1318   // printing methods (single value for all data types, opening/closing
1319   // bracket, comma, newline). The lowering fully unrolls a vector
1320   // in terms of these elementary printing operations. The advantage
1321   // of this approach is that the library can remain unaware of all
1322   // low-level implementation details of vectors while still supporting
1323   // output of any shaped and dimensioned vector. Due to full unrolling,
1324   // this approach is less suited for very large vectors though.
1325   //
1326   // TODO: rely solely on libc in future? something else?
1327   //
1328   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1329   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1330                   ConversionPatternRewriter &rewriter) const override {
1331     auto printOp = cast<vector::PrintOp>(op);
1332     auto adaptor = vector::PrintOpAdaptor(operands);
1333     Type printType = printOp.getPrintType();
1334 
1335     if (typeConverter->convertType(printType) == nullptr)
1336       return failure();
1337 
1338     // Make sure element type has runtime support.
1339     PrintConversion conversion = PrintConversion::None;
1340     VectorType vectorType = printType.dyn_cast<VectorType>();
1341     Type eltType = vectorType ? vectorType.getElementType() : printType;
1342     Operation *printer;
1343     if (eltType.isF32()) {
1344       printer = getPrintFloat(op);
1345     } else if (eltType.isF64()) {
1346       printer = getPrintDouble(op);
1347     } else if (eltType.isIndex()) {
1348       printer = getPrintU64(op);
1349     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1350       // Integers need a zero or sign extension on the operand
1351       // (depending on the source type) as well as a signed or
1352       // unsigned print method. Up to 64-bit is supported.
1353       unsigned width = intTy.getWidth();
1354       if (intTy.isUnsigned()) {
1355         if (width <= 64) {
1356           if (width < 64)
1357             conversion = PrintConversion::ZeroExt64;
1358           printer = getPrintU64(op);
1359         } else {
1360           return failure();
1361         }
1362       } else {
1363         assert(intTy.isSignless() || intTy.isSigned());
1364         if (width <= 64) {
1365           // Note that we *always* zero extend booleans (1-bit integers),
1366           // so that true/false is printed as 1/0 rather than -1/0.
1367           if (width == 1)
1368             conversion = PrintConversion::ZeroExt64;
1369           else if (width < 64)
1370             conversion = PrintConversion::SignExt64;
1371           printer = getPrintI64(op);
1372         } else {
1373           return failure();
1374         }
1375       }
1376     } else {
1377       return failure();
1378     }
1379 
1380     // Unroll vector into elementary print calls.
1381     int64_t rank = vectorType ? vectorType.getRank() : 0;
1382     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
1383               conversion);
1384     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1385     rewriter.eraseOp(op);
1386     return success();
1387   }
1388 
1389 private:
1390   enum class PrintConversion {
1391     // clang-format off
1392     None,
1393     ZeroExt64,
1394     SignExt64
1395     // clang-format on
1396   };
1397 
emitRanks(ConversionPatternRewriter & rewriter,Operation * op,Value value,VectorType vectorType,Operation * printer,int64_t rank,PrintConversion conversion) const1398   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1399                  Value value, VectorType vectorType, Operation *printer,
1400                  int64_t rank, PrintConversion conversion) const {
1401     Location loc = op->getLoc();
1402     if (rank == 0) {
1403       switch (conversion) {
1404       case PrintConversion::ZeroExt64:
1405         value = rewriter.create<ZeroExtendIOp>(
1406             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1407         break;
1408       case PrintConversion::SignExt64:
1409         value = rewriter.create<SignExtendIOp>(
1410             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1411         break;
1412       case PrintConversion::None:
1413         break;
1414       }
1415       emitCall(rewriter, loc, printer, value);
1416       return;
1417     }
1418 
1419     emitCall(rewriter, loc, getPrintOpen(op));
1420     Operation *printComma = getPrintComma(op);
1421     int64_t dim = vectorType.getDimSize(0);
1422     for (int64_t d = 0; d < dim; ++d) {
1423       auto reducedType =
1424           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1425       auto llvmType = typeConverter->convertType(
1426           rank > 1 ? reducedType : vectorType.getElementType());
1427       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1428                                    llvmType, rank, d);
1429       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1430                 conversion);
1431       if (d != dim - 1)
1432         emitCall(rewriter, loc, printComma);
1433     }
1434     emitCall(rewriter, loc, getPrintClose(op));
1435   }
1436 
1437   // Helper to emit a call.
emitCall(ConversionPatternRewriter & rewriter,Location loc,Operation * ref,ValueRange params=ValueRange ())1438   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1439                        Operation *ref, ValueRange params = ValueRange()) {
1440     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1441                                   rewriter.getSymbolRefAttr(ref), params);
1442   }
1443 
1444   // Helper for printer method declaration (first hit) and lookup.
getPrint(Operation * op,StringRef name,ArrayRef<LLVM::LLVMType> params)1445   static Operation *getPrint(Operation *op, StringRef name,
1446                              ArrayRef<LLVM::LLVMType> params) {
1447     auto module = op->getParentOfType<ModuleOp>();
1448     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1449     if (func)
1450       return func;
1451     OpBuilder moduleBuilder(module.getBodyRegion());
1452     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1453         op->getLoc(), name,
1454         LLVM::LLVMType::getFunctionTy(
1455             LLVM::LLVMType::getVoidTy(op->getContext()), params,
1456             /*isVarArg=*/false));
1457   }
1458 
1459   // Helpers for method names.
getPrintI64(Operation * op) const1460   Operation *getPrintI64(Operation *op) const {
1461     return getPrint(op, "printI64",
1462                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1463   }
getPrintU64(Operation * op) const1464   Operation *getPrintU64(Operation *op) const {
1465     return getPrint(op, "printU64",
1466                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1467   }
getPrintFloat(Operation * op) const1468   Operation *getPrintFloat(Operation *op) const {
1469     return getPrint(op, "printF32",
1470                     LLVM::LLVMType::getFloatTy(op->getContext()));
1471   }
getPrintDouble(Operation * op) const1472   Operation *getPrintDouble(Operation *op) const {
1473     return getPrint(op, "printF64",
1474                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1475   }
getPrintOpen(Operation * op) const1476   Operation *getPrintOpen(Operation *op) const {
1477     return getPrint(op, "printOpen", {});
1478   }
getPrintClose(Operation * op) const1479   Operation *getPrintClose(Operation *op) const {
1480     return getPrint(op, "printClose", {});
1481   }
getPrintComma(Operation * op) const1482   Operation *getPrintComma(Operation *op) const {
1483     return getPrint(op, "printComma", {});
1484   }
getPrintNewline(Operation * op) const1485   Operation *getPrintNewline(Operation *op) const {
1486     return getPrint(op, "printNewline", {});
1487   }
1488 };
1489 
1490 /// Progressive lowering of ExtractStridedSliceOp to either:
1491 ///   1. express single offset extract as a direct shuffle.
1492 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1493 class VectorExtractStridedSliceOpConversion
1494     : public OpRewritePattern<ExtractStridedSliceOp> {
1495 public:
VectorExtractStridedSliceOpConversion(MLIRContext * ctx)1496   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1497       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1498     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1499     // is bounded as the rank is strictly decreasing.
1500     setHasBoundedRewriteRecursion();
1501   }
1502 
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const1503   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1504                                 PatternRewriter &rewriter) const override {
1505     auto dstType = op.getResult().getType().cast<VectorType>();
1506 
1507     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1508 
1509     int64_t offset =
1510         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1511     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1512     int64_t stride =
1513         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1514 
1515     auto loc = op.getLoc();
1516     auto elemType = dstType.getElementType();
1517     assert(elemType.isSignlessIntOrIndexOrFloat());
1518 
1519     // Single offset can be more efficiently shuffled.
1520     if (op.offsets().getValue().size() == 1) {
1521       SmallVector<int64_t, 4> offsets;
1522       offsets.reserve(size);
1523       for (int64_t off = offset, e = offset + size * stride; off < e;
1524            off += stride)
1525         offsets.push_back(off);
1526       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1527                                              op.vector(),
1528                                              rewriter.getI64ArrayAttr(offsets));
1529       return success();
1530     }
1531 
1532     // Extract/insert on a lower ranked extract strided slice op.
1533     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1534                                              rewriter.getZeroAttr(elemType));
1535     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1536     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1537          off += stride, ++idx) {
1538       Value one = extractOne(rewriter, loc, op.vector(), off);
1539       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1540           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1541           getI64SubArray(op.sizes(), /* dropFront=*/1),
1542           getI64SubArray(op.strides(), /* dropFront=*/1));
1543       res = insertOne(rewriter, loc, extracted, res, idx);
1544     }
1545     rewriter.replaceOp(op, res);
1546     return success();
1547   }
1548 };
1549 
1550 } // namespace
1551 
1552 /// Populate the given list with patterns that convert from Vector to LLVM.
populateVectorToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns,bool reassociateFPReductions,bool enableIndexOptimizations)1553 void mlir::populateVectorToLLVMConversionPatterns(
1554     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1555     bool reassociateFPReductions, bool enableIndexOptimizations) {
1556   MLIRContext *ctx = converter.getDialect()->getContext();
1557   // clang-format off
1558   patterns.insert<VectorFMAOpNDRewritePattern,
1559                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1560                   VectorInsertStridedSliceOpSameRankRewritePattern,
1561                   VectorExtractStridedSliceOpConversion>(ctx);
1562   patterns.insert<VectorReductionOpConversion>(
1563       ctx, converter, reassociateFPReductions);
1564   patterns.insert<VectorCreateMaskOpConversion,
1565                   VectorTransferConversion<TransferReadOp>,
1566                   VectorTransferConversion<TransferWriteOp>>(
1567       ctx, converter, enableIndexOptimizations);
1568   patterns
1569       .insert<VectorShuffleOpConversion,
1570               VectorExtractElementOpConversion,
1571               VectorExtractOpConversion,
1572               VectorFMAOp1DConversion,
1573               VectorInsertElementOpConversion,
1574               VectorInsertOpConversion,
1575               VectorPrintOpConversion,
1576               VectorTypeCastOpConversion,
1577               VectorMaskedLoadOpConversion,
1578               VectorMaskedStoreOpConversion,
1579               VectorGatherOpConversion,
1580               VectorScatterOpConversion,
1581               VectorExpandLoadOpConversion,
1582               VectorCompressStoreOpConversion>(ctx, converter);
1583   // clang-format on
1584 }
1585 
populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)1586 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1587     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1588   MLIRContext *ctx = converter.getDialect()->getContext();
1589   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1590   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1591 }
1592