2626#include < cstddef>
2727#include < cstdint>
2828#include < limits>
29- #include < sycl/sycl.hpp>
3029#include < utility>
3130#include < vector>
3231
32+ #include < sycl/sycl.hpp>
33+
3334#include " dpctl_tensor_types.hpp"
3435#include " utils/offset_utils.hpp"
3536#include " utils/type_dispatch_building.hpp"
@@ -599,6 +600,10 @@ sycl::event masked_place_all_slices_strided_impl(
599600 sycl::nd_range<2 > ndRange{gRange , lRange};
600601
601602 using LocalAccessorT = sycl::local_accessor<indT, 1 >;
603+ using Impl =
604+ MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
605+ Strided1DCyclicIndexer, dataT, indT,
606+ LocalAccessorT>;
602607
603608 dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
604609 const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
@@ -611,13 +616,9 @@ sycl::event masked_place_all_slices_strided_impl(
611616 LocalAccessorT lacc (lacc_size, cgh);
612617
613618 cgh.parallel_for <KernelName>(
614- ndRange,
615- MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
616- Strided1DCyclicIndexer, dataT, indT,
617- LocalAccessorT>(
618- dst_tp, cumsum_tp, rhs_tp, iteration_size,
619- orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
620- lacc));
619+ ndRange, Impl (dst_tp, cumsum_tp, rhs_tp, iteration_size,
620+ orthog_dst_rhs_indexer, masked_dst_indexer,
621+ masked_rhs_indexer, lacc));
621622 });
622623
623624 return comp_ev;
@@ -696,6 +697,10 @@ sycl::event masked_place_some_slices_strided_impl(
696697 sycl::nd_range<2 > ndRange{gRange , lRange};
697698
698699 using LocalAccessorT = sycl::local_accessor<indT, 1 >;
700+ using Impl =
701+ MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
702+ Strided1DCyclicIndexer, dataT, indT,
703+ LocalAccessorT>;
699704
700705 dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
701706 const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
@@ -708,13 +713,9 @@ sycl::event masked_place_some_slices_strided_impl(
708713 LocalAccessorT lacc (lacc_size, cgh);
709714
710715 cgh.parallel_for <KernelName>(
711- ndRange,
712- MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
713- Strided1DCyclicIndexer, dataT, indT,
714- LocalAccessorT>(
715- dst_tp, cumsum_tp, rhs_tp, masked_nelems,
716- orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
717- lacc));
716+ ndRange, Impl (dst_tp, cumsum_tp, rhs_tp, masked_nelems,
717+ orthog_dst_rhs_indexer, masked_dst_indexer,
718+ masked_rhs_indexer, lacc));
718719 });
719720
720721 return comp_ev;
0 commit comments