1+ use crate :: Mask ;
2+ use core:: marker:: PhantomData ;
3+
4+ /// Helper trait for limiting int conversion types
5+ pub trait ConvertToInt { }
6+ impl < const LANES : usize > ConvertToInt for crate :: SimdI8 < LANES > where Self : crate :: LanesAtMost32 { }
7+ impl < const LANES : usize > ConvertToInt for crate :: SimdI16 < LANES > where Self : crate :: LanesAtMost32 { }
8+ impl < const LANES : usize > ConvertToInt for crate :: SimdI32 < LANES > where Self : crate :: LanesAtMost32 { }
9+ impl < const LANES : usize > ConvertToInt for crate :: SimdI64 < LANES > where Self : crate :: LanesAtMost32 { }
10+ impl < const LANES : usize > ConvertToInt for crate :: SimdIsize < LANES > where Self : crate :: LanesAtMost32 { }
11+
112/// A mask where each lane is represented by a single bit.
2- #[ derive( Copy , Clone , Debug , PartialOrd , PartialEq , Ord , Eq , Hash ) ]
313#[ repr( transparent) ]
4- pub struct BitMask < const LANES : usize > ( u64 ) ;
14+ pub struct BitMask < T : Mask , const LANES : usize > ( T :: BitMask , PhantomData < [ ( ) ; LANES ] > ) ;
515
6- impl < const LANES : usize > BitMask < LANES >
7- {
16+ impl < T : Mask , const LANES : usize > Copy for BitMask < T , LANES > { }
17+
18+ impl < T : Mask , const LANES : usize > Clone for BitMask < T , LANES > {
19+ fn clone ( & self ) -> Self {
20+ * self
21+ }
22+ }
23+
24+ impl < T : Mask , const LANES : usize > PartialEq for BitMask < T , LANES > {
25+ fn eq ( & self , other : & Self ) -> bool {
26+ self . 0 . as_ref ( ) == other. 0 . as_ref ( )
27+ }
28+ }
29+
30+ impl < T : Mask , const LANES : usize > PartialOrd for BitMask < T , LANES > {
31+ fn partial_cmp ( & self , other : & Self ) -> Option < core:: cmp:: Ordering > {
32+ self . 0 . as_ref ( ) . partial_cmp ( other. 0 . as_ref ( ) )
33+ }
34+ }
35+
36+ impl < T : Mask , const LANES : usize > Eq for BitMask < T , LANES > { }
37+
38+ impl < T : Mask , const LANES : usize > Ord for BitMask < T , LANES > {
39+ fn cmp ( & self , other : & Self ) -> core:: cmp:: Ordering {
40+ self . 0 . as_ref ( ) . cmp ( other. 0 . as_ref ( ) )
41+ }
42+ }
43+
44+ impl < T : Mask , const LANES : usize > BitMask < T , LANES > {
845 #[ inline]
946 pub fn splat ( value : bool ) -> Self {
47+ let mut mask = T :: BitMask :: default ( ) ;
1048 if value {
11- Self ( u64 :: MAX >> ( 64 - LANES ) )
49+ mask . as_mut ( ) . fill ( u8 :: MAX )
1250 } else {
13- Self ( u64:: MIN )
51+ mask. as_mut ( ) . fill ( u8:: MIN )
52+ }
53+ if LANES % 8 > 0 {
54+ * mask. as_mut ( ) . last_mut ( ) . unwrap ( ) &= u8:: MAX >> ( 8 - LANES % 8 ) ;
1455 }
56+ Self ( mask, PhantomData )
1557 }
1658
1759 #[ inline]
1860 pub unsafe fn test_unchecked ( & self , lane : usize ) -> bool {
19- ( self . 0 >> lane) & 0x1 > 0
61+ ( self . 0 . as_ref ( ) [ lane / 8 ] >> lane % 8 ) & 0x1 > 0
2062 }
2163
2264 #[ inline]
2365 pub unsafe fn set_unchecked ( & mut self , lane : usize , value : bool ) {
24- self . 0 ^= ( ( value ^ self . test_unchecked ( lane) ) as u64 ) << lane
66+ self . 0 . as_mut ( ) [ lane / 8 ] ^= ( ( value ^ self . test_unchecked ( lane) ) as u8 ) << ( lane % 8 )
2567 }
2668
2769 #[ inline]
28- pub fn to_int < V , T > ( self ) -> V
70+ pub fn to_int < V > ( self ) -> V
2971 where
30- V : Default + AsMut < [ T ; LANES ] > ,
31- T : From < i8 > ,
72+ V : ConvertToInt + Default + core:: ops:: Not < Output = V > ,
3273 {
33- // TODO this should be an intrinsic sign-extension
34- let mut v = V :: default ( ) ;
35- for i in 0 ..LANES {
36- let lane = unsafe { self . test_unchecked ( i) } ;
37- v. as_mut ( ) [ i] = ( -( lane as i8 ) ) . into ( ) ;
74+ unsafe {
75+ let mask: T :: IntBitMask = core:: mem:: transmute_copy ( & self ) ;
76+ crate :: intrinsics:: simd_select_bitmask ( mask, !V :: default ( ) , V :: default ( ) )
3877 }
39- v
4078 }
4179
4280 #[ inline]
4381 pub unsafe fn from_int_unchecked < V > ( value : V ) -> Self
4482 where
4583 V : crate :: LanesAtMost32 ,
4684 {
47- let mask: V :: BitMask = crate :: intrinsics:: simd_bitmask ( value) ;
48- Self ( mask. into ( ) )
85+ // TODO remove the transmute when rustc is more flexible
86+ assert_eq ! (
87+ core:: mem:: size_of:: <T :: IntBitMask >( ) ,
88+ core:: mem:: size_of:: <T :: BitMask >( )
89+ ) ;
90+ let mask: T :: IntBitMask = crate :: intrinsics:: simd_bitmask ( value) ;
91+ Self ( core:: mem:: transmute_copy ( & mask) , PhantomData )
4992 }
5093
5194 #[ inline]
52- pub fn to_bitmask ( self ) -> u64 {
53- self . 0
95+ pub fn to_bitmask < U : Mask > ( self ) -> U :: BitMask {
96+ assert_eq ! (
97+ core:: mem:: size_of:: <T :: BitMask >( ) ,
98+ core:: mem:: size_of:: <U :: BitMask >( )
99+ ) ;
100+ unsafe { core:: mem:: transmute_copy ( & self . 0 ) }
54101 }
55102
56103 #[ inline]
@@ -64,87 +111,61 @@ impl<const LANES: usize> BitMask<LANES>
64111 }
65112}
66113
67- impl < const LANES : usize > core:: ops:: BitAnd for BitMask < LANES >
114+ impl < T : Mask , const LANES : usize > core:: ops:: BitAnd for BitMask < T , LANES >
115+ where
116+ T :: BitMask : Default + AsRef < [ u8 ] > + AsMut < [ u8 ] > ,
68117{
69118 type Output = Self ;
70119 #[ inline]
71- fn bitand ( self , rhs : Self ) -> Self {
72- Self ( self . 0 & rhs. 0 )
120+ fn bitand ( mut self , rhs : Self ) -> Self {
121+ for ( l, r) in self . 0 . as_mut ( ) . iter_mut ( ) . zip ( rhs. 0 . as_ref ( ) . iter ( ) ) {
122+ * l &= r;
123+ }
124+ self
73125 }
74126}
75127
76- impl < const LANES : usize > core:: ops:: BitAnd < bool > for BitMask < LANES >
128+ impl < T : Mask , const LANES : usize > core:: ops:: BitOr for BitMask < T , LANES >
129+ where
130+ T :: BitMask : Default + AsRef < [ u8 ] > + AsMut < [ u8 ] > ,
77131{
78132 type Output = Self ;
79133 #[ inline]
80- fn bitand ( self , rhs : bool ) -> Self {
81- self & Self :: splat ( rhs)
82- }
83- }
84-
85- impl < const LANES : usize > core:: ops:: BitAnd < BitMask < LANES > > for bool
86- {
87- type Output = BitMask < LANES > ;
88- #[ inline]
89- fn bitand ( self , rhs : BitMask < LANES > ) -> BitMask < LANES > {
90- BitMask :: < LANES > :: splat ( self ) & rhs
134+ fn bitor ( mut self , rhs : Self ) -> Self {
135+ for ( l, r) in self . 0 . as_mut ( ) . iter_mut ( ) . zip ( rhs. 0 . as_ref ( ) . iter ( ) ) {
136+ * l |= r;
137+ }
138+ self
91139 }
92140}
93141
94- impl < const LANES : usize > core:: ops:: BitOr for BitMask < LANES >
95- {
142+ impl < T : Mask , const LANES : usize > core:: ops:: BitXor for BitMask < T , LANES > {
96143 type Output = Self ;
97144 #[ inline]
98- fn bitor ( self , rhs : Self ) -> Self {
99- Self ( self . 0 | rhs. 0 )
145+ fn bitxor ( mut self , rhs : Self ) -> Self :: Output {
146+ for ( l, r) in self . 0 . as_mut ( ) . iter_mut ( ) . zip ( rhs. 0 . as_ref ( ) . iter ( ) ) {
147+ * l ^= r;
148+ }
149+ self
100150 }
101151}
102152
103- impl < const LANES : usize > core:: ops:: BitXor for BitMask < LANES >
104- {
153+ impl < T : Mask , const LANES : usize > core:: ops:: Not for BitMask < T , LANES > {
105154 type Output = Self ;
106155 #[ inline]
107- fn bitxor ( self , rhs : Self ) -> Self :: Output {
108- Self ( self . 0 ^ rhs. 0 )
109- }
110- }
111-
112- impl < const LANES : usize > core:: ops:: Not for BitMask < LANES >
113- {
114- type Output = BitMask < LANES > ;
115- #[ inline]
116- fn not ( self ) -> Self :: Output {
117- Self ( !self . 0 ) & Self :: splat ( true )
118- }
119- }
120-
121- impl < const LANES : usize > core:: ops:: BitAndAssign for BitMask < LANES >
122- {
123- #[ inline]
124- fn bitand_assign ( & mut self , rhs : Self ) {
125- self . 0 &= rhs. 0 ;
126- }
127- }
128-
129- impl < const LANES : usize > core:: ops:: BitOrAssign for BitMask < LANES >
130- {
131- #[ inline]
132- fn bitor_assign ( & mut self , rhs : Self ) {
133- self . 0 |= rhs. 0 ;
134- }
135- }
136-
137- impl < const LANES : usize > core:: ops:: BitXorAssign for BitMask < LANES >
138- {
139- #[ inline]
140- fn bitxor_assign ( & mut self , rhs : Self ) {
141- self . 0 ^= rhs. 0 ;
156+ fn not ( mut self ) -> Self :: Output {
157+ for x in self . 0 . as_mut ( ) {
158+ * x = !* x;
159+ }
160+ if LANES % 8 > 0 {
161+ * self . 0 . as_mut ( ) . last_mut ( ) . unwrap ( ) &= u8:: MAX >> ( 8 - LANES % 8 ) ;
162+ }
163+ self
142164 }
143165}
144166
145- pub type Mask8 < const LANES : usize > = BitMask < LANES > ;
146- pub type Mask16 < const LANES : usize > = BitMask < LANES > ;
147- pub type Mask32 < const LANES : usize > = BitMask < LANES > ;
148- pub type Mask64 < const LANES : usize > = BitMask < LANES > ;
149- pub type Mask128 < const LANES : usize > = BitMask < LANES > ;
150- pub type MaskSize < const LANES : usize > = BitMask < LANES > ;
167+ pub type Mask8 < T , const LANES : usize > = BitMask < T , LANES > ;
168+ pub type Mask16 < T , const LANES : usize > = BitMask < T , LANES > ;
169+ pub type Mask32 < T , const LANES : usize > = BitMask < T , LANES > ;
170+ pub type Mask64 < T , const LANES : usize > = BitMask < T , LANES > ;
171+ pub type MaskSize < T , const LANES : usize > = BitMask < T , LANES > ;
0 commit comments