2020 * Sort all the NAN's to end of the array and return the index of the last elem
2121 * in the array which is not a nan
2222 */
23- template <typename T1, typename T2>
23+ template <typename T1, typename T2, typename vtype >
2424X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array (T1 *keys,
2525 T2 *vals,
2626 arrsize_t size)
2727{
28+ using reg_t = typename vtype::reg_t ;
29+
2830 arrsize_t jj = size - 1 ;
2931 arrsize_t ii = 0 ;
3032 arrsize_t count = 0 ;
33+
34+ while (ii + vtype::numlanes < jj) {
35+ reg_t in = vtype::loadu (keys + ii);
36+ auto nanmask = vtype::convert_mask_to_int (
37+ vtype::template fpclass<0x01 | 0x80 >(in));
38+
39+ // Check if there are any nans in this vector, and process them if so
40+ if (nanmask != 0x00 ) {
41+ for (size_t offset = 0 ; offset < vtype::numlanes; offset++) {
42+ if (is_a_nan (keys[ii])) {
43+ std::swap (keys[ii], keys[jj]);
44+ std::swap (vals[ii], vals[jj]);
45+ jj -= 1 ;
46+ count++;
47+ }
48+ else {
49+ ii += 1 ;
50+ }
51+ }
52+ }
53+ else {
54+ ii += vtype::numlanes;
55+ }
56+ }
57+
58+ // Handle the remainders once we have less than 1 vector worth
3159 while (ii < jj) {
3260 if (is_a_nan (keys[ii])) {
3361 std::swap (keys[ii], keys[jj]);
@@ -39,6 +67,7 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys,
3967 ii += 1 ;
4068 }
4169 }
70+
4271 /* Haven't checked for nan when ii == jj */
4372 if (is_a_nan (keys[ii])) { count++; }
4473 return size - count - 1 ;
@@ -570,7 +599,8 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
570599 if constexpr (xss::fp::is_floating_point_v<T1>) {
571600 if (UNLIKELY (hasnan)) {
572601 index_last_elem
573- = move_nans_to_end_of_array (keys, indexes, arrsize);
602+ = move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
603+ keys, indexes, arrsize);
574604 }
575605 }
576606 else {
@@ -660,7 +690,8 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
660690 if constexpr (xss::fp::is_floating_point_v<T1>) {
661691 if (UNLIKELY (hasnan)) {
662692 index_last_elem
663- = move_nans_to_end_of_array (keys, indexes, arrsize);
693+ = move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
694+ keys, indexes, arrsize);
664695 }
665696 }
666697
0 commit comments