1 //===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // Vector operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
15
16 #include "../PassDetail.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
19 #include "mlir/Dialect/GPU/GPUDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26
27 using namespace mlir;
28 using namespace mlir::vector;
29
replaceTransferOpWithMubuf(ConversionPatternRewriter & rewriter,ArrayRef<Value> operands,LLVMTypeConverter & typeConverter,Location loc,TransferReadOp xferOp,LLVM::LLVMType & vecTy,Value & dwordConfig,Value & vindex,Value & offsetSizeInBytes,Value & glc,Value & slc)30 static LogicalResult replaceTransferOpWithMubuf(
31 ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
32 LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
33 LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
34 Value &offsetSizeInBytes, Value &glc, Value &slc) {
35 rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
36 xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
37 return success();
38 }
39
replaceTransferOpWithMubuf(ConversionPatternRewriter & rewriter,ArrayRef<Value> operands,LLVMTypeConverter & typeConverter,Location loc,TransferWriteOp xferOp,LLVM::LLVMType & vecTy,Value & dwordConfig,Value & vindex,Value & offsetSizeInBytes,Value & glc,Value & slc)40 static LogicalResult replaceTransferOpWithMubuf(
41 ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
42 LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
43 LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
44 Value &offsetSizeInBytes, Value &glc, Value &slc) {
45 auto adaptor = TransferWriteOpAdaptor(operands);
46 rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
47 dwordConfig, vindex,
48 offsetSizeInBytes, glc, slc);
49 return success();
50 }
51
52 namespace {
53 /// Conversion pattern that converts a 1-D vector transfer read/write.
54 /// Note that this conversion pass only converts vector x2 or x4 f32
55 /// types. For unsupported cases, they will fall back to the vector to
56 /// llvm conversion pattern.
57 template <typename ConcreteOp>
58 class VectorTransferConversion : public ConvertToLLVMPattern {
59 public:
VectorTransferConversion(MLIRContext * context,LLVMTypeConverter & typeConv)60 explicit VectorTransferConversion(MLIRContext *context,
61 LLVMTypeConverter &typeConv)
62 : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
63 typeConv) {}
64
65 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const66 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
67 ConversionPatternRewriter &rewriter) const override {
68 auto xferOp = cast<ConcreteOp>(op);
69 typename ConcreteOp::Adaptor adaptor(operands);
70
71 if (xferOp.getVectorType().getRank() > 1 ||
72 llvm::size(xferOp.indices()) == 0)
73 return failure();
74
75 if (!xferOp.permutation_map().isMinorIdentity())
76 return failure();
77
78 // Have it handled in vector->llvm conversion pass.
79 if (!xferOp.isMaskedDim(0))
80 return failure();
81
82 auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
83 LLVM::LLVMType vecTy =
84 toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
85 unsigned vecWidth = vecTy.getVectorNumElements();
86 Location loc = op->getLoc();
87
88 // The backend result vector scalarization have trouble scalarize
89 // <1 x ty> result, exclude the x1 width from the lowering.
90 if (vecWidth != 2 && vecWidth != 4)
91 return failure();
92
93 // Obtain dataPtr and elementType from the memref.
94 MemRefType memRefType = xferOp.getMemRefType();
95 // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
96 // In case of 3(LDS): fall back to vector->llvm pass
97 // In case of 5(VGPR): wrong
98 if ((memRefType.getMemorySpace() != 0) &&
99 (memRefType.getMemorySpace() != 1))
100 return failure();
101
102 // Note that the dataPtr starts at the offset address specified by
103 // indices, so no need to calculate offset size in bytes again in
104 // the MUBUF instruction.
105 Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
106 adaptor.indices(), rewriter);
107
108 // 1. Create and fill a <4 x i32> dwordConfig with:
109 // 1st two elements holding the address of dataPtr.
110 // 3rd element: -1.
111 // 4th element: 0x27000.
112 SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
113 Type i32Ty = rewriter.getIntegerType(32);
114 VectorType i32Vecx4 = VectorType::get(4, i32Ty);
115 Value constConfig = rewriter.create<LLVM::ConstantOp>(
116 loc, toLLVMTy(i32Vecx4),
117 DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
118
119 // Treat first two element of <4 x i32> as i64, and save the dataPtr
120 // to it.
121 Type i64Ty = rewriter.getIntegerType(64);
122 Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
123 loc,
124 LLVM::LLVMType::getVectorTy(
125 toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
126 constConfig);
127 Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
128 loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
129 Value zero = createIndexConstant(rewriter, loc, 0);
130 Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
131 loc,
132 LLVM::LLVMType::getVectorTy(
133 toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
134 i64x2Ty, dataPtrAsI64, zero);
135 dwordConfig =
136 rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
137
138 // 2. Rewrite op as a buffer read or write.
139 Value int1False = rewriter.create<LLVM::ConstantOp>(
140 loc, toLLVMTy(rewriter.getIntegerType(1)),
141 rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
142 Value int32Zero = rewriter.create<LLVM::ConstantOp>(
143 loc, toLLVMTy(i32Ty),
144 rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
145 return replaceTransferOpWithMubuf(
146 rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy,
147 dwordConfig, int32Zero, int32Zero, int1False, int1False);
148 }
149 };
150 } // end anonymous namespace
151
populateVectorToROCDLConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)152 void mlir::populateVectorToROCDLConversionPatterns(
153 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
154 MLIRContext *ctx = converter.getDialect()->getContext();
155 patterns.insert<VectorTransferConversion<TransferReadOp>,
156 VectorTransferConversion<TransferWriteOp>>(ctx, converter);
157 }
158
159 namespace {
160 struct LowerVectorToROCDLPass
161 : public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
162 void runOnOperation() override;
163 };
164 } // namespace
165
runOnOperation()166 void LowerVectorToROCDLPass::runOnOperation() {
167 LLVMTypeConverter converter(&getContext());
168 OwningRewritePatternList patterns;
169
170 populateVectorToROCDLConversionPatterns(converter, patterns);
171 populateStdToLLVMConversionPatterns(converter, patterns);
172
173 LLVMConversionTarget target(getContext());
174 target.addLegalDialect<ROCDL::ROCDLDialect>();
175
176 if (failed(
177 applyPartialConversion(getOperation(), target, std::move(patterns))))
178 signalPassFailure();
179 }
180
181 std::unique_ptr<OperationPass<ModuleOp>>
createConvertVectorToROCDLPass()182 mlir::createConvertVectorToROCDLPass() {
183 return std::make_unique<LowerVectorToROCDLPass>();
184 }
185