@@ -12,14 +12,26 @@ namespace llm {
1212namespace kernel ::moe {
1313// forward declare the kernel function
1414std::tuple<torch::Tensor, torch::Tensor> permute_with_index_map (
15- torch::Tensor tokens,
16- torch::Tensor indices);
17-
18- torch::Tensor unpermute_with_index_map (torch::Tensor permuted_tokens,
19- torch::Tensor row_id_map,
20- torch::Tensor probs,
21- int64_t n_tokens,
22- int64_t topk);
15+ torch::Tensor tokens, // [n_tokens, dim]
16+ torch::Tensor indices // [n_tokens, topk]
17+ );
18+
19+ torch::Tensor unpermute_with_index_map (
20+ torch::Tensor permuted_tokens, // [n_permuted_tokens, dim]
21+ torch::Tensor row_id_map, // [topk, n_tokens]
22+ torch::Tensor probs // [n_tokens, topk]
23+ );
24+
25+ std::tuple<torch::Tensor, torch::Tensor> permute_with_mask_map (
26+ torch::Tensor tokens, // [n_tokens, dim]
27+ torch::Tensor routing_map, // [n_tokens, n_experts]
28+ int64_t topk);
29+
30+ torch::Tensor unpermute_with_mask_map (
31+ torch::Tensor permuted_tokens, // [n_permuted_tokens, dim]
32+ torch::Tensor row_id_map, // [n_experts, n_tokens]
33+ torch::Tensor probs // [n_tokens, n_experts]
34+ );
2335
2436} // namespace kernel::moe
2537
@@ -66,6 +78,45 @@ torch::Tensor unpermute_index_ref(
6678 return tokens.sum (/* dim=*/ 1 );
6779}
6880
81+ std::tuple<torch::Tensor, torch::Tensor> permute_mask_ref (
82+ const torch::Tensor& tokens, // [n_tokens, dim]
83+ const torch::Tensor& routing_map // [n_tokens, n_experts]
84+ ) {
85+ const auto n_tokens = routing_map.size (0 );
86+ const auto n_experts = routing_map.size (1 );
87+ const auto options = tokens.options ();
88+
89+ // [n_experts, n_tokens]
90+ auto token_indices = torch::arange (n_tokens, options.dtype (torch::kLong ))
91+ .unsqueeze (/* dim=*/ 0 )
92+ .expand ({n_experts, n_tokens});
93+
94+ // [n_permuted_tokens] original token indices, sorted by expert idx
95+ auto sorted_indices = token_indices.masked_select (/* mask=*/ routing_map.t ());
96+ auto permuted_tokens = tokens.index_select (
97+ /* dim=*/ 0 , /* index=*/ sorted_indices);
98+ return {permuted_tokens, sorted_indices};
99+ }
100+
101+ torch::Tensor unpermute_mask_ref (
102+ const torch::Tensor& permuted_tokens, // [n_permuted_tokens, dim]
103+ const torch::Tensor& permuted_probs, // [n_permuted_tokens]
104+ const torch::Tensor& sorted_incices, // [n_permuted_tokens]
105+ int64_t n_tokens) {
106+ const auto dim = permuted_tokens.size (1 );
107+ const auto options = permuted_tokens.options ();
108+ // [n_tokens, dim]
109+ auto tokens = torch::zeros ({n_tokens, dim}, options);
110+ // [n_permuted_tokens] => [n_permuted_tokens, dim]
111+ auto index = sorted_incices.unsqueeze (/* dim=*/ 1 ).expand ({-1 , dim});
112+ // reduce sum over experts
113+ tokens.scatter_add_ (
114+ /* dim=*/ 0 ,
115+ /* index=*/ index,
116+ /* src=*/ permuted_tokens * permuted_probs.unsqueeze (/* dim=*/ 1 ));
117+ return tokens;
118+ }
119+
69120} // namespace
70121
71122class PermuteTest
@@ -101,7 +152,7 @@ TEST_P(PermuteTest, Index) {
101152 EXPECT_TRUE (torch::allclose (permuted_tokens, ref_permuted_tokens));
102153
103154 auto unpermute_out = kernel::moe::unpermute_with_index_map (
104- permuted_tokens, sorted_indices, probs, n_tokens, topk );
155+ permuted_tokens, sorted_indices, probs);
105156
106157 auto ref_unpermute_out = unpermute_index_ref (
107158 ref_permuted_tokens, ref_sorted_indices, probs, n_tokens, topk);
@@ -111,6 +162,49 @@ TEST_P(PermuteTest, Index) {
111162 torch::allclose (tokens, unpermute_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 ));
112163}
113164
165+ TEST_P (PermuteTest, Mask) {
166+ const auto [dtype, n_tokens, dim, n_experts, topk] = GetParam ();
167+
168+ const auto options = torch::dtype (dtype).device (torch::kCUDA );
169+
170+ const auto tokens = torch::randn ({n_tokens, dim}, options);
171+ const auto gating_logit = torch::randn ({n_tokens, n_experts}, options);
172+
173+ auto [weights, indices] = gating_logit.topk (topk, /* dim=*/ -1 );
174+ // auto probs = weights.softmax(/*dim=*/-1);
175+
176+ // construct dense routing map and probs
177+ auto probs = torch::zeros_like (gating_logit)
178+ .scatter (
179+ /* dim=*/ 1 , /* index=*/ indices, /* value=*/ 1.0 / topk);
180+ auto routing_map = torch::zeros_like (gating_logit, torch::kInt )
181+ .scatter (
182+ /* dim=*/ 1 , /* index=*/ indices, /* value=*/ 1 )
183+ .to (torch::kBool );
184+
185+ auto [permuted_tokens, row_id_map] =
186+ kernel::moe::permute_with_mask_map (tokens, routing_map, topk);
187+
188+ auto [ref_permuted_tokens, ref_row_id_map] =
189+ permute_mask_ref (tokens, routing_map);
190+
191+ EXPECT_TRUE (torch::allclose (permuted_tokens, ref_permuted_tokens));
192+
193+ auto unpermute_out =
194+ kernel::moe::unpermute_with_mask_map (permuted_tokens, row_id_map, probs);
195+
196+ auto ref_permuted_probs = probs.t ().masked_select (/* mask=*/ routing_map.t ());
197+ auto ref_unpermute_out = unpermute_mask_ref (
198+ ref_permuted_tokens, ref_permuted_probs, ref_row_id_map, n_tokens);
199+ EXPECT_TRUE (torch::allclose (
200+ unpermute_out, ref_unpermute_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 ));
201+
202+ EXPECT_TRUE (torch::allclose (tokens,
203+ unpermute_out,
204+ /* rtol=*/ 1e-2 ,
205+ /* atol=*/ 1e-2 ));
206+ }
207+
114208INSTANTIATE_TEST_SUITE_P (
115209 Moe,
116210 PermuteTest,
@@ -123,4 +217,4 @@ INSTANTIATE_TEST_SUITE_P(
123217 ::testing::Values(1 , 2 , 4 ) // topk
124218 ));
125219
126- } // namespace llm
220+ } // namespace llm
0 commit comments