• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Masks that take up full SIMD vector registers.
2 
3 use super::MaskElement;
4 use crate::simd::intrinsics;
5 use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
6 
7 #[cfg(feature = "generic_const_exprs")]
8 use crate::simd::ToBitMaskArray;
9 
10 #[repr(transparent)]
11 pub struct Mask<T, const LANES: usize>(Simd<T, LANES>)
12 where
13     T: MaskElement,
14     LaneCount<LANES>: SupportedLaneCount;
15 
16 impl<T, const LANES: usize> Copy for Mask<T, LANES>
17 where
18     T: MaskElement,
19     LaneCount<LANES>: SupportedLaneCount,
20 {
21 }
22 
23 impl<T, const LANES: usize> Clone for Mask<T, LANES>
24 where
25     T: MaskElement,
26     LaneCount<LANES>: SupportedLaneCount,
27 {
28     #[inline]
29     #[must_use = "method returns a new mask and does not mutate the original value"]
clone(&self) -> Self30     fn clone(&self) -> Self {
31         *self
32     }
33 }
34 
35 impl<T, const LANES: usize> PartialEq for Mask<T, LANES>
36 where
37     T: MaskElement + PartialEq,
38     LaneCount<LANES>: SupportedLaneCount,
39 {
40     #[inline]
eq(&self, other: &Self) -> bool41     fn eq(&self, other: &Self) -> bool {
42         self.0.eq(&other.0)
43     }
44 }
45 
46 impl<T, const LANES: usize> PartialOrd for Mask<T, LANES>
47 where
48     T: MaskElement + PartialOrd,
49     LaneCount<LANES>: SupportedLaneCount,
50 {
51     #[inline]
partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering>52     fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
53         self.0.partial_cmp(&other.0)
54     }
55 }
56 
57 impl<T, const LANES: usize> Eq for Mask<T, LANES>
58 where
59     T: MaskElement + Eq,
60     LaneCount<LANES>: SupportedLaneCount,
61 {
62 }
63 
64 impl<T, const LANES: usize> Ord for Mask<T, LANES>
65 where
66     T: MaskElement + Ord,
67     LaneCount<LANES>: SupportedLaneCount,
68 {
69     #[inline]
cmp(&self, other: &Self) -> core::cmp::Ordering70     fn cmp(&self, other: &Self) -> core::cmp::Ordering {
71         self.0.cmp(&other.0)
72     }
73 }
74 
75 // Used for bitmask bit order workaround
76 pub(crate) trait ReverseBits {
77     // Reverse the least significant `n` bits of `self`.
78     // (Remaining bits must be 0.)
reverse_bits(self, n: usize) -> Self79     fn reverse_bits(self, n: usize) -> Self;
80 }
81 
82 macro_rules! impl_reverse_bits {
83     { $($int:ty),* } => {
84         $(
85         impl ReverseBits for $int {
86             #[inline(always)]
87             fn reverse_bits(self, n: usize) -> Self {
88                 let rev = <$int>::reverse_bits(self);
89                 let bitsize = core::mem::size_of::<$int>() * 8;
90                 if n < bitsize {
91                     // Shift things back to the right
92                     rev >> (bitsize - n)
93                 } else {
94                     rev
95                 }
96             }
97         }
98         )*
99     }
100 }
101 
102 impl_reverse_bits! { u8, u16, u32, u64 }
103 
104 impl<T, const LANES: usize> Mask<T, LANES>
105 where
106     T: MaskElement,
107     LaneCount<LANES>: SupportedLaneCount,
108 {
109     #[inline]
110     #[must_use = "method returns a new mask and does not mutate the original value"]
splat(value: bool) -> Self111     pub fn splat(value: bool) -> Self {
112         Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
113     }
114 
115     #[inline]
116     #[must_use = "method returns a new bool and does not mutate the original value"]
test_unchecked(&self, lane: usize) -> bool117     pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
118         T::eq(self.0[lane], T::TRUE)
119     }
120 
121     #[inline]
set_unchecked(&mut self, lane: usize, value: bool)122     pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
123         self.0[lane] = if value { T::TRUE } else { T::FALSE }
124     }
125 
126     #[inline]
127     #[must_use = "method returns a new vector and does not mutate the original value"]
to_int(self) -> Simd<T, LANES>128     pub fn to_int(self) -> Simd<T, LANES> {
129         self.0
130     }
131 
132     #[inline]
133     #[must_use = "method returns a new mask and does not mutate the original value"]
from_int_unchecked(value: Simd<T, LANES>) -> Self134     pub unsafe fn from_int_unchecked(value: Simd<T, LANES>) -> Self {
135         Self(value)
136     }
137 
138     #[inline]
139     #[must_use = "method returns a new mask and does not mutate the original value"]
convert<U>(self) -> Mask<U, LANES> where U: MaskElement,140     pub fn convert<U>(self) -> Mask<U, LANES>
141     where
142         U: MaskElement,
143     {
144         // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type.
145         unsafe { Mask(intrinsics::simd_cast(self.0)) }
146     }
147 
148     #[cfg(feature = "generic_const_exprs")]
149     #[inline]
150     #[must_use = "method returns a new array and does not mutate the original value"]
to_bitmask_array<const N: usize>(self) -> [u8; N] where super::Mask<T, LANES>: ToBitMaskArray, [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,151     pub fn to_bitmask_array<const N: usize>(self) -> [u8; N]
152     where
153         super::Mask<T, LANES>: ToBitMaskArray,
154         [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,
155     {
156         assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N);
157 
158         // Safety: N is the correct bitmask size
159         unsafe {
160             // Compute the bitmask
161             let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] =
162                 intrinsics::simd_bitmask(self.0);
163 
164             // Transmute to the return type, previously asserted to be the same size
165             let mut bitmask: [u8; N] = core::mem::transmute_copy(&bitmask);
166 
167             // LLVM assumes bit order should match endianness
168             if cfg!(target_endian = "big") {
169                 for x in bitmask.as_mut() {
170                     *x = x.reverse_bits();
171                 }
172             };
173 
174             bitmask
175         }
176     }
177 
178     #[cfg(feature = "generic_const_exprs")]
179     #[inline]
180     #[must_use = "method returns a new mask and does not mutate the original value"]
from_bitmask_array<const N: usize>(mut bitmask: [u8; N]) -> Self where super::Mask<T, LANES>: ToBitMaskArray, [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,181     pub fn from_bitmask_array<const N: usize>(mut bitmask: [u8; N]) -> Self
182     where
183         super::Mask<T, LANES>: ToBitMaskArray,
184         [(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,
185     {
186         assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N);
187 
188         // Safety: N is the correct bitmask size
189         unsafe {
190             // LLVM assumes bit order should match endianness
191             if cfg!(target_endian = "big") {
192                 for x in bitmask.as_mut() {
193                     *x = x.reverse_bits();
194                 }
195             }
196 
197             // Transmute to the bitmask type, previously asserted to be the same size
198             let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] =
199                 core::mem::transmute_copy(&bitmask);
200 
201             // Compute the regular mask
202             Self::from_int_unchecked(intrinsics::simd_select_bitmask(
203                 bitmask,
204                 Self::splat(true).to_int(),
205                 Self::splat(false).to_int(),
206             ))
207         }
208     }
209 
210     #[inline]
to_bitmask_integer<U: ReverseBits>(self) -> U where super::Mask<T, LANES>: ToBitMask<BitMask = U>,211     pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
212     where
213         super::Mask<T, LANES>: ToBitMask<BitMask = U>,
214     {
215         // Safety: U is required to be the appropriate bitmask type
216         let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
217 
218         // LLVM assumes bit order should match endianness
219         if cfg!(target_endian = "big") {
220             bitmask.reverse_bits(LANES)
221         } else {
222             bitmask
223         }
224     }
225 
226     #[inline]
from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self where super::Mask<T, LANES>: ToBitMask<BitMask = U>,227     pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
228     where
229         super::Mask<T, LANES>: ToBitMask<BitMask = U>,
230     {
231         // LLVM assumes bit order should match endianness
232         let bitmask = if cfg!(target_endian = "big") {
233             bitmask.reverse_bits(LANES)
234         } else {
235             bitmask
236         };
237 
238         // Safety: U is required to be the appropriate bitmask type
239         unsafe {
240             Self::from_int_unchecked(intrinsics::simd_select_bitmask(
241                 bitmask,
242                 Self::splat(true).to_int(),
243                 Self::splat(false).to_int(),
244             ))
245         }
246     }
247 
248     #[inline]
249     #[must_use = "method returns a new bool and does not mutate the original value"]
any(self) -> bool250     pub fn any(self) -> bool {
251         // Safety: use `self` as an integer vector
252         unsafe { intrinsics::simd_reduce_any(self.to_int()) }
253     }
254 
255     #[inline]
256     #[must_use = "method returns a new vector and does not mutate the original value"]
all(self) -> bool257     pub fn all(self) -> bool {
258         // Safety: use `self` as an integer vector
259         unsafe { intrinsics::simd_reduce_all(self.to_int()) }
260     }
261 }
262 
263 impl<T, const LANES: usize> From<Mask<T, LANES>> for Simd<T, LANES>
264 where
265     T: MaskElement,
266     LaneCount<LANES>: SupportedLaneCount,
267 {
268     #[inline]
from(value: Mask<T, LANES>) -> Self269     fn from(value: Mask<T, LANES>) -> Self {
270         value.0
271     }
272 }
273 
274 impl<T, const LANES: usize> core::ops::BitAnd for Mask<T, LANES>
275 where
276     T: MaskElement,
277     LaneCount<LANES>: SupportedLaneCount,
278 {
279     type Output = Self;
280     #[inline]
281     #[must_use = "method returns a new mask and does not mutate the original value"]
bitand(self, rhs: Self) -> Self282     fn bitand(self, rhs: Self) -> Self {
283         // Safety: `self` is an integer vector
284         unsafe { Self(intrinsics::simd_and(self.0, rhs.0)) }
285     }
286 }
287 
288 impl<T, const LANES: usize> core::ops::BitOr for Mask<T, LANES>
289 where
290     T: MaskElement,
291     LaneCount<LANES>: SupportedLaneCount,
292 {
293     type Output = Self;
294     #[inline]
295     #[must_use = "method returns a new mask and does not mutate the original value"]
bitor(self, rhs: Self) -> Self296     fn bitor(self, rhs: Self) -> Self {
297         // Safety: `self` is an integer vector
298         unsafe { Self(intrinsics::simd_or(self.0, rhs.0)) }
299     }
300 }
301 
302 impl<T, const LANES: usize> core::ops::BitXor for Mask<T, LANES>
303 where
304     T: MaskElement,
305     LaneCount<LANES>: SupportedLaneCount,
306 {
307     type Output = Self;
308     #[inline]
309     #[must_use = "method returns a new mask and does not mutate the original value"]
bitxor(self, rhs: Self) -> Self310     fn bitxor(self, rhs: Self) -> Self {
311         // Safety: `self` is an integer vector
312         unsafe { Self(intrinsics::simd_xor(self.0, rhs.0)) }
313     }
314 }
315 
316 impl<T, const LANES: usize> core::ops::Not for Mask<T, LANES>
317 where
318     T: MaskElement,
319     LaneCount<LANES>: SupportedLaneCount,
320 {
321     type Output = Self;
322     #[inline]
323     #[must_use = "method returns a new mask and does not mutate the original value"]
not(self) -> Self::Output324     fn not(self) -> Self::Output {
325         Self::splat(true) ^ self
326     }
327 }
328