1 // This file is part of ICU4X. For terms of use, please see the file
2 // called LICENSE at the top level of the ICU4X source tree
3 // (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5 use core::cmp::Ordering;
6
7 use super::super::branch_meta::BranchMeta;
8 use super::store::NonConstLengthsStack;
9 use super::store::TrieBuilderStore;
10 use crate::builder::bytestr::ByteStr;
11 use crate::byte_phf::PerfectByteHashMapCacheOwned;
12 use crate::error::ZeroTrieBuildError;
13 use crate::options::*;
14 use crate::varint;
15 use alloc::borrow::Cow;
16 use alloc::vec::Vec;
17
18 /// A low-level builder for ZeroTrie. Supports all options.
19 pub(crate) struct ZeroTrieBuilder<S> {
20 data: S,
21 phf_cache: PerfectByteHashMapCacheOwned,
22 options: ZeroTrieBuilderOptions,
23 }
24
25 impl<S: TrieBuilderStore> ZeroTrieBuilder<S> {
26 /// Returns the trie data as a `Vec<u8>`.
to_bytes(&self) -> Vec<u8>27 pub fn to_bytes(&self) -> Vec<u8> {
28 self.data.atbs_to_bytes()
29 }
30
31 /// Prepends a byte value to the front of the builder. If it is ASCII, an ASCII
32 /// node is prepended. If it is non-ASCII, if there is already a span node at
33 /// the front, we modify the span node to add the new byte; otherwise, we create
34 /// a new span node. Returns the delta in length, which is either 1 or 2.
prepend_ascii(&mut self, ascii: u8) -> Result<usize, ZeroTrieBuildError>35 fn prepend_ascii(&mut self, ascii: u8) -> Result<usize, ZeroTrieBuildError> {
36 if ascii <= 127 {
37 self.data.atbs_push_front(ascii);
38 Ok(1)
39 } else if matches!(self.options.ascii_mode, AsciiMode::BinarySpans) {
40 if let Some(old_front) = self.data.atbs_pop_front() {
41 let old_byte_len = self.data.atbs_len() + 1;
42 if old_front & 0b11100000 == 0b10100000 {
43 // Extend an existing span
44 // Unwrap OK: there is a varint at this location in the buffer
45 #[allow(clippy::unwrap_used)]
46 let old_span_size =
47 varint::try_read_varint_meta3_from_tstore(old_front, &mut self.data)
48 .unwrap();
49 self.data.atbs_push_front(ascii);
50 let varint_array = varint::write_varint_meta3(old_span_size + 1);
51 self.data.atbs_extend_front(varint_array.as_slice());
52 self.data.atbs_bitor_assign(0, 0b10100000);
53 let new_byte_len = self.data.atbs_len();
54 return Ok(new_byte_len - old_byte_len);
55 } else {
56 self.data.atbs_push_front(old_front);
57 }
58 }
59 // Create a new span
60 self.data.atbs_push_front(ascii);
61 self.data.atbs_push_front(0b10100001);
62 Ok(2)
63 } else {
64 Err(ZeroTrieBuildError::NonAsciiError)
65 }
66 }
67
68 /// Prepends a value node to the front of the builder. Returns the
69 /// delta in length, which depends on the size of the varint.
70 #[must_use]
prepend_value(&mut self, value: usize) -> usize71 fn prepend_value(&mut self, value: usize) -> usize {
72 let varint_array = varint::write_varint_meta3(value);
73 self.data.atbs_extend_front(varint_array.as_slice());
74 self.data.atbs_bitor_assign(0, 0b10000000);
75 varint_array.len()
76 }
77
78 /// Prepends a branch node to the front of the builder. Returns the
79 /// delta in length, which depends on the size of the varint.
80 #[must_use]
prepend_branch(&mut self, value: usize) -> usize81 fn prepend_branch(&mut self, value: usize) -> usize {
82 let varint_array = varint::write_varint_meta2(value);
83 self.data.atbs_extend_front(varint_array.as_slice());
84 self.data.atbs_bitor_assign(0, 0b11000000);
85 varint_array.len()
86 }
87
88 /// Prepends multiple arbitrary bytes to the front of the builder. Returns the
89 /// delta in length, which is the length of the slice.
90 #[must_use]
prepend_slice(&mut self, s: &[u8]) -> usize91 fn prepend_slice(&mut self, s: &[u8]) -> usize {
92 self.data.atbs_extend_front(s);
93 s.len()
94 }
95
96 /// Builds a ZeroTrie from an iterator of bytes. It first collects and sorts the iterator.
from_bytes_iter<K: AsRef<[u8]>, I: IntoIterator<Item = (K, usize)>>( iter: I, options: ZeroTrieBuilderOptions, ) -> Result<Self, ZeroTrieBuildError>97 pub fn from_bytes_iter<K: AsRef<[u8]>, I: IntoIterator<Item = (K, usize)>>(
98 iter: I,
99 options: ZeroTrieBuilderOptions,
100 ) -> Result<Self, ZeroTrieBuildError> {
101 let items = Vec::<(K, usize)>::from_iter(iter);
102 let mut items = items
103 .iter()
104 .map(|(k, v)| (k.as_ref(), *v))
105 .collect::<Vec<(&[u8], usize)>>();
106 items.sort_by(|a, b| cmp_keys_values(options, *a, *b));
107 let ascii_str_slice = items.as_slice();
108 let byte_str_slice = ByteStr::from_byte_slice_with_value(ascii_str_slice);
109 Self::from_sorted_tuple_slice_impl(byte_str_slice, options)
110 }
111
112 /// Builds a ZeroTrie with the given items and options. Assumes that the items are sorted,
113 /// except for a case-insensitive trie where the items are re-sorted.
114 ///
115 /// # Panics
116 ///
117 /// May panic if the items are not sorted.
from_sorted_tuple_slice( items: &[(&ByteStr, usize)], options: ZeroTrieBuilderOptions, ) -> Result<Self, ZeroTrieBuildError>118 pub fn from_sorted_tuple_slice(
119 items: &[(&ByteStr, usize)],
120 options: ZeroTrieBuilderOptions,
121 ) -> Result<Self, ZeroTrieBuildError> {
122 let mut items = Cow::Borrowed(items);
123 if matches!(options.case_sensitivity, CaseSensitivity::IgnoreCase) {
124 // We need to re-sort the items with our custom comparator.
125 items.to_mut().sort_by(|a, b| {
126 cmp_keys_values(options, (a.0.as_bytes(), a.1), (b.0.as_bytes(), b.1))
127 });
128 }
129 Self::from_sorted_tuple_slice_impl(&items, options)
130 }
131
132 /// Internal constructor that does not re-sort the items.
from_sorted_tuple_slice_impl( items: &[(&ByteStr, usize)], options: ZeroTrieBuilderOptions, ) -> Result<Self, ZeroTrieBuildError>133 fn from_sorted_tuple_slice_impl(
134 items: &[(&ByteStr, usize)],
135 options: ZeroTrieBuilderOptions,
136 ) -> Result<Self, ZeroTrieBuildError> {
137 for ab in items.windows(2) {
138 debug_assert!(cmp_keys_values(
139 options,
140 (ab[0].0.as_bytes(), ab[0].1),
141 (ab[1].0.as_bytes(), ab[1].1)
142 )
143 .is_lt());
144 }
145 let mut result = Self {
146 data: S::atbs_new_empty(),
147 phf_cache: PerfectByteHashMapCacheOwned::new_empty(),
148 options,
149 };
150 let total_size = result.create(items)?;
151 debug_assert!(total_size == result.data.atbs_len());
152 Ok(result)
153 }
154
155 /// The actual builder algorithm. For an explanation, see [`crate::builder`].
156 #[allow(clippy::unwrap_used)] // lots of indexing, but all indexes should be in range
create(&mut self, all_items: &[(&ByteStr, usize)]) -> Result<usize, ZeroTrieBuildError>157 fn create(&mut self, all_items: &[(&ByteStr, usize)]) -> Result<usize, ZeroTrieBuildError> {
158 let mut prefix_len = match all_items.last() {
159 Some(x) => x.0.len(),
160 // Empty slice:
161 None => return Ok(0),
162 };
163 // Initialize the main loop to point at the last string.
164 let mut lengths_stack = NonConstLengthsStack::new();
165 let mut i = all_items.len() - 1;
166 let mut j = all_items.len();
167 let mut current_len = 0;
168 // Start the main loop.
169 loop {
170 let item_i = all_items.get(i).unwrap();
171 let item_j = all_items.get(j - 1).unwrap();
172 debug_assert!(item_i.0.prefix_eq(item_j.0, prefix_len));
173 // Check if we need to add a value node here.
174 if item_i.0.len() == prefix_len {
175 let len = self.prepend_value(item_i.1);
176 current_len += len;
177 }
178 if prefix_len == 0 {
179 // All done! Leave the main loop.
180 break;
181 }
182 // Reduce the prefix length by 1 and recalculate i and j.
183 prefix_len -= 1;
184 let mut new_i = i;
185 let mut new_j = j;
186 let mut ascii_i = item_i.0.byte_at_or_panic(prefix_len);
187 let mut ascii_j = item_j.0.byte_at_or_panic(prefix_len);
188 debug_assert_eq!(ascii_i, ascii_j);
189 let key_ascii = ascii_i;
190 loop {
191 if new_i == 0 {
192 break;
193 }
194 let candidate = all_items.get(new_i - 1).unwrap().0;
195 if candidate.len() < prefix_len {
196 // Too short
197 break;
198 }
199 if item_i.0.prefix_eq(candidate, prefix_len) {
200 new_i -= 1;
201 } else {
202 break;
203 }
204 if candidate.len() == prefix_len {
205 // A string that equals the prefix does not take part in the branch node.
206 break;
207 }
208 let candidate = candidate.byte_at_or_panic(prefix_len);
209 if candidate != ascii_i {
210 ascii_i = candidate;
211 }
212 }
213 loop {
214 if new_j == all_items.len() {
215 break;
216 }
217 let candidate = all_items.get(new_j).unwrap().0;
218 if candidate.len() < prefix_len {
219 // Too short
220 break;
221 }
222 if item_j.0.prefix_eq(candidate, prefix_len) {
223 new_j += 1;
224 } else {
225 break;
226 }
227 if candidate.len() == prefix_len {
228 panic!("A shorter string should be earlier in the sequence");
229 }
230 let candidate = candidate.byte_at_or_panic(prefix_len);
231 if candidate != ascii_j {
232 ascii_j = candidate;
233 }
234 }
235 // If there are no different bytes at this prefix level, we can add an ASCII or Span
236 // node and then continue to the next iteration of the main loop.
237 if ascii_i == key_ascii && ascii_j == key_ascii {
238 let len = self.prepend_ascii(key_ascii)?;
239 current_len += len;
240 if matches!(self.options.case_sensitivity, CaseSensitivity::IgnoreCase)
241 && i == new_i + 2
242 {
243 // This can happen if two strings were picked up, each with a different case
244 return Err(ZeroTrieBuildError::MixedCase);
245 }
246 debug_assert!(
247 i == new_i || i == new_i + 1,
248 "only the exact prefix string can be picked up at this level: {}",
249 key_ascii
250 );
251 i = new_i;
252 debug_assert_eq!(j, new_j);
253 continue;
254 }
255 // If i and j changed, we are a target of a branch node.
256 if ascii_j == key_ascii {
257 // We are the _last_ target of a branch node.
258 lengths_stack.push(BranchMeta {
259 ascii: key_ascii,
260 cumulative_length: current_len,
261 local_length: current_len,
262 count: 1,
263 });
264 } else {
265 // We are the _not the last_ target of a branch node.
266 let BranchMeta {
267 cumulative_length,
268 count,
269 ..
270 } = lengths_stack.peek_or_panic();
271 lengths_stack.push(BranchMeta {
272 ascii: key_ascii,
273 cumulative_length: cumulative_length + current_len,
274 local_length: current_len,
275 count: count + 1,
276 });
277 }
278 if ascii_i != key_ascii {
279 // We are _not the first_ target of a branch node.
280 // Set the cursor to the previous string and continue the loop.
281 j = i;
282 i -= 1;
283 prefix_len = all_items.get(i).unwrap().0.len();
284 current_len = 0;
285 continue;
286 }
287 // Branch (first)
288 // std::println!("lengths_stack: {lengths_stack:?}");
289 let (total_length, total_count) = {
290 let BranchMeta {
291 cumulative_length,
292 count,
293 ..
294 } = lengths_stack.peek_or_panic();
295 (cumulative_length, count)
296 };
297 let mut branch_metas = lengths_stack.pop_many_or_panic(total_count);
298 let original_keys = branch_metas.map_to_ascii_bytes();
299 if matches!(self.options.case_sensitivity, CaseSensitivity::IgnoreCase) {
300 // Check to see if we have the same letter in two different cases
301 let mut seen_ascii_alpha = [false; 26];
302 for c in original_keys.as_const_slice().as_slice() {
303 if c.is_ascii_alphabetic() {
304 let i = (c.to_ascii_lowercase() - b'a') as usize;
305 if seen_ascii_alpha[i] {
306 return Err(ZeroTrieBuildError::MixedCase);
307 } else {
308 seen_ascii_alpha[i] = true;
309 }
310 }
311 }
312 }
313 let use_phf = matches!(self.options.phf_mode, PhfMode::UsePhf);
314 let opt_phf_vec = if total_count > 15 && use_phf {
315 let phf_vec = self
316 .phf_cache
317 .try_get_or_insert(original_keys.as_const_slice().as_slice().to_vec())?;
318 // Put everything in order via bubble sort
319 // Note: branch_metas is stored in reverse order (0 = last element)
320 loop {
321 let mut l = total_count - 1;
322 let mut changes = 0;
323 let mut start = 0;
324 while l > 0 {
325 let a = *branch_metas.as_const_slice().get_or_panic(l);
326 let b = *branch_metas.as_const_slice().get_or_panic(l - 1);
327 let a_idx = phf_vec.keys().iter().position(|x| x == &a.ascii).unwrap();
328 let b_idx = phf_vec.keys().iter().position(|x| x == &b.ascii).unwrap();
329 if a_idx > b_idx {
330 // std::println!("{a:?} <=> {b:?} ({phf_vec:?})");
331 self.data.atbs_swap_ranges(
332 start,
333 start + a.local_length,
334 start + a.local_length + b.local_length,
335 );
336 branch_metas = branch_metas.swap_or_panic(l - 1, l);
337 start += b.local_length;
338 changes += 1;
339 // FIXME: fix the `length` field
340 } else {
341 start += a.local_length;
342 }
343 l -= 1;
344 }
345 if changes == 0 {
346 break;
347 }
348 }
349 Some(phf_vec)
350 } else {
351 None
352 };
353 // Write out the offset table
354 current_len = total_length;
355 const USIZE_BITS: usize = core::mem::size_of::<usize>() * 8;
356 let w = (USIZE_BITS - (total_length.leading_zeros() as usize) - 1) / 8;
357 if w > 3 && matches!(self.options.capacity_mode, CapacityMode::Normal) {
358 return Err(ZeroTrieBuildError::CapacityExceeded);
359 }
360 let mut k = 0;
361 while k <= w {
362 self.data.atbs_prepend_n_zeros(total_count - 1);
363 current_len += total_count - 1;
364 let mut l = 0;
365 let mut length_to_write = 0;
366 while l < total_count {
367 let BranchMeta { local_length, .. } = *branch_metas
368 .as_const_slice()
369 .get_or_panic(total_count - l - 1);
370 let mut adjusted_length = length_to_write;
371 let mut m = 0;
372 while m < k {
373 adjusted_length >>= 8;
374 m += 1;
375 }
376 if l > 0 {
377 self.data.atbs_bitor_assign(l - 1, adjusted_length as u8);
378 }
379 l += 1;
380 length_to_write += local_length;
381 }
382 k += 1;
383 }
384 // Write out the lookup table
385 assert!(0 < total_count && total_count <= 256);
386 let branch_value = (w << 8) + (total_count & 0xff);
387 if let Some(phf_vec) = opt_phf_vec {
388 self.data.atbs_extend_front(phf_vec.as_bytes());
389 let phf_len = phf_vec.as_bytes().len();
390 let branch_len = self.prepend_branch(branch_value);
391 current_len += phf_len + branch_len;
392 } else {
393 let search_len = self.prepend_slice(original_keys.as_slice());
394 let branch_len = self.prepend_branch(branch_value);
395 current_len += search_len + branch_len;
396 }
397 i = new_i;
398 j = new_j;
399 }
400 assert!(lengths_stack.is_empty());
401 Ok(current_len)
402 }
403 }
404
cmp_keys_values( options: ZeroTrieBuilderOptions, a: (&[u8], usize), b: (&[u8], usize), ) -> Ordering405 fn cmp_keys_values(
406 options: ZeroTrieBuilderOptions,
407 a: (&[u8], usize),
408 b: (&[u8], usize),
409 ) -> Ordering {
410 if matches!(options.case_sensitivity, CaseSensitivity::Sensitive) {
411 a.0.cmp(b.0)
412 } else {
413 let a_iter = a.0.iter().map(|x| x.to_ascii_lowercase());
414 let b_iter = b.0.iter().map(|x| x.to_ascii_lowercase());
415 Iterator::cmp(a_iter, b_iter)
416 }
417 .then_with(|| a.1.cmp(&b.1))
418 }
419