@@ -50,12 +50,18 @@ namespace tensor
5050namespace kernels
5151{
5252
53+ template <typename ReductionOpT, typename T> struct needs_workaround
54+ {
55+ static constexpr bool value =
56+ std::is_same_v<ReductionOpT, sycl::multiplies<T>> &&
57+ (std::is_same_v<T, std::int64_t > || std::is_same_v<T, std::uint64_t >);
58+ };
59+
5360template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
5461{
5562 static constexpr bool value =
5663 sycl::has_known_identity<ReductionOpT, T>::value &&
57- !std::is_same_v<T, std::int64_t > && !std::is_same_v<T, std::uint64_t > &&
58- !std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
64+ !needs_workaround<ReductionOpT, T>::value;
5965};
6066
6167template <typename argT,
@@ -1088,7 +1094,7 @@ sycl::event reduction_over_group_temps_strided_impl(
10881094 // max_max_wg prevents running out of resources on CPU
10891095 constexpr size_t max_max_wg = 2048 ;
10901096 size_t max_wg = std::min (
1091- max_max_wg, d.get_info <sycl::info::device::max_work_group_size>());
1097+ max_max_wg, d.get_info <sycl::info::device::max_work_group_size>() / 2 );
10921098
10931099 size_t reductions_per_wi (preferrered_reductions_per_wi);
10941100 if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1339,7 +1345,7 @@ sycl::event reduction_over_group_temps_strided_impl(
13391345 static_cast <py::ssize_t >(remaining_reduction_nelems)};
13401346 ResIndexerT res_iter_indexer{iter_nd, iter_res_offset,
13411347 /* shape */ iter_shape_and_strides,
1342- /* s trides */ iter_shape_and_strides +
1348+ /* strides */ iter_shape_and_strides +
13431349 2 * iter_nd};
13441350
13451351 InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
@@ -1424,8 +1430,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
14241430 py::ssize_t reduction_arg_offset,
14251431 const std::vector<sycl::event> &depends)
14261432{
1427- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
1428- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
1433+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1434+ iter_arg_offset + reduction_arg_offset;
1435+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
14291436
14301437 constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
14311438
@@ -1437,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
14371444 // max_max_wg prevents running out of resources on CPU
14381445 constexpr size_t max_max_wg = 2048 ;
14391446 size_t max_wg = std::min (
1440- max_max_wg, d.get_info <sycl::info::device::max_work_group_size>());
1447+ max_max_wg, d.get_info <sycl::info::device::max_work_group_size>() / 2 );
14411448
14421449 size_t reductions_per_wi (preferrered_reductions_per_wi);
14431450 if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1767,8 +1774,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
17671774 py::ssize_t reduction_arg_offset,
17681775 const std::vector<sycl::event> &depends)
17691776{
1770- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
1771- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
1777+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1778+ iter_arg_offset + reduction_arg_offset;
1779+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
17721780
17731781 constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
17741782
@@ -1780,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
17801788 // max_max_wg prevents running out of resources on CPU
17811789 constexpr size_t max_max_wg = 2048 ;
17821790 size_t max_wg = std::min (
1783- max_max_wg, d.get_info <sycl::info::device::max_work_group_size>());
1791+ max_max_wg, d.get_info <sycl::info::device::max_work_group_size>() / 2 );
17841792
17851793 size_t reductions_per_wi (preferrered_reductions_per_wi);
17861794 if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -3875,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl(
38753883
38763884 constexpr size_t preferrered_reductions_per_wi = 4 ;
38773885 // max_max_wg prevents running out of resources on CPU
3878- size_t max_wg = std::min (
3879- size_t (2048 ), d.get_info <sycl::info::device::max_work_group_size>());
3886+ size_t max_wg =
3887+ std::min (size_t (2048 ),
3888+ d.get_info <sycl::info::device::max_work_group_size>() / 2 );
38803889
38813890 size_t reductions_per_wi (preferrered_reductions_per_wi);
38823891 if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4258,8 +4267,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
42584267 py::ssize_t reduction_arg_offset,
42594268 const std::vector<sycl::event> &depends)
42604269{
4261- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
4262- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
4270+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
4271+ iter_arg_offset + reduction_arg_offset;
4272+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
42634273
42644274 constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
42654275 constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4270,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
42704280
42714281 constexpr size_t preferrered_reductions_per_wi = 8 ;
42724282 // max_max_wg prevents running out of resources on CPU
4273- size_t max_wg = std::min (
4274- size_t (2048 ), d.get_info <sycl::info::device::max_work_group_size>());
4283+ size_t max_wg =
4284+ std::min (size_t (2048 ),
4285+ d.get_info <sycl::info::device::max_work_group_size>() / 2 );
42754286
42764287 size_t reductions_per_wi (preferrered_reductions_per_wi);
42774288 if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4635,8 +4646,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
46354646 py::ssize_t reduction_arg_offset,
46364647 const std::vector<sycl::event> &depends)
46374648{
4638- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
4639- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
4649+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
4650+ iter_arg_offset + reduction_arg_offset;
4651+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
46404652
46414653 constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
46424654 constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4647,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
46474659
46484660 constexpr size_t preferrered_reductions_per_wi = 8 ;
46494661 // max_max_wg prevents running out of resources on CPU
4650- size_t max_wg = std::min (
4651- size_t (2048 ), d.get_info <sycl::info::device::max_work_group_size>());
4662+ size_t max_wg =
4663+ std::min (size_t (2048 ),
4664+ d.get_info <sycl::info::device::max_work_group_size>() / 2 );
46524665
46534666 size_t reductions_per_wi (preferrered_reductions_per_wi);
46544667 if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
0 commit comments