1 //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
10 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/FunctionSupport.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/SymbolTable.h"
15
16 using namespace mlir;
17
18 //===----------------------------------------------------------------------===//
19 // TargetEnv
20 //===----------------------------------------------------------------------===//
21
TargetEnv(spirv::TargetEnvAttr targetAttr)22 spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
23 : targetAttr(targetAttr) {
24 for (spirv::Extension ext : targetAttr.getExtensions())
25 givenExtensions.insert(ext);
26
27 // Add extensions implied by the current version.
28 for (spirv::Extension ext :
29 spirv::getImpliedExtensions(targetAttr.getVersion()))
30 givenExtensions.insert(ext);
31
32 for (spirv::Capability cap : targetAttr.getCapabilities()) {
33 givenCapabilities.insert(cap);
34
35 // Add capabilities implied by the current capability.
36 for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
37 givenCapabilities.insert(c);
38 }
39 }
40
getVersion() const41 spirv::Version spirv::TargetEnv::getVersion() const {
42 return targetAttr.getVersion();
43 }
44
allows(spirv::Capability capability) const45 bool spirv::TargetEnv::allows(spirv::Capability capability) const {
46 return givenCapabilities.count(capability);
47 }
48
49 Optional<spirv::Capability>
allows(ArrayRef<spirv::Capability> caps) const50 spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
51 const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
52 return givenCapabilities.count(cap);
53 });
54 if (chosen != caps.end())
55 return *chosen;
56 return llvm::None;
57 }
58
allows(spirv::Extension extension) const59 bool spirv::TargetEnv::allows(spirv::Extension extension) const {
60 return givenExtensions.count(extension);
61 }
62
63 Optional<spirv::Extension>
allows(ArrayRef<spirv::Extension> exts) const64 spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
65 const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
66 return givenExtensions.count(ext);
67 });
68 if (chosen != exts.end())
69 return *chosen;
70 return llvm::None;
71 }
72
getVendorID() const73 spirv::Vendor spirv::TargetEnv::getVendorID() const {
74 return targetAttr.getVendorID();
75 }
76
getDeviceType() const77 spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
78 return targetAttr.getDeviceType();
79 }
80
getDeviceID() const81 uint32_t spirv::TargetEnv::getDeviceID() const {
82 return targetAttr.getDeviceID();
83 }
84
getResourceLimits() const85 spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
86 return targetAttr.getResourceLimits();
87 }
88
getContext() const89 MLIRContext *spirv::TargetEnv::getContext() const {
90 return targetAttr.getContext();
91 }
92
93 //===----------------------------------------------------------------------===//
94 // Utility functions
95 //===----------------------------------------------------------------------===//
96
getInterfaceVarABIAttrName()97 StringRef spirv::getInterfaceVarABIAttrName() {
98 return "spv.interface_var_abi";
99 }
100
101 spirv::InterfaceVarABIAttr
getInterfaceVarABIAttr(unsigned descriptorSet,unsigned binding,Optional<spirv::StorageClass> storageClass,MLIRContext * context)102 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
103 Optional<spirv::StorageClass> storageClass,
104 MLIRContext *context) {
105 return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
106 context);
107 }
108
needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr)109 bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
110 for (spirv::Capability cap : targetAttr.getCapabilities()) {
111 if (cap == spirv::Capability::Kernel)
112 return false;
113 if (cap == spirv::Capability::Shader)
114 return true;
115 }
116 return false;
117 }
118
getEntryPointABIAttrName()119 StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }
120
121 spirv::EntryPointABIAttr
getEntryPointABIAttr(ArrayRef<int32_t> localSize,MLIRContext * context)122 spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
123 assert(localSize.size() == 3);
124 return spirv::EntryPointABIAttr::get(
125 DenseElementsAttr::get<int32_t>(
126 VectorType::get(3, IntegerType::get(32, context)), localSize)
127 .cast<DenseIntElementsAttr>(),
128 context);
129 }
130
lookupEntryPointABI(Operation * op)131 spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
132 while (op && !op->hasTrait<OpTrait::FunctionLike>())
133 op = op->getParentOp();
134 if (!op)
135 return {};
136
137 if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
138 spirv::getEntryPointABIAttrName()))
139 return attr;
140
141 return {};
142 }
143
lookupLocalWorkGroupSize(Operation * op)144 DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
145 if (auto entryPoint = spirv::lookupEntryPointABI(op))
146 return entryPoint.local_size();
147
148 return {};
149 }
150
151 spirv::ResourceLimitsAttr
getDefaultResourceLimits(MLIRContext * context)152 spirv::getDefaultResourceLimits(MLIRContext *context) {
153 // All the fields have default values. Here we just provide a nicer way to
154 // construct a default resource limit attribute.
155 return spirv::ResourceLimitsAttr ::get(
156 /*max_compute_shared_memory_size=*/nullptr,
157 /*max_compute_workgroup_invocations=*/nullptr,
158 /*max_compute_workgroup_size=*/nullptr,
159 /*subgroup_size=*/nullptr,
160 /*cooperative_matrix_properties_nv=*/nullptr, context);
161 }
162
getTargetEnvAttrName()163 StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; }
164
getDefaultTargetEnv(MLIRContext * context)165 spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
166 auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
167 {spirv::Capability::Shader},
168 ArrayRef<Extension>(), context);
169 return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
170 spirv::DeviceType::Unknown,
171 spirv::TargetEnvAttr::kUnknownDeviceID,
172 spirv::getDefaultResourceLimits(context));
173 }
174
lookupTargetEnv(Operation * op)175 spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
176 while (op) {
177 op = SymbolTable::getNearestSymbolTable(op);
178 if (!op)
179 break;
180
181 if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
182 spirv::getTargetEnvAttrName()))
183 return attr;
184
185 op = op->getParentOp();
186 }
187
188 return {};
189 }
190
lookupTargetEnvOrDefault(Operation * op)191 spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
192 if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
193 return attr;
194
195 return getDefaultTargetEnv(op->getContext());
196 }
197
198 spirv::AddressingModel
getAddressingModel(spirv::TargetEnvAttr targetAttr)199 spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr) {
200 for (spirv::Capability cap : targetAttr.getCapabilities()) {
201 // TODO: Physical64 is hard-coded here, but some information should come
202 // from TargetEnvAttr to selected between Physical32 and Physical64.
203 if (cap == Capability::Kernel)
204 return spirv::AddressingModel::Physical64;
205 }
206 // Logical addressing doesn't need any capabilities so return it as default.
207 return spirv::AddressingModel::Logical;
208 }
209
210 FailureOr<spirv::ExecutionModel>
getExecutionModel(spirv::TargetEnvAttr targetAttr)211 spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) {
212 for (spirv::Capability cap : targetAttr.getCapabilities()) {
213 if (cap == spirv::Capability::Kernel)
214 return spirv::ExecutionModel::Kernel;
215 if (cap == spirv::Capability::Shader)
216 return spirv::ExecutionModel::GLCompute;
217 }
218 return failure();
219 }
220
221 FailureOr<spirv::MemoryModel>
getMemoryModel(spirv::TargetEnvAttr targetAttr)222 spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) {
223 for (spirv::Capability cap : targetAttr.getCapabilities()) {
224 if (cap == spirv::Capability::Addresses)
225 return spirv::MemoryModel::OpenCL;
226 if (cap == spirv::Capability::Shader)
227 return spirv::MemoryModel::GLSL450;
228 }
229 return failure();
230 }
231