1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef BASE_CONTAINERS_ENUM_SET_H_
6 #define BASE_CONTAINERS_ENUM_SET_H_
7
8 #include <bitset>
9 #include <cstddef>
10 #include <initializer_list>
11 #include <type_traits>
12 #include <utility>
13
14 #include "base/check.h"
15 #include "base/check_op.h"
16 #include "base/memory/raw_ptr.h"
17 #include "build/build_config.h"
18
19 namespace base {
20
21 // Forward declarations needed for friend declarations.
22 template <typename E, E MinEnumValue, E MaxEnumValue>
23 class EnumSet;
24
25 template <typename E, E Min, E Max>
26 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
27 EnumSet<E, Min, Max> set2);
28
29 template <typename E, E Min, E Max>
30 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
31 EnumSet<E, Min, Max> set2);
32
33 template <typename E, E Min, E Max>
34 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
35 EnumSet<E, Min, Max> set2);
36
37 // An EnumSet is a set that can hold enum values between a min and a
38 // max value (inclusive of both). It's essentially a wrapper around
39 // std::bitset<> with stronger type enforcement, more descriptive
40 // member function names, and an iterator interface.
41 //
42 // If you're working with enums with a small number of possible values
43 // (say, fewer than 64), you can efficiently pass around an EnumSet
44 // for that enum around by value.
45
46 template <typename E, E MinEnumValue, E MaxEnumValue>
47 class EnumSet {
48 private:
49 static_assert(
50 std::is_enum_v<E>,
51 "First template parameter of EnumSet must be an enumeration type");
52 using enum_underlying_type = std::underlying_type_t<E>;
53
InRange(E value)54 static constexpr bool InRange(E value) {
55 return (value >= MinEnumValue) && (value <= MaxEnumValue);
56 }
57
GetUnderlyingValue(E value)58 static constexpr enum_underlying_type GetUnderlyingValue(E value) {
59 return static_cast<enum_underlying_type>(value);
60 }
61
62 public:
63 using EnumType = E;
64 static const E kMinValue = MinEnumValue;
65 static const E kMaxValue = MaxEnumValue;
66 static const size_t kValueCount =
67 GetUnderlyingValue(kMaxValue) - GetUnderlyingValue(kMinValue) + 1;
68
69 static_assert(kMinValue <= kMaxValue,
70 "min value must be no greater than max value");
71
72 private:
73 // Declaration needed by Iterator.
74 using EnumBitSet = std::bitset<kValueCount>;
75
76 public:
77 // Iterator is a forward-only read-only iterator for EnumSet. It follows the
78 // common STL input iterator interface (like std::unordered_set).
79 //
80 // Example usage, using a range-based for loop:
81 //
82 // EnumSet<SomeType> enums;
83 // for (SomeType val : enums) {
84 // Process(val);
85 // }
86 //
87 // Or using an explicit iterator (not recommended):
88 //
89 // for (EnumSet<...>::Iterator it = enums.begin(); it != enums.end(); it++) {
90 // Process(*it);
91 // }
92 //
93 // The iterator must not be outlived by the set. In particular, the following
94 // is an error:
95 //
96 // EnumSet<...> SomeFn() { ... }
97 //
98 // /* ERROR */
99 // for (EnumSet<...>::Iterator it = SomeFun().begin(); ...
100 //
101 // Also, there are no guarantees as to what will happen if you
102 // modify an EnumSet while traversing it with an iterator.
103 class Iterator {
104 public:
Iterator()105 Iterator() : enums_(nullptr), i_(kValueCount) {}
106 ~Iterator() = default;
107
108 friend bool operator==(const Iterator& lhs, const Iterator& rhs) {
109 return lhs.i_ == rhs.i_;
110 }
111
112 E operator*() const {
113 DCHECK(Good());
114 return FromIndex(i_);
115 }
116
117 Iterator& operator++() {
118 DCHECK(Good());
119 // If there are no more set elements in the bitset, this will result in an
120 // index equal to kValueCount, which is equivalent to EnumSet.end().
121 i_ = FindNext(i_ + 1);
122
123 return *this;
124 }
125
126 Iterator operator++(int) {
127 DCHECK(Good());
128 Iterator old(*this);
129
130 // If there are no more set elements in the bitset, this will result in an
131 // index equal to kValueCount, which is equivalent to EnumSet.end().
132 i_ = FindNext(i_ + 1);
133
134 return std::move(old);
135 }
136
137 private:
138 friend Iterator EnumSet::begin() const;
139
Iterator(const EnumBitSet & enums)140 explicit Iterator(const EnumBitSet& enums)
141 : enums_(&enums), i_(FindNext(0)) {}
142
143 // Returns true iff the iterator points to an EnumSet and it
144 // hasn't yet traversed the EnumSet entirely.
Good()145 bool Good() const { return enums_ && i_ < kValueCount && enums_->test(i_); }
146
FindNext(size_t i)147 size_t FindNext(size_t i) {
148 while ((i < kValueCount) && !enums_->test(i)) {
149 ++i;
150 }
151 return i;
152 }
153
154 const raw_ptr<const EnumBitSet> enums_;
155 size_t i_;
156 };
157
158 EnumSet() = default;
159
160 ~EnumSet() = default;
161
EnumSet(std::initializer_list<E> values)162 constexpr EnumSet(std::initializer_list<E> values) {
163 if (std::is_constant_evaluated()) {
164 enums_ = bitstring(values);
165 } else {
166 for (E value : values) {
167 Put(value);
168 }
169 }
170 }
171
172 // Returns an EnumSet with all values between kMinValue and kMaxValue, which
173 // also contains undefined enum values if the enum in question has gaps
174 // between kMinValue and kMaxValue.
All()175 static constexpr EnumSet All() {
176 if (std::is_constant_evaluated()) {
177 if (kValueCount == 0) {
178 return EnumSet();
179 }
180 // Since `1 << kValueCount` may trigger shift-count-overflow warning if
181 // the `kValueCount` is 64, instead of returning `(1 << kValueCount) - 1`,
182 // the bitmask will be constructed from two parts: the most significant
183 // bits and the remaining.
184 uint64_t mask = 1ULL << (kValueCount - 1);
185 return EnumSet(EnumBitSet(mask - 1 + mask));
186 } else {
187 // When `kValueCount` is greater than 64, we can't use the constexpr path,
188 // and we will build an `EnumSet` value by value.
189 EnumSet enum_set;
190 for (size_t value = 0; value < kValueCount; ++value) {
191 enum_set.Put(FromIndex(value));
192 }
193 return enum_set;
194 }
195 }
196
197 // Returns an EnumSet with all the values from start to end, inclusive.
FromRange(E start,E end)198 static constexpr EnumSet FromRange(E start, E end) {
199 CHECK_LE(start, end);
200 return EnumSet(EnumBitSet(
201 ((single_val_bitstring(end)) - (single_val_bitstring(start))) |
202 (single_val_bitstring(end))));
203 }
204
205 // Copy constructor and assignment welcome.
206
207 // Bitmask operations.
208 //
209 // This bitmask is 0-based and the value of the Nth bit depends on whether
210 // the set contains an enum element of integer value N.
211 //
212 // These may only be used if Min >= 0 and Max < 64.
213
214 // Returns an EnumSet constructed from |bitmask|.
FromEnumBitmask(const uint64_t bitmask)215 static constexpr EnumSet FromEnumBitmask(const uint64_t bitmask) {
216 static_assert(GetUnderlyingValue(kMaxValue) < 64,
217 "The highest enum value must be < 64 for FromEnumBitmask ");
218 static_assert(GetUnderlyingValue(kMinValue) >= 0,
219 "The lowest enum value must be >= 0 for FromEnumBitmask ");
220 return EnumSet(EnumBitSet(bitmask >> GetUnderlyingValue(kMinValue)));
221 }
222 // Returns a bitmask for the EnumSet.
ToEnumBitmask()223 uint64_t ToEnumBitmask() const {
224 static_assert(GetUnderlyingValue(kMaxValue) < 64,
225 "The highest enum value must be < 64 for ToEnumBitmask ");
226 static_assert(GetUnderlyingValue(kMinValue) >= 0,
227 "The lowest enum value must be >= 0 for FromEnumBitmask ");
228 return enums_.to_ullong() << GetUnderlyingValue(kMinValue);
229 }
230
231 // Set operations. Put, Retain, and Remove are basically
232 // self-mutating versions of Union, Intersection, and Difference
233 // (defined below).
234
235 // Adds the given value (which must be in range) to our set.
Put(E value)236 void Put(E value) { enums_.set(ToIndex(value)); }
237
238 // Adds all values in the given set to our set.
PutAll(EnumSet other)239 void PutAll(EnumSet other) { enums_ |= other.enums_; }
240
241 // Adds all values in the given range to our set, inclusive.
PutRange(E start,E end)242 void PutRange(E start, E end) {
243 CHECK_LE(start, end);
244 size_t endIndexInclusive = ToIndex(end);
245 for (size_t current = ToIndex(start); current <= endIndexInclusive;
246 ++current) {
247 enums_.set(current);
248 }
249 }
250
251 // There's no real need for a Retain(E) member function.
252
253 // Removes all values not in the given set from our set.
RetainAll(EnumSet other)254 void RetainAll(EnumSet other) { enums_ &= other.enums_; }
255
256 // If the given value is in range, removes it from our set.
Remove(E value)257 void Remove(E value) {
258 if (InRange(value)) {
259 enums_.reset(ToIndex(value));
260 }
261 }
262
263 // Removes all values in the given set from our set.
RemoveAll(EnumSet other)264 void RemoveAll(EnumSet other) { enums_ &= ~other.enums_; }
265
266 // Removes all values from our set.
Clear()267 void Clear() { enums_.reset(); }
268
269 // Conditionally puts or removes `value`, based on `should_be_present`.
PutOrRemove(E value,bool should_be_present)270 void PutOrRemove(E value, bool should_be_present) {
271 if (should_be_present) {
272 Put(value);
273 } else {
274 Remove(value);
275 }
276 }
277
278 // Returns true iff the given value is in range and a member of our set.
Has(E value)279 constexpr bool Has(E value) const {
280 return InRange(value) && enums_[ToIndex(value)];
281 }
282
283 // Returns true iff the given set is a subset of our set.
HasAll(EnumSet other)284 bool HasAll(EnumSet other) const {
285 return (enums_ & other.enums_) == other.enums_;
286 }
287
288 // Returns true if the given set contains any value of our set.
HasAny(EnumSet other)289 bool HasAny(EnumSet other) const {
290 return (enums_ & other.enums_).count() > 0;
291 }
292
293 // Returns true iff our set is empty.
Empty()294 bool Empty() const { return !enums_.any(); }
295
296 // Returns how many values our set has.
Size()297 size_t Size() const { return enums_.count(); }
298
299 // Returns an iterator pointing to the first element (if any).
begin()300 Iterator begin() const { return Iterator(enums_); }
301
302 // Returns an iterator that does not point to any element, but to the position
303 // that follows the last element in the set.
end()304 Iterator end() const { return Iterator(); }
305
306 // Returns true iff our set and the given set contain exactly the same values.
307 friend bool operator==(const EnumSet&, const EnumSet&) = default;
308
309 private:
310 friend constexpr EnumSet Union<E, MinEnumValue, MaxEnumValue>(EnumSet set1,
311 EnumSet set2);
312 friend constexpr EnumSet Intersection<E, MinEnumValue, MaxEnumValue>(
313 EnumSet set1,
314 EnumSet set2);
315 friend constexpr EnumSet Difference<E, MinEnumValue, MaxEnumValue>(
316 EnumSet set1,
317 EnumSet set2);
318
bitstring(const std::initializer_list<E> & values)319 static constexpr uint64_t bitstring(const std::initializer_list<E>& values) {
320 uint64_t result = 0;
321 for (E value : values) {
322 result |= single_val_bitstring(value);
323 }
324 return result;
325 }
326
single_val_bitstring(E val)327 static constexpr uint64_t single_val_bitstring(E val) {
328 const uint64_t bitstring = 1;
329 const size_t shift_amount = ToIndex(val);
330 CHECK_LT(shift_amount, sizeof(bitstring) * 8);
331 return bitstring << shift_amount;
332 }
333
334 // A bitset can't be constexpr constructed if it has size > 64, since the
335 // constexpr constructor uses a uint64_t. If your EnumSet has > 64 values, you
336 // can safely remove the constepxr qualifiers from this file, at the cost of
337 // some minor optimizations.
EnumSet(EnumBitSet enums)338 explicit constexpr EnumSet(EnumBitSet enums) : enums_(enums) {
339 if (std::is_constant_evaluated()) {
340 CHECK(kValueCount <= 64)
341 << "Max number of enum values is 64 for constexpr constructor";
342 }
343 }
344
345 // Converts a value to/from an index into |enums_|.
ToIndex(E value)346 static constexpr size_t ToIndex(E value) {
347 CHECK(InRange(value));
348 return static_cast<size_t>(GetUnderlyingValue(value)) -
349 static_cast<size_t>(GetUnderlyingValue(MinEnumValue));
350 }
351
FromIndex(size_t i)352 static E FromIndex(size_t i) {
353 DCHECK_LT(i, kValueCount);
354 return static_cast<E>(GetUnderlyingValue(MinEnumValue) + i);
355 }
356
357 EnumBitSet enums_;
358 };
359
360 template <typename E, E MinEnumValue, E MaxEnumValue>
361 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMinValue;
362
363 template <typename E, E MinEnumValue, E MaxEnumValue>
364 const E EnumSet<E, MinEnumValue, MaxEnumValue>::kMaxValue;
365
366 template <typename E, E MinEnumValue, E MaxEnumValue>
367 const size_t EnumSet<E, MinEnumValue, MaxEnumValue>::kValueCount;
368
369 // The usual set operations.
370
371 template <typename E, E Min, E Max>
Union(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)372 constexpr EnumSet<E, Min, Max> Union(EnumSet<E, Min, Max> set1,
373 EnumSet<E, Min, Max> set2) {
374 return EnumSet<E, Min, Max>(set1.enums_ | set2.enums_);
375 }
376
377 template <typename E, E Min, E Max>
Intersection(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)378 constexpr EnumSet<E, Min, Max> Intersection(EnumSet<E, Min, Max> set1,
379 EnumSet<E, Min, Max> set2) {
380 return EnumSet<E, Min, Max>(set1.enums_ & set2.enums_);
381 }
382
383 template <typename E, E Min, E Max>
Difference(EnumSet<E,Min,Max> set1,EnumSet<E,Min,Max> set2)384 constexpr EnumSet<E, Min, Max> Difference(EnumSet<E, Min, Max> set1,
385 EnumSet<E, Min, Max> set2) {
386 return EnumSet<E, Min, Max>(set1.enums_ & ~set2.enums_);
387 }
388
389 } // namespace base
390
391 #endif // BASE_CONTAINERS_ENUM_SET_H_
392