1 // Copyright 2018 Developers of the Rand project. 2 // 3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license 5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your 6 // option. This file may not be copied, modified, or distributed 7 // except according to those terms. 8 9 //! Math helper functions 10 11 #[cfg(feature = "simd_support")] use packed_simd::*; 12 13 14 pub(crate) trait WideningMultiply<RHS = Self> { 15 type Output; 16 wmul(self, x: RHS) -> Self::Output17 fn wmul(self, x: RHS) -> Self::Output; 18 } 19 20 macro_rules! wmul_impl { 21 ($ty:ty, $wide:ty, $shift:expr) => { 22 impl WideningMultiply for $ty { 23 type Output = ($ty, $ty); 24 25 #[inline(always)] 26 fn wmul(self, x: $ty) -> Self::Output { 27 let tmp = (self as $wide) * (x as $wide); 28 ((tmp >> $shift) as $ty, tmp as $ty) 29 } 30 } 31 }; 32 33 // simd bulk implementation 34 ($(($ty:ident, $wide:ident),)+, $shift:expr) => { 35 $( 36 impl WideningMultiply for $ty { 37 type Output = ($ty, $ty); 38 39 #[inline(always)] 40 fn wmul(self, x: $ty) -> Self::Output { 41 // For supported vectors, this should compile to a couple 42 // supported multiply & swizzle instructions (no actual 43 // casting). 44 // TODO: optimize 45 let y: $wide = self.cast(); 46 let x: $wide = x.cast(); 47 let tmp = y * x; 48 let hi: $ty = (tmp >> $shift).cast(); 49 let lo: $ty = tmp.cast(); 50 (hi, lo) 51 } 52 } 53 )+ 54 }; 55 } 56 wmul_impl! { u8, u16, 8 } 57 wmul_impl! { u16, u32, 16 } 58 wmul_impl! { u32, u64, 32 } 59 #[cfg(not(target_os = "emscripten"))] 60 wmul_impl! { u64, u128, 64 } 61 62 // This code is a translation of the __mulddi3 function in LLVM's 63 // compiler-rt. It is an optimised variant of the common method 64 // `(a + b) * (c + d) = ac + ad + bc + bd`. 65 // 66 // For some reason LLVM can optimise the C version very well, but 67 // keeps shuffling registers in this Rust translation. 68 macro_rules! wmul_impl_large { 69 ($ty:ty, $half:expr) => { 70 impl WideningMultiply for $ty { 71 type Output = ($ty, $ty); 72 73 #[inline(always)] 74 fn wmul(self, b: $ty) -> Self::Output { 75 const LOWER_MASK: $ty = !0 >> $half; 76 let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK); 77 let mut t = low >> $half; 78 low &= LOWER_MASK; 79 t += (self >> $half).wrapping_mul(b & LOWER_MASK); 80 low += (t & LOWER_MASK) << $half; 81 let mut high = t >> $half; 82 t = low >> $half; 83 low &= LOWER_MASK; 84 t += (b >> $half).wrapping_mul(self & LOWER_MASK); 85 low += (t & LOWER_MASK) << $half; 86 high += t >> $half; 87 high += (self >> $half).wrapping_mul(b >> $half); 88 89 (high, low) 90 } 91 } 92 }; 93 94 // simd bulk implementation 95 (($($ty:ty,)+) $scalar:ty, $half:expr) => { 96 $( 97 impl WideningMultiply for $ty { 98 type Output = ($ty, $ty); 99 100 #[inline(always)] 101 fn wmul(self, b: $ty) -> Self::Output { 102 // needs wrapping multiplication 103 const LOWER_MASK: $scalar = !0 >> $half; 104 let mut low = (self & LOWER_MASK) * (b & LOWER_MASK); 105 let mut t = low >> $half; 106 low &= LOWER_MASK; 107 t += (self >> $half) * (b & LOWER_MASK); 108 low += (t & LOWER_MASK) << $half; 109 let mut high = t >> $half; 110 t = low >> $half; 111 low &= LOWER_MASK; 112 t += (b >> $half) * (self & LOWER_MASK); 113 low += (t & LOWER_MASK) << $half; 114 high += t >> $half; 115 high += (self >> $half) * (b >> $half); 116 117 (high, low) 118 } 119 } 120 )+ 121 }; 122 } 123 #[cfg(target_os = "emscripten")] 124 wmul_impl_large! { u64, 32 } 125 #[cfg(not(target_os = "emscripten"))] 126 wmul_impl_large! { u128, 64 } 127 128 macro_rules! wmul_impl_usize { 129 ($ty:ty) => { 130 impl WideningMultiply for usize { 131 type Output = (usize, usize); 132 133 #[inline(always)] 134 fn wmul(self, x: usize) -> Self::Output { 135 let (high, low) = (self as $ty).wmul(x as $ty); 136 (high as usize, low as usize) 137 } 138 } 139 }; 140 } 141 #[cfg(target_pointer_width = "32")] 142 wmul_impl_usize! { u32 } 143 #[cfg(target_pointer_width = "64")] 144 wmul_impl_usize! { u64 } 145 146 #[cfg(feature = "simd_support")] 147 mod simd_wmul { 148 use super::*; 149 #[cfg(target_arch = "x86")] use core::arch::x86::*; 150 #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; 151 152 wmul_impl! { 153 (u8x2, u16x2), 154 (u8x4, u16x4), 155 (u8x8, u16x8), 156 (u8x16, u16x16), 157 (u8x32, u16x32),, 158 8 159 } 160 161 wmul_impl! { (u16x2, u32x2),, 16 } 162 wmul_impl! { (u16x4, u32x4),, 16 } 163 #[cfg(not(target_feature = "sse2"))] 164 wmul_impl! { (u16x8, u32x8),, 16 } 165 #[cfg(not(target_feature = "avx2"))] 166 wmul_impl! { (u16x16, u32x16),, 16 } 167 168 // 16-bit lane widths allow use of the x86 `mulhi` instructions, which 169 // means `wmul` can be implemented with only two instructions. 170 #[allow(unused_macros)] 171 macro_rules! wmul_impl_16 { 172 ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => { 173 impl WideningMultiply for $ty { 174 type Output = ($ty, $ty); 175 176 #[inline(always)] 177 fn wmul(self, x: $ty) -> Self::Output { 178 let b = $intrinsic::from_bits(x); 179 let a = $intrinsic::from_bits(self); 180 let hi = $ty::from_bits(unsafe { $mulhi(a, b) }); 181 let lo = $ty::from_bits(unsafe { $mullo(a, b) }); 182 (hi, lo) 183 } 184 } 185 }; 186 } 187 188 #[cfg(target_feature = "sse2")] 189 wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 } 190 #[cfg(target_feature = "avx2")] 191 wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 } 192 // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>` 193 // cannot use the same implementation. 194 195 wmul_impl! { 196 (u32x2, u64x2), 197 (u32x4, u64x4), 198 (u32x8, u64x8),, 199 32 200 } 201 202 // TODO: optimize, this seems to seriously slow things down 203 wmul_impl_large! { (u8x64,) u8, 4 } 204 wmul_impl_large! { (u16x32,) u16, 8 } 205 wmul_impl_large! { (u32x16,) u32, 16 } 206 wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 } 207 } 208 209 /// Helper trait when dealing with scalar and SIMD floating point types. 210 pub(crate) trait FloatSIMDUtils { 211 // `PartialOrd` for vectors compares lexicographically. We want to compare all 212 // the individual SIMD lanes instead, and get the combined result over all 213 // lanes. This is possible using something like `a.lt(b).all()`, but we 214 // implement it as a trait so we can write the same code for `f32` and `f64`. 215 // Only the comparison functions we need are implemented. all_lt(self, other: Self) -> bool216 fn all_lt(self, other: Self) -> bool; all_le(self, other: Self) -> bool217 fn all_le(self, other: Self) -> bool; all_finite(self) -> bool218 fn all_finite(self) -> bool; 219 220 type Mask; finite_mask(self) -> Self::Mask221 fn finite_mask(self) -> Self::Mask; gt_mask(self, other: Self) -> Self::Mask222 fn gt_mask(self, other: Self) -> Self::Mask; ge_mask(self, other: Self) -> Self::Mask223 fn ge_mask(self, other: Self) -> Self::Mask; 224 225 // Decrease all lanes where the mask is `true` to the next lower value 226 // representable by the floating-point type. At least one of the lanes 227 // must be set. decrease_masked(self, mask: Self::Mask) -> Self228 fn decrease_masked(self, mask: Self::Mask) -> Self; 229 230 // Convert from int value. Conversion is done while retaining the numerical 231 // value, not by retaining the binary representation. 232 type UInt; cast_from_int(i: Self::UInt) -> Self233 fn cast_from_int(i: Self::UInt) -> Self; 234 } 235 236 /// Implement functions available in std builds but missing from core primitives 237 #[cfg(not(std))] 238 pub(crate) trait Float: Sized { is_nan(self) -> bool239 fn is_nan(self) -> bool; is_infinite(self) -> bool240 fn is_infinite(self) -> bool; is_finite(self) -> bool241 fn is_finite(self) -> bool; 242 } 243 244 /// Implement functions on f32/f64 to give them APIs similar to SIMD types 245 pub(crate) trait FloatAsSIMD: Sized { 246 #[inline(always)] lanes() -> usize247 fn lanes() -> usize { 248 1 249 } 250 #[inline(always)] splat(scalar: Self) -> Self251 fn splat(scalar: Self) -> Self { 252 scalar 253 } 254 #[inline(always)] extract(self, index: usize) -> Self255 fn extract(self, index: usize) -> Self { 256 debug_assert_eq!(index, 0); 257 self 258 } 259 #[inline(always)] replace(self, index: usize, new_value: Self) -> Self260 fn replace(self, index: usize, new_value: Self) -> Self { 261 debug_assert_eq!(index, 0); 262 new_value 263 } 264 } 265 266 pub(crate) trait BoolAsSIMD: Sized { any(self) -> bool267 fn any(self) -> bool; all(self) -> bool268 fn all(self) -> bool; none(self) -> bool269 fn none(self) -> bool; 270 } 271 272 impl BoolAsSIMD for bool { 273 #[inline(always)] any(self) -> bool274 fn any(self) -> bool { 275 self 276 } 277 278 #[inline(always)] all(self) -> bool279 fn all(self) -> bool { 280 self 281 } 282 283 #[inline(always)] none(self) -> bool284 fn none(self) -> bool { 285 !self 286 } 287 } 288 289 macro_rules! scalar_float_impl { 290 ($ty:ident, $uty:ident) => { 291 #[cfg(not(std))] 292 impl Float for $ty { 293 #[inline] 294 fn is_nan(self) -> bool { 295 self != self 296 } 297 298 #[inline] 299 fn is_infinite(self) -> bool { 300 self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY 301 } 302 303 #[inline] 304 fn is_finite(self) -> bool { 305 !(self.is_nan() || self.is_infinite()) 306 } 307 } 308 309 impl FloatSIMDUtils for $ty { 310 type Mask = bool; 311 type UInt = $uty; 312 313 #[inline(always)] 314 fn all_lt(self, other: Self) -> bool { 315 self < other 316 } 317 318 #[inline(always)] 319 fn all_le(self, other: Self) -> bool { 320 self <= other 321 } 322 323 #[inline(always)] 324 fn all_finite(self) -> bool { 325 self.is_finite() 326 } 327 328 #[inline(always)] 329 fn finite_mask(self) -> Self::Mask { 330 self.is_finite() 331 } 332 333 #[inline(always)] 334 fn gt_mask(self, other: Self) -> Self::Mask { 335 self > other 336 } 337 338 #[inline(always)] 339 fn ge_mask(self, other: Self) -> Self::Mask { 340 self >= other 341 } 342 343 #[inline(always)] 344 fn decrease_masked(self, mask: Self::Mask) -> Self { 345 debug_assert!(mask, "At least one lane must be set"); 346 <$ty>::from_bits(self.to_bits() - 1) 347 } 348 349 #[inline] 350 fn cast_from_int(i: Self::UInt) -> Self { 351 i as $ty 352 } 353 } 354 355 impl FloatAsSIMD for $ty {} 356 }; 357 } 358 359 scalar_float_impl!(f32, u32); 360 scalar_float_impl!(f64, u64); 361 362 363 #[cfg(feature = "simd_support")] 364 macro_rules! simd_impl { 365 ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => { 366 impl FloatSIMDUtils for $ty { 367 type Mask = $mty; 368 type UInt = $uty; 369 370 #[inline(always)] 371 fn all_lt(self, other: Self) -> bool { 372 self.lt(other).all() 373 } 374 375 #[inline(always)] 376 fn all_le(self, other: Self) -> bool { 377 self.le(other).all() 378 } 379 380 #[inline(always)] 381 fn all_finite(self) -> bool { 382 self.finite_mask().all() 383 } 384 385 #[inline(always)] 386 fn finite_mask(self) -> Self::Mask { 387 // This can possibly be done faster by checking bit patterns 388 let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY); 389 let pos_inf = $ty::splat(::core::$f_scalar::INFINITY); 390 self.gt(neg_inf) & self.lt(pos_inf) 391 } 392 393 #[inline(always)] 394 fn gt_mask(self, other: Self) -> Self::Mask { 395 self.gt(other) 396 } 397 398 #[inline(always)] 399 fn ge_mask(self, other: Self) -> Self::Mask { 400 self.ge(other) 401 } 402 403 #[inline(always)] 404 fn decrease_masked(self, mask: Self::Mask) -> Self { 405 // Casting a mask into ints will produce all bits set for 406 // true, and 0 for false. Adding that to the binary 407 // representation of a float means subtracting one from 408 // the binary representation, resulting in the next lower 409 // value representable by $ty. This works even when the 410 // current value is infinity. 411 debug_assert!(mask.any(), "At least one lane must be set"); 412 <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask)) 413 } 414 415 #[inline] 416 fn cast_from_int(i: Self::UInt) -> Self { 417 i.cast() 418 } 419 } 420 }; 421 } 422 423 #[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 } 424 #[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 } 425 #[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 } 426 #[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 } 427 #[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 } 428 #[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 } 429 #[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 } 430