• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // This file provides a class hierarchy for representing SPIR-V types.
16 
17 #ifndef SOURCE_OPT_TYPES_H_
18 #define SOURCE_OPT_TYPES_H_
19 
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28 
29 #include "source/latest_version_spirv_header.h"
30 #include "source/opt/instruction.h"
31 #include "source/util/small_vector.h"
32 #include "spirv-tools/libspirv.h"
33 
34 namespace spvtools {
35 namespace opt {
36 namespace analysis {
37 
38 class Void;
39 class Bool;
40 class Integer;
41 class Float;
42 class Vector;
43 class Matrix;
44 class Image;
45 class Sampler;
46 class SampledImage;
47 class Array;
48 class RuntimeArray;
49 class Struct;
50 class Opaque;
51 class Pointer;
52 class Function;
53 class Event;
54 class DeviceEvent;
55 class ReserveId;
56 class Queue;
57 class Pipe;
58 class ForwardPointer;
59 class PipeStorage;
60 class NamedBarrier;
61 class AccelerationStructureNV;
62 class CooperativeMatrixNV;
63 class RayQueryKHR;
64 
65 // Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
66 // which is used as a way to probe the actual <subclass>.
67 class Type {
68  public:
69   typedef std::set<std::pair<const Pointer*, const Pointer*>> IsSameCache;
70 
71   using SeenTypes = spvtools::utils::SmallVector<const Type*, 8>;
72 
73   // Available subtypes.
74   //
75   // When adding a new derived class of Type, please add an entry to the enum.
76   enum Kind {
77     kVoid,
78     kBool,
79     kInteger,
80     kFloat,
81     kVector,
82     kMatrix,
83     kImage,
84     kSampler,
85     kSampledImage,
86     kArray,
87     kRuntimeArray,
88     kStruct,
89     kOpaque,
90     kPointer,
91     kFunction,
92     kEvent,
93     kDeviceEvent,
94     kReserveId,
95     kQueue,
96     kPipe,
97     kForwardPointer,
98     kPipeStorage,
99     kNamedBarrier,
100     kAccelerationStructureNV,
101     kCooperativeMatrixNV,
102     kRayQueryKHR,
103     kLast
104   };
105 
Type(Kind k)106   Type(Kind k) : kind_(k) {}
107 
108   virtual ~Type() = default;
109 
110   // Attaches a decoration directly on this type.
AddDecoration(std::vector<uint32_t> && d)111   void AddDecoration(std::vector<uint32_t>&& d) {
112     decorations_.push_back(std::move(d));
113   }
114   // Returns the decorations on this type as a string.
115   std::string GetDecorationStr() const;
116   // Returns true if this type has exactly the same decorations as |that| type.
117   bool HasSameDecorations(const Type* that) const;
118   // Returns true if this type is exactly the same as |that| type, including
119   // decorations.
IsSame(const Type * that)120   bool IsSame(const Type* that) const {
121     IsSameCache seen;
122     return IsSameImpl(that, &seen);
123   }
124 
125   // Returns true if this type is exactly the same as |that| type, including
126   // decorations.  |seen| is the set of |Pointer*| pair that are currently being
127   // compared in a parent call to |IsSameImpl|.
128   virtual bool IsSameImpl(const Type* that, IsSameCache* seen) const = 0;
129 
130   // Returns a human-readable string to represent this type.
131   virtual std::string str() const = 0;
132 
kind()133   Kind kind() const { return kind_; }
decorations()134   const std::vector<std::vector<uint32_t>>& decorations() const {
135     return decorations_;
136   }
137 
138   // Returns true if there is no decoration on this type. For struct types,
139   // returns true only when there is no decoration for both the struct type
140   // and the struct members.
decoration_empty()141   virtual bool decoration_empty() const { return decorations_.empty(); }
142 
143   // Creates a clone of |this|.
144   std::unique_ptr<Type> Clone() const;
145 
146   // Returns a clone of |this| minus any decorations.
147   std::unique_ptr<Type> RemoveDecorations() const;
148 
149   // Returns true if this type must be unique.
150   //
151   // If variable pointers are allowed, then pointers are not required to be
152   // unique.
153   // TODO(alanbaker): Update this if variable pointers become a core feature.
154   bool IsUniqueType(bool allowVariablePointers = false) const;
155 
156   bool operator==(const Type& other) const;
157 
158   // Returns the hash value of this type.
159   size_t HashValue() const;
160 
161   size_t ComputeHashValue(size_t hash, SeenTypes* seen) const;
162 
163   // Returns the number of components in a composite type.  Returns 0 for a
164   // non-composite type.
165   uint64_t NumberOfComponents() const;
166 
167 // A bunch of methods for casting this type to a given type. Returns this if the
168 // cast can be done, nullptr otherwise.
169 // clang-format off
170 #define DeclareCastMethod(target)                  \
171   virtual target* As##target() { return nullptr; } \
172   virtual const target* As##target() const { return nullptr; }
173   DeclareCastMethod(Void)
174   DeclareCastMethod(Bool)
175   DeclareCastMethod(Integer)
176   DeclareCastMethod(Float)
177   DeclareCastMethod(Vector)
178   DeclareCastMethod(Matrix)
179   DeclareCastMethod(Image)
180   DeclareCastMethod(Sampler)
181   DeclareCastMethod(SampledImage)
182   DeclareCastMethod(Array)
183   DeclareCastMethod(RuntimeArray)
184   DeclareCastMethod(Struct)
185   DeclareCastMethod(Opaque)
186   DeclareCastMethod(Pointer)
187   DeclareCastMethod(Function)
188   DeclareCastMethod(Event)
189   DeclareCastMethod(DeviceEvent)
190   DeclareCastMethod(ReserveId)
191   DeclareCastMethod(Queue)
192   DeclareCastMethod(Pipe)
193   DeclareCastMethod(ForwardPointer)
194   DeclareCastMethod(PipeStorage)
195   DeclareCastMethod(NamedBarrier)
196   DeclareCastMethod(AccelerationStructureNV)
197   DeclareCastMethod(CooperativeMatrixNV)
198   DeclareCastMethod(RayQueryKHR)
199 #undef DeclareCastMethod
200 
201 protected:
202   // Add any type-specific state to |hash| and returns new hash.
203   virtual size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const = 0;
204 
205  protected:
206   // Decorations attached to this type. Each decoration is encoded as a vector
207   // of uint32_t numbers. The first uint32_t number is the decoration value,
208   // and the rest are the parameters to the decoration (if exists).
209   std::vector<std::vector<uint32_t>> decorations_;
210 
211  private:
212   // Removes decorations on this type. For struct types, also removes element
213   // decorations.
ClearDecorations()214   virtual void ClearDecorations() { decorations_.clear(); }
215 
216   Kind kind_;
217 };
218 // clang-format on
219 
220 class Integer : public Type {
221  public:
Integer(uint32_t w,bool is_signed)222   Integer(uint32_t w, bool is_signed)
223       : Type(kInteger), width_(w), signed_(is_signed) {}
224   Integer(const Integer&) = default;
225 
226   std::string str() const override;
227 
AsInteger()228   Integer* AsInteger() override { return this; }
AsInteger()229   const Integer* AsInteger() const override { return this; }
width()230   uint32_t width() const { return width_; }
IsSigned()231   bool IsSigned() const { return signed_; }
232 
233   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
234 
235  private:
236   bool IsSameImpl(const Type* that, IsSameCache*) const override;
237 
238   uint32_t width_;  // bit width
239   bool signed_;     // true if this integer is signed
240 };
241 
242 class Float : public Type {
243  public:
Float(uint32_t w)244   Float(uint32_t w) : Type(kFloat), width_(w) {}
245   Float(const Float&) = default;
246 
247   std::string str() const override;
248 
AsFloat()249   Float* AsFloat() override { return this; }
AsFloat()250   const Float* AsFloat() const override { return this; }
width()251   uint32_t width() const { return width_; }
252 
253   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
254 
255  private:
256   bool IsSameImpl(const Type* that, IsSameCache*) const override;
257 
258   uint32_t width_;  // bit width
259 };
260 
261 class Vector : public Type {
262  public:
263   Vector(const Type* element_type, uint32_t count);
264   Vector(const Vector&) = default;
265 
266   std::string str() const override;
element_type()267   const Type* element_type() const { return element_type_; }
element_count()268   uint32_t element_count() const { return count_; }
269 
AsVector()270   Vector* AsVector() override { return this; }
AsVector()271   const Vector* AsVector() const override { return this; }
272 
273   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
274 
275  private:
276   bool IsSameImpl(const Type* that, IsSameCache*) const override;
277 
278   const Type* element_type_;
279   uint32_t count_;
280 };
281 
282 class Matrix : public Type {
283  public:
284   Matrix(const Type* element_type, uint32_t count);
285   Matrix(const Matrix&) = default;
286 
287   std::string str() const override;
element_type()288   const Type* element_type() const { return element_type_; }
element_count()289   uint32_t element_count() const { return count_; }
290 
AsMatrix()291   Matrix* AsMatrix() override { return this; }
AsMatrix()292   const Matrix* AsMatrix() const override { return this; }
293 
294   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
295 
296  private:
297   bool IsSameImpl(const Type* that, IsSameCache*) const override;
298 
299   const Type* element_type_;
300   uint32_t count_;
301 };
302 
303 class Image : public Type {
304  public:
305   Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample,
306         uint32_t sampling, SpvImageFormat f,
307         SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly);
308   Image(const Image&) = default;
309 
310   std::string str() const override;
311 
AsImage()312   Image* AsImage() override { return this; }
AsImage()313   const Image* AsImage() const override { return this; }
314 
sampled_type()315   const Type* sampled_type() const { return sampled_type_; }
dim()316   SpvDim dim() const { return dim_; }
depth()317   uint32_t depth() const { return depth_; }
is_arrayed()318   bool is_arrayed() const { return arrayed_; }
is_multisampled()319   bool is_multisampled() const { return ms_; }
sampled()320   uint32_t sampled() const { return sampled_; }
format()321   SpvImageFormat format() const { return format_; }
access_qualifier()322   SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
323 
324   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
325 
326  private:
327   bool IsSameImpl(const Type* that, IsSameCache*) const override;
328 
329   Type* sampled_type_;
330   SpvDim dim_;
331   uint32_t depth_;
332   bool arrayed_;
333   bool ms_;
334   uint32_t sampled_;
335   SpvImageFormat format_;
336   SpvAccessQualifier access_qualifier_;
337 };
338 
339 class SampledImage : public Type {
340  public:
SampledImage(Type * image)341   SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {}
342   SampledImage(const SampledImage&) = default;
343 
344   std::string str() const override;
345 
AsSampledImage()346   SampledImage* AsSampledImage() override { return this; }
AsSampledImage()347   const SampledImage* AsSampledImage() const override { return this; }
348 
image_type()349   const Type* image_type() const { return image_type_; }
350 
351   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
352 
353  private:
354   bool IsSameImpl(const Type* that, IsSameCache*) const override;
355   Type* image_type_;
356 };
357 
358 class Array : public Type {
359  public:
360   // Data about the length operand, that helps us distinguish between one
361   // array length and another.
362   struct LengthInfo {
363     // The result id of the instruction defining the length.
364     const uint32_t id;
365     enum Case : uint32_t {
366       kConstant = 0,
367       kConstantWithSpecId = 1,
368       kDefiningId = 2
369     };
370     // Extra words used to distinshish one array length and another.
371     //  - if OpConstant, then it's 0, then the words in the literal constant
372     //    value.
373     //  - if OpSpecConstant, then it's 1, then the SpecID decoration if there
374     //    is one, followed by the words in the literal constant value.
375     //    The spec might not be overridden, in which case we'll end up using
376     //    the literal value.
377     //  - Otherwise, it's an OpSpecConsant, and this 2, then the ID (again).
378     const std::vector<uint32_t> words;
379   };
380 
381   // Constructs an array type with given element and length.  If the length
382   // is an OpSpecConstant, then |spec_id| should be its SpecId decoration.
383   Array(const Type* element_type, const LengthInfo& length_info_arg);
384   Array(const Array&) = default;
385 
386   std::string str() const override;
element_type()387   const Type* element_type() const { return element_type_; }
LengthId()388   uint32_t LengthId() const { return length_info_.id; }
length_info()389   const LengthInfo& length_info() const { return length_info_; }
390 
AsArray()391   Array* AsArray() override { return this; }
AsArray()392   const Array* AsArray() const override { return this; }
393 
394   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
395 
396   void ReplaceElementType(const Type* element_type);
397   LengthInfo GetConstantLengthInfo(uint32_t const_id, uint32_t length) const;
398 
399  private:
400   bool IsSameImpl(const Type* that, IsSameCache*) const override;
401 
402   const Type* element_type_;
403   const LengthInfo length_info_;
404 };
405 
406 class RuntimeArray : public Type {
407  public:
408   RuntimeArray(const Type* element_type);
409   RuntimeArray(const RuntimeArray&) = default;
410 
411   std::string str() const override;
element_type()412   const Type* element_type() const { return element_type_; }
413 
AsRuntimeArray()414   RuntimeArray* AsRuntimeArray() override { return this; }
AsRuntimeArray()415   const RuntimeArray* AsRuntimeArray() const override { return this; }
416 
417   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
418 
419   void ReplaceElementType(const Type* element_type);
420 
421  private:
422   bool IsSameImpl(const Type* that, IsSameCache*) const override;
423 
424   const Type* element_type_;
425 };
426 
427 class Struct : public Type {
428  public:
429   Struct(const std::vector<const Type*>& element_types);
430   Struct(const Struct&) = default;
431 
432   // Adds a decoration to the member at the given index.  The first word is the
433   // decoration enum, and the remaining words, if any, are its operands.
434   void AddMemberDecoration(uint32_t index, std::vector<uint32_t>&& decoration);
435 
436   std::string str() const override;
element_types()437   const std::vector<const Type*>& element_types() const {
438     return element_types_;
439   }
element_types()440   std::vector<const Type*>& element_types() { return element_types_; }
decoration_empty()441   bool decoration_empty() const override {
442     return decorations_.empty() && element_decorations_.empty();
443   }
444 
445   const std::map<uint32_t, std::vector<std::vector<uint32_t>>>&
element_decorations()446   element_decorations() const {
447     return element_decorations_;
448   }
449 
AsStruct()450   Struct* AsStruct() override { return this; }
AsStruct()451   const Struct* AsStruct() const override { return this; }
452 
453   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
454 
455  private:
456   bool IsSameImpl(const Type* that, IsSameCache*) const override;
457 
ClearDecorations()458   void ClearDecorations() override {
459     decorations_.clear();
460     element_decorations_.clear();
461   }
462 
463   std::vector<const Type*> element_types_;
464   // We can attach decorations to struct members and that should not affect the
465   // underlying element type. So we need an extra data structure here to keep
466   // track of element type decorations.  They must be stored in an ordered map
467   // because |GetExtraHashWords| will traverse the structure.  It must have a
468   // fixed order in order to hash to the same value every time.
469   std::map<uint32_t, std::vector<std::vector<uint32_t>>> element_decorations_;
470 };
471 
472 class Opaque : public Type {
473  public:
Opaque(std::string n)474   Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {}
475   Opaque(const Opaque&) = default;
476 
477   std::string str() const override;
478 
AsOpaque()479   Opaque* AsOpaque() override { return this; }
AsOpaque()480   const Opaque* AsOpaque() const override { return this; }
481 
name()482   const std::string& name() const { return name_; }
483 
484   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
485 
486  private:
487   bool IsSameImpl(const Type* that, IsSameCache*) const override;
488 
489   std::string name_;
490 };
491 
492 class Pointer : public Type {
493  public:
494   Pointer(const Type* pointee, SpvStorageClass sc);
495   Pointer(const Pointer&) = default;
496 
497   std::string str() const override;
pointee_type()498   const Type* pointee_type() const { return pointee_type_; }
storage_class()499   SpvStorageClass storage_class() const { return storage_class_; }
500 
AsPointer()501   Pointer* AsPointer() override { return this; }
AsPointer()502   const Pointer* AsPointer() const override { return this; }
503 
504   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
505 
506   void SetPointeeType(const Type* type);
507 
508  private:
509   bool IsSameImpl(const Type* that, IsSameCache*) const override;
510 
511   const Type* pointee_type_;
512   SpvStorageClass storage_class_;
513 };
514 
515 class Function : public Type {
516  public:
517   Function(const Type* ret_type, const std::vector<const Type*>& params);
518   Function(const Type* ret_type, std::vector<const Type*>& params);
519   Function(const Function&) = default;
520 
521   std::string str() const override;
522 
AsFunction()523   Function* AsFunction() override { return this; }
AsFunction()524   const Function* AsFunction() const override { return this; }
525 
return_type()526   const Type* return_type() const { return return_type_; }
param_types()527   const std::vector<const Type*>& param_types() const { return param_types_; }
param_types()528   std::vector<const Type*>& param_types() { return param_types_; }
529 
530   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
531 
532   void SetReturnType(const Type* type);
533 
534  private:
535   bool IsSameImpl(const Type* that, IsSameCache*) const override;
536 
537   const Type* return_type_;
538   std::vector<const Type*> param_types_;
539 };
540 
541 class Pipe : public Type {
542  public:
Pipe(SpvAccessQualifier qualifier)543   Pipe(SpvAccessQualifier qualifier)
544       : Type(kPipe), access_qualifier_(qualifier) {}
545   Pipe(const Pipe&) = default;
546 
547   std::string str() const override;
548 
AsPipe()549   Pipe* AsPipe() override { return this; }
AsPipe()550   const Pipe* AsPipe() const override { return this; }
551 
access_qualifier()552   SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
553 
554   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
555 
556  private:
557   bool IsSameImpl(const Type* that, IsSameCache*) const override;
558 
559   SpvAccessQualifier access_qualifier_;
560 };
561 
562 class ForwardPointer : public Type {
563  public:
ForwardPointer(uint32_t id,SpvStorageClass sc)564   ForwardPointer(uint32_t id, SpvStorageClass sc)
565       : Type(kForwardPointer),
566         target_id_(id),
567         storage_class_(sc),
568         pointer_(nullptr) {}
569   ForwardPointer(const ForwardPointer&) = default;
570 
target_id()571   uint32_t target_id() const { return target_id_; }
SetTargetPointer(const Pointer * pointer)572   void SetTargetPointer(const Pointer* pointer) { pointer_ = pointer; }
storage_class()573   SpvStorageClass storage_class() const { return storage_class_; }
target_pointer()574   const Pointer* target_pointer() const { return pointer_; }
575 
576   std::string str() const override;
577 
AsForwardPointer()578   ForwardPointer* AsForwardPointer() override { return this; }
AsForwardPointer()579   const ForwardPointer* AsForwardPointer() const override { return this; }
580 
581   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
582 
583  private:
584   bool IsSameImpl(const Type* that, IsSameCache*) const override;
585 
586   uint32_t target_id_;
587   SpvStorageClass storage_class_;
588   const Pointer* pointer_;
589 };
590 
591 class CooperativeMatrixNV : public Type {
592  public:
593   CooperativeMatrixNV(const Type* type, const uint32_t scope,
594                       const uint32_t rows, const uint32_t columns);
595   CooperativeMatrixNV(const CooperativeMatrixNV&) = default;
596 
597   std::string str() const override;
598 
AsCooperativeMatrixNV()599   CooperativeMatrixNV* AsCooperativeMatrixNV() override { return this; }
AsCooperativeMatrixNV()600   const CooperativeMatrixNV* AsCooperativeMatrixNV() const override {
601     return this;
602   }
603 
604   size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
605 
component_type()606   const Type* component_type() const { return component_type_; }
scope_id()607   uint32_t scope_id() const { return scope_id_; }
rows_id()608   uint32_t rows_id() const { return rows_id_; }
columns_id()609   uint32_t columns_id() const { return columns_id_; }
610 
611  private:
612   bool IsSameImpl(const Type* that, IsSameCache*) const override;
613 
614   const Type* component_type_;
615   const uint32_t scope_id_;
616   const uint32_t rows_id_;
617   const uint32_t columns_id_;
618 };
619 
620 #define DefineParameterlessType(type, name)                                \
621   class type : public Type {                                               \
622    public:                                                                 \
623     type() : Type(k##type) {}                                              \
624     type(const type&) = default;                                           \
625                                                                            \
626     std::string str() const override { return #name; }                     \
627                                                                            \
628     type* As##type() override { return this; }                             \
629     const type* As##type() const override { return this; }                 \
630                                                                            \
631     size_t ComputeExtraStateHash(size_t hash, SeenTypes*) const override { \
632       return hash;                                                         \
633     }                                                                      \
634                                                                            \
635    private:                                                                \
636     bool IsSameImpl(const Type* that, IsSameCache*) const override {       \
637       return that->As##type() && HasSameDecorations(that);                 \
638     }                                                                      \
639   }
640 DefineParameterlessType(Void, void);
641 DefineParameterlessType(Bool, bool);
642 DefineParameterlessType(Sampler, sampler);
643 DefineParameterlessType(Event, event);
644 DefineParameterlessType(DeviceEvent, device_event);
645 DefineParameterlessType(ReserveId, reserve_id);
646 DefineParameterlessType(Queue, queue);
647 DefineParameterlessType(PipeStorage, pipe_storage);
648 DefineParameterlessType(NamedBarrier, named_barrier);
649 DefineParameterlessType(AccelerationStructureNV, accelerationStructureNV);
650 DefineParameterlessType(RayQueryKHR, rayQueryKHR);
651 #undef DefineParameterlessType
652 
653 }  // namespace analysis
654 }  // namespace opt
655 }  // namespace spvtools
656 
657 #endif  // SOURCE_OPT_TYPES_H_
658