• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::iter::Bytes;
2 
3 #[inline]
4 #[target_feature(enable = "avx2")]
match_uri_vectored(bytes: &mut Bytes)5 pub unsafe fn match_uri_vectored(bytes: &mut Bytes) {
6     while bytes.as_ref().len() >= 32 {
7 
8         let advance = match_url_char_32_avx(bytes.as_ref());
9 
10         bytes.advance(advance);
11 
12         if advance != 32 {
13             return;
14         }
15     }
16     // NOTE: use SWAR for <32B, more efficient than falling back to SSE4.2
17     super::swar::match_uri_vectored(bytes)
18 }
19 
20 #[inline(always)]
21 #[allow(non_snake_case, overflowing_literals)]
22 #[allow(unused)]
match_url_char_32_avx(buf: &[u8]) -> usize23 unsafe fn match_url_char_32_avx(buf: &[u8]) -> usize {
24     // NOTE: This check might be not necessary since this function is only used in
25     // `match_uri_vectored` where buffer overflow is taken care of.
26     debug_assert!(buf.len() >= 32);
27 
28     #[cfg(target_arch = "x86")]
29     use core::arch::x86::*;
30     #[cfg(target_arch = "x86_64")]
31     use core::arch::x86_64::*;
32 
33     // pointer to buffer
34     let ptr = buf.as_ptr();
35 
36     // %x21-%x7e %x80-%xff
37     //
38     // Character ranges allowed by this function, can also be interpreted as:
39     // 33 =< (x != 127) =< 255
40     //
41     // Create a vector full of DEL (0x7f) characters.
42     let DEL: __m256i = _mm256_set1_epi8(0x7f);
43     // Create a vector full of exclamation mark (!) (0x21) characters.
44     // Used as lower threshold, characters in URLs cannot be smaller than this.
45     let LOW: __m256i = _mm256_set1_epi8(0x21);
46 
47     // Load a chunk of 32 bytes from `ptr` as a vector.
48     // We can check 32 bytes in parallel at most with AVX2 since
49     // YMM registers can only have 256 bits most.
50     let dat = _mm256_lddqu_si256(ptr as *const _);
51 
52     // unsigned comparison dat >= LOW
53     //
54     // `_mm256_max_epu8` creates a new vector by comparing vectors `dat` and `LOW`
55     // and picks the max. values from each for all indices.
56     // So if a byte in `dat` is <= 32, it'll be represented as 33
57     // which is the smallest valid character.
58     //
59     // Then, we compare the new vector with `dat` for equality.
60     //
61     // `_mm256_cmpeq_epi8` returns a new vector where;
62     // * matching bytes are set to 0xFF (all bits set),
63     // * nonmatching bytes are set to 0 (no bits set).
64     let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat);
65     // Similar to what we did before, but now invalid characters are set to 0xFF.
66     let del = _mm256_cmpeq_epi8(dat, DEL);
67 
68     // We glue the both comparisons via `_mm256_andnot_si256`.
69     //
70     // Since the representation of truthiness differ in these comparisons,
71     // we are in need of bitwise NOT to convert valid characters of `del`.
72     let bit = _mm256_andnot_si256(del, low);
73     // This creates a bitmask from the most significant bit of each byte.
74     // Simply, we're converting a vector value to scalar value here.
75     let res = _mm256_movemask_epi8(bit) as u32;
76 
77     // Count trailing zeros to find the first encountered invalid character.
78     // Bitwise NOT is required once again to flip truthiness.
79     // TODO: use .trailing_ones() once MSRV >= 1.46
80     (!res).trailing_zeros() as usize
81 }
82 
83 #[target_feature(enable = "avx2")]
match_header_value_vectored(bytes: &mut Bytes)84 pub unsafe fn match_header_value_vectored(bytes: &mut Bytes) {
85     while bytes.as_ref().len() >= 32 {
86         let advance = match_header_value_char_32_avx(bytes.as_ref());
87         bytes.advance(advance);
88 
89         if advance != 32 {
90             return;
91         }
92     }
93     // NOTE: use SWAR for <32B, more efficient than falling back to SSE4.2
94     super::swar::match_header_value_vectored(bytes)
95 }
96 
97 #[inline(always)]
98 #[allow(non_snake_case)]
99 #[allow(unused)]
match_header_value_char_32_avx(buf: &[u8]) -> usize100 unsafe fn match_header_value_char_32_avx(buf: &[u8]) -> usize {
101     debug_assert!(buf.len() >= 32);
102 
103     #[cfg(target_arch = "x86")]
104     use core::arch::x86::*;
105     #[cfg(target_arch = "x86_64")]
106     use core::arch::x86_64::*;
107 
108     let ptr = buf.as_ptr();
109 
110     // %x09 %x20-%x7e %x80-%xff
111     // Create a vector full of horizontal tab (\t) (0x09) characters.
112     let TAB: __m256i = _mm256_set1_epi8(0x09);
113     // Create a vector full of DEL (0x7f) characters.
114     let DEL: __m256i = _mm256_set1_epi8(0x7f);
115     // Create a vector full of space (0x20) characters.
116     let LOW: __m256i = _mm256_set1_epi8(0x20);
117 
118     // Load a chunk of 32 bytes from `ptr` as a vector.
119     let dat = _mm256_lddqu_si256(ptr as *const _);
120 
121     // unsigned comparison dat >= LOW
122     //
123     // Same as what we do in `match_url_char_32_avx`.
124     // This time the lower threshold is set to space character though.
125     let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat);
126     // Check if `dat` includes `TAB` characters.
127     let tab = _mm256_cmpeq_epi8(dat, TAB);
128     // Check if `dat` includes `DEL` characters.
129     let del = _mm256_cmpeq_epi8(dat, DEL);
130 
131     // Combine all comparisons together, notice that we're also using OR
132     // to connect `low` and `tab` but flip bits of `del`.
133     //
134     // In the end, this is simply:
135     // ~del & (low | tab)
136     let bit = _mm256_andnot_si256(del, _mm256_or_si256(low, tab));
137     // This creates a bitmask from the most significant bit of each byte.
138     // Creates a scalar value from vector value.
139     let res = _mm256_movemask_epi8(bit) as u32;
140 
141     // Count trailing zeros to find the first encountered invalid character.
142     // Bitwise NOT is required once again to flip truthiness.
143     // TODO: use .trailing_ones() once MSRV >= 1.46
144     (!res).trailing_zeros() as usize
145 }
146 
147 #[test]
avx2_code_matches_uri_chars_table()148 fn avx2_code_matches_uri_chars_table() {
149     if !is_x86_feature_detected!("avx2") {
150         return;
151     }
152 
153     #[allow(clippy::undocumented_unsafe_blocks)]
154     unsafe {
155         assert!(byte_is_allowed(b'_', match_uri_vectored));
156 
157         for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() {
158             assert_eq!(
159                 byte_is_allowed(b as u8, match_uri_vectored), allowed,
160                 "byte_is_allowed({:?}) should be {:?}", b, allowed,
161             );
162         }
163     }
164 }
165 
166 #[test]
avx2_code_matches_header_value_chars_table()167 fn avx2_code_matches_header_value_chars_table() {
168     if !is_x86_feature_detected!("avx2") {
169         return;
170     }
171 
172     #[allow(clippy::undocumented_unsafe_blocks)]
173     unsafe {
174         assert!(byte_is_allowed(b'_', match_header_value_vectored));
175 
176         for (b, allowed) in crate::HEADER_VALUE_MAP.iter().cloned().enumerate() {
177             assert_eq!(
178                 byte_is_allowed(b as u8, match_header_value_vectored), allowed,
179                 "byte_is_allowed({:?}) should be {:?}", b, allowed,
180             );
181         }
182     }
183 }
184 
185 #[cfg(test)]
byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool186 unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool {
187     let slice = [
188         b'_', b'_', b'_', b'_',
189         b'_', b'_', b'_', b'_',
190         b'_', b'_', b'_', b'_',
191         b'_', b'_', b'_', b'_',
192         b'_', b'_', b'_', b'_',
193         b'_', b'_', b'_', b'_',
194         b'_', b'_', byte, b'_',
195         b'_', b'_', b'_', b'_',
196     ];
197     let mut bytes = Bytes::new(&slice);
198 
199     f(&mut bytes);
200 
201     match bytes.pos() {
202         32 => true,
203         26 => false,
204         _ => unreachable!(),
205     }
206 }
207