@@ -385,6 +385,9 @@ py_extract(dpctl::tensor::usm_ndarray src,
385385 auto fn =
386386 masked_extract_all_slices_strided_impl_dispatch_vector[src_typeid];
387387
388+ assert (dst_shape_vec.size () == 1 );
389+ assert (dst_strides_vec.size () == 1 );
390+
388391 using dpctl::tensor::offset_utils::device_allocate_and_pack;
389392 const auto &ptr_size_event_tuple1 =
390393 device_allocate_and_pack<py::ssize_t >(
@@ -397,9 +400,6 @@ py_extract(dpctl::tensor::usm_ndarray src,
397400 sycl::event copy_src_shape_strides_ev =
398401 std::get<2 >(ptr_size_event_tuple1);
399402
400- assert (dst_shape_vec.size () == 1 );
401- assert (dst_strides_vec.size () == 1 );
402-
403403 std::vector<sycl::event> all_deps;
404404 all_deps.reserve (depends.size () + 1 );
405405 all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
@@ -469,40 +469,31 @@ py_extract(dpctl::tensor::usm_ndarray src,
469469 simplified_ortho_shape, simplified_ortho_src_strides,
470470 simplified_ortho_dst_strides, ortho_src_offset, ortho_dst_offset);
471471
472+ assert (masked_dst_shape.size () == 1 );
473+ assert (masked_dst_strides.size () == 1 );
474+
472475 using dpctl::tensor::offset_utils::device_allocate_and_pack;
473476 const auto &ptr_size_event_tuple1 =
474477 device_allocate_and_pack<py::ssize_t >(
475478 exec_q, host_task_events, simplified_ortho_shape,
476- simplified_ortho_src_strides, simplified_ortho_dst_strides);
477- py:: ssize_t *packed_ortho_src_dst_shape_strides =
478- std::get<0 >(ptr_size_event_tuple1);
479- if (packed_ortho_src_dst_shape_strides == nullptr ) {
479+ simplified_ortho_src_strides, simplified_ortho_dst_strides,
480+ masked_src_shape, masked_src_strides);
481+ py:: ssize_t *packed_shapes_strides = std::get<0 >(ptr_size_event_tuple1);
482+ if (packed_shapes_strides == nullptr ) {
480483 throw std::runtime_error (" Unable to allocate device memory" );
481484 }
482- sycl::event copy_shape_strides_ev1 = std::get<2 >(ptr_size_event_tuple1);
485+ sycl::event copy_shapes_strides_ev = std::get<2 >(ptr_size_event_tuple1);
483486
484- const auto &ptr_size_event_tuple2 =
485- device_allocate_and_pack<py::ssize_t >(
486- exec_q, host_task_events, masked_src_shape, masked_src_strides);
487+ py::ssize_t *packed_ortho_src_dst_shape_strides = packed_shapes_strides;
487488 py::ssize_t *packed_masked_src_shape_strides =
488- std::get<0 >(ptr_size_event_tuple2);
489- if (packed_masked_src_shape_strides == nullptr ) {
490- copy_shape_strides_ev1.wait ();
491- sycl::free (packed_ortho_src_dst_shape_strides, exec_q);
492- throw std::runtime_error (" Unable to allocate device memory" );
493- }
494- sycl::event copy_shape_strides_ev2 = std::get<2 >(ptr_size_event_tuple2);
495-
496- assert (masked_dst_shape.size () == 1 );
497- assert (masked_dst_strides.size () == 1 );
489+ packed_shapes_strides + (3 * ortho_nd);
498490
499491 std::vector<sycl::event> all_deps;
500- all_deps.reserve (depends.size () + 2 );
492+ all_deps.reserve (depends.size () + 1 );
501493 all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
502- all_deps.push_back (copy_shape_strides_ev1);
503- all_deps.push_back (copy_shape_strides_ev2);
494+ all_deps.push_back (copy_shapes_strides_ev);
504495
505- assert (all_deps.size () == depends.size () + 2 );
496+ assert (all_deps.size () == depends.size () + 1 );
506497
507498 // OrthogIndexerT orthog_src_dst_indexer_, MaskedIndexerT
508499 // masked_src_indexer_, MaskedIndexerT masked_dst_indexer_
@@ -520,10 +511,8 @@ py_extract(dpctl::tensor::usm_ndarray src,
520511 exec_q.submit ([&](sycl::handler &cgh) {
521512 cgh.depends_on (extract_ev);
522513 auto ctx = exec_q.get_context ();
523- cgh.host_task ([ctx, packed_ortho_src_dst_shape_strides,
524- packed_masked_src_shape_strides] {
525- sycl::free (packed_ortho_src_dst_shape_strides, ctx);
526- sycl::free (packed_masked_src_shape_strides, ctx);
514+ cgh.host_task ([ctx, packed_shapes_strides] {
515+ sycl::free (packed_shapes_strides, ctx);
527516 });
528517 });
529518 host_task_events.push_back (cleanup_tmp_allocations_ev);
@@ -684,13 +673,16 @@ py_place(dpctl::tensor::usm_ndarray dst,
684673 auto rhs_shape_vec = rhs.get_shape_vector ();
685674 auto rhs_strides_vec = rhs.get_strides_vector ();
686675
687- sycl::event extract_ev ;
676+ sycl::event place_ev ;
688677 std::vector<sycl::event> host_task_events{};
689678 if (axis_start == 0 && axis_end == dst_nd) {
690679 // empty orthogonal directions
691680 auto fn =
692681 masked_place_all_slices_strided_impl_dispatch_vector[dst_typeid];
693682
683+ assert (rhs_shape_vec.size () == 1 );
684+ assert (rhs_strides_vec.size () == 1 );
685+
694686 using dpctl::tensor::offset_utils::device_allocate_and_pack;
695687 const auto &ptr_size_event_tuple1 =
696688 device_allocate_and_pack<py::ssize_t >(
@@ -703,23 +695,20 @@ py_place(dpctl::tensor::usm_ndarray dst,
703695 sycl::event copy_dst_shape_strides_ev =
704696 std::get<2 >(ptr_size_event_tuple1);
705697
706- assert (rhs_shape_vec.size () == 1 );
707- assert (rhs_strides_vec.size () == 1 );
708-
709698 std::vector<sycl::event> all_deps;
710699 all_deps.reserve (depends.size () + 1 );
711700 all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
712701 all_deps.push_back (copy_dst_shape_strides_ev);
713702
714703 assert (all_deps.size () == depends.size () + 1 );
715704
716- extract_ev = fn (exec_q, cumsum_sz, dst_data_p, cumsum_data_p,
717- rhs_data_p, dst_nd, packed_dst_shape_strides,
718- rhs_shape_vec[ 0 ], rhs_strides_vec[0 ], all_deps);
705+ place_ev = fn (exec_q, cumsum_sz, dst_data_p, cumsum_data_p, rhs_data_p ,
706+ dst_nd, packed_dst_shape_strides, rhs_shape_vec[ 0 ] ,
707+ rhs_strides_vec[0 ], all_deps);
719708
720709 sycl::event cleanup_tmp_allocations_ev =
721710 exec_q.submit ([&](sycl::handler &cgh) {
722- cgh.depends_on (extract_ev );
711+ cgh.depends_on (place_ev );
723712 auto ctx = exec_q.get_context ();
724713 cgh.host_task ([ctx, packed_dst_shape_strides] {
725714 sycl::free (packed_dst_shape_strides, ctx);
@@ -774,69 +763,59 @@ py_place(dpctl::tensor::usm_ndarray dst,
774763 simplified_ortho_shape, simplified_ortho_dst_strides,
775764 simplified_ortho_rhs_strides, ortho_dst_offset, ortho_rhs_offset);
776765
766+ assert (masked_rhs_shape.size () == 1 );
767+ assert (masked_rhs_strides.size () == 1 );
768+
777769 using dpctl::tensor::offset_utils::device_allocate_and_pack;
778770 const auto &ptr_size_event_tuple1 =
779771 device_allocate_and_pack<py::ssize_t >(
780772 exec_q, host_task_events, simplified_ortho_shape,
781- simplified_ortho_dst_strides, simplified_ortho_rhs_strides);
782- py:: ssize_t *packed_ortho_dst_rhs_shape_strides =
783- std::get<0 >(ptr_size_event_tuple1);
784- if (packed_ortho_dst_rhs_shape_strides == nullptr ) {
773+ simplified_ortho_dst_strides, simplified_ortho_rhs_strides,
774+ masked_dst_shape, masked_dst_strides);
775+ py:: ssize_t *packed_shapes_strides = std::get<0 >(ptr_size_event_tuple1);
776+ if (packed_shapes_strides == nullptr ) {
785777 throw std::runtime_error (" Unable to allocate device memory" );
786778 }
787- sycl::event copy_shape_strides_ev1 = std::get<2 >(ptr_size_event_tuple1);
779+ sycl::event copy_shapes_strides_ev = std::get<2 >(ptr_size_event_tuple1);
788780
789- auto ptr_size_event_tuple2 = device_allocate_and_pack<py::ssize_t >(
790- exec_q, host_task_events, masked_dst_shape, masked_dst_strides);
781+ py::ssize_t *packed_ortho_dst_rhs_shape_strides = packed_shapes_strides;
791782 py::ssize_t *packed_masked_dst_shape_strides =
792- std::get<0 >(ptr_size_event_tuple2);
793- if (packed_masked_dst_shape_strides == nullptr ) {
794- copy_shape_strides_ev1.wait ();
795- sycl::free (packed_ortho_dst_rhs_shape_strides, exec_q);
796- throw std::runtime_error (" Unable to allocate device memory" );
797- }
798- sycl::event copy_shape_strides_ev2 = std::get<2 >(ptr_size_event_tuple2);
799-
800- assert (masked_rhs_shape.size () == 1 );
801- assert (masked_rhs_strides.size () == 1 );
783+ packed_shapes_strides + (3 * ortho_nd);
802784
803785 std::vector<sycl::event> all_deps;
804- all_deps.reserve (depends.size () + 2 );
786+ all_deps.reserve (depends.size () + 1 );
805787 all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
806- all_deps.push_back (copy_shape_strides_ev1);
807- all_deps.push_back (copy_shape_strides_ev2);
808-
809- assert (all_deps.size () == depends.size () + 2 );
810-
811- extract_ev = fn (exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
812- cumsum_data_p, rhs_data_p,
813- // data to build orthog_dst_rhs_indexer
814- ortho_nd, packed_ortho_dst_rhs_shape_strides,
815- ortho_dst_offset, ortho_rhs_offset,
816- // data to build masked_dst_indexer
817- masked_dst_nd, packed_masked_dst_shape_strides,
818- // data to build masked_dst_indexer,
819- masked_rhs_shape[0 ], masked_rhs_strides[0 ], all_deps);
788+ all_deps.push_back (copy_shapes_strides_ev);
789+
790+ assert (all_deps.size () == depends.size () + 1 );
791+
792+ place_ev = fn (exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
793+ cumsum_data_p, rhs_data_p,
794+ // data to build orthog_dst_rhs_indexer
795+ ortho_nd, packed_ortho_dst_rhs_shape_strides,
796+ ortho_dst_offset, ortho_rhs_offset,
797+ // data to build masked_dst_indexer
798+ masked_dst_nd, packed_masked_dst_shape_strides,
799+ // data to build masked_dst_indexer,
800+ masked_rhs_shape[0 ], masked_rhs_strides[0 ], all_deps);
820801
821802 sycl::event cleanup_tmp_allocations_ev =
822803 exec_q.submit ([&](sycl::handler &cgh) {
823- cgh.depends_on (extract_ev );
804+ cgh.depends_on (place_ev );
824805 auto ctx = exec_q.get_context ();
825- cgh.host_task ([ctx, packed_ortho_dst_rhs_shape_strides,
826- packed_masked_dst_shape_strides] {
827- sycl::free (packed_ortho_dst_rhs_shape_strides, ctx);
828- sycl::free (packed_masked_dst_shape_strides, ctx);
806+ cgh.host_task ([ctx, packed_shapes_strides] {
807+ sycl::free (packed_shapes_strides, ctx);
829808 });
830809 });
831810 host_task_events.push_back (cleanup_tmp_allocations_ev);
832811 }
833812
834- host_task_events.push_back (extract_ev );
813+ host_task_events.push_back (place_ev );
835814
836815 sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive (
837816 exec_q, {dst, cumsum, rhs}, host_task_events);
838817
839- return std::make_pair (py_obj_management_host_task_ev, extract_ev );
818+ return std::make_pair (py_obj_management_host_task_ev, place_ev );
840819}
841820
842821// Non-zero
0 commit comments