1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
16 #include "mlir/Dialect/SPIRV/Passes.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
19 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/SetVector.h"
22
23 using namespace mlir;
24
25 /// Creates a global variable for an argument based on the ABI info.
26 static spirv::GlobalVariableOp
createGlobalVarForEntryPointArgument(OpBuilder & builder,spirv::FuncOp funcOp,unsigned argIndex,spirv::InterfaceVarABIAttr abiInfo)27 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
28 unsigned argIndex,
29 spirv::InterfaceVarABIAttr abiInfo) {
30 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
31 if (!spirvModule)
32 return nullptr;
33
34 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
35 builder.setInsertionPoint(funcOp.getOperation());
36 std::string varName =
37 funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
38
39 // Get the type of variable. If this is a scalar/vector type and has an ABI
40 // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
41 // it must already be a !spv.ptr<!spv.struct<...>>.
42 auto varType = funcOp.getType().getInput(argIndex);
43 if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
44 auto storageClass = abiInfo.getStorageClass();
45 if (!storageClass)
46 return nullptr;
47 varType =
48 spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
49 }
50 auto varPtrType = varType.cast<spirv::PointerType>();
51 auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
52
53 // Set the offset information.
54 varPointeeType =
55 VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
56
57 if (!varPointeeType)
58 return nullptr;
59
60 varType =
61 spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
62
63 return builder.create<spirv::GlobalVariableOp>(
64 funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
65 abiInfo.getBinding());
66 }
67
68 /// Gets the global variables that need to be specified as interface variable
69 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
70 static LogicalResult
getInterfaceVariables(spirv::FuncOp funcOp,SmallVectorImpl<Attribute> & interfaceVars)71 getInterfaceVariables(spirv::FuncOp funcOp,
72 SmallVectorImpl<Attribute> &interfaceVars) {
73 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
74 if (!module) {
75 return failure();
76 }
77 llvm::SetVector<Operation *> interfaceVarSet;
78
79 // TODO: This should in reality traverse the entry function
80 // call graph and collect all the interfaces. For now, just traverse the
81 // instructions in this function.
82 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
83 auto var =
84 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
85 // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
86 // storage classes are limited to the Input and Output storage classes.
87 // Starting with version 1.4, the interface’s storage classes are all
88 // storage classes used in declaring all global variables referenced by the
89 // entry point’s call tree." We should consider the target environment here.
90 switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
91 case spirv::StorageClass::Input:
92 case spirv::StorageClass::Output:
93 interfaceVarSet.insert(var.getOperation());
94 break;
95 default:
96 break;
97 }
98 });
99 for (auto &var : interfaceVarSet) {
100 interfaceVars.push_back(SymbolRefAttr::get(
101 cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
102 }
103 return success();
104 }
105
106 /// Lowers the entry point attribute.
lowerEntryPointABIAttr(spirv::FuncOp funcOp,OpBuilder & builder)107 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
108 OpBuilder &builder) {
109 auto entryPointAttrName = spirv::getEntryPointABIAttrName();
110 auto entryPointAttr =
111 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
112 if (!entryPointAttr) {
113 return failure();
114 }
115
116 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
117 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
118 builder.setInsertionPoint(spirvModule.body().front().getTerminator());
119
120 // Adds the spv.EntryPointOp after collecting all the interface variables
121 // needed.
122 SmallVector<Attribute, 1> interfaceVars;
123 if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
124 return failure();
125 }
126
127 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(funcOp);
128 FailureOr<spirv::ExecutionModel> executionModel =
129 spirv::getExecutionModel(targetEnv);
130 if (failed(executionModel))
131 return funcOp.emitRemark("lower entry point failure: could not select "
132 "execution model based on 'spv.target_env'");
133
134 builder.create<spirv::EntryPointOp>(
135 funcOp.getLoc(), executionModel.getValue(), funcOp, interfaceVars);
136
137 // Specifies the spv.ExecutionModeOp.
138 auto localSizeAttr = entryPointAttr.local_size();
139 SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>());
140 builder.create<spirv::ExecutionModeOp>(
141 funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
142 funcOp.removeAttr(entryPointAttrName);
143 return success();
144 }
145
146 namespace {
147 /// A pattern to convert function signature according to interface variable ABI
148 /// attributes.
149 ///
150 /// Specifically, this pattern creates global variables according to interface
151 /// variable ABI attributes attached to function arguments and converts all
152 /// function argument uses to those global variables. This is necessary because
153 /// Vulkan requires all shader entry points to be of void(void) type.
154 class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> {
155 public:
156 using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering;
157 LogicalResult
158 matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
159 ConversionPatternRewriter &rewriter) const override;
160 };
161
162 /// Pass to implement the ABI information specified as attributes.
163 class LowerABIAttributesPass final
164 : public SPIRVLowerABIAttributesBase<LowerABIAttributesPass> {
165 void runOnOperation() override;
166 };
167 } // namespace
168
matchAndRewrite(spirv::FuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const169 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
170 spirv::FuncOp funcOp, ArrayRef<Value> operands,
171 ConversionPatternRewriter &rewriter) const {
172 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
173 spirv::getEntryPointABIAttrName())) {
174 // TODO: Non-entry point functions are not handled.
175 return failure();
176 }
177 TypeConverter::SignatureConversion signatureConverter(
178 funcOp.getType().getNumInputs());
179
180 auto attrName = spirv::getInterfaceVarABIAttrName();
181 for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) {
182 auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
183 argType.index(), attrName);
184 if (!abiInfo) {
185 // TODO: For non-entry point functions, it should be legal
186 // to pass around scalar/vector values and return a scalar/vector. For now
187 // non-entry point functions are not handled in this ABI lowering and will
188 // produce an error.
189 return failure();
190 }
191 spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
192 rewriter, funcOp, argType.index(), abiInfo);
193 if (!var)
194 return failure();
195
196 OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
197 rewriter.setInsertionPointToStart(&funcOp.front());
198 // Insert spirv::AddressOf and spirv::AccessChain operations.
199 Value replacement =
200 rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
201 // Check if the arg is a scalar or vector type. In that case, the value
202 // needs to be loaded into registers.
203 // TODO: This is loading value of the scalar into registers
204 // at the start of the function. It is probably better to do the load just
205 // before the use. There might be multiple loads and currently there is no
206 // easy way to replace all uses with a sequence of operations.
207 if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
208 auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext());
209 auto zero =
210 spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
211 auto loadPtr = rewriter.create<spirv::AccessChainOp>(
212 funcOp.getLoc(), replacement, zero.constant());
213 replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
214 }
215 signatureConverter.remapInput(argType.index(), replacement);
216 }
217 if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter,
218 &signatureConverter)))
219 return failure();
220
221 // Creates a new function with the update signature.
222 rewriter.updateRootInPlace(funcOp, [&] {
223 funcOp.setType(rewriter.getFunctionType(
224 signatureConverter.getConvertedTypes(), llvm::None));
225 });
226 return success();
227 }
228
runOnOperation()229 void LowerABIAttributesPass::runOnOperation() {
230 // Uses the signature conversion methodology of the dialect conversion
231 // framework to implement the conversion.
232 spirv::ModuleOp module = getOperation();
233 MLIRContext *context = &getContext();
234
235 spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
236
237 SPIRVTypeConverter typeConverter(targetEnv);
238
239 // Insert a bitcast in the case of a pointer type change.
240 typeConverter.addSourceMaterialization([](OpBuilder &builder,
241 spirv::PointerType type,
242 ValueRange inputs, Location loc) {
243 if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
244 return Value();
245 return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
246 });
247
248 OwningRewritePatternList patterns;
249 patterns.insert<ProcessInterfaceVarABI>(context, typeConverter);
250
251 ConversionTarget target(*context);
252 // "Legal" function ops should have no interface variable ABI attributes.
253 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
254 StringRef attrName = spirv::getInterfaceVarABIAttrName();
255 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
256 if (op.getArgAttr(i, attrName))
257 return false;
258 return true;
259 });
260 // All other SPIR-V ops are legal.
261 target.markUnknownOpDynamicallyLegal([](Operation *op) {
262 return op->getDialect()->getNamespace() ==
263 spirv::SPIRVDialect::getDialectNamespace();
264 });
265 if (failed(applyPartialConversion(module, target, std::move(patterns))))
266 return signalPassFailure();
267
268 // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
269 // attributes.
270 OpBuilder builder(context);
271 SmallVector<spirv::FuncOp, 1> entryPointFns;
272 auto entryPointAttrName = spirv::getEntryPointABIAttrName();
273 module.walk([&](spirv::FuncOp funcOp) {
274 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
275 entryPointFns.push_back(funcOp);
276 }
277 });
278 for (auto fn : entryPointFns) {
279 if (failed(lowerEntryPointABIAttr(fn, builder))) {
280 return signalPassFailure();
281 }
282 }
283 }
284
285 std::unique_ptr<OperationPass<spirv::ModuleOp>>
createLowerABIAttributesPass()286 mlir::spirv::createLowerABIAttributesPass() {
287 return std::make_unique<LowerABIAttributesPass>();
288 }
289