• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // Vector operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
15 
16 #include "../PassDetail.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
19 #include "mlir/Dialect/GPU/GPUDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
replaceTransferOpWithMubuf(ConversionPatternRewriter & rewriter,ArrayRef<Value> operands,LLVMTypeConverter & typeConverter,Location loc,TransferReadOp xferOp,LLVM::LLVMType & vecTy,Value & dwordConfig,Value & vindex,Value & offsetSizeInBytes,Value & glc,Value & slc)30 static LogicalResult replaceTransferOpWithMubuf(
31     ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
32     LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
33     LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
34     Value &offsetSizeInBytes, Value &glc, Value &slc) {
35   rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
36       xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
37   return success();
38 }
39 
replaceTransferOpWithMubuf(ConversionPatternRewriter & rewriter,ArrayRef<Value> operands,LLVMTypeConverter & typeConverter,Location loc,TransferWriteOp xferOp,LLVM::LLVMType & vecTy,Value & dwordConfig,Value & vindex,Value & offsetSizeInBytes,Value & glc,Value & slc)40 static LogicalResult replaceTransferOpWithMubuf(
41     ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
42     LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
43     LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
44     Value &offsetSizeInBytes, Value &glc, Value &slc) {
45   auto adaptor = TransferWriteOpAdaptor(operands);
46   rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
47                                                    dwordConfig, vindex,
48                                                    offsetSizeInBytes, glc, slc);
49   return success();
50 }
51 
52 namespace {
53 /// Conversion pattern that converts a 1-D vector transfer read/write.
54 /// Note that this conversion pass only converts vector x2 or x4 f32
55 /// types. For unsupported cases, they will fall back to the vector to
56 /// llvm conversion pattern.
57 template <typename ConcreteOp>
58 class VectorTransferConversion : public ConvertToLLVMPattern {
59 public:
VectorTransferConversion(MLIRContext * context,LLVMTypeConverter & typeConv)60   explicit VectorTransferConversion(MLIRContext *context,
61                                     LLVMTypeConverter &typeConv)
62       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
63                              typeConv) {}
64 
65   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const66   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
67                   ConversionPatternRewriter &rewriter) const override {
68     auto xferOp = cast<ConcreteOp>(op);
69     typename ConcreteOp::Adaptor adaptor(operands);
70 
71     if (xferOp.getVectorType().getRank() > 1 ||
72         llvm::size(xferOp.indices()) == 0)
73       return failure();
74 
75     if (!xferOp.permutation_map().isMinorIdentity())
76       return failure();
77 
78     // Have it handled in vector->llvm conversion pass.
79     if (!xferOp.isMaskedDim(0))
80       return failure();
81 
82     auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
83     LLVM::LLVMType vecTy =
84         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
85     unsigned vecWidth = vecTy.getVectorNumElements();
86     Location loc = op->getLoc();
87 
88     // The backend result vector scalarization have trouble scalarize
89     // <1 x ty> result, exclude the x1 width from the lowering.
90     if (vecWidth != 2 && vecWidth != 4)
91       return failure();
92 
93     // Obtain dataPtr and elementType from the memref.
94     MemRefType memRefType = xferOp.getMemRefType();
95     // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
96     // In case of 3(LDS): fall back to vector->llvm pass
97     // In case of 5(VGPR): wrong
98     if ((memRefType.getMemorySpace() != 0) &&
99         (memRefType.getMemorySpace() != 1))
100       return failure();
101 
102     // Note that the dataPtr starts at the offset address specified by
103     // indices, so no need to calculate offset size in bytes again in
104     // the MUBUF instruction.
105     Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
106                                          adaptor.indices(), rewriter);
107 
108     // 1. Create and fill a <4 x i32> dwordConfig with:
109     //    1st two elements holding the address of dataPtr.
110     //    3rd element: -1.
111     //    4th element: 0x27000.
112     SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
113     Type i32Ty = rewriter.getIntegerType(32);
114     VectorType i32Vecx4 = VectorType::get(4, i32Ty);
115     Value constConfig = rewriter.create<LLVM::ConstantOp>(
116         loc, toLLVMTy(i32Vecx4),
117         DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
118 
119     // Treat first two element of <4 x i32> as i64, and save the dataPtr
120     // to it.
121     Type i64Ty = rewriter.getIntegerType(64);
122     Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
123         loc,
124         LLVM::LLVMType::getVectorTy(
125             toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
126         constConfig);
127     Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
128         loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
129     Value zero = createIndexConstant(rewriter, loc, 0);
130     Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
131         loc,
132         LLVM::LLVMType::getVectorTy(
133             toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
134         i64x2Ty, dataPtrAsI64, zero);
135     dwordConfig =
136         rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
137 
138     // 2. Rewrite op as a buffer read or write.
139     Value int1False = rewriter.create<LLVM::ConstantOp>(
140         loc, toLLVMTy(rewriter.getIntegerType(1)),
141         rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
142     Value int32Zero = rewriter.create<LLVM::ConstantOp>(
143         loc, toLLVMTy(i32Ty),
144         rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
145     return replaceTransferOpWithMubuf(
146         rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy,
147         dwordConfig, int32Zero, int32Zero, int1False, int1False);
148   }
149 };
150 } // end anonymous namespace
151 
populateVectorToROCDLConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)152 void mlir::populateVectorToROCDLConversionPatterns(
153     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
154   MLIRContext *ctx = converter.getDialect()->getContext();
155   patterns.insert<VectorTransferConversion<TransferReadOp>,
156                   VectorTransferConversion<TransferWriteOp>>(ctx, converter);
157 }
158 
159 namespace {
160 struct LowerVectorToROCDLPass
161     : public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
162   void runOnOperation() override;
163 };
164 } // namespace
165 
runOnOperation()166 void LowerVectorToROCDLPass::runOnOperation() {
167   LLVMTypeConverter converter(&getContext());
168   OwningRewritePatternList patterns;
169 
170   populateVectorToROCDLConversionPatterns(converter, patterns);
171   populateStdToLLVMConversionPatterns(converter, patterns);
172 
173   LLVMConversionTarget target(getContext());
174   target.addLegalDialect<ROCDL::ROCDLDialect>();
175 
176   if (failed(
177           applyPartialConversion(getOperation(), target, std::move(patterns))))
178     signalPassFailure();
179 }
180 
181 std::unique_ptr<OperationPass<ModuleOp>>
createConvertVectorToROCDLPass()182 mlir::createConvertVectorToROCDLPass() {
183   return std::make_unique<LowerVectorToROCDLPass>();
184 }
185