// This file is part of ICU4X. For terms of use, please see the file // called LICENSE at the top level of the ICU4X source tree // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). use super::*; use crate::error::ZeroTrieBuildError; use alloc::vec; use alloc::vec::Vec; /// To speed up the search algorithm, we limit the number of times the level-2 parameter (q) /// can hit its max value (initially Q_FAST_MAX) before we try the next level-1 parameter (p). /// In practice, this has a small impact on the resulting perfect hash, resulting in about /// 1 in 10000 hash maps that fall back to the slow path. const MAX_L2_SEARCH_MISSES: usize = 24; /// Directly compute the perfect hash function. /// /// Returns `(p, [q_0, q_1, ..., q_(N-1)])`, or an error if the PHF could not be computed. #[allow(unused_labels)] // for readability pub fn find(bytes: &[u8]) -> Result<(u8, Vec), ZeroTrieBuildError> { let n_usize = bytes.len(); let mut p = 0u8; let mut qq = vec![0u8; n_usize]; let mut bqs = vec![0u8; n_usize]; let mut seen = vec![false; n_usize]; let max_allowable_p = P_FAST_MAX; let mut max_allowable_q = Q_FAST_MAX; #[allow(non_snake_case)] let N = if n_usize > 0 && n_usize < 256 { n_usize as u8 } else { debug_assert!(n_usize == 0 || n_usize == 256); return Ok((p, qq)); }; 'p_loop: loop { let mut buckets: Vec<(usize, Vec)> = (0..n_usize).map(|i| (i, vec![])).collect(); for byte in bytes { let l1 = f1(*byte, p, N) as usize; buckets[l1].1.push(*byte); } buckets.sort_by_key(|(_, v)| -(v.len() as isize)); // println!("New P: p={p:?}, buckets={buckets:?}"); let mut i = 0; let mut num_max_q = 0; bqs.fill(0); seen.fill(false); 'q_loop: loop { if i == buckets.len() { for (local_j, real_j) in buckets.iter().map(|(j, _)| *j).enumerate() { qq[real_j] = bqs[local_j]; } // println!("Success: p={p:?}, num_max_q={num_max_q:?}, bqs={bqs:?}, qq={qq:?}"); // if num_max_q > 0 { // println!("num_max_q={num_max_q:?}"); // } return Ok((p, qq)); } let mut bucket = buckets[i].1.as_slice(); 'byte_loop: for (j, byte) in bucket.iter().enumerate() { let l2 = f2(*byte, bqs[i], N) as usize; if seen[l2] { // println!("Skipping Q: p={p:?}, i={i:?}, byte={byte:}, q={i:?}, l2={:?}", f2(*byte, bqs[i], N)); for k_byte in &bucket[0..j] { let l2 = f2(*k_byte, bqs[i], N) as usize; assert!(seen[l2]); seen[l2] = false; } 'reset_loop: loop { if bqs[i] < max_allowable_q { bqs[i] += 1; continue 'q_loop; } num_max_q += 1; bqs[i] = 0; if i == 0 || num_max_q > MAX_L2_SEARCH_MISSES { if p == max_allowable_p && max_allowable_q != Q_REAL_MAX { // println!("Could not solve fast function: trying again: {bytes:?}"); max_allowable_q = Q_REAL_MAX; p = 0; continue 'p_loop; } else if p == max_allowable_p { // If a fallback algorithm for `p` is added, relax this assertion // and re-run the loop with a higher `max_allowable_p`. debug_assert_eq!(max_allowable_p, P_REAL_MAX); // println!("Could not solve PHF function"); return Err(ZeroTrieBuildError::CouldNotSolvePerfectHash); } else { p += 1; continue 'p_loop; } } i -= 1; bucket = buckets[i].1.as_slice(); for byte in bucket { let l2 = f2(*byte, bqs[i], N) as usize; assert!(seen[l2]); seen[l2] = false; } } } else { // println!("Marking as seen: i={i:?}, byte={byte:}, l2={:?}", f2(*byte, bqs[i], N)); let l2 = f2(*byte, bqs[i], N) as usize; seen[l2] = true; } } // println!("Found Q: i={i:?}, q={:?}", bqs[i]); i += 1; } } } impl PerfectByteHashMap> { /// Computes a new [`PerfectByteHashMap`]. /// /// (this is a doc-hidden API) pub fn try_new(keys: &[u8]) -> Result { let n_usize = keys.len(); let n = n_usize as u8; let (p, mut qq) = find(keys)?; let mut keys_permuted = vec![0; n_usize]; for key in keys { let l1 = f1(*key, p, n) as usize; let q = qq[l1]; let l2 = f2(*key, q, n) as usize; keys_permuted[l2] = *key; } let mut result = Vec::with_capacity(n_usize * 2 + 1); result.push(p); result.append(&mut qq); result.append(&mut keys_permuted); Ok(Self(result)) } } #[cfg(test)] mod tests { use super::*; extern crate std; use std::print; use std::println; fn print_byte_to_stdout(byte: u8) { let c = char::from(byte); if c.is_ascii_alphanumeric() { print!("'{c}'"); } else { print!("0x{byte:X}"); } } fn random_alphanums(seed: u64, len: usize) -> Vec { use rand::seq::SliceRandom; use rand::SeedableRng; const BYTES: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; let mut rng = rand_pcg::Lcg64Xsh32::seed_from_u64(seed); BYTES.choose_multiple(&mut rng, len).copied().collect() } #[test] fn test_random_distributions() { let mut p_distr = vec![0; 256]; let mut q_distr = vec![0; 256]; for len in 0..50 { for seed in 0..50 { let bytes = random_alphanums(seed, len); let (p, qq) = find(bytes.as_slice()).unwrap(); p_distr[p as usize] += 1; for q in qq { q_distr[q as usize] += 1; } } } println!("p_distr: {p_distr:?}"); println!("q_distr: {q_distr:?}"); let fast_p = p_distr[0..=P_FAST_MAX as usize].iter().sum::(); let slow_p = p_distr[(P_FAST_MAX + 1) as usize..].iter().sum::(); let fast_q = q_distr[0..=Q_FAST_MAX as usize].iter().sum::(); let slow_q = q_distr[(Q_FAST_MAX + 1) as usize..].iter().sum::(); assert_eq!(2500, fast_p); assert_eq!(0, slow_p); assert_eq!(61247, fast_q); assert_eq!(3, slow_q); let bytes = random_alphanums(0, 16); #[allow(non_snake_case)] let N = u8::try_from(bytes.len()).unwrap(); let (p, qq) = find(bytes.as_slice()).unwrap(); println!("Results:"); for byte in bytes.iter() { print_byte_to_stdout(*byte); let l1 = f1(*byte, p, N) as usize; let q = qq[l1]; let l2 = f2(*byte, q, N) as usize; println!(" => l1 {l1} => q {q} => l2 {l2}"); } } }