1 //! Masks that take up full SIMD vector registers. 2 3 use super::MaskElement; 4 use crate::simd::intrinsics; 5 use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; 6 7 #[cfg(feature = "generic_const_exprs")] 8 use crate::simd::ToBitMaskArray; 9 10 #[repr(transparent)] 11 pub struct Mask<T, const LANES: usize>(Simd<T, LANES>) 12 where 13 T: MaskElement, 14 LaneCount<LANES>: SupportedLaneCount; 15 16 impl<T, const LANES: usize> Copy for Mask<T, LANES> 17 where 18 T: MaskElement, 19 LaneCount<LANES>: SupportedLaneCount, 20 { 21 } 22 23 impl<T, const LANES: usize> Clone for Mask<T, LANES> 24 where 25 T: MaskElement, 26 LaneCount<LANES>: SupportedLaneCount, 27 { 28 #[inline] 29 #[must_use = "method returns a new mask and does not mutate the original value"] clone(&self) -> Self30 fn clone(&self) -> Self { 31 *self 32 } 33 } 34 35 impl<T, const LANES: usize> PartialEq for Mask<T, LANES> 36 where 37 T: MaskElement + PartialEq, 38 LaneCount<LANES>: SupportedLaneCount, 39 { 40 #[inline] eq(&self, other: &Self) -> bool41 fn eq(&self, other: &Self) -> bool { 42 self.0.eq(&other.0) 43 } 44 } 45 46 impl<T, const LANES: usize> PartialOrd for Mask<T, LANES> 47 where 48 T: MaskElement + PartialOrd, 49 LaneCount<LANES>: SupportedLaneCount, 50 { 51 #[inline] partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering>52 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { 53 self.0.partial_cmp(&other.0) 54 } 55 } 56 57 impl<T, const LANES: usize> Eq for Mask<T, LANES> 58 where 59 T: MaskElement + Eq, 60 LaneCount<LANES>: SupportedLaneCount, 61 { 62 } 63 64 impl<T, const LANES: usize> Ord for Mask<T, LANES> 65 where 66 T: MaskElement + Ord, 67 LaneCount<LANES>: SupportedLaneCount, 68 { 69 #[inline] cmp(&self, other: &Self) -> core::cmp::Ordering70 fn cmp(&self, other: &Self) -> core::cmp::Ordering { 71 self.0.cmp(&other.0) 72 } 73 } 74 75 // Used for bitmask bit order workaround 76 pub(crate) trait ReverseBits { 77 // Reverse the least significant `n` bits of `self`. 78 // (Remaining bits must be 0.) reverse_bits(self, n: usize) -> Self79 fn reverse_bits(self, n: usize) -> Self; 80 } 81 82 macro_rules! impl_reverse_bits { 83 { $($int:ty),* } => { 84 $( 85 impl ReverseBits for $int { 86 #[inline(always)] 87 fn reverse_bits(self, n: usize) -> Self { 88 let rev = <$int>::reverse_bits(self); 89 let bitsize = core::mem::size_of::<$int>() * 8; 90 if n < bitsize { 91 // Shift things back to the right 92 rev >> (bitsize - n) 93 } else { 94 rev 95 } 96 } 97 } 98 )* 99 } 100 } 101 102 impl_reverse_bits! { u8, u16, u32, u64 } 103 104 impl<T, const LANES: usize> Mask<T, LANES> 105 where 106 T: MaskElement, 107 LaneCount<LANES>: SupportedLaneCount, 108 { 109 #[inline] 110 #[must_use = "method returns a new mask and does not mutate the original value"] splat(value: bool) -> Self111 pub fn splat(value: bool) -> Self { 112 Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) 113 } 114 115 #[inline] 116 #[must_use = "method returns a new bool and does not mutate the original value"] test_unchecked(&self, lane: usize) -> bool117 pub unsafe fn test_unchecked(&self, lane: usize) -> bool { 118 T::eq(self.0[lane], T::TRUE) 119 } 120 121 #[inline] set_unchecked(&mut self, lane: usize, value: bool)122 pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { 123 self.0[lane] = if value { T::TRUE } else { T::FALSE } 124 } 125 126 #[inline] 127 #[must_use = "method returns a new vector and does not mutate the original value"] to_int(self) -> Simd<T, LANES>128 pub fn to_int(self) -> Simd<T, LANES> { 129 self.0 130 } 131 132 #[inline] 133 #[must_use = "method returns a new mask and does not mutate the original value"] from_int_unchecked(value: Simd<T, LANES>) -> Self134 pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self { 135 Self(value) 136 } 137 138 #[inline] 139 #[must_use = "method returns a new mask and does not mutate the original value"] convert<U>(self) -> Mask<U, LANES> where U: MaskElement,140 pub fn convert<U>(self) -> Mask<U, LANES> 141 where 142 U: MaskElement, 143 { 144 // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type. 145 unsafe { Mask(intrinsics::simd_cast(self.0)) } 146 } 147 148 #[cfg(feature = "generic_const_exprs")] 149 #[inline] 150 #[must_use = "method returns a new array and does not mutate the original value"] to_bitmask_array<const N: usize>(self) -> [u8; N] where super::Mask<T, LANES>: ToBitMaskArray, [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,151 pub fn to_bitmask_array<const N: usize>(self) -> [u8; N] 152 where 153 super::Mask<T, LANES>: ToBitMaskArray, 154 [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized, 155 { 156 assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N); 157 158 // Safety: N is the correct bitmask size 159 unsafe { 160 // Compute the bitmask 161 let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] = 162 intrinsics::simd_bitmask(self.0); 163 164 // Transmute to the return type, previously asserted to be the same size 165 let mut bitmask: [u8; N] = core::mem::transmute_copy(&bitmask); 166 167 // LLVM assumes bit order should match endianness 168 if cfg!(target_endian = "big") { 169 for x in bitmask.as_mut() { 170 *x = x.reverse_bits(); 171 } 172 }; 173 174 bitmask 175 } 176 } 177 178 #[cfg(feature = "generic_const_exprs")] 179 #[inline] 180 #[must_use = "method returns a new mask and does not mutate the original value"] from_bitmask_array<const N: usize>(mut bitmask: [u8; N]) -> Self where super::Mask<T, LANES>: ToBitMaskArray, [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,181 pub fn from_bitmask_array<const N: usize>(mut bitmask: [u8; N]) -> Self 182 where 183 super::Mask<T, LANES>: ToBitMaskArray, 184 [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized, 185 { 186 assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N); 187 188 // Safety: N is the correct bitmask size 189 unsafe { 190 // LLVM assumes bit order should match endianness 191 if cfg!(target_endian = "big") { 192 for x in bitmask.as_mut() { 193 *x = x.reverse_bits(); 194 } 195 } 196 197 // Transmute to the bitmask type, previously asserted to be the same size 198 let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] = 199 core::mem::transmute_copy(&bitmask); 200 201 // Compute the regular mask 202 Self::from_int_unchecked(intrinsics::simd_select_bitmask( 203 bitmask, 204 Self::splat(true).to_int(), 205 Self::splat(false).to_int(), 206 )) 207 } 208 } 209 210 #[inline] to_bitmask_integer<U: ReverseBits>(self) -> U where super::Mask<T, LANES>: ToBitMask<BitMask = U>,211 pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U 212 where 213 super::Mask<T, LANES>: ToBitMask<BitMask = U>, 214 { 215 // Safety: U is required to be the appropriate bitmask type 216 let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) }; 217 218 // LLVM assumes bit order should match endianness 219 if cfg!(target_endian = "big") { 220 bitmask.reverse_bits(LANES) 221 } else { 222 bitmask 223 } 224 } 225 226 #[inline] from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self where super::Mask<T, LANES>: ToBitMask<BitMask = U>,227 pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self 228 where 229 super::Mask<T, LANES>: ToBitMask<BitMask = U>, 230 { 231 // LLVM assumes bit order should match endianness 232 let bitmask = if cfg!(target_endian = "big") { 233 bitmask.reverse_bits(LANES) 234 } else { 235 bitmask 236 }; 237 238 // Safety: U is required to be the appropriate bitmask type 239 unsafe { 240 Self::from_int_unchecked(intrinsics::simd_select_bitmask( 241 bitmask, 242 Self::splat(true).to_int(), 243 Self::splat(false).to_int(), 244 )) 245 } 246 } 247 248 #[inline] 249 #[must_use = "method returns a new bool and does not mutate the original value"] any(self) -> bool250 pub fn any(self) -> bool { 251 // Safety: use `self` as an integer vector 252 unsafe { intrinsics::simd_reduce_any(self.to_int()) } 253 } 254 255 #[inline] 256 #[must_use = "method returns a new vector and does not mutate the original value"] all(self) -> bool257 pub fn all(self) -> bool { 258 // Safety: use `self` as an integer vector 259 unsafe { intrinsics::simd_reduce_all(self.to_int()) } 260 } 261 } 262 263 impl<T, const LANES: usize> From<Mask<T, LANES>> for Simd<T, LANES> 264 where 265 T: MaskElement, 266 LaneCount<LANES>: SupportedLaneCount, 267 { 268 #[inline] from(value: Mask<T, LANES>) -> Self269 fn from(value: Mask<T, LANES>) -> Self { 270 value.0 271 } 272 } 273 274 impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES> 275 where 276 T: MaskElement, 277 LaneCount<LANES>: SupportedLaneCount, 278 { 279 type Output = Self; 280 #[inline] 281 #[must_use = "method returns a new mask and does not mutate the original value"] bitand(self, rhs: Self) -> Self282 fn bitand(self, rhs: Self) -> Self { 283 // Safety: `self` is an integer vector 284 unsafe { Self(intrinsics::simd_and(self.0, rhs.0)) } 285 } 286 } 287 288 impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES> 289 where 290 T: MaskElement, 291 LaneCount<LANES>: SupportedLaneCount, 292 { 293 type Output = Self; 294 #[inline] 295 #[must_use = "method returns a new mask and does not mutate the original value"] bitor(self, rhs: Self) -> Self296 fn bitor(self, rhs: Self) -> Self { 297 // Safety: `self` is an integer vector 298 unsafe { Self(intrinsics::simd_or(self.0, rhs.0)) } 299 } 300 } 301 302 impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES> 303 where 304 T: MaskElement, 305 LaneCount<LANES>: SupportedLaneCount, 306 { 307 type Output = Self; 308 #[inline] 309 #[must_use = "method returns a new mask and does not mutate the original value"] bitxor(self, rhs: Self) -> Self310 fn bitxor(self, rhs: Self) -> Self { 311 // Safety: `self` is an integer vector 312 unsafe { Self(intrinsics::simd_xor(self.0, rhs.0)) } 313 } 314 } 315 316 impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES> 317 where 318 T: MaskElement, 319 LaneCount<LANES>: SupportedLaneCount, 320 { 321 type Output = Self; 322 #[inline] 323 #[must_use = "method returns a new mask and does not mutate the original value"] not(self) -> Self::Output324 fn not(self) -> Self::Output { 325 Self::splat(true) ^ self 326 } 327 } 328