1 // This file is part of ICU4X. For terms of use, please see the file
2 // called LICENSE at the top level of the ICU4X source tree
3 // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5 use super::*;
6 use crate::error::ZeroTrieBuildError;
7 use alloc::vec;
8 use alloc::vec::Vec;
9
10 /// To speed up the search algorithm, we limit the number of times the level-2 parameter (q)
11 /// can hit its max value (initially Q_FAST_MAX) before we try the next level-1 parameter (p).
12 /// In practice, this has a small impact on the resulting perfect hash, resulting in about
13 /// 1 in 10000 hash maps that fall back to the slow path.
14 const MAX_L2_SEARCH_MISSES: usize = 24;
15
16 /// Directly compute the perfect hash function.
17 ///
18 /// Returns `(p, [q_0, q_1, ..., q_(N-1)])`, or an error if the PHF could not be computed.
19 #[allow(unused_labels)] // for readability
find(bytes: &[u8]) -> Result<(u8, Vec<u8>), ZeroTrieBuildError>20 pub fn find(bytes: &[u8]) -> Result<(u8, Vec<u8>), ZeroTrieBuildError> {
21 let n_usize = bytes.len();
22
23 let mut p = 0u8;
24 let mut qq = vec![0u8; n_usize];
25
26 let mut bqs = vec![0u8; n_usize];
27 let mut seen = vec![false; n_usize];
28 let max_allowable_p = P_FAST_MAX;
29 let mut max_allowable_q = Q_FAST_MAX;
30
31 #[allow(non_snake_case)]
32 let N = if n_usize > 0 && n_usize < 256 {
33 n_usize as u8
34 } else {
35 debug_assert!(n_usize == 0 || n_usize == 256);
36 return Ok((p, qq));
37 };
38
39 'p_loop: loop {
40 let mut buckets: Vec<(usize, Vec<u8>)> = (0..n_usize).map(|i| (i, vec![])).collect();
41 for byte in bytes {
42 let l1 = f1(*byte, p, N) as usize;
43 buckets[l1].1.push(*byte);
44 }
45 buckets.sort_by_key(|(_, v)| -(v.len() as isize));
46 // println!("New P: p={p:?}, buckets={buckets:?}");
47 let mut i = 0;
48 let mut num_max_q = 0;
49 bqs.fill(0);
50 seen.fill(false);
51 'q_loop: loop {
52 if i == buckets.len() {
53 for (local_j, real_j) in buckets.iter().map(|(j, _)| *j).enumerate() {
54 qq[real_j] = bqs[local_j];
55 }
56 // println!("Success: p={p:?}, num_max_q={num_max_q:?}, bqs={bqs:?}, qq={qq:?}");
57 // if num_max_q > 0 {
58 // println!("num_max_q={num_max_q:?}");
59 // }
60 return Ok((p, qq));
61 }
62 let mut bucket = buckets[i].1.as_slice();
63 'byte_loop: for (j, byte) in bucket.iter().enumerate() {
64 let l2 = f2(*byte, bqs[i], N) as usize;
65 if seen[l2] {
66 // println!("Skipping Q: p={p:?}, i={i:?}, byte={byte:}, q={i:?}, l2={:?}", f2(*byte, bqs[i], N));
67 for k_byte in &bucket[0..j] {
68 let l2 = f2(*k_byte, bqs[i], N) as usize;
69 assert!(seen[l2]);
70 seen[l2] = false;
71 }
72 'reset_loop: loop {
73 if bqs[i] < max_allowable_q {
74 bqs[i] += 1;
75 continue 'q_loop;
76 }
77 num_max_q += 1;
78 bqs[i] = 0;
79 if i == 0 || num_max_q > MAX_L2_SEARCH_MISSES {
80 if p == max_allowable_p && max_allowable_q != Q_REAL_MAX {
81 // println!("Could not solve fast function: trying again: {bytes:?}");
82 max_allowable_q = Q_REAL_MAX;
83 p = 0;
84 continue 'p_loop;
85 } else if p == max_allowable_p {
86 // If a fallback algorithm for `p` is added, relax this assertion
87 // and re-run the loop with a higher `max_allowable_p`.
88 debug_assert_eq!(max_allowable_p, P_REAL_MAX);
89 // println!("Could not solve PHF function");
90 return Err(ZeroTrieBuildError::CouldNotSolvePerfectHash);
91 } else {
92 p += 1;
93 continue 'p_loop;
94 }
95 }
96 i -= 1;
97 bucket = buckets[i].1.as_slice();
98 for byte in bucket {
99 let l2 = f2(*byte, bqs[i], N) as usize;
100 assert!(seen[l2]);
101 seen[l2] = false;
102 }
103 }
104 } else {
105 // println!("Marking as seen: i={i:?}, byte={byte:}, l2={:?}", f2(*byte, bqs[i], N));
106 let l2 = f2(*byte, bqs[i], N) as usize;
107 seen[l2] = true;
108 }
109 }
110 // println!("Found Q: i={i:?}, q={:?}", bqs[i]);
111 i += 1;
112 }
113 }
114 }
115
116 impl PerfectByteHashMap<Vec<u8>> {
117 /// Computes a new [`PerfectByteHashMap`].
118 ///
119 /// (this is a doc-hidden API)
try_new(keys: &[u8]) -> Result<Self, ZeroTrieBuildError>120 pub fn try_new(keys: &[u8]) -> Result<Self, ZeroTrieBuildError> {
121 let n_usize = keys.len();
122 let n = n_usize as u8;
123 let (p, mut qq) = find(keys)?;
124 let mut keys_permuted = vec![0; n_usize];
125 for key in keys {
126 let l1 = f1(*key, p, n) as usize;
127 let q = qq[l1];
128 let l2 = f2(*key, q, n) as usize;
129 keys_permuted[l2] = *key;
130 }
131 let mut result = Vec::with_capacity(n_usize * 2 + 1);
132 result.push(p);
133 result.append(&mut qq);
134 result.append(&mut keys_permuted);
135 Ok(Self(result))
136 }
137 }
138
139 #[cfg(test)]
140 mod tests {
141 use super::*;
142
143 extern crate std;
144 use std::print;
145 use std::println;
146
print_byte_to_stdout(byte: u8)147 fn print_byte_to_stdout(byte: u8) {
148 let c = char::from(byte);
149 if c.is_ascii_alphanumeric() {
150 print!("'{c}'");
151 } else {
152 print!("0x{byte:X}");
153 }
154 }
155
random_alphanums(seed: u64, len: usize) -> Vec<u8>156 fn random_alphanums(seed: u64, len: usize) -> Vec<u8> {
157 use rand::seq::SliceRandom;
158 use rand::SeedableRng;
159 const BYTES: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
160 let mut rng = rand_pcg::Lcg64Xsh32::seed_from_u64(seed);
161 BYTES.choose_multiple(&mut rng, len).copied().collect()
162 }
163
164 #[test]
test_random_distributions()165 fn test_random_distributions() {
166 let mut p_distr = vec![0; 256];
167 let mut q_distr = vec![0; 256];
168 for len in 0..50 {
169 for seed in 0..50 {
170 let bytes = random_alphanums(seed, len);
171 let (p, qq) = find(bytes.as_slice()).unwrap();
172 p_distr[p as usize] += 1;
173 for q in qq {
174 q_distr[q as usize] += 1;
175 }
176 }
177 }
178 println!("p_distr: {p_distr:?}");
179 println!("q_distr: {q_distr:?}");
180
181 let fast_p = p_distr[0..=P_FAST_MAX as usize].iter().sum::<usize>();
182 let slow_p = p_distr[(P_FAST_MAX + 1) as usize..].iter().sum::<usize>();
183 let fast_q = q_distr[0..=Q_FAST_MAX as usize].iter().sum::<usize>();
184 let slow_q = q_distr[(Q_FAST_MAX + 1) as usize..].iter().sum::<usize>();
185
186 assert_eq!(2500, fast_p);
187 assert_eq!(0, slow_p);
188 assert_eq!(61247, fast_q);
189 assert_eq!(3, slow_q);
190
191 let bytes = random_alphanums(0, 16);
192
193 #[allow(non_snake_case)]
194 let N = u8::try_from(bytes.len()).unwrap();
195
196 let (p, qq) = find(bytes.as_slice()).unwrap();
197
198 println!("Results:");
199 for byte in bytes.iter() {
200 print_byte_to_stdout(*byte);
201 let l1 = f1(*byte, p, N) as usize;
202 let q = qq[l1];
203 let l2 = f2(*byte, q, N) as usize;
204 println!(" => l1 {l1} => q {q} => l2 {l2}");
205 }
206 }
207 }
208