@@ -97,11 +97,23 @@ struct SequentialDotProduct
9797 auto lhs_reduction_offset = reduction_offsets.get_first_offset ();
9898 auto rhs_reduction_offset = reduction_offsets.get_second_offset ();
9999
100- using tu_ns::convert_impl;
101- red_val += convert_impl<outT, lhsT>(
102- lhs_[lhs_batch_offset + lhs_reduction_offset]) *
103- convert_impl<outT, rhsT>(
104- rhs_[rhs_batch_offset + rhs_reduction_offset]);
100+ if constexpr (tu_ns::is_complex_v<outT>) {
101+ using realT = typename outT::value_type;
102+ using sycl_complex = exprm_ns::complex <realT>;
103+
104+ auto tmp = sycl_complex (red_val);
105+ tmp += sycl_complex (tu_ns::convert_impl<outT, lhsT>(
106+ lhs_[lhs_batch_offset + lhs_reduction_offset])) *
107+ sycl_complex (tu_ns::convert_impl<outT, rhsT>(
108+ rhs_[rhs_batch_offset + rhs_reduction_offset]));
109+ red_val = outT (tmp);
110+ }
111+ else {
112+ red_val += tu_ns::convert_impl<outT, lhsT>(
113+ lhs_[lhs_batch_offset + lhs_reduction_offset]) *
114+ tu_ns::convert_impl<outT, rhsT>(
115+ rhs_[rhs_batch_offset + rhs_reduction_offset]);
116+ }
105117 }
106118
107119 out_[out_batch_offset] = red_val;
@@ -180,10 +192,9 @@ struct DotProductFunctor
180192 const auto &rhs_reduction_offset =
181193 reduction_offsets_.get_second_offset ();
182194
183- using tu_ns::convert_impl;
184- outT val = convert_impl<outT, lhsT>(
195+ outT val = tu_ns::convert_impl<outT, lhsT>(
185196 lhs_[lhs_batch_offset + lhs_reduction_offset]) *
186- convert_impl<outT, rhsT>(
197+ tu_ns:: convert_impl<outT, rhsT>(
187198 rhs_[rhs_batch_offset + rhs_reduction_offset]);
188199
189200 local_red_val += val;
@@ -278,10 +289,9 @@ struct DotProductCustomFunctor
278289 const auto &rhs_reduction_offset =
279290 reduction_offsets_.get_second_offset ();
280291
281- using tu_ns::convert_impl;
282- outT val = convert_impl<outT, lhsT>(
292+ outT val = tu_ns::convert_impl<outT, lhsT>(
283293 lhs_[lhs_batch_offset + lhs_reduction_offset]) *
284- convert_impl<outT, rhsT>(
294+ tu_ns:: convert_impl<outT, rhsT>(
285295 rhs_[rhs_batch_offset + rhs_reduction_offset]);
286296
287297 local_red_val += val;
@@ -723,23 +733,21 @@ struct DotProductNoAtomicFunctor
723733 const auto &rhs_reduction_offset =
724734 reduction_offsets_.get_second_offset ();
725735
726- using tu_ns::convert_impl;
727- using tu_ns::is_complex_v;
728- if constexpr (is_complex_v<outT>) {
736+ if constexpr (tu_ns::is_complex_v<outT>) {
729737 using realT = typename outT::value_type;
730738 using sycl_complexT = exprm_ns::complex <realT>;
731739
732740 sycl_complexT val =
733- sycl_complexT (convert_impl<outT, lhsT>(
741+ sycl_complexT (tu_ns:: convert_impl<outT, lhsT>(
734742 lhs_[lhs_batch_offset + lhs_reduction_offset])) *
735- sycl_complexT (convert_impl<outT, rhsT>(
743+ sycl_complexT (tu_ns:: convert_impl<outT, rhsT>(
736744 rhs_[rhs_batch_offset + rhs_reduction_offset]));
737745 local_red_val = outT (sycl_complexT (local_red_val) + val);
738746 }
739747 else {
740- outT val = convert_impl<outT, lhsT>(
748+ outT val = tu_ns:: convert_impl<outT, lhsT>(
741749 lhs_[lhs_batch_offset + lhs_reduction_offset]) *
742- convert_impl<outT, rhsT>(
750+ tu_ns:: convert_impl<outT, rhsT>(
743751 rhs_[rhs_batch_offset + rhs_reduction_offset]);
744752 local_red_val += val;
745753 }
@@ -837,23 +845,21 @@ struct DotProductNoAtomicCustomFunctor
837845 const auto &rhs_reduction_offset =
838846 reduction_offsets_.get_second_offset ();
839847
840- using tu_ns::convert_impl;
841- using tu_ns::is_complex_v;
842- if constexpr (is_complex_v<outT>) {
848+ if constexpr (tu_ns::is_complex_v<outT>) {
843849 using realT = typename outT::value_type;
844850 using sycl_complexT = exprm_ns::complex <realT>;
845851
846852 sycl_complexT val =
847- sycl_complexT (convert_impl<outT, lhsT>(
853+ sycl_complexT (tu_ns:: convert_impl<outT, lhsT>(
848854 lhs_[lhs_batch_offset + lhs_reduction_offset])) *
849- sycl_complexT (convert_impl<outT, rhsT>(
855+ sycl_complexT (tu_ns:: convert_impl<outT, rhsT>(
850856 rhs_[rhs_batch_offset + rhs_reduction_offset]));
851857 local_red_val = outT (sycl_complexT (local_red_val) + val);
852858 }
853859 else {
854- outT val = convert_impl<outT, lhsT>(
860+ outT val = tu_ns:: convert_impl<outT, lhsT>(
855861 lhs_[lhs_batch_offset + lhs_reduction_offset]) *
856- convert_impl<outT, rhsT>(
862+ tu_ns:: convert_impl<outT, rhsT>(
857863 rhs_[rhs_batch_offset + rhs_reduction_offset]);
858864 local_red_val += val;
859865 }
0 commit comments