4040#include " utils/sycl_utils.hpp"
4141#include " utils/type_utils.hpp"
4242
43+ #define SYCL_EXT_ONEAPI_COMPLEX
44+ #include < sycl/ext/oneapi/experimental/complex/complex.hpp>
45+
4346namespace dpctl
4447{
4548namespace tensor
@@ -49,6 +52,8 @@ namespace kernels
4952
5053using dpctl::tensor::ssize_t ;
5154namespace su_ns = dpctl::tensor::sycl_utils;
55+ namespace tu_ns = dpctl::tensor::type_utils;
56+ namespace exprm_ns = sycl::ext::oneapi::experimental;
5257
5358template <typename lhsT,
5459 typename rhsT,
@@ -92,7 +97,7 @@ struct SequentialDotProduct
9297 auto lhs_reduction_offset = reduction_offsets.get_first_offset ();
9398 auto rhs_reduction_offset = reduction_offsets.get_second_offset ();
9499
95- using dpctl::tensor::type_utils ::convert_impl;
100+ using tu_ns ::convert_impl;
96101 red_val += convert_impl<outT, lhsT>(
97102 lhs_[lhs_batch_offset + lhs_reduction_offset]) *
98103 convert_impl<outT, rhsT>(
@@ -175,7 +180,7 @@ struct DotProductFunctor
175180 const auto &rhs_reduction_offset =
176181 reduction_offsets_.get_second_offset ();
177182
178- using dpctl::tensor::type_utils ::convert_impl;
183+ using tu_ns ::convert_impl;
179184 outT val = convert_impl<outT, lhsT>(
180185 lhs_[lhs_batch_offset + lhs_reduction_offset]) *
181186 convert_impl<outT, rhsT>(
@@ -273,7 +278,7 @@ struct DotProductCustomFunctor
273278 const auto &rhs_reduction_offset =
274279 reduction_offsets_.get_second_offset ();
275280
276- using dpctl::tensor::type_utils ::convert_impl;
281+ using tu_ns ::convert_impl;
277282 outT val = convert_impl<outT, lhsT>(
278283 lhs_[lhs_batch_offset + lhs_reduction_offset]) *
279284 convert_impl<outT, rhsT>(
@@ -718,13 +723,26 @@ struct DotProductNoAtomicFunctor
718723 const auto &rhs_reduction_offset =
719724 reduction_offsets_.get_second_offset ();
720725
721- using dpctl::tensor::type_utils::convert_impl;
722- outT val = convert_impl<outT, lhsT>(
723- lhs_[lhs_batch_offset + lhs_reduction_offset]) *
724- convert_impl<outT, rhsT>(
725- rhs_[rhs_batch_offset + rhs_reduction_offset]);
726-
727- local_red_val += val;
726+ using tu_ns::convert_impl;
727+ using tu_ns::is_complex_v;
728+ if constexpr (is_complex_v<outT>) {
729+ using realT = typename outT::value_type;
730+ using sycl_complexT = exprm_ns::complex <realT>;
731+
732+ sycl_complexT val =
733+ sycl_complexT (convert_impl<outT, lhsT>(
734+ lhs_[lhs_batch_offset + lhs_reduction_offset])) *
735+ sycl_complexT (convert_impl<outT, rhsT>(
736+ rhs_[rhs_batch_offset + rhs_reduction_offset]));
737+ local_red_val = outT (sycl_complexT (local_red_val) + val);
738+ }
739+ else {
740+ outT val = convert_impl<outT, lhsT>(
741+ lhs_[lhs_batch_offset + lhs_reduction_offset]) *
742+ convert_impl<outT, rhsT>(
743+ rhs_[rhs_batch_offset + rhs_reduction_offset]);
744+ local_red_val += val;
745+ }
728746 }
729747
730748 auto work_group = it.get_group ();
@@ -819,13 +837,26 @@ struct DotProductNoAtomicCustomFunctor
819837 const auto &rhs_reduction_offset =
820838 reduction_offsets_.get_second_offset ();
821839
822- using dpctl::tensor::type_utils::convert_impl;
823- outT val = convert_impl<outT, lhsT>(
824- lhs_[lhs_batch_offset + lhs_reduction_offset]) *
825- convert_impl<outT, rhsT>(
826- rhs_[rhs_batch_offset + rhs_reduction_offset]);
827-
828- local_red_val += val;
840+ using tu_ns::convert_impl;
841+ using tu_ns::is_complex_v;
842+ if constexpr (is_complex_v<outT>) {
843+ using realT = typename outT::value_type;
844+ using sycl_complexT = exprm_ns::complex <realT>;
845+
846+ sycl_complexT val =
847+ sycl_complexT (convert_impl<outT, lhsT>(
848+ lhs_[lhs_batch_offset + lhs_reduction_offset])) *
849+ sycl_complexT (convert_impl<outT, rhsT>(
850+ rhs_[rhs_batch_offset + rhs_reduction_offset]));
851+ local_red_val = outT (sycl_complexT (local_red_val) + val);
852+ }
853+ else {
854+ outT val = convert_impl<outT, lhsT>(
855+ lhs_[lhs_batch_offset + lhs_reduction_offset]) *
856+ convert_impl<outT, rhsT>(
857+ rhs_[rhs_batch_offset + rhs_reduction_offset]);
858+ local_red_val += val;
859+ }
829860 }
830861
831862 auto work_group = it.get_group ();
0 commit comments