2626#include < cstddef>
2727#include < cstdint>
2828#include < stdexcept>
29- #include < sycl/sycl.hpp>
3029#include < utility>
3130
31+ #include < sycl/sycl.hpp>
32+
3233#include " kernels/alignment.hpp"
3334#include " kernels/dpctl_tensor_types.hpp"
35+ #include " kernels/elementwise_functions/common_detail.hpp"
3436#include " utils/offset_utils.hpp"
3537#include " utils/sycl_alloc_utils.hpp"
3638#include " utils/sycl_utils.hpp"
@@ -324,21 +326,23 @@ sycl::event unary_contig_impl(sycl::queue &exec_q,
324326 {
325327 constexpr bool enable_sg_loadstore = true ;
326328 using KernelName = BaseKernelName;
329+ using Impl = ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
330+ enable_sg_loadstore>;
327331
328332 cgh.parallel_for <KernelName>(
329333 sycl::nd_range<1 >(gws_range, lws_range),
330- ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
331- enable_sg_loadstore>(arg_tp, res_tp, nelems));
334+ Impl (arg_tp, res_tp, nelems));
332335 }
333336 else {
334337 constexpr bool disable_sg_loadstore = false ;
335338 using KernelName =
336339 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
340+ using Impl = ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
341+ disable_sg_loadstore>;
337342
338343 cgh.parallel_for <KernelName>(
339344 sycl::nd_range<1 >(gws_range, lws_range),
340- ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
341- disable_sg_loadstore>(arg_tp, res_tp, nelems));
345+ Impl (arg_tp, res_tp, nelems));
342346 }
343347 });
344348
@@ -377,9 +381,10 @@ unary_strided_impl(sycl::queue &exec_q,
377381 const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_p);
378382 resTy *res_tp = reinterpret_cast <resTy *>(res_p);
379383
384+ using Impl = StridedFunctorT<argTy, resTy, IndexerT>;
385+
380386 cgh.parallel_for <kernel_name<argTy, resTy, IndexerT>>(
381- {nelems},
382- StridedFunctorT<argTy, resTy, IndexerT>(arg_tp, res_tp, indexer));
387+ {nelems}, Impl (arg_tp, res_tp, indexer));
383388 });
384389 return comp_ev;
385390}
@@ -814,22 +819,23 @@ sycl::event binary_contig_impl(sycl::queue &exec_q,
814819 {
815820 constexpr bool enable_sg_loadstore = true ;
816821 using KernelName = BaseKernelName;
822+ using Impl = BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz,
823+ n_vecs, enable_sg_loadstore>;
817824
818825 cgh.parallel_for <KernelName>(
819826 sycl::nd_range<1 >(gws_range, lws_range),
820- BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
821- enable_sg_loadstore>(arg1_tp, arg2_tp,
822- res_tp, nelems));
827+ Impl (arg1_tp, arg2_tp, res_tp, nelems));
823828 }
824829 else {
825830 constexpr bool disable_sg_loadstore = false ;
826831 using KernelName =
827832 disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
833+ using Impl = BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz,
834+ n_vecs, disable_sg_loadstore>;
835+
828836 cgh.parallel_for <KernelName>(
829837 sycl::nd_range<1 >(gws_range, lws_range),
830- BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
831- disable_sg_loadstore>(arg1_tp, arg2_tp,
832- res_tp, nelems));
838+ Impl (arg1_tp, arg2_tp, res_tp, nelems));
833839 }
834840 });
835841 return comp_ev;
@@ -873,9 +879,10 @@ binary_strided_impl(sycl::queue &exec_q,
873879 const argTy2 *arg2_tp = reinterpret_cast <const argTy2 *>(arg2_p);
874880 resTy *res_tp = reinterpret_cast <resTy *>(res_p);
875881
882+ using Impl = BinaryStridedFunctorT<argTy1, argTy2, resTy, IndexerT>;
883+
876884 cgh.parallel_for <kernel_name<argTy1, argTy2, resTy, IndexerT>>(
877- {nelems}, BinaryStridedFunctorT<argTy1, argTy2, resTy, IndexerT>(
878- arg1_tp, arg2_tp, res_tp, indexer));
885+ {nelems}, Impl (arg1_tp, arg2_tp, res_tp, indexer));
879886 });
880887 return comp_ev;
881888}
@@ -917,13 +924,9 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl(
917924 exec_q);
918925 argT2 *padded_vec = padded_vec_owner.get ();
919926
920- sycl::event make_padded_vec_ev = exec_q.submit ([&](sycl::handler &cgh) {
921- cgh.depends_on (depends); // ensure vec contains actual data
922- cgh.parallel_for ({n1_padded}, [=](sycl::id<1 > id) {
923- auto i = id[0 ];
924- padded_vec[i] = vec[i % n1];
925- });
926- });
927+ sycl::event make_padded_vec_ev =
928+ dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
929+ argT2>(exec_q, vec, n1, padded_vec, n1_padded, depends);
927930
928931 // sub-group spans work-items [I, I + sgSize)
929932 // base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -942,10 +945,12 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl(
942945 std::size_t n_groups = (n_elems + lws - 1 ) / lws;
943946 auto gwsRange = sycl::range<1 >(n_groups * lws);
944947
948+ using Impl =
949+ BinaryContigMatrixContigRowBroadcastFunctorT<argT1, argT2, resT>;
950+
945951 cgh.parallel_for <class kernel_name <argT1, argT2, resT>>(
946952 sycl::nd_range<1 >(gwsRange, lwsRange),
947- BinaryContigMatrixContigRowBroadcastFunctorT<argT1, argT2, resT>(
948- mat, padded_vec, res, n_elems, n1));
953+ Impl (mat, padded_vec, res, n_elems, n1));
949954 });
950955
951956 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
@@ -993,13 +998,9 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl(
993998 exec_q);
994999 argT2 *padded_vec = padded_vec_owner.get ();
9951000
996- sycl::event make_padded_vec_ev = exec_q.submit ([&](sycl::handler &cgh) {
997- cgh.depends_on (depends); // ensure vec contains actual data
998- cgh.parallel_for ({n1_padded}, [=](sycl::id<1 > id) {
999- auto i = id[0 ];
1000- padded_vec[i] = vec[i % n1];
1001- });
1002- });
1001+ sycl::event make_padded_vec_ev =
1002+ dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
1003+ argT2>(exec_q, vec, n1, padded_vec, n1_padded, depends);
10031004
10041005 // sub-group spans work-items [I, I + sgSize)
10051006 // base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -1018,10 +1019,12 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl(
10181019 std::size_t n_groups = (n_elems + lws - 1 ) / lws;
10191020 auto gwsRange = sycl::range<1 >(n_groups * lws);
10201021
1022+ using Impl =
1023+ BinaryContigRowContigMatrixBroadcastFunctorT<argT1, argT2, resT>;
1024+
10211025 cgh.parallel_for <class kernel_name <argT1, argT2, resT>>(
10221026 sycl::nd_range<1 >(gwsRange, lwsRange),
1023- BinaryContigRowContigMatrixBroadcastFunctorT<argT1, argT2, resT>(
1024- padded_vec, mat, res, n_elems, n1));
1027+ Impl (padded_vec, mat, res, n_elems, n1));
10251028 });
10261029
10271030 sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
0 commit comments