1313mod mask_impl;
1414
1515use crate :: simd:: {
16- cmp:: SimdPartialEq , intrinsics, LaneCount , Simd , SimdElement , SupportedLaneCount ,
16+ cmp:: SimdPartialEq , intrinsics, LaneCount , Simd , SimdCast , SimdElement , SupportedLaneCount ,
1717} ;
1818use core:: cmp:: Ordering ;
1919use core:: { fmt, mem} ;
@@ -35,6 +35,10 @@ mod sealed {
3535
3636 fn eq ( self , other : Self ) -> bool ;
3737
38+ fn as_usize ( self ) -> usize ;
39+
40+ type Unsigned : SimdElement ;
41+
3842 const TRUE : Self ;
3943
4044 const FALSE : Self ;
@@ -46,10 +50,10 @@ use sealed::Sealed;
4650///
4751/// # Safety
4852/// Type must be a signed integer.
49- pub unsafe trait MaskElement : SimdElement + Sealed { }
53+ pub unsafe trait MaskElement : SimdElement < Mask = Self > + SimdCast + Sealed { }
5054
5155macro_rules! impl_element {
52- { $ty: ty } => {
56+ { $ty: ty, $unsigned : ty } => {
5357 impl Sealed for $ty {
5458 #[ inline]
5559 fn valid<const N : usize >( value: Simd <Self , N >) -> bool
@@ -62,6 +66,13 @@ macro_rules! impl_element {
6266 #[ inline]
6367 fn eq( self , other: Self ) -> bool { self == other }
6468
69+ #[ inline]
70+ fn as_usize( self ) -> usize {
71+ self as usize
72+ }
73+
74+ type Unsigned = $unsigned;
75+
6576 const TRUE : Self = -1 ;
6677 const FALSE : Self = 0 ;
6778 }
@@ -71,11 +82,11 @@ macro_rules! impl_element {
7182 }
7283}
7384
74- impl_element ! { i8 }
75- impl_element ! { i16 }
76- impl_element ! { i32 }
77- impl_element ! { i64 }
78- impl_element ! { isize }
85+ impl_element ! { i8 , u8 }
86+ impl_element ! { i16 , u16 }
87+ impl_element ! { i32 , u32 }
88+ impl_element ! { i64 , u64 }
89+ impl_element ! { isize , usize }
7990
8091/// A SIMD vector mask for `N` elements of width specified by `Element`.
8192///
@@ -298,6 +309,67 @@ where
298309 pub fn from_bitmask_vector ( bitmask : Simd < u8 , N > ) -> Self {
299310 Self ( mask_impl:: Mask :: from_bitmask_vector ( bitmask) )
300311 }
312+
313+ /// Find the index of the first set element.
314+ ///
315+ /// ```
316+ /// # #![feature(portable_simd)]
317+ /// # #[cfg(feature = "as_crate")] use core_simd::simd;
318+ /// # #[cfg(not(feature = "as_crate"))] use core::simd;
319+ /// # use simd::mask32x8;
320+ /// assert_eq!(mask32x8::splat(false).first_set(), None);
321+ /// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
322+ ///
323+ /// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
324+ /// assert_eq!(mask.first_set(), Some(1));
325+ /// ```
326+ #[ inline]
327+ #[ must_use = "method returns the index and does not mutate the original value" ]
328+ pub fn first_set ( self ) -> Option < usize > {
329+ // If bitmasks are efficient, using them is better
330+ if cfg ! ( target_feature = "sse" ) && N <= 64 {
331+ let tz = self . to_bitmask ( ) . trailing_zeros ( ) ;
332+ return if tz == 64 { None } else { Some ( tz as usize ) } ;
333+ }
334+
335+ // To find the first set index:
336+ // * create a vector 0..N
337+ // * replace unset mask elements in that vector with -1
338+ // * perform _unsigned_ reduce-min
339+ // * check if the result is -1 or an index
340+
341+ let index = Simd :: from_array (
342+ const {
343+ let mut index = [ 0 ; N ] ;
344+ let mut i = 0 ;
345+ while i < N {
346+ index[ i] = i;
347+ i += 1 ;
348+ }
349+ index
350+ } ,
351+ ) ;
352+
353+ // Safety: the input and output are integer vectors
354+ let index: Simd < T , N > = unsafe { intrinsics:: simd_cast ( index) } ;
355+
356+ let masked_index = self . select ( index, Self :: splat ( true ) . to_int ( ) ) ;
357+
358+ // Safety: the input and output are integer vectors
359+ let masked_index: Simd < T :: Unsigned , N > = unsafe { intrinsics:: simd_cast ( masked_index) } ;
360+
361+ // Safety: the input is an integer vector
362+ let min_index: T :: Unsigned = unsafe { intrinsics:: simd_reduce_min ( masked_index) } ;
363+
364+ // Safety: the return value is the unsigned version of T
365+ let min_index: T = unsafe { core:: mem:: transmute_copy ( & min_index) } ;
366+
367+ if min_index. eq ( T :: TRUE ) {
368+ None
369+ } else {
370+ Some ( min_index. as_usize ( ) )
371+ }
372+ }
301373}
302374
303375// vector/array conversion
0 commit comments