Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/infiniop/ops/topkrouter/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
// #include <cuda_bf16.h>
// #include <cuda_fp16.h>
// #include <cuda_runtime.h>

template <typename T>
inline __device__ float exp_func(T x) {
float data;
if constexpr (std::is_same_v<T, float>) {
data = x;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
data = __bfloat162float(x);
} else if constexpr (std::is_same_v<T, half>) {
data = __half2float(x);
Expand Down
8 changes: 8 additions & 0 deletions src/infiniop/ops/topkrouter/metax/topkrouter_metax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __TOPKROUTER_METAX_H__
#define __TOPKROUTER_METAX_H__

#include "../topkrouter.h"

DESCRIPTOR(metax)

#endif
93 changes: 93 additions & 0 deletions src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topkrouter_metax.h"
#include <cub/block/block_reduce.cuh>

namespace op::topkrouter::metax {

struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc) {
auto result = TopkrouterInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();

if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}

*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

namespace {

template <int BLOCK_SIZE = 128>
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
hcStream_t stream) {
const int block_threads = BLOCK_SIZE;
dim3 blocks(N);
dim3 threads(block_threads);

if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_F16) {
topkrouter_kernel<half, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<cuda_bfloat16, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}

}; // namespace

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const float *correction_bias,
const float routed_scaling_factor,
const size_t topk,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

size_t N = _info.N;
size_t width = _info.width; // 256

// size_t n_routed_experts = 256;
// size_t n_group = 8;
// size_t topk_group = 4;
auto cuda_stream = reinterpret_cast<hcStream_t>(stream);

if (256 == width) {
launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream);
} else {
return INFINI_STATUS_BAD_PARAM;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::topkrouter::metax
15 changes: 15 additions & 0 deletions src/infiniop/ops/topkrouter/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/topkrouter_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/topkrouter_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/topkrouter_kunlun.h"
#endif
Expand All @@ -30,6 +33,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
Expand All @@ -56,6 +62,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
Expand Down Expand Up @@ -85,6 +94,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
Expand All @@ -111,6 +123,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif
Expand Down