1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Identifier.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
23
24 using namespace mlir;
25 using namespace mlir::spirv;
26
27 // Pull in all enum utility function definitions
28 #include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc"
29 // Pull in all enum type availability query function definitions
30 #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc"
31
32 //===----------------------------------------------------------------------===//
33 // Availability relationship
34 //===----------------------------------------------------------------------===//
35
getImpliedExtensions(Version version)36 ArrayRef<Extension> spirv::getImpliedExtensions(Version version) {
37 // Note: the following lists are from "Appendix A: Changes" of the spec.
38
39 #define V_1_3_IMPLIED_EXTS \
40 Extension::SPV_KHR_shader_draw_parameters, Extension::SPV_KHR_16bit_storage, \
41 Extension::SPV_KHR_device_group, Extension::SPV_KHR_multiview, \
42 Extension::SPV_KHR_storage_buffer_storage_class, \
43 Extension::SPV_KHR_variable_pointers
44
45 #define V_1_4_IMPLIED_EXTS \
46 Extension::SPV_KHR_no_integer_wrap_decoration, \
47 Extension::SPV_GOOGLE_decorate_string, \
48 Extension::SPV_GOOGLE_hlsl_functionality1, \
49 Extension::SPV_KHR_float_controls
50
51 #define V_1_5_IMPLIED_EXTS \
52 Extension::SPV_KHR_8bit_storage, Extension::SPV_EXT_descriptor_indexing, \
53 Extension::SPV_EXT_shader_viewport_index_layer, \
54 Extension::SPV_EXT_physical_storage_buffer, \
55 Extension::SPV_KHR_physical_storage_buffer, \
56 Extension::SPV_KHR_vulkan_memory_model
57
58 switch (version) {
59 default:
60 return {};
61 case Version::V_1_3: {
62 // The following manual ArrayRef constructor call is to satisfy GCC 5.
63 static const Extension exts[] = {V_1_3_IMPLIED_EXTS};
64 return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
65 }
66 case Version::V_1_4: {
67 static const Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS};
68 return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
69 }
70 case Version::V_1_5: {
71 static const Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS,
72 V_1_5_IMPLIED_EXTS};
73 return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
74 }
75 }
76
77 #undef V_1_5_IMPLIED_EXTS
78 #undef V_1_4_IMPLIED_EXTS
79 #undef V_1_3_IMPLIED_EXTS
80 }
81
82 // Pull in utility function definition for implied capabilities
83 #include "mlir/Dialect/SPIRV/SPIRVCapabilityImplication.inc"
84
85 SmallVector<Capability, 0>
getRecursiveImpliedCapabilities(Capability cap)86 spirv::getRecursiveImpliedCapabilities(Capability cap) {
87 ArrayRef<Capability> directCaps = getDirectImpliedCapabilities(cap);
88 llvm::SetVector<Capability, SmallVector<Capability, 0>> allCaps(
89 directCaps.begin(), directCaps.end());
90
91 // TODO: This is insufficient; find a better way to handle this
92 // (e.g., using static lists) if this turns out to be a bottleneck.
93 for (unsigned i = 0; i < allCaps.size(); ++i)
94 for (Capability c : getDirectImpliedCapabilities(allCaps[i]))
95 allCaps.insert(c);
96
97 return allCaps.takeVector();
98 }
99
100 //===----------------------------------------------------------------------===//
101 // ArrayType
102 //===----------------------------------------------------------------------===//
103
104 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
105 using KeyTy = std::tuple<Type, unsigned, unsigned>;
106
constructspirv::detail::ArrayTypeStorage107 static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
108 const KeyTy &key) {
109 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
110 }
111
operator ==spirv::detail::ArrayTypeStorage112 bool operator==(const KeyTy &key) const {
113 return key == KeyTy(elementType, elementCount, stride);
114 }
115
ArrayTypeStoragespirv::detail::ArrayTypeStorage116 ArrayTypeStorage(const KeyTy &key)
117 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
118 stride(std::get<2>(key)) {}
119
120 Type elementType;
121 unsigned elementCount;
122 unsigned stride;
123 };
124
get(Type elementType,unsigned elementCount)125 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
126 assert(elementCount && "ArrayType needs at least one element");
127 return Base::get(elementType.getContext(), elementType, elementCount,
128 /*stride=*/0);
129 }
130
get(Type elementType,unsigned elementCount,unsigned stride)131 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
132 unsigned stride) {
133 assert(elementCount && "ArrayType needs at least one element");
134 return Base::get(elementType.getContext(), elementType, elementCount, stride);
135 }
136
getNumElements() const137 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
138
getElementType() const139 Type ArrayType::getElementType() const { return getImpl()->elementType; }
140
getArrayStride() const141 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
142
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)143 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
144 Optional<StorageClass> storage) {
145 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
146 }
147
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)148 void ArrayType::getCapabilities(
149 SPIRVType::CapabilityArrayRefVector &capabilities,
150 Optional<StorageClass> storage) {
151 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
152 }
153
getSizeInBytes()154 Optional<int64_t> ArrayType::getSizeInBytes() {
155 auto elementType = getElementType().cast<SPIRVType>();
156 Optional<int64_t> size = elementType.getSizeInBytes();
157 if (!size)
158 return llvm::None;
159 return (*size + getArrayStride()) * getNumElements();
160 }
161
162 //===----------------------------------------------------------------------===//
163 // CompositeType
164 //===----------------------------------------------------------------------===//
165
classof(Type type)166 bool CompositeType::classof(Type type) {
167 if (auto vectorType = type.dyn_cast<VectorType>())
168 return isValid(vectorType);
169 return type
170 .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
171 spirv::RuntimeArrayType, spirv::StructType>();
172 }
173
isValid(VectorType type)174 bool CompositeType::isValid(VectorType type) {
175 switch (type.getNumElements()) {
176 case 2:
177 case 3:
178 case 4:
179 case 8:
180 case 16:
181 break;
182 default:
183 return false;
184 }
185 return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
186 }
187
getElementType(unsigned index) const188 Type CompositeType::getElementType(unsigned index) const {
189 return TypeSwitch<Type, Type>(*this)
190 .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
191 [](auto type) { return type.getElementType(); })
192 .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
193 .Case<StructType>(
194 [index](StructType type) { return type.getElementType(index); })
195 .Default(
196 [](Type) -> Type { llvm_unreachable("invalid composite type"); });
197 }
198
getNumElements() const199 unsigned CompositeType::getNumElements() const {
200 if (auto arrayType = dyn_cast<ArrayType>())
201 return arrayType.getNumElements();
202 if (auto matrixType = dyn_cast<MatrixType>())
203 return matrixType.getNumColumns();
204 if (auto structType = dyn_cast<StructType>())
205 return structType.getNumElements();
206 if (auto vectorType = dyn_cast<VectorType>())
207 return vectorType.getNumElements();
208 if (isa<CooperativeMatrixNVType>()) {
209 llvm_unreachable(
210 "invalid to query number of elements of spirv::CooperativeMatrix type");
211 }
212 if (isa<RuntimeArrayType>()) {
213 llvm_unreachable(
214 "invalid to query number of elements of spirv::RuntimeArray type");
215 }
216 llvm_unreachable("invalid composite type");
217 }
218
hasCompileTimeKnownNumElements() const219 bool CompositeType::hasCompileTimeKnownNumElements() const {
220 return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
221 }
222
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)223 void CompositeType::getExtensions(
224 SPIRVType::ExtensionArrayRefVector &extensions,
225 Optional<StorageClass> storage) {
226 TypeSwitch<Type>(*this)
227 .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
228 StructType>(
229 [&](auto type) { type.getExtensions(extensions, storage); })
230 .Case<VectorType>([&](VectorType type) {
231 return type.getElementType().cast<ScalarType>().getExtensions(
232 extensions, storage);
233 })
234 .Default([](Type) { llvm_unreachable("invalid composite type"); });
235 }
236
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)237 void CompositeType::getCapabilities(
238 SPIRVType::CapabilityArrayRefVector &capabilities,
239 Optional<StorageClass> storage) {
240 TypeSwitch<Type>(*this)
241 .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
242 StructType>(
243 [&](auto type) { type.getCapabilities(capabilities, storage); })
244 .Case<VectorType>([&](VectorType type) {
245 auto vecSize = getNumElements();
246 if (vecSize == 8 || vecSize == 16) {
247 static const Capability caps[] = {Capability::Vector16};
248 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
249 capabilities.push_back(ref);
250 }
251 return type.getElementType().cast<ScalarType>().getCapabilities(
252 capabilities, storage);
253 })
254 .Default([](Type) { llvm_unreachable("invalid composite type"); });
255 }
256
getSizeInBytes()257 Optional<int64_t> CompositeType::getSizeInBytes() {
258 if (auto arrayType = dyn_cast<ArrayType>())
259 return arrayType.getSizeInBytes();
260 if (auto structType = dyn_cast<StructType>())
261 return structType.getSizeInBytes();
262 if (auto vectorType = dyn_cast<VectorType>()) {
263 Optional<int64_t> elementSize =
264 vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
265 if (!elementSize)
266 return llvm::None;
267 return *elementSize * vectorType.getNumElements();
268 }
269 return llvm::None;
270 }
271
272 //===----------------------------------------------------------------------===//
273 // CooperativeMatrixType
274 //===----------------------------------------------------------------------===//
275
276 struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
277 using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
278
279 static CooperativeMatrixTypeStorage *
constructspirv::detail::CooperativeMatrixTypeStorage280 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
281 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
282 CooperativeMatrixTypeStorage(key);
283 }
284
operator ==spirv::detail::CooperativeMatrixTypeStorage285 bool operator==(const KeyTy &key) const {
286 return key == KeyTy(elementType, scope, rows, columns);
287 }
288
CooperativeMatrixTypeStoragespirv::detail::CooperativeMatrixTypeStorage289 CooperativeMatrixTypeStorage(const KeyTy &key)
290 : elementType(std::get<0>(key)), rows(std::get<2>(key)),
291 columns(std::get<3>(key)), scope(std::get<1>(key)) {}
292
293 Type elementType;
294 unsigned rows;
295 unsigned columns;
296 Scope scope;
297 };
298
get(Type elementType,Scope scope,unsigned rows,unsigned columns)299 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
300 Scope scope, unsigned rows,
301 unsigned columns) {
302 return Base::get(elementType.getContext(), elementType, scope, rows, columns);
303 }
304
getElementType() const305 Type CooperativeMatrixNVType::getElementType() const {
306 return getImpl()->elementType;
307 }
308
getScope() const309 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
310
getRows() const311 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
312
getColumns() const313 unsigned CooperativeMatrixNVType::getColumns() const {
314 return getImpl()->columns;
315 }
316
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)317 void CooperativeMatrixNVType::getExtensions(
318 SPIRVType::ExtensionArrayRefVector &extensions,
319 Optional<StorageClass> storage) {
320 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
321 static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
322 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
323 extensions.push_back(ref);
324 }
325
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)326 void CooperativeMatrixNVType::getCapabilities(
327 SPIRVType::CapabilityArrayRefVector &capabilities,
328 Optional<StorageClass> storage) {
329 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
330 static const Capability caps[] = {Capability::CooperativeMatrixNV};
331 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
332 capabilities.push_back(ref);
333 }
334
335 //===----------------------------------------------------------------------===//
336 // ImageType
337 //===----------------------------------------------------------------------===//
338
getNumBits()339 template <typename T> static constexpr unsigned getNumBits() { return 0; }
getNumBits()340 template <> constexpr unsigned getNumBits<Dim>() {
341 static_assert((1 << 3) > getMaxEnumValForDim(),
342 "Not enough bits to encode Dim value");
343 return 3;
344 }
getNumBits()345 template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
346 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
347 "Not enough bits to encode ImageDepthInfo value");
348 return 2;
349 }
getNumBits()350 template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
351 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
352 "Not enough bits to encode ImageArrayedInfo value");
353 return 1;
354 }
getNumBits()355 template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
356 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
357 "Not enough bits to encode ImageSamplingInfo value");
358 return 1;
359 }
getNumBits()360 template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
361 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
362 "Not enough bits to encode ImageSamplerUseInfo value");
363 return 2;
364 }
getNumBits()365 template <> constexpr unsigned getNumBits<ImageFormat>() {
366 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
367 "Not enough bits to encode ImageFormat value");
368 return 6;
369 }
370
371 struct spirv::detail::ImageTypeStorage : public TypeStorage {
372 public:
373 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
374 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
375
constructspirv::detail::ImageTypeStorage376 static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
377 const KeyTy &key) {
378 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
379 }
380
operator ==spirv::detail::ImageTypeStorage381 bool operator==(const KeyTy &key) const {
382 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
383 samplerUseInfo, format);
384 }
385
ImageTypeStoragespirv::detail::ImageTypeStorage386 ImageTypeStorage(const KeyTy &key)
387 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
388 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
389 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
390 format(std::get<6>(key)) {}
391
392 Type elementType;
393 Dim dim : getNumBits<Dim>();
394 ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
395 ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
396 ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
397 ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
398 ImageFormat format : getNumBits<ImageFormat>();
399 };
400
401 ImageType
get(std::tuple<Type,Dim,ImageDepthInfo,ImageArrayedInfo,ImageSamplingInfo,ImageSamplerUseInfo,ImageFormat> value)402 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
403 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
404 value) {
405 return Base::get(std::get<0>(value).getContext(), value);
406 }
407
getElementType() const408 Type ImageType::getElementType() const { return getImpl()->elementType; }
409
getDim() const410 Dim ImageType::getDim() const { return getImpl()->dim; }
411
getDepthInfo() const412 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
413
getArrayedInfo() const414 ImageArrayedInfo ImageType::getArrayedInfo() const {
415 return getImpl()->arrayedInfo;
416 }
417
getSamplingInfo() const418 ImageSamplingInfo ImageType::getSamplingInfo() const {
419 return getImpl()->samplingInfo;
420 }
421
getSamplerUseInfo() const422 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
423 return getImpl()->samplerUseInfo;
424 }
425
getImageFormat() const426 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
427
getExtensions(SPIRVType::ExtensionArrayRefVector &,Optional<StorageClass>)428 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
429 Optional<StorageClass>) {
430 // Image types do not require extra extensions thus far.
431 }
432
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass>)433 void ImageType::getCapabilities(
434 SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
435 if (auto dimCaps = spirv::getCapabilities(getDim()))
436 capabilities.push_back(*dimCaps);
437
438 if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
439 capabilities.push_back(*fmtCaps);
440 }
441
442 //===----------------------------------------------------------------------===//
443 // PointerType
444 //===----------------------------------------------------------------------===//
445
446 struct spirv::detail::PointerTypeStorage : public TypeStorage {
447 // (Type, StorageClass) as the key: Type stored in this struct, and
448 // StorageClass stored as TypeStorage's subclass data.
449 using KeyTy = std::pair<Type, StorageClass>;
450
constructspirv::detail::PointerTypeStorage451 static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
452 const KeyTy &key) {
453 return new (allocator.allocate<PointerTypeStorage>())
454 PointerTypeStorage(key);
455 }
456
operator ==spirv::detail::PointerTypeStorage457 bool operator==(const KeyTy &key) const {
458 return key == KeyTy(pointeeType, storageClass);
459 }
460
PointerTypeStoragespirv::detail::PointerTypeStorage461 PointerTypeStorage(const KeyTy &key)
462 : pointeeType(key.first), storageClass(key.second) {}
463
464 Type pointeeType;
465 StorageClass storageClass;
466 };
467
get(Type pointeeType,StorageClass storageClass)468 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
469 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
470 }
471
getPointeeType() const472 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
473
getStorageClass() const474 StorageClass PointerType::getStorageClass() const {
475 return getImpl()->storageClass;
476 }
477
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)478 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
479 Optional<StorageClass> storage) {
480 // Use this pointer type's storage class because this pointer indicates we are
481 // using the pointee type in that specific storage class.
482 getPointeeType().cast<SPIRVType>().getExtensions(extensions,
483 getStorageClass());
484
485 if (auto scExts = spirv::getExtensions(getStorageClass()))
486 extensions.push_back(*scExts);
487 }
488
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)489 void PointerType::getCapabilities(
490 SPIRVType::CapabilityArrayRefVector &capabilities,
491 Optional<StorageClass> storage) {
492 // Use this pointer type's storage class because this pointer indicates we are
493 // using the pointee type in that specific storage class.
494 getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
495 getStorageClass());
496
497 if (auto scCaps = spirv::getCapabilities(getStorageClass()))
498 capabilities.push_back(*scCaps);
499 }
500
501 //===----------------------------------------------------------------------===//
502 // RuntimeArrayType
503 //===----------------------------------------------------------------------===//
504
505 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
506 using KeyTy = std::pair<Type, unsigned>;
507
constructspirv::detail::RuntimeArrayTypeStorage508 static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
509 const KeyTy &key) {
510 return new (allocator.allocate<RuntimeArrayTypeStorage>())
511 RuntimeArrayTypeStorage(key);
512 }
513
operator ==spirv::detail::RuntimeArrayTypeStorage514 bool operator==(const KeyTy &key) const {
515 return key == KeyTy(elementType, stride);
516 }
517
RuntimeArrayTypeStoragespirv::detail::RuntimeArrayTypeStorage518 RuntimeArrayTypeStorage(const KeyTy &key)
519 : elementType(key.first), stride(key.second) {}
520
521 Type elementType;
522 unsigned stride;
523 };
524
get(Type elementType)525 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
526 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
527 }
528
get(Type elementType,unsigned stride)529 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
530 return Base::get(elementType.getContext(), elementType, stride);
531 }
532
getElementType() const533 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
534
getArrayStride() const535 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
536
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)537 void RuntimeArrayType::getExtensions(
538 SPIRVType::ExtensionArrayRefVector &extensions,
539 Optional<StorageClass> storage) {
540 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
541 }
542
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)543 void RuntimeArrayType::getCapabilities(
544 SPIRVType::CapabilityArrayRefVector &capabilities,
545 Optional<StorageClass> storage) {
546 {
547 static const Capability caps[] = {Capability::Shader};
548 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
549 capabilities.push_back(ref);
550 }
551 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
552 }
553
554 //===----------------------------------------------------------------------===//
555 // ScalarType
556 //===----------------------------------------------------------------------===//
557
classof(Type type)558 bool ScalarType::classof(Type type) {
559 if (auto floatType = type.dyn_cast<FloatType>()) {
560 return isValid(floatType);
561 }
562 if (auto intType = type.dyn_cast<IntegerType>()) {
563 return isValid(intType);
564 }
565 return false;
566 }
567
isValid(FloatType type)568 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
569
isValid(IntegerType type)570 bool ScalarType::isValid(IntegerType type) {
571 switch (type.getWidth()) {
572 case 1:
573 case 8:
574 case 16:
575 case 32:
576 case 64:
577 return true;
578 default:
579 return false;
580 }
581 }
582
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)583 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
584 Optional<StorageClass> storage) {
585 // 8- or 16-bit integer/floating-point numbers will require extra extensions
586 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
587 // SPV_KHR_8bit_storage for more details.
588 if (!storage)
589 return;
590
591 switch (*storage) {
592 case StorageClass::PushConstant:
593 case StorageClass::StorageBuffer:
594 case StorageClass::Uniform:
595 if (getIntOrFloatBitWidth() == 8) {
596 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
597 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
598 extensions.push_back(ref);
599 }
600 LLVM_FALLTHROUGH;
601 case StorageClass::Input:
602 case StorageClass::Output:
603 if (getIntOrFloatBitWidth() == 16) {
604 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
605 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
606 extensions.push_back(ref);
607 }
608 break;
609 default:
610 break;
611 }
612 }
613
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)614 void ScalarType::getCapabilities(
615 SPIRVType::CapabilityArrayRefVector &capabilities,
616 Optional<StorageClass> storage) {
617 unsigned bitwidth = getIntOrFloatBitWidth();
618
619 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
620 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
621 // SPV_KHR_8bit_storage for more details.
622
623 #define STORAGE_CASE(storage, cap8, cap16) \
624 case StorageClass::storage: { \
625 if (bitwidth == 8) { \
626 static const Capability caps[] = {Capability::cap8}; \
627 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
628 capabilities.push_back(ref); \
629 } else if (bitwidth == 16) { \
630 static const Capability caps[] = {Capability::cap16}; \
631 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
632 capabilities.push_back(ref); \
633 } \
634 /* No requirements for other bitwidths */ \
635 return; \
636 }
637
638 // This part only handles the cases where special bitwidths appearing in
639 // interface storage classes.
640 if (storage) {
641 switch (*storage) {
642 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
643 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
644 StorageBuffer16BitAccess);
645 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
646 StorageUniform16);
647 case StorageClass::Input:
648 case StorageClass::Output: {
649 if (bitwidth == 16) {
650 static const Capability caps[] = {Capability::StorageInputOutput16};
651 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
652 capabilities.push_back(ref);
653 }
654 return;
655 }
656 default:
657 break;
658 }
659 }
660 #undef STORAGE_CASE
661
662 // For other non-interface storage classes, require a different set of
663 // capabilities for special bitwidths.
664
665 #define WIDTH_CASE(type, width) \
666 case width: { \
667 static const Capability caps[] = {Capability::type##width}; \
668 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
669 capabilities.push_back(ref); \
670 } break
671
672 if (auto intType = dyn_cast<IntegerType>()) {
673 switch (bitwidth) {
674 case 32:
675 case 1:
676 break;
677 WIDTH_CASE(Int, 8);
678 WIDTH_CASE(Int, 16);
679 WIDTH_CASE(Int, 64);
680 default:
681 llvm_unreachable("invalid bitwidth to getCapabilities");
682 }
683 } else {
684 assert(isa<FloatType>());
685 switch (bitwidth) {
686 case 32:
687 break;
688 WIDTH_CASE(Float, 16);
689 WIDTH_CASE(Float, 64);
690 default:
691 llvm_unreachable("invalid bitwidth to getCapabilities");
692 }
693 }
694
695 #undef WIDTH_CASE
696 }
697
getSizeInBytes()698 Optional<int64_t> ScalarType::getSizeInBytes() {
699 auto bitWidth = getIntOrFloatBitWidth();
700 // According to the SPIR-V spec:
701 // "There is no physical size or bit pattern defined for values with boolean
702 // type. If they are stored (in conjunction with OpVariable), they can only
703 // be used with logical addressing operations, not physical, and only with
704 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
705 // Private, Function, Input, and Output."
706 if (bitWidth == 1)
707 return llvm::None;
708 return bitWidth / 8;
709 }
710
711 //===----------------------------------------------------------------------===//
712 // SPIRVType
713 //===----------------------------------------------------------------------===//
714
classof(Type type)715 bool SPIRVType::classof(Type type) {
716 // Allow SPIR-V dialect types
717 if (llvm::isa<SPIRVDialect>(type.getDialect()))
718 return true;
719 if (type.isa<ScalarType>())
720 return true;
721 if (auto vectorType = type.dyn_cast<VectorType>())
722 return CompositeType::isValid(vectorType);
723 return false;
724 }
725
isScalarOrVector()726 bool SPIRVType::isScalarOrVector() {
727 return isIntOrFloat() || isa<VectorType>();
728 }
729
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)730 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
731 Optional<StorageClass> storage) {
732 if (auto scalarType = dyn_cast<ScalarType>()) {
733 scalarType.getExtensions(extensions, storage);
734 } else if (auto compositeType = dyn_cast<CompositeType>()) {
735 compositeType.getExtensions(extensions, storage);
736 } else if (auto imageType = dyn_cast<ImageType>()) {
737 imageType.getExtensions(extensions, storage);
738 } else if (auto matrixType = dyn_cast<MatrixType>()) {
739 matrixType.getExtensions(extensions, storage);
740 } else if (auto ptrType = dyn_cast<PointerType>()) {
741 ptrType.getExtensions(extensions, storage);
742 } else {
743 llvm_unreachable("invalid SPIR-V Type to getExtensions");
744 }
745 }
746
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)747 void SPIRVType::getCapabilities(
748 SPIRVType::CapabilityArrayRefVector &capabilities,
749 Optional<StorageClass> storage) {
750 if (auto scalarType = dyn_cast<ScalarType>()) {
751 scalarType.getCapabilities(capabilities, storage);
752 } else if (auto compositeType = dyn_cast<CompositeType>()) {
753 compositeType.getCapabilities(capabilities, storage);
754 } else if (auto imageType = dyn_cast<ImageType>()) {
755 imageType.getCapabilities(capabilities, storage);
756 } else if (auto matrixType = dyn_cast<MatrixType>()) {
757 matrixType.getCapabilities(capabilities, storage);
758 } else if (auto ptrType = dyn_cast<PointerType>()) {
759 ptrType.getCapabilities(capabilities, storage);
760 } else {
761 llvm_unreachable("invalid SPIR-V Type to getCapabilities");
762 }
763 }
764
getSizeInBytes()765 Optional<int64_t> SPIRVType::getSizeInBytes() {
766 if (auto scalarType = dyn_cast<ScalarType>())
767 return scalarType.getSizeInBytes();
768 if (auto compositeType = dyn_cast<CompositeType>())
769 return compositeType.getSizeInBytes();
770 return llvm::None;
771 }
772
773 //===----------------------------------------------------------------------===//
774 // StructType
775 //===----------------------------------------------------------------------===//
776
777 /// Type storage for SPIR-V structure types:
778 ///
779 /// Structures are uniqued using:
780 /// - for identified structs:
781 /// - a string identifier;
782 /// - for literal structs:
783 /// - a list of member types;
784 /// - a list of member offset info;
785 /// - a list of member decoration info.
786 ///
787 /// Identified structures only have a mutable component consisting of:
788 /// - a list of member types;
789 /// - a list of member offset info;
790 /// - a list of member decoration info.
791 struct spirv::detail::StructTypeStorage : public TypeStorage {
792 /// Construct a storage object for an identified struct type. A struct type
793 /// associated with such storage must call StructType::trySetBody(...) later
794 /// in order to mutate the storage object providing the actual content.
StructTypeStoragespirv::detail::StructTypeStorage795 StructTypeStorage(StringRef identifier)
796 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
797 numMemberDecorations(0), memberDecorationsInfo(nullptr),
798 identifier(identifier) {}
799
800 /// Construct a storage object for a literal struct type. A struct type
801 /// associated with such storage is immutable.
StructTypeStoragespirv::detail::StructTypeStorage802 StructTypeStorage(
803 unsigned numMembers, Type const *memberTypes,
804 StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
805 StructType::MemberDecorationInfo const *memberDecorationsInfo)
806 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
807 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
808 memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
809
810 /// A storage key is divided into 2 parts:
811 /// - for identified structs:
812 /// - a StringRef representing the struct identifier;
813 /// - for literal structs:
814 /// - an ArrayRef<Type> for member types;
815 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
816 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
817 /// info.
818 ///
819 /// An identified struct type is uniqued only by the first part (field 0)
820 /// of the key.
821 ///
822 /// A literal struct type is unqiued only by the second part (fields 1, 2, and
823 /// 3) of the key. The identifier field (field 0) must be empty.
824 using KeyTy =
825 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
826 ArrayRef<StructType::MemberDecorationInfo>>;
827
828 /// For identified structs, return true if the given key contains the same
829 /// identifier.
830 ///
831 /// For literal structs, return true if the given key contains a matching list
832 /// of member types + offset info + decoration info.
operator ==spirv::detail::StructTypeStorage833 bool operator==(const KeyTy &key) const {
834 if (isIdentified()) {
835 // Identified types are uniqued by their identifier.
836 return getIdentifier() == std::get<0>(key);
837 }
838
839 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
840 getMemberDecorationsInfo());
841 }
842
843 /// If the given key contains a non-empty identifier, this method constructs
844 /// an identified struct and leaves the rest of the struct type data to be set
845 /// through a later call to StructType::trySetBody(...).
846 ///
847 /// If, on the other hand, the key contains an empty identifier, a literal
848 /// struct is constructed using the other fields of the key.
constructspirv::detail::StructTypeStorage849 static StructTypeStorage *construct(TypeStorageAllocator &allocator,
850 const KeyTy &key) {
851 StringRef keyIdentifier = std::get<0>(key);
852
853 if (!keyIdentifier.empty()) {
854 StringRef identifier = allocator.copyInto(keyIdentifier);
855
856 // Identified StructType body/members will be set through trySetBody(...)
857 // later.
858 return new (allocator.allocate<StructTypeStorage>())
859 StructTypeStorage(identifier);
860 }
861
862 ArrayRef<Type> keyTypes = std::get<1>(key);
863
864 // Copy the member type and layout information into the bump pointer
865 const Type *typesList = nullptr;
866 if (!keyTypes.empty()) {
867 typesList = allocator.copyInto(keyTypes).data();
868 }
869
870 const StructType::OffsetInfo *offsetInfoList = nullptr;
871 if (!std::get<2>(key).empty()) {
872 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
873 assert(keyOffsetInfo.size() == keyTypes.size() &&
874 "size of offset information must be same as the size of number of "
875 "elements");
876 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
877 }
878
879 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
880 unsigned numMemberDecorations = 0;
881 if (!std::get<3>(key).empty()) {
882 auto keyMemberDecorations = std::get<3>(key);
883 numMemberDecorations = keyMemberDecorations.size();
884 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
885 }
886
887 return new (allocator.allocate<StructTypeStorage>())
888 StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
889 numMemberDecorations, memberDecorationList);
890 }
891
getMemberTypesspirv::detail::StructTypeStorage892 ArrayRef<Type> getMemberTypes() const {
893 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
894 }
895
getOffsetInfospirv::detail::StructTypeStorage896 ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
897 if (offsetInfo) {
898 return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
899 }
900 return {};
901 }
902
getMemberDecorationsInfospirv::detail::StructTypeStorage903 ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
904 if (memberDecorationsInfo) {
905 return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
906 numMemberDecorations);
907 }
908 return {};
909 }
910
getIdentifierspirv::detail::StructTypeStorage911 StringRef getIdentifier() const { return identifier; }
912
isIdentifiedspirv::detail::StructTypeStorage913 bool isIdentified() const { return !identifier.empty(); }
914
915 /// Sets the struct type content for identified structs. Calling this method
916 /// is only valid for identified structs.
917 ///
918 /// Fails under the following conditions:
919 /// - If called for a literal struct;
920 /// - If called for an identified struct whose body was set before (through a
921 /// call to this method) but with different contents from the passed
922 /// arguments.
mutatespirv::detail::StructTypeStorage923 LogicalResult mutate(
924 TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
925 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
926 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
927 if (!isIdentified())
928 return failure();
929
930 if (memberTypesAndIsBodySet.getInt() &&
931 (getMemberTypes() != structMemberTypes ||
932 getOffsetInfo() != structOffsetInfo ||
933 getMemberDecorationsInfo() != structMemberDecorationInfo))
934 return failure();
935
936 memberTypesAndIsBodySet.setInt(true);
937 numMembers = structMemberTypes.size();
938
939 // Copy the member type and layout information into the bump pointer.
940 if (!structMemberTypes.empty())
941 memberTypesAndIsBodySet.setPointer(
942 allocator.copyInto(structMemberTypes).data());
943
944 if (!structOffsetInfo.empty()) {
945 assert(structOffsetInfo.size() == structMemberTypes.size() &&
946 "size of offset information must be same as the size of number of "
947 "elements");
948 offsetInfo = allocator.copyInto(structOffsetInfo).data();
949 }
950
951 if (!structMemberDecorationInfo.empty()) {
952 numMemberDecorations = structMemberDecorationInfo.size();
953 memberDecorationsInfo =
954 allocator.copyInto(structMemberDecorationInfo).data();
955 }
956
957 return success();
958 }
959
960 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
961 StructType::OffsetInfo const *offsetInfo;
962 unsigned numMembers;
963 unsigned numMemberDecorations;
964 StructType::MemberDecorationInfo const *memberDecorationsInfo;
965 StringRef identifier;
966 };
967
968 StructType
get(ArrayRef<Type> memberTypes,ArrayRef<StructType::OffsetInfo> offsetInfo,ArrayRef<StructType::MemberDecorationInfo> memberDecorations)969 StructType::get(ArrayRef<Type> memberTypes,
970 ArrayRef<StructType::OffsetInfo> offsetInfo,
971 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
972 assert(!memberTypes.empty() && "Struct needs at least one member type");
973 // Sort the decorations.
974 SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
975 memberDecorations.begin(), memberDecorations.end());
976 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
977 return Base::get(memberTypes.vec().front().getContext(),
978 /*identifier=*/StringRef(), memberTypes, offsetInfo,
979 sortedDecorations);
980 }
981
getIdentified(MLIRContext * context,StringRef identifier)982 StructType StructType::getIdentified(MLIRContext *context,
983 StringRef identifier) {
984 assert(!identifier.empty() &&
985 "StructType identifier must be non-empty string");
986
987 return Base::get(context, identifier, ArrayRef<Type>(),
988 ArrayRef<StructType::OffsetInfo>(),
989 ArrayRef<StructType::MemberDecorationInfo>());
990 }
991
getEmpty(MLIRContext * context,StringRef identifier)992 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
993 StructType newStructType = Base::get(
994 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
995 ArrayRef<StructType::MemberDecorationInfo>());
996 // Set an empty body in case this is a identified struct.
997 if (newStructType.isIdentified() &&
998 failed(newStructType.trySetBody(
999 ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1000 ArrayRef<StructType::MemberDecorationInfo>())))
1001 return StructType();
1002
1003 return newStructType;
1004 }
1005
getIdentifier() const1006 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1007
isIdentified() const1008 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1009
getNumElements() const1010 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1011
getElementType(unsigned index) const1012 Type StructType::getElementType(unsigned index) const {
1013 assert(getNumElements() > index && "member index out of range");
1014 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1015 }
1016
getElementTypes() const1017 StructType::ElementTypeRange StructType::getElementTypes() const {
1018 return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1019 getNumElements());
1020 }
1021
hasOffset() const1022 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1023
getMemberOffset(unsigned index) const1024 uint64_t StructType::getMemberOffset(unsigned index) const {
1025 assert(getNumElements() > index && "member index out of range");
1026 return getImpl()->offsetInfo[index];
1027 }
1028
getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorations) const1029 void StructType::getMemberDecorations(
1030 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1031 const {
1032 memberDecorations.clear();
1033 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1034 memberDecorations.append(implMemberDecorations.begin(),
1035 implMemberDecorations.end());
1036 }
1037
getMemberDecorations(unsigned index,SmallVectorImpl<StructType::MemberDecorationInfo> & decorationsInfo) const1038 void StructType::getMemberDecorations(
1039 unsigned index,
1040 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1041 assert(getNumElements() > index && "member index out of range");
1042 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1043 decorationsInfo.clear();
1044 for (const auto &memberDecoration : memberDecorations) {
1045 if (memberDecoration.memberIndex == index) {
1046 decorationsInfo.push_back(memberDecoration);
1047 }
1048 if (memberDecoration.memberIndex > index) {
1049 // Early exit since the decorations are stored sorted.
1050 return;
1051 }
1052 }
1053 }
1054
1055 LogicalResult
trySetBody(ArrayRef<Type> memberTypes,ArrayRef<OffsetInfo> offsetInfo,ArrayRef<MemberDecorationInfo> memberDecorations)1056 StructType::trySetBody(ArrayRef<Type> memberTypes,
1057 ArrayRef<OffsetInfo> offsetInfo,
1058 ArrayRef<MemberDecorationInfo> memberDecorations) {
1059 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1060 }
1061
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1062 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1063 Optional<StorageClass> storage) {
1064 for (Type elementType : getElementTypes())
1065 elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1066 }
1067
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1068 void StructType::getCapabilities(
1069 SPIRVType::CapabilityArrayRefVector &capabilities,
1070 Optional<StorageClass> storage) {
1071 for (Type elementType : getElementTypes())
1072 elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1073 }
1074
hash_value(const StructType::MemberDecorationInfo & memberDecorationInfo)1075 llvm::hash_code spirv::hash_value(
1076 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1077 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1078 memberDecorationInfo.decoration);
1079 }
1080
1081 //===----------------------------------------------------------------------===//
1082 // MatrixType
1083 //===----------------------------------------------------------------------===//
1084
1085 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
MatrixTypeStoragespirv::detail::MatrixTypeStorage1086 MatrixTypeStorage(Type columnType, uint32_t columnCount)
1087 : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1088
1089 using KeyTy = std::tuple<Type, uint32_t>;
1090
constructspirv::detail::MatrixTypeStorage1091 static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1092 const KeyTy &key) {
1093
1094 // Initialize the memory using placement new.
1095 return new (allocator.allocate<MatrixTypeStorage>())
1096 MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1097 }
1098
operator ==spirv::detail::MatrixTypeStorage1099 bool operator==(const KeyTy &key) const {
1100 return key == KeyTy(columnType, columnCount);
1101 }
1102
1103 Type columnType;
1104 const uint32_t columnCount;
1105 };
1106
get(Type columnType,uint32_t columnCount)1107 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1108 return Base::get(columnType.getContext(), columnType, columnCount);
1109 }
1110
getChecked(Type columnType,uint32_t columnCount,Location location)1111 MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
1112 Location location) {
1113 return Base::getChecked(location, columnType, columnCount);
1114 }
1115
verifyConstructionInvariants(Location loc,Type columnType,uint32_t columnCount)1116 LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
1117 Type columnType,
1118 uint32_t columnCount) {
1119 if (columnCount < 2 || columnCount > 4)
1120 return emitError(loc, "matrix can have 2, 3, or 4 columns only");
1121
1122 if (!isValidColumnType(columnType))
1123 return emitError(loc, "matrix columns must be vectors of floats");
1124
1125 /// The underlying vectors (columns) must be of size 2, 3, or 4
1126 ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1127 if (columnShape.size() != 1)
1128 return emitError(loc, "matrix columns must be 1D vectors");
1129
1130 if (columnShape[0] < 2 || columnShape[0] > 4)
1131 return emitError(loc, "matrix columns must be of size 2, 3, or 4");
1132
1133 return success();
1134 }
1135
1136 /// Returns true if the matrix elements are vectors of float elements
isValidColumnType(Type columnType)1137 bool MatrixType::isValidColumnType(Type columnType) {
1138 if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1139 if (vectorType.getElementType().isa<FloatType>())
1140 return true;
1141 }
1142 return false;
1143 }
1144
getColumnType() const1145 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1146
getElementType() const1147 Type MatrixType::getElementType() const {
1148 return getImpl()->columnType.cast<VectorType>().getElementType();
1149 }
1150
getNumColumns() const1151 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1152
getNumRows() const1153 unsigned MatrixType::getNumRows() const {
1154 return getImpl()->columnType.cast<VectorType>().getShape()[0];
1155 }
1156
getNumElements() const1157 unsigned MatrixType::getNumElements() const {
1158 return (getImpl()->columnCount) * getNumRows();
1159 }
1160
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1161 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1162 Optional<StorageClass> storage) {
1163 getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1164 }
1165
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1166 void MatrixType::getCapabilities(
1167 SPIRVType::CapabilityArrayRefVector &capabilities,
1168 Optional<StorageClass> storage) {
1169 {
1170 static const Capability caps[] = {Capability::Matrix};
1171 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1172 capabilities.push_back(ref);
1173 }
1174 // Add any capabilities associated with the underlying vectors (i.e., columns)
1175 getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1176 }
1177