3535#include < pybind11/stl.h>
3636
3737#include " kernels/reductions.hpp"
38- #include " reductions .hpp"
38+ #include " sum_reductions .hpp"
3939
4040#include " simplify_iteration_space.hpp"
41+ #include " utils/memory_overlap.hpp"
4142#include " utils/offset_utils.hpp"
4243#include " utils/type_dispatch.hpp"
4344
@@ -135,9 +136,23 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
135136 reduction_nelems *= static_cast <size_t >(src_shape_ptr[i]);
136137 }
137138
138- // FIXME: check that dst and src do not overlap
139- // check that dst is ample enough (memory span is sufficient
140- // to accommodate all elements)
139+ // check that dst and src do not overlap
140+ auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
141+ if (overlap (src, dst)) {
142+ throw py::value_error (" Arrays index overlapping segments of memory" );
143+ }
144+
145+ // destination must be ample enough to accomodate all elements
146+ {
147+ auto dst_offsets = dst.get_minmax_offsets ();
148+ size_t range =
149+ static_cast <size_t >(dst_offsets.second - dst_offsets.first );
150+ if (range + 1 < dst_nelems) {
151+ throw py::value_error (
152+ " Destination array can not accomodate all the "
153+ " elements of source array." );
154+ }
155+ }
141156
142157 int src_typenum = src.get_typenum ();
143158 int dst_typenum = dst.get_typenum ();
@@ -297,38 +312,33 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
297312 }
298313
299314 std::vector<sycl::event> host_task_events{};
300- const auto &iter_src_dst_metadata_packing_triple_ =
301- dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t >(
302- exec_q, host_task_events, simplified_iteration_shape,
303- simplified_iteration_src_strides, simplified_iteration_dst_strides);
304- py::ssize_t *iter_shape_and_strides =
305- std::get<0 >(iter_src_dst_metadata_packing_triple_);
306- if (iter_shape_and_strides == nullptr ) {
315+
316+ using dpctl::tensor::offset_utils::device_allocate_and_pack;
317+
318+ const auto &arrays_metainfo_packing_triple_ =
319+ device_allocate_and_pack<py::ssize_t >(
320+ exec_q, host_task_events,
321+ // iteration metadata
322+ simplified_iteration_shape, simplified_iteration_src_strides,
323+ simplified_iteration_dst_strides,
324+ // reduction metadata
325+ simplified_reduction_shape, simplified_reduction_src_strides);
326+ py::ssize_t *temp_allocation_ptr =
327+ std::get<0 >(arrays_metainfo_packing_triple_);
328+ if (temp_allocation_ptr == nullptr ) {
307329 throw std::runtime_error (" Unable to allocate memory on device" );
308330 }
309- const auto ©_iter_metadata_ev =
310- std::get<2 >(iter_src_dst_metadata_packing_triple_);
331+ const auto ©_metadata_ev = std::get<2 >(arrays_metainfo_packing_triple_);
311332
312- const auto &reduction_metadata_packing_triple_ =
313- dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t >(
314- exec_q, host_task_events, simplified_reduction_shape,
315- simplified_reduction_src_strides);
333+ py::ssize_t *iter_shape_and_strides = temp_allocation_ptr;
316334 py::ssize_t *reduction_shape_stride =
317- std::get<0 >(reduction_metadata_packing_triple_);
318- if (reduction_shape_stride == nullptr ) {
319- sycl::event::wait (host_task_events);
320- sycl::free (iter_shape_and_strides, exec_q);
321- throw std::runtime_error (" Unable to allocate memory on device" );
322- }
323- const auto ©_reduction_metadata_ev =
324- std::get<2 >(reduction_metadata_packing_triple_);
335+ temp_allocation_ptr + 3 * simplified_iteration_shape.size ();
325336
326337 std::vector<sycl::event> all_deps;
327- all_deps.reserve (depends.size () + 2 );
338+ all_deps.reserve (depends.size () + 1 );
328339 all_deps.resize (depends.size ());
329340 std::copy (depends.begin (), depends.end (), all_deps.begin ());
330- all_deps.push_back (copy_iter_metadata_ev);
331- all_deps.push_back (copy_reduction_metadata_ev);
341+ all_deps.push_back (copy_metadata_ev);
332342
333343 auto comp_ev = fn (exec_q, dst_nelems, reduction_nelems, src.get_data (),
334344 dst.get_data (), iteration_nd, iter_shape_and_strides,
@@ -339,9 +349,8 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
339349 sycl::event temp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
340350 cgh.depends_on (comp_ev);
341351 auto ctx = exec_q.get_context ();
342- cgh.host_task ([ctx, iter_shape_and_strides, reduction_shape_stride] {
343- sycl::free (iter_shape_and_strides, ctx);
344- sycl::free (reduction_shape_stride, ctx);
352+ cgh.host_task ([ctx, temp_allocation_ptr] {
353+ sycl::free (temp_allocation_ptr, ctx);
345354 });
346355 });
347356 host_task_events.push_back (temp_cleanup_ev);
0 commit comments