3232
3333#include " pybind11/pybind11.h"
3434#include " utils/offset_utils.hpp"
35+ #include " utils/sycl_utils.hpp"
3536#include " utils/type_dispatch.hpp"
3637#include " utils/type_utils.hpp"
3738
@@ -150,35 +151,6 @@ struct ReductionOverGroupWithAtomicFunctor
150151 }
151152};
152153
153- template <size_t f = 4 >
154- size_t choose_workgroup_size (const size_t reduction_nelems,
155- const std::vector<size_t > &sg_sizes)
156- {
157- std::vector<size_t > wg_choices;
158- wg_choices.reserve (f * sg_sizes.size ());
159-
160- for (const auto &sg_size : sg_sizes) {
161- #pragma unroll
162- for (size_t i = 1 ; i <= f; ++i) {
163- wg_choices.push_back (sg_size * i);
164- }
165- }
166- std::sort (std::begin (wg_choices), std::end (wg_choices));
167-
168- size_t wg = 1 ;
169- for (size_t i = 0 ; i < wg_choices.size (); ++i) {
170- if (wg_choices[i] == wg) {
171- continue ;
172- }
173- wg = wg_choices[i];
174- size_t n_groups = ((reduction_nelems + wg - 1 ) / wg);
175- if (n_groups == 1 )
176- break ;
177- }
178-
179- return wg;
180- }
181-
182154typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
183155 sycl::queue,
184156 size_t ,
@@ -200,6 +172,8 @@ class sum_reduction_over_group_with_atomics_krn;
200172template <typename T1, typename T2, typename T3>
201173class sum_reduction_over_group_with_atomics_1d_krn ;
202174
175+ using dpctl::tensor::sycl_utils::choose_workgroup_size;
176+
203177template <typename argTy, typename resTy>
204178sycl::event sum_reduction_over_group_with_atomics_strided_impl (
205179 sycl::queue exec_q,
@@ -548,13 +522,22 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
548522 (preferrered_reductions_per_wi * wg);
549523 assert (reduction_groups > 1 );
550524
551- resTy *partially_reduced_tmp =
552- sycl::malloc_device<resTy>(iter_nelems * reduction_groups, exec_q);
525+ size_t second_iter_reduction_groups_ =
526+ (reduction_groups + preferrered_reductions_per_wi * wg - 1 ) /
527+ (preferrered_reductions_per_wi * wg);
528+
529+ resTy *partially_reduced_tmp = sycl::malloc_device<resTy>(
530+ iter_nelems * (reduction_groups + second_iter_reduction_groups_),
531+ exec_q);
553532 resTy *partially_reduced_tmp2 = nullptr ;
554533
555534 if (partially_reduced_tmp == nullptr ) {
556535 throw std::runtime_error (" Unabled to allocate device_memory" );
557536 }
537+ else {
538+ partially_reduced_tmp2 =
539+ partially_reduced_tmp + reduction_groups * iter_nelems;
540+ }
558541
559542 sycl::event first_reduction_ev = exec_q.submit ([&](sycl::handler &cgh) {
560543 cgh.depends_on (depends);
@@ -610,21 +593,6 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
610593 (preferrered_reductions_per_wi * wg);
611594 assert (reduction_groups_ > 1 );
612595
613- if (partially_reduced_tmp2 == nullptr ) {
614- partially_reduced_tmp2 = sycl::malloc_device<resTy>(
615- iter_nelems * reduction_groups_, exec_q);
616-
617- if (partially_reduced_tmp2 == nullptr ) {
618- dependent_ev.wait ();
619- sycl::free (partially_reduced_tmp, exec_q);
620-
621- throw std::runtime_error (
622- " Unable to allocate device memory" );
623- }
624-
625- temp2_arg = partially_reduced_tmp2;
626- }
627-
628596 // keep reducing
629597 sycl::event partial_reduction_ev =
630598 exec_q.submit ([&](sycl::handler &cgh) {
@@ -727,13 +695,9 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
727695 cgh.depends_on (final_reduction_ev);
728696 sycl::context ctx = exec_q.get_context ();
729697
730- cgh.host_task (
731- [ctx, partially_reduced_tmp, partially_reduced_tmp2] {
732- sycl::free (partially_reduced_tmp, ctx);
733- if (partially_reduced_tmp2) {
734- sycl::free (partially_reduced_tmp2, ctx);
735- }
736- });
698+ cgh.host_task ([ctx, partially_reduced_tmp] {
699+ sycl::free (partially_reduced_tmp, ctx);
700+ });
737701 });
738702
739703 // FIXME: do not return host-task event
0 commit comments