@@ -238,38 +238,30 @@ py_boolean_reduction(dpctl::tensor::usm_ndarray src,
238238
239239 auto fn = strided_dispatch_vector[src_typeid];
240240
241+ // using a single host_task for packing here
242+ // prevents crashes on CPU
241243 std::vector<sycl::event> host_task_events{};
242- const auto &iter_src_dst_metadata_packing_triple_ =
244+ const auto &iter_red_metadata_packing_triple_ =
243245 dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t >(
244246 exec_q, host_task_events, simplified_iter_shape,
245- simplified_iter_src_strides, simplified_iter_dst_strides);
246- py::ssize_t *iter_shape_and_strides =
247- std::get<0 >(iter_src_dst_metadata_packing_triple_);
248- if (iter_shape_and_strides == nullptr ) {
247+ simplified_iter_src_strides, simplified_iter_dst_strides,
248+ simplified_red_shape, simplified_red_src_strides);
249+ py::ssize_t *packed_shapes_and_strides =
250+ std::get<0 >(iter_red_metadata_packing_triple_);
251+ if (packed_shapes_and_strides == nullptr ) {
249252 throw std::runtime_error (" Unable to allocate memory on device" );
250253 }
251- const auto ©_iter_metadata_ev =
252- std::get<2 >(iter_src_dst_metadata_packing_triple_ );
254+ const auto ©_metadata_ev =
255+ std::get<2 >(iter_red_metadata_packing_triple_ );
253256
254- const auto &red_metadata_packing_triple_ =
255- dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t >(
256- exec_q, host_task_events, simplified_red_shape,
257- simplified_red_src_strides);
258- py::ssize_t *red_shape_stride = std::get<0 >(red_metadata_packing_triple_);
259- if (red_shape_stride == nullptr ) {
260- sycl::event::wait (host_task_events);
261- sycl::free (iter_shape_and_strides, exec_q);
262- throw std::runtime_error (" Unable to allocate memory on device" );
263- }
264- const auto ©_red_metadata_ev =
265- std::get<2 >(red_metadata_packing_triple_);
257+ py::ssize_t *iter_shape_and_strides = packed_shapes_and_strides;
258+ py::ssize_t *red_shape_stride = packed_shapes_and_strides + (3 * iter_nd);
266259
267260 std::vector<sycl::event> all_deps;
268- all_deps.reserve (depends.size () + 2 );
261+ all_deps.reserve (depends.size () + 1 );
269262 all_deps.resize (depends.size ());
270263 std::copy (depends.begin (), depends.end (), all_deps.begin ());
271- all_deps.push_back (copy_iter_metadata_ev);
272- all_deps.push_back (copy_red_metadata_ev);
264+ all_deps.push_back (copy_metadata_ev);
273265
274266 auto red_ev =
275267 fn (exec_q, dst_nelems, red_nelems, src_data, dst_data, dst_nd,
@@ -279,9 +271,8 @@ py_boolean_reduction(dpctl::tensor::usm_ndarray src,
279271 sycl::event temp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
280272 cgh.depends_on (red_ev);
281273 auto ctx = exec_q.get_context ();
282- cgh.host_task ([ctx, iter_shape_and_strides, red_shape_stride] {
283- sycl::free (iter_shape_and_strides, ctx);
284- sycl::free (red_shape_stride, ctx);
274+ cgh.host_task ([ctx, packed_shapes_and_strides] {
275+ sycl::free (packed_shapes_and_strides, ctx);
285276 });
286277 });
287278 host_task_events.push_back (temp_cleanup_ev);
0 commit comments