Skip to content

Commit 48fc8b1

Browse files
[BugFix] Fix async-scheduling + FlashAttn MLA (#28990)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 1ffe934 commit 48fc8b1

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,7 @@ def build(
755755
seq_lens = common_attn_metadata.seq_lens
756756
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
757757
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
758+
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
758759

759760
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
760761

@@ -944,18 +945,20 @@ def build(
944945

945946
decode_metadata = None
946947
if num_decodes > 0:
948+
dcp_tot_seq_lens_device = None
949+
if self.dcp_world_size > 1:
950+
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
951+
seq_lens_cpu = dcp_local_seq_lens_cpu
952+
seq_lens = dcp_local_seq_lens
953+
947954
decode_metadata = self._build_decode(
948955
block_table_tensor=block_table_tensor[:num_decodes, ...],
949956
seq_lens_cpu=seq_lens_cpu[:num_decodes],
950-
seq_lens_device=dcp_local_seq_lens[:num_decodes]
951-
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
952-
else seq_lens[:num_decodes],
957+
seq_lens_device=seq_lens[:num_decodes],
953958
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
954959
query_start_loc_device=query_start_loc[: num_decodes + 1],
955960
num_decode_tokens=num_decode_tokens,
956-
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
957-
if self.dcp_world_size > 1
958-
else None,
961+
dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
959962
)
960963

961964
attn_metadata = self.metadata_cls(

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _build_decode(
173173
) -> FlashAttnMLADecodeMetadata:
174174
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
175175
max_query_len = query_lens_cpu.max().item()
176-
max_seq_len = seq_lens_device.max().item()
176+
max_seq_len = seq_lens_cpu.max().item()
177177

178178
# For Flash Attention MLA + full cudagraph
179179
max_num_splits = 0

vllm/v1/attention/backends/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class CommonAttentionMetadata:
9292
encoder_seq_lens: np.ndarray | None = None
9393

9494
dcp_local_seq_lens: torch.Tensor | None = None
95+
dcp_local_seq_lens_cpu: torch.Tensor | None = None
9596
"""Sequence lengths of the local rank in decode context parallelism world"""
9697

9798

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,9 +1451,12 @@ def _build_attention_metadata(
14511451
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
14521452
:num_reqs
14531453
]
1454-
dcp_local_seq_lens = (
1455-
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
1456-
)
1454+
1455+
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
1456+
if self.dcp_world_size > 1:
1457+
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
1458+
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
1459+
14571460
spec_decode_common_attn_metadata = None
14581461

14591462
if for_cudagraph_capture:
@@ -1521,6 +1524,7 @@ def _build_attention_metadata(
15211524
causal=True,
15221525
encoder_seq_lens=encoder_seq_lens,
15231526
dcp_local_seq_lens=dcp_local_seq_lens,
1527+
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
15241528
)
15251529

15261530
if self.speculative_config and spec_decode_common_attn_metadata is None:

0 commit comments

Comments
 (0)