@@ -473,36 +473,27 @@ py_extract(dpctl::tensor::usm_ndarray src,
473473 const auto &ptr_size_event_tuple1 =
474474 device_allocate_and_pack<py::ssize_t >(
475475 exec_q, host_task_events, simplified_ortho_shape,
476- simplified_ortho_src_strides, simplified_ortho_dst_strides);
477- py:: ssize_t *packed_ortho_src_dst_shape_strides =
478- std::get<0 >(ptr_size_event_tuple1);
479- if (packed_ortho_src_dst_shape_strides == nullptr ) {
476+ simplified_ortho_src_strides, simplified_ortho_dst_strides,
477+ masked_src_shape, masked_src_strides);
478+ py:: ssize_t *packed_shapes_strides = std::get<0 >(ptr_size_event_tuple1);
479+ if (packed_shapes_strides == nullptr ) {
480480 throw std::runtime_error (" Unable to allocate device memory" );
481481 }
482- sycl::event copy_shape_strides_ev1 = std::get<2 >(ptr_size_event_tuple1);
482+ sycl::event copy_shapes_strides_ev = std::get<2 >(ptr_size_event_tuple1);
483483
484- const auto &ptr_size_event_tuple2 =
485- device_allocate_and_pack<py::ssize_t >(
486- exec_q, host_task_events, masked_src_shape, masked_src_strides);
484+ py::ssize_t *packed_ortho_src_dst_shape_strides = packed_shapes_strides;
487485 py::ssize_t *packed_masked_src_shape_strides =
488- std::get<0 >(ptr_size_event_tuple2);
489- if (packed_masked_src_shape_strides == nullptr ) {
490- copy_shape_strides_ev1.wait ();
491- sycl::free (packed_ortho_src_dst_shape_strides, exec_q);
492- throw std::runtime_error (" Unable to allocate device memory" );
493- }
494- sycl::event copy_shape_strides_ev2 = std::get<2 >(ptr_size_event_tuple2);
486+ packed_shapes_strides + (3 * ortho_nd);
495487
496488 assert (masked_dst_shape.size () == 1 );
497489 assert (masked_dst_strides.size () == 1 );
498490
499491 std::vector<sycl::event> all_deps;
500- all_deps.reserve (depends.size () + 2 );
492+ all_deps.reserve (depends.size () + 1 );
501493 all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
502- all_deps.push_back (copy_shape_strides_ev1);
503- all_deps.push_back (copy_shape_strides_ev2);
494+ all_deps.push_back (copy_shapes_strides_ev);
504495
505- assert (all_deps.size () == depends.size () + 2 );
496+ assert (all_deps.size () == depends.size () + 1 );
506497
507498 // OrthogIndexerT orthog_src_dst_indexer_, MaskedIndexerT
508499 // masked_src_indexer_, MaskedIndexerT masked_dst_indexer_
@@ -520,10 +511,8 @@ py_extract(dpctl::tensor::usm_ndarray src,
520511 exec_q.submit ([&](sycl::handler &cgh) {
521512 cgh.depends_on (extract_ev);
522513 auto ctx = exec_q.get_context ();
523- cgh.host_task ([ctx, packed_ortho_src_dst_shape_strides,
524- packed_masked_src_shape_strides] {
525- sycl::free (packed_ortho_src_dst_shape_strides, ctx);
526- sycl::free (packed_masked_src_shape_strides, ctx);
514+ cgh.host_task ([ctx, packed_shapes_strides] {
515+ sycl::free (packed_shapes_strides, ctx);
527516 });
528517 });
529518 host_task_events.push_back (cleanup_tmp_allocations_ev);
@@ -684,7 +673,7 @@ py_place(dpctl::tensor::usm_ndarray dst,
684673 auto rhs_shape_vec = rhs.get_shape_vector ();
685674 auto rhs_strides_vec = rhs.get_strides_vector ();
686675
687- sycl::event extract_ev ;
676+ sycl::event place_ev ;
688677 std::vector<sycl::event> host_task_events{};
689678 if (axis_start == 0 && axis_end == dst_nd) {
690679 // empty orthogonal directions
@@ -713,13 +702,13 @@ py_place(dpctl::tensor::usm_ndarray dst,
713702
714703 assert (all_deps.size () == depends.size () + 1 );
715704
716- extract_ev = fn (exec_q, cumsum_sz, dst_data_p, cumsum_data_p,
717- rhs_data_p, dst_nd, packed_dst_shape_strides,
718- rhs_shape_vec[ 0 ], rhs_strides_vec[0 ], all_deps);
705+ place_ev = fn (exec_q, cumsum_sz, dst_data_p, cumsum_data_p, rhs_data_p ,
706+ dst_nd, packed_dst_shape_strides, rhs_shape_vec[ 0 ] ,
707+ rhs_strides_vec[0 ], all_deps);
719708
720709 sycl::event cleanup_tmp_allocations_ev =
721710 exec_q.submit ([&](sycl::handler &cgh) {
722- cgh.depends_on (extract_ev );
711+ cgh.depends_on (place_ev );
723712 auto ctx = exec_q.get_context ();
724713 cgh.host_task ([ctx, packed_dst_shape_strides] {
725714 sycl::free (packed_dst_shape_strides, ctx);
@@ -778,65 +767,55 @@ py_place(dpctl::tensor::usm_ndarray dst,
778767 const auto &ptr_size_event_tuple1 =
779768 device_allocate_and_pack<py::ssize_t >(
780769 exec_q, host_task_events, simplified_ortho_shape,
781- simplified_ortho_dst_strides, simplified_ortho_rhs_strides);
782- py:: ssize_t *packed_ortho_dst_rhs_shape_strides =
783- std::get<0 >(ptr_size_event_tuple1);
784- if (packed_ortho_dst_rhs_shape_strides == nullptr ) {
770+ simplified_ortho_dst_strides, simplified_ortho_rhs_strides,
771+ masked_dst_shape, masked_dst_strides);
772+ py:: ssize_t *packed_shapes_strides = std::get<0 >(ptr_size_event_tuple1);
773+ if (packed_shapes_strides == nullptr ) {
785774 throw std::runtime_error (" Unable to allocate device memory" );
786775 }
787- sycl::event copy_shape_strides_ev1 = std::get<2 >(ptr_size_event_tuple1);
776+ sycl::event copy_shapes_strides_ev = std::get<2 >(ptr_size_event_tuple1);
788777
789- auto ptr_size_event_tuple2 = device_allocate_and_pack<py::ssize_t >(
790- exec_q, host_task_events, masked_dst_shape, masked_dst_strides);
778+ py::ssize_t *packed_ortho_dst_rhs_shape_strides = packed_shapes_strides;
791779 py::ssize_t *packed_masked_dst_shape_strides =
792- std::get<0 >(ptr_size_event_tuple2);
793- if (packed_masked_dst_shape_strides == nullptr ) {
794- copy_shape_strides_ev1.wait ();
795- sycl::free (packed_ortho_dst_rhs_shape_strides, exec_q);
796- throw std::runtime_error (" Unable to allocate device memory" );
797- }
798- sycl::event copy_shape_strides_ev2 = std::get<2 >(ptr_size_event_tuple2);
780+ packed_shapes_strides + (3 * ortho_nd);
799781
800782 assert (masked_rhs_shape.size () == 1 );
801783 assert (masked_rhs_strides.size () == 1 );
802784
803785 std::vector<sycl::event> all_deps;
804- all_deps.reserve (depends.size () + 2 );
786+ all_deps.reserve (depends.size () + 1 );
805787 all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
806- all_deps.push_back (copy_shape_strides_ev1);
807- all_deps.push_back (copy_shape_strides_ev2);
808-
809- assert (all_deps.size () == depends.size () + 2 );
810-
811- extract_ev = fn (exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
812- cumsum_data_p, rhs_data_p,
813- // data to build orthog_dst_rhs_indexer
814- ortho_nd, packed_ortho_dst_rhs_shape_strides,
815- ortho_dst_offset, ortho_rhs_offset,
816- // data to build masked_dst_indexer
817- masked_dst_nd, packed_masked_dst_shape_strides,
818- // data to build masked_dst_indexer,
819- masked_rhs_shape[0 ], masked_rhs_strides[0 ], all_deps);
788+ all_deps.push_back (copy_shapes_strides_ev);
789+
790+ assert (all_deps.size () == depends.size () + 1 );
791+
792+ place_ev = fn (exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
793+ cumsum_data_p, rhs_data_p,
794+ // data to build orthog_dst_rhs_indexer
795+ ortho_nd, packed_ortho_dst_rhs_shape_strides,
796+ ortho_dst_offset, ortho_rhs_offset,
797+ // data to build masked_dst_indexer
798+ masked_dst_nd, packed_masked_dst_shape_strides,
799+ // data to build masked_dst_indexer,
800+ masked_rhs_shape[0 ], masked_rhs_strides[0 ], all_deps);
820801
821802 sycl::event cleanup_tmp_allocations_ev =
822803 exec_q.submit ([&](sycl::handler &cgh) {
823- cgh.depends_on (extract_ev );
804+ cgh.depends_on (place_ev );
824805 auto ctx = exec_q.get_context ();
825- cgh.host_task ([ctx, packed_ortho_dst_rhs_shape_strides,
826- packed_masked_dst_shape_strides] {
827- sycl::free (packed_ortho_dst_rhs_shape_strides, ctx);
828- sycl::free (packed_masked_dst_shape_strides, ctx);
806+ cgh.host_task ([ctx, packed_shapes_strides] {
807+ sycl::free (packed_shapes_strides, ctx);
829808 });
830809 });
831810 host_task_events.push_back (cleanup_tmp_allocations_ev);
832811 }
833812
834- host_task_events.push_back (extract_ev );
813+ host_task_events.push_back (place_ev );
835814
836815 sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive (
837816 exec_q, {dst, cumsum, rhs}, host_task_events);
838817
839- return std::make_pair (py_obj_management_host_task_ev, extract_ev );
818+ return std::make_pair (py_obj_management_host_task_ev, place_ev );
840819}
841820
842821// Non-zero
0 commit comments