• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use core::mem;
2 
3 // The following ~400 lines of code exists for exactly one purpose, which is
4 // to optimize this code:
5 //
6 //     byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len())
7 //
8 // Yes... Overengineered is a word that comes to mind, but this is effectively
9 // a very similar problem to memchr, and virtually nobody has been able to
10 // resist optimizing the crap out of that (except for perhaps the BSD and MUSL
11 // folks). In particular, this routine makes a very common case (ASCII) very
12 // fast, which seems worth it. We do stop short of adding AVX variants of the
13 // code below in order to retain our sanity and also to avoid needing to deal
14 // with runtime target feature detection. RESIST!
15 //
16 // In order to understand the SIMD version below, it would be good to read this
17 // comment describing how my memchr routine works:
18 // https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106
19 //
20 // The primary difference with memchr is that for ASCII, we can do a bit less
21 // work. In particular, we don't need to detect the presence of a specific
22 // byte, but rather, whether any byte has its most significant bit set. That
23 // means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to
24 // _mm_movemask_epi8.
25 
26 #[cfg(any(test, miri, not(target_arch = "x86_64")))]
27 const USIZE_BYTES: usize = mem::size_of::<usize>();
28 #[cfg(any(test, miri, not(target_arch = "x86_64")))]
29 const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES;
30 
31 // This is a mask where the most significant bit of each byte in the usize
32 // is set. We test this bit to determine whether a character is ASCII or not.
33 // Namely, a single byte is regarded as an ASCII codepoint if and only if it's
34 // most significant bit is not set.
35 #[cfg(any(test, miri, not(target_arch = "x86_64")))]
36 const ASCII_MASK_U64: u64 = 0x8080808080808080;
37 #[cfg(any(test, miri, not(target_arch = "x86_64")))]
38 const ASCII_MASK: usize = ASCII_MASK_U64 as usize;
39 
40 /// Returns the index of the first non ASCII byte in the given slice.
41 ///
42 /// If slice only contains ASCII bytes, then the length of the slice is
43 /// returned.
first_non_ascii_byte(slice: &[u8]) -> usize44 pub fn first_non_ascii_byte(slice: &[u8]) -> usize {
45     #[cfg(any(miri, not(target_arch = "x86_64")))]
46     {
47         first_non_ascii_byte_fallback(slice)
48     }
49 
50     #[cfg(all(not(miri), target_arch = "x86_64"))]
51     {
52         first_non_ascii_byte_sse2(slice)
53     }
54 }
55 
56 #[cfg(any(test, miri, not(target_arch = "x86_64")))]
first_non_ascii_byte_fallback(slice: &[u8]) -> usize57 fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize {
58     let align = USIZE_BYTES - 1;
59     let start_ptr = slice.as_ptr();
60     let end_ptr = slice[slice.len()..].as_ptr();
61     let mut ptr = start_ptr;
62 
63     unsafe {
64         if slice.len() < USIZE_BYTES {
65             return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
66         }
67 
68         let chunk = read_unaligned_usize(ptr);
69         let mask = chunk & ASCII_MASK;
70         if mask != 0 {
71             return first_non_ascii_byte_mask(mask);
72         }
73 
74         ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align));
75         debug_assert!(ptr > start_ptr);
76         debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr);
77         if slice.len() >= FALLBACK_LOOP_SIZE {
78             while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) {
79                 debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
80 
81                 let a = *(ptr as *const usize);
82                 let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
83                 if (a | b) & ASCII_MASK != 0 {
84                     // What a kludge. We wrap the position finding code into
85                     // a non-inlineable function, which makes the codegen in
86                     // the tight loop above a bit better by avoiding a
87                     // couple extra movs. We pay for it by two additional
88                     // stores, but only in the case of finding a non-ASCII
89                     // byte.
90                     #[inline(never)]
91                     unsafe fn findpos(
92                         start_ptr: *const u8,
93                         ptr: *const u8,
94                     ) -> usize {
95                         let a = *(ptr as *const usize);
96                         let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
97 
98                         let mut at = sub(ptr, start_ptr);
99                         let maska = a & ASCII_MASK;
100                         if maska != 0 {
101                             return at + first_non_ascii_byte_mask(maska);
102                         }
103 
104                         at += USIZE_BYTES;
105                         let maskb = b & ASCII_MASK;
106                         debug_assert!(maskb != 0);
107                         return at + first_non_ascii_byte_mask(maskb);
108                     }
109                     return findpos(start_ptr, ptr);
110                 }
111                 ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE);
112             }
113         }
114         first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
115     }
116 }
117 
118 #[cfg(all(not(miri), target_arch = "x86_64"))]
first_non_ascii_byte_sse2(slice: &[u8]) -> usize119 fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize {
120     use core::arch::x86_64::*;
121 
122     const VECTOR_SIZE: usize = mem::size_of::<__m128i>();
123     const VECTOR_ALIGN: usize = VECTOR_SIZE - 1;
124     const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE;
125 
126     let start_ptr = slice.as_ptr();
127     let end_ptr = slice[slice.len()..].as_ptr();
128     let mut ptr = start_ptr;
129 
130     unsafe {
131         if slice.len() < VECTOR_SIZE {
132             return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
133         }
134 
135         let chunk = _mm_loadu_si128(ptr as *const __m128i);
136         let mask = _mm_movemask_epi8(chunk);
137         if mask != 0 {
138             return mask.trailing_zeros() as usize;
139         }
140 
141         ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN));
142         debug_assert!(ptr > start_ptr);
143         debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr);
144         if slice.len() >= VECTOR_LOOP_SIZE {
145             while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) {
146                 debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE);
147 
148                 let a = _mm_load_si128(ptr as *const __m128i);
149                 let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i);
150                 let c =
151                     _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i);
152                 let d =
153                     _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i);
154 
155                 let or1 = _mm_or_si128(a, b);
156                 let or2 = _mm_or_si128(c, d);
157                 let or3 = _mm_or_si128(or1, or2);
158                 if _mm_movemask_epi8(or3) != 0 {
159                     let mut at = sub(ptr, start_ptr);
160                     let mask = _mm_movemask_epi8(a);
161                     if mask != 0 {
162                         return at + mask.trailing_zeros() as usize;
163                     }
164 
165                     at += VECTOR_SIZE;
166                     let mask = _mm_movemask_epi8(b);
167                     if mask != 0 {
168                         return at + mask.trailing_zeros() as usize;
169                     }
170 
171                     at += VECTOR_SIZE;
172                     let mask = _mm_movemask_epi8(c);
173                     if mask != 0 {
174                         return at + mask.trailing_zeros() as usize;
175                     }
176 
177                     at += VECTOR_SIZE;
178                     let mask = _mm_movemask_epi8(d);
179                     debug_assert!(mask != 0);
180                     return at + mask.trailing_zeros() as usize;
181                 }
182                 ptr = ptr_add(ptr, VECTOR_LOOP_SIZE);
183             }
184         }
185         while ptr <= end_ptr.sub(VECTOR_SIZE) {
186             debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE);
187 
188             let chunk = _mm_loadu_si128(ptr as *const __m128i);
189             let mask = _mm_movemask_epi8(chunk);
190             if mask != 0 {
191                 return sub(ptr, start_ptr) + mask.trailing_zeros() as usize;
192             }
193             ptr = ptr.add(VECTOR_SIZE);
194         }
195         first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
196     }
197 }
198 
199 #[inline(always)]
first_non_ascii_byte_slow( start_ptr: *const u8, end_ptr: *const u8, mut ptr: *const u8, ) -> usize200 unsafe fn first_non_ascii_byte_slow(
201     start_ptr: *const u8,
202     end_ptr: *const u8,
203     mut ptr: *const u8,
204 ) -> usize {
205     debug_assert!(start_ptr <= ptr);
206     debug_assert!(ptr <= end_ptr);
207 
208     while ptr < end_ptr {
209         if *ptr > 0x7F {
210             return sub(ptr, start_ptr);
211         }
212         ptr = ptr.offset(1);
213     }
214     sub(end_ptr, start_ptr)
215 }
216 
217 /// Compute the position of the first ASCII byte in the given mask.
218 ///
219 /// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is
220 /// 8 contiguous bytes of the slice being checked where *at least* one of those
221 /// bytes is not an ASCII byte.
222 ///
223 /// The position returned is always in the inclusive range [0, 7].
224 #[cfg(any(test, miri, not(target_arch = "x86_64")))]
first_non_ascii_byte_mask(mask: usize) -> usize225 fn first_non_ascii_byte_mask(mask: usize) -> usize {
226     #[cfg(target_endian = "little")]
227     {
228         mask.trailing_zeros() as usize / 8
229     }
230     #[cfg(target_endian = "big")]
231     {
232         mask.leading_zeros() as usize / 8
233     }
234 }
235 
236 /// Increment the given pointer by the given amount.
ptr_add(ptr: *const u8, amt: usize) -> *const u8237 unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 {
238     debug_assert!(amt < ::core::isize::MAX as usize);
239     ptr.offset(amt as isize)
240 }
241 
242 /// Decrement the given pointer by the given amount.
ptr_sub(ptr: *const u8, amt: usize) -> *const u8243 unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 {
244     debug_assert!(amt < ::core::isize::MAX as usize);
245     ptr.offset((amt as isize).wrapping_neg())
246 }
247 
248 #[cfg(any(test, miri, not(target_arch = "x86_64")))]
read_unaligned_usize(ptr: *const u8) -> usize249 unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
250     use core::ptr;
251 
252     let mut n: usize = 0;
253     ptr::copy_nonoverlapping(ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES);
254     n
255 }
256 
257 /// Subtract `b` from `a` and return the difference. `a` should be greater than
258 /// or equal to `b`.
sub(a: *const u8, b: *const u8) -> usize259 fn sub(a: *const u8, b: *const u8) -> usize {
260     debug_assert!(a >= b);
261     (a as usize) - (b as usize)
262 }
263 
264 #[cfg(test)]
265 mod tests {
266     use super::*;
267 
268     // Our testing approach here is to try and exhaustively test every case.
269     // This includes the position at which a non-ASCII byte occurs in addition
270     // to the alignment of the slice that we're searching.
271 
272     #[test]
positive_fallback_forward()273     fn positive_fallback_forward() {
274         for i in 0..517 {
275             let s = "a".repeat(i);
276             assert_eq!(
277                 i,
278                 first_non_ascii_byte_fallback(s.as_bytes()),
279                 "i: {:?}, len: {:?}, s: {:?}",
280                 i,
281                 s.len(),
282                 s
283             );
284         }
285     }
286 
287     #[test]
288     #[cfg(target_arch = "x86_64")]
289     #[cfg(not(miri))]
positive_sse2_forward()290     fn positive_sse2_forward() {
291         for i in 0..517 {
292             let b = "a".repeat(i).into_bytes();
293             assert_eq!(b.len(), first_non_ascii_byte_sse2(&b));
294         }
295     }
296 
297     #[test]
298     #[cfg(not(miri))]
negative_fallback_forward()299     fn negative_fallback_forward() {
300         for i in 0..517 {
301             for align in 0..65 {
302                 let mut s = "a".repeat(i);
303                 s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
304                 let s = s.get(align..).unwrap_or("");
305                 assert_eq!(
306                     i.saturating_sub(align),
307                     first_non_ascii_byte_fallback(s.as_bytes()),
308                     "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
309                     i,
310                     align,
311                     s.len(),
312                     s
313                 );
314             }
315         }
316     }
317 
318     #[test]
319     #[cfg(target_arch = "x86_64")]
320     #[cfg(not(miri))]
negative_sse2_forward()321     fn negative_sse2_forward() {
322         for i in 0..517 {
323             for align in 0..65 {
324                 let mut s = "a".repeat(i);
325                 s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
326                 let s = s.get(align..).unwrap_or("");
327                 assert_eq!(
328                     i.saturating_sub(align),
329                     first_non_ascii_byte_sse2(s.as_bytes()),
330                     "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
331                     i,
332                     align,
333                     s.len(),
334                     s
335                 );
336             }
337         }
338     }
339 }
340