@@ -1339,7 +1339,7 @@ sycl::event reduction_over_group_temps_strided_impl(
13391339 static_cast <py::ssize_t >(remaining_reduction_nelems)};
13401340 ResIndexerT res_iter_indexer{iter_nd, iter_res_offset,
13411341 /* shape */ iter_shape_and_strides,
1342- /* s trides */ iter_shape_and_strides +
1342+ /* strides */ iter_shape_and_strides +
13431343 2 * iter_nd};
13441344
13451345 InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
@@ -1424,8 +1424,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
14241424 py::ssize_t reduction_arg_offset,
14251425 const std::vector<sycl::event> &depends)
14261426{
1427- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
1428- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
1427+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1428+ iter_arg_offset + reduction_arg_offset;
1429+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
14291430
14301431 constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
14311432
@@ -1767,8 +1768,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
17671768 py::ssize_t reduction_arg_offset,
17681769 const std::vector<sycl::event> &depends)
17691770{
1770- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
1771- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
1771+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1772+ iter_arg_offset + reduction_arg_offset;
1773+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
17721774
17731775 constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
17741776
@@ -4258,8 +4260,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
42584260 py::ssize_t reduction_arg_offset,
42594261 const std::vector<sycl::event> &depends)
42604262{
4261- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
4262- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
4263+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
4264+ iter_arg_offset + reduction_arg_offset;
4265+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
42634266
42644267 constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
42654268 constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4635,8 +4638,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
46354638 py::ssize_t reduction_arg_offset,
46364639 const std::vector<sycl::event> &depends)
46374640{
4638- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
4639- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
4641+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
4642+ iter_arg_offset + reduction_arg_offset;
4643+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
46404644
46414645 constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
46424646 constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
0 commit comments