Skip to content

Commit 2b82320

Browse files
zzhx1Levi-JQ
andauthored
[Bugfix] Fix bug with establishing the flashcomm2 and pp communication domains. (#4458)
### What this PR does / why we need it? The previous implementation of the flashcomm2 communication domain did not consider pp(pipeline parallel), which caused problems when enabling pp and flashcomm2. This PR fixes this issue. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
1 parent 8c65009 commit 2b82320

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

tests/ut/distributed/test_parallel_state.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ def mock_distributed():
2424
patch('torch.distributed.get_backend', return_value='nccl'), \
2525
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
2626
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
27-
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group:
27+
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \
28+
patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group:
2829
mock_group.return_value.local_rank = 0
2930
mock_group.return_value.device_group = MagicMock()
3031
mock_tp_group.return_value.world_size = 4
3132
mock_dp_group.return_value.world_size = 2
33+
mock_pp_group.return_value.world_size = 2
3234
yield
3335

3436

vllm_ascend/distributed/parallel_state.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44
from vllm.config import ParallelConfig, get_current_vllm_config
55
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
6-
get_tp_group, get_world_group,
6+
get_pp_group, get_tp_group,
7+
get_world_group,
78
init_model_parallel_group)
89

910
import vllm_ascend.envs as envs_ascend
@@ -185,6 +186,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
185186
).flashcomm2_oproj_tensor_parallel_size
186187
global_tp_size = get_tp_group().world_size
187188
global_dp_size = get_dp_group().world_size
189+
global_pp_size = get_pp_group().world_size
188190
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
189191
flashcomm2_otp_size)
190192

@@ -197,18 +199,27 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
197199
if flashcomm2_otp_size > 1:
198200
otp_group_ranks = []
199201
odp_group_ranks: list[list[int]] = [
200-
[] for _ in range(flashcomm2_otp_size * global_dp_size)
202+
[] for _ in range(flashcomm2_otp_size * global_dp_size *
203+
global_pp_size)
201204
]
202-
203205
for dp_group_index in range(global_dp_size):
204-
for i in range(num_fc2_oproj_tensor_parallel_groups):
205-
ranks = []
206-
for j in range(flashcomm2_otp_size):
207-
rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
208-
ranks.append(rank_idx)
209-
odp_group_index = dp_group_index * flashcomm2_otp_size + j
210-
odp_group_ranks[odp_group_index].append(rank_idx)
211-
otp_group_ranks.append(ranks)
206+
for pp_group_index in range(global_pp_size):
207+
dp_pp_serial_index = dp_group_index * global_pp_size + pp_group_index
208+
tp_base_rank = dp_pp_serial_index * global_tp_size
209+
odp_base_index = dp_pp_serial_index * flashcomm2_otp_size
210+
211+
for i in range(num_fc2_oproj_tensor_parallel_groups):
212+
ranks = []
213+
for j in range(flashcomm2_otp_size):
214+
tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups
215+
assert tp_local_rank < global_tp_size
216+
global_rank = tp_base_rank + tp_local_rank
217+
ranks.append(global_rank)
218+
219+
odp_group_index = odp_base_index + j
220+
odp_group_ranks[odp_group_index].append(
221+
global_rank)
222+
otp_group_ranks.append(ranks)
212223

213224
_FLASHCOMM2_OTP = init_model_parallel_group(
214225
otp_group_ranks,

0 commit comments

Comments
 (0)