• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of ICU4X. For terms of use, please see the file
2 // called LICENSE at the top level of the ICU4X source tree
3 // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4 
5 use super::*;
6 use crate::error::ZeroTrieBuildError;
7 use alloc::vec;
8 use alloc::vec::Vec;
9 
10 /// To speed up the search algorithm, we limit the number of times the level-2 parameter (q)
11 /// can hit its max value (initially Q_FAST_MAX) before we try the next level-1 parameter (p).
12 /// In practice, this has a small impact on the resulting perfect hash, resulting in about
13 /// 1 in 10000 hash maps that fall back to the slow path.
14 const MAX_L2_SEARCH_MISSES: usize = 24;
15 
16 /// Directly compute the perfect hash function.
17 ///
18 /// Returns `(p, [q_0, q_1, ..., q_(N-1)])`, or an error if the PHF could not be computed.
19 #[allow(unused_labels)] // for readability
find(bytes: &[u8]) -> Result<(u8, Vec<u8>), ZeroTrieBuildError>20 pub fn find(bytes: &[u8]) -> Result<(u8, Vec<u8>), ZeroTrieBuildError> {
21     let n_usize = bytes.len();
22 
23     let mut p = 0u8;
24     let mut qq = vec![0u8; n_usize];
25 
26     let mut bqs = vec![0u8; n_usize];
27     let mut seen = vec![false; n_usize];
28     let max_allowable_p = P_FAST_MAX;
29     let mut max_allowable_q = Q_FAST_MAX;
30 
31     #[allow(non_snake_case)]
32     let N = if n_usize > 0 && n_usize < 256 {
33         n_usize as u8
34     } else {
35         debug_assert!(n_usize == 0 || n_usize == 256);
36         return Ok((p, qq));
37     };
38 
39     'p_loop: loop {
40         let mut buckets: Vec<(usize, Vec<u8>)> = (0..n_usize).map(|i| (i, vec![])).collect();
41         for byte in bytes {
42             let l1 = f1(*byte, p, N) as usize;
43             buckets[l1].1.push(*byte);
44         }
45         buckets.sort_by_key(|(_, v)| -(v.len() as isize));
46         // println!("New P: p={p:?}, buckets={buckets:?}");
47         let mut i = 0;
48         let mut num_max_q = 0;
49         bqs.fill(0);
50         seen.fill(false);
51         'q_loop: loop {
52             if i == buckets.len() {
53                 for (local_j, real_j) in buckets.iter().map(|(j, _)| *j).enumerate() {
54                     qq[real_j] = bqs[local_j];
55                 }
56                 // println!("Success: p={p:?}, num_max_q={num_max_q:?}, bqs={bqs:?}, qq={qq:?}");
57                 // if num_max_q > 0 {
58                 //     println!("num_max_q={num_max_q:?}");
59                 // }
60                 return Ok((p, qq));
61             }
62             let mut bucket = buckets[i].1.as_slice();
63             'byte_loop: for (j, byte) in bucket.iter().enumerate() {
64                 let l2 = f2(*byte, bqs[i], N) as usize;
65                 if seen[l2] {
66                     // println!("Skipping Q: p={p:?}, i={i:?}, byte={byte:}, q={i:?}, l2={:?}", f2(*byte, bqs[i], N));
67                     for k_byte in &bucket[0..j] {
68                         let l2 = f2(*k_byte, bqs[i], N) as usize;
69                         assert!(seen[l2]);
70                         seen[l2] = false;
71                     }
72                     'reset_loop: loop {
73                         if bqs[i] < max_allowable_q {
74                             bqs[i] += 1;
75                             continue 'q_loop;
76                         }
77                         num_max_q += 1;
78                         bqs[i] = 0;
79                         if i == 0 || num_max_q > MAX_L2_SEARCH_MISSES {
80                             if p == max_allowable_p && max_allowable_q != Q_REAL_MAX {
81                                 // println!("Could not solve fast function: trying again: {bytes:?}");
82                                 max_allowable_q = Q_REAL_MAX;
83                                 p = 0;
84                                 continue 'p_loop;
85                             } else if p == max_allowable_p {
86                                 // If a fallback algorithm for `p` is added, relax this assertion
87                                 // and re-run the loop with a higher `max_allowable_p`.
88                                 debug_assert_eq!(max_allowable_p, P_REAL_MAX);
89                                 // println!("Could not solve PHF function");
90                                 return Err(ZeroTrieBuildError::CouldNotSolvePerfectHash);
91                             } else {
92                                 p += 1;
93                                 continue 'p_loop;
94                             }
95                         }
96                         i -= 1;
97                         bucket = buckets[i].1.as_slice();
98                         for byte in bucket {
99                             let l2 = f2(*byte, bqs[i], N) as usize;
100                             assert!(seen[l2]);
101                             seen[l2] = false;
102                         }
103                     }
104                 } else {
105                     // println!("Marking as seen: i={i:?}, byte={byte:}, l2={:?}", f2(*byte, bqs[i], N));
106                     let l2 = f2(*byte, bqs[i], N) as usize;
107                     seen[l2] = true;
108                 }
109             }
110             // println!("Found Q: i={i:?}, q={:?}", bqs[i]);
111             i += 1;
112         }
113     }
114 }
115 
116 impl PerfectByteHashMap<Vec<u8>> {
117     /// Computes a new [`PerfectByteHashMap`].
118     ///
119     /// (this is a doc-hidden API)
try_new(keys: &[u8]) -> Result<Self, ZeroTrieBuildError>120     pub fn try_new(keys: &[u8]) -> Result<Self, ZeroTrieBuildError> {
121         let n_usize = keys.len();
122         let n = n_usize as u8;
123         let (p, mut qq) = find(keys)?;
124         let mut keys_permuted = vec![0; n_usize];
125         for key in keys {
126             let l1 = f1(*key, p, n) as usize;
127             let q = qq[l1];
128             let l2 = f2(*key, q, n) as usize;
129             keys_permuted[l2] = *key;
130         }
131         let mut result = Vec::with_capacity(n_usize * 2 + 1);
132         result.push(p);
133         result.append(&mut qq);
134         result.append(&mut keys_permuted);
135         Ok(Self(result))
136     }
137 }
138 
139 #[cfg(test)]
140 mod tests {
141     use super::*;
142 
143     extern crate std;
144     use std::print;
145     use std::println;
146 
print_byte_to_stdout(byte: u8)147     fn print_byte_to_stdout(byte: u8) {
148         let c = char::from(byte);
149         if c.is_ascii_alphanumeric() {
150             print!("'{c}'");
151         } else {
152             print!("0x{byte:X}");
153         }
154     }
155 
random_alphanums(seed: u64, len: usize) -> Vec<u8>156     fn random_alphanums(seed: u64, len: usize) -> Vec<u8> {
157         use rand::seq::SliceRandom;
158         use rand::SeedableRng;
159         const BYTES: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
160         let mut rng = rand_pcg::Lcg64Xsh32::seed_from_u64(seed);
161         BYTES.choose_multiple(&mut rng, len).copied().collect()
162     }
163 
164     #[test]
test_random_distributions()165     fn test_random_distributions() {
166         let mut p_distr = vec![0; 256];
167         let mut q_distr = vec![0; 256];
168         for len in 0..50 {
169             for seed in 0..50 {
170                 let bytes = random_alphanums(seed, len);
171                 let (p, qq) = find(bytes.as_slice()).unwrap();
172                 p_distr[p as usize] += 1;
173                 for q in qq {
174                     q_distr[q as usize] += 1;
175                 }
176             }
177         }
178         println!("p_distr: {p_distr:?}");
179         println!("q_distr: {q_distr:?}");
180 
181         let fast_p = p_distr[0..=P_FAST_MAX as usize].iter().sum::<usize>();
182         let slow_p = p_distr[(P_FAST_MAX + 1) as usize..].iter().sum::<usize>();
183         let fast_q = q_distr[0..=Q_FAST_MAX as usize].iter().sum::<usize>();
184         let slow_q = q_distr[(Q_FAST_MAX + 1) as usize..].iter().sum::<usize>();
185 
186         assert_eq!(2500, fast_p);
187         assert_eq!(0, slow_p);
188         assert_eq!(61247, fast_q);
189         assert_eq!(3, slow_q);
190 
191         let bytes = random_alphanums(0, 16);
192 
193         #[allow(non_snake_case)]
194         let N = u8::try_from(bytes.len()).unwrap();
195 
196         let (p, qq) = find(bytes.as_slice()).unwrap();
197 
198         println!("Results:");
199         for byte in bytes.iter() {
200             print_byte_to_stdout(*byte);
201             let l1 = f1(*byte, p, N) as usize;
202             let q = qq[l1];
203             let l2 = f2(*byte, q, N) as usize;
204             println!(" => l1 {l1} => q {q} => l2 {l2}");
205         }
206     }
207 }
208