• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This is adapted from `fallback.rs` from rust-memchr. It's modified to return
2 // the 'inverse' query of memchr, e.g. finding the first byte not in the
3 // provided set. This is simple for the 1-byte case.
4 
5 use core::{cmp, usize};
6 
7 const USIZE_BYTES: usize = core::mem::size_of::<usize>();
8 
9 // The number of bytes to loop at in one iteration of memchr/memrchr.
10 const LOOP_SIZE: usize = 2 * USIZE_BYTES;
11 
12 /// Repeat the given byte into a word size number. That is, every 8 bits
13 /// is equivalent to the given byte. For example, if `b` is `\x4E` or
14 /// `01001110` in binary, then the returned value on a 32-bit system would be:
15 /// `01001110_01001110_01001110_01001110`.
16 #[inline(always)]
repeat_byte(b: u8) -> usize17 fn repeat_byte(b: u8) -> usize {
18     (b as usize) * (usize::MAX / 255)
19 }
20 
inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize>21 pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
22     let vn1 = repeat_byte(n1);
23     let confirm = |byte| byte != n1;
24     let loop_size = cmp::min(LOOP_SIZE, haystack.len());
25     let align = USIZE_BYTES - 1;
26     let start_ptr = haystack.as_ptr();
27 
28     unsafe {
29         let end_ptr = haystack.as_ptr().add(haystack.len());
30         let mut ptr = start_ptr;
31 
32         if haystack.len() < USIZE_BYTES {
33             return forward_search(start_ptr, end_ptr, ptr, confirm);
34         }
35 
36         let chunk = read_unaligned_usize(ptr);
37         if (chunk ^ vn1) != 0 {
38             return forward_search(start_ptr, end_ptr, ptr, confirm);
39         }
40 
41         ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & align));
42         debug_assert!(ptr > start_ptr);
43         debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
44         while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
45             debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
46 
47             let a = *(ptr as *const usize);
48             let b = *(ptr.add(USIZE_BYTES) as *const usize);
49             let eqa = (a ^ vn1) != 0;
50             let eqb = (b ^ vn1) != 0;
51             if eqa || eqb {
52                 break;
53             }
54             ptr = ptr.add(LOOP_SIZE);
55         }
56         forward_search(start_ptr, end_ptr, ptr, confirm)
57     }
58 }
59 
60 /// Return the last index not matching the byte `x` in `text`.
inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize>61 pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
62     let vn1 = repeat_byte(n1);
63     let confirm = |byte| byte != n1;
64     let loop_size = cmp::min(LOOP_SIZE, haystack.len());
65     let align = USIZE_BYTES - 1;
66     let start_ptr = haystack.as_ptr();
67 
68     unsafe {
69         let end_ptr = haystack.as_ptr().add(haystack.len());
70         let mut ptr = end_ptr;
71 
72         if haystack.len() < USIZE_BYTES {
73             return reverse_search(start_ptr, end_ptr, ptr, confirm);
74         }
75 
76         let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES));
77         if (chunk ^ vn1) != 0 {
78             return reverse_search(start_ptr, end_ptr, ptr, confirm);
79         }
80 
81         ptr = ptr.sub(end_ptr as usize & align);
82         debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
83         while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
84             debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
85 
86             let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
87             let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
88             let eqa = (a ^ vn1) != 0;
89             let eqb = (b ^ vn1) != 0;
90             if eqa || eqb {
91                 break;
92             }
93             ptr = ptr.sub(loop_size);
94         }
95         reverse_search(start_ptr, end_ptr, ptr, confirm)
96     }
97 }
98 
99 #[inline(always)]
forward_search<F: Fn(u8) -> bool>( start_ptr: *const u8, end_ptr: *const u8, mut ptr: *const u8, confirm: F, ) -> Option<usize>100 unsafe fn forward_search<F: Fn(u8) -> bool>(
101     start_ptr: *const u8,
102     end_ptr: *const u8,
103     mut ptr: *const u8,
104     confirm: F,
105 ) -> Option<usize> {
106     debug_assert!(start_ptr <= ptr);
107     debug_assert!(ptr <= end_ptr);
108 
109     while ptr < end_ptr {
110         if confirm(*ptr) {
111             return Some(sub(ptr, start_ptr));
112         }
113         ptr = ptr.offset(1);
114     }
115     None
116 }
117 
118 #[inline(always)]
reverse_search<F: Fn(u8) -> bool>( start_ptr: *const u8, end_ptr: *const u8, mut ptr: *const u8, confirm: F, ) -> Option<usize>119 unsafe fn reverse_search<F: Fn(u8) -> bool>(
120     start_ptr: *const u8,
121     end_ptr: *const u8,
122     mut ptr: *const u8,
123     confirm: F,
124 ) -> Option<usize> {
125     debug_assert!(start_ptr <= ptr);
126     debug_assert!(ptr <= end_ptr);
127 
128     while ptr > start_ptr {
129         ptr = ptr.offset(-1);
130         if confirm(*ptr) {
131             return Some(sub(ptr, start_ptr));
132         }
133     }
134     None
135 }
136 
read_unaligned_usize(ptr: *const u8) -> usize137 unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
138     (ptr as *const usize).read_unaligned()
139 }
140 
141 /// Subtract `b` from `a` and return the difference. `a` should be greater than
142 /// or equal to `b`.
sub(a: *const u8, b: *const u8) -> usize143 fn sub(a: *const u8, b: *const u8) -> usize {
144     debug_assert!(a >= b);
145     (a as usize) - (b as usize)
146 }
147 
148 /// Safe wrapper around `forward_search`
149 #[inline]
forward_search_bytes<F: Fn(u8) -> bool>( s: &[u8], confirm: F, ) -> Option<usize>150 pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>(
151     s: &[u8],
152     confirm: F,
153 ) -> Option<usize> {
154     unsafe {
155         let start = s.as_ptr();
156         let end = start.add(s.len());
157         forward_search(start, end, start, confirm)
158     }
159 }
160 
161 /// Safe wrapper around `reverse_search`
162 #[inline]
reverse_search_bytes<F: Fn(u8) -> bool>( s: &[u8], confirm: F, ) -> Option<usize>163 pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>(
164     s: &[u8],
165     confirm: F,
166 ) -> Option<usize> {
167     unsafe {
168         let start = s.as_ptr();
169         let end = start.add(s.len());
170         reverse_search(start, end, end, confirm)
171     }
172 }
173 
174 #[cfg(all(test, feature = "std"))]
175 mod tests {
176     use super::{inv_memchr, inv_memrchr};
177 
178     // search string, search byte, inv_memchr result, inv_memrchr result.
179     // these are expanded into a much larger set of tests in build_tests
180     const TESTS: &[(&[u8], u8, usize, usize)] = &[
181         (b"z", b'a', 0, 0),
182         (b"zz", b'a', 0, 1),
183         (b"aza", b'a', 1, 1),
184         (b"zaz", b'a', 0, 2),
185         (b"zza", b'a', 0, 1),
186         (b"zaa", b'a', 0, 0),
187         (b"zzz", b'a', 0, 2),
188     ];
189 
190     type TestCase = (Vec<u8>, u8, Option<(usize, usize)>);
191 
build_tests() -> Vec<TestCase>192     fn build_tests() -> Vec<TestCase> {
193         #[cfg(not(miri))]
194         const MAX_PER: usize = 515;
195         #[cfg(miri)]
196         const MAX_PER: usize = 10;
197 
198         let mut result = vec![];
199         for &(search, byte, fwd_pos, rev_pos) in TESTS {
200             result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos))));
201             for i in 1..MAX_PER {
202                 // add a bunch of copies of the search byte to the end.
203                 let mut suffixed: Vec<u8> = search.into();
204                 suffixed.extend(std::iter::repeat(byte).take(i));
205                 result.push((suffixed, byte, Some((fwd_pos, rev_pos))));
206 
207                 // add a bunch of copies of the search byte to the start.
208                 let mut prefixed: Vec<u8> =
209                     std::iter::repeat(byte).take(i).collect();
210                 prefixed.extend(search);
211                 result.push((
212                     prefixed,
213                     byte,
214                     Some((fwd_pos + i, rev_pos + i)),
215                 ));
216 
217                 // add a bunch of copies of the search byte to both ends.
218                 let mut surrounded: Vec<u8> =
219                     std::iter::repeat(byte).take(i).collect();
220                 surrounded.extend(search);
221                 surrounded.extend(std::iter::repeat(byte).take(i));
222                 result.push((
223                     surrounded,
224                     byte,
225                     Some((fwd_pos + i, rev_pos + i)),
226                 ));
227             }
228         }
229 
230         // build non-matching tests for several sizes
231         for i in 0..MAX_PER {
232             result.push((
233                 std::iter::repeat(b'\0').take(i).collect(),
234                 b'\0',
235                 None,
236             ));
237         }
238 
239         result
240     }
241 
242     #[test]
test_inv_memchr()243     fn test_inv_memchr() {
244         use crate::{ByteSlice, B};
245 
246         #[cfg(not(miri))]
247         const MAX_OFFSET: usize = 130;
248         #[cfg(miri)]
249         const MAX_OFFSET: usize = 13;
250 
251         for (search, byte, matching) in build_tests() {
252             assert_eq!(
253                 inv_memchr(byte, &search),
254                 matching.map(|m| m.0),
255                 "inv_memchr when searching for {:?} in {:?}",
256                 byte as char,
257                 // better printing
258                 B(&search).as_bstr(),
259             );
260             assert_eq!(
261                 inv_memrchr(byte, &search),
262                 matching.map(|m| m.1),
263                 "inv_memrchr when searching for {:?} in {:?}",
264                 byte as char,
265                 // better printing
266                 B(&search).as_bstr(),
267             );
268             // Test a rather large number off offsets for potential alignment
269             // issues.
270             for offset in 1..MAX_OFFSET {
271                 if offset >= search.len() {
272                     break;
273                 }
274                 // If this would cause us to shift the results off the end,
275                 // skip it so that we don't have to recompute them.
276                 if let Some((f, r)) = matching {
277                     if offset > f || offset > r {
278                         break;
279                     }
280                 }
281                 let realigned = &search[offset..];
282 
283                 let forward_pos = matching.map(|m| m.0 - offset);
284                 let reverse_pos = matching.map(|m| m.1 - offset);
285 
286                 assert_eq!(
287                     inv_memchr(byte, &realigned),
288                     forward_pos,
289                     "inv_memchr when searching (realigned by {}) for {:?} in {:?}",
290                     offset,
291                     byte as char,
292                     realigned.as_bstr(),
293                 );
294                 assert_eq!(
295                     inv_memrchr(byte, &realigned),
296                     reverse_pos,
297                     "inv_memrchr when searching (realigned by {}) for {:?} in {:?}",
298                     offset,
299                     byte as char,
300                     realigned.as_bstr(),
301                 );
302             }
303         }
304     }
305 }
306