• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <stddef.h>
16 
17 #include <algorithm>
18 #include <cassert>
19 #include <cstdint>
20 #include <functional>
21 #include <initializer_list>
22 #include <iterator>
23 #include <limits>
24 #include <type_traits>
25 #include <vector>
26 
27 #ifndef SOURCE_ENUM_SET_H_
28 #define SOURCE_ENUM_SET_H_
29 
30 #include "source/latest_version_spirv_header.h"
31 
32 namespace spvtools {
33 
34 // This container is optimized to store and retrieve unsigned enum values.
35 // The base model for this implementation is an open-addressing hashtable with
36 // linear probing. For small enums (max index < 64), all operations are O(1).
37 //
38 // - Enums are stored in buckets (64 contiguous values max per bucket)
39 // - Buckets ranges don't overlap, but don't have to be contiguous.
40 // - Enums are packed into 64-bits buckets, using 1 bit per enum value.
41 //
42 // Example:
43 //  - MyEnum { A = 0, B = 1, C = 64, D = 65 }
44 //  - 2 buckets are required:
45 //      - bucket 0, storing values in the range [ 0;  64[
46 //      - bucket 1, storing values in the range [64; 128[
47 //
48 // - Buckets are stored in a sorted vector (sorted by bucket range).
49 // - Retrieval is done by computing the theoretical bucket index using the enum
50 // value, and
51 //   doing a linear scan from this position.
52 // - Insertion is done by retrieving the bucket and either:
53 //   - inserting a new bucket in the sorted vector when no buckets has a
54 //   compatible range.
55 //   - setting the corresponding bit in the bucket.
56 //   This means insertion in the middle/beginning can cause a memmove when no
57 //   bucket is available. In our case, this happens at most 23 times for the
58 //   largest enum we have (Opcodes).
59 template <typename T>
60 class EnumSet {
61  private:
62   using BucketType = uint64_t;
63   using ElementType = std::underlying_type_t<T>;
64   static_assert(std::is_enum_v<T>, "EnumSets only works with enums.");
65   static_assert(std::is_signed_v<ElementType> == false,
66                 "EnumSet doesn't supports signed enums.");
67 
68   // Each bucket can hold up to `kBucketSize` distinct, contiguous enum values.
69   // The first value a bucket can hold must be aligned on `kBucketSize`.
70   struct Bucket {
71     // bit mask to store `kBucketSize` enums.
72     BucketType data;
73     // 1st enum this bucket can represent.
74     T start;
75 
76     friend bool operator==(const Bucket& lhs, const Bucket& rhs) {
77       return lhs.start == rhs.start && lhs.data == rhs.data;
78     }
79   };
80 
81   // How many distinct values can a bucket hold? 1 bit per value.
82   static constexpr size_t kBucketSize = sizeof(BucketType) * 8ULL;
83 
84  public:
85   class Iterator {
86    public:
87     typedef Iterator self_type;
88     typedef T value_type;
89     typedef T& reference;
90     typedef T* pointer;
91     typedef std::forward_iterator_tag iterator_category;
92     typedef size_t difference_type;
93 
Iterator(const Iterator & other)94     Iterator(const Iterator& other)
95         : set_(other.set_),
96           bucketIndex_(other.bucketIndex_),
97           bucketOffset_(other.bucketOffset_) {}
98 
99     Iterator& operator++() {
100       do {
101         if (bucketIndex_ >= set_->buckets_.size()) {
102           bucketIndex_ = set_->buckets_.size();
103           bucketOffset_ = 0;
104           break;
105         }
106 
107         if (bucketOffset_ + 1 == kBucketSize) {
108           bucketOffset_ = 0;
109           ++bucketIndex_;
110         } else {
111           ++bucketOffset_;
112         }
113 
114       } while (bucketIndex_ < set_->buckets_.size() &&
115                !set_->HasEnumAt(bucketIndex_, bucketOffset_));
116       return *this;
117     }
118 
119     Iterator operator++(int) {
120       Iterator old = *this;
121       operator++();
122       return old;
123     }
124 
125     T operator*() const {
126       assert(set_->HasEnumAt(bucketIndex_, bucketOffset_) &&
127              "operator*() called on an invalid iterator.");
128       return GetValueFromBucket(set_->buckets_[bucketIndex_], bucketOffset_);
129     }
130 
131     bool operator!=(const Iterator& other) const {
132       return set_ != other.set_ || bucketOffset_ != other.bucketOffset_ ||
133              bucketIndex_ != other.bucketIndex_;
134     }
135 
136     bool operator==(const Iterator& other) const {
137       return !(operator!=(other));
138     }
139 
140     Iterator& operator=(const Iterator& other) {
141       set_ = other.set_;
142       bucketIndex_ = other.bucketIndex_;
143       bucketOffset_ = other.bucketOffset_;
144       return *this;
145     }
146 
147    private:
Iterator(const EnumSet * set,size_t bucketIndex,ElementType bucketOffset)148     Iterator(const EnumSet* set, size_t bucketIndex, ElementType bucketOffset)
149         : set_(set), bucketIndex_(bucketIndex), bucketOffset_(bucketOffset) {}
150 
151    private:
152     const EnumSet* set_ = nullptr;
153     // Index of the bucket in the vector.
154     size_t bucketIndex_ = 0;
155     // Offset in bits in the current bucket.
156     ElementType bucketOffset_ = 0;
157 
158     friend class EnumSet;
159   };
160 
161   // Required to allow the use of std::inserter.
162   using value_type = T;
163   using const_iterator = Iterator;
164   using iterator = Iterator;
165 
166  public:
cbegin()167   iterator cbegin() const noexcept {
168     auto it = iterator(this, /* bucketIndex= */ 0, /* bucketOffset= */ 0);
169     if (buckets_.size() == 0) {
170       return it;
171     }
172 
173     // The iterator has the logic to find the next valid bit. If the value 0
174     // is not stored, use it to find the next valid bit.
175     if (!HasEnumAt(it.bucketIndex_, it.bucketOffset_)) {
176       ++it;
177     }
178 
179     return it;
180   }
181 
begin()182   iterator begin() const noexcept { return cbegin(); }
183 
cend()184   iterator cend() const noexcept {
185     return iterator(this, buckets_.size(), /* bucketOffset= */ 0);
186   }
187 
end()188   iterator end() const noexcept { return cend(); }
189 
190   // Creates an empty set.
EnumSet()191   EnumSet() : buckets_(0), size_(0) {}
192 
193   // Creates a set and store `value` in it.
EnumSet(T value)194   EnumSet(T value) : EnumSet() { insert(value); }
195 
196   // Creates a set and stores each `values` in it.
EnumSet(std::initializer_list<T> values)197   EnumSet(std::initializer_list<T> values) : EnumSet() {
198     for (auto item : values) {
199       insert(item);
200     }
201   }
202 
203   // Creates a set, and insert `count` enum values pointed by `array` in it.
EnumSet(ElementType count,const T * array)204   EnumSet(ElementType count, const T* array) : EnumSet() {
205     for (ElementType i = 0; i < count; i++) {
206       insert(array[i]);
207     }
208   }
209 
210   // Creates a set initialized with the content of the range [begin; end[.
211   template <class InputIt>
EnumSet(InputIt begin,InputIt end)212   EnumSet(InputIt begin, InputIt end) : EnumSet() {
213     for (; begin != end; ++begin) {
214       insert(*begin);
215     }
216   }
217 
218   // Copies the EnumSet `other` into a new EnumSet.
EnumSet(const EnumSet & other)219   EnumSet(const EnumSet& other)
220       : buckets_(other.buckets_), size_(other.size_) {}
221 
222   // Moves the EnumSet `other` into a new EnumSet.
EnumSet(EnumSet && other)223   EnumSet(EnumSet&& other)
224       : buckets_(std::move(other.buckets_)), size_(other.size_) {}
225 
226   // Deep-copies the EnumSet `other` into this EnumSet.
227   EnumSet& operator=(const EnumSet& other) {
228     buckets_ = other.buckets_;
229     size_ = other.size_;
230     return *this;
231   }
232 
233   // Matches std::unordered_set::insert behavior.
insert(const T & value)234   std::pair<iterator, bool> insert(const T& value) {
235     const size_t index = FindBucketForValue(value);
236     const ElementType offset = ComputeBucketOffset(value);
237 
238     if (index >= buckets_.size() ||
239         buckets_[index].start != ComputeBucketStart(value)) {
240       size_ += 1;
241       InsertBucketFor(index, value);
242       return std::make_pair(Iterator(this, index, offset), true);
243     }
244 
245     auto& bucket = buckets_[index];
246     const auto mask = ComputeMaskForValue(value);
247     if (bucket.data & mask) {
248       return std::make_pair(Iterator(this, index, offset), false);
249     }
250 
251     size_ += 1;
252     bucket.data |= ComputeMaskForValue(value);
253     return std::make_pair(Iterator(this, index, offset), true);
254   }
255 
256   // Inserts `value` in the set if possible.
257   // Similar to `std::unordered_set::insert`, except the hint is ignored.
258   // Returns an iterator to the inserted element, or the element preventing
259   // insertion.
insert(const_iterator,const T & value)260   iterator insert(const_iterator, const T& value) {
261     return insert(value).first;
262   }
263 
264   // Inserts `value` in the set if possible.
265   // Similar to `std::unordered_set::insert`, except the hint is ignored.
266   // Returns an iterator to the inserted element, or the element preventing
267   // insertion.
insert(const_iterator,T && value)268   iterator insert(const_iterator, T&& value) { return insert(value).first; }
269 
270   // Inserts all the values in the range [`first`; `last[.
271   // Similar to `std::unordered_set::insert`.
272   template <class InputIt>
insert(InputIt first,InputIt last)273   void insert(InputIt first, InputIt last) {
274     for (auto it = first; it != last; ++it) {
275       insert(*it);
276     }
277   }
278 
279   // Removes the value `value` into the set.
280   // Similar to `std::unordered_set::erase`.
281   // Returns the number of erased elements.
erase(const T & value)282   size_t erase(const T& value) {
283     const size_t index = FindBucketForValue(value);
284     if (index >= buckets_.size() ||
285         buckets_[index].start != ComputeBucketStart(value)) {
286       return 0;
287     }
288 
289     auto& bucket = buckets_[index];
290     const auto mask = ComputeMaskForValue(value);
291     if (!(bucket.data & mask)) {
292       return 0;
293     }
294 
295     size_ -= 1;
296     bucket.data &= ~mask;
297     if (bucket.data == 0) {
298       buckets_.erase(buckets_.cbegin() + index);
299     }
300     return 1;
301   }
302 
303   // Returns true if `value` is present in the set.
contains(T value)304   bool contains(T value) const {
305     const size_t index = FindBucketForValue(value);
306     if (index >= buckets_.size() ||
307         buckets_[index].start != ComputeBucketStart(value)) {
308       return false;
309     }
310     auto& bucket = buckets_[index];
311     return bucket.data & ComputeMaskForValue(value);
312   }
313 
314   // Returns the 1 if `value` is present in the set, `0` otherwise.
count(T value)315   inline size_t count(T value) const { return contains(value) ? 1 : 0; }
316 
317   // Returns true if the set is holds no values.
empty()318   inline bool empty() const { return size_ == 0; }
319 
320   // Returns the number of enums stored in this set.
size()321   size_t size() const { return size_; }
322 
323   // Returns true if this set contains at least one value contained in `in_set`.
324   // Note: If `in_set` is empty, this function returns true.
HasAnyOf(const EnumSet<T> & in_set)325   bool HasAnyOf(const EnumSet<T>& in_set) const {
326     if (in_set.empty()) {
327       return true;
328     }
329 
330     auto lhs = buckets_.cbegin();
331     auto rhs = in_set.buckets_.cbegin();
332 
333     while (lhs != buckets_.cend() && rhs != in_set.buckets_.cend()) {
334       if (lhs->start == rhs->start) {
335         if (lhs->data & rhs->data) {
336           // At least 1 bit is shared. Early return.
337           return true;
338         }
339 
340         lhs++;
341         rhs++;
342         continue;
343       }
344 
345       // LHS bucket is smaller than the current RHS bucket. Catching up on RHS.
346       if (lhs->start < rhs->start) {
347         lhs++;
348         continue;
349       }
350 
351       // Otherwise, RHS needs to catch up on LHS.
352       rhs++;
353     }
354 
355     return false;
356   }
357 
358  private:
359   // Returns the index of the last bucket in which `value` could be stored.
ComputeLargestPossibleBucketIndexFor(T value)360   static constexpr inline size_t ComputeLargestPossibleBucketIndexFor(T value) {
361     return static_cast<size_t>(value) / kBucketSize;
362   }
363 
364   // Returns the smallest enum value that could be contained in the same bucket
365   // as `value`.
ComputeBucketStart(T value)366   static constexpr inline T ComputeBucketStart(T value) {
367     return static_cast<T>(kBucketSize *
368                           ComputeLargestPossibleBucketIndexFor(value));
369   }
370 
371   //  Returns the index of the bit that corresponds to `value` in the bucket.
ComputeBucketOffset(T value)372   static constexpr inline ElementType ComputeBucketOffset(T value) {
373     return static_cast<ElementType>(value) % kBucketSize;
374   }
375 
376   // Returns the bitmask used to represent the enum `value` in its bucket.
ComputeMaskForValue(T value)377   static constexpr inline BucketType ComputeMaskForValue(T value) {
378     return 1ULL << ComputeBucketOffset(value);
379   }
380 
381   // Returns the `enum` stored in `bucket` at `offset`.
382   // `offset` is the bit-offset in the bucket storage.
GetValueFromBucket(const Bucket & bucket,BucketType offset)383   static constexpr inline T GetValueFromBucket(const Bucket& bucket,
384                                                BucketType offset) {
385     return static_cast<T>(static_cast<ElementType>(bucket.start) + offset);
386   }
387 
388   // For a given enum `value`, finds the bucket index that could contain this
389   // value. If no such bucket is found, the index at which the new bucket should
390   // be inserted is returned.
FindBucketForValue(T value)391   size_t FindBucketForValue(T value) const {
392     // Set is empty, insert at 0.
393     if (buckets_.size() == 0) {
394       return 0;
395     }
396 
397     const T wanted_start = ComputeBucketStart(value);
398     assert(buckets_.size() > 0 &&
399            "Size must not be 0 here. Has the code above changed?");
400     size_t index = std::min(buckets_.size() - 1,
401                             ComputeLargestPossibleBucketIndexFor(value));
402 
403     // This loops behaves like std::upper_bound with a reverse iterator.
404     // Buckets are sorted. 3 main cases:
405     //  - The bucket matches
406     //    => returns the bucket index.
407     //  - The found bucket is larger
408     //    => scans left until it finds the correct bucket, or insertion point.
409     //  - The found bucket is smaller
410     //    => We are at the end, so we return past-end index for insertion.
411     for (; buckets_[index].start >= wanted_start; index--) {
412       if (index == 0) {
413         return 0;
414       }
415     }
416 
417     return index + 1;
418   }
419 
420   // Creates a new bucket to store `value` and inserts it at `index`.
421   // If the `index` is past the end, the bucket is inserted at the end of the
422   // vector.
InsertBucketFor(size_t index,T value)423   void InsertBucketFor(size_t index, T value) {
424     const T bucket_start = ComputeBucketStart(value);
425     Bucket bucket = {1ULL << ComputeBucketOffset(value), bucket_start};
426     auto it = buckets_.emplace(buckets_.begin() + index, std::move(bucket));
427 #if defined(NDEBUG)
428     (void)it;  // Silencing unused variable warning.
429 #else
430     assert(std::next(it) == buckets_.end() ||
431            std::next(it)->start > bucket_start);
432     assert(it == buckets_.begin() || std::prev(it)->start < bucket_start);
433 #endif
434   }
435 
436   // Returns true if the bucket at `bucketIndex/ stores the enum at
437   // `bucketOffset`, false otherwise.
HasEnumAt(size_t bucketIndex,BucketType bucketOffset)438   bool HasEnumAt(size_t bucketIndex, BucketType bucketOffset) const {
439     assert(bucketIndex < buckets_.size());
440     assert(bucketOffset < kBucketSize);
441     return buckets_[bucketIndex].data & (1ULL << bucketOffset);
442   }
443 
444   // Returns true if `lhs` and `rhs` hold the exact same values.
445   friend bool operator==(const EnumSet& lhs, const EnumSet& rhs) {
446     if (lhs.size_ != rhs.size_) {
447       return false;
448     }
449 
450     if (lhs.buckets_.size() != rhs.buckets_.size()) {
451       return false;
452     }
453     return lhs.buckets_ == rhs.buckets_;
454   }
455 
456   // Returns true if `lhs` and `rhs` hold at least 1 different value.
457   friend bool operator!=(const EnumSet& lhs, const EnumSet& rhs) {
458     return !(lhs == rhs);
459   }
460 
461   // Storage for the buckets.
462   std::vector<Bucket> buckets_;
463   // How many enums is this set storing.
464   size_t size_ = 0;
465 };
466 
467 // A set of spv::Capability.
468 using CapabilitySet = EnumSet<spv::Capability>;
469 
470 }  // namespace spvtools
471 
472 #endif  // SOURCE_ENUM_SET_H_
473