22//! Types representing
33#![ allow( non_camel_case_types) ]
44
5- #[ cfg_attr(
6- not( all( target_arch = "x86_64" , target_feature = "avx512f" ) ) ,
7- path = "masks/full_masks.rs"
8- ) ]
9- #[ cfg_attr(
10- all( target_arch = "x86_64" , target_feature = "avx512f" ) ,
11- path = "masks/bitmask.rs"
12- ) ]
13- mod mask_impl;
14-
15- use crate :: simd:: { LaneCount , Simd , SimdCast , SimdElement , SupportedLaneCount } ;
5+ use crate :: simd:: { LaneCount , Select , Simd , SimdCast , SimdElement , SupportedLaneCount } ;
166use core:: cmp:: Ordering ;
177use core:: { fmt, mem} ;
188
9+ pub ( crate ) trait FixEndianness {
10+ fn fix_endianness ( self ) -> Self ;
11+ }
12+
13+ macro_rules! impl_fix_endianness {
14+ { $( $int: ty) ,* } => {
15+ $(
16+ impl FixEndianness for $int {
17+ #[ inline( always) ]
18+ fn fix_endianness( self ) -> Self {
19+ if cfg!( target_endian = "big" ) {
20+ <$int>:: reverse_bits( self )
21+ } else {
22+ self
23+ }
24+ }
25+ }
26+ ) *
27+ }
28+ }
29+
30+ impl_fix_endianness ! { u8 , u16 , u32 , u64 }
31+
1932mod sealed {
2033 use super :: * ;
2134
@@ -109,7 +122,7 @@ impl_element! { isize, usize }
109122/// and/or Rust versions, and code should not assume that it is equivalent to
110123/// `[T; N]`.
111124#[ repr( transparent) ]
112- pub struct Mask < T , const N : usize > ( mask_impl :: Mask < T , N > )
125+ pub struct Mask < T , const N : usize > ( Simd < T , N > )
113126where
114127 T : MaskElement ,
115128 LaneCount < N > : SupportedLaneCount ;
@@ -141,7 +154,7 @@ where
141154 #[ inline]
142155 #[ rustc_const_unstable( feature = "portable_simd" , issue = "86656" ) ]
143156 pub const fn splat ( value : bool ) -> Self {
144- Self ( mask_impl :: Mask :: splat ( value) )
157+ Self ( Simd :: splat ( if value { T :: TRUE } else { T :: FALSE } ) )
145158 }
146159
147160 /// Converts an array of bools to a SIMD mask.
@@ -192,8 +205,8 @@ where
192205 // Safety: the caller must confirm this invariant
193206 unsafe {
194207 core:: intrinsics:: assume ( <T as Sealed >:: valid ( value) ) ;
195- Self ( mask_impl:: Mask :: from_simd_unchecked ( value) )
196208 }
209+ Self ( value)
197210 }
198211
199212 /// Converts a vector of integers to a mask, where 0 represents `false` and -1
@@ -215,14 +228,15 @@ where
215228 #[ inline]
216229 #[ must_use = "method returns a new vector and does not mutate the original value" ]
217230 pub fn to_simd ( self ) -> Simd < T , N > {
218- self . 0 . to_simd ( )
231+ self . 0
219232 }
220233
221234 /// Converts the mask to a mask of any other element size.
222235 #[ inline]
223236 #[ must_use = "method returns a new mask and does not mutate the original value" ]
224237 pub fn cast < U : MaskElement > ( self ) -> Mask < U , N > {
225- Mask ( self . 0 . convert ( ) )
238+ // Safety: mask elements are integers
239+ unsafe { Mask ( core:: intrinsics:: simd:: simd_as ( self . 0 ) ) }
226240 }
227241
228242 /// Tests the value of the specified element.
@@ -233,7 +247,7 @@ where
233247 #[ must_use = "method returns a new bool and does not mutate the original value" ]
234248 pub unsafe fn test_unchecked ( & self , index : usize ) -> bool {
235249 // Safety: the caller must confirm this invariant
236- unsafe { self . 0 . test_unchecked ( index) }
250+ unsafe { T :: eq ( * self . 0 . as_array ( ) . get_unchecked ( index) , T :: TRUE ) }
237251 }
238252
239253 /// Tests the value of the specified element.
@@ -244,9 +258,7 @@ where
244258 #[ must_use = "method returns a new bool and does not mutate the original value" ]
245259 #[ track_caller]
246260 pub fn test ( & self , index : usize ) -> bool {
247- assert ! ( index < N , "element index out of range" ) ;
248- // Safety: the element index has been checked
249- unsafe { self . test_unchecked ( index) }
261+ T :: eq ( self . 0 [ index] , T :: TRUE )
250262 }
251263
252264 /// Sets the value of the specified element.
@@ -257,7 +269,7 @@ where
257269 pub unsafe fn set_unchecked ( & mut self , index : usize , value : bool ) {
258270 // Safety: the caller must confirm this invariant
259271 unsafe {
260- self . 0 . set_unchecked ( index, value) ;
272+ * self . 0 . as_mut_array ( ) . get_unchecked_mut ( index) = if value { T :: TRUE } else { T :: FALSE }
261273 }
262274 }
263275
@@ -268,35 +280,67 @@ where
268280 #[ inline]
269281 #[ track_caller]
270282 pub fn set ( & mut self , index : usize , value : bool ) {
271- assert ! ( index < N , "element index out of range" ) ;
272- // Safety: the element index has been checked
273- unsafe {
274- self . set_unchecked ( index, value) ;
275- }
283+ self . 0 [ index] = if value { T :: TRUE } else { T :: FALSE }
276284 }
277285
278286 /// Returns true if any element is set, or false otherwise.
279287 #[ inline]
280288 #[ must_use = "method returns a new bool and does not mutate the original value" ]
281289 pub fn any ( self ) -> bool {
282- self . 0 . any ( )
290+ // Safety: `self` is a mask vector
291+ unsafe { core:: intrinsics:: simd:: simd_reduce_any ( self . 0 ) }
283292 }
284293
285294 /// Returns true if all elements are set, or false otherwise.
286295 #[ inline]
287296 #[ must_use = "method returns a new bool and does not mutate the original value" ]
288297 pub fn all ( self ) -> bool {
289- self . 0 . all ( )
298+ // Safety: `self` is a mask vector
299+ unsafe { core:: intrinsics:: simd:: simd_reduce_all ( self . 0 ) }
290300 }
291301
292302 /// Creates a bitmask from a mask.
293303 ///
294304 /// Each bit is set if the corresponding element in the mask is `true`.
295- /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
296305 #[ inline]
297306 #[ must_use = "method returns a new integer and does not mutate the original value" ]
298307 pub fn to_bitmask ( self ) -> u64 {
299- self . 0 . to_bitmask_integer ( )
308+ const {
309+ assert ! ( N <= 64 , "number of elements can't be greater than 64" ) ;
310+ }
311+
312+ #[ inline]
313+ unsafe fn to_bitmask_impl < T , U : FixEndianness , const M : usize , const N : usize > (
314+ mask : Mask < T , N > ,
315+ ) -> U
316+ where
317+ T : MaskElement ,
318+ LaneCount < M > : SupportedLaneCount ,
319+ LaneCount < N > : SupportedLaneCount ,
320+ {
321+ let resized = mask. resize :: < M > ( false ) ;
322+
323+ // Safety: `resized` is an integer vector with length M, which must match T
324+ let bitmask: U = unsafe { core:: intrinsics:: simd:: simd_bitmask ( resized. 0 ) } ;
325+
326+ // LLVM assumes bit order should match endianness
327+ bitmask. fix_endianness ( )
328+ }
329+
330+ // TODO modify simd_bitmask to zero-extend output, making this unnecessary
331+ if N <= 8 {
332+ // Safety: bitmask matches length
333+ unsafe { to_bitmask_impl :: < T , u8 , 8 , N > ( self ) as u64 }
334+ } else if N <= 16 {
335+ // Safety: bitmask matches length
336+ unsafe { to_bitmask_impl :: < T , u16 , 16 , N > ( self ) as u64 }
337+ } else if N <= 32 {
338+ // Safety: bitmask matches length
339+ unsafe { to_bitmask_impl :: < T , u32 , 32 , N > ( self ) as u64 }
340+ } else {
341+ // Safety: bitmask matches length
342+ unsafe { to_bitmask_impl :: < T , u64 , 64 , N > ( self ) }
343+ }
300344 }
301345
302346 /// Creates a mask from a bitmask.
@@ -306,7 +350,7 @@ where
306350 #[ inline]
307351 #[ must_use = "method returns a new mask and does not mutate the original value" ]
308352 pub fn from_bitmask ( bitmask : u64 ) -> Self {
309- Self ( mask_impl :: Mask :: from_bitmask_integer ( bitmask ) )
353+ Self ( bitmask . select ( Simd :: splat ( T :: TRUE ) , Simd :: splat ( T :: FALSE ) ) )
310354 }
311355
312356 /// Finds the index of the first set element.
@@ -450,7 +494,8 @@ where
450494 type Output = Self ;
451495 #[ inline]
452496 fn bitand ( self , rhs : Self ) -> Self {
453- Self ( self . 0 & rhs. 0 )
497+ // Safety: `self` is an integer vector
498+ unsafe { Self ( core:: intrinsics:: simd:: simd_and ( self . 0 , rhs. 0 ) ) }
454499 }
455500}
456501
@@ -486,7 +531,8 @@ where
486531 type Output = Self ;
487532 #[ inline]
488533 fn bitor ( self , rhs : Self ) -> Self {
489- Self ( self . 0 | rhs. 0 )
534+ // Safety: `self` is an integer vector
535+ unsafe { Self ( core:: intrinsics:: simd:: simd_or ( self . 0 , rhs. 0 ) ) }
490536 }
491537}
492538
@@ -522,7 +568,8 @@ where
522568 type Output = Self ;
523569 #[ inline]
524570 fn bitxor ( self , rhs : Self ) -> Self :: Output {
525- Self ( self . 0 ^ rhs. 0 )
571+ // Safety: `self` is an integer vector
572+ unsafe { Self ( core:: intrinsics:: simd:: simd_xor ( self . 0 , rhs. 0 ) ) }
526573 }
527574}
528575
@@ -558,7 +605,7 @@ where
558605 type Output = Mask < T , N > ;
559606 #[ inline]
560607 fn not ( self ) -> Self :: Output {
561- Self ( ! self . 0 )
608+ Self :: splat ( true ) ^ self
562609 }
563610}
564611
@@ -569,7 +616,7 @@ where
569616{
570617 #[ inline]
571618 fn bitand_assign ( & mut self , rhs : Self ) {
572- self . 0 = self . 0 & rhs. 0 ;
619+ * self = * self & rhs;
573620 }
574621}
575622
@@ -591,7 +638,7 @@ where
591638{
592639 #[ inline]
593640 fn bitor_assign ( & mut self , rhs : Self ) {
594- self . 0 = self . 0 | rhs. 0 ;
641+ * self = * self | rhs;
595642 }
596643}
597644
@@ -613,7 +660,7 @@ where
613660{
614661 #[ inline]
615662 fn bitxor_assign ( & mut self , rhs : Self ) {
616- self . 0 = self . 0 ^ rhs. 0 ;
663+ * self = * self ^ rhs;
617664 }
618665}
619666
0 commit comments