Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -842,8 +842,11 @@ __global__ void permute_x_fp8_kernel(
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
const int end_idx = token_nums_per_expert_cum[i];
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
if ((s_token_idx - start_idx) < token_nums_per_expert[i])
if ((s_token_idx - start_idx) < token_nums_per_expert[i]) {
m_indices[s_token_idx] = i;
} else {
m_indices[s_token_idx] = -1;
}
break;
}
}
Expand Down Expand Up @@ -979,8 +982,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
paddle::DataType::FLOAT32,
place);

auto m_indices = paddle::full(
{token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place);
auto m_indices =
GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::INT32, place);
auto token_nums_per_expert_cumsum =
GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto token_nums_per_expert_padded_cumsum =
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def allgather(self, out, token_num):
paddle.Tensor: Gathered tensor.
"""
token_num_per_rank = out.shape[0]
if token_num_per_rank == 0:
return out
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, out.shape[1]], dtype=out.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
return multi_outs[:token_num, :]
Expand Down
106 changes: 106 additions & 0 deletions tests/layers/test_ep_moe_expert_dispatch_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import unittest

import numpy as np
import paddle

import fastdeploy

np.random.seed(20160703)

paddle.set_default_dtype("bfloat16")


class TestFusedMoE(unittest.TestCase):
def setUp(self) -> None:
pass

def test_ffn(self):
paddle.seed(10)
num_rows = 2
recv_x = paddle.randn([num_rows, 4096], dtype="bfloat16").cast(paddle.float8_e4m3fn)
recv_x_scale = paddle.randn([num_rows, 4096 // 128]).cast("float32")
local_num_experts = 8
gate_out = paddle.randn([num_rows, local_num_experts], dtype="float32")
recv_topk_idx = paddle.topk(gate_out, k=8, axis=-1)[1]
recv_topk_idx[:, 3:5] = -1
recv_topk_weights = paddle.topk(gate_out, k=8, axis=-1)[0]

tmp0 = [0] * local_num_experts
tmp1 = [0] * local_num_experts
recv_topk_idx_list = recv_topk_idx.flatten().numpy().tolist()
for ele in recv_topk_idx_list:
if ele >= 0:
tmp0[ele] += 1
for idx in range(len(tmp1)):
tmp1[idx] = (tmp0[idx] + 127) // 128 * 128

token_all_num = sum(tmp1)
baseline_m_indices = paddle.zeros([token_all_num]).cast("int32") - 1
for idx in range(len(tmp1)):
start = sum(tmp1[:idx])
baseline_m_indices[start : start + tmp0[idx]] = idx

tmp0 = paddle.to_tensor(tmp0).cast("int32")
tmp1 = paddle.to_tensor(tmp1).cast("int32")

(
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
recv_topk_idx,
recv_topk_weights,
tmp0,
tmp1,
True, # use_in_ep
token_all_num,
)
assert (m_indices - baseline_m_indices).abs().sum().item() == 0
for i in range(recv_x.shape[0]):
for j in range(local_num_experts):
dst_pos = permute_indices_per_token[j, i].item()
if dst_pos >= 0:

a = recv_x[i].cast("float32")
b = permute_input[dst_pos].cast("float32")
assert (a - b).abs().max().item() == 0

def haha():
for i in range(100):
fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
recv_topk_idx,
recv_topk_weights,
tmp0,
tmp1,
True, # use_in_ep
token_all_num,
)

num_tests = 20

start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
start_events[i].record()

haha()

end_events[i].record()
paddle.device.cuda.synchronize()

times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
print(times[-5:])


if __name__ == "__main__":
unittest.main()
Loading