• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright © 2022 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3 
4 //! A set of usizes, represented as a bit vector
5 //!
6 //! In addition to some basic operations like `insert()` and `remove()`, this
7 //! module also lets you write expressions on sets that are lazily evaluated. To
8 //! do so, call `.s(..)` on the set to reference the bitset in a
9 //! lazily-evaluated `BitSetStream`, and then use typical binary operators on
10 //! the `BitSetStream`s.
11 //! ```rust
12 //! let a = BitSet::new();
13 //! let b = BitSet::new();
14 //! let c = BitSet::new();
15 //!
16 //! c.assign(a.s(..) | b.s(..));
17 //! c ^= a.s(..);
18 //! ```
19 //! Supported binary operations are `&`, `|`, `^`, `-`. Note that there is no
20 //! unary negation, because that would result in an infinite result set. For
21 //! patterns like `a & !b`, instead use set subtraction `a - b`.
22 
23 use std::cmp::{max, min};
24 use std::ops::{
25     BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, RangeFull,
26     Sub, SubAssign,
27 };
28 
29 #[derive(Clone)]
30 pub struct BitSet {
31     words: Vec<u32>,
32 }
33 
34 impl BitSet {
new() -> BitSet35     pub fn new() -> BitSet {
36         BitSet { words: Vec::new() }
37     }
38 
reserve_words(&mut self, words: usize)39     fn reserve_words(&mut self, words: usize) {
40         if self.words.len() < words {
41             self.words.resize(words, 0);
42         }
43     }
44 
reserve(&mut self, bits: usize)45     pub fn reserve(&mut self, bits: usize) {
46         self.reserve_words(bits.div_ceil(32));
47     }
48 
clear(&mut self)49     pub fn clear(&mut self) {
50         for w in self.words.iter_mut() {
51             *w = 0;
52         }
53     }
54 
get(&self, idx: usize) -> bool55     pub fn get(&self, idx: usize) -> bool {
56         let w = idx / 32;
57         let b = idx % 32;
58         if w < self.words.len() {
59             self.words[w] & (1_u32 << b) != 0
60         } else {
61             false
62         }
63     }
64 
is_empty(&self) -> bool65     pub fn is_empty(&self) -> bool {
66         for w in self.words.iter() {
67             if *w != 0 {
68                 return false;
69             }
70         }
71         true
72     }
73 
iter(&self) -> impl '_ + Iterator<Item = usize>74     pub fn iter(&self) -> impl '_ + Iterator<Item = usize> {
75         BitSetIter::new(self)
76     }
77 
next_unset(&self, start: usize) -> usize78     pub fn next_unset(&self, start: usize) -> usize {
79         if start >= self.words.len() * 32 {
80             return start;
81         }
82 
83         let mut w = start / 32;
84         let mut mask = !(u32::MAX << (start % 32));
85         while w < self.words.len() {
86             let b = (self.words[w] | mask).trailing_ones();
87             if b < 32 {
88                 return w * 32 + usize::try_from(b).unwrap();
89             }
90             mask = 0;
91             w += 1;
92         }
93         self.words.len() * 32
94     }
95 
insert(&mut self, idx: usize) -> bool96     pub fn insert(&mut self, idx: usize) -> bool {
97         let w = idx / 32;
98         let b = idx % 32;
99         self.reserve_words(w + 1);
100         let exists = self.words[w] & (1_u32 << b) != 0;
101         self.words[w] |= 1_u32 << b;
102         !exists
103     }
104 
remove(&mut self, idx: usize) -> bool105     pub fn remove(&mut self, idx: usize) -> bool {
106         let w = idx / 32;
107         let b = idx % 32;
108         self.reserve_words(w + 1);
109         let exists = self.words[w] & (1_u32 << b) != 0;
110         self.words[w] &= !(1_u32 << b);
111         exists
112     }
113 
114     /// Evaluate an expression and store its value in self
assign<B>(&mut self, value: BitSetStream<B>) where B: BitSetStreamTrait,115     pub fn assign<B>(&mut self, value: BitSetStream<B>)
116     where
117         B: BitSetStreamTrait,
118     {
119         let mut value = value.0;
120         let len = value.len();
121         self.words.clear();
122         self.words.resize_with(len, || value.next());
123         for _ in 0..16 {
124             debug_assert_eq!(value.next(), 0);
125         }
126     }
127 
128     /// Calculate the union of self and an expression, and store the result in
129     /// self.
130     ///
131     /// Returns true if the value of self changes, or false otherwise. If you
132     /// don't need the return value of this function, consider using the `|=`
133     /// operator instead.
union_with<B>(&mut self, other: BitSetStream<B>) -> bool where B: BitSetStreamTrait,134     pub fn union_with<B>(&mut self, other: BitSetStream<B>) -> bool
135     where
136         B: BitSetStreamTrait,
137     {
138         let mut other = other.0;
139         let mut added_bits = false;
140         let other_len = other.len();
141         self.reserve_words(other_len);
142         for w in 0..other_len {
143             let uw = self.words[w] | other.next();
144             if uw != self.words[w] {
145                 added_bits = true;
146                 self.words[w] = uw;
147             }
148         }
149         added_bits
150     }
151 
s<'a>( &'a self, _: RangeFull, ) -> BitSetStream<impl 'a + BitSetStreamTrait>152     pub fn s<'a>(
153         &'a self,
154         _: RangeFull,
155     ) -> BitSetStream<impl 'a + BitSetStreamTrait> {
156         BitSetStream(BitSetStreamFromBitSet {
157             iter: self.words.iter().copied(),
158         })
159     }
160 }
161 
162 impl Default for BitSet {
default() -> BitSet163     fn default() -> BitSet {
164         BitSet::new()
165     }
166 }
167 
168 impl FromIterator<usize> for BitSet {
from_iter<T>(iter: T) -> Self where T: IntoIterator<Item = usize>,169     fn from_iter<T>(iter: T) -> Self
170     where
171         T: IntoIterator<Item = usize>,
172     {
173         let mut res = BitSet::new();
174         for i in iter {
175             res.insert(i);
176         }
177         res
178     }
179 }
180 
181 pub trait BitSetStreamTrait {
182     /// Get the next word
183     ///
184     /// Guaranteed to return 0 after len() elements
next(&mut self) -> u32185     fn next(&mut self) -> u32;
186 
187     /// Get the number of output words
len(&self) -> usize188     fn len(&self) -> usize;
189 }
190 
191 struct BitSetStreamFromBitSet<T>
192 where
193     T: ExactSizeIterator<Item = u32>,
194 {
195     iter: T,
196 }
197 
198 impl<T> BitSetStreamTrait for BitSetStreamFromBitSet<T>
199 where
200     T: ExactSizeIterator<Item = u32>,
201 {
next(&mut self) -> u32202     fn next(&mut self) -> u32 {
203         self.iter.next().unwrap_or(0)
204     }
len(&self) -> usize205     fn len(&self) -> usize {
206         self.iter.len()
207     }
208 }
209 
210 pub struct BitSetStream<T>(T)
211 where
212     T: BitSetStreamTrait;
213 
214 impl<T> From<BitSetStream<T>> for BitSet
215 where
216     T: BitSetStreamTrait,
217 {
from(value: BitSetStream<T>) -> Self218     fn from(value: BitSetStream<T>) -> Self {
219         let mut out = BitSet::new();
220         out.assign(value);
221         out
222     }
223 }
224 
225 macro_rules! binop {
226     (
227         $BinOp:ident,
228         $bin_op:ident,
229         $AssignBinOp:ident,
230         $assign_bin_op:ident,
231         $Struct:ident,
232         |$a:ident, $b:ident| $next_impl:expr,
233         |$a_len: ident, $b_len: ident| $len_impl:expr,
234     ) => {
235         pub struct $Struct<A, B>
236         where
237             A: BitSetStreamTrait,
238             B: BitSetStreamTrait,
239         {
240             a: A,
241             b: B,
242         }
243 
244         impl<A, B> BitSetStreamTrait for $Struct<A, B>
245         where
246             A: BitSetStreamTrait,
247             B: BitSetStreamTrait,
248         {
249             fn next(&mut self) -> u32 {
250                 let $a = self.a.next();
251                 let $b = self.b.next();
252                 $next_impl
253             }
254 
255             fn len(&self) -> usize {
256                 let $a_len = self.a.len();
257                 let $b_len = self.b.len();
258                 let new_len = $len_impl;
259                 new_len
260             }
261         }
262 
263         impl<A, B> $BinOp<BitSetStream<B>> for BitSetStream<A>
264         where
265             A: BitSetStreamTrait,
266             B: BitSetStreamTrait,
267         {
268             type Output = BitSetStream<$Struct<A, B>>;
269 
270             fn $bin_op(self, rhs: BitSetStream<B>) -> Self::Output {
271                 BitSetStream($Struct {
272                     a: self.0,
273                     b: rhs.0,
274                 })
275             }
276         }
277 
278         impl<B> $AssignBinOp<BitSetStream<B>> for BitSet
279         where
280             B: BitSetStreamTrait,
281         {
282             fn $assign_bin_op(&mut self, rhs: BitSetStream<B>) {
283                 let mut rhs = rhs.0;
284 
285                 let $a_len = self.words.len();
286                 let $b_len = rhs.len();
287                 let expected_word_len = $len_impl;
288                 self.words.resize(expected_word_len, 0);
289 
290                 for lhs in &mut self.words {
291                     let $a = *lhs;
292                     let $b = rhs.next();
293                     *lhs = $next_impl;
294                 }
295 
296                 for _ in 0..16 {
297                     debug_assert_eq!(
298                         {
299                             let $a = 0;
300                             let $b = rhs.next();
301                             $next_impl
302                         },
303                         0
304                     );
305                 }
306             }
307         }
308     };
309 }
310 
311 binop!(
312     BitAnd,
313     bitand,
314     BitAndAssign,
315     bitand_assign,
316     BitSetStreamAnd,
317     |a, b| a & b,
318     |a, b| min(a, b),
319 );
320 
321 binop!(
322     BitOr,
323     bitor,
324     BitOrAssign,
325     bitor_assign,
326     BitSetStreamOr,
327     |a, b| a | b,
328     |a, b| max(a, b),
329 );
330 
331 binop!(
332     BitXor,
333     bitxor,
334     BitXorAssign,
335     bitxor_assign,
336     BitSetStreamXor,
337     |a, b| a ^ b,
338     |a, b| max(a, b),
339 );
340 
341 binop!(
342     Sub,
343     sub,
344     SubAssign,
345     sub_assign,
346     BitSetStreamSub,
347     |a, b| a & !b,
348     |a, _b| a,
349 );
350 
351 struct BitSetIter<'a> {
352     set: &'a BitSet,
353     w: usize,
354     mask: u32,
355 }
356 
357 impl<'a> BitSetIter<'a> {
new(set: &'a BitSet) -> Self358     fn new(set: &'a BitSet) -> Self {
359         Self {
360             set,
361             w: 0,
362             mask: u32::MAX,
363         }
364     }
365 }
366 
367 impl<'a> Iterator for BitSetIter<'a> {
368     type Item = usize;
369 
next(&mut self) -> Option<usize>370     fn next(&mut self) -> Option<usize> {
371         while self.w < self.set.words.len() {
372             let b = (self.set.words[self.w] & self.mask).trailing_zeros();
373             if b < 32 {
374                 self.mask &= !(1 << b);
375                 return Some(self.w * 32 + usize::try_from(b).unwrap());
376             }
377             self.mask = u32::MAX;
378             self.w += 1;
379         }
380         None
381     }
382 }
383 
384 #[cfg(test)]
385 mod tests {
386     use super::*;
387 
to_vec(set: &BitSet) -> Vec<usize>388     fn to_vec(set: &BitSet) -> Vec<usize> {
389         set.iter().collect()
390     }
391 
392     #[test]
test_basic()393     fn test_basic() {
394         let mut set = BitSet::new();
395 
396         assert_eq!(to_vec(&set), &[]);
397         assert!(set.is_empty());
398 
399         set.insert(0);
400 
401         assert_eq!(to_vec(&set), &[0]);
402 
403         set.insert(73);
404         set.insert(1);
405 
406         assert_eq!(to_vec(&set), &[0, 1, 73]);
407         assert!(!set.is_empty());
408 
409         assert!(set.get(73));
410         assert!(!set.get(197));
411 
412         assert!(set.remove(1));
413         assert!(!set.remove(7));
414 
415         let mut set2 = set.clone();
416         assert_eq!(to_vec(&set), &[0, 73]);
417         assert!(!set.is_empty());
418 
419         assert!(set.remove(0));
420         assert!(set.remove(73));
421         assert!(set.is_empty());
422 
423         set.clear();
424         assert!(set.is_empty());
425 
426         set2.clear();
427         assert!(set2.is_empty());
428     }
429 
430     #[test]
test_next_unset()431     fn test_next_unset() {
432         for test_range in
433             &[0..0, 42..1337, 1337..1337, 31..32, 32..33, 63..64, 64..65]
434         {
435             let mut set = BitSet::new();
436             for i in test_range.clone() {
437                 set.insert(i);
438             }
439             for extra_bit in [17, 34, 39] {
440                 assert!(test_range.end != extra_bit);
441                 set.insert(extra_bit);
442             }
443             assert_eq!(set.next_unset(test_range.start), test_range.end);
444         }
445     }
446 
447     #[test]
test_from_iter()448     fn test_from_iter() {
449         let vec = vec![0, 1, 99];
450         let set: BitSet = vec.clone().into_iter().collect();
451         assert_eq!(to_vec(&set), vec);
452     }
453 
454     #[test]
test_or()455     fn test_or() {
456         let a: BitSet = vec![9, 23, 18, 72].into_iter().collect();
457         let b: BitSet = vec![7, 23, 1337].into_iter().collect();
458         let expected = vec![7, 9, 18, 23, 72, 1337];
459 
460         assert_eq!(to_vec(&(a.s(..) | b.s(..)).into()), &expected[..]);
461         assert_eq!(to_vec(&(b.s(..) | a.s(..)).into()), &expected[..]);
462 
463         let mut actual_1 = a.clone();
464         actual_1 |= b.s(..);
465         assert_eq!(to_vec(&actual_1), &expected[..]);
466 
467         let mut actual_2 = b.clone();
468         actual_2 |= a.s(..);
469         assert_eq!(to_vec(&actual_2), &expected[..]);
470 
471         let mut actual_3 = a.clone();
472         assert_eq!(actual_3.union_with(a.s(..)), false);
473         assert_eq!(actual_3.union_with(b.s(..)), true);
474         assert_eq!(to_vec(&actual_3), &expected[..]);
475 
476         let mut actual_4 = b.clone();
477         assert_eq!(actual_4.union_with(b.s(..)), false);
478         assert_eq!(actual_4.union_with(a.s(..)), true);
479         assert_eq!(to_vec(&actual_4), &expected[..]);
480     }
481 
482     #[test]
test_and()483     fn test_and() {
484         let a: BitSet = vec![1337, 42, 7, 1].into_iter().collect();
485         let b: BitSet = vec![42, 783, 2, 7].into_iter().collect();
486         let expected = vec![7, 42];
487 
488         assert_eq!(to_vec(&(a.s(..) & b.s(..)).into()), &expected[..]);
489         assert_eq!(to_vec(&(b.s(..) & a.s(..)).into()), &expected[..]);
490 
491         let mut actual_1 = a.clone();
492         actual_1 &= b.s(..);
493         assert_eq!(to_vec(&actual_1), &expected[..]);
494 
495         let mut actual_2 = b.clone();
496         actual_2 &= a.s(..);
497         assert_eq!(to_vec(&actual_2), &expected[..]);
498     }
499 
500     #[test]
test_xor()501     fn test_xor() {
502         let a: BitSet = vec![1337, 42, 7, 1].into_iter().collect();
503         let b: BitSet = vec![42, 127, 2, 7].into_iter().collect();
504         let expected = vec![1, 2, 127, 1337];
505 
506         assert_eq!(to_vec(&(a.s(..) ^ b.s(..)).into()), &expected[..]);
507         assert_eq!(to_vec(&(b.s(..) ^ a.s(..)).into()), &expected[..]);
508 
509         let mut actual_1 = a.clone();
510         actual_1 ^= b.s(..);
511         assert_eq!(to_vec(&actual_1), &expected[..]);
512 
513         let mut actual_2 = b.clone();
514         actual_2 ^= a.s(..);
515         assert_eq!(to_vec(&actual_2), &expected[..]);
516     }
517 
518     #[test]
test_sub()519     fn test_sub() {
520         let a: BitSet = vec![1337, 42, 7, 1].into_iter().collect();
521         let b: BitSet = vec![42, 127, 2, 7].into_iter().collect();
522         let expected_1 = vec![1, 1337];
523         let expected_2 = vec![2, 127];
524 
525         assert_eq!(to_vec(&(a.s(..) - b.s(..)).into()), &expected_1[..]);
526         assert_eq!(to_vec(&(b.s(..) - a.s(..)).into()), &expected_2[..]);
527 
528         let mut actual_1 = a.clone();
529         actual_1 -= b.s(..);
530         assert_eq!(to_vec(&actual_1), &expected_1[..]);
531 
532         let mut actual_2 = b.clone();
533         actual_2 -= a.s(..);
534         assert_eq!(to_vec(&actual_2), &expected_2[..]);
535     }
536 
537     #[test]
test_compund()538     fn test_compund() {
539         let a: BitSet = vec![1337, 42, 7, 1].into_iter().collect();
540         let b: BitSet = vec![42, 127, 2, 7].into_iter().collect();
541         let mut c = BitSet::new();
542 
543         c &= a.s(..) | b.s(..);
544         assert!(c.is_empty());
545     }
546 }
547