• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef BASE_CONTAINERS_ENUM_SET_H_
6 #define BASE_CONTAINERS_ENUM_SET_H_
7 
8 #include <bitset>
9 #include <cstddef>
10 #include <initializer_list>
11 #include <optional>
12 #include <string>
13 #include <type_traits>
14 #include <utility>
15 
16 #include "base/check.h"
17 #include "base/check_op.h"
18 #include "base/memory/raw_ptr.h"
19 #include "build/build_config.h"
20 
21 namespace base {
22 
23 // Forward declarations needed for friend declarations.
24 template <typename E, E MinEnumValue, E MaxEnumValue>
25 class EnumSet;
26 
27 template <typename E, E Min, E Max>
28 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
29                                      EnumSet<E, Min, Max> set2);
30 
31 template <typename E, E Min, E Max>
32 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
33                                             EnumSet<E, Min, Max> set2);
34 
35 template <typename E, E Min, E Max>
36 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
37                                           EnumSet<E, Min, Max> set2);
38 
39 // An EnumSet is a set that can hold enum values between a min and a
40 // max value (inclusive of both).  It's essentially a wrapper around
41 // std::bitset<> with stronger type enforcement, more descriptive
42 // member function names, and an iterator interface.
43 //
44 // If you're working with enums with a small number of possible values
45 // (say, fewer than 64), you can efficiently pass around an EnumSet
46 // for that enum around by value.
47 
48 template <typename E, E MinEnumValue, E MaxEnumValue>
49 class EnumSet {
50  private:
51   static_assert(
52       std::is_enum_v<E>,
53       "First template parameter of EnumSet must be an enumeration type");
54   using enum_underlying_type = std::underlying_type_t<E>;
55 
InRange(E value)56   static constexpr bool InRange(E value) {
57     return (value >= MinEnumValue) && (value <= MaxEnumValue);
58   }
59 
GetUnderlyingValue(E value)60   static constexpr enum_underlying_type GetUnderlyingValue(E value) {
61     return static_cast<enum_underlying_type>(value);
62   }
63 
64  public:
65   using EnumType = E;
66   static const E kMinValue = MinEnumValue;
67   static const E kMaxValue = MaxEnumValue;
68   static const size_t kValueCount =
69       GetUnderlyingValue(kMaxValue) - GetUnderlyingValue(kMinValue) + 1;
70 
71   static_assert(kMinValue <= kMaxValue,
72                 "min value must be no greater than max value");
73 
74  private:
75   // Declaration needed by Iterator.
76   using EnumBitSet = std::bitset<kValueCount>;
77 
78  public:
79   // Iterator is a forward-only read-only iterator for EnumSet. It follows the
80   // common STL input iterator interface (like std::unordered_set).
81   //
82   // Example usage, using a range-based for loop:
83   //
84   // EnumSet<SomeType> enums;
85   // for (SomeType val : enums) {
86   //   Process(val);
87   // }
88   //
89   // Or using an explicit iterator (not recommended):
90   //
91   // for (EnumSet<...>::Iterator it = enums.begin(); it != enums.end(); it++) {
92   //   Process(*it);
93   // }
94   //
95   // The iterator must not be outlived by the set. In particular, the following
96   // is an error:
97   //
98   // EnumSet<...> SomeFn() { ... }
99   //
100   // /* ERROR */
101   // for (EnumSet<...>::Iterator it = SomeFun().begin(); ...
102   //
103   // Also, there are no guarantees as to what will happen if you
104   // modify an EnumSet while traversing it with an iterator.
105   class Iterator {
106    public:
107     using value_type = EnumType;
108     using size_type = size_t;
109     using difference_type = ptrdiff_t;
110     using pointer = EnumType*;
111     using reference = EnumType&;
112     using iterator_category = std::forward_iterator_tag;
113 
Iterator()114     Iterator() : enums_(nullptr), i_(kValueCount) {}
115     ~Iterator() = default;
116 
117     Iterator(const Iterator&) = default;
118     Iterator& operator=(const Iterator&) = default;
119 
120     Iterator(Iterator&&) = default;
121     Iterator& operator=(Iterator&&) = default;
122 
123     friend bool operator==(const Iterator& lhs, const Iterator& rhs) {
124       return lhs.i_ == rhs.i_;
125     }
126 
127     value_type operator*() const {
128       DCHECK(Good());
129       return FromIndex(i_);
130     }
131 
132     Iterator& operator++() {
133       DCHECK(Good());
134       // If there are no more set elements in the bitset, this will result in an
135       // index equal to kValueCount, which is equivalent to EnumSet.end().
136       i_ = FindNext(i_ + 1);
137 
138       return *this;
139     }
140 
141     Iterator operator++(int) {
142       DCHECK(Good());
143       Iterator old(*this);
144 
145       // If there are no more set elements in the bitset, this will result in an
146       // index equal to kValueCount, which is equivalent to EnumSet.end().
147       i_ = FindNext(i_ + 1);
148 
149       return std::move(old);
150     }
151 
152    private:
153     friend Iterator EnumSet::begin() const;
154 
Iterator(const EnumBitSet & enums)155     explicit Iterator(const EnumBitSet& enums)
156         : enums_(&enums), i_(FindNext(0)) {}
157 
158     // Returns true iff the iterator points to an EnumSet and it
159     // hasn't yet traversed the EnumSet entirely.
Good()160     bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
161 
FindNext(size_t i)162     size_t FindNext(size_t i) {
163       while ((i < kValueCount) && !enums_->test(i)) {
164         ++i;
165       }
166       return i;
167     }
168 
169     raw_ptr<const EnumBitSet> enums_;
170     size_t i_;
171   };
172 
173   EnumSet() = default;
174 
175   ~EnumSet() = default;
176 
EnumSet(std::initializer_list<E> values)177   constexpr EnumSet(std::initializer_list<E> values) {
178     if (std::is_constant_evaluated()) {
179       enums_ = bitstring(values);
180     } else {
181       for (E value : values) {
182         Put(value);
183       }
184     }
185   }
186 
187   // Returns an EnumSet with all values between kMinValue and kMaxValue, which
188   // also contains undefined enum values if the enum in question has gaps
189   // between kMinValue and kMaxValue.
All()190   static constexpr EnumSet All() {
191     if (std::is_constant_evaluated()) {
192       if (kValueCount == 0) {
193         return EnumSet();
194       }
195       // Since `1 << kValueCount` may trigger shift-count-overflow warning if
196       // the `kValueCount` is 64, instead of returning `(1 << kValueCount) - 1`,
197       // the bitmask will be constructed from two parts: the most significant
198       // bits and the remaining.
199       uint64_t mask = 1ULL << (kValueCount - 1);
200       return EnumSet(EnumBitSet(mask - 1 + mask));
201     } else {
202       // When `kValueCount` is greater than 64, we can't use the constexpr path,
203       // and we will build an `EnumSet` value by value.
204       EnumSet enum_set;
205       for (size_t value = 0; value < kValueCount; ++value) {
206         enum_set.Put(FromIndex(value));
207       }
208       return enum_set;
209     }
210   }
211 
212   // Returns an EnumSet with all the values from start to end, inclusive.
FromRange(E start,E end)213   static constexpr EnumSet FromRange(E start, E end) {
214     CHECK_LE(start, end);
215     return EnumSet(EnumBitSet(
216         ((single_val_bitstring(end)) - (single_val_bitstring(start))) |
217         (single_val_bitstring(end))));
218   }
219 
220   // Copy constructor and assignment welcome.
221 
222   // Bitmask operations.
223   //
224   // This bitmask is 0-based and the value of the Nth bit depends on whether
225   // the set contains an enum element of integer value N.
226   //
227   // These may only be used if Min >= 0 and Max < 64.
228 
229   // Returns an EnumSet constructed from |bitmask|.
FromEnumBitmask(const uint64_t bitmask)230   static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
231     static_assert(GetUnderlyingValue(kMaxValue) < 64,
232                   "The highest enum value must be < 64 for FromEnumBitmask ");
233     static_assert(GetUnderlyingValue(kMinValue) >= 0,
234                   "The lowest enum value must be >= 0 for FromEnumBitmask ");
235     return EnumSet(EnumBitSet(bitmask >> GetUnderlyingValue(kMinValue)));
236   }
237   // Returns a bitmask for the EnumSet.
ToEnumBitmask()238   uint64_t ToEnumBitmask() const {
239     static_assert(GetUnderlyingValue(kMaxValue) < 64,
240                   "The highest enum value must be < 64 for ToEnumBitmask ");
241     static_assert(GetUnderlyingValue(kMinValue) >= 0,
242                   "The lowest enum value must be >= 0 for FromEnumBitmask ");
243     return enums_.to_ullong() << GetUnderlyingValue(kMinValue);
244   }
245 
246   // Returns a uint64_t bit mask representing the values within the range
247   // [64*n, 64*n + 63] of the EnumSet.
GetNth64bitWordBitmask(size_t n)248   std::optional<uint64_t> GetNth64bitWordBitmask(size_t n) const {
249     // If the EnumSet contains less than n 64-bit masks, return std::nullopt.
250     if (GetUnderlyingValue(kMaxValue) / 64 < n) {
251       return std::nullopt;
252     }
253 
254     std::bitset<kValueCount> mask = ~uint64_t{0};
255     std::bitset<kValueCount> bits = enums_;
256     if (GetUnderlyingValue(kMinValue) < n * 64) {
257       bits >>= n * 64 - GetUnderlyingValue(kMinValue);
258     }
259     uint64_t result = (bits & mask).to_ullong();
260     if (GetUnderlyingValue(kMinValue) > n * 64) {
261       result <<= GetUnderlyingValue(kMinValue) - n * 64;
262     }
263     return result;
264   }
265 
266   // Set operations.  Put, Retain, and Remove are basically
267   // self-mutating versions of Union, Intersection, and Difference
268   // (defined below).
269 
270   // Adds the given value (which must be in range) to our set.
Put(E value)271   void Put(E value) { enums_.set(ToIndex(value)); }
272 
273   // Adds all values in the given set to our set.
PutAll(EnumSet other)274   void PutAll(EnumSet other) { enums_ |= other.enums_; }
275 
276   // Adds all values in the given range to our set, inclusive.
PutRange(E start,E end)277   void PutRange(E start, E end) {
278     CHECK_LE(start, end);
279     size_t endIndexInclusive = ToIndex(end);
280     for (size_t current = ToIndex(start); current <= endIndexInclusive;
281          ++current) {
282       enums_.set(current);
283     }
284   }
285 
286   // There's no real need for a Retain(E) member function.
287 
288   // Removes all values not in the given set from our set.
RetainAll(EnumSet other)289   void RetainAll(EnumSet other) { enums_ &= other.enums_; }
290 
291   // If the given value is in range, removes it from our set.
Remove(E value)292   void Remove(E value) {
293     if (InRange(value)) {
294       enums_.reset(ToIndex(value));
295     }
296   }
297 
298   // Removes all values in the given set from our set.
RemoveAll(EnumSet other)299   void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
300 
301   // Removes all values from our set.
Clear()302   void Clear() { enums_.reset(); }
303 
304   // Conditionally puts or removes `value`, based on `should_be_present`.
PutOrRemove(E value,bool should_be_present)305   void PutOrRemove(E value, bool should_be_present) {
306     if (should_be_present) {
307       Put(value);
308     } else {
309       Remove(value);
310     }
311   }
312 
313   // Returns true iff the given value is in range and a member of our set.
Has(E value)314   constexpr bool Has(E value) const {
315     return InRange(value) && enums_[ToIndex(value)];
316   }
317 
318   // Returns true iff the given set is a subset of our set.
HasAll(EnumSet other)319   bool HasAll(EnumSet other) const {
320     return (enums_ & other.enums_) == other.enums_;
321   }
322 
323   // Returns true if the given set contains any value of our set.
HasAny(EnumSet other)324   bool HasAny(EnumSet other) const {
325     return (enums_ & other.enums_).count() > 0;
326   }
327 
328   // Returns true iff our set is empty.
empty()329   bool empty() const { return !enums_.any(); }
330 
331   // Returns how many values our set has.
size()332   size_t size() const { return enums_.count(); }
333 
334   // Returns an iterator pointing to the first element (if any).
begin()335   Iterator begin() const { return Iterator(enums_); }
336 
337   // Returns an iterator that does not point to any element, but to the position
338   // that follows the last element in the set.
end()339   Iterator end() const { return Iterator(); }
340 
341   // Returns true iff our set and the given set contain exactly the same values.
342   friend bool operator==(const EnumSet&, const EnumSet&) = default;
343 
ToString()344   std::string ToString() const { return enums_.to_string(); }
345 
346  private:
347   friend constexpr EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
348                                                                 EnumSet set2);
349   friend constexpr EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(
350       EnumSet set1,
351       EnumSet set2);
352   friend constexpr EnumSet Difference<E, MinEnumValue, MaxEnumValue>(
353       EnumSet set1,
354       EnumSet set2);
355 
bitstring(const std::initializer_list<E> & values)356   static constexpr uint64_t bitstring(const std::initializer_list<E>& values) {
357     uint64_t result = 0;
358     for (E value : values) {
359       result |= single_val_bitstring(value);
360     }
361     return result;
362   }
363 
single_val_bitstring(E val)364   static constexpr uint64_t single_val_bitstring(E val) {
365     const uint64_t bitstring = 1;
366     const size_t shift_amount = ToIndex(val);
367     CHECK_LT(shift_amount, sizeof(bitstring) * 8);
368     return bitstring << shift_amount;
369   }
370 
371   // A bitset can't be constexpr constructed if it has size > 64, since the
372   // constexpr constructor uses a uint64_t. If your EnumSet has > 64 values, you
373   // can safely remove the constepxr qualifiers from this file, at the cost of
374   // some minor optimizations.
EnumSet(EnumBitSet enums)375   explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
376     if (std::is_constant_evaluated()) {
377       CHECK(kValueCount <= 64)
378           << "Max number of enum values is 64 for constexpr constructor";
379     }
380   }
381 
382   // Converts a value to/from an index into |enums_|.
ToIndex(E value)383   static constexpr size_t ToIndex(E value) {
384     CHECK(InRange(value));
385     return static_cast<size_t>(GetUnderlyingValue(value)) -
386            static_cast<size_t>(GetUnderlyingValue(MinEnumValue));
387   }
388 
FromIndex(size_t i)389   static E FromIndex(size_t i) {
390     DCHECK_LT(i, kValueCount);
391     return static_cast<E>(GetUnderlyingValue(MinEnumValue) + i);
392   }
393 
394   EnumBitSet enums_;
395 };
396 
397 template <typename E, E MinEnumValue, E MaxEnumValue>
398 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
399 
400 template <typename E, E MinEnumValue, E MaxEnumValue>
401 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
402 
403 template <typename E, E MinEnumValue, E MaxEnumValue>
404 const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
405 
406 // The usual set operations.
407 
408 template <typename E, E Min, E Max>
Union(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)409 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
410                                      EnumSet<E, Min, Max> set2) {
411   return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
412 }
413 
414 template <typename E, E Min, E Max>
Intersection(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)415 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
416                                             EnumSet<E, Min, Max> set2) {
417   return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
418 }
419 
420 template <typename E, E Min, E Max>
Difference(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)421 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
422                                           EnumSet<E, Min, Max> set2) {
423   return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
424 }
425 
426 }  // namespace base
427 
428 #endif  // BASE_CONTAINERS_ENUM_SET_H_
429