1 //===- SPIRVAttributes.cpp - SPIR-V attribute definitions -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
10 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
11 #include "mlir/IR/Builders.h"
12
13 using namespace mlir;
14
15 //===----------------------------------------------------------------------===//
16 // DictionaryDict derived attributes
17 //===----------------------------------------------------------------------===//
18
19 #include "mlir/Dialect/SPIRV/TargetAndABI.cpp.inc"
20
21 namespace mlir {
22
23 //===----------------------------------------------------------------------===//
24 // Attribute storage classes
25 //===----------------------------------------------------------------------===//
26
27 namespace spirv {
28 namespace detail {
29
30 struct InterfaceVarABIAttributeStorage : public AttributeStorage {
31 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
32
InterfaceVarABIAttributeStoragemlir::spirv::detail::InterfaceVarABIAttributeStorage33 InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
34 Attribute storageClass)
35 : descriptorSet(descriptorSet), binding(binding),
36 storageClass(storageClass) {}
37
operator ==mlir::spirv::detail::InterfaceVarABIAttributeStorage38 bool operator==(const KeyTy &key) const {
39 return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
40 std::get<2>(key) == storageClass;
41 }
42
43 static InterfaceVarABIAttributeStorage *
constructmlir::spirv::detail::InterfaceVarABIAttributeStorage44 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
45 return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
46 InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
47 std::get<2>(key));
48 }
49
50 Attribute descriptorSet;
51 Attribute binding;
52 Attribute storageClass;
53 };
54
55 struct VerCapExtAttributeStorage : public AttributeStorage {
56 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
57
VerCapExtAttributeStoragemlir::spirv::detail::VerCapExtAttributeStorage58 VerCapExtAttributeStorage(Attribute version, Attribute capabilities,
59 Attribute extensions)
60 : version(version), capabilities(capabilities), extensions(extensions) {}
61
operator ==mlir::spirv::detail::VerCapExtAttributeStorage62 bool operator==(const KeyTy &key) const {
63 return std::get<0>(key) == version && std::get<1>(key) == capabilities &&
64 std::get<2>(key) == extensions;
65 }
66
67 static VerCapExtAttributeStorage *
constructmlir::spirv::detail::VerCapExtAttributeStorage68 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
69 return new (allocator.allocate<VerCapExtAttributeStorage>())
70 VerCapExtAttributeStorage(std::get<0>(key), std::get<1>(key),
71 std::get<2>(key));
72 }
73
74 Attribute version;
75 Attribute capabilities;
76 Attribute extensions;
77 };
78
79 struct TargetEnvAttributeStorage : public AttributeStorage {
80 using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>;
81
TargetEnvAttributeStoragemlir::spirv::detail::TargetEnvAttributeStorage82 TargetEnvAttributeStorage(Attribute triple, Vendor vendorID,
83 DeviceType deviceType, uint32_t deviceID,
84 Attribute limits)
85 : triple(triple), limits(limits), vendorID(vendorID),
86 deviceType(deviceType), deviceID(deviceID) {}
87
operator ==mlir::spirv::detail::TargetEnvAttributeStorage88 bool operator==(const KeyTy &key) const {
89 return key ==
90 std::make_tuple(triple, vendorID, deviceType, deviceID, limits);
91 }
92
93 static TargetEnvAttributeStorage *
constructmlir::spirv::detail::TargetEnvAttributeStorage94 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
95 return new (allocator.allocate<TargetEnvAttributeStorage>())
96 TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
97 std::get<2>(key), std::get<3>(key),
98 std::get<4>(key));
99 }
100
101 Attribute triple;
102 Attribute limits;
103 Vendor vendorID;
104 DeviceType deviceType;
105 uint32_t deviceID;
106 };
107 } // namespace detail
108 } // namespace spirv
109 } // namespace mlir
110
111 //===----------------------------------------------------------------------===//
112 // InterfaceVarABIAttr
113 //===----------------------------------------------------------------------===//
114
115 spirv::InterfaceVarABIAttr
get(uint32_t descriptorSet,uint32_t binding,Optional<spirv::StorageClass> storageClass,MLIRContext * context)116 spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
117 Optional<spirv::StorageClass> storageClass,
118 MLIRContext *context) {
119 Builder b(context);
120 auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
121 auto bindingAttr = b.getI32IntegerAttr(binding);
122 auto storageClassAttr =
123 storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
124 : IntegerAttr();
125 return get(descriptorSetAttr, bindingAttr, storageClassAttr);
126 }
127
128 spirv::InterfaceVarABIAttr
get(IntegerAttr descriptorSet,IntegerAttr binding,IntegerAttr storageClass)129 spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
130 IntegerAttr storageClass) {
131 assert(descriptorSet && binding);
132 MLIRContext *context = descriptorSet.getContext();
133 return Base::get(context, descriptorSet, binding, storageClass);
134 }
135
getKindName()136 StringRef spirv::InterfaceVarABIAttr::getKindName() {
137 return "interface_var_abi";
138 }
139
getBinding()140 uint32_t spirv::InterfaceVarABIAttr::getBinding() {
141 return getImpl()->binding.cast<IntegerAttr>().getInt();
142 }
143
getDescriptorSet()144 uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
145 return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
146 }
147
getStorageClass()148 Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
149 if (getImpl()->storageClass)
150 return static_cast<spirv::StorageClass>(
151 getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
152 return llvm::None;
153 }
154
verifyConstructionInvariants(Location loc,IntegerAttr descriptorSet,IntegerAttr binding,IntegerAttr storageClass)155 LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants(
156 Location loc, IntegerAttr descriptorSet, IntegerAttr binding,
157 IntegerAttr storageClass) {
158 if (!descriptorSet.getType().isSignlessInteger(32))
159 return emitError(loc, "expected 32-bit integer for descriptor set");
160
161 if (!binding.getType().isSignlessInteger(32))
162 return emitError(loc, "expected 32-bit integer for binding");
163
164 if (storageClass) {
165 if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
166 auto storageClassValue =
167 spirv::symbolizeStorageClass(storageClassAttr.getInt());
168 if (!storageClassValue)
169 return emitError(loc, "unknown storage class");
170 } else {
171 return emitError(loc, "expected valid storage class");
172 }
173 }
174
175 return success();
176 }
177
178 //===----------------------------------------------------------------------===//
179 // VerCapExtAttr
180 //===----------------------------------------------------------------------===//
181
get(spirv::Version version,ArrayRef<spirv::Capability> capabilities,ArrayRef<spirv::Extension> extensions,MLIRContext * context)182 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(
183 spirv::Version version, ArrayRef<spirv::Capability> capabilities,
184 ArrayRef<spirv::Extension> extensions, MLIRContext *context) {
185 Builder b(context);
186
187 auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version));
188
189 SmallVector<Attribute, 4> capAttrs;
190 capAttrs.reserve(capabilities.size());
191 for (spirv::Capability cap : capabilities)
192 capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap)));
193
194 SmallVector<Attribute, 4> extAttrs;
195 extAttrs.reserve(extensions.size());
196 for (spirv::Extension ext : extensions)
197 extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));
198
199 return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs));
200 }
201
get(IntegerAttr version,ArrayAttr capabilities,ArrayAttr extensions)202 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
203 ArrayAttr capabilities,
204 ArrayAttr extensions) {
205 assert(version && capabilities && extensions);
206 MLIRContext *context = version.getContext();
207 return Base::get(context, version, capabilities, extensions);
208 }
209
getKindName()210 StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
211
getVersion()212 spirv::Version spirv::VerCapExtAttr::getVersion() {
213 return static_cast<spirv::Version>(
214 getImpl()->version.cast<IntegerAttr>().getValue().getZExtValue());
215 }
216
ext_iterator(ArrayAttr::iterator it)217 spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
218 : llvm::mapped_iterator<ArrayAttr::iterator,
219 spirv::Extension (*)(Attribute)>(
220 it, [](Attribute attr) {
221 return *symbolizeExtension(attr.cast<StringAttr>().getValue());
222 }) {}
223
getExtensions()224 spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
225 auto range = getExtensionsAttr().getValue();
226 return {ext_iterator(range.begin()), ext_iterator(range.end())};
227 }
228
getExtensionsAttr()229 ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
230 return getImpl()->extensions.cast<ArrayAttr>();
231 }
232
cap_iterator(ArrayAttr::iterator it)233 spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
234 : llvm::mapped_iterator<ArrayAttr::iterator,
235 spirv::Capability (*)(Attribute)>(
236 it, [](Attribute attr) {
237 return *symbolizeCapability(
238 attr.cast<IntegerAttr>().getValue().getZExtValue());
239 }) {}
240
getCapabilities()241 spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
242 auto range = getCapabilitiesAttr().getValue();
243 return {cap_iterator(range.begin()), cap_iterator(range.end())};
244 }
245
getCapabilitiesAttr()246 ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
247 return getImpl()->capabilities.cast<ArrayAttr>();
248 }
249
verifyConstructionInvariants(Location loc,IntegerAttr version,ArrayAttr capabilities,ArrayAttr extensions)250 LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants(
251 Location loc, IntegerAttr version, ArrayAttr capabilities,
252 ArrayAttr extensions) {
253 if (!version.getType().isSignlessInteger(32))
254 return emitError(loc, "expected 32-bit integer for version");
255
256 if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
257 if (auto intAttr = attr.dyn_cast<IntegerAttr>())
258 if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
259 return true;
260 return false;
261 }))
262 return emitError(loc, "unknown capability in capability list");
263
264 if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
265 if (auto strAttr = attr.dyn_cast<StringAttr>())
266 if (spirv::symbolizeExtension(strAttr.getValue()))
267 return true;
268 return false;
269 }))
270 return emitError(loc, "unknown extension in extension list");
271
272 return success();
273 }
274
275 //===----------------------------------------------------------------------===//
276 // TargetEnvAttr
277 //===----------------------------------------------------------------------===//
278
get(spirv::VerCapExtAttr triple,Vendor vendorID,DeviceType deviceType,uint32_t deviceID,DictionaryAttr limits)279 spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
280 Vendor vendorID,
281 DeviceType deviceType,
282 uint32_t deviceID,
283 DictionaryAttr limits) {
284 assert(triple && limits && "expected valid triple and limits");
285 MLIRContext *context = triple.getContext();
286 return Base::get(context, triple, vendorID, deviceType, deviceID, limits);
287 }
288
getKindName()289 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
290
getTripleAttr() const291 spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
292 return getImpl()->triple.cast<spirv::VerCapExtAttr>();
293 }
294
getVersion() const295 spirv::Version spirv::TargetEnvAttr::getVersion() const {
296 return getTripleAttr().getVersion();
297 }
298
getExtensions()299 spirv::VerCapExtAttr::ext_range spirv::TargetEnvAttr::getExtensions() {
300 return getTripleAttr().getExtensions();
301 }
302
getExtensionsAttr()303 ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() {
304 return getTripleAttr().getExtensionsAttr();
305 }
306
getCapabilities()307 spirv::VerCapExtAttr::cap_range spirv::TargetEnvAttr::getCapabilities() {
308 return getTripleAttr().getCapabilities();
309 }
310
getCapabilitiesAttr()311 ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
312 return getTripleAttr().getCapabilitiesAttr();
313 }
314
getVendorID() const315 spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
316 return getImpl()->vendorID;
317 }
318
getDeviceType() const319 spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const {
320 return getImpl()->deviceType;
321 }
322
getDeviceID() const323 uint32_t spirv::TargetEnvAttr::getDeviceID() const {
324 return getImpl()->deviceID;
325 }
326
getResourceLimits() const327 spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
328 return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
329 }
330
verifyConstructionInvariants(Location loc,spirv::VerCapExtAttr,spirv::Vendor,spirv::DeviceType,uint32_t,DictionaryAttr limits)331 LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
332 Location loc, spirv::VerCapExtAttr /*triple*/, spirv::Vendor /*vendorID*/,
333 spirv::DeviceType /*deviceType*/, uint32_t /*deviceID*/,
334 DictionaryAttr limits) {
335 if (!limits.isa<spirv::ResourceLimitsAttr>())
336 return emitError(loc, "expected spirv::ResourceLimitsAttr for limits");
337
338 return success();
339 }
340