1 // Copyright (c) 2023 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <stddef.h> 16 17 #include <algorithm> 18 #include <cassert> 19 #include <cstdint> 20 #include <functional> 21 #include <initializer_list> 22 #include <iterator> 23 #include <limits> 24 #include <type_traits> 25 #include <vector> 26 27 #ifndef SOURCE_ENUM_SET_H_ 28 #define SOURCE_ENUM_SET_H_ 29 30 #include "source/latest_version_spirv_header.h" 31 32 namespace spvtools { 33 34 // This container is optimized to store and retrieve unsigned enum values. 35 // The base model for this implementation is an open-addressing hashtable with 36 // linear probing. For small enums (max index < 64), all operations are O(1). 37 // 38 // - Enums are stored in buckets (64 contiguous values max per bucket) 39 // - Buckets ranges don't overlap, but don't have to be contiguous. 40 // - Enums are packed into 64-bits buckets, using 1 bit per enum value. 41 // 42 // Example: 43 // - MyEnum { A = 0, B = 1, C = 64, D = 65 } 44 // - 2 buckets are required: 45 // - bucket 0, storing values in the range [ 0; 64[ 46 // - bucket 1, storing values in the range [64; 128[ 47 // 48 // - Buckets are stored in a sorted vector (sorted by bucket range). 49 // - Retrieval is done by computing the theoretical bucket index using the enum 50 // value, and 51 // doing a linear scan from this position. 52 // - Insertion is done by retrieving the bucket and either: 53 // - inserting a new bucket in the sorted vector when no buckets has a 54 // compatible range. 55 // - setting the corresponding bit in the bucket. 56 // This means insertion in the middle/beginning can cause a memmove when no 57 // bucket is available. In our case, this happens at most 23 times for the 58 // largest enum we have (Opcodes). 59 template <typename T> 60 class EnumSet { 61 private: 62 using BucketType = uint64_t; 63 using ElementType = std::underlying_type_t<T>; 64 static_assert(std::is_enum_v<T>, "EnumSets only works with enums."); 65 static_assert(std::is_signed_v<ElementType> == false, 66 "EnumSet doesn't supports signed enums."); 67 68 // Each bucket can hold up to `kBucketSize` distinct, contiguous enum values. 69 // The first value a bucket can hold must be aligned on `kBucketSize`. 70 struct Bucket { 71 // bit mask to store `kBucketSize` enums. 72 BucketType data; 73 // 1st enum this bucket can represent. 74 T start; 75 76 friend bool operator==(const Bucket& lhs, const Bucket& rhs) { 77 return lhs.start == rhs.start && lhs.data == rhs.data; 78 } 79 }; 80 81 // How many distinct values can a bucket hold? 1 bit per value. 82 static constexpr size_t kBucketSize = sizeof(BucketType) * 8ULL; 83 84 public: 85 class Iterator { 86 public: 87 typedef Iterator self_type; 88 typedef T value_type; 89 typedef T& reference; 90 typedef T* pointer; 91 typedef std::forward_iterator_tag iterator_category; 92 typedef size_t difference_type; 93 Iterator(const Iterator & other)94 Iterator(const Iterator& other) 95 : set_(other.set_), 96 bucketIndex_(other.bucketIndex_), 97 bucketOffset_(other.bucketOffset_) {} 98 99 Iterator& operator++() { 100 do { 101 if (bucketIndex_ >= set_->buckets_.size()) { 102 bucketIndex_ = set_->buckets_.size(); 103 bucketOffset_ = 0; 104 break; 105 } 106 107 if (bucketOffset_ + 1 == kBucketSize) { 108 bucketOffset_ = 0; 109 ++bucketIndex_; 110 } else { 111 ++bucketOffset_; 112 } 113 114 } while (bucketIndex_ < set_->buckets_.size() && 115 !set_->HasEnumAt(bucketIndex_, bucketOffset_)); 116 return *this; 117 } 118 119 Iterator operator++(int) { 120 Iterator old = *this; 121 operator++(); 122 return old; 123 } 124 125 T operator*() const { 126 assert(set_->HasEnumAt(bucketIndex_, bucketOffset_) && 127 "operator*() called on an invalid iterator."); 128 return GetValueFromBucket(set_->buckets_[bucketIndex_], bucketOffset_); 129 } 130 131 bool operator!=(const Iterator& other) const { 132 return set_ != other.set_ || bucketOffset_ != other.bucketOffset_ || 133 bucketIndex_ != other.bucketIndex_; 134 } 135 136 bool operator==(const Iterator& other) const { 137 return !(operator!=(other)); 138 } 139 140 Iterator& operator=(const Iterator& other) { 141 set_ = other.set_; 142 bucketIndex_ = other.bucketIndex_; 143 bucketOffset_ = other.bucketOffset_; 144 return *this; 145 } 146 147 private: Iterator(const EnumSet * set,size_t bucketIndex,ElementType bucketOffset)148 Iterator(const EnumSet* set, size_t bucketIndex, ElementType bucketOffset) 149 : set_(set), bucketIndex_(bucketIndex), bucketOffset_(bucketOffset) {} 150 151 private: 152 const EnumSet* set_ = nullptr; 153 // Index of the bucket in the vector. 154 size_t bucketIndex_ = 0; 155 // Offset in bits in the current bucket. 156 ElementType bucketOffset_ = 0; 157 158 friend class EnumSet; 159 }; 160 161 // Required to allow the use of std::inserter. 162 using value_type = T; 163 using const_iterator = Iterator; 164 using iterator = Iterator; 165 166 public: cbegin()167 iterator cbegin() const noexcept { 168 auto it = iterator(this, /* bucketIndex= */ 0, /* bucketOffset= */ 0); 169 if (buckets_.size() == 0) { 170 return it; 171 } 172 173 // The iterator has the logic to find the next valid bit. If the value 0 174 // is not stored, use it to find the next valid bit. 175 if (!HasEnumAt(it.bucketIndex_, it.bucketOffset_)) { 176 ++it; 177 } 178 179 return it; 180 } 181 begin()182 iterator begin() const noexcept { return cbegin(); } 183 cend()184 iterator cend() const noexcept { 185 return iterator(this, buckets_.size(), /* bucketOffset= */ 0); 186 } 187 end()188 iterator end() const noexcept { return cend(); } 189 190 // Creates an empty set. EnumSet()191 EnumSet() : buckets_(0), size_(0) {} 192 193 // Creates a set and store `value` in it. EnumSet(T value)194 EnumSet(T value) : EnumSet() { insert(value); } 195 196 // Creates a set and stores each `values` in it. EnumSet(std::initializer_list<T> values)197 EnumSet(std::initializer_list<T> values) : EnumSet() { 198 for (auto item : values) { 199 insert(item); 200 } 201 } 202 203 // Creates a set, and insert `count` enum values pointed by `array` in it. EnumSet(ElementType count,const T * array)204 EnumSet(ElementType count, const T* array) : EnumSet() { 205 for (ElementType i = 0; i < count; i++) { 206 insert(array[i]); 207 } 208 } 209 210 // Creates a set initialized with the content of the range [begin; end[. 211 template <class InputIt> EnumSet(InputIt begin,InputIt end)212 EnumSet(InputIt begin, InputIt end) : EnumSet() { 213 for (; begin != end; ++begin) { 214 insert(*begin); 215 } 216 } 217 218 // Copies the EnumSet `other` into a new EnumSet. EnumSet(const EnumSet & other)219 EnumSet(const EnumSet& other) 220 : buckets_(other.buckets_), size_(other.size_) {} 221 222 // Moves the EnumSet `other` into a new EnumSet. EnumSet(EnumSet && other)223 EnumSet(EnumSet&& other) 224 : buckets_(std::move(other.buckets_)), size_(other.size_) {} 225 226 // Deep-copies the EnumSet `other` into this EnumSet. 227 EnumSet& operator=(const EnumSet& other) { 228 buckets_ = other.buckets_; 229 size_ = other.size_; 230 return *this; 231 } 232 233 // Matches std::unordered_set::insert behavior. insert(const T & value)234 std::pair<iterator, bool> insert(const T& value) { 235 const size_t index = FindBucketForValue(value); 236 const ElementType offset = ComputeBucketOffset(value); 237 238 if (index >= buckets_.size() || 239 buckets_[index].start != ComputeBucketStart(value)) { 240 size_ += 1; 241 InsertBucketFor(index, value); 242 return std::make_pair(Iterator(this, index, offset), true); 243 } 244 245 auto& bucket = buckets_[index]; 246 const auto mask = ComputeMaskForValue(value); 247 if (bucket.data & mask) { 248 return std::make_pair(Iterator(this, index, offset), false); 249 } 250 251 size_ += 1; 252 bucket.data |= ComputeMaskForValue(value); 253 return std::make_pair(Iterator(this, index, offset), true); 254 } 255 256 // Inserts `value` in the set if possible. 257 // Similar to `std::unordered_set::insert`, except the hint is ignored. 258 // Returns an iterator to the inserted element, or the element preventing 259 // insertion. insert(const_iterator,const T & value)260 iterator insert(const_iterator, const T& value) { 261 return insert(value).first; 262 } 263 264 // Inserts `value` in the set if possible. 265 // Similar to `std::unordered_set::insert`, except the hint is ignored. 266 // Returns an iterator to the inserted element, or the element preventing 267 // insertion. insert(const_iterator,T && value)268 iterator insert(const_iterator, T&& value) { return insert(value).first; } 269 270 // Inserts all the values in the range [`first`; `last[. 271 // Similar to `std::unordered_set::insert`. 272 template <class InputIt> insert(InputIt first,InputIt last)273 void insert(InputIt first, InputIt last) { 274 for (auto it = first; it != last; ++it) { 275 insert(*it); 276 } 277 } 278 279 // Removes the value `value` into the set. 280 // Similar to `std::unordered_set::erase`. 281 // Returns the number of erased elements. erase(const T & value)282 size_t erase(const T& value) { 283 const size_t index = FindBucketForValue(value); 284 if (index >= buckets_.size() || 285 buckets_[index].start != ComputeBucketStart(value)) { 286 return 0; 287 } 288 289 auto& bucket = buckets_[index]; 290 const auto mask = ComputeMaskForValue(value); 291 if (!(bucket.data & mask)) { 292 return 0; 293 } 294 295 size_ -= 1; 296 bucket.data &= ~mask; 297 if (bucket.data == 0) { 298 buckets_.erase(buckets_.cbegin() + index); 299 } 300 return 1; 301 } 302 303 // Returns true if `value` is present in the set. contains(T value)304 bool contains(T value) const { 305 const size_t index = FindBucketForValue(value); 306 if (index >= buckets_.size() || 307 buckets_[index].start != ComputeBucketStart(value)) { 308 return false; 309 } 310 auto& bucket = buckets_[index]; 311 return bucket.data & ComputeMaskForValue(value); 312 } 313 314 // Returns the 1 if `value` is present in the set, `0` otherwise. count(T value)315 inline size_t count(T value) const { return contains(value) ? 1 : 0; } 316 317 // Returns true if the set is holds no values. empty()318 inline bool empty() const { return size_ == 0; } 319 320 // Returns the number of enums stored in this set. size()321 size_t size() const { return size_; } 322 323 // Returns true if this set contains at least one value contained in `in_set`. 324 // Note: If `in_set` is empty, this function returns true. HasAnyOf(const EnumSet<T> & in_set)325 bool HasAnyOf(const EnumSet<T>& in_set) const { 326 if (in_set.empty()) { 327 return true; 328 } 329 330 auto lhs = buckets_.cbegin(); 331 auto rhs = in_set.buckets_.cbegin(); 332 333 while (lhs != buckets_.cend() && rhs != in_set.buckets_.cend()) { 334 if (lhs->start == rhs->start) { 335 if (lhs->data & rhs->data) { 336 // At least 1 bit is shared. Early return. 337 return true; 338 } 339 340 lhs++; 341 rhs++; 342 continue; 343 } 344 345 // LHS bucket is smaller than the current RHS bucket. Catching up on RHS. 346 if (lhs->start < rhs->start) { 347 lhs++; 348 continue; 349 } 350 351 // Otherwise, RHS needs to catch up on LHS. 352 rhs++; 353 } 354 355 return false; 356 } 357 358 private: 359 // Returns the index of the last bucket in which `value` could be stored. ComputeLargestPossibleBucketIndexFor(T value)360 static constexpr inline size_t ComputeLargestPossibleBucketIndexFor(T value) { 361 return static_cast<size_t>(value) / kBucketSize; 362 } 363 364 // Returns the smallest enum value that could be contained in the same bucket 365 // as `value`. ComputeBucketStart(T value)366 static constexpr inline T ComputeBucketStart(T value) { 367 return static_cast<T>(kBucketSize * 368 ComputeLargestPossibleBucketIndexFor(value)); 369 } 370 371 // Returns the index of the bit that corresponds to `value` in the bucket. ComputeBucketOffset(T value)372 static constexpr inline ElementType ComputeBucketOffset(T value) { 373 return static_cast<ElementType>(value) % kBucketSize; 374 } 375 376 // Returns the bitmask used to represent the enum `value` in its bucket. ComputeMaskForValue(T value)377 static constexpr inline BucketType ComputeMaskForValue(T value) { 378 return 1ULL << ComputeBucketOffset(value); 379 } 380 381 // Returns the `enum` stored in `bucket` at `offset`. 382 // `offset` is the bit-offset in the bucket storage. GetValueFromBucket(const Bucket & bucket,BucketType offset)383 static constexpr inline T GetValueFromBucket(const Bucket& bucket, 384 BucketType offset) { 385 return static_cast<T>(static_cast<ElementType>(bucket.start) + offset); 386 } 387 388 // For a given enum `value`, finds the bucket index that could contain this 389 // value. If no such bucket is found, the index at which the new bucket should 390 // be inserted is returned. FindBucketForValue(T value)391 size_t FindBucketForValue(T value) const { 392 // Set is empty, insert at 0. 393 if (buckets_.size() == 0) { 394 return 0; 395 } 396 397 const T wanted_start = ComputeBucketStart(value); 398 assert(buckets_.size() > 0 && 399 "Size must not be 0 here. Has the code above changed?"); 400 size_t index = std::min(buckets_.size() - 1, 401 ComputeLargestPossibleBucketIndexFor(value)); 402 403 // This loops behaves like std::upper_bound with a reverse iterator. 404 // Buckets are sorted. 3 main cases: 405 // - The bucket matches 406 // => returns the bucket index. 407 // - The found bucket is larger 408 // => scans left until it finds the correct bucket, or insertion point. 409 // - The found bucket is smaller 410 // => We are at the end, so we return past-end index for insertion. 411 for (; buckets_[index].start >= wanted_start; index--) { 412 if (index == 0) { 413 return 0; 414 } 415 } 416 417 return index + 1; 418 } 419 420 // Creates a new bucket to store `value` and inserts it at `index`. 421 // If the `index` is past the end, the bucket is inserted at the end of the 422 // vector. InsertBucketFor(size_t index,T value)423 void InsertBucketFor(size_t index, T value) { 424 const T bucket_start = ComputeBucketStart(value); 425 Bucket bucket = {1ULL << ComputeBucketOffset(value), bucket_start}; 426 auto it = buckets_.emplace(buckets_.begin() + index, std::move(bucket)); 427 #if defined(NDEBUG) 428 (void)it; // Silencing unused variable warning. 429 #else 430 assert(std::next(it) == buckets_.end() || 431 std::next(it)->start > bucket_start); 432 assert(it == buckets_.begin() || std::prev(it)->start < bucket_start); 433 #endif 434 } 435 436 // Returns true if the bucket at `bucketIndex/ stores the enum at 437 // `bucketOffset`, false otherwise. HasEnumAt(size_t bucketIndex,BucketType bucketOffset)438 bool HasEnumAt(size_t bucketIndex, BucketType bucketOffset) const { 439 assert(bucketIndex < buckets_.size()); 440 assert(bucketOffset < kBucketSize); 441 return buckets_[bucketIndex].data & (1ULL << bucketOffset); 442 } 443 444 // Returns true if `lhs` and `rhs` hold the exact same values. 445 friend bool operator==(const EnumSet& lhs, const EnumSet& rhs) { 446 if (lhs.size_ != rhs.size_) { 447 return false; 448 } 449 450 if (lhs.buckets_.size() != rhs.buckets_.size()) { 451 return false; 452 } 453 return lhs.buckets_ == rhs.buckets_; 454 } 455 456 // Returns true if `lhs` and `rhs` hold at least 1 different value. 457 friend bool operator!=(const EnumSet& lhs, const EnumSet& rhs) { 458 return !(lhs == rhs); 459 } 460 461 // Storage for the buckets. 462 std::vector<Bucket> buckets_; 463 // How many enums is this set storing. 464 size_t size_ = 0; 465 }; 466 467 // A set of spv::Capability. 468 using CapabilitySet = EnumSet<spv::Capability>; 469 470 } // namespace spvtools 471 472 #endif // SOURCE_ENUM_SET_H_ 473