• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef BASE_CONTAINERS_ENUM_SET_H_
6 #define BASE_CONTAINERS_ENUM_SET_H_
7 
8 #include <bitset>
9 #include <cstddef>
10 #include <initializer_list>
11 #include <type_traits>
12 #include <utility>
13 
14 #include "base/check.h"
15 #include "base/check_op.h"
16 #include "base/memory/raw_ptr.h"
17 #include "build/build_config.h"
18 
19 namespace base {
20 
21 // Forward declarations needed for friend declarations.
22 template <typename E, E MinEnumValue, E MaxEnumValue>
23 class EnumSet;
24 
25 template <typename E, E Min, E Max>
26 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
27                                      EnumSet<E, Min, Max> set2);
28 
29 template <typename E, E Min, E Max>
30 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
31                                             EnumSet<E, Min, Max> set2);
32 
33 template <typename E, E Min, E Max>
34 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
35                                           EnumSet<E, Min, Max> set2);
36 
37 // An EnumSet is a set that can hold enum values between a min and a
38 // max value (inclusive of both).  It's essentially a wrapper around
39 // std::bitset<> with stronger type enforcement, more descriptive
40 // member function names, and an iterator interface.
41 //
42 // If you're working with enums with a small number of possible values
43 // (say, fewer than 64), you can efficiently pass around an EnumSet
44 // for that enum around by value.
45 
46 template <typename E, E MinEnumValue, E MaxEnumValue>
47 class EnumSet {
48  private:
49   static_assert(
50       std::is_enum_v<E>,
51       "First template parameter of EnumSet must be an enumeration type");
52   using enum_underlying_type = std::underlying_type_t<E>;
53 
InRange(E value)54   static constexpr bool InRange(E value) {
55     return (value >= MinEnumValue) && (value <= MaxEnumValue);
56   }
57 
GetUnderlyingValue(E value)58   static constexpr enum_underlying_type GetUnderlyingValue(E value) {
59     return static_cast<enum_underlying_type>(value);
60   }
61 
62  public:
63   using EnumType = E;
64   static const E kMinValue = MinEnumValue;
65   static const E kMaxValue = MaxEnumValue;
66   static const size_t kValueCount =
67       GetUnderlyingValue(kMaxValue) - GetUnderlyingValue(kMinValue) + 1;
68 
69   static_assert(kMinValue <= kMaxValue,
70                 "min value must be no greater than max value");
71 
72  private:
73   // Declaration needed by Iterator.
74   using EnumBitSet = std::bitset<kValueCount>;
75 
76  public:
77   // Iterator is a forward-only read-only iterator for EnumSet. It follows the
78   // common STL input iterator interface (like std::unordered_set).
79   //
80   // Example usage, using a range-based for loop:
81   //
82   // EnumSet<SomeType> enums;
83   // for (SomeType val : enums) {
84   //   Process(val);
85   // }
86   //
87   // Or using an explicit iterator (not recommended):
88   //
89   // for (EnumSet<...>::Iterator it = enums.begin(); it != enums.end(); it++) {
90   //   Process(*it);
91   // }
92   //
93   // The iterator must not be outlived by the set. In particular, the following
94   // is an error:
95   //
96   // EnumSet<...> SomeFn() { ... }
97   //
98   // /* ERROR */
99   // for (EnumSet<...>::Iterator it = SomeFun().begin(); ...
100   //
101   // Also, there are no guarantees as to what will happen if you
102   // modify an EnumSet while traversing it with an iterator.
103   class Iterator {
104    public:
Iterator()105     Iterator() : enums_(nullptr), i_(kValueCount) {}
106     ~Iterator() = default;
107 
108     friend bool operator==(const Iterator& lhs, const Iterator& rhs) {
109       return lhs.i_ == rhs.i_;
110     }
111 
112     E operator*() const {
113       DCHECK(Good());
114       return FromIndex(i_);
115     }
116 
117     Iterator& operator++() {
118       DCHECK(Good());
119       // If there are no more set elements in the bitset, this will result in an
120       // index equal to kValueCount, which is equivalent to EnumSet.end().
121       i_ = FindNext(i_ + 1);
122 
123       return *this;
124     }
125 
126     Iterator operator++(int) {
127       DCHECK(Good());
128       Iterator old(*this);
129 
130       // If there are no more set elements in the bitset, this will result in an
131       // index equal to kValueCount, which is equivalent to EnumSet.end().
132       i_ = FindNext(i_ + 1);
133 
134       return std::move(old);
135     }
136 
137    private:
138     friend Iterator EnumSet::begin() const;
139 
Iterator(const EnumBitSet & enums)140     explicit Iterator(const EnumBitSet& enums)
141         : enums_(&enums), i_(FindNext(0)) {}
142 
143     // Returns true iff the iterator points to an EnumSet and it
144     // hasn't yet traversed the EnumSet entirely.
Good()145     bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
146 
FindNext(size_t i)147     size_t FindNext(size_t i) {
148       while ((i < kValueCount) && !enums_->test(i)) {
149         ++i;
150       }
151       return i;
152     }
153 
154     const raw_ptr<const EnumBitSet> enums_;
155     size_t i_;
156   };
157 
158   EnumSet() = default;
159 
160   ~EnumSet() = default;
161 
EnumSet(std::initializer_list<E> values)162   constexpr EnumSet(std::initializer_list<E> values) {
163     if (std::is_constant_evaluated()) {
164       enums_ = bitstring(values);
165     } else {
166       for (E value : values) {
167         Put(value);
168       }
169     }
170   }
171 
172   // Returns an EnumSet with all values between kMinValue and kMaxValue, which
173   // also contains undefined enum values if the enum in question has gaps
174   // between kMinValue and kMaxValue.
All()175   static constexpr EnumSet All() {
176     if (std::is_constant_evaluated()) {
177       if (kValueCount == 0) {
178         return EnumSet();
179       }
180       // Since `1 << kValueCount` may trigger shift-count-overflow warning if
181       // the `kValueCount` is 64, instead of returning `(1 << kValueCount) - 1`,
182       // the bitmask will be constructed from two parts: the most significant
183       // bits and the remaining.
184       uint64_t mask = 1ULL << (kValueCount - 1);
185       return EnumSet(EnumBitSet(mask - 1 + mask));
186     } else {
187       // When `kValueCount` is greater than 64, we can't use the constexpr path,
188       // and we will build an `EnumSet` value by value.
189       EnumSet enum_set;
190       for (size_t value = 0; value < kValueCount; ++value) {
191         enum_set.Put(FromIndex(value));
192       }
193       return enum_set;
194     }
195   }
196 
197   // Returns an EnumSet with all the values from start to end, inclusive.
FromRange(E start,E end)198   static constexpr EnumSet FromRange(E start, E end) {
199     CHECK_LE(start, end);
200     return EnumSet(EnumBitSet(
201         ((single_val_bitstring(end)) - (single_val_bitstring(start))) |
202         (single_val_bitstring(end))));
203   }
204 
205   // Copy constructor and assignment welcome.
206 
207   // Bitmask operations.
208   //
209   // This bitmask is 0-based and the value of the Nth bit depends on whether
210   // the set contains an enum element of integer value N.
211   //
212   // These may only be used if Min >= 0 and Max < 64.
213 
214   // Returns an EnumSet constructed from |bitmask|.
FromEnumBitmask(const uint64_t bitmask)215   static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
216     static_assert(GetUnderlyingValue(kMaxValue) < 64,
217                   "The highest enum value must be < 64 for FromEnumBitmask ");
218     static_assert(GetUnderlyingValue(kMinValue) >= 0,
219                   "The lowest enum value must be >= 0 for FromEnumBitmask ");
220     return EnumSet(EnumBitSet(bitmask >> GetUnderlyingValue(kMinValue)));
221   }
222   // Returns a bitmask for the EnumSet.
ToEnumBitmask()223   uint64_t ToEnumBitmask() const {
224     static_assert(GetUnderlyingValue(kMaxValue) < 64,
225                   "The highest enum value must be < 64 for ToEnumBitmask ");
226     static_assert(GetUnderlyingValue(kMinValue) >= 0,
227                   "The lowest enum value must be >= 0 for FromEnumBitmask ");
228     return enums_.to_ullong() << GetUnderlyingValue(kMinValue);
229   }
230 
231   // Set operations.  Put, Retain, and Remove are basically
232   // self-mutating versions of Union, Intersection, and Difference
233   // (defined below).
234 
235   // Adds the given value (which must be in range) to our set.
Put(E value)236   void Put(E value) { enums_.set(ToIndex(value)); }
237 
238   // Adds all values in the given set to our set.
PutAll(EnumSet other)239   void PutAll(EnumSet other) { enums_ |= other.enums_; }
240 
241   // Adds all values in the given range to our set, inclusive.
PutRange(E start,E end)242   void PutRange(E start, E end) {
243     CHECK_LE(start, end);
244     size_t endIndexInclusive = ToIndex(end);
245     for (size_t current = ToIndex(start); current <= endIndexInclusive;
246          ++current) {
247       enums_.set(current);
248     }
249   }
250 
251   // There's no real need for a Retain(E) member function.
252 
253   // Removes all values not in the given set from our set.
RetainAll(EnumSet other)254   void RetainAll(EnumSet other) { enums_ &= other.enums_; }
255 
256   // If the given value is in range, removes it from our set.
Remove(E value)257   void Remove(E value) {
258     if (InRange(value)) {
259       enums_.reset(ToIndex(value));
260     }
261   }
262 
263   // Removes all values in the given set from our set.
RemoveAll(EnumSet other)264   void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
265 
266   // Removes all values from our set.
Clear()267   void Clear() { enums_.reset(); }
268 
269   // Conditionally puts or removes `value`, based on `should_be_present`.
PutOrRemove(E value,bool should_be_present)270   void PutOrRemove(E value, bool should_be_present) {
271     if (should_be_present) {
272       Put(value);
273     } else {
274       Remove(value);
275     }
276   }
277 
278   // Returns true iff the given value is in range and a member of our set.
Has(E value)279   constexpr bool Has(E value) const {
280     return InRange(value) && enums_[ToIndex(value)];
281   }
282 
283   // Returns true iff the given set is a subset of our set.
HasAll(EnumSet other)284   bool HasAll(EnumSet other) const {
285     return (enums_ & other.enums_) == other.enums_;
286   }
287 
288   // Returns true if the given set contains any value of our set.
HasAny(EnumSet other)289   bool HasAny(EnumSet other) const {
290     return (enums_ & other.enums_).count() > 0;
291   }
292 
293   // Returns true iff our set is empty.
Empty()294   bool Empty() const { return !enums_.any(); }
295 
296   // Returns how many values our set has.
Size()297   size_t Size() const { return enums_.count(); }
298 
299   // Returns an iterator pointing to the first element (if any).
begin()300   Iterator begin() const { return Iterator(enums_); }
301 
302   // Returns an iterator that does not point to any element, but to the position
303   // that follows the last element in the set.
end()304   Iterator end() const { return Iterator(); }
305 
306   // Returns true iff our set and the given set contain exactly the same values.
307   friend bool operator==(const EnumSet&, const EnumSet&) = default;
308 
309  private:
310   friend constexpr EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
311                                                                 EnumSet set2);
312   friend constexpr EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(
313       EnumSet set1,
314       EnumSet set2);
315   friend constexpr EnumSet Difference<E, MinEnumValue, MaxEnumValue>(
316       EnumSet set1,
317       EnumSet set2);
318 
bitstring(const std::initializer_list<E> & values)319   static constexpr uint64_t bitstring(const std::initializer_list<E>& values) {
320     uint64_t result = 0;
321     for (E value : values) {
322       result |= single_val_bitstring(value);
323     }
324     return result;
325   }
326 
single_val_bitstring(E val)327   static constexpr uint64_t single_val_bitstring(E val) {
328     const uint64_t bitstring = 1;
329     const size_t shift_amount = ToIndex(val);
330     CHECK_LT(shift_amount, sizeof(bitstring) * 8);
331     return bitstring << shift_amount;
332   }
333 
334   // A bitset can't be constexpr constructed if it has size > 64, since the
335   // constexpr constructor uses a uint64_t. If your EnumSet has > 64 values, you
336   // can safely remove the constepxr qualifiers from this file, at the cost of
337   // some minor optimizations.
EnumSet(EnumBitSet enums)338   explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
339     if (std::is_constant_evaluated()) {
340       CHECK(kValueCount <= 64)
341           << "Max number of enum values is 64 for constexpr constructor";
342     }
343   }
344 
345   // Converts a value to/from an index into |enums_|.
ToIndex(E value)346   static constexpr size_t ToIndex(E value) {
347     CHECK(InRange(value));
348     return static_cast<size_t>(GetUnderlyingValue(value)) -
349            static_cast<size_t>(GetUnderlyingValue(MinEnumValue));
350   }
351 
FromIndex(size_t i)352   static E FromIndex(size_t i) {
353     DCHECK_LT(i, kValueCount);
354     return static_cast<E>(GetUnderlyingValue(MinEnumValue) + i);
355   }
356 
357   EnumBitSet enums_;
358 };
359 
360 template <typename E, E MinEnumValue, E MaxEnumValue>
361 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
362 
363 template <typename E, E MinEnumValue, E MaxEnumValue>
364 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
365 
366 template <typename E, E MinEnumValue, E MaxEnumValue>
367 const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
368 
369 // The usual set operations.
370 
371 template <typename E, E Min, E Max>
Union(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)372 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
373                                      EnumSet<E, Min, Max> set2) {
374   return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
375 }
376 
377 template <typename E, E Min, E Max>
Intersection(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)378 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
379                                             EnumSet<E, Min, Max> set2) {
380   return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
381 }
382 
383 template <typename E, E Min, E Max>
Difference(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)384 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
385                                           EnumSet<E, Min, Max> set2) {
386   return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
387 }
388 
389 }  // namespace base
390 
391 #endif  // BASE_CONTAINERS_ENUM_SET_H_
392