1 // Copyright (c) 2016 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // This file provides a class hierarchy for representing SPIR-V types. 16 17 #ifndef SOURCE_OPT_TYPES_H_ 18 #define SOURCE_OPT_TYPES_H_ 19 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include <unordered_map> 25 #include <unordered_set> 26 #include <utility> 27 #include <vector> 28 29 #include "source/latest_version_spirv_header.h" 30 #include "source/opt/instruction.h" 31 #include "source/util/small_vector.h" 32 #include "spirv-tools/libspirv.h" 33 34 namespace spvtools { 35 namespace opt { 36 namespace analysis { 37 38 class Void; 39 class Bool; 40 class Integer; 41 class Float; 42 class Vector; 43 class Matrix; 44 class Image; 45 class Sampler; 46 class SampledImage; 47 class Array; 48 class RuntimeArray; 49 class Struct; 50 class Opaque; 51 class Pointer; 52 class Function; 53 class Event; 54 class DeviceEvent; 55 class ReserveId; 56 class Queue; 57 class Pipe; 58 class ForwardPointer; 59 class PipeStorage; 60 class NamedBarrier; 61 class AccelerationStructureNV; 62 class CooperativeMatrixNV; 63 class RayQueryKHR; 64 65 // Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods, 66 // which is used as a way to probe the actual <subclass>. 67 class Type { 68 public: 69 typedef std::set<std::pair<const Pointer*, const Pointer*>> IsSameCache; 70 71 using SeenTypes = spvtools::utils::SmallVector<const Type*, 8>; 72 73 // Available subtypes. 74 // 75 // When adding a new derived class of Type, please add an entry to the enum. 76 enum Kind { 77 kVoid, 78 kBool, 79 kInteger, 80 kFloat, 81 kVector, 82 kMatrix, 83 kImage, 84 kSampler, 85 kSampledImage, 86 kArray, 87 kRuntimeArray, 88 kStruct, 89 kOpaque, 90 kPointer, 91 kFunction, 92 kEvent, 93 kDeviceEvent, 94 kReserveId, 95 kQueue, 96 kPipe, 97 kForwardPointer, 98 kPipeStorage, 99 kNamedBarrier, 100 kAccelerationStructureNV, 101 kCooperativeMatrixNV, 102 kRayQueryKHR, 103 kLast 104 }; 105 Type(Kind k)106 Type(Kind k) : kind_(k) {} 107 108 virtual ~Type() = default; 109 110 // Attaches a decoration directly on this type. AddDecoration(std::vector<uint32_t> && d)111 void AddDecoration(std::vector<uint32_t>&& d) { 112 decorations_.push_back(std::move(d)); 113 } 114 // Returns the decorations on this type as a string. 115 std::string GetDecorationStr() const; 116 // Returns true if this type has exactly the same decorations as |that| type. 117 bool HasSameDecorations(const Type* that) const; 118 // Returns true if this type is exactly the same as |that| type, including 119 // decorations. IsSame(const Type * that)120 bool IsSame(const Type* that) const { 121 IsSameCache seen; 122 return IsSameImpl(that, &seen); 123 } 124 125 // Returns true if this type is exactly the same as |that| type, including 126 // decorations. |seen| is the set of |Pointer*| pair that are currently being 127 // compared in a parent call to |IsSameImpl|. 128 virtual bool IsSameImpl(const Type* that, IsSameCache* seen) const = 0; 129 130 // Returns a human-readable string to represent this type. 131 virtual std::string str() const = 0; 132 kind()133 Kind kind() const { return kind_; } decorations()134 const std::vector<std::vector<uint32_t>>& decorations() const { 135 return decorations_; 136 } 137 138 // Returns true if there is no decoration on this type. For struct types, 139 // returns true only when there is no decoration for both the struct type 140 // and the struct members. decoration_empty()141 virtual bool decoration_empty() const { return decorations_.empty(); } 142 143 // Creates a clone of |this|. 144 std::unique_ptr<Type> Clone() const; 145 146 // Returns a clone of |this| minus any decorations. 147 std::unique_ptr<Type> RemoveDecorations() const; 148 149 // Returns true if this type must be unique. 150 // 151 // If variable pointers are allowed, then pointers are not required to be 152 // unique. 153 // TODO(alanbaker): Update this if variable pointers become a core feature. 154 bool IsUniqueType(bool allowVariablePointers = false) const; 155 156 bool operator==(const Type& other) const; 157 158 // Returns the hash value of this type. 159 size_t HashValue() const; 160 161 size_t ComputeHashValue(size_t hash, SeenTypes* seen) const; 162 163 // Returns the number of components in a composite type. Returns 0 for a 164 // non-composite type. 165 uint64_t NumberOfComponents() const; 166 167 // A bunch of methods for casting this type to a given type. Returns this if the 168 // cast can be done, nullptr otherwise. 169 // clang-format off 170 #define DeclareCastMethod(target) \ 171 virtual target* As##target() { return nullptr; } \ 172 virtual const target* As##target() const { return nullptr; } 173 DeclareCastMethod(Void) 174 DeclareCastMethod(Bool) 175 DeclareCastMethod(Integer) 176 DeclareCastMethod(Float) 177 DeclareCastMethod(Vector) 178 DeclareCastMethod(Matrix) 179 DeclareCastMethod(Image) 180 DeclareCastMethod(Sampler) 181 DeclareCastMethod(SampledImage) 182 DeclareCastMethod(Array) 183 DeclareCastMethod(RuntimeArray) 184 DeclareCastMethod(Struct) 185 DeclareCastMethod(Opaque) 186 DeclareCastMethod(Pointer) 187 DeclareCastMethod(Function) 188 DeclareCastMethod(Event) 189 DeclareCastMethod(DeviceEvent) 190 DeclareCastMethod(ReserveId) 191 DeclareCastMethod(Queue) 192 DeclareCastMethod(Pipe) 193 DeclareCastMethod(ForwardPointer) 194 DeclareCastMethod(PipeStorage) 195 DeclareCastMethod(NamedBarrier) 196 DeclareCastMethod(AccelerationStructureNV) 197 DeclareCastMethod(CooperativeMatrixNV) 198 DeclareCastMethod(RayQueryKHR) 199 #undef DeclareCastMethod 200 201 protected: 202 // Add any type-specific state to |hash| and returns new hash. 203 virtual size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const = 0; 204 205 protected: 206 // Decorations attached to this type. Each decoration is encoded as a vector 207 // of uint32_t numbers. The first uint32_t number is the decoration value, 208 // and the rest are the parameters to the decoration (if exists). 209 std::vector<std::vector<uint32_t>> decorations_; 210 211 private: 212 // Removes decorations on this type. For struct types, also removes element 213 // decorations. ClearDecorations()214 virtual void ClearDecorations() { decorations_.clear(); } 215 216 Kind kind_; 217 }; 218 // clang-format on 219 220 class Integer : public Type { 221 public: Integer(uint32_t w,bool is_signed)222 Integer(uint32_t w, bool is_signed) 223 : Type(kInteger), width_(w), signed_(is_signed) {} 224 Integer(const Integer&) = default; 225 226 std::string str() const override; 227 AsInteger()228 Integer* AsInteger() override { return this; } AsInteger()229 const Integer* AsInteger() const override { return this; } width()230 uint32_t width() const { return width_; } IsSigned()231 bool IsSigned() const { return signed_; } 232 233 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 234 235 private: 236 bool IsSameImpl(const Type* that, IsSameCache*) const override; 237 238 uint32_t width_; // bit width 239 bool signed_; // true if this integer is signed 240 }; 241 242 class Float : public Type { 243 public: Float(uint32_t w)244 Float(uint32_t w) : Type(kFloat), width_(w) {} 245 Float(const Float&) = default; 246 247 std::string str() const override; 248 AsFloat()249 Float* AsFloat() override { return this; } AsFloat()250 const Float* AsFloat() const override { return this; } width()251 uint32_t width() const { return width_; } 252 253 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 254 255 private: 256 bool IsSameImpl(const Type* that, IsSameCache*) const override; 257 258 uint32_t width_; // bit width 259 }; 260 261 class Vector : public Type { 262 public: 263 Vector(const Type* element_type, uint32_t count); 264 Vector(const Vector&) = default; 265 266 std::string str() const override; element_type()267 const Type* element_type() const { return element_type_; } element_count()268 uint32_t element_count() const { return count_; } 269 AsVector()270 Vector* AsVector() override { return this; } AsVector()271 const Vector* AsVector() const override { return this; } 272 273 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 274 275 private: 276 bool IsSameImpl(const Type* that, IsSameCache*) const override; 277 278 const Type* element_type_; 279 uint32_t count_; 280 }; 281 282 class Matrix : public Type { 283 public: 284 Matrix(const Type* element_type, uint32_t count); 285 Matrix(const Matrix&) = default; 286 287 std::string str() const override; element_type()288 const Type* element_type() const { return element_type_; } element_count()289 uint32_t element_count() const { return count_; } 290 AsMatrix()291 Matrix* AsMatrix() override { return this; } AsMatrix()292 const Matrix* AsMatrix() const override { return this; } 293 294 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 295 296 private: 297 bool IsSameImpl(const Type* that, IsSameCache*) const override; 298 299 const Type* element_type_; 300 uint32_t count_; 301 }; 302 303 class Image : public Type { 304 public: 305 Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample, 306 uint32_t sampling, SpvImageFormat f, 307 SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly); 308 Image(const Image&) = default; 309 310 std::string str() const override; 311 AsImage()312 Image* AsImage() override { return this; } AsImage()313 const Image* AsImage() const override { return this; } 314 sampled_type()315 const Type* sampled_type() const { return sampled_type_; } dim()316 SpvDim dim() const { return dim_; } depth()317 uint32_t depth() const { return depth_; } is_arrayed()318 bool is_arrayed() const { return arrayed_; } is_multisampled()319 bool is_multisampled() const { return ms_; } sampled()320 uint32_t sampled() const { return sampled_; } format()321 SpvImageFormat format() const { return format_; } access_qualifier()322 SpvAccessQualifier access_qualifier() const { return access_qualifier_; } 323 324 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 325 326 private: 327 bool IsSameImpl(const Type* that, IsSameCache*) const override; 328 329 Type* sampled_type_; 330 SpvDim dim_; 331 uint32_t depth_; 332 bool arrayed_; 333 bool ms_; 334 uint32_t sampled_; 335 SpvImageFormat format_; 336 SpvAccessQualifier access_qualifier_; 337 }; 338 339 class SampledImage : public Type { 340 public: SampledImage(Type * image)341 SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {} 342 SampledImage(const SampledImage&) = default; 343 344 std::string str() const override; 345 AsSampledImage()346 SampledImage* AsSampledImage() override { return this; } AsSampledImage()347 const SampledImage* AsSampledImage() const override { return this; } 348 image_type()349 const Type* image_type() const { return image_type_; } 350 351 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 352 353 private: 354 bool IsSameImpl(const Type* that, IsSameCache*) const override; 355 Type* image_type_; 356 }; 357 358 class Array : public Type { 359 public: 360 // Data about the length operand, that helps us distinguish between one 361 // array length and another. 362 struct LengthInfo { 363 // The result id of the instruction defining the length. 364 const uint32_t id; 365 enum Case : uint32_t { 366 kConstant = 0, 367 kConstantWithSpecId = 1, 368 kDefiningId = 2 369 }; 370 // Extra words used to distinshish one array length and another. 371 // - if OpConstant, then it's 0, then the words in the literal constant 372 // value. 373 // - if OpSpecConstant, then it's 1, then the SpecID decoration if there 374 // is one, followed by the words in the literal constant value. 375 // The spec might not be overridden, in which case we'll end up using 376 // the literal value. 377 // - Otherwise, it's an OpSpecConsant, and this 2, then the ID (again). 378 const std::vector<uint32_t> words; 379 }; 380 381 // Constructs an array type with given element and length. If the length 382 // is an OpSpecConstant, then |spec_id| should be its SpecId decoration. 383 Array(const Type* element_type, const LengthInfo& length_info_arg); 384 Array(const Array&) = default; 385 386 std::string str() const override; element_type()387 const Type* element_type() const { return element_type_; } LengthId()388 uint32_t LengthId() const { return length_info_.id; } length_info()389 const LengthInfo& length_info() const { return length_info_; } 390 AsArray()391 Array* AsArray() override { return this; } AsArray()392 const Array* AsArray() const override { return this; } 393 394 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 395 396 void ReplaceElementType(const Type* element_type); 397 LengthInfo GetConstantLengthInfo(uint32_t const_id, uint32_t length) const; 398 399 private: 400 bool IsSameImpl(const Type* that, IsSameCache*) const override; 401 402 const Type* element_type_; 403 const LengthInfo length_info_; 404 }; 405 406 class RuntimeArray : public Type { 407 public: 408 RuntimeArray(const Type* element_type); 409 RuntimeArray(const RuntimeArray&) = default; 410 411 std::string str() const override; element_type()412 const Type* element_type() const { return element_type_; } 413 AsRuntimeArray()414 RuntimeArray* AsRuntimeArray() override { return this; } AsRuntimeArray()415 const RuntimeArray* AsRuntimeArray() const override { return this; } 416 417 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 418 419 void ReplaceElementType(const Type* element_type); 420 421 private: 422 bool IsSameImpl(const Type* that, IsSameCache*) const override; 423 424 const Type* element_type_; 425 }; 426 427 class Struct : public Type { 428 public: 429 Struct(const std::vector<const Type*>& element_types); 430 Struct(const Struct&) = default; 431 432 // Adds a decoration to the member at the given index. The first word is the 433 // decoration enum, and the remaining words, if any, are its operands. 434 void AddMemberDecoration(uint32_t index, std::vector<uint32_t>&& decoration); 435 436 std::string str() const override; element_types()437 const std::vector<const Type*>& element_types() const { 438 return element_types_; 439 } element_types()440 std::vector<const Type*>& element_types() { return element_types_; } decoration_empty()441 bool decoration_empty() const override { 442 return decorations_.empty() && element_decorations_.empty(); 443 } 444 445 const std::map<uint32_t, std::vector<std::vector<uint32_t>>>& element_decorations()446 element_decorations() const { 447 return element_decorations_; 448 } 449 AsStruct()450 Struct* AsStruct() override { return this; } AsStruct()451 const Struct* AsStruct() const override { return this; } 452 453 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 454 455 private: 456 bool IsSameImpl(const Type* that, IsSameCache*) const override; 457 ClearDecorations()458 void ClearDecorations() override { 459 decorations_.clear(); 460 element_decorations_.clear(); 461 } 462 463 std::vector<const Type*> element_types_; 464 // We can attach decorations to struct members and that should not affect the 465 // underlying element type. So we need an extra data structure here to keep 466 // track of element type decorations. They must be stored in an ordered map 467 // because |GetExtraHashWords| will traverse the structure. It must have a 468 // fixed order in order to hash to the same value every time. 469 std::map<uint32_t, std::vector<std::vector<uint32_t>>> element_decorations_; 470 }; 471 472 class Opaque : public Type { 473 public: Opaque(std::string n)474 Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {} 475 Opaque(const Opaque&) = default; 476 477 std::string str() const override; 478 AsOpaque()479 Opaque* AsOpaque() override { return this; } AsOpaque()480 const Opaque* AsOpaque() const override { return this; } 481 name()482 const std::string& name() const { return name_; } 483 484 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 485 486 private: 487 bool IsSameImpl(const Type* that, IsSameCache*) const override; 488 489 std::string name_; 490 }; 491 492 class Pointer : public Type { 493 public: 494 Pointer(const Type* pointee, SpvStorageClass sc); 495 Pointer(const Pointer&) = default; 496 497 std::string str() const override; pointee_type()498 const Type* pointee_type() const { return pointee_type_; } storage_class()499 SpvStorageClass storage_class() const { return storage_class_; } 500 AsPointer()501 Pointer* AsPointer() override { return this; } AsPointer()502 const Pointer* AsPointer() const override { return this; } 503 504 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 505 506 void SetPointeeType(const Type* type); 507 508 private: 509 bool IsSameImpl(const Type* that, IsSameCache*) const override; 510 511 const Type* pointee_type_; 512 SpvStorageClass storage_class_; 513 }; 514 515 class Function : public Type { 516 public: 517 Function(const Type* ret_type, const std::vector<const Type*>& params); 518 Function(const Type* ret_type, std::vector<const Type*>& params); 519 Function(const Function&) = default; 520 521 std::string str() const override; 522 AsFunction()523 Function* AsFunction() override { return this; } AsFunction()524 const Function* AsFunction() const override { return this; } 525 return_type()526 const Type* return_type() const { return return_type_; } param_types()527 const std::vector<const Type*>& param_types() const { return param_types_; } param_types()528 std::vector<const Type*>& param_types() { return param_types_; } 529 530 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 531 532 void SetReturnType(const Type* type); 533 534 private: 535 bool IsSameImpl(const Type* that, IsSameCache*) const override; 536 537 const Type* return_type_; 538 std::vector<const Type*> param_types_; 539 }; 540 541 class Pipe : public Type { 542 public: Pipe(SpvAccessQualifier qualifier)543 Pipe(SpvAccessQualifier qualifier) 544 : Type(kPipe), access_qualifier_(qualifier) {} 545 Pipe(const Pipe&) = default; 546 547 std::string str() const override; 548 AsPipe()549 Pipe* AsPipe() override { return this; } AsPipe()550 const Pipe* AsPipe() const override { return this; } 551 access_qualifier()552 SpvAccessQualifier access_qualifier() const { return access_qualifier_; } 553 554 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 555 556 private: 557 bool IsSameImpl(const Type* that, IsSameCache*) const override; 558 559 SpvAccessQualifier access_qualifier_; 560 }; 561 562 class ForwardPointer : public Type { 563 public: ForwardPointer(uint32_t id,SpvStorageClass sc)564 ForwardPointer(uint32_t id, SpvStorageClass sc) 565 : Type(kForwardPointer), 566 target_id_(id), 567 storage_class_(sc), 568 pointer_(nullptr) {} 569 ForwardPointer(const ForwardPointer&) = default; 570 target_id()571 uint32_t target_id() const { return target_id_; } SetTargetPointer(const Pointer * pointer)572 void SetTargetPointer(const Pointer* pointer) { pointer_ = pointer; } storage_class()573 SpvStorageClass storage_class() const { return storage_class_; } target_pointer()574 const Pointer* target_pointer() const { return pointer_; } 575 576 std::string str() const override; 577 AsForwardPointer()578 ForwardPointer* AsForwardPointer() override { return this; } AsForwardPointer()579 const ForwardPointer* AsForwardPointer() const override { return this; } 580 581 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 582 583 private: 584 bool IsSameImpl(const Type* that, IsSameCache*) const override; 585 586 uint32_t target_id_; 587 SpvStorageClass storage_class_; 588 const Pointer* pointer_; 589 }; 590 591 class CooperativeMatrixNV : public Type { 592 public: 593 CooperativeMatrixNV(const Type* type, const uint32_t scope, 594 const uint32_t rows, const uint32_t columns); 595 CooperativeMatrixNV(const CooperativeMatrixNV&) = default; 596 597 std::string str() const override; 598 AsCooperativeMatrixNV()599 CooperativeMatrixNV* AsCooperativeMatrixNV() override { return this; } AsCooperativeMatrixNV()600 const CooperativeMatrixNV* AsCooperativeMatrixNV() const override { 601 return this; 602 } 603 604 size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; 605 component_type()606 const Type* component_type() const { return component_type_; } scope_id()607 uint32_t scope_id() const { return scope_id_; } rows_id()608 uint32_t rows_id() const { return rows_id_; } columns_id()609 uint32_t columns_id() const { return columns_id_; } 610 611 private: 612 bool IsSameImpl(const Type* that, IsSameCache*) const override; 613 614 const Type* component_type_; 615 const uint32_t scope_id_; 616 const uint32_t rows_id_; 617 const uint32_t columns_id_; 618 }; 619 620 #define DefineParameterlessType(type, name) \ 621 class type : public Type { \ 622 public: \ 623 type() : Type(k##type) {} \ 624 type(const type&) = default; \ 625 \ 626 std::string str() const override { return #name; } \ 627 \ 628 type* As##type() override { return this; } \ 629 const type* As##type() const override { return this; } \ 630 \ 631 size_t ComputeExtraStateHash(size_t hash, SeenTypes*) const override { \ 632 return hash; \ 633 } \ 634 \ 635 private: \ 636 bool IsSameImpl(const Type* that, IsSameCache*) const override { \ 637 return that->As##type() && HasSameDecorations(that); \ 638 } \ 639 } 640 DefineParameterlessType(Void, void); 641 DefineParameterlessType(Bool, bool); 642 DefineParameterlessType(Sampler, sampler); 643 DefineParameterlessType(Event, event); 644 DefineParameterlessType(DeviceEvent, device_event); 645 DefineParameterlessType(ReserveId, reserve_id); 646 DefineParameterlessType(Queue, queue); 647 DefineParameterlessType(PipeStorage, pipe_storage); 648 DefineParameterlessType(NamedBarrier, named_barrier); 649 DefineParameterlessType(AccelerationStructureNV, accelerationStructureNV); 650 DefineParameterlessType(RayQueryKHR, rayQueryKHR); 651 #undef DefineParameterlessType 652 653 } // namespace analysis 654 } // namespace opt 655 } // namespace spvtools 656 657 #endif // SOURCE_OPT_TYPES_H_ 658