1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef BASE_CONTAINERS_ENUM_SET_H_
6 #define BASE_CONTAINERS_ENUM_SET_H_
7
8 #include <bitset>
9 #include <cstddef>
10 #include <type_traits>
11 #include <utility>
12
13 #include "base/check.h"
14 #include "base/check_op.h"
15 #include "base/memory/raw_ptr.h"
16
17 namespace base {
18
19 // Forward declarations needed for friend declarations.
20 template <typename E, E MinEnumValue, E MaxEnumValue>
21 class EnumSet;
22
23 template <typename E, E Min, E Max>
24 EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
25 EnumSet<E, Min, Max> set2);
26
27 template <typename E, E Min, E Max>
28 EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
29 EnumSet<E, Min, Max> set2);
30
31 template <typename E, E Min, E Max>
32 EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
33 EnumSet<E, Min, Max> set2);
34
35 // An EnumSet is a set that can hold enum values between a min and a
36 // max value (inclusive of both). It's essentially a wrapper around
37 // std::bitset<> with stronger type enforcement, more descriptive
38 // member function names, and an iterator interface.
39 //
40 // If you're working with enums with a small number of possible values
41 // (say, fewer than 64), you can efficiently pass around an EnumSet
42 // for that enum around by value.
43
44 template <typename E, E MinEnumValue, E MaxEnumValue>
45 class EnumSet {
46 private:
47 static_assert(
48 std::is_enum<E>::value,
49 "First template parameter of EnumSet must be an enumeration type");
50 using enum_underlying_type = std::underlying_type_t<E>;
51
InRange(E value)52 static constexpr bool InRange(E value) {
53 return (value >= MinEnumValue) && (value <= MaxEnumValue);
54 }
55
GetUnderlyingValue(E value)56 static constexpr enum_underlying_type GetUnderlyingValue(E value) {
57 return static_cast<enum_underlying_type>(value);
58 }
59
60 public:
61 using EnumType = E;
62 static const E kMinValue = MinEnumValue;
63 static const E kMaxValue = MaxEnumValue;
64 static const size_t kValueCount =
65 GetUnderlyingValue(kMaxValue) - GetUnderlyingValue(kMinValue) + 1;
66
67 static_assert(kMinValue < kMaxValue, "min value must be less than max value");
68
69 private:
70 // Declaration needed by Iterator.
71 using EnumBitSet = std::bitset<kValueCount>;
72
73 public:
74 // Iterator is a forward-only read-only iterator for EnumSet. It follows the
75 // common STL input iterator interface (like std::unordered_set).
76 //
77 // Example usage, using a range-based for loop:
78 //
79 // EnumSet<SomeType> enums;
80 // for (SomeType val : enums) {
81 // Process(val);
82 // }
83 //
84 // Or using an explicit iterator (not recommended):
85 //
86 // for (EnumSet<...>::Iterator it = enums.begin(); it != enums.end(); it++) {
87 // Process(*it);
88 // }
89 //
90 // The iterator must not be outlived by the set. In particular, the following
91 // is an error:
92 //
93 // EnumSet<...> SomeFn() { ... }
94 //
95 // /* ERROR */
96 // for (EnumSet<...>::Iterator it = SomeFun().begin(); ...
97 //
98 // Also, there are no guarantees as to what will happen if you
99 // modify an EnumSet while traversing it with an iterator.
100 class Iterator {
101 public:
Iterator()102 Iterator() : enums_(nullptr), i_(kValueCount) {}
103 ~Iterator() = default;
104
105 bool operator==(const Iterator& other) const { return i_ == other.i_; }
106
107 bool operator!=(const Iterator& other) const { return !(*this == other); }
108
109 E operator*() const {
110 DCHECK(Good());
111 return FromIndex(i_);
112 }
113
114 Iterator& operator++() {
115 DCHECK(Good());
116 // If there are no more set elements in the bitset, this will result in an
117 // index equal to kValueCount, which is equivalent to EnumSet.end().
118 i_ = FindNext(i_ + 1);
119
120 return *this;
121 }
122
123 Iterator operator++(int) {
124 DCHECK(Good());
125 Iterator old(*this);
126
127 // If there are no more set elements in the bitset, this will result in an
128 // index equal to kValueCount, which is equivalent to EnumSet.end().
129 i_ = FindNext(i_ + 1);
130
131 return std::move(old);
132 }
133
134 private:
135 friend Iterator EnumSet::begin() const;
136
Iterator(const EnumBitSet & enums)137 explicit Iterator(const EnumBitSet& enums)
138 : enums_(&enums), i_(FindNext(0)) {}
139
140 // Returns true iff the iterator points to an EnumSet and it
141 // hasn't yet traversed the EnumSet entirely.
Good()142 bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
143
FindNext(size_t i)144 size_t FindNext(size_t i) {
145 while ((i < kValueCount) && !enums_->test(i)) {
146 ++i;
147 }
148 return i;
149 }
150
151 raw_ptr<const EnumBitSet, DanglingUntriaged> enums_;
152 size_t i_;
153 };
154
155 EnumSet() = default;
156
157 ~EnumSet() = default;
158
single_val_bitstring(E val)159 static constexpr uint64_t single_val_bitstring(E val) {
160 const uint64_t bitstring = 1;
161 const size_t shift_amount = ToIndex(val);
162 CHECK_LT(shift_amount, sizeof(bitstring) * 8);
163 return bitstring << shift_amount;
164 }
165
166 template <class... T>
bitstring(T...values)167 static constexpr uint64_t bitstring(T... values) {
168 uint64_t converted[] = {single_val_bitstring(values)...};
169 uint64_t result = 0;
170 for (uint64_t e : converted)
171 result |= e;
172 return result;
173 }
174
175 template <class... T>
EnumSet(E head,T...tail)176 constexpr EnumSet(E head, T... tail)
177 : EnumSet(EnumBitSet(bitstring(head, tail...))) {}
178
179 // Returns an EnumSet with all possible values.
All()180 static constexpr EnumSet All() {
181 return EnumSet(EnumBitSet((1ULL << kValueCount) - 1));
182 }
183
184 // Returns an EnumSet with all the values from start to end, inclusive.
FromRange(E start,E end)185 static constexpr EnumSet FromRange(E start, E end) {
186 CHECK_LE(start, end);
187 return EnumSet(EnumBitSet(
188 ((single_val_bitstring(end)) - (single_val_bitstring(start))) |
189 (single_val_bitstring(end))));
190 }
191
192 // Copy constructor and assignment welcome.
193
194 // Bitmask operations.
195 //
196 // This bitmask is 0-based and the value of the Nth bit depends on whether
197 // the set contains an enum element of integer value N.
198 //
199 // These may only be used if Min >= 0 and Max < 64.
200
201 // Returns an EnumSet constructed from |bitmask|.
FromEnumBitmask(const uint64_t bitmask)202 static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
203 static_assert(GetUnderlyingValue(kMaxValue) < 64,
204 "The highest enum value must be < 64 for FromEnumBitmask ");
205 static_assert(GetUnderlyingValue(kMinValue) >= 0,
206 "The lowest enum value must be >= 0 for FromEnumBitmask ");
207 return EnumSet(EnumBitSet(bitmask >> GetUnderlyingValue(kMinValue)));
208 }
209 // Returns a bitmask for the EnumSet.
ToEnumBitmask()210 uint64_t ToEnumBitmask() const {
211 static_assert(GetUnderlyingValue(kMaxValue) < 64,
212 "The highest enum value must be < 64 for ToEnumBitmask ");
213 static_assert(GetUnderlyingValue(kMinValue) >= 0,
214 "The lowest enum value must be >= 0 for FromEnumBitmask ");
215 return enums_.to_ullong() << GetUnderlyingValue(kMinValue);
216 }
217
218 // Set operations. Put, Retain, and Remove are basically
219 // self-mutating versions of Union, Intersection, and Difference
220 // (defined below).
221
222 // Adds the given value (which must be in range) to our set.
Put(E value)223 void Put(E value) { enums_.set(ToIndex(value)); }
224
225 // Adds all values in the given set to our set.
PutAll(EnumSet other)226 void PutAll(EnumSet other) { enums_ |= other.enums_; }
227
228 // Adds all values in the given range to our set, inclusive.
PutRange(E start,E end)229 void PutRange(E start, E end) {
230 CHECK_LE(start, end);
231 size_t endIndexInclusive = ToIndex(end);
232 for (size_t current = ToIndex(start); current <= endIndexInclusive;
233 ++current) {
234 enums_.set(current);
235 }
236 }
237
238 // There's no real need for a Retain(E) member function.
239
240 // Removes all values not in the given set from our set.
RetainAll(EnumSet other)241 void RetainAll(EnumSet other) { enums_ &= other.enums_; }
242
243 // If the given value is in range, removes it from our set.
Remove(E value)244 void Remove(E value) {
245 if (InRange(value)) {
246 enums_.reset(ToIndex(value));
247 }
248 }
249
250 // Removes all values in the given set from our set.
RemoveAll(EnumSet other)251 void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
252
253 // Removes all values from our set.
Clear()254 void Clear() { enums_.reset(); }
255
256 // Conditionally puts or removes `value`, based on `should_be_present`.
PutOrRemove(E value,bool should_be_present)257 void PutOrRemove(E value, bool should_be_present) {
258 if (should_be_present) {
259 Put(value);
260 } else {
261 Remove(value);
262 }
263 }
264
265 // Returns true iff the given value is in range and a member of our set.
Has(E value)266 constexpr bool Has(E value) const {
267 return InRange(value) && enums_[ToIndex(value)];
268 }
269
270 // Returns true iff the given set is a subset of our set.
HasAll(EnumSet other)271 bool HasAll(EnumSet other) const {
272 return (enums_ & other.enums_) == other.enums_;
273 }
274
275 // Returns true if the given set contains any value of our set.
HasAny(EnumSet other)276 bool HasAny(EnumSet other) const {
277 return (enums_ & other.enums_).count() > 0;
278 }
279
280 // Returns true iff our set is empty.
Empty()281 bool Empty() const { return !enums_.any(); }
282
283 // Returns how many values our set has.
Size()284 size_t Size() const { return enums_.count(); }
285
286 // Returns an iterator pointing to the first element (if any).
begin()287 Iterator begin() const { return Iterator(enums_); }
288
289 // Returns an iterator that does not point to any element, but to the position
290 // that follows the last element in the set.
end()291 Iterator end() const { return Iterator(); }
292
293 // Returns true iff our set and the given set contain exactly the same values.
294 bool operator==(const EnumSet& other) const { return enums_ == other.enums_; }
295
296 // Returns true iff our set and the given set do not contain exactly the same
297 // values.
298 bool operator!=(const EnumSet& other) const { return enums_ != other.enums_; }
299
300 private:
301 friend EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
302 EnumSet set2);
303 friend EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
304 EnumSet set2);
305 friend EnumSet Difference<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
306 EnumSet set2);
307
308 // A bitset can't be constexpr constructed if it has size > 64, since the
309 // constexpr constructor uses a uint64_t. If your EnumSet has > 64 values, you
310 // can safely remove the constepxr qualifiers from this file, at the cost of
311 // some minor optimizations.
EnumSet(EnumBitSet enums)312 explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
313 static_assert(kValueCount <= 64,
314 "Max number of enum values is 64 for constexpr constructor");
315 }
316
317 // Converts a value to/from an index into |enums_|.
ToIndex(E value)318 static constexpr size_t ToIndex(E value) {
319 CHECK(InRange(value));
320 return static_cast<size_t>(GetUnderlyingValue(value)) -
321 static_cast<size_t>(GetUnderlyingValue(MinEnumValue));
322 }
323
FromIndex(size_t i)324 static E FromIndex(size_t i) {
325 DCHECK_LT(i, kValueCount);
326 return static_cast<E>(GetUnderlyingValue(MinEnumValue) + i);
327 }
328
329 EnumBitSet enums_;
330 };
331
332 template <typename E, E MinEnumValue, E MaxEnumValue>
333 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
334
335 template <typename E, E MinEnumValue, E MaxEnumValue>
336 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
337
338 template <typename E, E MinEnumValue, E MaxEnumValue>
339 const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
340
341 // The usual set operations.
342
343 template <typename E, E Min, E Max>
Union(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)344 EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
345 EnumSet<E, Min, Max> set2) {
346 return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
347 }
348
349 template <typename E, E Min, E Max>
Intersection(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)350 EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
351 EnumSet<E, Min, Max> set2) {
352 return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
353 }
354
355 template <typename E, E Min, E Max>
Difference(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)356 EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
357 EnumSet<E, Min, Max> set2) {
358 return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
359 }
360
361 } // namespace base
362
363 #endif // BASE_CONTAINERS_ENUM_SET_H_
364