Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,24 @@ inline void gpu_reduce_kernel(

using traits = function_traits<decltype(&ops_t::reduce)>;
using arg_t = typename traits::template arg<0>::type;

// at::Half/at::ComplexHalf overflows easily as it's range is very small.
// So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
// set can_accumulate_in_output to False.
static constexpr bool is_inp_out_type_half_or_chalf =
(std::is_same_v<at::Half, scalar_t> &&
std::is_same_v<at::Half, out_scalar_t>) ||
(std::is_same_v<c10::complex<Half>, scalar_t> &&
std::is_same_v<c10::complex<Half>, out_scalar_t>);
// at::BFloat16 has lower precision and can lead to rounding errors.
// So when scalar_t and out_scalar_t are at::BFloat16, we
// set can_accumulate_in_output to False.
static constexpr bool is_inp_out_type_bfloat16 =
(std::is_same_v<at::BFloat16, scalar_t> &&
std::is_same_v<at::BFloat16, out_scalar_t>);
static constexpr bool can_accumulate_in_output =
std::is_convertible<arg_t, out_scalar_t>::value;
std::is_convertible_v<arg_t, out_scalar_t> &&
!(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);

bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
Expand Down
19 changes: 10 additions & 9 deletions src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ struct SumFunctor {
}
};

template <>
struct SumFunctor<c10::complex<at::Half>> {
using scalar_t = c10::complex<at::Half>;
using acc_t = at::opmath_type<scalar_t>;
inline acc_t operator()(acc_t a, acc_t b) const {
return a + b;
}
};

template <
typename scalar_t,
typename acc_t = scalar_t,
Expand All @@ -68,6 +59,16 @@ struct sum_functor {
}
};

template <>
struct sum_functor<c10::complex<at::Half>> {
void operator()(TensorIterator& iter) {
using scalar_t = c10::complex<at::Half>;
using acc_t = at::opmath_type<scalar_t>;
gpu_reduce_kernel<scalar_t, scalar_t>(
iter, func_wrapper<scalar_t>(SumFunctor<acc_t>()));
}
};

void sum_kernel(TensorIterator& iter) {
auto general_dispatcher = [](TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
Expand Down