1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef BASE_CONTAINERS_ENUM_SET_H_
6 #define BASE_CONTAINERS_ENUM_SET_H_
7
8 #include <bitset>
9 #include <cstddef>
10 #include <initializer_list>
11 #include <optional>
12 #include <string>
13 #include <type_traits>
14 #include <utility>
15
16 #include "base/check.h"
17 #include "base/check_op.h"
18 #include "base/memory/raw_ptr.h"
19 #include "build/build_config.h"
20
21 namespace base {
22
23 // Forward declarations needed for friend declarations.
24 template <typename E, E MinEnumValue, E MaxEnumValue>
25 class EnumSet;
26
27 template <typename E, E Min, E Max>
28 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
29 EnumSet<E, Min, Max> set2);
30
31 template <typename E, E Min, E Max>
32 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
33 EnumSet<E, Min, Max> set2);
34
35 template <typename E, E Min, E Max>
36 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
37 EnumSet<E, Min, Max> set2);
38
39 // An EnumSet is a set that can hold enum values between a min and a
40 // max value (inclusive of both). It's essentially a wrapper around
41 // std::bitset<> with stronger type enforcement, more descriptive
42 // member function names, and an iterator interface.
43 //
44 // If you're working with enums with a small number of possible values
45 // (say, fewer than 64), you can efficiently pass around an EnumSet
46 // for that enum around by value.
47
48 template <typename E, E MinEnumValue, E MaxEnumValue>
49 class EnumSet {
50 private:
51 static_assert(
52 std::is_enum_v<E>,
53 "First template parameter of EnumSet must be an enumeration type");
54 using enum_underlying_type = std::underlying_type_t<E>;
55
InRange(E value)56 static constexpr bool InRange(E value) {
57 return (value >= MinEnumValue) && (value <= MaxEnumValue);
58 }
59
GetUnderlyingValue(E value)60 static constexpr enum_underlying_type GetUnderlyingValue(E value) {
61 return static_cast<enum_underlying_type>(value);
62 }
63
64 public:
65 using EnumType = E;
66 static const E kMinValue = MinEnumValue;
67 static const E kMaxValue = MaxEnumValue;
68 static const size_t kValueCount =
69 GetUnderlyingValue(kMaxValue) - GetUnderlyingValue(kMinValue) + 1;
70
71 static_assert(kMinValue <= kMaxValue,
72 "min value must be no greater than max value");
73
74 private:
75 // Declaration needed by Iterator.
76 using EnumBitSet = std::bitset<kValueCount>;
77
78 public:
79 // Iterator is a forward-only read-only iterator for EnumSet. It follows the
80 // common STL input iterator interface (like std::unordered_set).
81 //
82 // Example usage, using a range-based for loop:
83 //
84 // EnumSet<SomeType> enums;
85 // for (SomeType val : enums) {
86 // Process(val);
87 // }
88 //
89 // Or using an explicit iterator (not recommended):
90 //
91 // for (EnumSet<...>::Iterator it = enums.begin(); it != enums.end(); it++) {
92 // Process(*it);
93 // }
94 //
95 // The iterator must not be outlived by the set. In particular, the following
96 // is an error:
97 //
98 // EnumSet<...> SomeFn() { ... }
99 //
100 // /* ERROR */
101 // for (EnumSet<...>::Iterator it = SomeFun().begin(); ...
102 //
103 // Also, there are no guarantees as to what will happen if you
104 // modify an EnumSet while traversing it with an iterator.
105 class Iterator {
106 public:
107 using value_type = EnumType;
108 using size_type = size_t;
109 using difference_type = ptrdiff_t;
110 using pointer = EnumType*;
111 using reference = EnumType&;
112 using iterator_category = std::forward_iterator_tag;
113
Iterator()114 Iterator() : enums_(nullptr), i_(kValueCount) {}
115 ~Iterator() = default;
116
117 Iterator(const Iterator&) = default;
118 Iterator& operator=(const Iterator&) = default;
119
120 Iterator(Iterator&&) = default;
121 Iterator& operator=(Iterator&&) = default;
122
123 friend bool operator==(const Iterator& lhs, const Iterator& rhs) {
124 return lhs.i_ == rhs.i_;
125 }
126
127 value_type operator*() const {
128 DCHECK(Good());
129 return FromIndex(i_);
130 }
131
132 Iterator& operator++() {
133 DCHECK(Good());
134 // If there are no more set elements in the bitset, this will result in an
135 // index equal to kValueCount, which is equivalent to EnumSet.end().
136 i_ = FindNext(i_ + 1);
137
138 return *this;
139 }
140
141 Iterator operator++(int) {
142 DCHECK(Good());
143 Iterator old(*this);
144
145 // If there are no more set elements in the bitset, this will result in an
146 // index equal to kValueCount, which is equivalent to EnumSet.end().
147 i_ = FindNext(i_ + 1);
148
149 return std::move(old);
150 }
151
152 private:
153 friend Iterator EnumSet::begin() const;
154
Iterator(const EnumBitSet & enums)155 explicit Iterator(const EnumBitSet& enums)
156 : enums_(&enums), i_(FindNext(0)) {}
157
158 // Returns true iff the iterator points to an EnumSet and it
159 // hasn't yet traversed the EnumSet entirely.
Good()160 bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
161
FindNext(size_t i)162 size_t FindNext(size_t i) {
163 while ((i < kValueCount) && !enums_->test(i)) {
164 ++i;
165 }
166 return i;
167 }
168
169 raw_ptr<const EnumBitSet> enums_;
170 size_t i_;
171 };
172
173 EnumSet() = default;
174
175 ~EnumSet() = default;
176
EnumSet(std::initializer_list<E> values)177 constexpr EnumSet(std::initializer_list<E> values) {
178 if (std::is_constant_evaluated()) {
179 enums_ = bitstring(values);
180 } else {
181 for (E value : values) {
182 Put(value);
183 }
184 }
185 }
186
187 // Returns an EnumSet with all values between kMinValue and kMaxValue, which
188 // also contains undefined enum values if the enum in question has gaps
189 // between kMinValue and kMaxValue.
All()190 static constexpr EnumSet All() {
191 if (std::is_constant_evaluated()) {
192 if (kValueCount == 0) {
193 return EnumSet();
194 }
195 // Since `1 << kValueCount` may trigger shift-count-overflow warning if
196 // the `kValueCount` is 64, instead of returning `(1 << kValueCount) - 1`,
197 // the bitmask will be constructed from two parts: the most significant
198 // bits and the remaining.
199 uint64_t mask = 1ULL << (kValueCount - 1);
200 return EnumSet(EnumBitSet(mask - 1 + mask));
201 } else {
202 // When `kValueCount` is greater than 64, we can't use the constexpr path,
203 // and we will build an `EnumSet` value by value.
204 EnumSet enum_set;
205 for (size_t value = 0; value < kValueCount; ++value) {
206 enum_set.Put(FromIndex(value));
207 }
208 return enum_set;
209 }
210 }
211
212 // Returns an EnumSet with all the values from start to end, inclusive.
FromRange(E start,E end)213 static constexpr EnumSet FromRange(E start, E end) {
214 CHECK_LE(start, end);
215 return EnumSet(EnumBitSet(
216 ((single_val_bitstring(end)) - (single_val_bitstring(start))) |
217 (single_val_bitstring(end))));
218 }
219
220 // Copy constructor and assignment welcome.
221
222 // Bitmask operations.
223 //
224 // This bitmask is 0-based and the value of the Nth bit depends on whether
225 // the set contains an enum element of integer value N.
226 //
227 // These may only be used if Min >= 0 and Max < 64.
228
229 // Returns an EnumSet constructed from |bitmask|.
FromEnumBitmask(const uint64_t bitmask)230 static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
231 static_assert(GetUnderlyingValue(kMaxValue) < 64,
232 "The highest enum value must be < 64 for FromEnumBitmask ");
233 static_assert(GetUnderlyingValue(kMinValue) >= 0,
234 "The lowest enum value must be >= 0 for FromEnumBitmask ");
235 return EnumSet(EnumBitSet(bitmask >> GetUnderlyingValue(kMinValue)));
236 }
237 // Returns a bitmask for the EnumSet.
ToEnumBitmask()238 uint64_t ToEnumBitmask() const {
239 static_assert(GetUnderlyingValue(kMaxValue) < 64,
240 "The highest enum value must be < 64 for ToEnumBitmask ");
241 static_assert(GetUnderlyingValue(kMinValue) >= 0,
242 "The lowest enum value must be >= 0 for FromEnumBitmask ");
243 return enums_.to_ullong() << GetUnderlyingValue(kMinValue);
244 }
245
246 // Returns a uint64_t bit mask representing the values within the range
247 // [64*n, 64*n + 63] of the EnumSet.
GetNth64bitWordBitmask(size_t n)248 std::optional<uint64_t> GetNth64bitWordBitmask(size_t n) const {
249 // If the EnumSet contains less than n 64-bit masks, return std::nullopt.
250 if (GetUnderlyingValue(kMaxValue) / 64 < n) {
251 return std::nullopt;
252 }
253
254 std::bitset<kValueCount> mask = ~uint64_t{0};
255 std::bitset<kValueCount> bits = enums_;
256 if (GetUnderlyingValue(kMinValue) < n * 64) {
257 bits >>= n * 64 - GetUnderlyingValue(kMinValue);
258 }
259 uint64_t result = (bits & mask).to_ullong();
260 if (GetUnderlyingValue(kMinValue) > n * 64) {
261 result <<= GetUnderlyingValue(kMinValue) - n * 64;
262 }
263 return result;
264 }
265
266 // Set operations. Put, Retain, and Remove are basically
267 // self-mutating versions of Union, Intersection, and Difference
268 // (defined below).
269
270 // Adds the given value (which must be in range) to our set.
Put(E value)271 void Put(E value) { enums_.set(ToIndex(value)); }
272
273 // Adds all values in the given set to our set.
PutAll(EnumSet other)274 void PutAll(EnumSet other) { enums_ |= other.enums_; }
275
276 // Adds all values in the given range to our set, inclusive.
PutRange(E start,E end)277 void PutRange(E start, E end) {
278 CHECK_LE(start, end);
279 size_t endIndexInclusive = ToIndex(end);
280 for (size_t current = ToIndex(start); current <= endIndexInclusive;
281 ++current) {
282 enums_.set(current);
283 }
284 }
285
286 // There's no real need for a Retain(E) member function.
287
288 // Removes all values not in the given set from our set.
RetainAll(EnumSet other)289 void RetainAll(EnumSet other) { enums_ &= other.enums_; }
290
291 // If the given value is in range, removes it from our set.
Remove(E value)292 void Remove(E value) {
293 if (InRange(value)) {
294 enums_.reset(ToIndex(value));
295 }
296 }
297
298 // Removes all values in the given set from our set.
RemoveAll(EnumSet other)299 void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
300
301 // Removes all values from our set.
Clear()302 void Clear() { enums_.reset(); }
303
304 // Conditionally puts or removes `value`, based on `should_be_present`.
PutOrRemove(E value,bool should_be_present)305 void PutOrRemove(E value, bool should_be_present) {
306 if (should_be_present) {
307 Put(value);
308 } else {
309 Remove(value);
310 }
311 }
312
313 // Returns true iff the given value is in range and a member of our set.
Has(E value)314 constexpr bool Has(E value) const {
315 return InRange(value) && enums_[ToIndex(value)];
316 }
317
318 // Returns true iff the given set is a subset of our set.
HasAll(EnumSet other)319 bool HasAll(EnumSet other) const {
320 return (enums_ & other.enums_) == other.enums_;
321 }
322
323 // Returns true if the given set contains any value of our set.
HasAny(EnumSet other)324 bool HasAny(EnumSet other) const {
325 return (enums_ & other.enums_).count() > 0;
326 }
327
328 // Returns true iff our set is empty.
empty()329 bool empty() const { return !enums_.any(); }
330
331 // Returns how many values our set has.
size()332 size_t size() const { return enums_.count(); }
333
334 // Returns an iterator pointing to the first element (if any).
begin()335 Iterator begin() const { return Iterator(enums_); }
336
337 // Returns an iterator that does not point to any element, but to the position
338 // that follows the last element in the set.
end()339 Iterator end() const { return Iterator(); }
340
341 // Returns true iff our set and the given set contain exactly the same values.
342 friend bool operator==(const EnumSet&, const EnumSet&) = default;
343
ToString()344 std::string ToString() const { return enums_.to_string(); }
345
346 private:
347 friend constexpr EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
348 EnumSet set2);
349 friend constexpr EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(
350 EnumSet set1,
351 EnumSet set2);
352 friend constexpr EnumSet Difference<E, MinEnumValue, MaxEnumValue>(
353 EnumSet set1,
354 EnumSet set2);
355
bitstring(const std::initializer_list<E> & values)356 static constexpr uint64_t bitstring(const std::initializer_list<E>& values) {
357 uint64_t result = 0;
358 for (E value : values) {
359 result |= single_val_bitstring(value);
360 }
361 return result;
362 }
363
single_val_bitstring(E val)364 static constexpr uint64_t single_val_bitstring(E val) {
365 const uint64_t bitstring = 1;
366 const size_t shift_amount = ToIndex(val);
367 CHECK_LT(shift_amount, sizeof(bitstring) * 8);
368 return bitstring << shift_amount;
369 }
370
371 // A bitset can't be constexpr constructed if it has size > 64, since the
372 // constexpr constructor uses a uint64_t. If your EnumSet has > 64 values, you
373 // can safely remove the constepxr qualifiers from this file, at the cost of
374 // some minor optimizations.
EnumSet(EnumBitSet enums)375 explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
376 if (std::is_constant_evaluated()) {
377 CHECK(kValueCount <= 64)
378 << "Max number of enum values is 64 for constexpr constructor";
379 }
380 }
381
382 // Converts a value to/from an index into |enums_|.
ToIndex(E value)383 static constexpr size_t ToIndex(E value) {
384 CHECK(InRange(value));
385 return static_cast<size_t>(GetUnderlyingValue(value)) -
386 static_cast<size_t>(GetUnderlyingValue(MinEnumValue));
387 }
388
FromIndex(size_t i)389 static E FromIndex(size_t i) {
390 DCHECK_LT(i, kValueCount);
391 return static_cast<E>(GetUnderlyingValue(MinEnumValue) + i);
392 }
393
394 EnumBitSet enums_;
395 };
396
397 template <typename E, E MinEnumValue, E MaxEnumValue>
398 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
399
400 template <typename E, E MinEnumValue, E MaxEnumValue>
401 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
402
403 template <typename E, E MinEnumValue, E MaxEnumValue>
404 const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
405
406 // The usual set operations.
407
408 template <typename E, E Min, E Max>
Union(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)409 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
410 EnumSet<E, Min, Max> set2) {
411 return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
412 }
413
414 template <typename E, E Min, E Max>
Intersection(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)415 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
416 EnumSet<E, Min, Max> set2) {
417 return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
418 }
419
420 template <typename E, E Min, E Max>
Difference(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)421 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
422 EnumSet<E, Min, Max> set2) {
423 return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
424 }
425
426 } // namespace base
427
428 #endif // BASE_CONTAINERS_ENUM_SET_H_
429