From a15aa36763990b31221cd97869f953bfbfef6f4e Mon Sep 17 00:00:00 2001 From: Zhao Shijie Date: Fri, 28 Nov 2025 23:18:04 +0800 Subject: [PATCH] issue/563 Add metax support for topkrouter --- src/infiniop/ops/topkrouter/cuda/kernel.cuh | 8 +- .../ops/topkrouter/metax/topkrouter_metax.h | 8 ++ .../topkrouter/metax/topkrouter_metax.maca | 93 +++++++++++++++++++ src/infiniop/ops/topkrouter/operator.cc | 15 +++ 4 files changed, 120 insertions(+), 4 deletions(-) create mode 100644 src/infiniop/ops/topkrouter/metax/topkrouter_metax.h create mode 100644 src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca diff --git a/src/infiniop/ops/topkrouter/cuda/kernel.cuh b/src/infiniop/ops/topkrouter/cuda/kernel.cuh index 0832c5b93..0e1578b50 100644 --- a/src/infiniop/ops/topkrouter/cuda/kernel.cuh +++ b/src/infiniop/ops/topkrouter/cuda/kernel.cuh @@ -6,16 +6,16 @@ #include #include #include -#include -#include -#include +// #include +// #include +// #include template inline __device__ float exp_func(T x) { float data; if constexpr (std::is_same_v) { data = x; - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { data = __bfloat162float(x); } else if constexpr (std::is_same_v) { data = __half2float(x); diff --git a/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h new file mode 100644 index 000000000..62f17dc6c --- /dev/null +++ b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h @@ -0,0 +1,8 @@ +#ifndef __TOPKROUTER_METAX_H__ +#define __TOPKROUTER_METAX_H__ + +#include "../topkrouter.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca new file mode 100644 index 000000000..71c2d37d6 --- /dev/null +++ b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca @@ -0,0 +1,93 @@ +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" +#include "../cuda/kernel.cuh" +#include "topkrouter_metax.h" +#include + +namespace op::topkrouter::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t correction_bias_desc) { + auto result = TopkrouterInfo::create(x_desc); + CHECK_RESULT(result); + auto info = result.take(); + + if (info.x_strides[1] != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +namespace { + +template +infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias, + const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype, + hcStream_t stream) { + const int block_threads = BLOCK_SIZE; + dim3 blocks(N); + dim3 threads(block_threads); + + if (xtype == INFINI_DTYPE_F32) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else if (xtype == INFINI_DTYPE_F16) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else if (xtype == INFINI_DTYPE_BF16) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +}; // namespace + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + float *values, + int *indices, + const void *x, + const float *correction_bias, + const float routed_scaling_factor, + const size_t topk, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + size_t N = _info.N; + size_t width = _info.width; // 256 + + // size_t n_routed_experts = 256; + // size_t n_group = 8; + // size_t topk_group = 4; + auto cuda_stream = reinterpret_cast(stream); + + if (256 == width) { + launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream); + } else { + return INFINI_STATUS_BAD_PARAM; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::topkrouter::metax diff --git a/src/infiniop/ops/topkrouter/operator.cc b/src/infiniop/ops/topkrouter/operator.cc index 23ed117eb..1115768a8 100644 --- a/src/infiniop/ops/topkrouter/operator.cc +++ b/src/infiniop/ops/topkrouter/operator.cc @@ -8,6 +8,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #include "nvidia/topkrouter_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/topkrouter_metax.h" +#endif __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, infiniopTopkrouterDescriptor_t *desc_ptr, infiniopTensorDescriptor_t x_desc, @@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i #endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); #endif } @@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript #endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); #endif } @@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void #endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); #endif } @@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip #endif #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); #endif }