diff --git a/src/ATen/native/xpu/sycl/Reduce.h b/src/ATen/native/xpu/sycl/Reduce.h index 0c713ee51..4639ea9e8 100644 --- a/src/ATen/native/xpu/sycl/Reduce.h +++ b/src/ATen/native/xpu/sycl/Reduce.h @@ -1151,8 +1151,24 @@ inline void gpu_reduce_kernel( using traits = function_traits; using arg_t = typename traits::template arg<0>::type; + + // at::Half/at::ComplexHalf overflows easily as it's range is very small. + // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_half_or_chalf = + (std::is_same_v && + std::is_same_v) || + (std::is_same_v, scalar_t> && + std::is_same_v, out_scalar_t>); + // at::BFloat16 has lower precision and can lead to rounding errors. + // So when scalar_t and out_scalar_t are at::BFloat16, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_bfloat16 = + (std::is_same_v && + std::is_same_v); static constexpr bool can_accumulate_in_output = - std::is_convertible::value; + std::is_convertible_v && + !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16); bool can_use_32bit_indexing = iter.can_use_32bit_indexing(); std::unique_ptr owned_buf_ptr; diff --git a/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp b/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp index 7bdc3a188..7b62371c4 100644 --- a/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp @@ -48,15 +48,6 @@ struct SumFunctor { } }; -template <> -struct SumFunctor> { - using scalar_t = c10::complex; - using acc_t = at::opmath_type; - inline acc_t operator()(acc_t a, acc_t b) const { - return a + b; - } -}; - template < typename scalar_t, typename acc_t = scalar_t, @@ -68,6 +59,16 @@ struct sum_functor { } }; +template <> +struct sum_functor> { + void operator()(TensorIterator& iter) { + using scalar_t = c10::complex; + using acc_t = at::opmath_type; + gpu_reduce_kernel( + iter, func_wrapper(SumFunctor())); + } +}; + void sum_kernel(TensorIterator& iter) { auto general_dispatcher = [](TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(