Skip to content

Commit 15a2401

Browse files
david6666666SunChenxiang123
authored andcommitted
[EPLB] Optimize EPLB for Async Rearrange Experts (vllm-project#22179)
Signed-off-by: David Chen <530634352@qq.com> Co-authored-by: SunChenxiang123 <1291824390@qq.com>
1 parent 8a14521 commit 15a2401

File tree

7 files changed

+778
-77
lines changed

7 files changed

+778
-77
lines changed

tests/distributed/test_eplb_execute.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import asyncio
45
import random
56

67
import pytest
78
import torch
89
import torch.distributed
910

10-
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
11+
from vllm.distributed.eplb.rebalance_execute import (
12+
move_from_buffer,
13+
rearrange_expert_weights_inplace,
14+
transfer_layer,
15+
)
1116
from vllm.distributed.parallel_state import (
1217
ensure_model_parallel_initialized,
1318
get_tp_group,
@@ -231,6 +236,100 @@ def verify_redundant_experts_have_same_weights(
231236
)
232237

233238

239+
def _test_async_transfer_layer_without_mtp_worker(
240+
env,
241+
world_size: int,
242+
num_layers: int,
243+
num_local_experts: int,
244+
num_logical_experts: int,
245+
) -> None:
246+
set_env_vars_and_device(env)
247+
ensure_model_parallel_initialized(
248+
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
249+
)
250+
251+
tp_group = get_tp_group()
252+
ep_group = tp_group.device_group
253+
ep_rank = torch.distributed.get_rank()
254+
device = torch.device(f"cuda:{ep_rank}")
255+
256+
total_physical_experts = world_size * num_local_experts
257+
hidden_sizes = [16, 32]
258+
259+
redundancy_config = create_redundancy_config(
260+
num_logical_experts,
261+
total_physical_experts,
262+
)
263+
old_indices = create_expert_indices_with_redundancy(
264+
num_layers,
265+
num_logical_experts,
266+
total_physical_experts,
267+
redundancy_config,
268+
)
269+
270+
new_redundancy_config = create_redundancy_config(
271+
num_logical_experts,
272+
total_physical_experts,
273+
)
274+
new_indices = create_expert_indices_with_redundancy(
275+
num_layers,
276+
num_logical_experts,
277+
total_physical_experts,
278+
new_redundancy_config,
279+
)
280+
281+
expert_weights = create_expert_weights(
282+
num_layers,
283+
num_local_experts,
284+
hidden_sizes,
285+
ep_rank,
286+
device,
287+
old_indices,
288+
)
289+
290+
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
291+
cuda_stream = torch.cuda.Stream(device=device)
292+
293+
for layer_idx in range(num_layers):
294+
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
295+
transfer_layer(
296+
old_global_expert_indices=old_indices,
297+
new_global_expert_indices=new_indices,
298+
expert_weights=expert_weights,
299+
expert_weights_buffer=expert_buffer,
300+
ep_group=ep_group,
301+
layer=layer_idx,
302+
cuda_stream=cuda_stream,
303+
)
304+
)
305+
306+
cuda_stream.synchronize()
307+
move_from_buffer(
308+
expert_weights=expert_weights[layer_idx],
309+
expert_weights_buffer=expert_buffer,
310+
is_unchanged=is_unchanged,
311+
is_received_locally=is_received_locally,
312+
experts_recv_loc=experts_recv_loc,
313+
new_indices=new_indices[layer_idx].tolist(),
314+
ep_group=ep_group,
315+
)
316+
317+
verify_expert_weights_after_shuffle(
318+
expert_weights,
319+
new_indices,
320+
hidden_sizes,
321+
ep_rank,
322+
num_local_experts,
323+
)
324+
verify_redundant_experts_have_same_weights(
325+
expert_weights,
326+
new_indices,
327+
hidden_sizes,
328+
world_size,
329+
num_local_experts,
330+
)
331+
332+
234333
def _test_rearrange_expert_weights_with_redundancy(
235334
env, world_size, num_layers, num_local_experts, num_logical_experts
236335
) -> None:
@@ -399,6 +498,32 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
399498
)
400499

401500

