1 //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares the types in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ 14 #define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ 15 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Diagnostics.h" 18 #include "mlir/IR/Location.h" 19 #include "mlir/IR/TypeSupport.h" 20 #include "mlir/IR/Types.h" 21 22 #include <tuple> 23 24 // Forward declare enum classes related to op availability. Their definitions 25 // are in the TableGen'erated SPIRVEnums.h.inc and can be referenced by other 26 // declarations in SPIRVEnums.h.inc. 27 namespace mlir { 28 namespace spirv { 29 enum class Version : uint32_t; 30 enum class Extension; 31 enum class Capability : uint32_t; 32 } // namespace spirv 33 } // namespace mlir 34 35 // Pull in all enum type definitions and utility function declarations 36 #include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc" 37 // Pull in all enum type availability query function declarations 38 #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.h.inc" 39 40 namespace mlir { 41 namespace spirv { 42 /// Returns the implied extensions for the given version. These extensions are 43 /// incorporated into the current version so they are implicitly declared when 44 /// targeting the given version. 45 ArrayRef<Extension> getImpliedExtensions(Version version); 46 47 /// Returns the directly implied capabilities for the given capability. These 48 /// capabilities are implicitly declared by the given capability. 49 ArrayRef<Capability> getDirectImpliedCapabilities(Capability cap); 50 /// Returns the recursively implied capabilities for the given capability. These 51 /// capabilities are implicitly declared by the given capability. Compared to 52 /// the above function, this function collects implied capabilities recursively: 53 /// if an implicitly declared capability implicitly declares a third one, the 54 /// third one will also be returned. 55 SmallVector<Capability, 0> getRecursiveImpliedCapabilities(Capability cap); 56 57 namespace detail { 58 struct ArrayTypeStorage; 59 struct CooperativeMatrixTypeStorage; 60 struct ImageTypeStorage; 61 struct MatrixTypeStorage; 62 struct PointerTypeStorage; 63 struct RuntimeArrayTypeStorage; 64 struct StructTypeStorage; 65 66 } // namespace detail 67 68 // Base SPIR-V type for providing availability queries. 69 class SPIRVType : public Type { 70 public: 71 using Type::Type; 72 73 static bool classof(Type type); 74 75 bool isScalarOrVector(); 76 77 /// The extension requirements for each type are following the 78 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 79 /// convention. 80 using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<Extension>>; 81 82 /// Appends to `extensions` the extensions needed for this type to appear in 83 /// the given `storage` class. This method does not guarantee the uniqueness 84 /// of extensions; the same extension may be appended multiple times. 85 void getExtensions(ExtensionArrayRefVector &extensions, 86 Optional<StorageClass> storage = llvm::None); 87 88 /// The capability requirements for each type are following the 89 /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) 90 /// convention. 91 using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<Capability>>; 92 93 /// Appends to `capabilities` the capabilities needed for this type to appear 94 /// in the given `storage` class. This method does not guarantee the 95 /// uniqueness of capabilities; the same capability may be appended multiple 96 /// times. 97 void getCapabilities(CapabilityArrayRefVector &capabilities, 98 Optional<StorageClass> storage = llvm::None); 99 100 /// Returns the size in bytes for each type. If no size can be calculated, 101 /// returns `llvm::None`. Note that if the type has explicit layout, it is 102 /// also taken into account in calculation. 103 Optional<int64_t> getSizeInBytes(); 104 }; 105 106 // SPIR-V scalar type: bool type, integer type, floating point type. 107 class ScalarType : public SPIRVType { 108 public: 109 using SPIRVType::SPIRVType; 110 111 static bool classof(Type type); 112 113 /// Returns true if the given integer type is valid for the SPIR-V dialect. 114 static bool isValid(FloatType); 115 /// Returns true if the given float type is valid for the SPIR-V dialect. 116 static bool isValid(IntegerType); 117 118 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 119 Optional<StorageClass> storage = llvm::None); 120 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 121 Optional<StorageClass> storage = llvm::None); 122 123 Optional<int64_t> getSizeInBytes(); 124 }; 125 126 // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. 127 class CompositeType : public SPIRVType { 128 public: 129 using SPIRVType::SPIRVType; 130 131 static bool classof(Type type); 132 133 /// Returns true if the given vector type is valid for the SPIR-V dialect. 134 static bool isValid(VectorType); 135 136 /// Return the number of elements of the type. This should only be called if 137 /// hasCompileTimeKnownNumElements is true. 138 unsigned getNumElements() const; 139 140 Type getElementType(unsigned) const; 141 142 /// Return true if the number of elements is known at compile time and is not 143 /// implementation dependent. 144 bool hasCompileTimeKnownNumElements() const; 145 146 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 147 Optional<StorageClass> storage = llvm::None); 148 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 149 Optional<StorageClass> storage = llvm::None); 150 151 Optional<int64_t> getSizeInBytes(); 152 }; 153 154 // SPIR-V array type 155 class ArrayType : public Type::TypeBase<ArrayType, CompositeType, 156 detail::ArrayTypeStorage> { 157 public: 158 using Base::Base; 159 160 static ArrayType get(Type elementType, unsigned elementCount); 161 162 /// Returns an array type with the given stride in bytes. 163 static ArrayType get(Type elementType, unsigned elementCount, 164 unsigned stride); 165 166 unsigned getNumElements() const; 167 168 Type getElementType() const; 169 170 /// Returns the array stride in bytes. 0 means no stride decorated on this 171 /// type. 172 unsigned getArrayStride() const; 173 174 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 175 Optional<StorageClass> storage = llvm::None); 176 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 177 Optional<StorageClass> storage = llvm::None); 178 179 /// Returns the array size in bytes. Since array type may have an explicit 180 /// stride declaration (in bytes), we also include it in the calculation. 181 Optional<int64_t> getSizeInBytes(); 182 }; 183 184 // SPIR-V image type 185 class ImageType 186 : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> { 187 public: 188 using Base::Base; 189 190 static ImageType 191 get(Type elementType, Dim dim, 192 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown, 193 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed, 194 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled, 195 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown, 196 ImageFormat format = ImageFormat::Unknown) { 197 return ImageType::get( 198 std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 199 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>( 200 elementType, dim, depth, arrayed, samplingInfo, samplerUse, 201 format)); 202 } 203 204 static ImageType 205 get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 206 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>); 207 208 Type getElementType() const; 209 Dim getDim() const; 210 ImageDepthInfo getDepthInfo() const; 211 ImageArrayedInfo getArrayedInfo() const; 212 ImageSamplingInfo getSamplingInfo() const; 213 ImageSamplerUseInfo getSamplerUseInfo() const; 214 ImageFormat getImageFormat() const; 215 // TODO: Add support for Access qualifier 216 217 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 218 Optional<StorageClass> storage = llvm::None); 219 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 220 Optional<StorageClass> storage = llvm::None); 221 }; 222 223 // SPIR-V pointer type 224 class PointerType : public Type::TypeBase<PointerType, SPIRVType, 225 detail::PointerTypeStorage> { 226 public: 227 using Base::Base; 228 229 static PointerType get(Type pointeeType, StorageClass storageClass); 230 231 Type getPointeeType() const; 232 233 StorageClass getStorageClass() const; 234 235 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 236 Optional<StorageClass> storage = llvm::None); 237 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 238 Optional<StorageClass> storage = llvm::None); 239 }; 240 241 // SPIR-V run-time array type 242 class RuntimeArrayType 243 : public Type::TypeBase<RuntimeArrayType, SPIRVType, 244 detail::RuntimeArrayTypeStorage> { 245 public: 246 using Base::Base; 247 248 static RuntimeArrayType get(Type elementType); 249 250 /// Returns a runtime array type with the given stride in bytes. 251 static RuntimeArrayType get(Type elementType, unsigned stride); 252 253 Type getElementType() const; 254 255 /// Returns the array stride in bytes. 0 means no stride decorated on this 256 /// type. 257 unsigned getArrayStride() const; 258 259 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 260 Optional<StorageClass> storage = llvm::None); 261 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 262 Optional<StorageClass> storage = llvm::None); 263 }; 264 265 /// SPIR-V struct type. Two kinds of struct types are supported: 266 /// - Literal: a literal struct type is uniqued by its fields (types + offset 267 /// info + decoration info). 268 /// - Identified: an indentified struct type is uniqued by its string identifier 269 /// (name). This is useful in representing recursive structs. For example, the 270 /// following C struct: 271 /// 272 /// struct A { 273 /// A* next; 274 /// }; 275 /// 276 /// would be represented in MLIR as: 277 /// 278 /// !spv.struct<A, (!spv.ptr<!spv.struct<A>, Generic>)> 279 /// 280 /// In the above, expressing recursive struct types is accomplished by giving a 281 /// recursive struct a unique identified and using that identifier in the struct 282 /// definition for recursive references. 283 class StructType : public Type::TypeBase<StructType, CompositeType, 284 detail::StructTypeStorage> { 285 public: 286 using Base::Base; 287 288 // Type for specifying the offset of the struct members 289 using OffsetInfo = uint32_t; 290 291 // Type for specifying the decoration(s) on struct members 292 struct MemberDecorationInfo { 293 uint32_t memberIndex : 31; 294 uint32_t hasValue : 1; 295 Decoration decoration; 296 uint32_t decorationValue; 297 MemberDecorationInfoMemberDecorationInfo298 MemberDecorationInfo(uint32_t index, uint32_t hasValue, 299 Decoration decoration, uint32_t decorationValue) 300 : memberIndex(index), hasValue(hasValue), decoration(decoration), 301 decorationValue(decorationValue) {} 302 303 bool operator==(const MemberDecorationInfo &other) const { 304 return (this->memberIndex == other.memberIndex) && 305 (this->decoration == other.decoration) && 306 (this->decorationValue == other.decorationValue); 307 } 308 309 bool operator<(const MemberDecorationInfo &other) const { 310 return this->memberIndex < other.memberIndex || 311 (this->memberIndex == other.memberIndex && 312 static_cast<uint32_t>(this->decoration) < 313 static_cast<uint32_t>(other.decoration)); 314 } 315 }; 316 317 /// Construct a literal StructType with at least one member. 318 static StructType get(ArrayRef<Type> memberTypes, 319 ArrayRef<OffsetInfo> offsetInfo = {}, 320 ArrayRef<MemberDecorationInfo> memberDecorations = {}); 321 322 /// Construct an identified StructType. This creates a StructType whose body 323 /// (member types, offset info, and decorations) is not set yet. A call to 324 /// StructType::trySetBody(...) must follow when the StructType contents are 325 /// available (e.g. parsed or deserialized). 326 /// 327 /// Note: If another thread creates (or had already created) a struct with the 328 /// same identifier, that struct will be returned as a result. 329 static StructType getIdentified(MLIRContext *context, StringRef identifier); 330 331 /// Construct a (possibly identified) StructType with no members. 332 /// 333 /// Note: this method might fail in a multi-threaded setup if another thread 334 /// created an identified struct with the same identifier but with different 335 /// contents before returning. In which case, an empty (default-constructed) 336 /// StructType is returned. 337 static StructType getEmpty(MLIRContext *context, StringRef identifier = ""); 338 339 /// For literal structs, return an empty string. 340 /// For identified structs, return the struct's identifier. 341 StringRef getIdentifier() const; 342 343 /// Returns true if the StructType is identified. 344 bool isIdentified() const; 345 346 unsigned getNumElements() const; 347 348 Type getElementType(unsigned) const; 349 350 /// Range class for element types. 351 class ElementTypeRange 352 : public ::llvm::detail::indexed_accessor_range_base< 353 ElementTypeRange, const Type *, Type, Type, Type> { 354 private: 355 using RangeBaseT::RangeBaseT; 356 357 /// See `llvm::detail::indexed_accessor_range_base` for details. offset_base(const Type * object,ptrdiff_t index)358 static const Type *offset_base(const Type *object, ptrdiff_t index) { 359 return object + index; 360 } 361 /// See `llvm::detail::indexed_accessor_range_base` for details. dereference_iterator(const Type * object,ptrdiff_t index)362 static Type dereference_iterator(const Type *object, ptrdiff_t index) { 363 return object[index]; 364 } 365 366 /// Allow base class access to `offset_base` and `dereference_iterator`. 367 friend RangeBaseT; 368 }; 369 370 ElementTypeRange getElementTypes() const; 371 372 bool hasOffset() const; 373 374 uint64_t getMemberOffset(unsigned) const; 375 376 // Returns in `memberDecorations` the Decorations (apart from Offset) 377 // associated with all members of the StructType. 378 void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> 379 &memberDecorations) const; 380 381 // Returns in `decorationsInfo` all the Decorations (apart from Offset) 382 // associated with the `i`-th member of the StructType. 383 void getMemberDecorations(unsigned i, 384 SmallVectorImpl<StructType::MemberDecorationInfo> 385 &decorationsInfo) const; 386 387 /// Sets the contents of an incomplete identified StructType. This method must 388 /// be called only for identified StructTypes and it must be called only once 389 /// per instance. Otherwise, failure() is returned. 390 LogicalResult 391 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, 392 ArrayRef<MemberDecorationInfo> memberDecorations = {}); 393 394 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 395 Optional<StorageClass> storage = llvm::None); 396 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 397 Optional<StorageClass> storage = llvm::None); 398 }; 399 400 llvm::hash_code 401 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); 402 403 // SPIR-V cooperative matrix type 404 class CooperativeMatrixNVType 405 : public Type::TypeBase<CooperativeMatrixNVType, CompositeType, 406 detail::CooperativeMatrixTypeStorage> { 407 public: 408 using Base::Base; 409 410 static CooperativeMatrixNVType get(Type elementType, Scope scope, 411 unsigned rows, unsigned columns); 412 Type getElementType() const; 413 414 /// Return the scope of the cooperative matrix. 415 Scope getScope() const; 416 /// return the number of rows of the matrix. 417 unsigned getRows() const; 418 /// return the number of columns of the matrix. 419 unsigned getColumns() const; 420 421 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 422 Optional<StorageClass> storage = llvm::None); 423 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 424 Optional<StorageClass> storage = llvm::None); 425 }; 426 427 // SPIR-V matrix type 428 class MatrixType : public Type::TypeBase<MatrixType, CompositeType, 429 detail::MatrixTypeStorage> { 430 public: 431 using Base::Base; 432 433 static MatrixType get(Type columnType, uint32_t columnCount); 434 435 static MatrixType getChecked(Type columnType, uint32_t columnCount, 436 Location location); 437 438 static LogicalResult verifyConstructionInvariants(Location loc, 439 Type columnType, 440 uint32_t columnCount); 441 442 /// Returns true if the matrix elements are vectors of float elements. 443 static bool isValidColumnType(Type columnType); 444 445 Type getColumnType() const; 446 447 /// Returns the number of rows. 448 unsigned getNumRows() const; 449 450 /// Returns the number of columns. 451 unsigned getNumColumns() const; 452 453 /// Returns total number of elements (rows*columns). 454 unsigned getNumElements() const; 455 456 /// Returns the elements' type (i.e, single element type). 457 Type getElementType() const; 458 459 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 460 Optional<StorageClass> storage = llvm::None); 461 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 462 Optional<StorageClass> storage = llvm::None); 463 }; 464 465 } // end namespace spirv 466 } // end namespace mlir 467 468 #endif // MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ 469