• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use super::{
2     assert_input_range, avalanche, primes::*, stripes_with_tail, Halves, Secret, SliceBackport as _,
3 };
4 
5 #[cfg(feature = "xxhash3_128")]
6 use super::X128;
7 
8 use crate::{IntoU128, IntoU64};
9 
10 // This module is not `cfg`-gated because it is used by some of the
11 // SIMD implementations.
12 pub mod scalar;
13 
14 #[cfg(target_arch = "aarch64")]
15 pub mod neon;
16 
17 #[cfg(target_arch = "x86_64")]
18 pub mod avx2;
19 
20 #[cfg(target_arch = "x86_64")]
21 pub mod sse2;
22 
23 macro_rules! dispatch {
24     (
25         fn $fn_name:ident<$($gen:ident),*>($($arg_name:ident : $arg_ty:ty),*) $(-> $ret_ty:ty)?
26         [$($wheres:tt)*]
27     ) => {
28         #[inline]
29         fn do_scalar<$($gen),*>($($arg_name : $arg_ty),*) $(-> $ret_ty)?
30         where
31             $($wheres)*
32         {
33             $fn_name($crate::xxhash3::large::scalar::Impl, $($arg_name),*)
34         }
35 
36         /// # Safety
37         ///
38         /// You must ensure that the CPU has the NEON feature
39         #[inline]
40         #[target_feature(enable = "neon")]
41         #[cfg(all(target_arch = "aarch64", feature = "std"))]
42         unsafe fn do_neon<$($gen),*>($($arg_name : $arg_ty),*) $(-> $ret_ty)?
43         where
44             $($wheres)*
45         {
46             // Safety: The caller has ensured we have the NEON feature
47             unsafe {
48                 $fn_name($crate::xxhash3::large::neon::Impl::new_unchecked(), $($arg_name),*)
49             }
50         }
51 
52         /// # Safety
53         ///
54         /// You must ensure that the CPU has the AVX2 feature
55         #[inline]
56         #[target_feature(enable = "avx2")]
57         #[cfg(all(target_arch = "x86_64", feature = "std"))]
58         unsafe fn do_avx2<$($gen),*>($($arg_name : $arg_ty),*) $(-> $ret_ty)?
59         where
60             $($wheres)*
61         {
62             // Safety: The caller has ensured we have the AVX2 feature
63             unsafe {
64                 $fn_name($crate::xxhash3::large::avx2::Impl::new_unchecked(), $($arg_name),*)
65             }
66         }
67 
68         /// # Safety
69         ///
70         /// You must ensure that the CPU has the SSE2 feature
71         #[inline]
72         #[target_feature(enable = "sse2")]
73         #[cfg(all(target_arch = "x86_64", feature = "std"))]
74         unsafe fn do_sse2<$($gen),*>($($arg_name : $arg_ty),*) $(-> $ret_ty)?
75         where
76             $($wheres)*
77         {
78             // Safety: The caller has ensured we have the SSE2 feature
79             unsafe {
80                 $fn_name($crate::xxhash3::large::sse2::Impl::new_unchecked(), $($arg_name),*)
81             }
82         }
83 
84         // Now we invoke the right function
85 
86         #[cfg(_internal_xxhash3_force_neon)]
87         return unsafe { do_neon($($arg_name),*) };
88 
89         #[cfg(_internal_xxhash3_force_avx2)]
90         return unsafe { do_avx2($($arg_name),*) };
91 
92         #[cfg(_internal_xxhash3_force_sse2)]
93         return unsafe { do_sse2($($arg_name),*) };
94 
95         #[cfg(_internal_xxhash3_force_scalar)]
96         return do_scalar($($arg_name),*);
97 
98         // This code can be unreachable if one of the `*_force_*` cfgs
99         // are set above, but that's the point.
100         #[allow(unreachable_code)]
101         {
102             #[cfg(all(target_arch = "aarch64", feature = "std"))]
103             {
104                 if std::arch::is_aarch64_feature_detected!("neon") {
105                     // Safety: We just ensured we have the NEON feature
106                     return unsafe { do_neon($($arg_name),*) };
107                 }
108             }
109 
110             #[cfg(all(target_arch = "x86_64", feature = "std"))]
111             {
112                 if is_x86_feature_detected!("avx2") {
113                     // Safety: We just ensured we have the AVX2 feature
114                     return unsafe { do_avx2($($arg_name),*) };
115                 } else if is_x86_feature_detected!("sse2") {
116                     // Safety: We just ensured we have the SSE2 feature
117                     return unsafe { do_sse2($($arg_name),*) };
118                 }
119             }
120             do_scalar($($arg_name),*)
121         }
122     };
123 }
124 pub(crate) use dispatch;
125 
126 pub trait Vector: Copy {
round_scramble(&self, acc: &mut [u64; 8], secret_end: &[u8; 64])127     fn round_scramble(&self, acc: &mut [u64; 8], secret_end: &[u8; 64]);
128 
accumulate(&self, acc: &mut [u64; 8], stripe: &[u8; 64], secret: &[u8; 64])129     fn accumulate(&self, acc: &mut [u64; 8], stripe: &[u8; 64], secret: &[u8; 64]);
130 }
131 
132 #[rustfmt::skip]
133 pub const INITIAL_ACCUMULATORS: [u64; 8] = [
134     PRIME32_3, PRIME64_1, PRIME64_2, PRIME64_3,
135     PRIME64_4, PRIME32_2, PRIME64_5, PRIME32_1,
136 ];
137 
138 pub struct Algorithm<V>(pub V);
139 
140 impl<V> Algorithm<V>
141 where
142     V: Vector,
143 {
144     #[inline]
oneshot<F>(&self, secret: &Secret, input: &[u8], finalize: F) -> F::Output where F: super::Finalize,145     pub fn oneshot<F>(&self, secret: &Secret, input: &[u8], finalize: F) -> F::Output
146     where
147         F: super::Finalize,
148     {
149         assert_input_range!(241.., input.len());
150         let mut acc = INITIAL_ACCUMULATORS;
151 
152         let stripes_per_block = (secret.len() - 64) / 8;
153         let block_size = 64 * stripes_per_block;
154 
155         let mut blocks = input.chunks_exact(block_size);
156 
157         let last_block = if blocks.remainder().is_empty() {
158             // Safety: We know that `input` is non-empty, which means
159             // that either there will be a remainder or one or more
160             // full blocks. That info isn't flowing to the optimizer,
161             // so we use `unwrap_unchecked`.
162             unsafe { blocks.next_back().unwrap_unchecked() }
163         } else {
164             blocks.remainder()
165         };
166 
167         self.rounds(&mut acc, blocks, secret);
168 
169         let len = input.len();
170 
171         let last_stripe = input.last_chunk().unwrap();
172         finalize.large(self.0, acc, last_block, last_stripe, secret, len)
173     }
174 
175     #[inline]
rounds<'a>( &self, acc: &mut [u64; 8], blocks: impl IntoIterator<Item = &'a [u8]>, secret: &Secret, )176     fn rounds<'a>(
177         &self,
178         acc: &mut [u64; 8],
179         blocks: impl IntoIterator<Item = &'a [u8]>,
180         secret: &Secret,
181     ) {
182         for block in blocks {
183             let (stripes, _) = block.bp_as_chunks();
184 
185             self.round(acc, stripes, secret);
186         }
187     }
188 
189     #[inline]
round(&self, acc: &mut [u64; 8], stripes: &[[u8; 64]], secret: &Secret)190     fn round(&self, acc: &mut [u64; 8], stripes: &[[u8; 64]], secret: &Secret) {
191         let secret_end = secret.last_stripe();
192 
193         self.round_accumulate(acc, stripes, secret);
194         self.0.round_scramble(acc, secret_end);
195     }
196 
197     #[inline]
round_accumulate(&self, acc: &mut [u64; 8], stripes: &[[u8; 64]], secret: &Secret)198     fn round_accumulate(&self, acc: &mut [u64; 8], stripes: &[[u8; 64]], secret: &Secret) {
199         let secrets = (0..stripes.len()).map(|i| {
200             // Safety: The number of stripes is determined by the
201             // block size, which is determined by the secret size.
202             unsafe { secret.stripe(i) }
203         });
204 
205         for (stripe, secret) in stripes.iter().zip(secrets) {
206             self.0.accumulate(acc, stripe, secret);
207         }
208     }
209 
210     #[inline(always)]
211     #[cfg(feature = "xxhash3_64")]
finalize_64( &self, mut acc: [u64; 8], last_block: &[u8], last_stripe: &[u8; 64], secret: &Secret, len: usize, ) -> u64212     pub fn finalize_64(
213         &self,
214         mut acc: [u64; 8],
215         last_block: &[u8],
216         last_stripe: &[u8; 64],
217         secret: &Secret,
218         len: usize,
219     ) -> u64 {
220         debug_assert!(!last_block.is_empty());
221         self.last_round(&mut acc, last_block, last_stripe, secret);
222 
223         let low = len.into_u64().wrapping_mul(PRIME64_1);
224         self.final_merge(&acc, low, secret.final_secret())
225     }
226 
227     #[inline]
228     #[cfg(feature = "xxhash3_128")]
finalize_128( &self, mut acc: [u64; 8], last_block: &[u8], last_stripe: &[u8; 64], secret: &Secret, len: usize, ) -> u128229     pub fn finalize_128(
230         &self,
231         mut acc: [u64; 8],
232         last_block: &[u8],
233         last_stripe: &[u8; 64],
234         secret: &Secret,
235         len: usize,
236     ) -> u128 {
237         debug_assert!(!last_block.is_empty());
238         self.last_round(&mut acc, last_block, last_stripe, secret);
239 
240         let len = len.into_u64();
241 
242         let low = len.wrapping_mul(PRIME64_1);
243         let low = self.final_merge(&acc, low, secret.final_secret());
244 
245         let high = !len.wrapping_mul(PRIME64_2);
246         let high = self.final_merge(&acc, high, secret.for_128().final_secret());
247 
248         X128 { low, high }.into()
249     }
250 
251     #[inline]
last_round( &self, acc: &mut [u64; 8], block: &[u8], last_stripe: &[u8; 64], secret: &Secret, )252     fn last_round(
253         &self,
254         acc: &mut [u64; 8],
255         block: &[u8],
256         last_stripe: &[u8; 64],
257         secret: &Secret,
258     ) {
259         // Accumulation steps are run for the stripes in the last block,
260         // except for the last stripe (whether it is full or not)
261         let (stripes, _) = stripes_with_tail(block);
262 
263         let secrets = (0..stripes.len()).map(|i| {
264             // Safety: The number of stripes is determined by the
265             // block size, which is determined by the secret size.
266             unsafe { secret.stripe(i) }
267         });
268 
269         for (stripe, secret) in stripes.iter().zip(secrets) {
270             self.0.accumulate(acc, stripe, secret);
271         }
272 
273         let last_stripe_secret = secret.last_stripe_secret_better_name();
274         self.0.accumulate(acc, last_stripe, last_stripe_secret);
275     }
276 
277     #[inline]
final_merge(&self, acc: &[u64; 8], init_value: u64, secret: &[u8; 64]) -> u64278     fn final_merge(&self, acc: &[u64; 8], init_value: u64, secret: &[u8; 64]) -> u64 {
279         let (secrets, _) = secret.bp_as_chunks();
280         let mut result = init_value;
281         for i in 0..4 {
282             // 64-bit by 64-bit multiplication to 128-bit full result
283             let mul_result = {
284                 let sa = u64::from_le_bytes(secrets[i * 2]);
285                 let sb = u64::from_le_bytes(secrets[i * 2 + 1]);
286 
287                 let a = (acc[i * 2] ^ sa).into_u128();
288                 let b = (acc[i * 2 + 1] ^ sb).into_u128();
289                 a.wrapping_mul(b)
290             };
291             result = result.wrapping_add(mul_result.lower_half() ^ mul_result.upper_half());
292         }
293         avalanche(result)
294     }
295 }
296