4040#include " utils/sycl_utils.hpp"
4141#include " utils/type_utils.hpp"
4242
43+ #define SYCL_EXT_ONEAPI_COMPLEX
44+ #include < sycl/ext/oneapi/experimental/complex/complex.hpp>
45+
4346namespace dpctl
4447{
4548namespace tensor
@@ -48,6 +51,8 @@ namespace kernels
4851{
4952
5053using dpctl::tensor::ssize_t ;
54+ namespace tu_ns = dpctl::tensor::type_utils;
55+ namespace exprm_ns = sycl::ext::oneapi::experimental;
5156
5257namespace gemm_detail
5358{
@@ -1082,8 +1087,21 @@ class GemmBatchFunctorThreadNM_vecm
10821087#pragma unroll
10831088 for (std::uint32_t pr_j = 0 ; pr_j < wi_delta_m_vecs; ++pr_j)
10841089 {
1085- private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1086- pr_lhs[pr_i] * pr_rhs[pr_j];
1090+ if constexpr (tu_ns::is_complex_v<resT>) {
1091+ using realT = typename resT::value_type;
1092+ using sycl_complex = exprm_ns::complex <realT>;
1093+
1094+ auto tmp = sycl_complex (
1095+ private_C[pr_i * wi_delta_m_vecs + pr_j]);
1096+ tmp += sycl_complex (pr_lhs[pr_i]) *
1097+ sycl_complex (pr_rhs[pr_j]);
1098+ private_C[pr_i * wi_delta_m_vecs + pr_j] =
1099+ resT (tmp);
1100+ }
1101+ else {
1102+ private_C[pr_i * wi_delta_m_vecs + pr_j] +=
1103+ pr_lhs[pr_i] * pr_rhs[pr_j];
1104+ }
10871105 }
10881106 }
10891107 }
@@ -1949,9 +1967,21 @@ class GemmBatchNoAtomicFunctorThreadNM
19491967 slmB_t local_sum (identity_);
19501968 for (std::size_t private_s = 0 ; private_s < wi_delta_k; ++private_s)
19511969 {
1952- local_sum = local_sum +
1953- (local_A_block[a_offset + a_pr_offset + private_s] *
1954- local_B_block[b_offset + private_s]);
1970+ if constexpr (tu_ns::is_complex_v<resT>) {
1971+ using realT = typename resT::value_type;
1972+ using sycl_complex = exprm_ns::complex <realT>;
1973+ auto tmp = sycl_complex (local_sum);
1974+ tmp += (sycl_complex (local_A_block[a_offset + a_pr_offset +
1975+ private_s]) *
1976+ sycl_complex (local_B_block[b_offset + private_s]));
1977+ local_sum = resT (tmp);
1978+ }
1979+ else {
1980+ local_sum =
1981+ local_sum +
1982+ (local_A_block[a_offset + a_pr_offset + private_s] *
1983+ local_B_block[b_offset + private_s]);
1984+ }
19551985 }
19561986
19571987 const std::size_t gl_i = i + private_i;
@@ -2114,12 +2144,28 @@ class GemmBatchNoAtomicFunctorThreadK
21142144 accV_t private_sum (identity_);
21152145 constexpr accV_t vec_identity_ (identity_);
21162146 for (std::size_t t = local_s; t < local_B_block.size (); t += delta_k) {
2117- private_sum +=
2118- ((i < n) && (t + t_shift < k))
2119- ? (static_cast <resT>(
2120- lhs[lhs_offset + lhs_indexer (global_s_offset + t)]) *
2121- local_B_block[t])
2122- : vec_identity_;
2147+ if constexpr (tu_ns::is_complex_v<resT>) {
2148+ using realT = typename resT::value_type;
2149+ using sycl_complex = exprm_ns::complex <realT>;
2150+
2151+ auto tmp = sycl_complex (private_sum);
2152+ tmp += ((i < n) && (t + t_shift < k))
2153+ ? sycl_complex (static_cast <resT>(
2154+ lhs[lhs_offset +
2155+ lhs_indexer (global_s_offset + t)])) *
2156+ sycl_complex (local_B_block[t])
2157+ : sycl_complex (vec_identity_);
2158+ private_sum = resT (tmp);
2159+ }
2160+ else {
2161+ private_sum +=
2162+ ((i < n) && (t + t_shift < k))
2163+ ? (static_cast <resT>(
2164+ lhs[lhs_offset +
2165+ lhs_indexer (global_s_offset + t)]) *
2166+ local_B_block[t])
2167+ : vec_identity_;
2168+ }
21232169 }
21242170
21252171 std::size_t workspace_i_shift = local_i * delta_k;
@@ -2130,7 +2176,17 @@ class GemmBatchNoAtomicFunctorThreadK
21302176 if (local_s == 0 && i < n) {
21312177 accV_t local_sum (workspace[workspace_i_shift]);
21322178 for (std::size_t t = 1 ; t < delta_k; ++t) {
2133- local_sum += workspace[workspace_i_shift + t];
2179+ if constexpr (tu_ns::is_complex_v<resT>) {
2180+ using realT = typename resT::value_type;
2181+ using sycl_complex = exprm_ns::complex <realT>;
2182+
2183+ auto tmp = sycl_complex (local_sum);
2184+ tmp += sycl_complex (workspace[workspace_i_shift + t]);
2185+ local_sum = resT (tmp);
2186+ }
2187+ else {
2188+ local_sum += workspace[workspace_i_shift + t];
2189+ }
21342190 }
21352191
21362192 const std::size_t total_offset =
@@ -2863,8 +2919,7 @@ sycl::event gemm_batch_tree_impl(sycl::queue &exec_q,
28632919 }
28642920
28652921 if (max_nm < 64 ) {
2866- using dpctl::tensor::type_utils::is_complex;
2867- if constexpr (!is_complex<resTy>::value) {
2922+ if constexpr (!tu_ns::is_complex_v<resTy>) {
28682923 if (m < 4 ) {
28692924 constexpr std::uint32_t m_groups_one = 1 ;
28702925 return gemm_batch_tree_k_impl<lhsTy, rhsTy, resTy,
@@ -2900,8 +2955,7 @@ sycl::event gemm_batch_tree_impl(sycl::queue &exec_q,
29002955 }
29012956 }
29022957 else { // m > 1, n > k or m > k
2903- using dpctl::tensor::type_utils::is_complex;
2904- if constexpr (!is_complex<resTy>::value) {
2958+ if constexpr (!tu_ns::is_complex_v<resTy>) {
29052959 constexpr std::uint32_t m_groups_four = 4 ;
29062960 return gemm_batch_tree_nm_impl<lhsTy, rhsTy, resTy, m_groups_four>(
29072961 exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd,
@@ -3435,8 +3489,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
34353489 }
34363490
34373491 if (max_nm < 64 ) {
3438- using dpctl::tensor::type_utils::is_complex;
3439- if constexpr (!is_complex<resTy>::value) {
3492+ if constexpr (!tu_ns::is_complex_v<resTy>) {
34403493 if (m < 4 ) {
34413494 return gemm_batch_contig_tree_k_impl<lhsTy, rhsTy, resTy, 1 >(
34423495 exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m,
@@ -3454,8 +3507,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
34543507 }
34553508 }
34563509 else { // m > 1, n > k or m > k
3457- using dpctl::tensor::type_utils::is_complex;
3458- if constexpr (!is_complex<resTy>::value) {
3510+ if constexpr (!tu_ns::is_complex_v<resTy>) {
34593511 return gemm_batch_contig_tree_nm_impl<lhsTy, rhsTy, resTy, 4 >(
34603512 exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends);
34613513 }
@@ -3840,8 +3892,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
38403892 }
38413893
38423894 if (max_nm < 64 ) {
3843- using dpctl::tensor::type_utils::is_complex;
3844- if constexpr (!is_complex<resTy>::value) {
3895+ if constexpr (!tu_ns::is_complex_v<resTy>) {
38453896 if (m < 4 ) {
38463897 return gemm_tree_k_impl<lhsTy, rhsTy, resTy, 1 >(
38473898 exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd,
@@ -3866,8 +3917,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q,
38663917 }
38673918 }
38683919 else { // m > 1, n > k or m > k
3869- using dpctl::tensor::type_utils::is_complex;
3870- if constexpr (!is_complex<resTy>::value) {
3920+ if constexpr (!tu_ns::is_complex_v<resTy>) {
38713921 return gemm_tree_nm_impl<lhsTy, rhsTy, resTy, 4 >(
38723922 exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd,
38733923 lhs_outer_inner_shapes_strides, rhs_outer_nd,
@@ -4191,8 +4241,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
41914241 }
41924242
41934243 if (max_nm < 64 ) {
4194- using dpctl::tensor::type_utils::is_complex;
4195- if constexpr (!is_complex<resTy>::value) {
4244+ if constexpr (!tu_ns::is_complex_v<resTy>) {
41964245 if (m < 4 ) {
41974246 return gemm_contig_tree_k_impl<lhsTy, rhsTy, resTy, 1 >(
41984247 exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
@@ -4208,8 +4257,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
42084257 }
42094258 }
42104259 else { // m > 1, n > k or m > k
4211- using dpctl::tensor::type_utils::is_complex;
4212- if constexpr (!is_complex<resTy>::value) {
4260+ if constexpr (!tu_ns::is_complex_v<resTy>) {
42134261 return gemm_contig_tree_nm_impl<lhsTy, rhsTy, resTy, 4 >(
42144262 exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
42154263 }
0 commit comments