1 use core::mem;
2
3 // The following ~400 lines of code exists for exactly one purpose, which is
4 // to optimize this code:
5 //
6 // byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len())
7 //
8 // Yes... Overengineered is a word that comes to mind, but this is effectively
9 // a very similar problem to memchr, and virtually nobody has been able to
10 // resist optimizing the crap out of that (except for perhaps the BSD and MUSL
11 // folks). In particular, this routine makes a very common case (ASCII) very
12 // fast, which seems worth it. We do stop short of adding AVX variants of the
13 // code below in order to retain our sanity and also to avoid needing to deal
14 // with runtime target feature detection. RESIST!
15 //
16 // In order to understand the SIMD version below, it would be good to read this
17 // comment describing how my memchr routine works:
18 // https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106
19 //
20 // The primary difference with memchr is that for ASCII, we can do a bit less
21 // work. In particular, we don't need to detect the presence of a specific
22 // byte, but rather, whether any byte has its most significant bit set. That
23 // means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to
24 // _mm_movemask_epi8.
25
26 #[cfg(any(test, not(target_arch = "x86_64")))]
27 const USIZE_BYTES: usize = mem::size_of::<usize>();
28 #[cfg(any(test, not(target_arch = "x86_64")))]
29 const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES;
30
31 // This is a mask where the most significant bit of each byte in the usize
32 // is set. We test this bit to determine whether a character is ASCII or not.
33 // Namely, a single byte is regarded as an ASCII codepoint if and only if it's
34 // most significant bit is not set.
35 #[cfg(any(test, not(target_arch = "x86_64")))]
36 const ASCII_MASK_U64: u64 = 0x8080808080808080;
37 #[cfg(any(test, not(target_arch = "x86_64")))]
38 const ASCII_MASK: usize = ASCII_MASK_U64 as usize;
39
40 /// Returns the index of the first non ASCII byte in the given slice.
41 ///
42 /// If slice only contains ASCII bytes, then the length of the slice is
43 /// returned.
first_non_ascii_byte(slice: &[u8]) -> usize44 pub fn first_non_ascii_byte(slice: &[u8]) -> usize {
45 #[cfg(not(target_arch = "x86_64"))]
46 {
47 first_non_ascii_byte_fallback(slice)
48 }
49
50 #[cfg(target_arch = "x86_64")]
51 {
52 first_non_ascii_byte_sse2(slice)
53 }
54 }
55
56 #[cfg(any(test, not(target_arch = "x86_64")))]
first_non_ascii_byte_fallback(slice: &[u8]) -> usize57 fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize {
58 let align = USIZE_BYTES - 1;
59 let start_ptr = slice.as_ptr();
60 let end_ptr = slice[slice.len()..].as_ptr();
61 let mut ptr = start_ptr;
62
63 unsafe {
64 if slice.len() < USIZE_BYTES {
65 return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
66 }
67
68 let chunk = read_unaligned_usize(ptr);
69 let mask = chunk & ASCII_MASK;
70 if mask != 0 {
71 return first_non_ascii_byte_mask(mask);
72 }
73
74 ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align));
75 debug_assert!(ptr > start_ptr);
76 debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr);
77 if slice.len() >= FALLBACK_LOOP_SIZE {
78 while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) {
79 debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
80
81 let a = *(ptr as *const usize);
82 let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
83 if (a | b) & ASCII_MASK != 0 {
84 // What a kludge. We wrap the position finding code into
85 // a non-inlineable function, which makes the codegen in
86 // the tight loop above a bit better by avoiding a
87 // couple extra movs. We pay for it by two additional
88 // stores, but only in the case of finding a non-ASCII
89 // byte.
90 #[inline(never)]
91 unsafe fn findpos(
92 start_ptr: *const u8,
93 ptr: *const u8,
94 ) -> usize {
95 let a = *(ptr as *const usize);
96 let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
97
98 let mut at = sub(ptr, start_ptr);
99 let maska = a & ASCII_MASK;
100 if maska != 0 {
101 return at + first_non_ascii_byte_mask(maska);
102 }
103
104 at += USIZE_BYTES;
105 let maskb = b & ASCII_MASK;
106 debug_assert!(maskb != 0);
107 return at + first_non_ascii_byte_mask(maskb);
108 }
109 return findpos(start_ptr, ptr);
110 }
111 ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE);
112 }
113 }
114 first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
115 }
116 }
117
118 #[cfg(target_arch = "x86_64")]
first_non_ascii_byte_sse2(slice: &[u8]) -> usize119 fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize {
120 use core::arch::x86_64::*;
121
122 const VECTOR_SIZE: usize = mem::size_of::<__m128i>();
123 const VECTOR_ALIGN: usize = VECTOR_SIZE - 1;
124 const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE;
125
126 let start_ptr = slice.as_ptr();
127 let end_ptr = slice[slice.len()..].as_ptr();
128 let mut ptr = start_ptr;
129
130 unsafe {
131 if slice.len() < VECTOR_SIZE {
132 return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
133 }
134
135 let chunk = _mm_loadu_si128(ptr as *const __m128i);
136 let mask = _mm_movemask_epi8(chunk);
137 if mask != 0 {
138 return mask.trailing_zeros() as usize;
139 }
140
141 ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN));
142 debug_assert!(ptr > start_ptr);
143 debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr);
144 if slice.len() >= VECTOR_LOOP_SIZE {
145 while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) {
146 debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE);
147
148 let a = _mm_load_si128(ptr as *const __m128i);
149 let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i);
150 let c =
151 _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i);
152 let d =
153 _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i);
154
155 let or1 = _mm_or_si128(a, b);
156 let or2 = _mm_or_si128(c, d);
157 let or3 = _mm_or_si128(or1, or2);
158 if _mm_movemask_epi8(or3) != 0 {
159 let mut at = sub(ptr, start_ptr);
160 let mask = _mm_movemask_epi8(a);
161 if mask != 0 {
162 return at + mask.trailing_zeros() as usize;
163 }
164
165 at += VECTOR_SIZE;
166 let mask = _mm_movemask_epi8(b);
167 if mask != 0 {
168 return at + mask.trailing_zeros() as usize;
169 }
170
171 at += VECTOR_SIZE;
172 let mask = _mm_movemask_epi8(c);
173 if mask != 0 {
174 return at + mask.trailing_zeros() as usize;
175 }
176
177 at += VECTOR_SIZE;
178 let mask = _mm_movemask_epi8(d);
179 debug_assert!(mask != 0);
180 return at + mask.trailing_zeros() as usize;
181 }
182 ptr = ptr_add(ptr, VECTOR_LOOP_SIZE);
183 }
184 }
185 while ptr <= end_ptr.sub(VECTOR_SIZE) {
186 debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE);
187
188 let chunk = _mm_loadu_si128(ptr as *const __m128i);
189 let mask = _mm_movemask_epi8(chunk);
190 if mask != 0 {
191 return sub(ptr, start_ptr) + mask.trailing_zeros() as usize;
192 }
193 ptr = ptr.add(VECTOR_SIZE);
194 }
195 first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
196 }
197 }
198
199 #[inline(always)]
first_non_ascii_byte_slow( start_ptr: *const u8, end_ptr: *const u8, mut ptr: *const u8, ) -> usize200 unsafe fn first_non_ascii_byte_slow(
201 start_ptr: *const u8,
202 end_ptr: *const u8,
203 mut ptr: *const u8,
204 ) -> usize {
205 debug_assert!(start_ptr <= ptr);
206 debug_assert!(ptr <= end_ptr);
207
208 while ptr < end_ptr {
209 if *ptr > 0x7F {
210 return sub(ptr, start_ptr);
211 }
212 ptr = ptr.offset(1);
213 }
214 sub(end_ptr, start_ptr)
215 }
216
217 /// Compute the position of the first ASCII byte in the given mask.
218 ///
219 /// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is
220 /// 8 contiguous bytes of the slice being checked where *at least* one of those
221 /// bytes is not an ASCII byte.
222 ///
223 /// The position returned is always in the inclusive range [0, 7].
224 #[cfg(any(test, not(target_arch = "x86_64")))]
first_non_ascii_byte_mask(mask: usize) -> usize225 fn first_non_ascii_byte_mask(mask: usize) -> usize {
226 #[cfg(target_endian = "little")]
227 {
228 mask.trailing_zeros() as usize / 8
229 }
230 #[cfg(target_endian = "big")]
231 {
232 mask.leading_zeros() as usize / 8
233 }
234 }
235
236 /// Increment the given pointer by the given amount.
ptr_add(ptr: *const u8, amt: usize) -> *const u8237 unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 {
238 debug_assert!(amt < ::core::isize::MAX as usize);
239 ptr.offset(amt as isize)
240 }
241
242 /// Decrement the given pointer by the given amount.
ptr_sub(ptr: *const u8, amt: usize) -> *const u8243 unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 {
244 debug_assert!(amt < ::core::isize::MAX as usize);
245 ptr.offset((amt as isize).wrapping_neg())
246 }
247
248 #[cfg(any(test, not(target_arch = "x86_64")))]
read_unaligned_usize(ptr: *const u8) -> usize249 unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
250 use core::ptr;
251
252 let mut n: usize = 0;
253 ptr::copy_nonoverlapping(ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES);
254 n
255 }
256
257 /// Subtract `b` from `a` and return the difference. `a` should be greater than
258 /// or equal to `b`.
sub(a: *const u8, b: *const u8) -> usize259 fn sub(a: *const u8, b: *const u8) -> usize {
260 debug_assert!(a >= b);
261 (a as usize) - (b as usize)
262 }
263
264 #[cfg(test)]
265 mod tests {
266 use super::*;
267
268 // Our testing approach here is to try and exhaustively test every case.
269 // This includes the position at which a non-ASCII byte occurs in addition
270 // to the alignment of the slice that we're searching.
271
272 #[test]
positive_fallback_forward()273 fn positive_fallback_forward() {
274 for i in 0..517 {
275 let s = "a".repeat(i);
276 assert_eq!(
277 i,
278 first_non_ascii_byte_fallback(s.as_bytes()),
279 "i: {:?}, len: {:?}, s: {:?}",
280 i,
281 s.len(),
282 s
283 );
284 }
285 }
286
287 #[test]
288 #[cfg(target_arch = "x86_64")]
positive_sse2_forward()289 fn positive_sse2_forward() {
290 for i in 0..517 {
291 let b = "a".repeat(i).into_bytes();
292 assert_eq!(b.len(), first_non_ascii_byte_sse2(&b));
293 }
294 }
295
296 #[test]
negative_fallback_forward()297 fn negative_fallback_forward() {
298 for i in 0..517 {
299 for align in 0..65 {
300 let mut s = "a".repeat(i);
301 s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
302 let s = s.get(align..).unwrap_or("");
303 assert_eq!(
304 i.saturating_sub(align),
305 first_non_ascii_byte_fallback(s.as_bytes()),
306 "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
307 i,
308 align,
309 s.len(),
310 s
311 );
312 }
313 }
314 }
315
316 #[test]
317 #[cfg(target_arch = "x86_64")]
negative_sse2_forward()318 fn negative_sse2_forward() {
319 for i in 0..517 {
320 for align in 0..65 {
321 let mut s = "a".repeat(i);
322 s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
323 let s = s.get(align..).unwrap_or("");
324 assert_eq!(
325 i.saturating_sub(align),
326 first_non_ascii_byte_sse2(s.as_bytes()),
327 "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
328 i,
329 align,
330 s.len(),
331 s
332 );
333 }
334 }
335 }
336 }
337