@@ -51,6 +51,7 @@ namespace kernels
5151{
5252
5353using dpctl::tensor::ssize_t ;
54+ namespace su_ns = dpctl::tensor::sycl_utils;
5455namespace tu_ns = dpctl::tensor::type_utils;
5556namespace exprm_ns = sycl::ext::oneapi::experimental;
5657
@@ -101,7 +102,7 @@ void scale_gemm_nm_parameters(const std::size_t &local_mem_size,
101102}
102103} // namespace gemm_detail
103104
104- using dpctl::tensor::sycl_utils ::choose_workgroup_size;
105+ using su_ns ::choose_workgroup_size;
105106
106107template <typename T1, typename T2, typename T3, typename T4, typename T5>
107108class gemm_seq_reduction_krn ;
@@ -2367,12 +2368,12 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q,
23672368 depends);
23682369 }
23692370 else {
2370- using ReductionOpT =
2371- typename std::conditional<std:: is_same_v<resTy, bool >,
2372- sycl::logical_or <resTy>,
2373- sycl::plus<resTy>>::type ;
2371+ using ReductionOpT = std:: conditional_t <
2372+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
2373+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
2374+ sycl::plus<resTy>>> ;
23742375 constexpr resTy identity_val =
2375- sycl::known_identity <ReductionOpT, resTy>::value;
2376+ su_ns::Identity <ReductionOpT, resTy>::value;
23762377
23772378 std::size_t iter_nelems = batch_nelems * n * m;
23782379 std::size_t reduction_nelems =
@@ -2663,12 +2664,12 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q,
26632664 lhs_indexer, rhs_indexer, res_indexer, depends);
26642665 }
26652666 else {
2666- using ReductionOpT =
2667- typename std::conditional<std:: is_same_v<resTy, bool >,
2668- sycl::logical_or <resTy>,
2669- sycl::plus<resTy>>::type ;
2667+ using ReductionOpT = std:: conditional_t <
2668+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
2669+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
2670+ sycl::plus<resTy>>> ;
26702671 constexpr resTy identity_val =
2671- sycl::known_identity <ReductionOpT, resTy>::value;
2672+ su_ns::Identity <ReductionOpT, resTy>::value;
26722673 std::size_t iter_nelems = batch_nelems * n * m;
26732674 std::size_t reduction_nelems = (k + wi_delta_k - 1 ) / wi_delta_k;
26742675
@@ -3034,12 +3035,12 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q,
30343035 depends);
30353036 }
30363037 else {
3037- using ReductionOpT =
3038- typename std::conditional<std:: is_same_v<resTy, bool >,
3039- sycl::logical_or <resTy>,
3040- sycl::plus<resTy>>::type ;
3038+ using ReductionOpT = std:: conditional_t <
3039+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3040+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3041+ sycl::plus<resTy>>> ;
30413042 constexpr resTy identity_val =
3042- sycl::known_identity <ReductionOpT, resTy>::value;
3043+ su_ns::Identity <ReductionOpT, resTy>::value;
30433044
30443045 std::size_t iter_nelems = batch_nelems * n * m;
30453046 std::size_t reduction_nelems =
@@ -3222,12 +3223,12 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q,
32223223 lhs_indexer, rhs_indexer, res_indexer, depends);
32233224 }
32243225 else {
3225- using ReductionOpT =
3226- typename std::conditional<std:: is_same_v<resTy, bool >,
3227- sycl::logical_or <resTy>,
3228- sycl::plus<resTy>>::type ;
3226+ using ReductionOpT = std:: conditional_t <
3227+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3228+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3229+ sycl::plus<resTy>>> ;
32293230 constexpr resTy identity_val =
3230- sycl::known_identity <ReductionOpT, resTy>::value;
3231+ su_ns::Identity <ReductionOpT, resTy>::value;
32313232 std::size_t iter_nelems = batch_nelems * n * m;
32323233 std::size_t reduction_nelems = (k + wi_delta_k - 1 ) / wi_delta_k;
32333234
@@ -3591,12 +3592,12 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q,
35913592 res_indexer, depends);
35923593 }
35933594 else {
3594- using ReductionOpT =
3595- typename std::conditional<std:: is_same_v<resTy, bool >,
3596- sycl::logical_or <resTy>,
3597- sycl::plus<resTy>>::type ;
3595+ using ReductionOpT = std:: conditional_t <
3596+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3597+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3598+ sycl::plus<resTy>>> ;
35983599 constexpr resTy identity_val =
3599- sycl::known_identity <ReductionOpT, resTy>::value;
3600+ su_ns::Identity <ReductionOpT, resTy>::value;
36003601
36013602 std::size_t iter_nelems = n * m;
36023603 std::size_t reduction_nelems =
@@ -3745,12 +3746,12 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q,
37453746 lhs_indexer, rhs_indexer, res_indexer, depends);
37463747 }
37473748 else {
3748- using ReductionOpT =
3749- typename std::conditional<std:: is_same_v<resTy, bool >,
3750- sycl::logical_or <resTy>,
3751- sycl::plus<resTy>>::type ;
3749+ using ReductionOpT = std:: conditional_t <
3750+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3751+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3752+ sycl::plus<resTy>>> ;
37523753 constexpr resTy identity_val =
3753- sycl::known_identity <ReductionOpT, resTy>::value;
3754+ su_ns::Identity <ReductionOpT, resTy>::value;
37543755
37553756 std::size_t iter_nelems = n * m;
37563757 std::size_t reduction_nelems = (k + wi_delta_k - 1 ) / wi_delta_k;
@@ -3979,12 +3980,12 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q,
39793980 res_indexer, depends);
39803981 }
39813982 else {
3982- using ReductionOpT =
3983- typename std::conditional<std:: is_same_v<resTy, bool >,
3984- sycl::logical_or <resTy>,
3985- sycl::plus<resTy>>::type ;
3983+ using ReductionOpT = std:: conditional_t <
3984+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3985+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3986+ sycl::plus<resTy>>> ;
39863987 constexpr resTy identity_val =
3987- sycl::known_identity <ReductionOpT, resTy>::value;
3988+ su_ns::Identity <ReductionOpT, resTy>::value;
39883989
39893990 std::size_t iter_nelems = n * m;
39903991 std::size_t reduction_nelems =
@@ -4118,12 +4119,12 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q,
41184119 lhs_indexer, rhs_indexer, res_indexer, depends);
41194120 }
41204121 else {
4121- using ReductionOpT =
4122- typename std::conditional<std:: is_same_v<resTy, bool >,
4123- sycl::logical_or <resTy>,
4124- sycl::plus<resTy>>::type ;
4122+ using ReductionOpT = std:: conditional_t <
4123+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
4124+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
4125+ sycl::plus<resTy>>> ;
41254126 constexpr resTy identity_val =
4126- sycl::known_identity <ReductionOpT, resTy>::value;
4127+ su_ns::Identity <ReductionOpT, resTy>::value;
41274128
41284129 std::size_t iter_nelems = n * m;
41294130 std::size_t reduction_nelems = (k + wi_delta_k - 1 ) / wi_delta_k;
0 commit comments