Skip to content

Commit 1c1388c

Browse files
committed
issue/563 Add metax support for topkrouter
1 parent cad2d45 commit 1c1388c

File tree

5 files changed

+121
-5
lines changed

5 files changed

+121
-5
lines changed

src/infiniop/ops/topkrouter/cuda/kernel.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
#include <cub/block/block_reduce.cuh>
77
#include <cub/block/block_store.cuh>
88
#include <cub/cub.cuh>
9-
#include <cuda_bf16.h>
10-
#include <cuda_fp16.h>
11-
#include <cuda_runtime.h>
9+
// #include <cuda_bf16.h>
10+
// #include <cuda_fp16.h>
11+
// #include <cuda_runtime.h>
1212

1313
template <typename T>
1414
inline __device__ float exp_func(T x) {
1515
float data;
1616
if constexpr (std::is_same_v<T, float>) {
1717
data = x;
18-
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
18+
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
1919
data = __bfloat162float(x);
2020
} else if constexpr (std::is_same_v<T, half>) {
2121
data = __half2float(x);
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __TOPKROUTER_METAX_H__
2+
#define __TOPKROUTER_METAX_H__
3+
4+
#include "../topkrouter.h"
5+
6+
DESCRIPTOR(metax)
7+
8+
#endif
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#include "../../../devices/metax/metax_common.h"
2+
#include "../../../devices/metax/metax_kernel_common.h"
3+
#include "../cuda/kernel.cuh"
4+
#include "topkrouter_metax.h"
5+
#include <cub/block/block_reduce.cuh>
6+
7+
namespace op::topkrouter::metax {
8+
9+
struct Descriptor::Opaque {
10+
std::shared_ptr<device::metax::Handle::Internal> internal;
11+
};
12+
13+
Descriptor::~Descriptor() {
14+
delete _opaque;
15+
}
16+
17+
infiniStatus_t Descriptor::create(
18+
infiniopHandle_t handle,
19+
Descriptor **desc_ptr,
20+
infiniopTensorDescriptor_t x_desc,
21+
infiniopTensorDescriptor_t correction_bias_desc) {
22+
auto result = TopkrouterInfo::create(x_desc);
23+
CHECK_RESULT(result);
24+
auto info = result.take();
25+
26+
if (info.x_strides[1] != 1) {
27+
return INFINI_STATUS_BAD_TENSOR_STRIDES;
28+
}
29+
30+
*desc_ptr = new Descriptor(
31+
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
32+
std::move(info),
33+
0,
34+
handle->device, handle->device_id);
35+
return INFINI_STATUS_SUCCESS;
36+
}
37+
38+
namespace {
39+
40+
template <int BLOCK_SIZE = 128>
41+
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
42+
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
43+
hcStream_t stream) {
44+
const int block_threads = BLOCK_SIZE;
45+
dim3 blocks(N);
46+
dim3 threads(block_threads);
47+
48+
if (xtype == INFINI_DTYPE_F32) {
49+
topkrouter_kernel<float, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
50+
} else if (xtype == INFINI_DTYPE_F16) {
51+
topkrouter_kernel<half, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
52+
} else if (xtype == INFINI_DTYPE_BF16) {
53+
topkrouter_kernel<cuda_bfloat16, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
54+
} else {
55+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
56+
}
57+
58+
return INFINI_STATUS_SUCCESS;
59+
}
60+
61+
}; // namespace
62+
63+
infiniStatus_t Descriptor::calculate(
64+
void *workspace,
65+
size_t workspace_size,
66+
float *values,
67+
int *indices,
68+
const void *x,
69+
const float *correction_bias,
70+
const float routed_scaling_factor,
71+
const size_t topk,
72+
void *stream) const {
73+
if (workspace_size < _workspace_size) {
74+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
75+
}
76+
77+
size_t N = _info.N;
78+
size_t width = _info.width; // 256
79+
80+
// size_t n_routed_experts = 256;
81+
// size_t n_group = 8;
82+
// size_t topk_group = 4;
83+
auto cuda_stream = reinterpret_cast<hcStream_t>(stream);
84+
85+
if (256 == width) {
86+
launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream);
87+
} else {
88+
return INFINI_STATUS_BAD_PARAM;
89+
}
90+
91+
return INFINI_STATUS_SUCCESS;
92+
}
93+
} // namespace op::topkrouter::metax

src/infiniop/ops/topkrouter/operator.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
99
#include "nvidia/topkrouter_nvidia.cuh"
1010
#endif
11+
#ifdef ENABLE_METAX_API
12+
#include "metax/topkrouter_metax.h"
13+
#endif
1114

1215
__C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, infiniopTopkrouterDescriptor_t *desc_ptr,
1316
infiniopTensorDescriptor_t x_desc,
@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
2629
#endif
2730
#ifdef ENABLE_QY_API
2831
CREATE(INFINI_DEVICE_QY, nvidia);
32+
#endif
33+
#ifdef ENABLE_METAX_API
34+
CREATE(INFINI_DEVICE_METAX, metax);
2935
#endif
3036
}
3137

@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
4955
#endif
5056
#ifdef ENABLE_QY_API
5157
GET(INFINI_DEVICE_QY, nvidia);
58+
#endif
59+
#ifdef ENABLE_METAX_API
60+
GET(INFINI_DEVICE_METAX, metax);
5261
#endif
5362
}
5463

@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
7584
#endif
7685
#ifdef ENABLE_QY_API
7786
CALCULATE(INFINI_DEVICE_QY, nvidia);
87+
#endif
88+
#ifdef ENABLE_METAX_API
89+
CALCULATE(INFINI_DEVICE_METAX, metax);
7890
#endif
7991
}
8092

@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
98110
#endif
99111
#ifdef ENABLE_QY_API
100112
DESTROY(INFINI_DEVICE_QY, nvidia);
113+
#endif
114+
#ifdef ENABLE_METAX_API
115+
DESTROY(INFINI_DEVICE_METAX, metax);
101116
#endif
102117
}
103118

test/infiniop/topkrouter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
# w (weight) types
3535
# Note: 'None' means the same as input dtype
36-
_X_DTYPES = [] # [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
36+
_X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
3737
# x types used for testing
3838
_VALUE_DTYPES = [InfiniDtype.F32]
3939

0 commit comments

Comments
 (0)