Skip to content

Commit fb7f951

Browse files
[UNITEST] add test (#5305)
1 parent 8e0f4df commit fb7f951

File tree

3 files changed

+114
-3
lines changed

3 files changed

+114
-3
lines changed

custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,11 @@ __global__ void permute_x_fp8_kernel(
847847
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
848848
const int end_idx = token_nums_per_expert_cum[i];
849849
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
850-
if ((s_token_idx - start_idx) < token_nums_per_expert[i])
850+
if ((s_token_idx - start_idx) < token_nums_per_expert[i]) {
851851
m_indices[s_token_idx] = i;
852+
} else {
853+
m_indices[s_token_idx] = -1;
854+
}
852855
break;
853856
}
854857
}
@@ -984,8 +987,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
984987
paddle::DataType::FLOAT32,
985988
place);
986989

987-
auto m_indices = paddle::full(
988-
{token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place);
990+
auto m_indices =
991+
GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::INT32, place);
989992
auto token_nums_per_expert_cumsum =
990993
GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
991994
auto token_nums_per_expert_padded_cumsum =

fastdeploy/model_executor/layers/normalization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def allgather(self, out, token_num):
176176
paddle.Tensor: Gathered tensor.
177177
"""
178178
token_num_per_rank = out.shape[0]
179+
if token_num_per_rank == 0:
180+
return out
179181
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, out.shape[1]], dtype=out.dtype)
180182
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
181183
return multi_outs[:token_num, :]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import unittest
2+
3+
import numpy as np
4+
import paddle
5+
6+
import fastdeploy
7+
8+
np.random.seed(20160703)
9+
10+
paddle.set_default_dtype("bfloat16")
11+
12+
13+
class TestFusedMoE(unittest.TestCase):
14+
def setUp(self) -> None:
15+
pass
16+
17+
def test_ffn(self):
18+
paddle.seed(10)
19+
num_rows = 2
20+
recv_x = paddle.randn([num_rows, 4096], dtype="bfloat16").cast(paddle.float8_e4m3fn)
21+
recv_x_scale = paddle.randn([num_rows, 4096 // 128]).cast("float32")
22+
local_num_experts = 8
23+
gate_out = paddle.randn([num_rows, local_num_experts], dtype="float32")
24+
recv_topk_idx = paddle.topk(gate_out, k=8, axis=-1)[1]
25+
recv_topk_idx[:, 3:5] = -1
26+
recv_topk_weights = paddle.topk(gate_out, k=8, axis=-1)[0]
27+
28+
tmp0 = [0] * local_num_experts
29+
tmp1 = [0] * local_num_experts
30+
recv_topk_idx_list = recv_topk_idx.flatten().numpy().tolist()
31+
for ele in recv_topk_idx_list:
32+
if ele >= 0:
33+
tmp0[ele] += 1
34+
for idx in range(len(tmp1)):
35+
tmp1[idx] = (tmp0[idx] + 127) // 128 * 128
36+
37+
token_all_num = sum(tmp1)
38+
baseline_m_indices = paddle.zeros([token_all_num]).cast("int32") - 1
39+
for idx in range(len(tmp1)):
40+
start = sum(tmp1[:idx])
41+
baseline_m_indices[start : start + tmp0[idx]] = idx
42+
43+
tmp0 = paddle.to_tensor(tmp0).cast("int32")
44+
tmp1 = paddle.to_tensor(tmp1).cast("int32")
45+
46+
(
47+
permute_input,
48+
permute_scale,
49+
permute_indices_per_token,
50+
recv_num_tokens_per_expert_list_cumsum,
51+
recv_num_tokens_per_expert_list_padded_cumsum,
52+
dst_weights,
53+
dst_indices,
54+
cumsum_idx_gpu,
55+
m_indices,
56+
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
57+
recv_x,
58+
recv_x_scale,
59+
recv_topk_idx,
60+
recv_topk_weights,
61+
tmp0,
62+
tmp1,
63+
True, # use_in_ep
64+
token_all_num,
65+
)
66+
assert (m_indices - baseline_m_indices).abs().sum().item() == 0
67+
for i in range(recv_x.shape[0]):
68+
for j in range(local_num_experts):
69+
dst_pos = permute_indices_per_token[j, i].item()
70+
if dst_pos >= 0:
71+
72+
a = recv_x[i].cast("float32")
73+
b = permute_input[dst_pos].cast("float32")
74+
assert (a - b).abs().max().item() == 0
75+
76+
def haha():
77+
for i in range(100):
78+
fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
79+
recv_x,
80+
recv_x_scale,
81+
recv_topk_idx,
82+
recv_topk_weights,
83+
tmp0,
84+
tmp1,
85+
True, # use_in_ep
86+
token_all_num,
87+
)
88+
89+
num_tests = 20
90+
91+
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
92+
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
93+
for i in range(num_tests):
94+
start_events[i].record()
95+
96+
haha()
97+
98+
end_events[i].record()
99+
paddle.device.cuda.synchronize()
100+
101+
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
102+
print(times[-5:])
103+
104+
105+
if __name__ == "__main__":
106+
unittest.main()

0 commit comments

Comments
 (0)