@@ -1034,6 +1034,60 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
10341034 } } ;
10351035 }
10361036
1037+ /// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
1038+ macro_rules! require_int_ty {
1039+ ( $ty: expr, $diag: expr) => {
1040+ match $ty {
1041+ ty:: Int ( i) => i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) ) ,
1042+ _ => {
1043+ return_error!( $diag) ;
1044+ }
1045+ }
1046+ } ;
1047+ }
1048+
1049+ /// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
1050+ macro_rules! require_int_or_uint_ty {
1051+ ( $ty: expr, $diag: expr) => {
1052+ match $ty {
1053+ ty:: Int ( i) => i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) ) ,
1054+ ty:: Uint ( i) => {
1055+ i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) )
1056+ }
1057+ _ => {
1058+ return_error!( $diag) ;
1059+ }
1060+ }
1061+ } ;
1062+ }
1063+
1064+ /// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
1065+ /// down to an i1 based mask that can be used by llvm intrinsics.
1066+ ///
1067+ /// The rust simd semantics are that each element should either consist of all ones or all zeroes,
1068+ /// but this information is not available to llvm. Truncating the vector effectively uses the lowest bit,
1069+ /// but codegen for several targets is better if we consider the highest bit by shifting.
1070+ ///
1071+ /// For x86 SSE/AVX targets this is beneficial since most instructions with mask parameters only consider the highest bit.
1072+ /// So even though on llvm level we have an additional shift, in the final assembly there is no shift or truncate and
1073+ /// instead the mask can be used as is.
1074+ ///
1075+ /// For aarch64 and other targets there is a benefit because a mask from the sign bit can be more
1076+ /// efficiently converted to an all ones / all zeroes mask by comparing whether each element is negative.
1077+ fn vector_mask_to_bitmask < ' a , ' ll , ' tcx > (
1078+ bx : & mut Builder < ' a , ' ll , ' tcx > ,
1079+ i_xn : & ' ll Value ,
1080+ in_elem_bitwidth : u64 ,
1081+ in_len : u64 ,
1082+ ) -> & ' ll Value {
1083+ // Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
1084+ let shift_idx = bx. cx . const_int ( bx. type_ix ( in_elem_bitwidth) , ( in_elem_bitwidth - 1 ) as _ ) ;
1085+ let shift_indices = vec ! [ shift_idx; in_len as _] ;
1086+ let i_xn_msb = bx. lshr ( i_xn, bx. const_vector ( shift_indices. as_slice ( ) ) ) ;
1087+ // Truncate vector to an <i1 x N>
1088+ bx. trunc ( i_xn_msb, bx. type_vector ( bx. type_i1 ( ) , in_len) )
1089+ }
1090+
10371091 let tcx = bx. tcx ( ) ;
10381092 let sig =
10391093 tcx. normalize_erasing_late_bound_regions ( ty:: ParamEnv :: reveal_all ( ) , callee_ty. fn_sig ( tcx) ) ;
@@ -1294,14 +1348,11 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
12941348 m_len == v_len,
12951349 InvalidMonomorphization :: MismatchedLengths { span, name, m_len, v_len }
12961350 ) ;
1297- match m_elem_ty. kind ( ) {
1298- ty:: Int ( _) => { }
1299- _ => return_error ! ( InvalidMonomorphization :: MaskType { span, name, ty: m_elem_ty } ) ,
1300- }
1301- // truncate the mask to a vector of i1s
1302- let i1 = bx. type_i1 ( ) ;
1303- let i1xn = bx. type_vector ( i1, m_len as u64 ) ;
1304- let m_i1s = bx. trunc ( args[ 0 ] . immediate ( ) , i1xn) ;
1351+ let in_elem_bitwidth = require_int_ty ! (
1352+ m_elem_ty. kind( ) ,
1353+ InvalidMonomorphization :: MaskType { span, name, ty: m_elem_ty }
1354+ ) ;
1355+ let m_i1s = vector_mask_to_bitmask ( bx, args[ 0 ] . immediate ( ) , in_elem_bitwidth, m_len) ;
13051356 return Ok ( bx. select ( m_i1s, args[ 1 ] . immediate ( ) , args[ 2 ] . immediate ( ) ) ) ;
13061357 }
13071358
@@ -1319,32 +1370,12 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
13191370 let expected_bytes = expected_int_bits / 8 + ( ( expected_int_bits % 8 > 0 ) as u64 ) ;
13201371
13211372 // Integer vector <i{in_bitwidth} x in_len>:
1322- let ( i_xn, in_elem_bitwidth) = match in_elem. kind ( ) {
1323- ty:: Int ( i) => (
1324- args[ 0 ] . immediate ( ) ,
1325- i. bit_width ( ) . unwrap_or_else ( || bx. data_layout ( ) . pointer_size . bits ( ) ) ,
1326- ) ,
1327- ty:: Uint ( i) => (
1328- args[ 0 ] . immediate ( ) ,
1329- i. bit_width ( ) . unwrap_or_else ( || bx. data_layout ( ) . pointer_size . bits ( ) ) ,
1330- ) ,
1331- _ => return_error ! ( InvalidMonomorphization :: VectorArgument {
1332- span,
1333- name,
1334- in_ty,
1335- in_elem
1336- } ) ,
1337- } ;
1373+ let in_elem_bitwidth = require_int_or_uint_ty ! (
1374+ in_elem. kind( ) ,
1375+ InvalidMonomorphization :: VectorArgument { span, name, in_ty, in_elem }
1376+ ) ;
13381377
1339- // Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
1340- let shift_indices =
1341- vec ! [
1342- bx. cx. const_int( bx. type_ix( in_elem_bitwidth) , ( in_elem_bitwidth - 1 ) as _) ;
1343- in_len as _
1344- ] ;
1345- let i_xn_msb = bx. lshr ( i_xn, bx. const_vector ( shift_indices. as_slice ( ) ) ) ;
1346- // Truncate vector to an <i1 x N>
1347- let i1xn = bx. trunc ( i_xn_msb, bx. type_vector ( bx. type_i1 ( ) , in_len) ) ;
1378+ let i1xn = vector_mask_to_bitmask ( bx, args[ 0 ] . immediate ( ) , in_elem_bitwidth, in_len) ;
13481379 // Bitcast <i1 x N> to iN:
13491380 let i_ = bx. bitcast ( i1xn, bx. type_ix ( in_len) ) ;
13501381
@@ -1562,28 +1593,23 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
15621593 }
15631594 ) ;
15641595
1565- match element_ty2. kind ( ) {
1566- ty:: Int ( _) => ( ) ,
1567- _ => {
1568- return_error ! ( InvalidMonomorphization :: ThirdArgElementType {
1569- span,
1570- name,
1571- expected_element: element_ty2,
1572- third_arg: arg_tys[ 2 ]
1573- } ) ;
1596+ let mask_elem_bitwidth = require_int_ty ! (
1597+ element_ty2. kind( ) ,
1598+ InvalidMonomorphization :: ThirdArgElementType {
1599+ span,
1600+ name,
1601+ expected_element: element_ty2,
1602+ third_arg: arg_tys[ 2 ]
15741603 }
1575- }
1604+ ) ;
15761605
15771606 // Alignment of T, must be a constant integer value:
15781607 let alignment_ty = bx. type_i32 ( ) ;
15791608 let alignment = bx. const_i32 ( bx. align_of ( in_elem) . bytes ( ) as i32 ) ;
15801609
15811610 // Truncate the mask vector to a vector of i1s:
1582- let ( mask, mask_ty) = {
1583- let i1 = bx. type_i1 ( ) ;
1584- let i1xn = bx. type_vector ( i1, in_len) ;
1585- ( bx. trunc ( args[ 2 ] . immediate ( ) , i1xn) , i1xn)
1586- } ;
1611+ let mask = vector_mask_to_bitmask ( bx, args[ 2 ] . immediate ( ) , mask_elem_bitwidth, in_len) ;
1612+ let mask_ty = bx. type_vector ( bx. type_i1 ( ) , in_len) ;
15871613
15881614 // Type of the vector of pointers:
15891615 let llvm_pointer_vec_ty = llvm_vector_ty ( bx, element_ty1, in_len) ;
@@ -1668,8 +1694,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
16681694 }
16691695 ) ;
16701696
1671- require ! (
1672- matches! ( mask_elem. kind( ) , ty :: Int ( _ ) ) ,
1697+ let m_elem_bitwidth = require_int_ty ! (
1698+ mask_elem. kind( ) ,
16731699 InvalidMonomorphization :: ThirdArgElementType {
16741700 span,
16751701 name,
@@ -1678,17 +1704,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
16781704 }
16791705 ) ;
16801706
1707+ let mask = vector_mask_to_bitmask ( bx, args[ 0 ] . immediate ( ) , m_elem_bitwidth, mask_len) ;
1708+ let mask_ty = bx. type_vector ( bx. type_i1 ( ) , mask_len) ;
1709+
16811710 // Alignment of T, must be a constant integer value:
16821711 let alignment_ty = bx. type_i32 ( ) ;
16831712 let alignment = bx. const_i32 ( bx. align_of ( values_elem) . bytes ( ) as i32 ) ;
16841713
1685- // Truncate the mask vector to a vector of i1s:
1686- let ( mask, mask_ty) = {
1687- let i1 = bx. type_i1 ( ) ;
1688- let i1xn = bx. type_vector ( i1, mask_len) ;
1689- ( bx. trunc ( args[ 0 ] . immediate ( ) , i1xn) , i1xn)
1690- } ;
1691-
16921714 let llvm_pointer = bx. type_ptr ( ) ;
16931715
16941716 // Type of the vector of elements:
@@ -1760,8 +1782,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17601782 }
17611783 ) ;
17621784
1763- require ! (
1764- matches! ( mask_elem. kind( ) , ty :: Int ( _ ) ) ,
1785+ let m_elem_bitwidth = require_int_ty ! (
1786+ mask_elem. kind( ) ,
17651787 InvalidMonomorphization :: ThirdArgElementType {
17661788 span,
17671789 name,
@@ -1770,17 +1792,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17701792 }
17711793 ) ;
17721794
1795+ let mask = vector_mask_to_bitmask ( bx, args[ 0 ] . immediate ( ) , m_elem_bitwidth, mask_len) ;
1796+ let mask_ty = bx. type_vector ( bx. type_i1 ( ) , mask_len) ;
1797+
17731798 // Alignment of T, must be a constant integer value:
17741799 let alignment_ty = bx. type_i32 ( ) ;
17751800 let alignment = bx. const_i32 ( bx. align_of ( values_elem) . bytes ( ) as i32 ) ;
17761801
1777- // Truncate the mask vector to a vector of i1s:
1778- let ( mask, mask_ty) = {
1779- let i1 = bx. type_i1 ( ) ;
1780- let i1xn = bx. type_vector ( i1, in_len) ;
1781- ( bx. trunc ( args[ 0 ] . immediate ( ) , i1xn) , i1xn)
1782- } ;
1783-
17841802 let ret_t = bx. type_void ( ) ;
17851803
17861804 let llvm_pointer = bx. type_ptr ( ) ;
@@ -1859,28 +1877,23 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18591877 ) ;
18601878
18611879 // The element type of the third argument must be a signed integer type of any width:
1862- match element_ty2. kind ( ) {
1863- ty:: Int ( _) => ( ) ,
1864- _ => {
1865- return_error ! ( InvalidMonomorphization :: ThirdArgElementType {
1866- span,
1867- name,
1868- expected_element: element_ty2,
1869- third_arg: arg_tys[ 2 ]
1870- } ) ;
1880+ let mask_elem_bitwidth = require_int_ty ! (
1881+ element_ty2. kind( ) ,
1882+ InvalidMonomorphization :: ThirdArgElementType {
1883+ span,
1884+ name,
1885+ expected_element: element_ty2,
1886+ third_arg: arg_tys[ 2 ]
18711887 }
1872- }
1888+ ) ;
18731889
18741890 // Alignment of T, must be a constant integer value:
18751891 let alignment_ty = bx. type_i32 ( ) ;
18761892 let alignment = bx. const_i32 ( bx. align_of ( in_elem) . bytes ( ) as i32 ) ;
18771893
18781894 // Truncate the mask vector to a vector of i1s:
1879- let ( mask, mask_ty) = {
1880- let i1 = bx. type_i1 ( ) ;
1881- let i1xn = bx. type_vector ( i1, in_len) ;
1882- ( bx. trunc ( args[ 2 ] . immediate ( ) , i1xn) , i1xn)
1883- } ;
1895+ let mask = vector_mask_to_bitmask ( bx, args[ 2 ] . immediate ( ) , mask_elem_bitwidth, in_len) ;
1896+ let mask_ty = bx. type_vector ( bx. type_i1 ( ) , in_len) ;
18841897
18851898 let ret_t = bx. type_void ( ) ;
18861899
@@ -2018,8 +2031,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
20182031 ) ;
20192032 args[ 0 ] . immediate( )
20202033 } else {
2021- match in_elem. kind( ) {
2022- ty:: Int ( _) | ty:: Uint ( _) => { }
2034+ let bitwidth = match in_elem. kind( ) {
2035+ ty:: Int ( i) => {
2036+ i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) )
2037+ }
2038+ ty:: Uint ( i) => {
2039+ i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) )
2040+ }
20232041 _ => return_error!( InvalidMonomorphization :: UnsupportedSymbol {
20242042 span,
20252043 name,
@@ -2028,12 +2046,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
20282046 in_elem,
20292047 ret_ty
20302048 } ) ,
2031- }
2049+ } ;
20322050
2033- // boolean reductions operate on vectors of i1s:
2034- let i1 = bx. type_i1( ) ;
2035- let i1xn = bx. type_vector( i1, in_len as u64 ) ;
2036- bx. trunc( args[ 0 ] . immediate( ) , i1xn)
2051+ vector_mask_to_bitmask( bx, args[ 0 ] . immediate( ) , bitwidth, in_len as _)
20372052 } ;
20382053 return match in_elem. kind( ) {
20392054 ty:: Int ( _) | ty:: Uint ( _) => {
0 commit comments