From d1d048070fae11a9a2e0ee87ad9eb0622bef2879 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Wed, 5 Nov 2025 13:26:33 +0800 Subject: [PATCH 1/2] Update Indexing.cpp --- src/ATen/native/xpu/sycl/Indexing.cpp | 1099 +++++++++++++++---------- 1 file changed, 677 insertions(+), 422 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index 6c3a278d2..942247de6 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -19,11 +19,13 @@ #include #include #include +#include #include #include #include #include +#include #include @@ -43,10 +45,14 @@ void index_kernel( TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_V2( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::ComplexHalf, + at::ScalarType::BFloat16, + at::ScalarType::Half, + at::ScalarType::Bool, iter.dtype(), "index_xpu", - AT_WRAP([&] { + [&] { using dtype = OpaqueType; IndexFunctor f; _index_kernel( @@ -57,178 +63,7 @@ void index_kernel( IntArrayRef{}, f, true); - }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - AT_EXPAND(AT_FLOAT8_TYPES), - kComplexHalf, - kHalf, - kBool, - kBFloat16); -} - -template -class IndexSelectScalarFunctor { - public: - void operator()( - ValType* dst, - const ValType* src, - int64_t dst_off, - int64_t src_off, - int64_t idx, - ValType alpha) const { - dst[dst_off] = src[src_off]; - } -}; - -template < - class SrcInfo, - class DstInfo, - class IdxInfo, - bool TrivialOffCal = false> -static inline void _index_select_kernel( - SrcInfo& src_info, - DstInfo& dst_info, - IdxInfo& index_info, - int64_t dim) { - using scalar_t = typename DstInfo::scalar_t; - using IdxConfig = IndexKernelConfig< - SrcInfo, - DstInfo, - IdxInfo, - IndexSelectScalarFunctor>; - - using IndexKnownProblemInnerKernel = - IndexKernel; - auto IndexKnownProblemInnerKernel_cfg = - IdxConfig::template make_config( - src_info, - dst_info, - index_info, - static_cast(0), - dim, - false, - IndexSelectScalarFunctor()); - - using IndexUnknownProblemInnerKernel = - IndexKernel; - auto IndexUnknownProblemInnerKernel_cfg = - IdxConfig::template make_config( - src_info, - dst_info, - index_info, - static_cast(0), - dim, - false, - IndexSelectScalarFunctor()); - - if (IndexKnownProblemInnerKernel_cfg.problem_inner_) { - launch_index_kernel( - IndexKnownProblemInnerKernel_cfg); - } else { - launch_index_kernel( - IndexUnknownProblemInnerKernel_cfg); - } -} - -void index_select_kernel( - const Tensor& src, - int64_t dim, - const Tensor& indices, - const Tensor& dst) { - at::assert_no_internal_overlap(dst); - at::assert_no_overlap(dst, src); - at::assert_no_overlap(dst, indices); - - dim = at::maybe_wrap_dim(dim, src.dim()); - int srcDims = src.dim() == 0 ? 1 : src.dim(); - int dstDims = dst.dim(); - int idxDims = indices.dim(); - - TORCH_CHECK( - srcDims <= XPU_MAX_TENSORINFO_DIMS, - "src tensor dim should be < ", - XPU_MAX_TENSORINFO_DIMS); - TORCH_CHECK( - dstDims <= XPU_MAX_TENSORINFO_DIMS, - "dst tensor dim should be < ", - XPU_MAX_TENSORINFO_DIMS); - TORCH_CHECK( - idxDims <= XPU_MAX_TENSORINFO_DIMS, - "index tensor dim should be < ", - XPU_MAX_TENSORINFO_DIMS); - TORCH_CHECK( - idxDims <= 1, "Index is supposed to be an empty tensor or a vector"); - TORCH_CHECK( - dim >= -1 && dim < srcDims, - "Indexing dim should be >= -1 and < dims - 1"); - TORCH_CHECK(srcDims > 0, "Source tensor is empty"); - TORCH_CHECK( - indices.scalar_type() == ScalarType::Long || - indices.scalar_type() == ScalarType::Int, - "index_select(): Expected dtype int32 or int64 for index but got: ", - indices.scalar_type()); - TORCH_CHECK( - src.scalar_type() == dst.scalar_type(), - "index_select(): Source and result must have the same scalar type"); - - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "index_select", [&] { - TensorInfo index_info = - tensorInfoIfScalar(getTensorInfo(indices)); - index_info.collapseDims(); - - auto new_size = src.sizes().vec(); - - if (src.dim() > 0) { - new_size[dim] = indices.numel(); - } - - at::native::resize_output(dst, new_size); - - ptrdiff_t dst_num_elem = dst.numel(); - if (dst_num_elem == 0) { - return; - } - - AT_DISPATCH_V2( - dst.scalar_type(), - "index_select_xpu", - AT_WRAP([&] { - TensorInfo dst_info = - tensorInfoIfScalar(getTensorInfo(dst)); - TensorInfo src_info = tensorInfoIfScalar( - getTensorInfo(src.contiguous())); - int new_indexing_dim = src_info.collapseDims(dim); - - using SrcInfo = TensorInfo; - using DstInfo = TensorInfo; - using IdxInfo = TensorInfo; - - // Improve efficiency of generated native instructions for contiguous. - // See comm/TensorInfo.h - if (dst.is_contiguous() && indices.is_contiguous()) - _index_select_kernel< - SrcInfo, - DstInfo, - IdxInfo, - /* TrivialOffCal */ true>( - src_info, dst_info, index_info, new_indexing_dim); - else - _index_select_kernel< - SrcInfo, - DstInfo, - IdxInfo, - /* TrivialOffCal */ false>( - src_info, dst_info, index_info, new_indexing_dim); - }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), - AT_EXPAND(AT_FLOAT8_TYPES), - kComplexHalf, - kHalf, - kBool, - kBFloat16); - }); - return; + }); } template @@ -259,203 +94,6 @@ void masked_fill_kernel(TensorIterator& iter, const Scalar& value) { }); } -// Check tensor dimensions for index operations, and return the slice size. -static ptrdiff_t getSliceSize( - const Tensor& dst, - int dim, - const Tensor& index, - const Tensor& src) { - const auto dstDims = dst.dim(); - const auto srcDims = src.dim(); - - TORCH_CHECK(index.dim() <= 1, "Index must be vector or scalar"); - - ptrdiff_t dstSliceSize = 1; - TORCH_CHECK( - dim >= 0 && dim < dstDims, "Indexing dim ", dim, " is out of bounds"); - for (const auto d : c10::irange(dstDims)) { - if (d != dim) { - dstSliceSize *= dst.size(d); - } - } - - TORCH_CHECK(dim < srcDims, "Indexing dim ", dim, " is out of bounds"); - TORCH_CHECK( - index.numel() == src.size(dim), - "length of src.size[dim] is not equal to length of indices"); - - ptrdiff_t srcSliceSize = 1; - bool mismatch = false; - - if (dstDims != srcDims) - mismatch = true; - - for (const auto d : c10::irange(srcDims)) { - if (d != dim) { - srcSliceSize *= src.size(d); - if (!mismatch && dst.size(d) != src.size(d)) - mismatch = true; - } - } - - TORCH_CHECK( - dstSliceSize == srcSliceSize, - "Source/destination tensor have different slice sizes (%ld vs %ld)", - dstSliceSize, - srcSliceSize); - - if (mismatch) { - TORCH_WARN_ONCE( - "Warning: source/destination slices have same size but different " - "shape for an index operation. This behavior is deprecated.\n"); - } - - return dstSliceSize; -} - -template -bool indexShouldBeMajor( - TensorInfo& info, - int sliceDim) { - // The stride between adjacent slices (e.g., between element #0 of slice #100 - // and element #0 of slice #101). - unsigned int sliceStride = info.strides[sliceDim]; - - for (const auto i : c10::irange(info.dims)) { - if (i != sliceDim && info.sizes[i] > 1 && info.strides[i] < sliceStride) { - return true; - } - } - - return false; -} - -template -struct IndexAddScalarFunctor { - void operator()( - ValType* dst, - const ValType* src, - int64_t dst_off, - int64_t src_off, - int64_t idx, - ValType alpha) const { - atomicAdd((sycl_global_ptr)(dst + dst_off), src[src_off] * alpha); - } -}; - -template <> -struct IndexAddScalarFunctor { - void operator()( - bool* dst, - const bool* src, - int64_t dst_off, - int64_t src_off, - int64_t idx, - bool alpha) const { - atomicAdd((sycl_global_ptr)(dst + dst_off), src[src_off] && alpha); - } -}; - -void index_add_kernel( - const Tensor& self, - int64_t dim, - const Tensor& index, - const Tensor& source, - const Scalar& alpha, - const Tensor& result) { - if (!result.is_same(self)) { - result.copy_(self); - } - - auto numel = index.numel(); - if (result.dim() > 1) { - if (numel == 0 || self.numel() == 0) { - return; - } - } - - // Scalars are treated as 1-d tensor - const Tensor self_ = (result.dim() == 0) ? result.view(1) : result; - const Tensor source_ = (source.dim() == 0) ? source.view(1) : source; - - TORCH_CHECK( - result.dim() <= XPU_MAX_TENSORINFO_DIMS, - "tensor has too many (>", - XPU_MAX_TENSORINFO_DIMS, - ") dims"); - TORCH_CHECK( - source.dim() <= XPU_MAX_TENSORINFO_DIMS, - "tensor has too many (>", - XPU_MAX_TENSORINFO_DIMS, - ") dims"); - TORCH_CHECK( - index.dim() <= XPU_MAX_TENSORINFO_DIMS, - "tensor has too many (>", - XPU_MAX_TENSORINFO_DIMS, - ") dims"); - - if (globalContext().deterministicAlgorithms()) { - torch::List> indices; - indices.reserve(dim + 1); - for (int i = 0; i < dim; i++) { - indices.emplace_back(); - } - indices.emplace_back(index.to(at::kLong)); - result.index_put_(indices, source * alpha, true); - return; - } - - // The `source` is partitioned into two parts: - // -the size of each slice we are indexing, which is the - // total size of the tensor ignoring dimension `dim`; - // -the number of index we are choosing, which is the total size - // of the tensor `index`. - const ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_); - - if (sliceSize == 0) { - return; - } - - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( - at::ScalarType::Bool, - at::ScalarType::Half, - at::ScalarType::BFloat16, - at::ScalarType::ComplexHalf, - source_.scalar_type(), - "index_add_xpu", - [&] { - AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_xpu", [&]() { - TensorInfo index_info = - getTensorInfo(index); - index_info.collapseDims(); - - TensorInfo src_info = - getTensorInfo(source_); - - TensorInfo dst_info = - getTensorInfo(self_); - int new_indexing_dim = dst_info.collapseDims(dim); - - using IdxConfig = IndexKernelConfig< - decltype(src_info), - decltype(dst_info), - decltype(index_info), - IndexAddScalarFunctor>; - using KernelClass = IndexKernel; - - auto cfg = IdxConfig::template make_config( - src_info, - dst_info, - index_info, - alpha.to(), - new_indexing_dim, - true, - IndexAddScalarFunctor()); - launch_index_kernel(cfg); - }); - }); -} - template struct IndexFillScalarFunctor { void operator()( @@ -590,10 +228,14 @@ void index_put_kernel( false); }); } else { - AT_DISPATCH_V2( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::ComplexHalf, + at::ScalarType::BFloat16, + at::ScalarType::Half, + at::ScalarType::Bool, iter.dtype(), "index_put_xpu", - AT_WRAP([&] { + [&] { using dtype = OpaqueType; IndexPutFunctor f; _index_kernel( @@ -604,13 +246,7 @@ void index_put_kernel( IntArrayRef{}, f, false); - }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - AT_EXPAND(AT_FLOAT8_TYPES), - kComplexHalf, - kHalf, - kBool, - kBFloat16); + }); } } @@ -697,10 +333,14 @@ void index_put_deterministic_kernel( expandedValue.numel()); if (sliceSize > SIMD) { - AT_DISPATCH_V2( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::ComplexHalf, + at::ScalarType::BFloat16, + at::ScalarType::Half, + at::ScalarType::Bool, expandedValue.scalar_type(), "index_put_deterministic_kernel", - AT_WRAP([&] { + [&] { launch_index_put_deterministic_kernel( sorted_indices.mutable_data_ptr(), orig_indices.mutable_data_ptr(), @@ -711,24 +351,16 @@ void index_put_deterministic_kernel( strideBefore, nElemBefore, accumulate); - }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - // TODO: Enable AT_FLOAT8_DTYPES after accumulation behavior is - // cleared for float8 dtypes. - kFloat8_e4m3fn, - kFloat8_e5m2, - kFloat8_e4m3fnuz, - kFloat8_e5m2fnuz, - kComplexHalf, - kHalf, - kBool, - kBFloat16); + }); } else { - // Align acc type with CUDA - AT_DISPATCH_V2( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::ComplexHalf, + at::ScalarType::BFloat16, + at::ScalarType::Half, + at::ScalarType::Bool, expandedValue.scalar_type(), "index_put_deterministic_kernel", - AT_WRAP([&] { + [&] { using accscalar_t = at::opmath_type; launch_index_put_deterministic_kernel( sorted_indices.mutable_data_ptr(), @@ -740,18 +372,7 @@ void index_put_deterministic_kernel( strideBefore, nElemBefore, accumulate); - }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - // TODO: Enable AT_FLOAT8_DTYPES after accumulation behavior is - // cleared for float8 dtypes. - kFloat8_e4m3fn, - kFloat8_e5m2, - kFloat8_e4m3fnuz, - kFloat8_e5m2fnuz, - kComplexHalf, - kHalf, - kBool, - kBFloat16); + }); } if (permuted) @@ -1314,6 +935,8 @@ struct IndexFuncSmallIndexFunctor { T alpha_; }; +#define SMEM_SIZE 4096 + template < typename T, typename IndicesType, @@ -1323,11 +946,19 @@ template < int IdxDim, bool IndexIsMajor, typename func_t> -struct IndexFuncLargeIndexFunctor { - void operator()(sycl::nd_item<1> item) const { - // We stride over the output including the indexed dimension - // (totalSize), and calculate the destination index point based on that +struct IndexFuncLargeIndexFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<1> item) const { auto local_range = item.get_local_range(0); + T identity = (T)0; + + for (int i = item.get_local_id(0); i < SMEM_SIZE; i += local_range) { + smem_offsets[i] = (IndexType)-1; + smem_values[i] = identity; + } + + item.barrier(sycl_local_fence); + for (IndexType linearIndex = item.get_group(0) * local_range + item.get_local_id(0); linearIndex < totalSize_; @@ -1357,9 +988,44 @@ struct IndexFuncLargeIndexFunctor { srcOffset += srcIndex * src_.strides[srcAddDim_]; T val = src_.data[srcOffset] * alpha_; - op_(dst_.data, dstOffset, dstNumel_, &val); + const int smem_idx = (dstOffset / sizeof(T)) & (SMEM_SIZE - 1); + IndexType current_offset = smem_offsets[smem_idx]; + + if (current_offset == dstOffset) { + atomicAddLocal( + static_cast>(&smem_values[smem_idx]), val); + } else if (current_offset == (IndexType)-1) { + IndexType expected = (IndexType)-1; + if (atomicCAS(&smem_offsets[smem_idx], expected, dstOffset) == + expected) { + atomicAddLocal( + static_cast>(&smem_values[smem_idx]), val); + } else { + op_(dst_.data, dstOffset, dstNumel_, &val); + } + } else { + op_(dst_.data, dstOffset, dstNumel_, &val); + } + } + + item.barrier(sycl_local_fence); + + if (item.get_local_id(0) < SMEM_SIZE) { + IndexType final_dstOffset = smem_offsets[item.get_local_id(0)]; + + if (final_dstOffset != -1) { + T final_val = smem_values[item.get_local_id(0)]; + + op_(dst_.data, final_dstOffset, dstNumel_, &final_val); + } } } + + void sycl_ker_config_convention(sycl::handler& cgh) { + smem_offsets = sycl_local_acc_t(SMEM_SIZE, cgh); + smem_values = sycl_local_acc_t(SMEM_SIZE, cgh); + } + IndexFuncLargeIndexFunctor( TensorInfo dst, TensorInfo src, @@ -1396,16 +1062,377 @@ struct IndexFuncLargeIndexFunctor { int64_t dstNumel_; func_t op_; T alpha_; + sycl_local_acc_t smem_offsets; + sycl_local_acc_t smem_values; }; -template -void index_reduce_func_xpu_template( - const Tensor& self, - int64_t dim, - const Tensor& index, - const Tensor& source, - bool include_self, - const ReductionType& reduce, +struct IndexReduceAddFunctor { + template + void operator()( + scalar_t* self_data_start, + int64_t index, + int64_t numel, + const scalar_t* src_data) const { + atomicAdd((sycl_global_ptr)(self_data_start + index), *src_data); + } +}; +static IndexReduceAddFunctor reduce_add; + +// Check tensor dimensions for index operations, and return the slice size. +static ptrdiff_t getSliceSize( + const Tensor& dst, + int dim, + const Tensor& index, + const Tensor& src) { + const auto dstDims = dst.dim(); + const auto srcDims = src.dim(); + + TORCH_CHECK(index.dim() <= 1, "Index must be vector or scalar"); + + ptrdiff_t dstSliceSize = 1; + TORCH_CHECK( + dim >= 0 && dim < dstDims, "Indexing dim ", dim, " is out of bounds"); + for (const auto d : c10::irange(dstDims)) { + if (d != dim) { + dstSliceSize *= dst.size(d); + } + } + + TORCH_CHECK(dim < srcDims, "Indexing dim ", dim, " is out of bounds"); + TORCH_CHECK( + index.numel() == src.size(dim), + "length of src.size[dim] is not equal to length of indices"); + + ptrdiff_t srcSliceSize = 1; + bool mismatch = false; + + if (dstDims != srcDims) + mismatch = true; + + for (const auto d : c10::irange(srcDims)) { + if (d != dim) { + srcSliceSize *= src.size(d); + if (!mismatch && dst.size(d) != src.size(d)) + mismatch = true; + } + } + + TORCH_CHECK( + dstSliceSize == srcSliceSize, + "Source/destination tensor have different slice sizes (%ld vs %ld)", + dstSliceSize, + srcSliceSize); + + if (mismatch) { + TORCH_WARN_ONCE( + "Warning: source/destination slices have same size but different " + "shape for an index operation. This behavior is deprecated.\n"); + } + + return dstSliceSize; +} + +template +bool indexShouldBeMajor( + TensorInfo& info, + int sliceDim) { + // The stride between adjacent slices (e.g., between element #0 of slice #100 + // and element #0 of slice #101). + unsigned int sliceStride = info.strides[sliceDim]; + + for (const auto i : c10::irange(info.dims)) { + if (i != sliceDim && info.sizes[i] > 1 && info.strides[i] < sliceStride) { + return true; + } + } + return false; +} + +template +void index_reduce_add_xpu_template( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + const Scalar& alpha, + const Tensor& result, + const func_t& func) { + if (!result.is_same(self)) { + result.copy_(self); + } + + auto numel = index.numel(); + if (result.dim() > 1) { + if (numel == 0 || self.numel() == 0) { + return; + } + } + + // Scalars are treated as 1-d tensor + const Tensor self_ = (result.dim() == 0) ? result.view(1) : result; + const Tensor source_ = (source.dim() == 0) ? source.view(1) : source; + + TORCH_CHECK( + result.dim() <= XPU_MAX_TENSORINFO_DIMS, + "tensor has too many (>", + XPU_MAX_TENSORINFO_DIMS, + ") dims"); + TORCH_CHECK( + source.dim() <= XPU_MAX_TENSORINFO_DIMS, + "tensor has too many (>", + XPU_MAX_TENSORINFO_DIMS, + ") dims"); + TORCH_CHECK( + index.dim() <= XPU_MAX_TENSORINFO_DIMS, + "tensor has too many (>", + XPU_MAX_TENSORINFO_DIMS, + ") dims"); + + if (globalContext().deterministicAlgorithms()) { + torch::List> indices; + indices.reserve(dim + 1); + for (int i = 0; i < dim; i++) { + indices.emplace_back(); + } + indices.emplace_back(index.to(at::kLong)); + result.index_put_(indices, source * alpha, true); + return; + } + + // The `source` is partitioned into two parts: + // -the size of each slice we are indexing, which is the + // total size of the tensor ignoring dimension `dim`; + // -the number of index we are choosing, which is the total size + // of the tensor `index`. + const uint64_t sliceSize = getSliceSize(self_, dim, index, source_); + const uint64_t sourceTotalSize = source.numel(); + const uint64_t selfAddDimSize = self_.size(dim); + const uint64_t numIndex = index.numel(); + const uint64_t selfNumel = self_.numel(); + + if (sliceSize == 0) { + return; + } + + const bool indContig = index.is_contiguous(); + int ssc = syclMaxDSSNum(); + +#define SMALL_INDEX( \ + TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM, FUNC_T) \ + IndexFuncSmallIndexFunctor< \ + TENSOR_TYPE, \ + INDICES_TYPE, \ + TYPE, \ + SELF_DIM, \ + SOURCE_DIM, \ + IDX_DIM, \ + FUNC_T>( \ + selfInfo, \ + sourceInfo, \ + indexInfo, \ + selfAddDim, \ + sourceAddDim, \ + sliceSize, \ + selfAddDimSize, \ + selfNumel, \ + reduce_add, \ + alpha_value); + +#define LARGE_INDEX( \ + TENSOR_TYPE, \ + INDICES_TYPE, \ + TYPE, \ + SELF_DIM, \ + SOURCE_DIM, \ + IDX_DIM, \ + IDX_IS_MAJOR, \ + FUNC_T) \ + IndexFuncLargeIndexFunctor< \ + TENSOR_TYPE, \ + INDICES_TYPE, \ + TYPE, \ + SELF_DIM, \ + SOURCE_DIM, \ + IDX_DIM, \ + IDX_IS_MAJOR, \ + FUNC_T>( \ + selfInfo, \ + sourceInfo, \ + indexInfo, \ + selfAddDim, \ + sourceAddDim, \ + sourceTotalSize, \ + (IDX_IS_MAJOR) ? sliceSize : numIndex, \ + selfAddDimSize, \ + selfNumel, \ + reduce_add, \ + alpha_value); + + if (canUse32BitIndexMath(result) && canUse32BitIndexMath(source) && + canUse32BitIndexMath(index)) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::Bool, + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::ComplexHalf, + result.scalar_type(), + "index_add", + [&] { + TensorInfo selfInfo = + getTensorInfo(self_); + const int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + const auto alpha_value = alpha.to(); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_xpu_", [&]() { + auto sourceInfo = + getTensorInfo(source_); + const int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + auto indexInfo = getTensorInfo(index); + indexInfo.collapseDims(); + + // A reasonable choice for when to have each thread iterate over + // index to choose + if (numIndex <= 16) { + size_t num_wg = std::min( + ceil_div(sliceSize, (uint64_t)128), (uint64_t)(ssc * 8)); + size_t wg_size = std::min(sliceSize, (uint64_t)128); + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + auto caller = SMALL_INDEX( + scalar_t, index_t, unsigned int, 1, 1, -2, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else if ( + selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + auto caller = SMALL_INDEX( + scalar_t, index_t, unsigned int, 2, 2, -2, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else if ( + selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + auto caller = SMALL_INDEX( + scalar_t, index_t, unsigned int, 3, 3, -2, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else { + auto caller = SMALL_INDEX( + scalar_t, index_t, unsigned int, -1, -1, -1, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } + } else { + const bool indexIsMajor = + indexShouldBeMajor(selfInfo, selfAddDim); + uint64_t defaultMaxGroupThreads = syclDeviceMaxWorkGroupSize(); + size_t num_wg = std::min( + ceil_div(sourceTotalSize, (uint64_t)128), + (uint64_t)(ssc * 8)); + size_t wg_size = (sourceTotalSize < defaultMaxGroupThreads) + ? sourceTotalSize + : defaultMaxGroupThreads; + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + auto caller = LARGE_INDEX( + scalar_t, index_t, unsigned int, 1, 1, -2, true, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else if ( + selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + if (indexIsMajor) { + auto caller = LARGE_INDEX( + scalar_t, index_t, unsigned int, 2, 2, -2, true, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else { + auto caller = LARGE_INDEX( + scalar_t, index_t, unsigned int, 2, 2, -2, false, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } + } else if ( + selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + if (indexIsMajor) { + auto caller = LARGE_INDEX( + scalar_t, index_t, unsigned int, 3, 3, -2, true, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else { + auto caller = LARGE_INDEX( + scalar_t, index_t, unsigned int, 3, 3, -2, false, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } + } else { + auto caller = LARGE_INDEX( + scalar_t, index_t, unsigned int, -1, -1, -1, true, func_t); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } + } + }); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Bool, + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "index_add", + [&] { + TensorInfo selfInfo = + getTensorInfo(self_); + const int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + const auto alpha_value = alpha.to(); + + TensorInfo sourceInfo = + getTensorInfo(source_); + const int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_xpu_", [&]() { + TensorInfo indexInfo = + getTensorInfo(index); + indexInfo.collapseDims(); + + auto caller = LARGE_INDEX( + scalar_t, index_t, uint64_t, -1, -1, -1, true, func_t); + // uint64_t defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); + uint64_t defaultMaxGroupThreads = syclDeviceMaxWorkGroupSize(); + size_t num_wg = std::min( + ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(ssc * 8)); + size_t wg_size = (sourceTotalSize < defaultMaxGroupThreads) + ? sourceTotalSize + : defaultMaxGroupThreads; + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + }); + }); + } + +#undef SMALL_INDEX +#undef LARGE_INDEX +} + +void index_add_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + const Scalar& alpha, + const Tensor& result) { + index_reduce_add_xpu_template( + self, dim, index, source, alpha, result, reduce_add); +} + +template +void index_reduce_func_xpu_template( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, const func_t& reduce_func, const Tensor& result) { globalContext().alertNotDeterministic("index_reduce_xpu"); @@ -1938,6 +1965,234 @@ static inline ForwardIt find_bound( return first; } +template < + typename T, + typename IndicesType, + typename IndexType, + int DstDim, + int SrcDim, + int IdxDim> +struct IndexSelectSmallIndexFunctor { + void operator()(sycl::nd_item<1> item) const { + // In order to avoid reloading the index that we are copying, load + // it once to handle all of the points that are being selected, so + // it can be reused as much as possible. This kernel is chosen when + // this is a good choice (small number of chosen indices), since + // re-accessing indices in addition to src elements can be slow. + for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) { + IndexType srcIndex = + indices.data[IndexToOffset::get( + dstIndex, indices)]; + SYCL_KERNEL_ASSERT(srcIndex < srcSelectDimSize); + + // We stride over the output ignoring the indexed dimension + // (innerSize), whose offset calculation is handled differently + for (IndexType linearIndex = item.get_group(0) * item.get_local_range(0) + + item.get_local_id(0); + linearIndex < innerSize; + linearIndex += item.get_group_range(0) * item.get_local_range(0)) { + IndexType dstOffset = + IndexToOffset::get(linearIndex, dst); + dstOffset += dstIndex * dst.strides[dstSelectDim]; + + IndexType srcOffset = + IndexToOffset::get(linearIndex, src); + srcOffset += srcIndex * src.strides[srcSelectDim]; + + dst.data[dstOffset] = src.data[srcOffset]; + } + } + } + + IndexSelectSmallIndexFunctor( + TensorInfo dst, + TensorInfo src, + TensorInfo indices, + int dstSelectDim, + int srcSelectDim, + IndexType innerSize, + int64_t srcSelectDimSize) + : dst(dst), + src(src), + indices(indices), + dstSelectDim(dstSelectDim), + srcSelectDim(srcSelectDim), + innerSize(innerSize), + srcSelectDimSize(srcSelectDimSize) {} + + private: + TensorInfo dst; + TensorInfo src; + TensorInfo indices; + int dstSelectDim; + int srcSelectDim; + IndexType innerSize; + int64_t srcSelectDimSize; +}; + +// When using a 0-dim scalar tensor, we need the legacy (THC) semantics of +// TensorInfo: Pretend that the scalar tensor is in fact a one-element vector. +template +TensorInfo tensorInfoLegacyIfScalar(TensorInfo ti) { + if (ti.dims == 0) { + ti.dims = 1; + ti.sizes[0] = 1; + ti.strides[0] = 1; + } + return ti; +} + +template +void index_select_out_impl( + Tensor& out, + const Tensor& self, + int64_t dim, + const Tensor& index) { + uint64_t numIndices = index.numel(); + auto selfDims = self.dim() == 0 ? 1 : self.dim(); + + TORCH_CHECK( + index.dim() <= 1, "Index is supposed to be an empty tensor or a vector"); + TORCH_CHECK( + !(self.dim() == 0 && numIndices != 1), + "index_select(): Index to scalar can have only 1 value, got ", + numIndices, + " value(s)"); + TORCH_CHECK(dim < selfDims, "Indexing dim is out of bounds"); + + std::vector newSize = self.sizes().vec(); + if (self.dim() > 0) { + newSize[dim] = numIndices; + } + + at::native::resize_output(out, newSize); + + uint64_t outTotalSize = out.numel(); + if (outTotalSize == 0) { + return; + } + + bool indContig = index.is_contiguous(); + + // The `self` is partitioned into two parts: + // -the size of each slice we are indexing, which is the + // total size of the tensor ignoring dimension `dim`; + // -the number of indices we are choosing, which is the total size + // of the tensor `indices`. + uint64_t selfSelectDimSize = self.dim() == 0 ? 1 : self.size(dim); + uint64_t sliceSize = outTotalSize / numIndices; + + int ssc = syclMaxDSSNum(); + +#define SMALL_INDEX( \ + TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ + IndexSelectSmallIndexFunctor< \ + TENSOR_TYPE, \ + INDICES_TYPE, \ + TYPE, \ + DST_DIM, \ + SRC_DIM, \ + IDX_DIM>( \ + outInfo, \ + selfInfo, \ + indicesInfo, \ + outSelectDim, \ + selfSelectDim, \ + static_cast(sliceSize), \ + selfSelectDimSize); + + // SmallIndexKernel is more performant when the number of indices is + // small, and pre-loading the index reduces memory accesses. When the + // number of indices is large, we avoid that and increase parallellism by + // calling gather_out which is a generalization of index_select + if (canUse32BitIndexMath(out) && canUse32BitIndexMath(self) && + canUse32BitIndexMath(index) && numIndices <= 16) { + auto outInfo = + tensorInfoLegacyIfScalar(getTensorInfo(out)); + int outSelectDim = outInfo.collapseDims(dim); + outInfo.reduceDim(outSelectDim); + + auto selfInfo = tensorInfoLegacyIfScalar( + getTensorInfo(self)); + int selfSelectDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfSelectDim); + + AT_DISPATCH_INDEX_TYPES( + index.scalar_type(), "index_select_out_impl", [&]() { + auto indicesInfo = tensorInfoLegacyIfScalar( + getTensorInfo(index)); + indicesInfo.collapseDims(); + + uint64_t defaultMaxGroupThreads = syclDeviceMaxWorkGroupSize() / 2; + size_t num_wg = std::min( + ceil_div(sliceSize, defaultMaxGroupThreads), (uint64_t)(ssc * 8)); + size_t wg_size = std::min(sliceSize, defaultMaxGroupThreads); + + // A reasonable choice for when to have each thread iterate over + // indices to choose + if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { + auto caller = + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { + auto caller = + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { + auto caller = + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } else { + auto caller = + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + } + }); + } else { + std::vector tmpSize(newSize.size(), 1); + if (self.dim() > 0) { + tmpSize[dim] = numIndices; + } + at::gather_out(out, self, dim, index.view(tmpSize).expand(newSize)); + return; + } +#undef SMALL_INDEX +} + +Tensor& index_select_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + Tensor& out) { + static constexpr std::string_view DIM_WARNING = + "Tensor too large or too many (> 25) dimensions"; + at::assert_no_internal_overlap(out); + at::assert_no_overlap(out, self); + at::assert_no_overlap(out, index); + + dim = at::maybe_wrap_dim(dim, self); + TORCH_CHECK(self.dim() <= XPU_MAX_TENSORINFO_DIMS, DIM_WARNING); + TORCH_CHECK(index.dim() <= XPU_MAX_TENSORINFO_DIMS, DIM_WARNING); + + AT_DISPATCH_V2( + out.scalar_type(), + "index_select_xpu", + AT_WRAP([&] { index_select_out_impl(out, self, dim, index); }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); + + return out; +} + template struct IndexSelectSparse1Functor { index_t operator()(index_t idx) const { From 85b206e9520b3fd46090d7f7dc7f0719cf2461a0 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:17:06 +0800 Subject: [PATCH 2/2] Update Indexing.cpp --- src/ATen/native/xpu/sycl/Indexing.cpp | 76 +++++++++++++++++---------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index 942247de6..952efd823 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -45,14 +45,10 @@ void index_kernel( TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( - at::ScalarType::ComplexHalf, - at::ScalarType::BFloat16, - at::ScalarType::Half, - at::ScalarType::Bool, + AT_DISPATCH_V2( iter.dtype(), "index_xpu", - [&] { + AT_WRAP([&] { using dtype = OpaqueType; IndexFunctor f; _index_kernel( @@ -63,7 +59,13 @@ void index_kernel( IntArrayRef{}, f, true); - }); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); } template @@ -228,14 +230,10 @@ void index_put_kernel( false); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( - at::ScalarType::ComplexHalf, - at::ScalarType::BFloat16, - at::ScalarType::Half, - at::ScalarType::Bool, + AT_DISPATCH_V2( iter.dtype(), "index_put_xpu", - [&] { + AT_WRAP([&] { using dtype = OpaqueType; IndexPutFunctor f; _index_kernel( @@ -246,7 +244,13 @@ void index_put_kernel( IntArrayRef{}, f, false); - }); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); } } @@ -331,16 +335,11 @@ void index_put_deterministic_kernel( linearIndex.numel() * sliceSize * nElemBefore, " vs ", expandedValue.numel()); - if (sliceSize > SIMD) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( - at::ScalarType::ComplexHalf, - at::ScalarType::BFloat16, - at::ScalarType::Half, - at::ScalarType::Bool, + AT_DISPATCH_V2( expandedValue.scalar_type(), "index_put_deterministic_kernel", - [&] { + AT_WRAP([&] { launch_index_put_deterministic_kernel( sorted_indices.mutable_data_ptr(), orig_indices.mutable_data_ptr(), @@ -351,16 +350,24 @@ void index_put_deterministic_kernel( strideBefore, nElemBefore, accumulate); - }); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + // TODO: Enable AT_FLOAT8_DTYPES after accumulation behavior is + // cleared for float8 dtypes. + kFloat8_e4m3fn, + kFloat8_e5m2, + kFloat8_e4m3fnuz, + kFloat8_e5m2fnuz, + kComplexHalf, + kHalf, + kBool, + kBFloat16); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( - at::ScalarType::ComplexHalf, - at::ScalarType::BFloat16, - at::ScalarType::Half, - at::ScalarType::Bool, + // Align acc type with CUDA + AT_DISPATCH_V2( expandedValue.scalar_type(), "index_put_deterministic_kernel", - [&] { + AT_WRAP([&] { using accscalar_t = at::opmath_type; launch_index_put_deterministic_kernel( sorted_indices.mutable_data_ptr(), @@ -372,7 +379,18 @@ void index_put_deterministic_kernel( strideBefore, nElemBefore, accumulate); - }); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + // TODO: Enable AT_FLOAT8_DTYPES after accumulation behavior is + // cleared for float8 dtypes. + kFloat8_e4m3fn, + kFloat8_e5m2, + kFloat8_e4m3fnuz, + kFloat8_e5m2fnuz, + kComplexHalf, + kHalf, + kBool, + kBFloat16); } if (permuted)