@@ -3401,6 +3401,129 @@ struct LogSumExpOverAxis0TempsContigFactory
34013401
34023402// Argmax and Argmin
34033403
3404+ /* Sequential search reduction */
3405+
3406+ template <typename argT,
3407+ typename outT,
3408+ typename ReductionOp,
3409+ typename IdxReductionOp,
3410+ typename InputOutputIterIndexerT,
3411+ typename InputRedIndexerT>
3412+ struct SequentialSearchReduction
3413+ {
3414+ private:
3415+ const argT *inp_ = nullptr ;
3416+ outT *out_ = nullptr ;
3417+ ReductionOp reduction_op_;
3418+ argT identity_;
3419+ IdxReductionOp idx_reduction_op_;
3420+ outT idx_identity_;
3421+ InputOutputIterIndexerT inp_out_iter_indexer_;
3422+ InputRedIndexerT inp_reduced_dims_indexer_;
3423+ size_t reduction_max_gid_ = 0 ;
3424+
3425+ public:
3426+ SequentialSearchReduction (const argT *inp,
3427+ outT *res,
3428+ ReductionOp reduction_op,
3429+ const argT &identity_val,
3430+ IdxReductionOp idx_reduction_op,
3431+ const outT &idx_identity_val,
3432+ InputOutputIterIndexerT arg_res_iter_indexer,
3433+ InputRedIndexerT arg_reduced_dims_indexer,
3434+ size_t reduction_size)
3435+ : inp_(inp), out_(res), reduction_op_(reduction_op),
3436+ identity_ (identity_val), idx_reduction_op_(idx_reduction_op),
3437+ idx_identity_(idx_identity_val),
3438+ inp_out_iter_indexer_(arg_res_iter_indexer),
3439+ inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
3440+ reduction_max_gid_(reduction_size)
3441+ {
3442+ }
3443+
3444+ void operator ()(sycl::id<1 > id) const
3445+ {
3446+
3447+ auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_ (id[0 ]);
3448+ const py::ssize_t &inp_iter_offset =
3449+ inp_out_iter_offsets_.get_first_offset ();
3450+ const py::ssize_t &out_iter_offset =
3451+ inp_out_iter_offsets_.get_second_offset ();
3452+
3453+ argT red_val (identity_);
3454+ outT idx_val (idx_identity_);
3455+ for (size_t m = 0 ; m < reduction_max_gid_; ++m) {
3456+ const py::ssize_t inp_reduction_offset =
3457+ inp_reduced_dims_indexer_ (m);
3458+ const py::ssize_t inp_offset =
3459+ inp_iter_offset + inp_reduction_offset;
3460+
3461+ argT val = inp_[inp_offset];
3462+ if (val == red_val) {
3463+ idx_val = idx_reduction_op_ (idx_val, static_cast <outT>(m));
3464+ }
3465+ else {
3466+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
3467+ using dpctl::tensor::type_utils::is_complex;
3468+ if constexpr (is_complex<argT>::value) {
3469+ using dpctl::tensor::math_utils::less_complex;
3470+ // less_complex always returns false for NaNs, so check
3471+ if (less_complex<argT>(val, red_val) ||
3472+ std::isnan (std::real (val)) ||
3473+ std::isnan (std::imag (val)))
3474+ {
3475+ red_val = val;
3476+ idx_val = static_cast <outT>(m);
3477+ }
3478+ }
3479+ else if constexpr (std::is_floating_point_v<argT> ||
3480+ std::is_same_v<argT, sycl::half>)
3481+ {
3482+ if (val < red_val || std::isnan (val)) {
3483+ red_val = val;
3484+ idx_val = static_cast <outT>(m);
3485+ }
3486+ }
3487+ else {
3488+ if (val < red_val) {
3489+ red_val = val;
3490+ idx_val = static_cast <outT>(m);
3491+ }
3492+ }
3493+ }
3494+ else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
3495+ using dpctl::tensor::type_utils::is_complex;
3496+ if constexpr (is_complex<argT>::value) {
3497+ using dpctl::tensor::math_utils::greater_complex;
3498+ if (greater_complex<argT>(val, red_val) ||
3499+ std::isnan (std::real (val)) ||
3500+ std::isnan (std::imag (val)))
3501+ {
3502+ red_val = val;
3503+ idx_val = static_cast <outT>(m);
3504+ }
3505+ }
3506+ else if constexpr (std::is_floating_point_v<argT> ||
3507+ std::is_same_v<argT, sycl::half>)
3508+ {
3509+ if (val > red_val || std::isnan (val)) {
3510+ red_val = val;
3511+ idx_val = static_cast <outT>(m);
3512+ }
3513+ }
3514+ else {
3515+ if (val > red_val) {
3516+ red_val = val;
3517+ idx_val = static_cast <outT>(m);
3518+ }
3519+ }
3520+ }
3521+ }
3522+ }
3523+ out_[out_iter_offset] = idx_val;
3524+ }
3525+ };
3526+
34043527/* = Search reduction using reduce_over_group*/
34053528
34063529template <typename argT,
@@ -3670,7 +3793,9 @@ struct CustomSearchReduction
36703793 }
36713794 }
36723795 }
3673- else if constexpr (std::is_floating_point_v<argT>) {
3796+ else if constexpr (std::is_floating_point_v<argT> ||
3797+ std::is_same_v<argT, sycl::half>)
3798+ {
36743799 if (val < local_red_val || std::isnan (val)) {
36753800 local_red_val = val;
36763801 if constexpr (!First) {
@@ -3714,7 +3839,9 @@ struct CustomSearchReduction
37143839 }
37153840 }
37163841 }
3717- else if constexpr (std::is_floating_point_v<argT>) {
3842+ else if constexpr (std::is_floating_point_v<argT> ||
3843+ std::is_same_v<argT, sycl::half>)
3844+ {
37183845 if (val > local_red_val || std::isnan (val)) {
37193846 local_red_val = val;
37203847 if constexpr (!First) {
@@ -3757,7 +3884,9 @@ struct CustomSearchReduction
37573884 ? local_idx
37583885 : idx_identity_;
37593886 }
3760- else if constexpr (std::is_floating_point_v<argT>) {
3887+ else if constexpr (std::is_floating_point_v<argT> ||
3888+ std::is_same_v<argT, sycl::half>)
3889+ {
37613890 // equality does not hold for NaNs, so check here
37623891 local_idx =
37633892 (red_val_over_wg == local_red_val || std::isnan (local_red_val))
@@ -3799,6 +3928,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
37993928 py::ssize_t ,
38003929 const std::vector<sycl::event> &);
38013930
3931+ template <typename T1,
3932+ typename T2,
3933+ typename T3,
3934+ typename T4,
3935+ typename T5,
3936+ typename T6>
3937+ class search_seq_strided_krn ;
3938+
38023939template <typename T1,
38033940 typename T2,
38043941 typename T3,
@@ -3820,6 +3957,14 @@ template <typename T1,
38203957 bool b2>
38213958class custom_search_over_group_temps_strided_krn ;
38223959
3960+ template <typename T1,
3961+ typename T2,
3962+ typename T3,
3963+ typename T4,
3964+ typename T5,
3965+ typename T6>
3966+ class search_seq_contig_krn ;
3967+
38233968template <typename T1,
38243969 typename T2,
38253970 typename T3,
@@ -4019,6 +4164,36 @@ sycl::event search_over_group_temps_strided_impl(
40194164 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
40204165 size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
40214166
4167+ if (reduction_nelems < wg) {
4168+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4169+ cgh.depends_on (depends);
4170+
4171+ using InputOutputIterIndexerT =
4172+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
4173+ using ReductionIndexerT =
4174+ dpctl::tensor::offset_utils::StridedIndexer;
4175+
4176+ InputOutputIterIndexerT in_out_iter_indexer{
4177+ iter_nd, iter_arg_offset, iter_res_offset,
4178+ iter_shape_and_strides};
4179+ ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
4180+ reduction_shape_stride};
4181+
4182+ cgh.parallel_for <class search_seq_strided_krn <
4183+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4184+ ReductionIndexerT>>(
4185+ sycl::range<1 >(iter_nelems),
4186+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4187+ InputOutputIterIndexerT,
4188+ ReductionIndexerT>(
4189+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4190+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4191+ reduction_nelems));
4192+ });
4193+
4194+ return comp_ev;
4195+ }
4196+
40224197 constexpr size_t preferred_reductions_per_wi = 4 ;
40234198 // max_max_wg prevents running out of resources on CPU
40244199 size_t max_wg =
@@ -4419,6 +4594,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
44194594 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
44204595 size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
44214596
4597+ if (reduction_nelems < wg) {
4598+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4599+ cgh.depends_on (depends);
4600+
4601+ using InputIterIndexerT =
4602+ dpctl::tensor::offset_utils::Strided1DIndexer;
4603+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
4604+ using InputOutputIterIndexerT =
4605+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
4606+ InputIterIndexerT, NoOpIndexerT>;
4607+ using ReductionIndexerT = NoOpIndexerT;
4608+
4609+ InputOutputIterIndexerT in_out_iter_indexer{
4610+ InputIterIndexerT{0 , static_cast <py::ssize_t >(iter_nelems),
4611+ static_cast <py::ssize_t >(reduction_nelems)},
4612+ NoOpIndexerT{}};
4613+ ReductionIndexerT reduction_indexer{};
4614+
4615+ cgh.parallel_for <class search_seq_contig_krn <
4616+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4617+ ReductionIndexerT>>(
4618+ sycl::range<1 >(iter_nelems),
4619+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4620+ InputOutputIterIndexerT,
4621+ ReductionIndexerT>(
4622+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4623+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4624+ reduction_nelems));
4625+ });
4626+
4627+ return comp_ev;
4628+ }
4629+
44224630 constexpr size_t preferred_reductions_per_wi = 8 ;
44234631 // max_max_wg prevents running out of resources on CPU
44244632 size_t max_wg =
@@ -4801,6 +5009,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
48015009 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
48025010 size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
48035011
5012+ if (reduction_nelems < wg) {
5013+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
5014+ cgh.depends_on (depends);
5015+
5016+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
5017+ using InputOutputIterIndexerT =
5018+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
5019+ NoOpIndexerT, NoOpIndexerT>;
5020+ using ReductionIndexerT =
5021+ dpctl::tensor::offset_utils::Strided1DIndexer;
5022+
5023+ InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
5024+ NoOpIndexerT{}};
5025+ ReductionIndexerT reduction_indexer{
5026+ 0 , static_cast <py::ssize_t >(reduction_nelems),
5027+ static_cast <py::ssize_t >(iter_nelems)};
5028+
5029+ using KernelName =
5030+ class search_seq_contig_krn <argTy, resTy, ReductionOpT,
5031+ IndexOpT, InputOutputIterIndexerT,
5032+ ReductionIndexerT>;
5033+
5034+ sycl::range<1 > iter_range{iter_nelems};
5035+
5036+ cgh.parallel_for <KernelName>(
5037+ iter_range,
5038+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
5039+ InputOutputIterIndexerT,
5040+ ReductionIndexerT>(
5041+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
5042+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
5043+ reduction_nelems));
5044+ });
5045+
5046+ return comp_ev;
5047+ }
5048+
48045049 constexpr size_t preferred_reductions_per_wi = 8 ;
48055050 // max_max_wg prevents running out of resources on CPU
48065051 size_t max_wg =
0 commit comments