• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of ICU4X. For terms of use, please see the file
2 // called LICENSE at the top level of the ICU4X source tree
3 // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4 
5 use core::cmp::Ordering;
6 
7 use super::super::branch_meta::BranchMeta;
8 use super::store::NonConstLengthsStack;
9 use super::store::TrieBuilderStore;
10 use crate::builder::bytestr::ByteStr;
11 use crate::byte_phf::PerfectByteHashMapCacheOwned;
12 use crate::error::ZeroTrieBuildError;
13 use crate::options::*;
14 use crate::varint;
15 use alloc::borrow::Cow;
16 use alloc::vec::Vec;
17 
18 /// A low-level builder for ZeroTrie. Supports all options.
19 pub(crate) struct ZeroTrieBuilder<S> {
20     data: S,
21     phf_cache: PerfectByteHashMapCacheOwned,
22     options: ZeroTrieBuilderOptions,
23 }
24 
25 impl<S: TrieBuilderStore> ZeroTrieBuilder<S> {
26     /// Returns the trie data as a `Vec<u8>`.
to_bytes(&self) -> Vec<u8>27     pub fn to_bytes(&self) -> Vec<u8> {
28         self.data.atbs_to_bytes()
29     }
30 
31     /// Prepends a byte value to the front of the builder. If it is ASCII, an ASCII
32     /// node is prepended. If it is non-ASCII, if there is already a span node at
33     /// the front, we modify the span node to add the new byte; otherwise, we create
34     /// a new span node. Returns the delta in length, which is either 1 or 2.
prepend_ascii(&mut self, ascii: u8) -> Result<usize, ZeroTrieBuildError>35     fn prepend_ascii(&mut self, ascii: u8) -> Result<usize, ZeroTrieBuildError> {
36         if ascii <= 127 {
37             self.data.atbs_push_front(ascii);
38             Ok(1)
39         } else if matches!(self.options.ascii_mode, AsciiMode::BinarySpans) {
40             if let Some(old_front) = self.data.atbs_pop_front() {
41                 let old_byte_len = self.data.atbs_len() + 1;
42                 if old_front & 0b11100000 == 0b10100000 {
43                     // Extend an existing span
44                     // Unwrap OK: there is a varint at this location in the buffer
45                     #[allow(clippy::unwrap_used)]
46                     let old_span_size =
47                         varint::try_read_varint_meta3_from_tstore(old_front, &mut self.data)
48                             .unwrap();
49                     self.data.atbs_push_front(ascii);
50                     let varint_array = varint::write_varint_meta3(old_span_size + 1);
51                     self.data.atbs_extend_front(varint_array.as_slice());
52                     self.data.atbs_bitor_assign(0, 0b10100000);
53                     let new_byte_len = self.data.atbs_len();
54                     return Ok(new_byte_len - old_byte_len);
55                 } else {
56                     self.data.atbs_push_front(old_front);
57                 }
58             }
59             // Create a new span
60             self.data.atbs_push_front(ascii);
61             self.data.atbs_push_front(0b10100001);
62             Ok(2)
63         } else {
64             Err(ZeroTrieBuildError::NonAsciiError)
65         }
66     }
67 
68     /// Prepends a value node to the front of the builder. Returns the
69     /// delta in length, which depends on the size of the varint.
70     #[must_use]
prepend_value(&mut self, value: usize) -> usize71     fn prepend_value(&mut self, value: usize) -> usize {
72         let varint_array = varint::write_varint_meta3(value);
73         self.data.atbs_extend_front(varint_array.as_slice());
74         self.data.atbs_bitor_assign(0, 0b10000000);
75         varint_array.len()
76     }
77 
78     /// Prepends a branch node to the front of the builder. Returns the
79     /// delta in length, which depends on the size of the varint.
80     #[must_use]
prepend_branch(&mut self, value: usize) -> usize81     fn prepend_branch(&mut self, value: usize) -> usize {
82         let varint_array = varint::write_varint_meta2(value);
83         self.data.atbs_extend_front(varint_array.as_slice());
84         self.data.atbs_bitor_assign(0, 0b11000000);
85         varint_array.len()
86     }
87 
88     /// Prepends multiple arbitrary bytes to the front of the builder. Returns the
89     /// delta in length, which is the length of the slice.
90     #[must_use]
prepend_slice(&mut self, s: &[u8]) -> usize91     fn prepend_slice(&mut self, s: &[u8]) -> usize {
92         self.data.atbs_extend_front(s);
93         s.len()
94     }
95 
96     /// Builds a ZeroTrie from an iterator of bytes. It first collects and sorts the iterator.
from_bytes_iter<K: AsRef<[u8]>, I: IntoIterator<Item = (K, usize)>>( iter: I, options: ZeroTrieBuilderOptions, ) -> Result<Self, ZeroTrieBuildError>97     pub fn from_bytes_iter<K: AsRef<[u8]>, I: IntoIterator<Item = (K, usize)>>(
98         iter: I,
99         options: ZeroTrieBuilderOptions,
100     ) -> Result<Self, ZeroTrieBuildError> {
101         let items = Vec::<(K, usize)>::from_iter(iter);
102         let mut items = items
103             .iter()
104             .map(|(k, v)| (k.as_ref(), *v))
105             .collect::<Vec<(&[u8], usize)>>();
106         items.sort_by(|a, b| cmp_keys_values(options, *a, *b));
107         let ascii_str_slice = items.as_slice();
108         let byte_str_slice = ByteStr::from_byte_slice_with_value(ascii_str_slice);
109         Self::from_sorted_tuple_slice_impl(byte_str_slice, options)
110     }
111 
112     /// Builds a ZeroTrie with the given items and options. Assumes that the items are sorted,
113     /// except for a case-insensitive trie where the items are re-sorted.
114     ///
115     /// # Panics
116     ///
117     /// May panic if the items are not sorted.
from_sorted_tuple_slice( items: &[(&ByteStr, usize)], options: ZeroTrieBuilderOptions, ) -> Result<Self, ZeroTrieBuildError>118     pub fn from_sorted_tuple_slice(
119         items: &[(&ByteStr, usize)],
120         options: ZeroTrieBuilderOptions,
121     ) -> Result<Self, ZeroTrieBuildError> {
122         let mut items = Cow::Borrowed(items);
123         if matches!(options.case_sensitivity, CaseSensitivity::IgnoreCase) {
124             // We need to re-sort the items with our custom comparator.
125             items.to_mut().sort_by(|a, b| {
126                 cmp_keys_values(options, (a.0.as_bytes(), a.1), (b.0.as_bytes(), b.1))
127             });
128         }
129         Self::from_sorted_tuple_slice_impl(&items, options)
130     }
131 
132     /// Internal constructor that does not re-sort the items.
from_sorted_tuple_slice_impl( items: &[(&ByteStr, usize)], options: ZeroTrieBuilderOptions, ) -> Result<Self, ZeroTrieBuildError>133     fn from_sorted_tuple_slice_impl(
134         items: &[(&ByteStr, usize)],
135         options: ZeroTrieBuilderOptions,
136     ) -> Result<Self, ZeroTrieBuildError> {
137         for ab in items.windows(2) {
138             debug_assert!(cmp_keys_values(
139                 options,
140                 (ab[0].0.as_bytes(), ab[0].1),
141                 (ab[1].0.as_bytes(), ab[1].1)
142             )
143             .is_lt());
144         }
145         let mut result = Self {
146             data: S::atbs_new_empty(),
147             phf_cache: PerfectByteHashMapCacheOwned::new_empty(),
148             options,
149         };
150         let total_size = result.create(items)?;
151         debug_assert!(total_size == result.data.atbs_len());
152         Ok(result)
153     }
154 
155     /// The actual builder algorithm. For an explanation, see [`crate::builder`].
156     #[allow(clippy::unwrap_used)] // lots of indexing, but all indexes should be in range
create(&mut self, all_items: &[(&ByteStr, usize)]) -> Result<usize, ZeroTrieBuildError>157     fn create(&mut self, all_items: &[(&ByteStr, usize)]) -> Result<usize, ZeroTrieBuildError> {
158         let mut prefix_len = match all_items.last() {
159             Some(x) => x.0.len(),
160             // Empty slice:
161             None => return Ok(0),
162         };
163         // Initialize the main loop to point at the last string.
164         let mut lengths_stack = NonConstLengthsStack::new();
165         let mut i = all_items.len() - 1;
166         let mut j = all_items.len();
167         let mut current_len = 0;
168         // Start the main loop.
169         loop {
170             let item_i = all_items.get(i).unwrap();
171             let item_j = all_items.get(j - 1).unwrap();
172             debug_assert!(item_i.0.prefix_eq(item_j.0, prefix_len));
173             // Check if we need to add a value node here.
174             if item_i.0.len() == prefix_len {
175                 let len = self.prepend_value(item_i.1);
176                 current_len += len;
177             }
178             if prefix_len == 0 {
179                 // All done! Leave the main loop.
180                 break;
181             }
182             // Reduce the prefix length by 1 and recalculate i and j.
183             prefix_len -= 1;
184             let mut new_i = i;
185             let mut new_j = j;
186             let mut ascii_i = item_i.0.byte_at_or_panic(prefix_len);
187             let mut ascii_j = item_j.0.byte_at_or_panic(prefix_len);
188             debug_assert_eq!(ascii_i, ascii_j);
189             let key_ascii = ascii_i;
190             loop {
191                 if new_i == 0 {
192                     break;
193                 }
194                 let candidate = all_items.get(new_i - 1).unwrap().0;
195                 if candidate.len() < prefix_len {
196                     // Too short
197                     break;
198                 }
199                 if item_i.0.prefix_eq(candidate, prefix_len) {
200                     new_i -= 1;
201                 } else {
202                     break;
203                 }
204                 if candidate.len() == prefix_len {
205                     // A string that equals the prefix does not take part in the branch node.
206                     break;
207                 }
208                 let candidate = candidate.byte_at_or_panic(prefix_len);
209                 if candidate != ascii_i {
210                     ascii_i = candidate;
211                 }
212             }
213             loop {
214                 if new_j == all_items.len() {
215                     break;
216                 }
217                 let candidate = all_items.get(new_j).unwrap().0;
218                 if candidate.len() < prefix_len {
219                     // Too short
220                     break;
221                 }
222                 if item_j.0.prefix_eq(candidate, prefix_len) {
223                     new_j += 1;
224                 } else {
225                     break;
226                 }
227                 if candidate.len() == prefix_len {
228                     panic!("A shorter string should be earlier in the sequence");
229                 }
230                 let candidate = candidate.byte_at_or_panic(prefix_len);
231                 if candidate != ascii_j {
232                     ascii_j = candidate;
233                 }
234             }
235             // If there are no different bytes at this prefix level, we can add an ASCII or Span
236             // node and then continue to the next iteration of the main loop.
237             if ascii_i == key_ascii && ascii_j == key_ascii {
238                 let len = self.prepend_ascii(key_ascii)?;
239                 current_len += len;
240                 if matches!(self.options.case_sensitivity, CaseSensitivity::IgnoreCase)
241                     && i == new_i + 2
242                 {
243                     // This can happen if two strings were picked up, each with a different case
244                     return Err(ZeroTrieBuildError::MixedCase);
245                 }
246                 debug_assert!(
247                     i == new_i || i == new_i + 1,
248                     "only the exact prefix string can be picked up at this level: {}",
249                     key_ascii
250                 );
251                 i = new_i;
252                 debug_assert_eq!(j, new_j);
253                 continue;
254             }
255             // If i and j changed, we are a target of a branch node.
256             if ascii_j == key_ascii {
257                 // We are the _last_ target of a branch node.
258                 lengths_stack.push(BranchMeta {
259                     ascii: key_ascii,
260                     cumulative_length: current_len,
261                     local_length: current_len,
262                     count: 1,
263                 });
264             } else {
265                 // We are the _not the last_ target of a branch node.
266                 let BranchMeta {
267                     cumulative_length,
268                     count,
269                     ..
270                 } = lengths_stack.peek_or_panic();
271                 lengths_stack.push(BranchMeta {
272                     ascii: key_ascii,
273                     cumulative_length: cumulative_length + current_len,
274                     local_length: current_len,
275                     count: count + 1,
276                 });
277             }
278             if ascii_i != key_ascii {
279                 // We are _not the first_ target of a branch node.
280                 // Set the cursor to the previous string and continue the loop.
281                 j = i;
282                 i -= 1;
283                 prefix_len = all_items.get(i).unwrap().0.len();
284                 current_len = 0;
285                 continue;
286             }
287             // Branch (first)
288             // std::println!("lengths_stack: {lengths_stack:?}");
289             let (total_length, total_count) = {
290                 let BranchMeta {
291                     cumulative_length,
292                     count,
293                     ..
294                 } = lengths_stack.peek_or_panic();
295                 (cumulative_length, count)
296             };
297             let mut branch_metas = lengths_stack.pop_many_or_panic(total_count);
298             let original_keys = branch_metas.map_to_ascii_bytes();
299             if matches!(self.options.case_sensitivity, CaseSensitivity::IgnoreCase) {
300                 // Check to see if we have the same letter in two different cases
301                 let mut seen_ascii_alpha = [false; 26];
302                 for c in original_keys.as_const_slice().as_slice() {
303                     if c.is_ascii_alphabetic() {
304                         let i = (c.to_ascii_lowercase() - b'a') as usize;
305                         if seen_ascii_alpha[i] {
306                             return Err(ZeroTrieBuildError::MixedCase);
307                         } else {
308                             seen_ascii_alpha[i] = true;
309                         }
310                     }
311                 }
312             }
313             let use_phf = matches!(self.options.phf_mode, PhfMode::UsePhf);
314             let opt_phf_vec = if total_count > 15 && use_phf {
315                 let phf_vec = self
316                     .phf_cache
317                     .try_get_or_insert(original_keys.as_const_slice().as_slice().to_vec())?;
318                 // Put everything in order via bubble sort
319                 // Note: branch_metas is stored in reverse order (0 = last element)
320                 loop {
321                     let mut l = total_count - 1;
322                     let mut changes = 0;
323                     let mut start = 0;
324                     while l > 0 {
325                         let a = *branch_metas.as_const_slice().get_or_panic(l);
326                         let b = *branch_metas.as_const_slice().get_or_panic(l - 1);
327                         let a_idx = phf_vec.keys().iter().position(|x| x == &a.ascii).unwrap();
328                         let b_idx = phf_vec.keys().iter().position(|x| x == &b.ascii).unwrap();
329                         if a_idx > b_idx {
330                             // std::println!("{a:?} <=> {b:?} ({phf_vec:?})");
331                             self.data.atbs_swap_ranges(
332                                 start,
333                                 start + a.local_length,
334                                 start + a.local_length + b.local_length,
335                             );
336                             branch_metas = branch_metas.swap_or_panic(l - 1, l);
337                             start += b.local_length;
338                             changes += 1;
339                             // FIXME: fix the `length` field
340                         } else {
341                             start += a.local_length;
342                         }
343                         l -= 1;
344                     }
345                     if changes == 0 {
346                         break;
347                     }
348                 }
349                 Some(phf_vec)
350             } else {
351                 None
352             };
353             // Write out the offset table
354             current_len = total_length;
355             const USIZE_BITS: usize = core::mem::size_of::<usize>() * 8;
356             let w = (USIZE_BITS - (total_length.leading_zeros() as usize) - 1) / 8;
357             if w > 3 && matches!(self.options.capacity_mode, CapacityMode::Normal) {
358                 return Err(ZeroTrieBuildError::CapacityExceeded);
359             }
360             let mut k = 0;
361             while k <= w {
362                 self.data.atbs_prepend_n_zeros(total_count - 1);
363                 current_len += total_count - 1;
364                 let mut l = 0;
365                 let mut length_to_write = 0;
366                 while l < total_count {
367                     let BranchMeta { local_length, .. } = *branch_metas
368                         .as_const_slice()
369                         .get_or_panic(total_count - l - 1);
370                     let mut adjusted_length = length_to_write;
371                     let mut m = 0;
372                     while m < k {
373                         adjusted_length >>= 8;
374                         m += 1;
375                     }
376                     if l > 0 {
377                         self.data.atbs_bitor_assign(l - 1, adjusted_length as u8);
378                     }
379                     l += 1;
380                     length_to_write += local_length;
381                 }
382                 k += 1;
383             }
384             // Write out the lookup table
385             assert!(0 < total_count && total_count <= 256);
386             let branch_value = (w << 8) + (total_count & 0xff);
387             if let Some(phf_vec) = opt_phf_vec {
388                 self.data.atbs_extend_front(phf_vec.as_bytes());
389                 let phf_len = phf_vec.as_bytes().len();
390                 let branch_len = self.prepend_branch(branch_value);
391                 current_len += phf_len + branch_len;
392             } else {
393                 let search_len = self.prepend_slice(original_keys.as_slice());
394                 let branch_len = self.prepend_branch(branch_value);
395                 current_len += search_len + branch_len;
396             }
397             i = new_i;
398             j = new_j;
399         }
400         assert!(lengths_stack.is_empty());
401         Ok(current_len)
402     }
403 }
404 
cmp_keys_values( options: ZeroTrieBuilderOptions, a: (&[u8], usize), b: (&[u8], usize), ) -> Ordering405 fn cmp_keys_values(
406     options: ZeroTrieBuilderOptions,
407     a: (&[u8], usize),
408     b: (&[u8], usize),
409 ) -> Ordering {
410     if matches!(options.case_sensitivity, CaseSensitivity::Sensitive) {
411         a.0.cmp(b.0)
412     } else {
413         let a_iter = a.0.iter().map(|x| x.to_ascii_lowercase());
414         let b_iter = b.0.iter().map(|x| x.to_ascii_lowercase());
415         Iterator::cmp(a_iter, b_iter)
416     }
417     .then_with(|| a.1.cmp(&b.1))
418 }
419