• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/types.h"
16 
17 #include <algorithm>
18 #include <cassert>
19 #include <cstdint>
20 #include <sstream>
21 #include <string>
22 #include <unordered_set>
23 
24 #include "source/util/hash_combine.h"
25 #include "source/util/make_unique.h"
26 
27 namespace spvtools {
28 namespace opt {
29 namespace analysis {
30 
31 using spvtools::utils::hash_combine;
32 using U32VecVec = std::vector<std::vector<uint32_t>>;
33 
34 namespace {
35 
36 // Returns true if the two vector of vectors are identical.
CompareTwoVectors(const U32VecVec a,const U32VecVec b)37 bool CompareTwoVectors(const U32VecVec a, const U32VecVec b) {
38   const auto size = a.size();
39   if (size != b.size()) return false;
40 
41   if (size == 0) return true;
42   if (size == 1) return a.front() == b.front();
43 
44   std::vector<const std::vector<uint32_t>*> a_ptrs, b_ptrs;
45   a_ptrs.reserve(size);
46   a_ptrs.reserve(size);
47   for (uint32_t i = 0; i < size; ++i) {
48     a_ptrs.push_back(&a[i]);
49     b_ptrs.push_back(&b[i]);
50   }
51 
52   const auto cmp = [](const std::vector<uint32_t>* m,
53                       const std::vector<uint32_t>* n) {
54     return m->front() < n->front();
55   };
56 
57   std::sort(a_ptrs.begin(), a_ptrs.end(), cmp);
58   std::sort(b_ptrs.begin(), b_ptrs.end(), cmp);
59 
60   for (uint32_t i = 0; i < size; ++i) {
61     if (*a_ptrs[i] != *b_ptrs[i]) return false;
62   }
63   return true;
64 }
65 
66 }  // namespace
67 
GetDecorationStr() const68 std::string Type::GetDecorationStr() const {
69   std::ostringstream oss;
70   oss << "[[";
71   for (const auto& decoration : decorations_) {
72     oss << "(";
73     for (size_t i = 0; i < decoration.size(); ++i) {
74       oss << (i > 0 ? ", " : "");
75       oss << decoration.at(i);
76     }
77     oss << ")";
78   }
79   oss << "]]";
80   return oss.str();
81 }
82 
HasSameDecorations(const Type * that) const83 bool Type::HasSameDecorations(const Type* that) const {
84   return CompareTwoVectors(decorations_, that->decorations_);
85 }
86 
IsUniqueType() const87 bool Type::IsUniqueType() const {
88   switch (kind_) {
89     case kPointer:
90     case kStruct:
91     case kArray:
92     case kRuntimeArray:
93       return false;
94     default:
95       return true;
96   }
97 }
98 
Clone() const99 std::unique_ptr<Type> Type::Clone() const {
100   std::unique_ptr<Type> type;
101   switch (kind_) {
102 #define DeclareKindCase(kind)                   \
103   case k##kind:                                 \
104     type = MakeUnique<kind>(*this->As##kind()); \
105     break
106     DeclareKindCase(Void);
107     DeclareKindCase(Bool);
108     DeclareKindCase(Integer);
109     DeclareKindCase(Float);
110     DeclareKindCase(Vector);
111     DeclareKindCase(Matrix);
112     DeclareKindCase(Image);
113     DeclareKindCase(Sampler);
114     DeclareKindCase(SampledImage);
115     DeclareKindCase(Array);
116     DeclareKindCase(RuntimeArray);
117     DeclareKindCase(Struct);
118     DeclareKindCase(Opaque);
119     DeclareKindCase(Pointer);
120     DeclareKindCase(Function);
121     DeclareKindCase(Event);
122     DeclareKindCase(DeviceEvent);
123     DeclareKindCase(ReserveId);
124     DeclareKindCase(Queue);
125     DeclareKindCase(Pipe);
126     DeclareKindCase(ForwardPointer);
127     DeclareKindCase(PipeStorage);
128     DeclareKindCase(NamedBarrier);
129     DeclareKindCase(AccelerationStructureNV);
130     DeclareKindCase(CooperativeMatrixNV);
131     DeclareKindCase(CooperativeMatrixKHR);
132     DeclareKindCase(RayQueryKHR);
133     DeclareKindCase(HitObjectNV);
134 #undef DeclareKindCase
135     default:
136       assert(false && "Unhandled type");
137   }
138   return type;
139 }
140 
RemoveDecorations() const141 std::unique_ptr<Type> Type::RemoveDecorations() const {
142   std::unique_ptr<Type> type(Clone());
143   type->ClearDecorations();
144   return type;
145 }
146 
operator ==(const Type & other) const147 bool Type::operator==(const Type& other) const {
148   if (kind_ != other.kind_) return false;
149 
150   switch (kind_) {
151 #define DeclareKindCase(kind) \
152   case k##kind:               \
153     return As##kind()->IsSame(&other)
154     DeclareKindCase(Void);
155     DeclareKindCase(Bool);
156     DeclareKindCase(Integer);
157     DeclareKindCase(Float);
158     DeclareKindCase(Vector);
159     DeclareKindCase(Matrix);
160     DeclareKindCase(Image);
161     DeclareKindCase(Sampler);
162     DeclareKindCase(SampledImage);
163     DeclareKindCase(Array);
164     DeclareKindCase(RuntimeArray);
165     DeclareKindCase(Struct);
166     DeclareKindCase(Opaque);
167     DeclareKindCase(Pointer);
168     DeclareKindCase(Function);
169     DeclareKindCase(Event);
170     DeclareKindCase(DeviceEvent);
171     DeclareKindCase(ReserveId);
172     DeclareKindCase(Queue);
173     DeclareKindCase(Pipe);
174     DeclareKindCase(ForwardPointer);
175     DeclareKindCase(PipeStorage);
176     DeclareKindCase(NamedBarrier);
177     DeclareKindCase(AccelerationStructureNV);
178     DeclareKindCase(CooperativeMatrixNV);
179     DeclareKindCase(CooperativeMatrixKHR);
180     DeclareKindCase(RayQueryKHR);
181     DeclareKindCase(HitObjectNV);
182 #undef DeclareKindCase
183     default:
184       assert(false && "Unhandled type");
185       return false;
186   }
187 }
188 
ComputeHashValue(size_t hash,SeenTypes * seen) const189 size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
190   // Linear search through a dense, cache coherent vector is faster than O(log
191   // n) search in a complex data structure (eg std::set) for the generally small
192   // number of nodes.  It also skips the overhead of an new/delete per Type
193   // (when inserting/removing from a set).
194   if (std::find(seen->begin(), seen->end(), this) != seen->end()) {
195     return hash;
196   }
197 
198   seen->push_back(this);
199 
200   hash = hash_combine(hash, uint32_t(kind_));
201   for (const auto& d : decorations_) {
202     hash = hash_combine(hash, d);
203   }
204 
205   switch (kind_) {
206 #define DeclareKindCase(type)                             \
207   case k##type:                                           \
208     hash = As##type()->ComputeExtraStateHash(hash, seen); \
209     break
210     DeclareKindCase(Void);
211     DeclareKindCase(Bool);
212     DeclareKindCase(Integer);
213     DeclareKindCase(Float);
214     DeclareKindCase(Vector);
215     DeclareKindCase(Matrix);
216     DeclareKindCase(Image);
217     DeclareKindCase(Sampler);
218     DeclareKindCase(SampledImage);
219     DeclareKindCase(Array);
220     DeclareKindCase(RuntimeArray);
221     DeclareKindCase(Struct);
222     DeclareKindCase(Opaque);
223     DeclareKindCase(Pointer);
224     DeclareKindCase(Function);
225     DeclareKindCase(Event);
226     DeclareKindCase(DeviceEvent);
227     DeclareKindCase(ReserveId);
228     DeclareKindCase(Queue);
229     DeclareKindCase(Pipe);
230     DeclareKindCase(ForwardPointer);
231     DeclareKindCase(PipeStorage);
232     DeclareKindCase(NamedBarrier);
233     DeclareKindCase(AccelerationStructureNV);
234     DeclareKindCase(CooperativeMatrixNV);
235     DeclareKindCase(CooperativeMatrixKHR);
236     DeclareKindCase(RayQueryKHR);
237     DeclareKindCase(HitObjectNV);
238 #undef DeclareKindCase
239     default:
240       assert(false && "Unhandled type");
241       break;
242   }
243 
244   seen->pop_back();
245   return hash;
246 }
247 
HashValue() const248 size_t Type::HashValue() const {
249   SeenTypes seen;
250   return ComputeHashValue(0, &seen);
251 }
252 
NumberOfComponents() const253 uint64_t Type::NumberOfComponents() const {
254   switch (kind()) {
255     case kVector:
256       return AsVector()->element_count();
257     case kMatrix:
258       return AsMatrix()->element_count();
259     case kArray: {
260       Array::LengthInfo length_info = AsArray()->length_info();
261       if (length_info.words[0] != Array::LengthInfo::kConstant) {
262         return UINT64_MAX;
263       }
264       assert(length_info.words.size() <= 3 &&
265              "The size of the array could not fit size_t.");
266       uint64_t length = 0;
267       length |= length_info.words[1];
268       if (length_info.words.size() > 2) {
269         length |= static_cast<uint64_t>(length_info.words[2]) << 32;
270       }
271       return length;
272     }
273     case kRuntimeArray:
274       return UINT64_MAX;
275     case kStruct:
276       return AsStruct()->element_types().size();
277     default:
278       return 0;
279   }
280 }
281 
IsSameImpl(const Type * that,IsSameCache *) const282 bool Integer::IsSameImpl(const Type* that, IsSameCache*) const {
283   const Integer* it = that->AsInteger();
284   return it && width_ == it->width_ && signed_ == it->signed_ &&
285          HasSameDecorations(that);
286 }
287 
str() const288 std::string Integer::str() const {
289   std::ostringstream oss;
290   oss << (signed_ ? "s" : "u") << "int" << width_;
291   return oss.str();
292 }
293 
ComputeExtraStateHash(size_t hash,SeenTypes *) const294 size_t Integer::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
295   return hash_combine(hash, width_, signed_);
296 }
297 
IsSameImpl(const Type * that,IsSameCache *) const298 bool Float::IsSameImpl(const Type* that, IsSameCache*) const {
299   const Float* ft = that->AsFloat();
300   return ft && width_ == ft->width_ && HasSameDecorations(that);
301 }
302 
str() const303 std::string Float::str() const {
304   std::ostringstream oss;
305   oss << "float" << width_;
306   return oss.str();
307 }
308 
ComputeExtraStateHash(size_t hash,SeenTypes *) const309 size_t Float::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
310   return hash_combine(hash, width_);
311 }
312 
Vector(const Type * type,uint32_t count)313 Vector::Vector(const Type* type, uint32_t count)
314     : Type(kVector), element_type_(type), count_(count) {
315   assert(type->AsBool() || type->AsInteger() || type->AsFloat());
316 }
317 
IsSameImpl(const Type * that,IsSameCache * seen) const318 bool Vector::IsSameImpl(const Type* that, IsSameCache* seen) const {
319   const Vector* vt = that->AsVector();
320   if (!vt) return false;
321   return count_ == vt->count_ &&
322          element_type_->IsSameImpl(vt->element_type_, seen) &&
323          HasSameDecorations(that);
324 }
325 
str() const326 std::string Vector::str() const {
327   std::ostringstream oss;
328   oss << "<" << element_type_->str() << ", " << count_ << ">";
329   return oss.str();
330 }
331 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const332 size_t Vector::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
333   // prefer form that doesn't require push/pop from stack: add state and
334   // make tail call.
335   hash = hash_combine(hash, count_);
336   return element_type_->ComputeHashValue(hash, seen);
337 }
338 
Matrix(const Type * type,uint32_t count)339 Matrix::Matrix(const Type* type, uint32_t count)
340     : Type(kMatrix), element_type_(type), count_(count) {
341   assert(type->AsVector());
342 }
343 
IsSameImpl(const Type * that,IsSameCache * seen) const344 bool Matrix::IsSameImpl(const Type* that, IsSameCache* seen) const {
345   const Matrix* mt = that->AsMatrix();
346   if (!mt) return false;
347   return count_ == mt->count_ &&
348          element_type_->IsSameImpl(mt->element_type_, seen) &&
349          HasSameDecorations(that);
350 }
351 
str() const352 std::string Matrix::str() const {
353   std::ostringstream oss;
354   oss << "<" << element_type_->str() << ", " << count_ << ">";
355   return oss.str();
356 }
357 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const358 size_t Matrix::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
359   hash = hash_combine(hash, count_);
360   return element_type_->ComputeHashValue(hash, seen);
361 }
362 
Image(Type * type,spv::Dim dimen,uint32_t d,bool array,bool multisample,uint32_t sampling,spv::ImageFormat f,spv::AccessQualifier qualifier)363 Image::Image(Type* type, spv::Dim dimen, uint32_t d, bool array,
364              bool multisample, uint32_t sampling, spv::ImageFormat f,
365              spv::AccessQualifier qualifier)
366     : Type(kImage),
367       sampled_type_(type),
368       dim_(dimen),
369       depth_(d),
370       arrayed_(array),
371       ms_(multisample),
372       sampled_(sampling),
373       format_(f),
374       access_qualifier_(qualifier) {
375   // TODO(antiagainst): check sampled_type
376 }
377 
IsSameImpl(const Type * that,IsSameCache * seen) const378 bool Image::IsSameImpl(const Type* that, IsSameCache* seen) const {
379   const Image* it = that->AsImage();
380   if (!it) return false;
381   return dim_ == it->dim_ && depth_ == it->depth_ && arrayed_ == it->arrayed_ &&
382          ms_ == it->ms_ && sampled_ == it->sampled_ && format_ == it->format_ &&
383          access_qualifier_ == it->access_qualifier_ &&
384          sampled_type_->IsSameImpl(it->sampled_type_, seen) &&
385          HasSameDecorations(that);
386 }
387 
str() const388 std::string Image::str() const {
389   std::ostringstream oss;
390   oss << "image(" << sampled_type_->str() << ", " << uint32_t(dim_) << ", "
391       << depth_ << ", " << arrayed_ << ", " << ms_ << ", " << sampled_ << ", "
392       << uint32_t(format_) << ", " << uint32_t(access_qualifier_) << ")";
393   return oss.str();
394 }
395 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const396 size_t Image::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
397   hash = hash_combine(hash, uint32_t(dim_), depth_, arrayed_, ms_, sampled_,
398                       uint32_t(format_), uint32_t(access_qualifier_));
399   return sampled_type_->ComputeHashValue(hash, seen);
400 }
401 
IsSameImpl(const Type * that,IsSameCache * seen) const402 bool SampledImage::IsSameImpl(const Type* that, IsSameCache* seen) const {
403   const SampledImage* sit = that->AsSampledImage();
404   if (!sit) return false;
405   return image_type_->IsSameImpl(sit->image_type_, seen) &&
406          HasSameDecorations(that);
407 }
408 
str() const409 std::string SampledImage::str() const {
410   std::ostringstream oss;
411   oss << "sampled_image(" << image_type_->str() << ")";
412   return oss.str();
413 }
414 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const415 size_t SampledImage::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
416   return image_type_->ComputeHashValue(hash, seen);
417 }
418 
Array(const Type * type,const Array::LengthInfo & length_info_arg)419 Array::Array(const Type* type, const Array::LengthInfo& length_info_arg)
420     : Type(kArray), element_type_(type), length_info_(length_info_arg) {
421   assert(type != nullptr);
422   assert(!type->AsVoid());
423   // We always have a word to say which case we're in, followed
424   // by at least one more word.
425   assert(length_info_arg.words.size() >= 2);
426 }
427 
IsSameImpl(const Type * that,IsSameCache * seen) const428 bool Array::IsSameImpl(const Type* that, IsSameCache* seen) const {
429   const Array* at = that->AsArray();
430   if (!at) return false;
431   bool is_same = element_type_->IsSameImpl(at->element_type_, seen);
432   is_same = is_same && HasSameDecorations(that);
433   is_same = is_same && (length_info_.words == at->length_info_.words);
434   return is_same;
435 }
436 
str() const437 std::string Array::str() const {
438   std::ostringstream oss;
439   oss << "[" << element_type_->str() << ", id(" << LengthId() << "), words(";
440   const char* spacer = "";
441   for (auto w : length_info_.words) {
442     oss << spacer << w;
443     spacer = ",";
444   }
445   oss << ")]";
446   return oss.str();
447 }
448 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const449 size_t Array::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
450   hash = hash_combine(hash, length_info_.words);
451   return element_type_->ComputeHashValue(hash, seen);
452 }
453 
ReplaceElementType(const Type * type)454 void Array::ReplaceElementType(const Type* type) { element_type_ = type; }
455 
GetConstantLengthInfo(uint32_t const_id,uint32_t length) const456 Array::LengthInfo Array::GetConstantLengthInfo(uint32_t const_id,
457                                                uint32_t length) const {
458   std::vector<uint32_t> extra_words{LengthInfo::Case::kConstant, length};
459   return {const_id, extra_words};
460 }
461 
RuntimeArray(const Type * type)462 RuntimeArray::RuntimeArray(const Type* type)
463     : Type(kRuntimeArray), element_type_(type) {
464   assert(!type->AsVoid());
465 }
466 
IsSameImpl(const Type * that,IsSameCache * seen) const467 bool RuntimeArray::IsSameImpl(const Type* that, IsSameCache* seen) const {
468   const RuntimeArray* rat = that->AsRuntimeArray();
469   if (!rat) return false;
470   return element_type_->IsSameImpl(rat->element_type_, seen) &&
471          HasSameDecorations(that);
472 }
473 
str() const474 std::string RuntimeArray::str() const {
475   std::ostringstream oss;
476   oss << "[" << element_type_->str() << "]";
477   return oss.str();
478 }
479 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const480 size_t RuntimeArray::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
481   return element_type_->ComputeHashValue(hash, seen);
482 }
483 
ReplaceElementType(const Type * type)484 void RuntimeArray::ReplaceElementType(const Type* type) {
485   element_type_ = type;
486 }
487 
Struct(const std::vector<const Type * > & types)488 Struct::Struct(const std::vector<const Type*>& types)
489     : Type(kStruct), element_types_(types) {
490   for (const auto* t : types) {
491     (void)t;
492     assert(!t->AsVoid());
493   }
494 }
495 
AddMemberDecoration(uint32_t index,std::vector<uint32_t> && decoration)496 void Struct::AddMemberDecoration(uint32_t index,
497                                  std::vector<uint32_t>&& decoration) {
498   if (index >= element_types_.size()) {
499     assert(0 && "index out of bound");
500     return;
501   }
502 
503   element_decorations_[index].push_back(std::move(decoration));
504 }
505 
IsSameImpl(const Type * that,IsSameCache * seen) const506 bool Struct::IsSameImpl(const Type* that, IsSameCache* seen) const {
507   const Struct* st = that->AsStruct();
508   if (!st) return false;
509   if (element_types_.size() != st->element_types_.size()) return false;
510   const auto size = element_decorations_.size();
511   if (size != st->element_decorations_.size()) return false;
512   if (!HasSameDecorations(that)) return false;
513 
514   for (size_t i = 0; i < element_types_.size(); ++i) {
515     if (!element_types_[i]->IsSameImpl(st->element_types_[i], seen))
516       return false;
517   }
518   for (const auto& p : element_decorations_) {
519     if (st->element_decorations_.count(p.first) == 0) return false;
520     if (!CompareTwoVectors(p.second, st->element_decorations_.at(p.first)))
521       return false;
522   }
523   return true;
524 }
525 
str() const526 std::string Struct::str() const {
527   std::ostringstream oss;
528   oss << "{";
529   const size_t count = element_types_.size();
530   for (size_t i = 0; i < count; ++i) {
531     oss << element_types_[i]->str();
532     if (i + 1 != count) oss << ", ";
533   }
534   oss << "}";
535   return oss.str();
536 }
537 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const538 size_t Struct::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
539   for (auto* t : element_types_) {
540     hash = t->ComputeHashValue(hash, seen);
541   }
542   for (const auto& pair : element_decorations_) {
543     hash = hash_combine(hash, pair.first, pair.second);
544   }
545   return hash;
546 }
547 
IsSameImpl(const Type * that,IsSameCache *) const548 bool Opaque::IsSameImpl(const Type* that, IsSameCache*) const {
549   const Opaque* ot = that->AsOpaque();
550   if (!ot) return false;
551   return name_ == ot->name_ && HasSameDecorations(that);
552 }
553 
str() const554 std::string Opaque::str() const {
555   std::ostringstream oss;
556   oss << "opaque('" << name_ << "')";
557   return oss.str();
558 }
559 
ComputeExtraStateHash(size_t hash,SeenTypes *) const560 size_t Opaque::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
561   return hash_combine(hash, name_);
562 }
563 
Pointer(const Type * type,spv::StorageClass sc)564 Pointer::Pointer(const Type* type, spv::StorageClass sc)
565     : Type(kPointer), pointee_type_(type), storage_class_(sc) {}
566 
IsSameImpl(const Type * that,IsSameCache * seen) const567 bool Pointer::IsSameImpl(const Type* that, IsSameCache* seen) const {
568   const Pointer* pt = that->AsPointer();
569   if (!pt) return false;
570   if (storage_class_ != pt->storage_class_) return false;
571   auto p = seen->insert(std::make_pair(this, that->AsPointer()));
572   if (!p.second) {
573     return true;
574   }
575   bool same_pointee = pointee_type_->IsSameImpl(pt->pointee_type_, seen);
576   seen->erase(p.first);
577   if (!same_pointee) {
578     return false;
579   }
580   return HasSameDecorations(that);
581 }
582 
str() const583 std::string Pointer::str() const {
584   std::ostringstream os;
585   os << pointee_type_->str() << " " << static_cast<uint32_t>(storage_class_)
586      << "*";
587   return os.str();
588 }
589 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const590 size_t Pointer::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
591   hash = hash_combine(hash, uint32_t(storage_class_));
592   return pointee_type_->ComputeHashValue(hash, seen);
593 }
594 
SetPointeeType(const Type * type)595 void Pointer::SetPointeeType(const Type* type) { pointee_type_ = type; }
596 
Function(const Type * ret_type,const std::vector<const Type * > & params)597 Function::Function(const Type* ret_type, const std::vector<const Type*>& params)
598     : Type(kFunction), return_type_(ret_type), param_types_(params) {}
599 
Function(const Type * ret_type,std::vector<const Type * > & params)600 Function::Function(const Type* ret_type, std::vector<const Type*>& params)
601     : Type(kFunction), return_type_(ret_type), param_types_(params) {}
602 
IsSameImpl(const Type * that,IsSameCache * seen) const603 bool Function::IsSameImpl(const Type* that, IsSameCache* seen) const {
604   const Function* ft = that->AsFunction();
605   if (!ft) return false;
606   if (!return_type_->IsSameImpl(ft->return_type_, seen)) return false;
607   if (param_types_.size() != ft->param_types_.size()) return false;
608   for (size_t i = 0; i < param_types_.size(); ++i) {
609     if (!param_types_[i]->IsSameImpl(ft->param_types_[i], seen)) return false;
610   }
611   return HasSameDecorations(that);
612 }
613 
str() const614 std::string Function::str() const {
615   std::ostringstream oss;
616   const size_t count = param_types_.size();
617   oss << "(";
618   for (size_t i = 0; i < count; ++i) {
619     oss << param_types_[i]->str();
620     if (i + 1 != count) oss << ", ";
621   }
622   oss << ") -> " << return_type_->str();
623   return oss.str();
624 }
625 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const626 size_t Function::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
627   for (const auto* t : param_types_) {
628     hash = t->ComputeHashValue(hash, seen);
629   }
630   return return_type_->ComputeHashValue(hash, seen);
631 }
632 
SetReturnType(const Type * type)633 void Function::SetReturnType(const Type* type) { return_type_ = type; }
634 
IsSameImpl(const Type * that,IsSameCache *) const635 bool Pipe::IsSameImpl(const Type* that, IsSameCache*) const {
636   const Pipe* pt = that->AsPipe();
637   if (!pt) return false;
638   return access_qualifier_ == pt->access_qualifier_ && HasSameDecorations(that);
639 }
640 
str() const641 std::string Pipe::str() const {
642   std::ostringstream oss;
643   oss << "pipe(" << uint32_t(access_qualifier_) << ")";
644   return oss.str();
645 }
646 
ComputeExtraStateHash(size_t hash,SeenTypes *) const647 size_t Pipe::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
648   return hash_combine(hash, uint32_t(access_qualifier_));
649 }
650 
IsSameImpl(const Type * that,IsSameCache *) const651 bool ForwardPointer::IsSameImpl(const Type* that, IsSameCache*) const {
652   const ForwardPointer* fpt = that->AsForwardPointer();
653   if (!fpt) return false;
654   return (pointer_ && fpt->pointer_ ? *pointer_ == *fpt->pointer_
655                                     : target_id_ == fpt->target_id_) &&
656          storage_class_ == fpt->storage_class_ && HasSameDecorations(that);
657 }
658 
str() const659 std::string ForwardPointer::str() const {
660   std::ostringstream oss;
661   oss << "forward_pointer(";
662   if (pointer_ != nullptr) {
663     oss << pointer_->str();
664   } else {
665     oss << target_id_;
666   }
667   oss << ")";
668   return oss.str();
669 }
670 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const671 size_t ForwardPointer::ComputeExtraStateHash(size_t hash,
672                                              SeenTypes* seen) const {
673   hash = hash_combine(hash, target_id_, uint32_t(storage_class_));
674   if (pointer_) hash = pointer_->ComputeHashValue(hash, seen);
675   return hash;
676 }
677 
CooperativeMatrixNV(const Type * type,const uint32_t scope,const uint32_t rows,const uint32_t columns)678 CooperativeMatrixNV::CooperativeMatrixNV(const Type* type, const uint32_t scope,
679                                          const uint32_t rows,
680                                          const uint32_t columns)
681     : Type(kCooperativeMatrixNV),
682       component_type_(type),
683       scope_id_(scope),
684       rows_id_(rows),
685       columns_id_(columns) {
686   assert(type != nullptr);
687   assert(scope != 0);
688   assert(rows != 0);
689   assert(columns != 0);
690 }
691 
str() const692 std::string CooperativeMatrixNV::str() const {
693   std::ostringstream oss;
694   oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
695       << ", " << columns_id_ << ">";
696   return oss.str();
697 }
698 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const699 size_t CooperativeMatrixNV::ComputeExtraStateHash(size_t hash,
700                                                   SeenTypes* seen) const {
701   hash = hash_combine(hash, scope_id_, rows_id_, columns_id_);
702   return component_type_->ComputeHashValue(hash, seen);
703 }
704 
IsSameImpl(const Type * that,IsSameCache * seen) const705 bool CooperativeMatrixNV::IsSameImpl(const Type* that,
706                                      IsSameCache* seen) const {
707   const CooperativeMatrixNV* mt = that->AsCooperativeMatrixNV();
708   if (!mt) return false;
709   return component_type_->IsSameImpl(mt->component_type_, seen) &&
710          scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
711          columns_id_ == mt->columns_id_ && HasSameDecorations(that);
712 }
713 
CooperativeMatrixKHR(const Type * type,const uint32_t scope,const uint32_t rows,const uint32_t columns,const uint32_t use)714 CooperativeMatrixKHR::CooperativeMatrixKHR(const Type* type,
715                                            const uint32_t scope,
716                                            const uint32_t rows,
717                                            const uint32_t columns,
718                                            const uint32_t use)
719     : Type(kCooperativeMatrixKHR),
720       component_type_(type),
721       scope_id_(scope),
722       rows_id_(rows),
723       columns_id_(columns),
724       use_id_(use) {
725   assert(type != nullptr);
726   assert(scope != 0);
727   assert(rows != 0);
728   assert(columns != 0);
729 }
730 
str() const731 std::string CooperativeMatrixKHR::str() const {
732   std::ostringstream oss;
733   oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
734       << ", " << columns_id_ << ", " << use_id_ << ">";
735   return oss.str();
736 }
737 
ComputeExtraStateHash(size_t hash,SeenTypes * seen) const738 size_t CooperativeMatrixKHR::ComputeExtraStateHash(size_t hash,
739                                                    SeenTypes* seen) const {
740   hash = hash_combine(hash, scope_id_, rows_id_, columns_id_, use_id_);
741   return component_type_->ComputeHashValue(hash, seen);
742 }
743 
IsSameImpl(const Type * that,IsSameCache * seen) const744 bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
745                                       IsSameCache* seen) const {
746   const CooperativeMatrixKHR* mt = that->AsCooperativeMatrixKHR();
747   if (!mt) return false;
748   return component_type_->IsSameImpl(mt->component_type_, seen) &&
749          scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
750          columns_id_ == mt->columns_id_ && HasSameDecorations(that);
751 }
752 
753 }  // namespace analysis
754 }  // namespace opt
755 }  // namespace spvtools
756