• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- SPIRVAttributes.cpp - SPIR-V attribute definitions -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
10 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
11 #include "mlir/IR/Builders.h"
12 
13 using namespace mlir;
14 
15 //===----------------------------------------------------------------------===//
16 // DictionaryDict derived attributes
17 //===----------------------------------------------------------------------===//
18 
19 #include "mlir/Dialect/SPIRV/TargetAndABI.cpp.inc"
20 
21 namespace mlir {
22 
23 //===----------------------------------------------------------------------===//
24 // Attribute storage classes
25 //===----------------------------------------------------------------------===//
26 
27 namespace spirv {
28 namespace detail {
29 
30 struct InterfaceVarABIAttributeStorage : public AttributeStorage {
31   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
32 
InterfaceVarABIAttributeStoragemlir::spirv::detail::InterfaceVarABIAttributeStorage33   InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
34                                   Attribute storageClass)
35       : descriptorSet(descriptorSet), binding(binding),
36         storageClass(storageClass) {}
37 
operator ==mlir::spirv::detail::InterfaceVarABIAttributeStorage38   bool operator==(const KeyTy &key) const {
39     return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
40            std::get<2>(key) == storageClass;
41   }
42 
43   static InterfaceVarABIAttributeStorage *
constructmlir::spirv::detail::InterfaceVarABIAttributeStorage44   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
45     return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
46         InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
47                                         std::get<2>(key));
48   }
49 
50   Attribute descriptorSet;
51   Attribute binding;
52   Attribute storageClass;
53 };
54 
55 struct VerCapExtAttributeStorage : public AttributeStorage {
56   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
57 
VerCapExtAttributeStoragemlir::spirv::detail::VerCapExtAttributeStorage58   VerCapExtAttributeStorage(Attribute version, Attribute capabilities,
59                             Attribute extensions)
60       : version(version), capabilities(capabilities), extensions(extensions) {}
61 
operator ==mlir::spirv::detail::VerCapExtAttributeStorage62   bool operator==(const KeyTy &key) const {
63     return std::get<0>(key) == version && std::get<1>(key) == capabilities &&
64            std::get<2>(key) == extensions;
65   }
66 
67   static VerCapExtAttributeStorage *
constructmlir::spirv::detail::VerCapExtAttributeStorage68   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
69     return new (allocator.allocate<VerCapExtAttributeStorage>())
70         VerCapExtAttributeStorage(std::get<0>(key), std::get<1>(key),
71                                   std::get<2>(key));
72   }
73 
74   Attribute version;
75   Attribute capabilities;
76   Attribute extensions;
77 };
78 
79 struct TargetEnvAttributeStorage : public AttributeStorage {
80   using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>;
81 
TargetEnvAttributeStoragemlir::spirv::detail::TargetEnvAttributeStorage82   TargetEnvAttributeStorage(Attribute triple, Vendor vendorID,
83                             DeviceType deviceType, uint32_t deviceID,
84                             Attribute limits)
85       : triple(triple), limits(limits), vendorID(vendorID),
86         deviceType(deviceType), deviceID(deviceID) {}
87 
operator ==mlir::spirv::detail::TargetEnvAttributeStorage88   bool operator==(const KeyTy &key) const {
89     return key ==
90            std::make_tuple(triple, vendorID, deviceType, deviceID, limits);
91   }
92 
93   static TargetEnvAttributeStorage *
constructmlir::spirv::detail::TargetEnvAttributeStorage94   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
95     return new (allocator.allocate<TargetEnvAttributeStorage>())
96         TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
97                                   std::get<2>(key), std::get<3>(key),
98                                   std::get<4>(key));
99   }
100 
101   Attribute triple;
102   Attribute limits;
103   Vendor vendorID;
104   DeviceType deviceType;
105   uint32_t deviceID;
106 };
107 } // namespace detail
108 } // namespace spirv
109 } // namespace mlir
110 
111 //===----------------------------------------------------------------------===//
112 // InterfaceVarABIAttr
113 //===----------------------------------------------------------------------===//
114 
115 spirv::InterfaceVarABIAttr
get(uint32_t descriptorSet,uint32_t binding,Optional<spirv::StorageClass> storageClass,MLIRContext * context)116 spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
117                                 Optional<spirv::StorageClass> storageClass,
118                                 MLIRContext *context) {
119   Builder b(context);
120   auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
121   auto bindingAttr = b.getI32IntegerAttr(binding);
122   auto storageClassAttr =
123       storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
124                    : IntegerAttr();
125   return get(descriptorSetAttr, bindingAttr, storageClassAttr);
126 }
127 
128 spirv::InterfaceVarABIAttr
get(IntegerAttr descriptorSet,IntegerAttr binding,IntegerAttr storageClass)129 spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
130                                 IntegerAttr storageClass) {
131   assert(descriptorSet && binding);
132   MLIRContext *context = descriptorSet.getContext();
133   return Base::get(context, descriptorSet, binding, storageClass);
134 }
135 
getKindName()136 StringRef spirv::InterfaceVarABIAttr::getKindName() {
137   return "interface_var_abi";
138 }
139 
getBinding()140 uint32_t spirv::InterfaceVarABIAttr::getBinding() {
141   return getImpl()->binding.cast<IntegerAttr>().getInt();
142 }
143 
getDescriptorSet()144 uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
145   return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
146 }
147 
getStorageClass()148 Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
149   if (getImpl()->storageClass)
150     return static_cast<spirv::StorageClass>(
151         getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
152   return llvm::None;
153 }
154 
verifyConstructionInvariants(Location loc,IntegerAttr descriptorSet,IntegerAttr binding,IntegerAttr storageClass)155 LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants(
156     Location loc, IntegerAttr descriptorSet, IntegerAttr binding,
157     IntegerAttr storageClass) {
158   if (!descriptorSet.getType().isSignlessInteger(32))
159     return emitError(loc, "expected 32-bit integer for descriptor set");
160 
161   if (!binding.getType().isSignlessInteger(32))
162     return emitError(loc, "expected 32-bit integer for binding");
163 
164   if (storageClass) {
165     if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
166       auto storageClassValue =
167           spirv::symbolizeStorageClass(storageClassAttr.getInt());
168       if (!storageClassValue)
169         return emitError(loc, "unknown storage class");
170     } else {
171       return emitError(loc, "expected valid storage class");
172     }
173   }
174 
175   return success();
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // VerCapExtAttr
180 //===----------------------------------------------------------------------===//
181 
get(spirv::Version version,ArrayRef<spirv::Capability> capabilities,ArrayRef<spirv::Extension> extensions,MLIRContext * context)182 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(
183     spirv::Version version, ArrayRef<spirv::Capability> capabilities,
184     ArrayRef<spirv::Extension> extensions, MLIRContext *context) {
185   Builder b(context);
186 
187   auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version));
188 
189   SmallVector<Attribute, 4> capAttrs;
190   capAttrs.reserve(capabilities.size());
191   for (spirv::Capability cap : capabilities)
192     capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap)));
193 
194   SmallVector<Attribute, 4> extAttrs;
195   extAttrs.reserve(extensions.size());
196   for (spirv::Extension ext : extensions)
197     extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));
198 
199   return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs));
200 }
201 
get(IntegerAttr version,ArrayAttr capabilities,ArrayAttr extensions)202 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
203                                                ArrayAttr capabilities,
204                                                ArrayAttr extensions) {
205   assert(version && capabilities && extensions);
206   MLIRContext *context = version.getContext();
207   return Base::get(context, version, capabilities, extensions);
208 }
209 
getKindName()210 StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
211 
getVersion()212 spirv::Version spirv::VerCapExtAttr::getVersion() {
213   return static_cast<spirv::Version>(
214       getImpl()->version.cast<IntegerAttr>().getValue().getZExtValue());
215 }
216 
ext_iterator(ArrayAttr::iterator it)217 spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
218     : llvm::mapped_iterator<ArrayAttr::iterator,
219                             spirv::Extension (*)(Attribute)>(
220           it, [](Attribute attr) {
221             return *symbolizeExtension(attr.cast<StringAttr>().getValue());
222           }) {}
223 
getExtensions()224 spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
225   auto range = getExtensionsAttr().getValue();
226   return {ext_iterator(range.begin()), ext_iterator(range.end())};
227 }
228 
getExtensionsAttr()229 ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
230   return getImpl()->extensions.cast<ArrayAttr>();
231 }
232 
cap_iterator(ArrayAttr::iterator it)233 spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
234     : llvm::mapped_iterator<ArrayAttr::iterator,
235                             spirv::Capability (*)(Attribute)>(
236           it, [](Attribute attr) {
237             return *symbolizeCapability(
238                 attr.cast<IntegerAttr>().getValue().getZExtValue());
239           }) {}
240 
getCapabilities()241 spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
242   auto range = getCapabilitiesAttr().getValue();
243   return {cap_iterator(range.begin()), cap_iterator(range.end())};
244 }
245 
getCapabilitiesAttr()246 ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
247   return getImpl()->capabilities.cast<ArrayAttr>();
248 }
249 
verifyConstructionInvariants(Location loc,IntegerAttr version,ArrayAttr capabilities,ArrayAttr extensions)250 LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants(
251     Location loc, IntegerAttr version, ArrayAttr capabilities,
252     ArrayAttr extensions) {
253   if (!version.getType().isSignlessInteger(32))
254     return emitError(loc, "expected 32-bit integer for version");
255 
256   if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
257         if (auto intAttr = attr.dyn_cast<IntegerAttr>())
258           if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
259             return true;
260         return false;
261       }))
262     return emitError(loc, "unknown capability in capability list");
263 
264   if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
265         if (auto strAttr = attr.dyn_cast<StringAttr>())
266           if (spirv::symbolizeExtension(strAttr.getValue()))
267             return true;
268         return false;
269       }))
270     return emitError(loc, "unknown extension in extension list");
271 
272   return success();
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // TargetEnvAttr
277 //===----------------------------------------------------------------------===//
278 
get(spirv::VerCapExtAttr triple,Vendor vendorID,DeviceType deviceType,uint32_t deviceID,DictionaryAttr limits)279 spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
280                                                Vendor vendorID,
281                                                DeviceType deviceType,
282                                                uint32_t deviceID,
283                                                DictionaryAttr limits) {
284   assert(triple && limits && "expected valid triple and limits");
285   MLIRContext *context = triple.getContext();
286   return Base::get(context, triple, vendorID, deviceType, deviceID, limits);
287 }
288 
getKindName()289 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
290 
getTripleAttr() const291 spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
292   return getImpl()->triple.cast<spirv::VerCapExtAttr>();
293 }
294 
getVersion() const295 spirv::Version spirv::TargetEnvAttr::getVersion() const {
296   return getTripleAttr().getVersion();
297 }
298 
getExtensions()299 spirv::VerCapExtAttr::ext_range spirv::TargetEnvAttr::getExtensions() {
300   return getTripleAttr().getExtensions();
301 }
302 
getExtensionsAttr()303 ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() {
304   return getTripleAttr().getExtensionsAttr();
305 }
306 
getCapabilities()307 spirv::VerCapExtAttr::cap_range spirv::TargetEnvAttr::getCapabilities() {
308   return getTripleAttr().getCapabilities();
309 }
310 
getCapabilitiesAttr()311 ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
312   return getTripleAttr().getCapabilitiesAttr();
313 }
314 
getVendorID() const315 spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
316   return getImpl()->vendorID;
317 }
318 
getDeviceType() const319 spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const {
320   return getImpl()->deviceType;
321 }
322 
getDeviceID() const323 uint32_t spirv::TargetEnvAttr::getDeviceID() const {
324   return getImpl()->deviceID;
325 }
326 
getResourceLimits() const327 spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
328   return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
329 }
330 
verifyConstructionInvariants(Location loc,spirv::VerCapExtAttr,spirv::Vendor,spirv::DeviceType,uint32_t,DictionaryAttr limits)331 LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
332     Location loc, spirv::VerCapExtAttr /*triple*/, spirv::Vendor /*vendorID*/,
333     spirv::DeviceType /*deviceType*/, uint32_t /*deviceID*/,
334     DictionaryAttr limits) {
335   if (!limits.isa<spirv::ResourceLimitsAttr>())
336     return emitError(loc, "expected spirv::ResourceLimitsAttr for limits");
337 
338   return success();
339 }
340