1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Vector/VectorOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/Target/LLVMIR/TypeTranslation.h"
18 #include "mlir/Transforms/DialectConversion.h"
19
20 using namespace mlir;
21 using namespace mlir::vector;
22
23 // Helper to reduce vector type by one rank at front.
reducedVectorTypeFront(VectorType tp)24 static VectorType reducedVectorTypeFront(VectorType tp) {
25 assert((tp.getRank() > 1) && "unlowerable vector type");
26 return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
27 }
28
29 // Helper to reduce vector type by *all* but one rank at back.
reducedVectorTypeBack(VectorType tp)30 static VectorType reducedVectorTypeBack(VectorType tp) {
31 assert((tp.getRank() > 1) && "unlowerable vector type");
32 return VectorType::get(tp.getShape().take_back(), tp.getElementType());
33 }
34
35 // Helper that picks the proper sequence for inserting.
insertOne(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,Value val1,Value val2,Type llvmType,int64_t rank,int64_t pos)36 static Value insertOne(ConversionPatternRewriter &rewriter,
37 LLVMTypeConverter &typeConverter, Location loc,
38 Value val1, Value val2, Type llvmType, int64_t rank,
39 int64_t pos) {
40 if (rank == 1) {
41 auto idxType = rewriter.getIndexType();
42 auto constant = rewriter.create<LLVM::ConstantOp>(
43 loc, typeConverter.convertType(idxType),
44 rewriter.getIntegerAttr(idxType, pos));
45 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
46 constant);
47 }
48 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
49 rewriter.getI64ArrayAttr(pos));
50 }
51
52 // Helper that picks the proper sequence for inserting.
insertOne(PatternRewriter & rewriter,Location loc,Value from,Value into,int64_t offset)53 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
54 Value into, int64_t offset) {
55 auto vectorType = into.getType().cast<VectorType>();
56 if (vectorType.getRank() > 1)
57 return rewriter.create<InsertOp>(loc, from, into, offset);
58 return rewriter.create<vector::InsertElementOp>(
59 loc, vectorType, from, into,
60 rewriter.create<ConstantIndexOp>(loc, offset));
61 }
62
63 // Helper that picks the proper sequence for extracting.
extractOne(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,Value val,Type llvmType,int64_t rank,int64_t pos)64 static Value extractOne(ConversionPatternRewriter &rewriter,
65 LLVMTypeConverter &typeConverter, Location loc,
66 Value val, Type llvmType, int64_t rank, int64_t pos) {
67 if (rank == 1) {
68 auto idxType = rewriter.getIndexType();
69 auto constant = rewriter.create<LLVM::ConstantOp>(
70 loc, typeConverter.convertType(idxType),
71 rewriter.getIntegerAttr(idxType, pos));
72 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
73 constant);
74 }
75 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
76 rewriter.getI64ArrayAttr(pos));
77 }
78
79 // Helper that picks the proper sequence for extracting.
extractOne(PatternRewriter & rewriter,Location loc,Value vector,int64_t offset)80 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
81 int64_t offset) {
82 auto vectorType = vector.getType().cast<VectorType>();
83 if (vectorType.getRank() > 1)
84 return rewriter.create<ExtractOp>(loc, vector, offset);
85 return rewriter.create<vector::ExtractElementOp>(
86 loc, vectorType.getElementType(), vector,
87 rewriter.create<ConstantIndexOp>(loc, offset));
88 }
89
90 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
91 // TODO: Better support for attribute subtype forwarding + slicing.
getI64SubArray(ArrayAttr arrayAttr,unsigned dropFront=0,unsigned dropBack=0)92 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
93 unsigned dropFront = 0,
94 unsigned dropBack = 0) {
95 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
96 auto range = arrayAttr.getAsRange<IntegerAttr>();
97 SmallVector<int64_t, 4> res;
98 res.reserve(arrayAttr.size() - dropFront - dropBack);
99 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
100 it != eit; ++it)
101 res.push_back((*it).getValue().getSExtValue());
102 return res;
103 }
104
105 // Helper that returns a vector comparison that constructs a mask:
106 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107 //
108 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109 // much more compact, IR for this operation, but LLVM eventually
110 // generates more elaborate instructions for this intrinsic since it
111 // is very conservative on the boundary conditions.
buildVectorComparison(ConversionPatternRewriter & rewriter,Operation * op,bool enableIndexOptimizations,int64_t dim,Value b,Value * off=nullptr)112 static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113 Operation *op, bool enableIndexOptimizations,
114 int64_t dim, Value b, Value *off = nullptr) {
115 auto loc = op->getLoc();
116 // If we can assume all indices fit in 32-bit, we perform the vector
117 // comparison in 32-bit to get a higher degree of SIMD parallelism.
118 // Otherwise we perform the vector comparison using 64-bit indices.
119 Value indices;
120 Type idxType;
121 if (enableIndexOptimizations) {
122 indices = rewriter.create<ConstantOp>(
123 loc, rewriter.getI32VectorAttr(
124 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125 idxType = rewriter.getI32Type();
126 } else {
127 indices = rewriter.create<ConstantOp>(
128 loc, rewriter.getI64VectorAttr(
129 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130 idxType = rewriter.getI64Type();
131 }
132 // Add in an offset if requested.
133 if (off) {
134 Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135 Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136 indices = rewriter.create<AddIOp>(loc, ov, indices);
137 }
138 // Construct the vector comparison.
139 Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141 return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142 }
143
144 // Helper that returns data layout alignment of an operation with memref.
145 template <typename T>
getMemRefAlignment(LLVMTypeConverter & typeConverter,T op,unsigned & align)146 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
147 unsigned &align) {
148 Type elementTy =
149 typeConverter.convertType(op.getMemRefType().getElementType());
150 if (!elementTy)
151 return failure();
152
153 // TODO: this should use the MLIR data layout when it becomes available and
154 // stop depending on translation.
155 llvm::LLVMContext llvmContext;
156 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
157 .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
158 typeConverter.getDataLayout());
159 return success();
160 }
161
162 // Helper that returns the base address of a memref.
getBase(ConversionPatternRewriter & rewriter,Location loc,Value memref,MemRefType memRefType,Value & base)163 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
164 Value memref, MemRefType memRefType, Value &base) {
165 // Inspect stride and offset structure.
166 //
167 // TODO: flat memory only for now, generalize
168 //
169 int64_t offset;
170 SmallVector<int64_t, 4> strides;
171 auto successStrides = getStridesAndOffset(memRefType, strides, offset);
172 if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
173 offset != 0 || memRefType.getMemorySpace() != 0)
174 return failure();
175 base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
176 return success();
177 }
178
179 // Helper that returns a pointer given a memref base.
getBasePtr(ConversionPatternRewriter & rewriter,Location loc,Value memref,MemRefType memRefType,Value & ptr)180 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
181 Location loc, Value memref,
182 MemRefType memRefType, Value &ptr) {
183 Value base;
184 if (failed(getBase(rewriter, loc, memref, memRefType, base)))
185 return failure();
186 auto pType = MemRefDescriptor(memref).getElementPtrType();
187 ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
188 return success();
189 }
190
191 // Helper that returns a bit-casted pointer given a memref base.
getBasePtr(ConversionPatternRewriter & rewriter,Location loc,Value memref,MemRefType memRefType,Type type,Value & ptr)192 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
193 Location loc, Value memref,
194 MemRefType memRefType, Type type, Value &ptr) {
195 Value base;
196 if (failed(getBase(rewriter, loc, memref, memRefType, base)))
197 return failure();
198 auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
199 base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
200 ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
201 return success();
202 }
203
204 // Helper that returns vector of pointers given a memref base and an index
205 // vector.
getIndexedPtrs(ConversionPatternRewriter & rewriter,Location loc,Value memref,Value indices,MemRefType memRefType,VectorType vType,Type iType,Value & ptrs)206 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
207 Location loc, Value memref, Value indices,
208 MemRefType memRefType, VectorType vType,
209 Type iType, Value &ptrs) {
210 Value base;
211 if (failed(getBase(rewriter, loc, memref, memRefType, base)))
212 return failure();
213 auto pType = MemRefDescriptor(memref).getElementPtrType();
214 auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
215 ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
216 return success();
217 }
218
219 static LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,TransferReadOp xferOp,ArrayRef<Value> operands,Value dataPtr)220 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
221 LLVMTypeConverter &typeConverter, Location loc,
222 TransferReadOp xferOp,
223 ArrayRef<Value> operands, Value dataPtr) {
224 unsigned align;
225 if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
226 return failure();
227 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
228 return success();
229 }
230
231 static LogicalResult
replaceTransferOpWithMasked(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,TransferReadOp xferOp,ArrayRef<Value> operands,Value dataPtr,Value mask)232 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
233 LLVMTypeConverter &typeConverter, Location loc,
234 TransferReadOp xferOp, ArrayRef<Value> operands,
235 Value dataPtr, Value mask) {
236 auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
237 VectorType fillType = xferOp.getVectorType();
238 Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
239 fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
240
241 Type vecTy = typeConverter.convertType(xferOp.getVectorType());
242 if (!vecTy)
243 return failure();
244
245 unsigned align;
246 if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
247 return failure();
248
249 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
250 xferOp, vecTy, dataPtr, mask, ValueRange{fill},
251 rewriter.getI32IntegerAttr(align));
252 return success();
253 }
254
255 static LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,TransferWriteOp xferOp,ArrayRef<Value> operands,Value dataPtr)256 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
257 LLVMTypeConverter &typeConverter, Location loc,
258 TransferWriteOp xferOp,
259 ArrayRef<Value> operands, Value dataPtr) {
260 unsigned align;
261 if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
262 return failure();
263 auto adaptor = TransferWriteOpAdaptor(operands);
264 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
265 align);
266 return success();
267 }
268
269 static LogicalResult
replaceTransferOpWithMasked(ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Location loc,TransferWriteOp xferOp,ArrayRef<Value> operands,Value dataPtr,Value mask)270 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
271 LLVMTypeConverter &typeConverter, Location loc,
272 TransferWriteOp xferOp, ArrayRef<Value> operands,
273 Value dataPtr, Value mask) {
274 unsigned align;
275 if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
276 return failure();
277
278 auto adaptor = TransferWriteOpAdaptor(operands);
279 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
280 xferOp, adaptor.vector(), dataPtr, mask,
281 rewriter.getI32IntegerAttr(align));
282 return success();
283 }
284
getTransferOpAdapter(TransferReadOp xferOp,ArrayRef<Value> operands)285 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
286 ArrayRef<Value> operands) {
287 return TransferReadOpAdaptor(operands);
288 }
289
getTransferOpAdapter(TransferWriteOp xferOp,ArrayRef<Value> operands)290 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
291 ArrayRef<Value> operands) {
292 return TransferWriteOpAdaptor(operands);
293 }
294
295 namespace {
296
297 /// Conversion pattern for a vector.matrix_multiply.
298 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
299 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
300 public:
VectorMatmulOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)301 explicit VectorMatmulOpConversion(MLIRContext *context,
302 LLVMTypeConverter &typeConverter)
303 : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
304 typeConverter) {}
305
306 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const307 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
308 ConversionPatternRewriter &rewriter) const override {
309 auto matmulOp = cast<vector::MatmulOp>(op);
310 auto adaptor = vector::MatmulOpAdaptor(operands);
311 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
312 op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
313 adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
314 matmulOp.rhs_columns());
315 return success();
316 }
317 };
318
319 /// Conversion pattern for a vector.flat_transpose.
320 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
321 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
322 public:
VectorFlatTransposeOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)323 explicit VectorFlatTransposeOpConversion(MLIRContext *context,
324 LLVMTypeConverter &typeConverter)
325 : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
326 context, typeConverter) {}
327
328 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const329 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
330 ConversionPatternRewriter &rewriter) const override {
331 auto transOp = cast<vector::FlatTransposeOp>(op);
332 auto adaptor = vector::FlatTransposeOpAdaptor(operands);
333 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
334 transOp, typeConverter->convertType(transOp.res().getType()),
335 adaptor.matrix(), transOp.rows(), transOp.columns());
336 return success();
337 }
338 };
339
340 /// Conversion pattern for a vector.maskedload.
341 class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
342 public:
VectorMaskedLoadOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)343 explicit VectorMaskedLoadOpConversion(MLIRContext *context,
344 LLVMTypeConverter &typeConverter)
345 : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
346 typeConverter) {}
347
348 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const349 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
350 ConversionPatternRewriter &rewriter) const override {
351 auto loc = op->getLoc();
352 auto load = cast<vector::MaskedLoadOp>(op);
353 auto adaptor = vector::MaskedLoadOpAdaptor(operands);
354
355 // Resolve alignment.
356 unsigned align;
357 if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
358 return failure();
359
360 auto vtype = typeConverter->convertType(load.getResultVectorType());
361 Value ptr;
362 if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
363 vtype, ptr)))
364 return failure();
365
366 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
367 load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
368 rewriter.getI32IntegerAttr(align));
369 return success();
370 }
371 };
372
373 /// Conversion pattern for a vector.maskedstore.
374 class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
375 public:
VectorMaskedStoreOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)376 explicit VectorMaskedStoreOpConversion(MLIRContext *context,
377 LLVMTypeConverter &typeConverter)
378 : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
379 typeConverter) {}
380
381 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const382 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
383 ConversionPatternRewriter &rewriter) const override {
384 auto loc = op->getLoc();
385 auto store = cast<vector::MaskedStoreOp>(op);
386 auto adaptor = vector::MaskedStoreOpAdaptor(operands);
387
388 // Resolve alignment.
389 unsigned align;
390 if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
391 return failure();
392
393 auto vtype = typeConverter->convertType(store.getValueVectorType());
394 Value ptr;
395 if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
396 vtype, ptr)))
397 return failure();
398
399 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
400 store, adaptor.value(), ptr, adaptor.mask(),
401 rewriter.getI32IntegerAttr(align));
402 return success();
403 }
404 };
405
406 /// Conversion pattern for a vector.gather.
407 class VectorGatherOpConversion : public ConvertToLLVMPattern {
408 public:
VectorGatherOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)409 explicit VectorGatherOpConversion(MLIRContext *context,
410 LLVMTypeConverter &typeConverter)
411 : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
412 typeConverter) {}
413
414 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const415 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
416 ConversionPatternRewriter &rewriter) const override {
417 auto loc = op->getLoc();
418 auto gather = cast<vector::GatherOp>(op);
419 auto adaptor = vector::GatherOpAdaptor(operands);
420
421 // Resolve alignment.
422 unsigned align;
423 if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
424 return failure();
425
426 // Get index ptrs.
427 VectorType vType = gather.getResultVectorType();
428 Type iType = gather.getIndicesVectorType().getElementType();
429 Value ptrs;
430 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
431 gather.getMemRefType(), vType, iType, ptrs)))
432 return failure();
433
434 // Replace with the gather intrinsic.
435 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
436 gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
437 adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
438 return success();
439 }
440 };
441
442 /// Conversion pattern for a vector.scatter.
443 class VectorScatterOpConversion : public ConvertToLLVMPattern {
444 public:
VectorScatterOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)445 explicit VectorScatterOpConversion(MLIRContext *context,
446 LLVMTypeConverter &typeConverter)
447 : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
448 typeConverter) {}
449
450 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const451 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
452 ConversionPatternRewriter &rewriter) const override {
453 auto loc = op->getLoc();
454 auto scatter = cast<vector::ScatterOp>(op);
455 auto adaptor = vector::ScatterOpAdaptor(operands);
456
457 // Resolve alignment.
458 unsigned align;
459 if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
460 return failure();
461
462 // Get index ptrs.
463 VectorType vType = scatter.getValueVectorType();
464 Type iType = scatter.getIndicesVectorType().getElementType();
465 Value ptrs;
466 if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
467 scatter.getMemRefType(), vType, iType, ptrs)))
468 return failure();
469
470 // Replace with the scatter intrinsic.
471 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
472 scatter, adaptor.value(), ptrs, adaptor.mask(),
473 rewriter.getI32IntegerAttr(align));
474 return success();
475 }
476 };
477
478 /// Conversion pattern for a vector.expandload.
479 class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
480 public:
VectorExpandLoadOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)481 explicit VectorExpandLoadOpConversion(MLIRContext *context,
482 LLVMTypeConverter &typeConverter)
483 : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
484 typeConverter) {}
485
486 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const487 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
488 ConversionPatternRewriter &rewriter) const override {
489 auto loc = op->getLoc();
490 auto expand = cast<vector::ExpandLoadOp>(op);
491 auto adaptor = vector::ExpandLoadOpAdaptor(operands);
492
493 Value ptr;
494 if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
495 ptr)))
496 return failure();
497
498 auto vType = expand.getResultVectorType();
499 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
500 op, typeConverter->convertType(vType), ptr, adaptor.mask(),
501 adaptor.pass_thru());
502 return success();
503 }
504 };
505
506 /// Conversion pattern for a vector.compressstore.
507 class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
508 public:
VectorCompressStoreOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)509 explicit VectorCompressStoreOpConversion(MLIRContext *context,
510 LLVMTypeConverter &typeConverter)
511 : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
512 context, typeConverter) {}
513
514 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const515 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
516 ConversionPatternRewriter &rewriter) const override {
517 auto loc = op->getLoc();
518 auto compress = cast<vector::CompressStoreOp>(op);
519 auto adaptor = vector::CompressStoreOpAdaptor(operands);
520
521 Value ptr;
522 if (failed(getBasePtr(rewriter, loc, adaptor.base(),
523 compress.getMemRefType(), ptr)))
524 return failure();
525
526 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
527 op, adaptor.value(), ptr, adaptor.mask());
528 return success();
529 }
530 };
531
532 /// Conversion pattern for all vector reductions.
533 class VectorReductionOpConversion : public ConvertToLLVMPattern {
534 public:
VectorReductionOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter,bool reassociateFPRed)535 explicit VectorReductionOpConversion(MLIRContext *context,
536 LLVMTypeConverter &typeConverter,
537 bool reassociateFPRed)
538 : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
539 typeConverter),
540 reassociateFPReductions(reassociateFPRed) {}
541
542 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const543 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
544 ConversionPatternRewriter &rewriter) const override {
545 auto reductionOp = cast<vector::ReductionOp>(op);
546 auto kind = reductionOp.kind();
547 Type eltType = reductionOp.dest().getType();
548 Type llvmType = typeConverter->convertType(eltType);
549 if (eltType.isIntOrIndex()) {
550 // Integer reductions: add/mul/min/max/and/or/xor.
551 if (kind == "add")
552 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
553 op, llvmType, operands[0]);
554 else if (kind == "mul")
555 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
556 op, llvmType, operands[0]);
557 else if (kind == "min" &&
558 (eltType.isIndex() || eltType.isUnsignedInteger()))
559 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
560 op, llvmType, operands[0]);
561 else if (kind == "min")
562 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
563 op, llvmType, operands[0]);
564 else if (kind == "max" &&
565 (eltType.isIndex() || eltType.isUnsignedInteger()))
566 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
567 op, llvmType, operands[0]);
568 else if (kind == "max")
569 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
570 op, llvmType, operands[0]);
571 else if (kind == "and")
572 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
573 op, llvmType, operands[0]);
574 else if (kind == "or")
575 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
576 op, llvmType, operands[0]);
577 else if (kind == "xor")
578 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
579 op, llvmType, operands[0]);
580 else
581 return failure();
582 return success();
583 }
584
585 if (!eltType.isa<FloatType>())
586 return failure();
587
588 // Floating-point reductions: add/mul/min/max
589 if (kind == "add") {
590 // Optional accumulator (or zero).
591 Value acc = operands.size() > 1 ? operands[1]
592 : rewriter.create<LLVM::ConstantOp>(
593 op->getLoc(), llvmType,
594 rewriter.getZeroAttr(eltType));
595 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
596 op, llvmType, acc, operands[0],
597 rewriter.getBoolAttr(reassociateFPReductions));
598 } else if (kind == "mul") {
599 // Optional accumulator (or one).
600 Value acc = operands.size() > 1
601 ? operands[1]
602 : rewriter.create<LLVM::ConstantOp>(
603 op->getLoc(), llvmType,
604 rewriter.getFloatAttr(eltType, 1.0));
605 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
606 op, llvmType, acc, operands[0],
607 rewriter.getBoolAttr(reassociateFPReductions));
608 } else if (kind == "min")
609 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
610 operands[0]);
611 else if (kind == "max")
612 rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
613 operands[0]);
614 else
615 return failure();
616 return success();
617 }
618
619 private:
620 const bool reassociateFPReductions;
621 };
622
623 /// Conversion pattern for a vector.create_mask (1-D only).
624 class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
625 public:
VectorCreateMaskOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter,bool enableIndexOpt)626 explicit VectorCreateMaskOpConversion(MLIRContext *context,
627 LLVMTypeConverter &typeConverter,
628 bool enableIndexOpt)
629 : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
630 typeConverter),
631 enableIndexOptimizations(enableIndexOpt) {}
632
633 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const634 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
635 ConversionPatternRewriter &rewriter) const override {
636 auto dstType = op->getResult(0).getType().cast<VectorType>();
637 int64_t rank = dstType.getRank();
638 if (rank == 1) {
639 rewriter.replaceOp(
640 op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
641 dstType.getDimSize(0), operands[0]));
642 return success();
643 }
644 return failure();
645 }
646
647 private:
648 const bool enableIndexOptimizations;
649 };
650
651 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
652 public:
VectorShuffleOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)653 explicit VectorShuffleOpConversion(MLIRContext *context,
654 LLVMTypeConverter &typeConverter)
655 : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
656 typeConverter) {}
657
658 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const659 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
660 ConversionPatternRewriter &rewriter) const override {
661 auto loc = op->getLoc();
662 auto adaptor = vector::ShuffleOpAdaptor(operands);
663 auto shuffleOp = cast<vector::ShuffleOp>(op);
664 auto v1Type = shuffleOp.getV1VectorType();
665 auto v2Type = shuffleOp.getV2VectorType();
666 auto vectorType = shuffleOp.getVectorType();
667 Type llvmType = typeConverter->convertType(vectorType);
668 auto maskArrayAttr = shuffleOp.mask();
669
670 // Bail if result type cannot be lowered.
671 if (!llvmType)
672 return failure();
673
674 // Get rank and dimension sizes.
675 int64_t rank = vectorType.getRank();
676 assert(v1Type.getRank() == rank);
677 assert(v2Type.getRank() == rank);
678 int64_t v1Dim = v1Type.getDimSize(0);
679
680 // For rank 1, where both operands have *exactly* the same vector type,
681 // there is direct shuffle support in LLVM. Use it!
682 if (rank == 1 && v1Type == v2Type) {
683 Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
684 loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
685 rewriter.replaceOp(op, shuffle);
686 return success();
687 }
688
689 // For all other cases, insert the individual values individually.
690 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
691 int64_t insPos = 0;
692 for (auto en : llvm::enumerate(maskArrayAttr)) {
693 int64_t extPos = en.value().cast<IntegerAttr>().getInt();
694 Value value = adaptor.v1();
695 if (extPos >= v1Dim) {
696 extPos -= v1Dim;
697 value = adaptor.v2();
698 }
699 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
700 llvmType, rank, extPos);
701 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
702 llvmType, rank, insPos++);
703 }
704 rewriter.replaceOp(op, insert);
705 return success();
706 }
707 };
708
709 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
710 public:
VectorExtractElementOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)711 explicit VectorExtractElementOpConversion(MLIRContext *context,
712 LLVMTypeConverter &typeConverter)
713 : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
714 context, typeConverter) {}
715
716 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const717 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
718 ConversionPatternRewriter &rewriter) const override {
719 auto adaptor = vector::ExtractElementOpAdaptor(operands);
720 auto extractEltOp = cast<vector::ExtractElementOp>(op);
721 auto vectorType = extractEltOp.getVectorType();
722 auto llvmType = typeConverter->convertType(vectorType.getElementType());
723
724 // Bail if result type cannot be lowered.
725 if (!llvmType)
726 return failure();
727
728 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
729 op, llvmType, adaptor.vector(), adaptor.position());
730 return success();
731 }
732 };
733
734 class VectorExtractOpConversion : public ConvertToLLVMPattern {
735 public:
VectorExtractOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)736 explicit VectorExtractOpConversion(MLIRContext *context,
737 LLVMTypeConverter &typeConverter)
738 : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
739 typeConverter) {}
740
741 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const742 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
743 ConversionPatternRewriter &rewriter) const override {
744 auto loc = op->getLoc();
745 auto adaptor = vector::ExtractOpAdaptor(operands);
746 auto extractOp = cast<vector::ExtractOp>(op);
747 auto vectorType = extractOp.getVectorType();
748 auto resultType = extractOp.getResult().getType();
749 auto llvmResultType = typeConverter->convertType(resultType);
750 auto positionArrayAttr = extractOp.position();
751
752 // Bail if result type cannot be lowered.
753 if (!llvmResultType)
754 return failure();
755
756 // One-shot extraction of vector from array (only requires extractvalue).
757 if (resultType.isa<VectorType>()) {
758 Value extracted = rewriter.create<LLVM::ExtractValueOp>(
759 loc, llvmResultType, adaptor.vector(), positionArrayAttr);
760 rewriter.replaceOp(op, extracted);
761 return success();
762 }
763
764 // Potential extraction of 1-D vector from array.
765 auto *context = op->getContext();
766 Value extracted = adaptor.vector();
767 auto positionAttrs = positionArrayAttr.getValue();
768 if (positionAttrs.size() > 1) {
769 auto oneDVectorType = reducedVectorTypeBack(vectorType);
770 auto nMinusOnePositionAttrs =
771 ArrayAttr::get(positionAttrs.drop_back(), context);
772 extracted = rewriter.create<LLVM::ExtractValueOp>(
773 loc, typeConverter->convertType(oneDVectorType), extracted,
774 nMinusOnePositionAttrs);
775 }
776
777 // Remaining extraction of element from 1-D LLVM vector
778 auto position = positionAttrs.back().cast<IntegerAttr>();
779 auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
780 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
781 extracted =
782 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
783 rewriter.replaceOp(op, extracted);
784
785 return success();
786 }
787 };
788
789 /// Conversion pattern that turns a vector.fma on a 1-D vector
790 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
791 /// This does not match vectors of n >= 2 rank.
792 ///
793 /// Example:
794 /// ```
795 /// vector.fma %a, %a, %a : vector<8xf32>
796 /// ```
797 /// is converted to:
798 /// ```
799 /// llvm.intr.fmuladd %va, %va, %va:
800 /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
801 /// -> !llvm<"<8 x float>">
802 /// ```
803 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
804 public:
VectorFMAOp1DConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)805 explicit VectorFMAOp1DConversion(MLIRContext *context,
806 LLVMTypeConverter &typeConverter)
807 : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
808 typeConverter) {}
809
810 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const811 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
812 ConversionPatternRewriter &rewriter) const override {
813 auto adaptor = vector::FMAOpAdaptor(operands);
814 vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
815 VectorType vType = fmaOp.getVectorType();
816 if (vType.getRank() != 1)
817 return failure();
818 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
819 adaptor.rhs(), adaptor.acc());
820 return success();
821 }
822 };
823
824 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
825 public:
VectorInsertElementOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)826 explicit VectorInsertElementOpConversion(MLIRContext *context,
827 LLVMTypeConverter &typeConverter)
828 : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
829 context, typeConverter) {}
830
831 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const832 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
833 ConversionPatternRewriter &rewriter) const override {
834 auto adaptor = vector::InsertElementOpAdaptor(operands);
835 auto insertEltOp = cast<vector::InsertElementOp>(op);
836 auto vectorType = insertEltOp.getDestVectorType();
837 auto llvmType = typeConverter->convertType(vectorType);
838
839 // Bail if result type cannot be lowered.
840 if (!llvmType)
841 return failure();
842
843 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
844 op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
845 return success();
846 }
847 };
848
849 class VectorInsertOpConversion : public ConvertToLLVMPattern {
850 public:
VectorInsertOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)851 explicit VectorInsertOpConversion(MLIRContext *context,
852 LLVMTypeConverter &typeConverter)
853 : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
854 typeConverter) {}
855
856 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const857 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
858 ConversionPatternRewriter &rewriter) const override {
859 auto loc = op->getLoc();
860 auto adaptor = vector::InsertOpAdaptor(operands);
861 auto insertOp = cast<vector::InsertOp>(op);
862 auto sourceType = insertOp.getSourceType();
863 auto destVectorType = insertOp.getDestVectorType();
864 auto llvmResultType = typeConverter->convertType(destVectorType);
865 auto positionArrayAttr = insertOp.position();
866
867 // Bail if result type cannot be lowered.
868 if (!llvmResultType)
869 return failure();
870
871 // One-shot insertion of a vector into an array (only requires insertvalue).
872 if (sourceType.isa<VectorType>()) {
873 Value inserted = rewriter.create<LLVM::InsertValueOp>(
874 loc, llvmResultType, adaptor.dest(), adaptor.source(),
875 positionArrayAttr);
876 rewriter.replaceOp(op, inserted);
877 return success();
878 }
879
880 // Potential extraction of 1-D vector from array.
881 auto *context = op->getContext();
882 Value extracted = adaptor.dest();
883 auto positionAttrs = positionArrayAttr.getValue();
884 auto position = positionAttrs.back().cast<IntegerAttr>();
885 auto oneDVectorType = destVectorType;
886 if (positionAttrs.size() > 1) {
887 oneDVectorType = reducedVectorTypeBack(destVectorType);
888 auto nMinusOnePositionAttrs =
889 ArrayAttr::get(positionAttrs.drop_back(), context);
890 extracted = rewriter.create<LLVM::ExtractValueOp>(
891 loc, typeConverter->convertType(oneDVectorType), extracted,
892 nMinusOnePositionAttrs);
893 }
894
895 // Insertion of an element into a 1-D LLVM vector.
896 auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
897 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
898 Value inserted = rewriter.create<LLVM::InsertElementOp>(
899 loc, typeConverter->convertType(oneDVectorType), extracted,
900 adaptor.source(), constant);
901
902 // Potential insertion of resulting 1-D vector into array.
903 if (positionAttrs.size() > 1) {
904 auto nMinusOnePositionAttrs =
905 ArrayAttr::get(positionAttrs.drop_back(), context);
906 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
907 adaptor.dest(), inserted,
908 nMinusOnePositionAttrs);
909 }
910
911 rewriter.replaceOp(op, inserted);
912 return success();
913 }
914 };
915
916 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
917 ///
918 /// Example:
919 /// ```
920 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
921 /// ```
922 /// is rewritten into:
923 /// ```
924 /// %r = splat %f0: vector<2x4xf32>
925 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
926 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
927 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
928 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
929 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
930 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
931 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
932 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
933 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
934 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
935 /// // %r3 holds the final value.
936 /// ```
937 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
938 public:
939 using OpRewritePattern<FMAOp>::OpRewritePattern;
940
matchAndRewrite(FMAOp op,PatternRewriter & rewriter) const941 LogicalResult matchAndRewrite(FMAOp op,
942 PatternRewriter &rewriter) const override {
943 auto vType = op.getVectorType();
944 if (vType.getRank() < 2)
945 return failure();
946
947 auto loc = op.getLoc();
948 auto elemType = vType.getElementType();
949 Value zero = rewriter.create<ConstantOp>(loc, elemType,
950 rewriter.getZeroAttr(elemType));
951 Value desc = rewriter.create<SplatOp>(loc, vType, zero);
952 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
953 Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
954 Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
955 Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
956 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
957 desc = rewriter.create<InsertOp>(loc, fma, desc, i);
958 }
959 rewriter.replaceOp(op, desc);
960 return success();
961 }
962 };
963
964 // When ranks are different, InsertStridedSlice needs to extract a properly
965 // ranked vector from the destination vector into which to insert. This pattern
966 // only takes care of this part and forwards the rest of the conversion to
967 // another pattern that converts InsertStridedSlice for operands of the same
968 // rank.
969 //
970 // RewritePattern for InsertStridedSliceOp where source and destination vectors
971 // have different ranks. In this case:
972 // 1. the proper subvector is extracted from the destination vector
973 // 2. a new InsertStridedSlice op is created to insert the source in the
974 // destination subvector
975 // 3. the destination subvector is inserted back in the proper place
976 // 4. the op is replaced by the result of step 3.
977 // The new InsertStridedSlice from step 2. will be picked up by a
978 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
979 class VectorInsertStridedSliceOpDifferentRankRewritePattern
980 : public OpRewritePattern<InsertStridedSliceOp> {
981 public:
982 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
983
matchAndRewrite(InsertStridedSliceOp op,PatternRewriter & rewriter) const984 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
985 PatternRewriter &rewriter) const override {
986 auto srcType = op.getSourceVectorType();
987 auto dstType = op.getDestVectorType();
988
989 if (op.offsets().getValue().empty())
990 return failure();
991
992 auto loc = op.getLoc();
993 int64_t rankDiff = dstType.getRank() - srcType.getRank();
994 assert(rankDiff >= 0);
995 if (rankDiff == 0)
996 return failure();
997
998 int64_t rankRest = dstType.getRank() - rankDiff;
999 // Extract / insert the subvector of matching rank and InsertStridedSlice
1000 // on it.
1001 Value extracted =
1002 rewriter.create<ExtractOp>(loc, op.dest(),
1003 getI64SubArray(op.offsets(), /*dropFront=*/0,
1004 /*dropBack=*/rankRest));
1005 // A different pattern will kick in for InsertStridedSlice with matching
1006 // ranks.
1007 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
1008 loc, op.source(), extracted,
1009 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1010 getI64SubArray(op.strides(), /*dropFront=*/0));
1011 rewriter.replaceOpWithNewOp<InsertOp>(
1012 op, stridedSliceInnerOp.getResult(), op.dest(),
1013 getI64SubArray(op.offsets(), /*dropFront=*/0,
1014 /*dropBack=*/rankRest));
1015 return success();
1016 }
1017 };
1018
1019 // RewritePattern for InsertStridedSliceOp where source and destination vectors
1020 // have the same rank. In this case, we reduce
1021 // 1. the proper subvector is extracted from the destination vector
1022 // 2. a new InsertStridedSlice op is created to insert the source in the
1023 // destination subvector
1024 // 3. the destination subvector is inserted back in the proper place
1025 // 4. the op is replaced by the result of step 3.
1026 // The new InsertStridedSlice from step 2. will be picked up by a
1027 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
1028 class VectorInsertStridedSliceOpSameRankRewritePattern
1029 : public OpRewritePattern<InsertStridedSliceOp> {
1030 public:
VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext * ctx)1031 VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1032 : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1033 // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1034 // bounded as the rank is strictly decreasing.
1035 setHasBoundedRewriteRecursion();
1036 }
1037
matchAndRewrite(InsertStridedSliceOp op,PatternRewriter & rewriter) const1038 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
1039 PatternRewriter &rewriter) const override {
1040 auto srcType = op.getSourceVectorType();
1041 auto dstType = op.getDestVectorType();
1042
1043 if (op.offsets().getValue().empty())
1044 return failure();
1045
1046 int64_t rankDiff = dstType.getRank() - srcType.getRank();
1047 assert(rankDiff >= 0);
1048 if (rankDiff != 0)
1049 return failure();
1050
1051 if (srcType == dstType) {
1052 rewriter.replaceOp(op, op.source());
1053 return success();
1054 }
1055
1056 int64_t offset =
1057 op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1058 int64_t size = srcType.getShape().front();
1059 int64_t stride =
1060 op.strides().getValue().front().cast<IntegerAttr>().getInt();
1061
1062 auto loc = op.getLoc();
1063 Value res = op.dest();
1064 // For each slice of the source vector along the most major dimension.
1065 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1066 off += stride, ++idx) {
1067 // 1. extract the proper subvector (or element) from source
1068 Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1069 if (extractedSource.getType().isa<VectorType>()) {
1070 // 2. If we have a vector, extract the proper subvector from destination
1071 // Otherwise we are at the element level and no need to recurse.
1072 Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1073 // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1074 // smaller rank.
1075 extractedSource = rewriter.create<InsertStridedSliceOp>(
1076 loc, extractedSource, extractedDest,
1077 getI64SubArray(op.offsets(), /* dropFront=*/1),
1078 getI64SubArray(op.strides(), /* dropFront=*/1));
1079 }
1080 // 4. Insert the extractedSource into the res vector.
1081 res = insertOne(rewriter, loc, extractedSource, res, off);
1082 }
1083
1084 rewriter.replaceOp(op, res);
1085 return success();
1086 }
1087 };
1088
1089 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1090 /// static layout.
1091 static llvm::Optional<SmallVector<int64_t, 4>>
computeContiguousStrides(MemRefType memRefType)1092 computeContiguousStrides(MemRefType memRefType) {
1093 int64_t offset;
1094 SmallVector<int64_t, 4> strides;
1095 if (failed(getStridesAndOffset(memRefType, strides, offset)))
1096 return None;
1097 if (!strides.empty() && strides.back() != 1)
1098 return None;
1099 // If no layout or identity layout, this is contiguous by definition.
1100 if (memRefType.getAffineMaps().empty() ||
1101 memRefType.getAffineMaps().front().isIdentity())
1102 return strides;
1103
1104 // Otherwise, we must determine contiguity form shapes. This can only ever
1105 // work in static cases because MemRefType is underspecified to represent
1106 // contiguous dynamic shapes in other ways than with just empty/identity
1107 // layout.
1108 auto sizes = memRefType.getShape();
1109 for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1110 if (ShapedType::isDynamic(sizes[index + 1]) ||
1111 ShapedType::isDynamicStrideOrOffset(strides[index]) ||
1112 ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
1113 return None;
1114 if (strides[index] != strides[index + 1] * sizes[index + 1])
1115 return None;
1116 }
1117 return strides;
1118 }
1119
1120 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
1121 public:
VectorTypeCastOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)1122 explicit VectorTypeCastOpConversion(MLIRContext *context,
1123 LLVMTypeConverter &typeConverter)
1124 : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
1125 typeConverter) {}
1126
1127 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1128 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1129 ConversionPatternRewriter &rewriter) const override {
1130 auto loc = op->getLoc();
1131 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
1132 MemRefType sourceMemRefType =
1133 castOp.getOperand().getType().cast<MemRefType>();
1134 MemRefType targetMemRefType =
1135 castOp.getResult().getType().cast<MemRefType>();
1136
1137 // Only static shape casts supported atm.
1138 if (!sourceMemRefType.hasStaticShape() ||
1139 !targetMemRefType.hasStaticShape())
1140 return failure();
1141
1142 auto llvmSourceDescriptorTy =
1143 operands[0].getType().dyn_cast<LLVM::LLVMType>();
1144 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
1145 return failure();
1146 MemRefDescriptor sourceMemRef(operands[0]);
1147
1148 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1149 .dyn_cast_or_null<LLVM::LLVMType>();
1150 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
1151 return failure();
1152
1153 // Only contiguous source buffers supported atm.
1154 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1155 if (!sourceStrides)
1156 return failure();
1157 auto targetStrides = computeContiguousStrides(targetMemRefType);
1158 if (!targetStrides)
1159 return failure();
1160 // Only support static strides for now, regardless of contiguity.
1161 if (llvm::any_of(*targetStrides, [](int64_t stride) {
1162 return ShapedType::isDynamicStrideOrOffset(stride);
1163 }))
1164 return failure();
1165
1166 auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
1167
1168 // Create descriptor.
1169 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1170 Type llvmTargetElementTy = desc.getElementPtrType();
1171 // Set allocated ptr.
1172 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1173 allocated =
1174 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1175 desc.setAllocatedPtr(rewriter, loc, allocated);
1176 // Set aligned ptr.
1177 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1178 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1179 desc.setAlignedPtr(rewriter, loc, ptr);
1180 // Fill offset 0.
1181 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1182 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1183 desc.setOffset(rewriter, loc, zero);
1184
1185 // Fill size and stride descriptors in memref.
1186 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1187 int64_t index = indexedSize.index();
1188 auto sizeAttr =
1189 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1190 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1191 desc.setSize(rewriter, loc, index, size);
1192 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1193 (*targetStrides)[index]);
1194 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1195 desc.setStride(rewriter, loc, index, stride);
1196 }
1197
1198 rewriter.replaceOp(op, {desc});
1199 return success();
1200 }
1201 };
1202
1203 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1204 /// sequence of:
1205 /// 1. Get the source/dst address as an LLVM vector pointer.
1206 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1207 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1208 /// 4. Create a mask where offsetVector is compared against memref upper bound.
1209 /// 5. Rewrite op as a masked read or write.
1210 template <typename ConcreteOp>
1211 class VectorTransferConversion : public ConvertToLLVMPattern {
1212 public:
VectorTransferConversion(MLIRContext * context,LLVMTypeConverter & typeConv,bool enableIndexOpt)1213 explicit VectorTransferConversion(MLIRContext *context,
1214 LLVMTypeConverter &typeConv,
1215 bool enableIndexOpt)
1216 : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1217 enableIndexOptimizations(enableIndexOpt) {}
1218
1219 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1220 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1221 ConversionPatternRewriter &rewriter) const override {
1222 auto xferOp = cast<ConcreteOp>(op);
1223 auto adaptor = getTransferOpAdapter(xferOp, operands);
1224
1225 if (xferOp.getVectorType().getRank() > 1 ||
1226 llvm::size(xferOp.indices()) == 0)
1227 return failure();
1228 if (xferOp.permutation_map() !=
1229 AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1230 xferOp.getVectorType().getRank(),
1231 op->getContext()))
1232 return failure();
1233 // Only contiguous source tensors supported atm.
1234 auto strides = computeContiguousStrides(xferOp.getMemRefType());
1235 if (!strides)
1236 return failure();
1237
1238 auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
1239
1240 Location loc = op->getLoc();
1241 MemRefType memRefType = xferOp.getMemRefType();
1242
1243 if (auto memrefVectorElementType =
1244 memRefType.getElementType().dyn_cast<VectorType>()) {
1245 // Memref has vector element type.
1246 if (memrefVectorElementType.getElementType() !=
1247 xferOp.getVectorType().getElementType())
1248 return failure();
1249 #ifndef NDEBUG
1250 // Check that memref vector type is a suffix of 'vectorType.
1251 unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1252 unsigned resultVecRank = xferOp.getVectorType().getRank();
1253 assert(memrefVecEltRank <= resultVecRank);
1254 // TODO: Move this to isSuffix in Vector/Utils.h.
1255 unsigned rankOffset = resultVecRank - memrefVecEltRank;
1256 auto memrefVecEltShape = memrefVectorElementType.getShape();
1257 auto resultVecShape = xferOp.getVectorType().getShape();
1258 for (unsigned i = 0; i < memrefVecEltRank; ++i)
1259 assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1260 "memref vector element shape should match suffix of vector "
1261 "result shape.");
1262 #endif // ifndef NDEBUG
1263 }
1264
1265 // 1. Get the source/dst address as an LLVM vector pointer.
1266 // The vector pointer would always be on address space 0, therefore
1267 // addrspacecast shall be used when source/dst memrefs are not on
1268 // address space 0.
1269 // TODO: support alignment when possible.
1270 Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
1271 adaptor.indices(), rewriter);
1272 auto vecTy =
1273 toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1274 Value vectorDataPtr;
1275 if (memRefType.getMemorySpace() == 0)
1276 vectorDataPtr =
1277 rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1278 else
1279 vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1280 loc, vecTy.getPointerTo(), dataPtr);
1281
1282 if (!xferOp.isMaskedDim(0))
1283 return replaceTransferOpWithLoadOrStore(
1284 rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
1285
1286 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1287 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1288 // 4. Let dim the memref dimension, compute the vector comparison mask:
1289 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1290 //
1291 // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1292 // dimensions here.
1293 unsigned vecWidth = vecTy.getVectorNumElements();
1294 unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1295 Value off = xferOp.indices()[lastIndex];
1296 Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1297 Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1298 vecWidth, dim, &off);
1299
1300 // 5. Rewrite as a masked read / write.
1301 return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
1302 xferOp, operands, vectorDataPtr, mask);
1303 }
1304
1305 private:
1306 const bool enableIndexOptimizations;
1307 };
1308
1309 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1310 public:
VectorPrintOpConversion(MLIRContext * context,LLVMTypeConverter & typeConverter)1311 explicit VectorPrintOpConversion(MLIRContext *context,
1312 LLVMTypeConverter &typeConverter)
1313 : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1314 typeConverter) {}
1315
1316 // Proof-of-concept lowering implementation that relies on a small
1317 // runtime support library, which only needs to provide a few
1318 // printing methods (single value for all data types, opening/closing
1319 // bracket, comma, newline). The lowering fully unrolls a vector
1320 // in terms of these elementary printing operations. The advantage
1321 // of this approach is that the library can remain unaware of all
1322 // low-level implementation details of vectors while still supporting
1323 // output of any shaped and dimensioned vector. Due to full unrolling,
1324 // this approach is less suited for very large vectors though.
1325 //
1326 // TODO: rely solely on libc in future? something else?
1327 //
1328 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1329 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1330 ConversionPatternRewriter &rewriter) const override {
1331 auto printOp = cast<vector::PrintOp>(op);
1332 auto adaptor = vector::PrintOpAdaptor(operands);
1333 Type printType = printOp.getPrintType();
1334
1335 if (typeConverter->convertType(printType) == nullptr)
1336 return failure();
1337
1338 // Make sure element type has runtime support.
1339 PrintConversion conversion = PrintConversion::None;
1340 VectorType vectorType = printType.dyn_cast<VectorType>();
1341 Type eltType = vectorType ? vectorType.getElementType() : printType;
1342 Operation *printer;
1343 if (eltType.isF32()) {
1344 printer = getPrintFloat(op);
1345 } else if (eltType.isF64()) {
1346 printer = getPrintDouble(op);
1347 } else if (eltType.isIndex()) {
1348 printer = getPrintU64(op);
1349 } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1350 // Integers need a zero or sign extension on the operand
1351 // (depending on the source type) as well as a signed or
1352 // unsigned print method. Up to 64-bit is supported.
1353 unsigned width = intTy.getWidth();
1354 if (intTy.isUnsigned()) {
1355 if (width <= 64) {
1356 if (width < 64)
1357 conversion = PrintConversion::ZeroExt64;
1358 printer = getPrintU64(op);
1359 } else {
1360 return failure();
1361 }
1362 } else {
1363 assert(intTy.isSignless() || intTy.isSigned());
1364 if (width <= 64) {
1365 // Note that we *always* zero extend booleans (1-bit integers),
1366 // so that true/false is printed as 1/0 rather than -1/0.
1367 if (width == 1)
1368 conversion = PrintConversion::ZeroExt64;
1369 else if (width < 64)
1370 conversion = PrintConversion::SignExt64;
1371 printer = getPrintI64(op);
1372 } else {
1373 return failure();
1374 }
1375 }
1376 } else {
1377 return failure();
1378 }
1379
1380 // Unroll vector into elementary print calls.
1381 int64_t rank = vectorType ? vectorType.getRank() : 0;
1382 emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
1383 conversion);
1384 emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1385 rewriter.eraseOp(op);
1386 return success();
1387 }
1388
1389 private:
1390 enum class PrintConversion {
1391 // clang-format off
1392 None,
1393 ZeroExt64,
1394 SignExt64
1395 // clang-format on
1396 };
1397
emitRanks(ConversionPatternRewriter & rewriter,Operation * op,Value value,VectorType vectorType,Operation * printer,int64_t rank,PrintConversion conversion) const1398 void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1399 Value value, VectorType vectorType, Operation *printer,
1400 int64_t rank, PrintConversion conversion) const {
1401 Location loc = op->getLoc();
1402 if (rank == 0) {
1403 switch (conversion) {
1404 case PrintConversion::ZeroExt64:
1405 value = rewriter.create<ZeroExtendIOp>(
1406 loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1407 break;
1408 case PrintConversion::SignExt64:
1409 value = rewriter.create<SignExtendIOp>(
1410 loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1411 break;
1412 case PrintConversion::None:
1413 break;
1414 }
1415 emitCall(rewriter, loc, printer, value);
1416 return;
1417 }
1418
1419 emitCall(rewriter, loc, getPrintOpen(op));
1420 Operation *printComma = getPrintComma(op);
1421 int64_t dim = vectorType.getDimSize(0);
1422 for (int64_t d = 0; d < dim; ++d) {
1423 auto reducedType =
1424 rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1425 auto llvmType = typeConverter->convertType(
1426 rank > 1 ? reducedType : vectorType.getElementType());
1427 Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1428 llvmType, rank, d);
1429 emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1430 conversion);
1431 if (d != dim - 1)
1432 emitCall(rewriter, loc, printComma);
1433 }
1434 emitCall(rewriter, loc, getPrintClose(op));
1435 }
1436
1437 // Helper to emit a call.
emitCall(ConversionPatternRewriter & rewriter,Location loc,Operation * ref,ValueRange params=ValueRange ())1438 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1439 Operation *ref, ValueRange params = ValueRange()) {
1440 rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1441 rewriter.getSymbolRefAttr(ref), params);
1442 }
1443
1444 // Helper for printer method declaration (first hit) and lookup.
getPrint(Operation * op,StringRef name,ArrayRef<LLVM::LLVMType> params)1445 static Operation *getPrint(Operation *op, StringRef name,
1446 ArrayRef<LLVM::LLVMType> params) {
1447 auto module = op->getParentOfType<ModuleOp>();
1448 auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1449 if (func)
1450 return func;
1451 OpBuilder moduleBuilder(module.getBodyRegion());
1452 return moduleBuilder.create<LLVM::LLVMFuncOp>(
1453 op->getLoc(), name,
1454 LLVM::LLVMType::getFunctionTy(
1455 LLVM::LLVMType::getVoidTy(op->getContext()), params,
1456 /*isVarArg=*/false));
1457 }
1458
1459 // Helpers for method names.
getPrintI64(Operation * op) const1460 Operation *getPrintI64(Operation *op) const {
1461 return getPrint(op, "printI64",
1462 LLVM::LLVMType::getInt64Ty(op->getContext()));
1463 }
getPrintU64(Operation * op) const1464 Operation *getPrintU64(Operation *op) const {
1465 return getPrint(op, "printU64",
1466 LLVM::LLVMType::getInt64Ty(op->getContext()));
1467 }
getPrintFloat(Operation * op) const1468 Operation *getPrintFloat(Operation *op) const {
1469 return getPrint(op, "printF32",
1470 LLVM::LLVMType::getFloatTy(op->getContext()));
1471 }
getPrintDouble(Operation * op) const1472 Operation *getPrintDouble(Operation *op) const {
1473 return getPrint(op, "printF64",
1474 LLVM::LLVMType::getDoubleTy(op->getContext()));
1475 }
getPrintOpen(Operation * op) const1476 Operation *getPrintOpen(Operation *op) const {
1477 return getPrint(op, "printOpen", {});
1478 }
getPrintClose(Operation * op) const1479 Operation *getPrintClose(Operation *op) const {
1480 return getPrint(op, "printClose", {});
1481 }
getPrintComma(Operation * op) const1482 Operation *getPrintComma(Operation *op) const {
1483 return getPrint(op, "printComma", {});
1484 }
getPrintNewline(Operation * op) const1485 Operation *getPrintNewline(Operation *op) const {
1486 return getPrint(op, "printNewline", {});
1487 }
1488 };
1489
1490 /// Progressive lowering of ExtractStridedSliceOp to either:
1491 /// 1. express single offset extract as a direct shuffle.
1492 /// 2. extract + lower rank strided_slice + insert for the n-D case.
1493 class VectorExtractStridedSliceOpConversion
1494 : public OpRewritePattern<ExtractStridedSliceOp> {
1495 public:
VectorExtractStridedSliceOpConversion(MLIRContext * ctx)1496 VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1497 : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1498 // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1499 // is bounded as the rank is strictly decreasing.
1500 setHasBoundedRewriteRecursion();
1501 }
1502
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const1503 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1504 PatternRewriter &rewriter) const override {
1505 auto dstType = op.getResult().getType().cast<VectorType>();
1506
1507 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1508
1509 int64_t offset =
1510 op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1511 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1512 int64_t stride =
1513 op.strides().getValue().front().cast<IntegerAttr>().getInt();
1514
1515 auto loc = op.getLoc();
1516 auto elemType = dstType.getElementType();
1517 assert(elemType.isSignlessIntOrIndexOrFloat());
1518
1519 // Single offset can be more efficiently shuffled.
1520 if (op.offsets().getValue().size() == 1) {
1521 SmallVector<int64_t, 4> offsets;
1522 offsets.reserve(size);
1523 for (int64_t off = offset, e = offset + size * stride; off < e;
1524 off += stride)
1525 offsets.push_back(off);
1526 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1527 op.vector(),
1528 rewriter.getI64ArrayAttr(offsets));
1529 return success();
1530 }
1531
1532 // Extract/insert on a lower ranked extract strided slice op.
1533 Value zero = rewriter.create<ConstantOp>(loc, elemType,
1534 rewriter.getZeroAttr(elemType));
1535 Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1536 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1537 off += stride, ++idx) {
1538 Value one = extractOne(rewriter, loc, op.vector(), off);
1539 Value extracted = rewriter.create<ExtractStridedSliceOp>(
1540 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1541 getI64SubArray(op.sizes(), /* dropFront=*/1),
1542 getI64SubArray(op.strides(), /* dropFront=*/1));
1543 res = insertOne(rewriter, loc, extracted, res, idx);
1544 }
1545 rewriter.replaceOp(op, res);
1546 return success();
1547 }
1548 };
1549
1550 } // namespace
1551
1552 /// Populate the given list with patterns that convert from Vector to LLVM.
populateVectorToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns,bool reassociateFPReductions,bool enableIndexOptimizations)1553 void mlir::populateVectorToLLVMConversionPatterns(
1554 LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1555 bool reassociateFPReductions, bool enableIndexOptimizations) {
1556 MLIRContext *ctx = converter.getDialect()->getContext();
1557 // clang-format off
1558 patterns.insert<VectorFMAOpNDRewritePattern,
1559 VectorInsertStridedSliceOpDifferentRankRewritePattern,
1560 VectorInsertStridedSliceOpSameRankRewritePattern,
1561 VectorExtractStridedSliceOpConversion>(ctx);
1562 patterns.insert<VectorReductionOpConversion>(
1563 ctx, converter, reassociateFPReductions);
1564 patterns.insert<VectorCreateMaskOpConversion,
1565 VectorTransferConversion<TransferReadOp>,
1566 VectorTransferConversion<TransferWriteOp>>(
1567 ctx, converter, enableIndexOptimizations);
1568 patterns
1569 .insert<VectorShuffleOpConversion,
1570 VectorExtractElementOpConversion,
1571 VectorExtractOpConversion,
1572 VectorFMAOp1DConversion,
1573 VectorInsertElementOpConversion,
1574 VectorInsertOpConversion,
1575 VectorPrintOpConversion,
1576 VectorTypeCastOpConversion,
1577 VectorMaskedLoadOpConversion,
1578 VectorMaskedStoreOpConversion,
1579 VectorGatherOpConversion,
1580 VectorScatterOpConversion,
1581 VectorExpandLoadOpConversion,
1582 VectorCompressStoreOpConversion>(ctx, converter);
1583 // clang-format on
1584 }
1585
populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)1586 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1587 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1588 MLIRContext *ctx = converter.getDialect()->getContext();
1589 patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1590 patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1591 }
1592