Skip to content

Commit 1294c6c

Browse files
authored
kernel: added align block permutation kernel for moe (#442)
1 parent 4fa760f commit 1294c6c

File tree

6 files changed

+554
-149
lines changed

6 files changed

+554
-149
lines changed

src/kernels/dispatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ namespace llm::kernel {
99
#define DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
1010
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
1111

12+
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
13+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)
14+
15+
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
16+
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
17+
1218
// NOLINTEND(cppcoreguidelines-macro-usage)
1319

1420
} // namespace llm::kernel

src/kernels/moe/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@ include(cc_library)
22
include(cc_test)
33

44
cc_library(
5-
NAME
5+
NAME
66
moe.kernels
7-
SRCS
7+
SRCS
88
topk_softmax_kernel.cu
99
grouped_topk_sigmoid_kernel.cu
1010
permutation_index_kernel.cu
1111
permutation_mask_kernel.cu
12+
align_block_kernel.cu
1213
DEPS
1314
cutlass
1415
glog::glog
@@ -23,6 +24,7 @@ cc_test(
2324
topk_softmax_kernel_test.cu
2425
grouped_topk_sigmoid_kernel_test.cu
2526
permutation_kernel_test.cu
27+
align_block_kernel_test.cu
2628
DEPS
2729
:moe.kernels
2830
absl::random_random
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <torch/torch.h>
3+
4+
#include <cute/tensor.hpp>
5+
6+
#include "../dispatch.h"
7+
8+
// Adapated from
9+
// https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/moe/moe_align_kernel.cu
10+
11+
namespace llm::kernel::moe {
12+
13+
namespace {
14+
constexpr int32_t WARP_SIZE = 32;
15+
16+
// map p_idx to f_idx
17+
template <typename scalar_t>
18+
__global__ void row_id_map_kernel(
19+
const scalar_t* __restrict__ topk_ids, // [m, topk]
20+
scalar_t* __restrict__ sorted_token_idxes, // [n_padded_tokens+]
21+
scalar_t* __restrict__ cu_sum, // [n_experts+1]
22+
int n_tokens // m * topk
23+
) {
24+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
25+
const int stride = blockDim.x * gridDim.x;
26+
27+
for (int i = tid; i < n_tokens; i += stride) {
28+
const auto e_idx = topk_ids[i];
29+
// N.B. token ids for each expert is not sorted
30+
const auto p_idx = atomicAdd(&cu_sum[e_idx], 1);
31+
sorted_token_idxes[p_idx] = i;
32+
}
33+
}
34+
35+
template <typename scalar_t>
36+
__global__ void cusum_kernel(
37+
const scalar_t* __restrict__ topk_ids, // [n_tokens, topk]
38+
scalar_t* __restrict__ expert_ids, // [n_blocks+]
39+
scalar_t* __restrict__ n_padded_tokens, // [1]
40+
int n_experts,
41+
int block_size,
42+
size_t n_tokens, // n_tokens * topk
43+
scalar_t* __restrict__ cu_sum // [n_experts+1]
44+
) {
45+
using namespace cute;
46+
47+
const int tid = threadIdx.x;
48+
const int stride = blockDim.x;
49+
50+
const auto curr_expert = threadIdx.x;
51+
52+
// token count for each expert [n_padded_experts]
53+
extern __shared__ int token_counts[];
54+
55+
// init token counts for each expert
56+
if (curr_expert < n_experts) {
57+
token_counts[curr_expert] = 0;
58+
}
59+
60+
__syncthreads();
61+
62+
// process the token shard
63+
for (int i = tid; i < n_tokens; i += stride) {
64+
// accumulate token counts for each expert
65+
atomicAdd(&token_counts[topk_ids[i]], 1);
66+
}
67+
68+
__syncthreads();
69+
70+
if (tid == 0) {
71+
cu_sum[0] = 0;
72+
for (int e_idx = 1; e_idx <= n_experts; ++e_idx) {
73+
cu_sum[e_idx] = cu_sum[e_idx - 1] +
74+
cute::round_up(token_counts[e_idx - 1], block_size);
75+
}
76+
*n_padded_tokens = cu_sum[n_experts];
77+
}
78+
79+
__syncthreads();
80+
81+
// update the expert id for each block
82+
if (curr_expert < n_experts) {
83+
for (int i = cu_sum[curr_expert]; i < cu_sum[curr_expert + 1];
84+
i += block_size) {
85+
expert_ids[i / block_size] = curr_expert;
86+
}
87+
}
88+
}
89+
90+
template <typename scalar_t>
91+
__global__ void align_block_kernel(
92+
const scalar_t* __restrict__ topk_ids, // [m, topk]
93+
scalar_t* __restrict__ sorted_token_idxes, // [n_padded_tokens+]
94+
scalar_t* __restrict__ expert_ids, // [n_blocks+]
95+
scalar_t* __restrict__ n_padded_tokens, // [1]
96+
int n_experts,
97+
int block_size,
98+
int n_tokens // m * topk
99+
) {
100+
using namespace cute;
101+
102+
// which shard and expert this thread would take care of
103+
const int n_shards = blockDim.x;
104+
const int curr_shard = threadIdx.x;
105+
const int curr_expert = threadIdx.x;
106+
107+
const int tid = threadIdx.x;
108+
const int stride = blockDim.x;
109+
110+
extern __shared__ int s_mem[];
111+
112+
// [n_experts+1]
113+
Tensor cu_sum = make_tensor(make_smem_ptr(reinterpret_cast<int*>(s_mem)),
114+
make_layout(make_shape(n_experts + 1)));
115+
// [n_shards+1][n_experts]
116+
// token_counts(0, _) = 0 is used to facilitate the prefix sum
117+
Tensor token_counts =
118+
make_tensor(make_smem_ptr(reinterpret_cast<int*>(s_mem + n_experts + 1)),
119+
make_layout(make_shape(n_shards + 1, n_experts)));
120+
121+
// init token counts for each expert in the shard
122+
for (int e_idx = 0; e_idx < n_experts; ++e_idx) {
123+
token_counts(curr_shard + 1, e_idx) = 0;
124+
}
125+
126+
// calculate expert counts for each token block
127+
for (int i = tid; i < n_tokens; i += stride) {
128+
++token_counts(curr_shard + 1, topk_ids[i]);
129+
}
130+
131+
__syncthreads();
132+
133+
// calculate the prefix sum for each expert
134+
// total number of tokens per expert is stored in token_counts(n_shards, _)
135+
if (curr_expert < n_experts) {
136+
token_counts(0, curr_expert) = 0;
137+
for (int i = 1; i <= n_shards; ++i) {
138+
token_counts(i, curr_expert) += token_counts(i - 1, curr_expert);
139+
}
140+
}
141+
142+
__syncthreads();
143+
144+
if (tid == 0) {
145+
// caluculate cumulative sum for each expert
146+
cu_sum[0] = 0;
147+
for (int e_idx = 1; e_idx <= n_experts; ++e_idx) {
148+
cu_sum[e_idx] =
149+
cu_sum[e_idx - 1] +
150+
cute::round_up(token_counts(n_shards, e_idx - 1), block_size);
151+
}
152+
*n_padded_tokens = cu_sum[n_experts];
153+
}
154+
155+
__syncthreads();
156+
157+
// each thread fills the expert id for each token block
158+
if (curr_expert < n_experts) {
159+
for (int i = cu_sum[curr_expert]; i < cu_sum[curr_expert + 1];
160+
i += block_size) {
161+
expert_ids[i / block_size] = curr_expert;
162+
}
163+
}
164+
165+
for (int i = tid; i < n_tokens; i += stride) {
166+
const auto e_idx = topk_ids[i];
167+
const auto idx = token_counts(curr_shard, e_idx)++;
168+
sorted_token_idxes[cu_sum[e_idx] + idx] = i;
169+
}
170+
}
171+
172+
// reduce along topk dimension, assuming contiguous memory
173+
template <typename scalar_t, int TOPK>
174+
__global__ void topk_sum_kernel(
175+
scalar_t* __restrict__ out, // [n_tokens, dim]
176+
const scalar_t* __restrict__ input, // [n_tokens, topk, dim]
177+
int64_t dim) {
178+
// one block per token
179+
const int64_t t_idx = blockIdx.x;
180+
for (int64_t i = threadIdx.x; i < dim; i += blockDim.x) {
181+
scalar_t sum = 0.0;
182+
CUTE_UNROLL
183+
for (int k = 0; k < TOPK; ++k) {
184+
sum += input[(t_idx * TOPK * dim) + (k * dim) + i];
185+
}
186+
out[(t_idx * dim) + i] = sum;
187+
}
188+
}
189+
190+
} // namespace
191+
192+
void permute_align_block(
193+
torch::Tensor topk_ids, // [n_tokens, topk]
194+
int64_t n_experts,
195+
int64_t block_size,
196+
torch::Tensor sorted_token_idxes, // [n_padded_permuted_tokens+]
197+
torch::Tensor experts_ids, // [n_blocks+]
198+
torch::Tensor n_padded_tokens, // [1]
199+
torch::Tensor cu_sum // [n_experts+1]
200+
) {
201+
const auto n_flatten_tokens = topk_ids.numel();
202+
auto* stream = at::cuda::getCurrentCUDAStream().stream();
203+
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "align_block_kernel", [&] {
204+
if (n_flatten_tokens <= 1024 && n_experts <= 64) {
205+
const int threads = std::max<int>(n_experts, WARP_SIZE);
206+
const int smem_size =
207+
((threads + 1) * n_experts + (n_experts + 1)) * sizeof(int);
208+
209+
align_block_kernel<scalar_t><<<1, threads, smem_size, stream>>>(
210+
topk_ids.data_ptr<scalar_t>(),
211+
sorted_token_idxes.data_ptr<scalar_t>(),
212+
experts_ids.data_ptr<scalar_t>(),
213+
n_padded_tokens.data_ptr<scalar_t>(),
214+
n_experts,
215+
block_size,
216+
n_flatten_tokens);
217+
} else {
218+
// each thread handles one expert
219+
// assert(n_experts <= 1024);
220+
size_t smem_size = 1024 * sizeof(int);
221+
cusum_kernel<scalar_t>
222+
<<<1, 1024, smem_size, stream>>>(topk_ids.data_ptr<scalar_t>(),
223+
experts_ids.data_ptr<scalar_t>(),
224+
n_padded_tokens.data_ptr<scalar_t>(),
225+
n_experts,
226+
block_size,
227+
n_flatten_tokens,
228+
cu_sum.data_ptr<scalar_t>());
229+
230+
constexpr int threads = 256;
231+
int n_blocks = cute::ceil_div(n_flatten_tokens, threads);
232+
n_blocks = std::min(n_blocks, 65535);
233+
row_id_map_kernel<scalar_t><<<n_blocks, threads, 0, stream>>>(
234+
topk_ids.data_ptr<scalar_t>(),
235+
sorted_token_idxes.data_ptr<scalar_t>(),
236+
cu_sum.data_ptr<scalar_t>(),
237+
n_flatten_tokens);
238+
}
239+
});
240+
}
241+
242+
void sum_out(const torch::Tensor& input, // [n_tokens, topk, dim]
243+
torch::Tensor& output) // [n_tokens, dim]
244+
{
245+
const auto n_tokens = input.size(0);
246+
const auto topk = input.size(1);
247+
const auto dim = input.size(2);
248+
249+
// one block per token
250+
const auto threads = std::min<int>(dim, 1024);
251+
auto* stream = at::cuda::getCurrentCUDAStream().stream();
252+
253+
#define DISPATCH_TOPK_SUM_KERNEL_CASE(TOPK) \
254+
case TOPK: { \
255+
DISPATCH_FLOATING_TYPES(input.scalar_type(), "sum_kernel", [&] { \
256+
topk_sum_kernel<scalar_t, TOPK><<<n_tokens, threads, 0, stream>>>( \
257+
output.data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), dim); \
258+
}); \
259+
break; \
260+
}
261+
262+
switch (topk) {
263+
DISPATCH_TOPK_SUM_KERNEL_CASE(2);
264+
DISPATCH_TOPK_SUM_KERNEL_CASE(3);
265+
DISPATCH_TOPK_SUM_KERNEL_CASE(4);
266+
DISPATCH_TOPK_SUM_KERNEL_CASE(8);
267+
default:
268+
// use torch::sum_out for other cases
269+
torch::sum_out(output, input, /*dim=*/1);
270+
break;
271+
}
272+
}
273+
274+
} // namespace llm::kernel::moe

0 commit comments

Comments
 (0)