Skip to content

Commit 4fa760f

Browse files
authored
kenerl: add kernel for moe permutation with mask map (#433)
1 parent 464668f commit 4fa760f

File tree

4 files changed

+491
-25
lines changed

4 files changed

+491
-25
lines changed

src/kernels/moe/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cc_library(
88
topk_softmax_kernel.cu
99
grouped_topk_sigmoid_kernel.cu
1010
permutation_index_kernel.cu
11+
permutation_mask_kernel.cu
1112
DEPS
1213
cutlass
1314
glog::glog

src/kernels/moe/permutation_index_kernel.cu

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,20 @@ __global__ void permute_kernel(
9999
const int topk,
100100
const int dim) {
101101
// one block corresponds to one token
102-
const int token_idx = blockIdx.x;
102+
const int t_idx = blockIdx.x;
103103
const int tid = threadIdx.x;
104104

105105
// frag for load/store
106106
float4 frag_ls;
107107

108108
static constexpr int kFragSize = 16 / sizeof(T);
109109
// tokens: [n_tokens, dim]
110-
const T* token_base = tokens + token_idx * dim;
110+
const T* token_base = tokens + t_idx * dim;
111111
for (int i = tid * kFragSize; i < dim; i += blockDim.x * kFragSize) {
112112
// load fragment into frag_ls (float4)
113113
frag_ls = __ldlu(reinterpret_cast<const float4*>(token_base + i));
114114

115-
int idx = token_idx;
115+
int idx = t_idx;
116116
for (int k_idx = 0; k_idx < topk; ++k_idx) {
117117
// row_id_map: [topk, n_tokens] => idx in permuted tokens
118118
const int p_idx = row_id_map[idx];
@@ -145,7 +145,7 @@ __global__ void unpermute_kernel(
145145

146146
// load prob into shared memory for the token
147147
// let first topk thread to load probs
148-
for (int i = tid; i < topk; i += blockDim.x * blockDim.y) {
148+
for (int i = tid; i < topk; i += blockDim.x) {
149149
s_probs[i] = probs[(t_idx * topk) + i];
150150
}
151151
__syncthreads();
@@ -281,10 +281,8 @@ std::tuple<torch::Tensor, torch::Tensor> permute_with_index_map(
281281

282282
const auto type = tokens.scalar_type();
283283

284-
auto permuted_tokens = torch::empty({n_permuted_tokens, dim},
285-
torch::dtype(type).device(torch::kCUDA));
286-
auto row_id_map = torch::empty(
287-
{n_tokens * topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
284+
auto permuted_tokens = torch::empty({n_permuted_tokens, dim}, options);
285+
auto row_id_map = torch::empty({n_tokens * topk}, int32_options);
288286

289287
auto* stream = at::cuda::getCurrentCUDAStream().stream();
290288

@@ -321,17 +319,17 @@ std::tuple<torch::Tensor, torch::Tensor> permute_with_index_map(
321319
torch::Tensor unpermute_with_index_map(
322320
torch::Tensor permuted_tokens, // [n_permuted_tokens, dim]
323321
torch::Tensor row_id_map, // [topk, n_tokens] => dst row
324-
torch::Tensor probs, // [n_tokens, topk]
325-
int64_t n_tokens,
326-
int64_t topk) {
322+
torch::Tensor probs // [n_tokens, topk]
323+
) {
327324
const auto dim = permuted_tokens.size(1);
325+
const auto n_tokens = probs.size(0);
326+
const auto topk = probs.size(1);
328327
const auto type = permuted_tokens.scalar_type();
329328

330-
// [n_tokens, dim]
331-
auto tokens = torch::empty(
332-
{n_tokens, dim},
333-
torch::dtype(type).device(torch::kCUDA).requires_grad(false));
329+
const auto options = permuted_tokens.options();
334330

331+
// [n_tokens, dim]
332+
auto tokens = torch::empty({n_tokens, dim}, options);
335333
auto* stream = at::cuda::getCurrentCUDAStream().stream();
336334

337335
#define LAUNCH_UNPERMUTE_KERNEL(DType) \

src/kernels/moe/permutation_kernel_test.cu

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,26 @@ namespace llm {
1212
namespace kernel::moe {
1313
// forward declare the kernel function
1414
std::tuple<torch::Tensor, torch::Tensor> permute_with_index_map(
15-
torch::Tensor tokens,
16-
torch::Tensor indices);
17-
18-
torch::Tensor unpermute_with_index_map(torch::Tensor permuted_tokens,
19-
torch::Tensor row_id_map,
20-
torch::Tensor probs,
21-
int64_t n_tokens,
22-
int64_t topk);
15+
torch::Tensor tokens, // [n_tokens, dim]
16+
torch::Tensor indices // [n_tokens, topk]
17+
);
18+
19+
torch::Tensor unpermute_with_index_map(
20+
torch::Tensor permuted_tokens, // [n_permuted_tokens, dim]
21+
torch::Tensor row_id_map, // [topk, n_tokens]
22+
torch::Tensor probs // [n_tokens, topk]
23+
);
24+
25+
std::tuple<torch::Tensor, torch::Tensor> permute_with_mask_map(
26+
torch::Tensor tokens, // [n_tokens, dim]
27+
torch::Tensor routing_map, // [n_tokens, n_experts]
28+
int64_t topk);
29+
30+
torch::Tensor unpermute_with_mask_map(
31+
torch::Tensor permuted_tokens, // [n_permuted_tokens, dim]
32+
torch::Tensor row_id_map, // [n_experts, n_tokens]
33+
torch::Tensor probs // [n_tokens, n_experts]
34+
);
2335

2436
} // namespace kernel::moe
2537

@@ -66,6 +78,45 @@ torch::Tensor unpermute_index_ref(
6678
return tokens.sum(/*dim=*/1);
6779
}
6880

81+
std::tuple<torch::Tensor, torch::Tensor> permute_mask_ref(
82+
const torch::Tensor& tokens, // [n_tokens, dim]
83+
const torch::Tensor& routing_map // [n_tokens, n_experts]
84+
) {
85+
const auto n_tokens = routing_map.size(0);
86+
const auto n_experts = routing_map.size(1);
87+
const auto options = tokens.options();
88+
89+
// [n_experts, n_tokens]
90+
auto token_indices = torch::arange(n_tokens, options.dtype(torch::kLong))
91+
.unsqueeze(/*dim=*/0)
92+
.expand({n_experts, n_tokens});
93+
94+
// [n_permuted_tokens] original token indices, sorted by expert idx
95+
auto sorted_indices = token_indices.masked_select(/*mask=*/routing_map.t());
96+
auto permuted_tokens = tokens.index_select(
97+
/*dim=*/0, /*index=*/sorted_indices);
98+
return {permuted_tokens, sorted_indices};
99+
}
100+
101+
torch::Tensor unpermute_mask_ref(
102+
const torch::Tensor& permuted_tokens, // [n_permuted_tokens, dim]
103+
const torch::Tensor& permuted_probs, // [n_permuted_tokens]
104+
const torch::Tensor& sorted_incices, // [n_permuted_tokens]
105+
int64_t n_tokens) {
106+
const auto dim = permuted_tokens.size(1);
107+
const auto options = permuted_tokens.options();
108+
// [n_tokens, dim]
109+
auto tokens = torch::zeros({n_tokens, dim}, options);
110+
// [n_permuted_tokens] => [n_permuted_tokens, dim]
111+
auto index = sorted_incices.unsqueeze(/*dim=*/1).expand({-1, dim});
112+
// reduce sum over experts
113+
tokens.scatter_add_(
114+
/*dim=*/0,
115+
/*index=*/index,
116+
/*src=*/permuted_tokens * permuted_probs.unsqueeze(/*dim=*/1));
117+
return tokens;
118+
}
119+
69120
} // namespace
70121

71122
class PermuteTest
@@ -101,7 +152,7 @@ TEST_P(PermuteTest, Index) {
101152
EXPECT_TRUE(torch::allclose(permuted_tokens, ref_permuted_tokens));
102153

103154
auto unpermute_out = kernel::moe::unpermute_with_index_map(
104-
permuted_tokens, sorted_indices, probs, n_tokens, topk);
155+
permuted_tokens, sorted_indices, probs);
105156

106157
auto ref_unpermute_out = unpermute_index_ref(
107158
ref_permuted_tokens, ref_sorted_indices, probs, n_tokens, topk);
@@ -111,6 +162,49 @@ TEST_P(PermuteTest, Index) {
111162
torch::allclose(tokens, unpermute_out, /*rtol=*/1e-2, /*atol=*/1e-2));
112163
}
113164

165+
TEST_P(PermuteTest, Mask) {
166+
const auto [dtype, n_tokens, dim, n_experts, topk] = GetParam();
167+
168+
const auto options = torch::dtype(dtype).device(torch::kCUDA);
169+
170+
const auto tokens = torch::randn({n_tokens, dim}, options);
171+
const auto gating_logit = torch::randn({n_tokens, n_experts}, options);
172+
173+
auto [weights, indices] = gating_logit.topk(topk, /*dim=*/-1);
174+
// auto probs = weights.softmax(/*dim=*/-1);
175+
176+
// construct dense routing map and probs
177+
auto probs = torch::zeros_like(gating_logit)
178+
.scatter(
179+
/*dim=*/1, /*index=*/indices, /*value=*/1.0 / topk);
180+
auto routing_map = torch::zeros_like(gating_logit, torch::kInt)
181+
.scatter(
182+
/*dim=*/1, /*index=*/indices, /*value=*/1)
183+
.to(torch::kBool);
184+
185+
auto [permuted_tokens, row_id_map] =
186+
kernel::moe::permute_with_mask_map(tokens, routing_map, topk);
187+
188+
auto [ref_permuted_tokens, ref_row_id_map] =
189+
permute_mask_ref(tokens, routing_map);
190+
191+
EXPECT_TRUE(torch::allclose(permuted_tokens, ref_permuted_tokens));
192+
193+
auto unpermute_out =
194+
kernel::moe::unpermute_with_mask_map(permuted_tokens, row_id_map, probs);
195+
196+
auto ref_permuted_probs = probs.t().masked_select(/*mask=*/routing_map.t());
197+
auto ref_unpermute_out = unpermute_mask_ref(
198+
ref_permuted_tokens, ref_permuted_probs, ref_row_id_map, n_tokens);
199+
EXPECT_TRUE(torch::allclose(
200+
unpermute_out, ref_unpermute_out, /*rtol=*/1e-2, /*atol=*/1e-2));
201+
202+
EXPECT_TRUE(torch::allclose(tokens,
203+
unpermute_out,
204+
/*rtol=*/1e-2,
205+
/*atol=*/1e-2));
206+
}
207+
114208
INSTANTIATE_TEST_SUITE_P(
115209
Moe,
116210
PermuteTest,
@@ -123,4 +217,4 @@ INSTANTIATE_TEST_SUITE_P(
123217
::testing::Values(1, 2, 4) // topk
124218
));
125219

126-
} // namespace llm
220+
} // namespace llm

0 commit comments

Comments
 (0)