From 9047f2886f0e2c2d764f23fef89d5e5bfceff4f1 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Wed, 5 Nov 2025 13:19:58 +0800 Subject: [PATCH] Update Atomics.h --- src/ATen/native/xpu/sycl/Atomics.h | 418 ++++++++++++++++++++++++++++- 1 file changed, 413 insertions(+), 5 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Atomics.h b/src/ATen/native/xpu/sycl/Atomics.h index d6cc1fe775..b4082a1459 100644 --- a/src/ATen/native/xpu/sycl/Atomics.h +++ b/src/ATen/native/xpu/sycl/Atomics.h @@ -4,6 +4,7 @@ #include #include #include +#include namespace at::native::xpu { @@ -108,7 +109,7 @@ struct AtomicIntegerImplLocal { }; #define SYCL_ATOMIC_INTEGER_LOCAL(NAME, OP, DTYPE) \ - static inline void atomic##NAME( \ + static inline void atomic##NAME##Local( \ const sycl_local_ptr& address, DTYPE val) { \ AtomicIntegerImplLocal()( \ address, val, [](DTYPE a, DTYPE b) { return OP; }); \ @@ -287,6 +288,91 @@ struct AtomicFPImpl { AtomicFPImpl()(address, val, [](DTYPE a, DTYPE b) { return OP; }); \ } +template +struct AtomicFPImplLocal; + +template <> +struct AtomicFPImplLocal { + template + inline void operator()(at::Half* address, at::Half val, const func_t& func) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int assumed = *address_as_ui; + unsigned int newval; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = assumed; + at::Half hsum; + hsum.x = (size_t)address & 2 ? (newval >> 16) : (newval & 0xffff); + hsum = func(hsum, val); + newval = (size_t)address & 2 ? (newval & 0xffff) | (hsum.x << 16) + : (newval & 0xffff0000) | hsum.x; + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template <> +struct AtomicFPImplLocal { + template + inline void operator()( + at::BFloat16* address, + at::BFloat16 val, + const func_t& func) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int assumed = *address_as_ui; + unsigned int newval; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = assumed; + at::BFloat16 bsum; + bsum.x = (size_t)address & 2 ? (newval >> 16) : (newval & 0xffff); + bsum = func(bsum, val); + newval = (size_t)address & 2 ? (newval & 0xffff) | (bsum.x << 16) + : (newval & 0xffff0000) | bsum.x; + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template <> +struct AtomicFPImplLocal { + template + inline void operator()(float* address, float val, const func_t& func) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int assumed = *address_as_ui; + unsigned int newval; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = __float_as_int(func(val, __int_as_float(assumed))); + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template <> +struct AtomicFPImplLocal { + template + inline void operator()(double* address, double val, const func_t& func) { + unsigned long long* address_as_ull = (unsigned long long*)address; + unsigned long long assumed = *address_as_ull; + unsigned long long newval; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ull); + + do { + newval = __double_as_long_long(func(val, __long_long_as_double(assumed))); + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +#define SYCL_ATOMIC_FP_LOCAL(NAME, OP, DTYPE) \ + static inline void atomic##NAME##Local( \ + const sycl_local_ptr& address, DTYPE val) { \ + AtomicFPImplLocal()( \ + address, val, [](DTYPE a, DTYPE b) { return OP; }); \ + } + static inline void atomicAdd(const sycl_global_ptr& address, float val) { sycl_atomic_ref_rlx_dev_global_t target(*address); target.fetch_add(val); @@ -337,6 +423,55 @@ static inline void atomicAdd( target.fetch_add(val); } +static inline void atomicAddLocal( + const sycl_local_ptr& address, + float val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +static inline void atomicAddLocal( + const sycl_local_ptr& address, + double val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +static inline void atomicAddLocal(const sycl_local_ptr& address, int val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +static inline void atomicAddLocal( + const sycl_local_ptr& address, + int64_t val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +static inline void atomicAddLocal( + const sycl_local_ptr& address, + uint32_t val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +static inline void atomicAddLocal( + const sycl_local_ptr& address, + uint64_t val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +// Atomic add local implementation. +SYCL_ATOMIC_INTEGER_LOCAL(Add, a || b, bool) +SYCL_ATOMIC_INTEGER_LOCAL(Add, std::plus()(a, b), uint8_t) +SYCL_ATOMIC_INTEGER_LOCAL(Add, std::plus()(a, b), int8_t) +SYCL_ATOMIC_INTEGER_LOCAL(Add, std::plus()(a, b), int16_t) + +SYCL_ATOMIC_FP_LOCAL(Add, std::plus()(a, b), at::Half) +SYCL_ATOMIC_FP_LOCAL(Add, std::plus()(a, b), at::BFloat16) + // Atomic add implementation. SYCL_ATOMIC_INTEGER(Add, a || b, bool) SYCL_ATOMIC_INTEGER(Add, std::plus()(a, b), uint8_t) @@ -354,6 +489,14 @@ static inline void atomicAdd( atomicAdd(&address->imag_, val.imag_); } +template +static inline void atomicAddLocal( + const sycl_local_ptr>& address, + c10::complex val) { + atomicAddLocal(&address->real_, val.real_); + atomicAddLocal(&address->imag_, val.imag_); +} + // Atomic multiplication implementation. SYCL_ATOMIC_INTEGER(Mul, std::multiplies()(a, b), uint8_t) SYCL_ATOMIC_INTEGER(Mul, std::multiplies()(a, b), int8_t) @@ -384,10 +527,6 @@ static inline void atomicMax( target.fetch_add(val); } -SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), uint8_t) -SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), int8_t) -SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), int16_t) - SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t) @@ -415,4 +554,273 @@ SYCL_ATOMIC_FP(Min, safe_min(a, b), double) SYCL_ATOMIC_FP(Min, safe_min(a, b), at::Half) SYCL_ATOMIC_FP(Min, safe_min(a, b), at::BFloat16) +// ========================================================================= +// ------------------------------AtomicCAS---------------------------------- +// ========================================================================= + +// --- Auxiliary Type Definition --- +// R is a template template parameter for the SYCL atomic ref type +template class R> +using AtomicRef = R; + +// --- Generic Integer CAS Structure Definition (R is the Atomic Ref type) --- +template class R> +struct AtomicCASInteger; + +// n=1 (1-byte Soft-RMW) +template class R> +struct AtomicCASInteger { + inline T operator()(T* address, T expected, T desired) { + size_t offset = (size_t)address & 3; + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + size_t shift = offset * 8; + uint32_t assumed; + uint32_t newval; + AtomicRef target(*address_as_ui); + + T extracted_old_value; + do { + assumed = *address_as_ui; + uint32_t byte_in_mem = (assumed >> shift) & 0xff; + extracted_old_value = static_cast(byte_in_mem); + + if (extracted_old_value == expected) { + uint32_t desired_byte = static_cast(desired); + newval = (assumed & ~(0x000000ff << shift)) | (desired_byte << shift); + } else { + break; + } + } while (!target.compare_exchange_strong(assumed, newval)); + + if (extracted_old_value == expected) { + return expected; + } else { + return extracted_old_value; + } + } +}; + +// n=2 (2-byte Soft-RMW) +template class R> +struct AtomicCASInteger { + inline T operator()(T* address, T expected, T desired) { + size_t offset = (size_t)address & 2; + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + bool is_upper_half = offset; + uint32_t assumed; + uint32_t newval; + uint32_t current_half_word; + + AtomicRef target(*address_as_ui); + + T extracted_old_value; + do { + assumed = *address_as_ui; + current_half_word = is_upper_half ? (assumed >> 16) : (assumed & 0xffff); + extracted_old_value = static_cast(current_half_word); + + if (extracted_old_value == expected) { + uint32_t desired_half_word = static_cast(desired); + newval = is_upper_half ? (assumed & 0xffff) | (desired_half_word << 16) + : (assumed & 0xffff0000) | desired_half_word; + } else { + break; + } + } while (!target.compare_exchange_strong(assumed, newval)); + + if (extracted_old_value == expected) { + return expected; + } else { + return extracted_old_value; + } + } +}; + +// n=4 (4-byte Native CAS) +template class R> +struct AtomicCASInteger { + inline T operator()(T* address, T expected, T desired) { + uint32_t* address_as_ui = (uint32_t*)(address); + uint32_t assumed; + uint32_t newval; + + AtomicRef target(*address_as_ui); + + uint32_t expected_ui = static_cast(expected); + newval = static_cast(desired); + + do { + assumed = *address_as_ui; + if (assumed != expected_ui) { + break; + } + } while (!target.compare_exchange_strong(assumed, newval)); + + if (assumed == expected_ui) { + return expected; + } else { + return static_cast(assumed); + } + } +}; + +// n=8 (8-byte Native CAS) +template class R> +struct AtomicCASInteger { + inline T operator()(T* address, T expected, T desired) { + unsigned long long* address_as_ull = (unsigned long long*)(address); + unsigned long long assumed; + unsigned long long newval; + + AtomicRef target(*address_as_ull); + + unsigned long long expected_ull = static_cast(expected); + newval = static_cast(desired); + + do { + assumed = *address_as_ull; + if (assumed != expected_ull) { + break; + } + } while (!target.compare_exchange_strong(assumed, newval)); + + if (assumed == expected_ull) { + return expected; + } else { + return static_cast(assumed); + } + } +}; + +// --- Generic Macro Definitions for Function Signatures --- +#define SYCL_ATOMIC_CAS_IMPL(DTYPE, STRUCT_NAME, PTR_TYPE, ATOMIC_REF) \ + static inline DTYPE atomicCAS( \ + const PTR_TYPE& address, DTYPE expected, DTYPE desired) { \ + /* Call generic struct with specific SYCL atomic ref type */ \ + return STRUCT_NAME()( \ + address, expected, desired); \ + } + +#define SYCL_ATOMIC_CAS_ALL(DTYPE, STRUCT_NAME) \ + /* local CAS version */ \ + SYCL_ATOMIC_CAS_IMPL( \ + DTYPE, STRUCT_NAME, sycl_local_ptr, sycl_atomic_ref_rlx_wg_local_t) + +SYCL_ATOMIC_CAS_ALL(int, AtomicCASInteger) +SYCL_ATOMIC_CAS_ALL(int64_t, AtomicCASInteger) +SYCL_ATOMIC_CAS_ALL(uint32_t, AtomicCASInteger) +SYCL_ATOMIC_CAS_ALL(uint64_t, AtomicCASInteger) +SYCL_ATOMIC_CAS_ALL(int8_t, AtomicCASInteger) +SYCL_ATOMIC_CAS_ALL(uint8_t, AtomicCASInteger) + +// --- Generic Floating Point CAS Structure Definition (R is the Atomic Ref +// type) --- +template class R> +struct AtomicCASFP; + +// n=2 (at::Half/at::BFloat16 Soft-RMW) +template class R> +struct AtomicCASFP { + inline T operator()(T* address, T expected, T desired) { + size_t offset = (size_t)address & 2; + unsigned int* address_as_ui = (unsigned int*)((char*)address - offset); + bool is_upper_half = offset; + + unsigned int assumed; + unsigned int newval; + + // 🌟 Using generic AtomicRef + AtomicRef target(*address_as_ui); + + unsigned int expected_half_word = expected.x; + unsigned int desired_half_word = desired.x; + + unsigned int current_half_word; + T extracted_old_value; + + do { + assumed = *address_as_ui; + current_half_word = is_upper_half ? (assumed >> 16) : (assumed & 0xffff); + + extracted_old_value.x = (uint16_t)current_half_word; + + if (extracted_old_value.x == expected_half_word) { + newval = is_upper_half ? (assumed & 0xffff) | (desired_half_word << 16) + : (assumed & 0xffff0000) | desired_half_word; + } else { + break; + } + } while (!target.compare_exchange_strong(assumed, newval)); + + if (extracted_old_value.x == expected_half_word) { + return expected; + } else { + return extracted_old_value; + } + } +}; + +// n=4 (4-byte float Native CAS) +template class R> +struct AtomicCASFP { + inline T operator()(T* address, T expected, T desired) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int assumed; + unsigned int newval; + + // 🌟 Using generic AtomicRef + AtomicRef target(*address_as_ui); + + unsigned int expected_ui = *((unsigned int*)&expected); + newval = *((unsigned int*)&desired); + + do { + assumed = *address_as_ui; + if (assumed != expected_ui) { + break; + } + } while (!target.compare_exchange_strong(assumed, newval)); + + if (assumed == expected_ui) { + return expected; + } else { + return *((T*)&assumed); + } + } +}; + +// n=8 (8-byte double Native CAS) +template class R> +struct AtomicCASFP { + inline T operator()(T* address, T expected, T desired) { + unsigned long long* address_as_ull = (unsigned long long*)address; + unsigned long long assumed; + unsigned long long newval; + + // 🌟 Using generic AtomicRef + AtomicRef target(*address_as_ull); + + unsigned long long expected_ull = *((unsigned long long*)&expected); + newval = *((unsigned long long*)&desired); + + do { + assumed = *address_as_ull; + if (assumed != expected_ull) { + break; + } + } while (!target.compare_exchange_strong(assumed, newval)); + + if (assumed == expected_ull) { + return expected; + } else { + return *((T*)&assumed); + } + } +}; + +SYCL_ATOMIC_CAS_ALL(float, AtomicCASFP) +SYCL_ATOMIC_CAS_ALL(double, AtomicCASFP) +SYCL_ATOMIC_CAS_ALL(at::Half, AtomicCASFP) +SYCL_ATOMIC_CAS_ALL(at::BFloat16, AtomicCASFP) + } // namespace at::native::xpu