99
1010#include " avx512-16bit-common.h"
1111
12- struct float16 {
13- uint16_t val;
14- };
15-
1612template <>
1713struct zmm_vector <float16> {
1814 using type_t = uint16_t ;
@@ -545,10 +541,65 @@ replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr, arrsize_t arrsize)
545541 return nan_count;
546542}
547543
548- template <>
549- X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<uint16_t >(uint16_t elem)
544+ X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan_fp16 (_Float16 *arr,
545+ arrsize_t size,
546+ arrsize_t nan_count,
547+ bool descending
548+ = false )
549+ {
550+ if (descending) {
551+ for (arrsize_t ii = 0 ; nan_count > 0 ; ++ii) {
552+ arr[ii] = xss::fp::quiet_NaN<_Float16>();
553+ nan_count -= 1 ;
554+ }
555+ }
556+ else {
557+ for (arrsize_t ii = size - 1 ; nan_count > 0 ; --ii) {
558+ arr[ii] = xss::fp::quiet_NaN<_Float16>();
559+ nan_count -= 1 ;
560+ }
561+ }
562+ }
563+
564+ template <typename comparator>
565+ [[maybe_unused]] X86_SIMD_SORT_INLINE void
566+ avx512_qsort_fp16_helper (uint16_t *arr, arrsize_t arrsize)
550567{
551- return ((elem & 0x7c00u ) == 0x7c00u ) && ((elem & 0x03ffu ) != 0 );
568+ using T = uint16_t ;
569+ using vtype = zmm_vector<float16>;
570+
571+ #ifdef XSS_COMPILE_OPENMP
572+ bool use_parallel = arrsize > 100000 ;
573+
574+ if (use_parallel) {
575+ // This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system
576+ constexpr int thread_limit = 8 ;
577+ int thread_count = std::min (thread_limit, omp_get_max_threads ());
578+ arrsize_t task_threshold = std::max ((arrsize_t )100000 , arrsize / 100 );
579+
580+ // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_
581+ // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems
582+ // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
583+ #pragma omp parallel num_threads(thread_count)
584+ #pragma omp single
585+ qsort_<vtype, comparator, T>(arr,
586+ 0 ,
587+ arrsize - 1 ,
588+ 2 * (arrsize_t )log2 (arrsize),
589+ task_threshold);
590+ }
591+ else {
592+ qsort_<vtype, comparator, T>(arr,
593+ 0 ,
594+ arrsize - 1 ,
595+ 2 * (arrsize_t )log2 (arrsize),
596+ std::numeric_limits<arrsize_t >::max ());
597+ }
598+ #pragma omp taskwait
599+ #else
600+ qsort_<vtype, comparator, T>(
601+ arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize), 0 );
602+ #endif
552603}
553604
554605[[maybe_unused]] X86_SIMD_SORT_INLINE void
@@ -559,22 +610,19 @@ avx512_qsort_fp16(uint16_t *arr,
559610{
560611 using vtype = zmm_vector<float16>;
561612
562- // TODO multithreading support here
563613 if (arrsize > 1 ) {
564614 arrsize_t nan_count = 0 ;
565615 if (UNLIKELY (hasnan)) {
566- nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t >(
567- arr, arrsize);
616+ nan_count = replace_nan_with_inf<vtype, uint16_t >(arr, arrsize);
568617 }
569618 if (descending) {
570- qsort_<vtype, Comparator<vtype, true >, uint16_t >(
571- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize), 0 );
619+ avx512_qsort_fp16_helper<Comparator<vtype, true >>(arr, arrsize);
572620 }
573621 else {
574- qsort_<vtype, Comparator<vtype, false >, uint16_t >(
575- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize), 0 );
622+ avx512_qsort_fp16_helper<Comparator<vtype, false >>(arr, arrsize);
576623 }
577- replace_inf_with_nan (arr, arrsize, nan_count, descending);
624+ replace_inf_with_nan_fp16 (
625+ (_Float16 *)arr, arrsize, nan_count, descending);
578626 }
579627
580628#ifdef __MMX__
@@ -592,26 +640,37 @@ avx512_qselect_fp16(uint16_t *arr,
592640{
593641 using vtype = zmm_vector<float16>;
594642
595- arrsize_t indx_last_elem = arrsize - 1 ;
643+ // Exit early if no work would be done
644+ if (arrsize <= 1 ) return ;
645+
646+ arrsize_t index_first_elem = 0 ;
647+ arrsize_t index_last_elem = arrsize - 1 ;
648+
596649 if (UNLIKELY (hasnan)) {
597- indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
650+ if (descending) {
651+ index_first_elem = move_nans_to_start_of_array (arr, arrsize);
652+ }
653+ else {
654+ index_last_elem = move_nans_to_end_of_array (arr, arrsize);
655+ }
598656 }
599- if (indx_last_elem >= k) {
657+
658+ if (index_first_elem <= k && index_last_elem >= k) {
600659 if (descending) {
601660 qselect_<vtype, Comparator<vtype, true >, uint16_t >(
602661 arr,
603662 k,
604- 0 ,
605- indx_last_elem ,
606- 2 * (arrsize_t )log2 (indx_last_elem ));
663+ index_first_elem ,
664+ index_last_elem ,
665+ 2 * (arrsize_t )log2 (arrsize ));
607666 }
608667 else {
609668 qselect_<vtype, Comparator<vtype, false >, uint16_t >(
610669 arr,
611670 k,
612- 0 ,
613- indx_last_elem ,
614- 2 * (arrsize_t )log2 (indx_last_elem ));
671+ index_first_elem ,
672+ index_last_elem ,
673+ 2 * (arrsize_t )log2 (arrsize ));
615674 }
616675 }
617676
@@ -628,7 +687,8 @@ avx512_partial_qsort_fp16(uint16_t *arr,
628687 bool hasnan = false ,
629688 bool descending = false )
630689{
690+ if (k == 0 ) return ;
631691 avx512_qselect_fp16 (arr, k - 1 , arrsize, hasnan, descending);
632- avx512_qsort_fp16 (arr, k - 1 , descending);
692+ avx512_qsort_fp16 (arr, k - 1 , hasnan, descending);
633693}
634694#endif // AVX512_QSORT_16BIT
0 commit comments