Skip to content

Commit 3248d20

Browse files
committed
upgrade vLLM to main
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 2b82320 commit 3248d20

File tree

19 files changed

+114
-151
lines changed

19 files changed

+114
-151
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def __init__(
283283
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
284284

285285
scheduler_config = vllm_config.scheduler_config
286-
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
286+
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
287287

288288
def reorder_batch(self, input_batch,
289289
scheduler_output: "SchedulerOutput") -> bool:

vllm_ascend/distributed/cpu_offload_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import TYPE_CHECKING, Any, Optional, Sequence
1010

1111
import torch
12-
from vllm.attention import AttentionType
12+
from vllm.attention.backends.abstract import AttentionType
1313
from vllm.attention.layer import Attention
1414
from vllm.config import VllmConfig
1515
from vllm.distributed.kv_transfer.kv_connector.v1.base import (

vllm_ascend/kv_offload/cpu_npu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import torch
3-
from vllm.attention import AttentionBackend
3+
from vllm.attention.backends.abstract import AttentionBackend
44
from vllm.logger import init_logger
55
from vllm.utils.platform_utils import is_pin_memory_available
66
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec

vllm_ascend/ops/mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import torch
2525
from torch import nn
26-
from vllm.attention import AttentionMetadata
26+
from vllm.attention.backends.abstract import AttentionMetadata
2727
from vllm.attention.layer import MLAAttention
2828
from vllm.config import CacheConfig, get_current_vllm_config
2929
from vllm.distributed import get_tensor_model_parallel_world_size

vllm_ascend/patch/worker/patch_qwen2_5_vl.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
from transformers.models.qwen2_vl.configuration_qwen2_vl import \
2828
Qwen2VLVisionConfig
2929
from vllm.attention.backends.registry import AttentionBackendEnum
30-
from vllm.attention.layer import (check_upstream_fa_availability,
31-
maybe_get_vit_flash_attn_backend)
30+
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
3231
from vllm.model_executor.layers.activation import get_act_and_mul_fn
3332
from vllm.model_executor.layers.layernorm import RMSNorm
3433
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -65,7 +64,6 @@ def forward(
6564
rotary_pos_emb_cos: torch.Tensor,
6665
rotary_pos_emb_sin: torch.Tensor,
6766
max_seqlen: torch.Tensor,
68-
seqlens: torch.Tensor,
6967
) -> torch.Tensor:
7068
# [s, b, c] --> [s, b, head * 3 * head_dim]
7169
x, _ = self.qkv(x)
@@ -141,15 +139,13 @@ def forward(
141139
rotary_pos_emb_cos: torch.Tensor,
142140
rotary_pos_emb_sin: torch.Tensor,
143141
max_seqlen: int | None = None, # Only used for Flash Attention
144-
seqlens: list[int] | None = None, # Only used for xFormers
145142
) -> torch.Tensor:
146143
x = x + self.attn(
147144
self.norm1(x),
148145
cu_seqlens=cu_seqlens,
149146
rotary_pos_emb_cos=rotary_pos_emb_cos,
150147
rotary_pos_emb_sin=rotary_pos_emb_sin,
151148
max_seqlen=max_seqlen,
152-
seqlens=seqlens,
153149
)
154150
x = x + self.mlp(self.norm2(x))
155151
return x
@@ -198,7 +194,6 @@ def __init__(
198194
head_size=head_dim,
199195
rotary_dim=head_dim // 2,
200196
max_position=8192,
201-
base=10000.0,
202197
is_neox_style=True,
203198
)
204199

@@ -228,10 +223,6 @@ def __init__(
228223
attn_backend_override=attn_backend_override,
229224
)
230225

231-
if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN
232-
and check_upstream_fa_availability(torch.get_default_dtype())):
233-
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
234-
235226
def rot_pos_emb(
236227
self,
237228
grid_thw: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
@@ -300,15 +291,14 @@ def forward(
300291
x = x.unsqueeze(1)
301292

302293
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
303-
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
294+
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
304295
for blk in self.blocks:
305296
x = blk(
306297
x,
307298
cu_seqlens=cu_seqlens,
308299
rotary_pos_emb_cos=rotary_pos_emb_cos,
309300
rotary_pos_emb_sin=rotary_pos_emb_sin,
310301
max_seqlen=max_seqlen,
311-
seqlens=seqlens,
312302
)
313303

314304
# adapter
@@ -326,15 +316,13 @@ def forward(
326316
rotary_pos_emb_cos: torch.Tensor,
327317
rotary_pos_emb_sin: torch.Tensor,
328318
max_seqlen: torch.Tensor, # Only used for Flash Attention
329-
seqlens: torch.Tensor, # Only used for xFormers
330319
) -> torch.Tensor:
331320
x_attn = self.attn(
332321
self.norm1(x),
333322
cu_seqlens=cu_seqlens,
334323
rotary_pos_emb_cos=rotary_pos_emb_cos,
335324
rotary_pos_emb_sin=rotary_pos_emb_sin,
336325
max_seqlen=max_seqlen,
337-
seqlens=seqlens,
338326
)
339327
x_fused_norm, residual = self.norm2(x, residual=x_attn)
340328
x = residual + self.mlp(x_fused_norm)
@@ -388,11 +376,9 @@ def __init__(
388376
head_size=head_dim,
389377
rotary_dim=head_dim // 2,
390378
max_position=8192,
391-
base=10000.0,
392379
is_neox_style=True,
393380
)
394381

395-
use_upstream_fa = False
396382
self.attn_backend = get_vit_attn_backend(
397383
head_size=head_dim,
398384
dtype=torch.get_default_dtype(),
@@ -402,7 +388,6 @@ def __init__(
402388
self.attn_backend, self.flash_attn_varlen_func = (
403389
maybe_get_vit_flash_attn_backend(
404390
self.attn_backend,
405-
use_upstream_fa,
406391
attn_backend_override=attn_backend_override,
407392
))
408393

@@ -418,7 +403,6 @@ def __init__(
418403
prefix=f"{prefix}.blocks.{layer_idx}",
419404
use_data_parallel=use_data_parallel,
420405
attn_backend=self.attn_backend,
421-
use_upstream_fa=use_upstream_fa,
422406
attn_backend_override=attn_backend_override,
423407
) for layer_idx in range(depth)
424408
])
@@ -553,10 +537,8 @@ def forward(
553537

554538
# transformers
555539
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
556-
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(
557-
cu_seqlens)
558-
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
559-
cu_window_seqlens)
540+
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
541+
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
560542

561543
cu_seqlens = cu_seqlens.to( # type: ignore[attr-defined]
562544
device=self.device,
@@ -587,19 +569,16 @@ def forward(
587569
if layer_num in self.fullatt_block_indexes:
588570
cu_seqlens_now = cu_seqlens
589571
max_seqlen_now = max_seqlen_full
590-
seqlens_now = seqlens_full
591572
else:
592573
cu_seqlens_now = cu_window_seqlens
593574
max_seqlen_now = max_seqlen_window
594-
seqlens_now = seqlens_window
595575

596576
hidden_states = blk(
597577
hidden_states,
598578
cu_seqlens=cu_seqlens_now,
599579
rotary_pos_emb_cos=rotary_pos_emb_cos,
600580
rotary_pos_emb_sin=rotary_pos_emb_sin,
601581
max_seqlen=max_seqlen_now,
602-
seqlens=seqlens_now,
603582
)
604583

605584
# For Qwen2.5-VL-3B, float16 will overflow at last block

vllm_ascend/platform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
178178
compilation_config.splitting_ops = []
179179

180180
compilation_config.cudagraph_num_of_warmups = 1
181+
compilation_config.pass_config.enable_fusion = False
181182

182183
if compilation_config.mode not in [
183184
CompilationMode.NONE, CompilationMode.VLLM_COMPILE

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def dummy_run(self,
138138
dummy_compute_logits(self.hidden_states)
139139

140140
def generate_token_ids(self,
141-
valid_sampled_token_ids: list[np.ndarray],
141+
valid_sampled_token_ids: torch.Tensor
142+
| list[list[int]],
142143
sampling_metadata: SamplingMetadata = None,
143144
scheduler_output: SchedulerOutput = None,
144145
spec_decode_metadata: SpecDecodeMetadata = None,
@@ -151,7 +152,7 @@ def generate_token_ids(self,
151152
attn_metadata = self._get_eagle_atten_dict(scheduler_output)
152153
next_token_ids: list[int] = []
153154
for i, token_ids in enumerate(valid_sampled_token_ids):
154-
if token_ids.shape[0] > 0:
155+
if token_ids:
155156
# Common case.
156157
next_token_id = token_ids[-1]
157158
else:
@@ -163,7 +164,7 @@ def generate_token_ids(self,
163164
scheduler_output.num_scheduled_tokens[req_id])
164165

165166
next_token_id = req_state.get_token_id(seq_len)
166-
next_token_ids.append(next_token_id.item())
167+
next_token_ids.append(next_token_id)
167168
next_token_ids = torch.tensor(next_token_ids,
168169
dtype=torch.int32,
169170
device=self.device)
@@ -183,7 +184,7 @@ def generate_token_ids(self,
183184
else:
184185
num_draft_tokens = spec_decode_metadata.num_draft_tokens
185186
num_rejected_tokens = [
186-
n + 1 - valid_sampled_token_ids[i].shape[0] if n > 0 else 0
187+
n + 1 - len(valid_sampled_token_ids[0]) if n > 0 else 0
187188
for i, n in enumerate(num_draft_tokens)
188189
]
189190
num_rejected_tokens = torch.tensor(

vllm_ascend/spec_decode/interface.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import enum
22
from typing import Optional
33

4-
import numpy as np
54
import torch
65
from vllm.config import CUDAGraphMode, VllmConfig
76
from vllm.v1.core.sched.output import SchedulerOutput
@@ -41,7 +40,7 @@ def dummy_run(self,
4140
raise NotImplementedError
4241

4342
def generate_token_ids(self,
44-
valid_sampled_token_ids: list[np.ndarray],
43+
valid_sampled_token_ids: list[list[int]],
4544
sampling_metadata: SamplingMetadata = None,
4645
scheduler_output: SchedulerOutput = None,
4746
spec_decode_metadata: SpecDecodeMetadata = None,

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,7 @@ def dummy_run(self,
302302
break
303303

304304
def generate_token_ids(self,
305-
sampled_token_ids: Union[torch.Tensor,
306-
list[np.ndarray]],
305+
sampled_token_ids: torch.Tensor | list[list[int]],
307306
sampling_metadata: SamplingMetadata = None,
308307
scheduler_output: SchedulerOutput = None,
309308
spec_decode_metadata: SpecDecodeMetadata = None,
@@ -380,7 +379,6 @@ def generate_token_ids(self,
380379
common_attn_metadata.query_start_loc = \
381380
query_start_loc_pcp_full[:num_reqs + 1]
382381
if self.speculative_config.disable_padded_drafter_batch:
383-
assert isinstance(sampled_token_ids, list)
384382
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
385383
token_indices_to_sample = None
386384
common_attn_metadata, token_indices =\
@@ -439,7 +437,7 @@ def _get_attn_metadata(self, attn_metadata):
439437
def _prepare_inputs(
440438
self,
441439
common_attn_metadata: CommonAttentionMetadata,
442-
sampled_token_ids: list[np.ndarray],
440+
sampled_token_ids: list[list[int]],
443441
num_draft_tokens: list[int],
444442
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
445443
"""
@@ -898,7 +896,7 @@ def _prepare_input_kernel(self, out_ptr: torch.Tensor,
898896

899897
def prepare_next_token_ids_cpu(
900898
self,
901-
sampled_token_ids: list[np.ndarray],
899+
sampled_token_ids: list[list[int]],
902900
requests: dict[str, CachedRequestState],
903901
gpu_input_batch: InputBatch,
904902
num_scheduled_tokens: dict[str, int],
@@ -913,7 +911,7 @@ def prepare_next_token_ids_cpu(
913911
req_ids = gpu_input_batch.req_ids
914912
next_token_ids: list[int] = []
915913
for i, token_ids in enumerate(sampled_token_ids):
916-
if token_ids.shape[0] > 0:
914+
if token_ids:
917915
# Common case.
918916
next_token_id = token_ids[-1]
919917
else:
@@ -924,7 +922,7 @@ def prepare_next_token_ids_cpu(
924922
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
925923
req_id]
926924
next_token_id = req_state.get_token_id(seq_len)
927-
next_token_ids.append(next_token_id.item())
925+
next_token_ids.append(next_token_id)
928926
next_token_ids = torch.tensor(next_token_ids,
929927
dtype=torch.int32,
930928
device=self.input_ids.device)

vllm_ascend/spec_decode/ngram_proposer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import torch
32
from vllm.config import CUDAGraphMode
43
from vllm.v1.spec_decode.ngram_proposer import \
@@ -32,7 +31,7 @@ def dummy_run(self,
3231
pass
3332

3433
def generate_token_ids(self,
35-
valid_sampled_token_ids: list[np.ndarray],
34+
valid_sampled_token_ids,
3635
sampling_metadata=None,
3736
scheduler_output=None,
3837
spec_decode_metadata=None,
@@ -43,7 +42,7 @@ def generate_token_ids(self,
4342
aux_hidden_states=None) -> list[list[int]]:
4443
valid_ngram_requests = []
4544
for i, sampled_ids in enumerate(valid_sampled_token_ids):
46-
num_sampled_ids = sampled_ids.shape[0]
45+
num_sampled_ids = len(sampled_ids)
4746
if not num_sampled_ids:
4847
continue
4948

0 commit comments

Comments
 (0)