• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef BASE_CONTAINERS_ENUM_SET_H_
6 #define BASE_CONTAINERS_ENUM_SET_H_
7 
8 #include <bitset>
9 #include <cstddef>
10 #include <type_traits>
11 #include <utility>
12 
13 #include "base/check.h"
14 #include "base/check_op.h"
15 #include "base/memory/raw_ptr.h"
16 
17 namespace base {
18 
19 // Forward declarations needed for friend declarations.
20 template <typename E, E MinEnumValue, E MaxEnumValue>
21 class EnumSet;
22 
23 template <typename E, E Min, E Max>
24 EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
25                            EnumSet<E, Min, Max> set2);
26 
27 template <typename E, E Min, E Max>
28 EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
29                                   EnumSet<E, Min, Max> set2);
30 
31 template <typename E, E Min, E Max>
32 EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
33                                 EnumSet<E, Min, Max> set2);
34 
35 // An EnumSet is a set that can hold enum values between a min and a
36 // max value (inclusive of both).  It's essentially a wrapper around
37 // std::bitset<> with stronger type enforcement, more descriptive
38 // member function names, and an iterator interface.
39 //
40 // If you're working with enums with a small number of possible values
41 // (say, fewer than 64), you can efficiently pass around an EnumSet
42 // for that enum around by value.
43 
44 template <typename E, E MinEnumValue, E MaxEnumValue>
45 class EnumSet {
46  private:
47   static_assert(
48       std::is_enum<E>::value,
49       "First template parameter of EnumSet must be an enumeration type");
50   using enum_underlying_type = std::underlying_type_t<E>;
51 
InRange(E value)52   static constexpr bool InRange(E value) {
53     return (value >= MinEnumValue) && (value <= MaxEnumValue);
54   }
55 
GetUnderlyingValue(E value)56   static constexpr enum_underlying_type GetUnderlyingValue(E value) {
57     return static_cast<enum_underlying_type>(value);
58   }
59 
60  public:
61   using EnumType = E;
62   static const E kMinValue = MinEnumValue;
63   static const E kMaxValue = MaxEnumValue;
64   static const size_t kValueCount =
65       GetUnderlyingValue(kMaxValue) - GetUnderlyingValue(kMinValue) + 1;
66 
67   static_assert(kMinValue < kMaxValue, "min value must be less than max value");
68 
69  private:
70   // Declaration needed by Iterator.
71   using EnumBitSet = std::bitset<kValueCount>;
72 
73  public:
74   // Iterator is a forward-only read-only iterator for EnumSet. It follows the
75   // common STL input iterator interface (like std::unordered_set).
76   //
77   // Example usage, using a range-based for loop:
78   //
79   // EnumSet<SomeType> enums;
80   // for (SomeType val : enums) {
81   //   Process(val);
82   // }
83   //
84   // Or using an explicit iterator (not recommended):
85   //
86   // for (EnumSet<...>::Iterator it = enums.begin(); it != enums.end(); it++) {
87   //   Process(*it);
88   // }
89   //
90   // The iterator must not be outlived by the set. In particular, the following
91   // is an error:
92   //
93   // EnumSet<...> SomeFn() { ... }
94   //
95   // /* ERROR */
96   // for (EnumSet<...>::Iterator it = SomeFun().begin(); ...
97   //
98   // Also, there are no guarantees as to what will happen if you
99   // modify an EnumSet while traversing it with an iterator.
100   class Iterator {
101    public:
Iterator()102     Iterator() : enums_(nullptr), i_(kValueCount) {}
103     ~Iterator() = default;
104 
105     bool operator==(const Iterator& other) const { return i_ == other.i_; }
106 
107     bool operator!=(const Iterator& other) const { return !(*this == other); }
108 
109     E operator*() const {
110       DCHECK(Good());
111       return FromIndex(i_);
112     }
113 
114     Iterator& operator++() {
115       DCHECK(Good());
116       // If there are no more set elements in the bitset, this will result in an
117       // index equal to kValueCount, which is equivalent to EnumSet.end().
118       i_ = FindNext(i_ + 1);
119 
120       return *this;
121     }
122 
123     Iterator operator++(int) {
124       DCHECK(Good());
125       Iterator old(*this);
126 
127       // If there are no more set elements in the bitset, this will result in an
128       // index equal to kValueCount, which is equivalent to EnumSet.end().
129       i_ = FindNext(i_ + 1);
130 
131       return std::move(old);
132     }
133 
134    private:
135     friend Iterator EnumSet::begin() const;
136 
Iterator(const EnumBitSet & enums)137     explicit Iterator(const EnumBitSet& enums)
138         : enums_(&enums), i_(FindNext(0)) {}
139 
140     // Returns true iff the iterator points to an EnumSet and it
141     // hasn't yet traversed the EnumSet entirely.
Good()142     bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
143 
FindNext(size_t i)144     size_t FindNext(size_t i) {
145       while ((i < kValueCount) && !enums_->test(i)) {
146         ++i;
147       }
148       return i;
149     }
150 
151     raw_ptr<const EnumBitSet, DanglingUntriaged> enums_;
152     size_t i_;
153   };
154 
155   EnumSet() = default;
156 
157   ~EnumSet() = default;
158 
single_val_bitstring(E val)159   static constexpr uint64_t single_val_bitstring(E val) {
160     const uint64_t bitstring = 1;
161     const size_t shift_amount = ToIndex(val);
162     CHECK_LT(shift_amount, sizeof(bitstring) * 8);
163     return bitstring << shift_amount;
164   }
165 
166   template <class... T>
bitstring(T...values)167   static constexpr uint64_t bitstring(T... values) {
168     uint64_t converted[] = {single_val_bitstring(values)...};
169     uint64_t result = 0;
170     for (uint64_t e : converted)
171       result |= e;
172     return result;
173   }
174 
175   template <class... T>
EnumSet(E head,T...tail)176   constexpr EnumSet(E head, T... tail)
177       : EnumSet(EnumBitSet(bitstring(head, tail...))) {}
178 
179   // Returns an EnumSet with all possible values.
All()180   static constexpr EnumSet All() {
181     return EnumSet(EnumBitSet((1ULL << kValueCount) - 1));
182   }
183 
184   // Returns an EnumSet with all the values from start to end, inclusive.
FromRange(E start,E end)185   static constexpr EnumSet FromRange(E start, E end) {
186     CHECK_LE(start, end);
187     return EnumSet(EnumBitSet(
188         ((single_val_bitstring(end)) - (single_val_bitstring(start))) |
189         (single_val_bitstring(end))));
190   }
191 
192   // Copy constructor and assignment welcome.
193 
194   // Bitmask operations.
195   //
196   // This bitmask is 0-based and the value of the Nth bit depends on whether
197   // the set contains an enum element of integer value N.
198   //
199   // These may only be used if Min >= 0 and Max < 64.
200 
201   // Returns an EnumSet constructed from |bitmask|.
FromEnumBitmask(const uint64_t bitmask)202   static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
203     static_assert(GetUnderlyingValue(kMaxValue) < 64,
204                   "The highest enum value must be < 64 for FromEnumBitmask ");
205     static_assert(GetUnderlyingValue(kMinValue) >= 0,
206                   "The lowest enum value must be >= 0 for FromEnumBitmask ");
207     return EnumSet(EnumBitSet(bitmask >> GetUnderlyingValue(kMinValue)));
208   }
209   // Returns a bitmask for the EnumSet.
ToEnumBitmask()210   uint64_t ToEnumBitmask() const {
211     static_assert(GetUnderlyingValue(kMaxValue) < 64,
212                   "The highest enum value must be < 64 for ToEnumBitmask ");
213     static_assert(GetUnderlyingValue(kMinValue) >= 0,
214                   "The lowest enum value must be >= 0 for FromEnumBitmask ");
215     return enums_.to_ullong() << GetUnderlyingValue(kMinValue);
216   }
217 
218   // Set operations.  Put, Retain, and Remove are basically
219   // self-mutating versions of Union, Intersection, and Difference
220   // (defined below).
221 
222   // Adds the given value (which must be in range) to our set.
Put(E value)223   void Put(E value) { enums_.set(ToIndex(value)); }
224 
225   // Adds all values in the given set to our set.
PutAll(EnumSet other)226   void PutAll(EnumSet other) { enums_ |= other.enums_; }
227 
228   // Adds all values in the given range to our set, inclusive.
PutRange(E start,E end)229   void PutRange(E start, E end) {
230     CHECK_LE(start, end);
231     size_t endIndexInclusive = ToIndex(end);
232     for (size_t current = ToIndex(start); current <= endIndexInclusive;
233          ++current) {
234       enums_.set(current);
235     }
236   }
237 
238   // There's no real need for a Retain(E) member function.
239 
240   // Removes all values not in the given set from our set.
RetainAll(EnumSet other)241   void RetainAll(EnumSet other) { enums_ &= other.enums_; }
242 
243   // If the given value is in range, removes it from our set.
Remove(E value)244   void Remove(E value) {
245     if (InRange(value)) {
246       enums_.reset(ToIndex(value));
247     }
248   }
249 
250   // Removes all values in the given set from our set.
RemoveAll(EnumSet other)251   void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
252 
253   // Removes all values from our set.
Clear()254   void Clear() { enums_.reset(); }
255 
256   // Conditionally puts or removes `value`, based on `should_be_present`.
PutOrRemove(E value,bool should_be_present)257   void PutOrRemove(E value, bool should_be_present) {
258     if (should_be_present) {
259       Put(value);
260     } else {
261       Remove(value);
262     }
263   }
264 
265   // Returns true iff the given value is in range and a member of our set.
Has(E value)266   constexpr bool Has(E value) const {
267     return InRange(value) && enums_[ToIndex(value)];
268   }
269 
270   // Returns true iff the given set is a subset of our set.
HasAll(EnumSet other)271   bool HasAll(EnumSet other) const {
272     return (enums_ & other.enums_) == other.enums_;
273   }
274 
275   // Returns true if the given set contains any value of our set.
HasAny(EnumSet other)276   bool HasAny(EnumSet other) const {
277     return (enums_ & other.enums_).count() > 0;
278   }
279 
280   // Returns true iff our set is empty.
Empty()281   bool Empty() const { return !enums_.any(); }
282 
283   // Returns how many values our set has.
Size()284   size_t Size() const { return enums_.count(); }
285 
286   // Returns an iterator pointing to the first element (if any).
begin()287   Iterator begin() const { return Iterator(enums_); }
288 
289   // Returns an iterator that does not point to any element, but to the position
290   // that follows the last element in the set.
end()291   Iterator end() const { return Iterator(); }
292 
293   // Returns true iff our set and the given set contain exactly the same values.
294   bool operator==(const EnumSet& other) const { return enums_ == other.enums_; }
295 
296   // Returns true iff our set and the given set do not contain exactly the same
297   // values.
298   bool operator!=(const EnumSet& other) const { return enums_ != other.enums_; }
299 
300  private:
301   friend EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
302                                                       EnumSet set2);
303   friend EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
304                                                              EnumSet set2);
305   friend EnumSet Difference<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
306                                                            EnumSet set2);
307 
308   // A bitset can't be constexpr constructed if it has size > 64, since the
309   // constexpr constructor uses a uint64_t. If your EnumSet has > 64 values, you
310   // can safely remove the constepxr qualifiers from this file, at the cost of
311   // some minor optimizations.
EnumSet(EnumBitSet enums)312   explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
313     static_assert(kValueCount <= 64,
314                   "Max number of enum values is 64 for constexpr constructor");
315   }
316 
317   // Converts a value to/from an index into |enums_|.
ToIndex(E value)318   static constexpr size_t ToIndex(E value) {
319     CHECK(InRange(value));
320     return static_cast<size_t>(GetUnderlyingValue(value)) -
321            static_cast<size_t>(GetUnderlyingValue(MinEnumValue));
322   }
323 
FromIndex(size_t i)324   static E FromIndex(size_t i) {
325     DCHECK_LT(i, kValueCount);
326     return static_cast<E>(GetUnderlyingValue(MinEnumValue) + i);
327   }
328 
329   EnumBitSet enums_;
330 };
331 
332 template <typename E, E MinEnumValue, E MaxEnumValue>
333 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
334 
335 template <typename E, E MinEnumValue, E MaxEnumValue>
336 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
337 
338 template <typename E, E MinEnumValue, E MaxEnumValue>
339 const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
340 
341 // The usual set operations.
342 
343 template <typename E, E Min, E Max>
Union(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)344 EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
345                            EnumSet<E, Min, Max> set2) {
346   return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
347 }
348 
349 template <typename E, E Min, E Max>
Intersection(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)350 EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
351                                   EnumSet<E, Min, Max> set2) {
352   return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
353 }
354 
355 template <typename E, E Min, E Max>
Difference(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)356 EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
357                                 EnumSet<E, Min, Max> set2) {
358   return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
359 }
360 
361 }  // namespace base
362 
363 #endif  // BASE_CONTAINERS_ENUM_SET_H_
364