501+
@pytest.mark.parametrize(
502+
"world_size,num_layers,num_local_experts,num_logical_experts",
503+
[
504+
(2, 2, 2, 3),
505+
],
506+
)
507+
def test_async_transfer_layer_without_mtp(
508+
world_size: int,
509+
num_layers: int,
510+
num_local_experts: int,
511+
num_logical_experts: int,
512+
):
513+
"""Exercise async EPLB transfer path without MTP/spec decode."""
514+
515+
if torch.cuda.device_count() < world_size:
516+
pytest.skip(f"Need at least {world_size} GPUs to run the test")
517+
518+
distributed_run(
519+
_test_async_transfer_layer_without_mtp_worker,
520+
world_size,
521+
num_layers,
522+
num_local_experts,
523+
num_logical_experts,
524+
)
525+
526+
402527
@pytest.mark.parametrize("world_size", [2, 4])
403528
def test_rearrange_expert_weights_no_change(world_size):
404529
"""

tests/distributed/test_eplb_spec_decode.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
def get_model_args(
1212
model_name: str,
13-
spec_model_name: str,
13+
spec_model_name: str | None,
1414
spec_method: str,
1515
tp_size: int,
1616
model_max_len: int,
17+
use_async: bool = False,
1718
) -> dict:
1819
speculative_config = {
1920
"method": spec_method,
@@ -37,6 +38,8 @@ def get_model_args(
3738
"enable_eplb": True,
3839
"max_model_len": model_max_len,
3940
}
41+
if use_async:
42+
model_args["eplb_config"] = {"use_async": True}
4043
return model_args
4144

4245

@@ -94,3 +97,37 @@ def test_eplb_spec_decode(
9497
measured_value - RTOL < expected_gsm8k_value
9598
and measured_value + RTOL > expected_gsm8k_value
9699
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"
100+
101+
102+
@large_gpu_mark(min_gb=80)
103+
def test_eplb_spec_decode_qwen3_next_mtp_async() -> None:
104+
"""
105+
Ensure async EPLB works with MTP speculative decoding for Qwen3-Next.
106+
"""
107+
108+
TASK = "gsm8k"
109+
FILTER = "exact_match,strict-match"
110+
RTOL = 0.03
111+
expected_gsm8k_value = 0.86
112+
113+
model_args = get_model_args(
114+
model_name="Qwen/Qwen3-Next-80B-A3B-Instruct",
115+
spec_model_name=None,
116+
spec_method="mtp",
117+
tp_size=4,
118+
model_max_len=4096,
119+
use_async=True,
120+
)
121+
122+
results = lm_eval.simple_evaluate(
123+
model="vllm",
124+
model_args=model_args,
125+
tasks=TASK,
126+
batch_size=64,
127+
num_fewshot=8,
128+
)
129+
measured_value = results["results"][TASK][FILTER]
130+
assert (
131+
measured_value - RTOL < expected_gsm8k_value
132+
and measured_value + RTOL > expected_gsm8k_value
133+
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"

vllm/config/parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class EPLBConfig:
6060
Log the balancedness each step of expert parallelism.
6161
This is turned off by default since it will cause communication overhead.
6262
"""
63+
use_async: bool = False
64+
"""
65+
Whether to use non-blocking EPLB.
66+
"""
6367

6468

6569
@config
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
The async worker that transfers experts in the background.
5+
"""
6+
7+
import asyncio
8+
import threading
9+
from typing import TYPE_CHECKING
10+
11+
import torch
12+
from torch.distributed import ProcessGroup
13+
14+
from vllm.distributed.parallel_state import get_ep_group
15+
from vllm.logger import init_logger
16+
17+
from .rebalance_execute import transfer_layer
18+
19+
if TYPE_CHECKING:
20+
from .eplb_state import EplbState
21+
22+
logger = init_logger(__name__)
23+
24+
25+
def start_async_worker(
26+
state: "EplbState",
27+
rank_mapping: dict[int, int] | None = None,
28+
is_profile: bool = False,
29+
) -> threading.Thread:
30+
ep_group = get_ep_group().device_group
31+
rank = ep_group.rank()
32+
device_index = state.cuda_device_index
33+
34+
def thread_target() -> None:
35+
assert device_index is not None
36+
torch.cuda.set_device(device_index)
37+
cuda_stream = torch.cuda.Stream(device=device_index)
38+
loop = asyncio.new_event_loop()
39+
asyncio.set_event_loop(loop)
40+
try:
41+
loop.run_until_complete(
42+
transfer_run_periodically(
43+
state=state,
44+
ep_group=ep_group,
45+
is_profile=is_profile,
46+
rank_mapping=rank_mapping,
47+
cuda_stream=cuda_stream,
48+
)
49+
)
50+
except Exception as exc: # pragma: no cover - diagnostic path
51+
logger.exception("async loop error (Rank %d): %s", rank, str(exc))
52+
finally:
53+
loop.close()
54+
55+
thread = threading.Thread(target=thread_target, daemon=True)
56+
thread.start()
57+
return thread
58+
59+
60+
async def transfer_run_periodically(
61+
state: "EplbState",
62+
ep_group: ProcessGroup,
63+
is_profile: bool = False,
64+
rank_mapping: dict[int, int] | None = None,
65+
cuda_stream: torch.cuda.Stream = None,
66+
) -> None:
67+
while True:
68+
await asyncio.to_thread(state.rearrange_event.wait)
69+
logger.info("async worker woke up for EPLB transfer")
70+
71+
for model_state in state.model_states.values():
72+
if not model_state.is_async_enabled:
73+
continue
74+
current_num_layers = model_state.model.num_moe_layers
75+
while (
76+
model_state.rebalanced
77+
and model_state.layer_to_transfer < current_num_layers
78+
):
79+
if (
80+
not model_state.ep_buffer_ready
81+
and model_state.rebalanced
82+
and model_state.new_physical_to_logical_map is not None
83+
):
84+
await asyncio.to_thread(model_state.buffer_lock.acquire)
85+
try:
86+
if model_state.layer_to_transfer >= current_num_layers:
87+
break
88+
89+
(
90+
model_state.is_unchanged,
91+
model_state.is_received_locally,
92+
model_state.experts_recv_loc,
93+
) = await transfer_layer(
94+
old_global_expert_indices=model_state.physical_to_logical_map,
95+
new_global_expert_indices=model_state.new_physical_to_logical_map,
96+
expert_weights=model_state.model.expert_weights,
97+
expert_weights_buffer=model_state.expert_buffer,
98+
ep_group=ep_group,
99+
is_profile=is_profile,
100+
layer=model_state.layer_to_transfer,
101+
cuda_stream=cuda_stream,
102+
rank_mapping=rank_mapping,
103+
)
104+
event = torch.cuda.Event(blocking=False)
105+
cuda_stream.record_event(event)
106+
model_state.buffer_ready_event = event
107+
model_state.ep_buffer_ready = 1
108+
finally:
109+
model_state.buffer_lock.release()
110+
else:
111+
if not model_state.rebalanced:
112+
break
113+
await asyncio.sleep(0.001)
114+
115+
state.rearrange_event.clear()

0 commit comments

Comments
 (0)