2727from transformers .models .qwen2_vl .configuration_qwen2_vl import \
2828 Qwen2VLVisionConfig
2929from vllm .attention .backends .registry import AttentionBackendEnum
30- from vllm .attention .layer import (check_upstream_fa_availability ,
31- maybe_get_vit_flash_attn_backend )
30+ from vllm .attention .layer import maybe_get_vit_flash_attn_backend
3231from vllm .model_executor .layers .activation import get_act_and_mul_fn
3332from vllm .model_executor .layers .layernorm import RMSNorm
3433from vllm .model_executor .layers .quantization import QuantizationConfig
@@ -65,7 +64,6 @@ def forward(
6564 rotary_pos_emb_cos : torch .Tensor ,
6665 rotary_pos_emb_sin : torch .Tensor ,
6766 max_seqlen : torch .Tensor ,
68- seqlens : torch .Tensor ,
6967 ) -> torch .Tensor :
7068 # [s, b, c] --> [s, b, head * 3 * head_dim]
7169 x , _ = self .qkv (x )
@@ -141,15 +139,13 @@ def forward(
141139 rotary_pos_emb_cos : torch .Tensor ,
142140 rotary_pos_emb_sin : torch .Tensor ,
143141 max_seqlen : int | None = None , # Only used for Flash Attention
144- seqlens : list [int ] | None = None , # Only used for xFormers
145142 ) -> torch .Tensor :
146143 x = x + self .attn (
147144 self .norm1 (x ),
148145 cu_seqlens = cu_seqlens ,
149146 rotary_pos_emb_cos = rotary_pos_emb_cos ,
150147 rotary_pos_emb_sin = rotary_pos_emb_sin ,
151148 max_seqlen = max_seqlen ,
152- seqlens = seqlens ,
153149 )
154150 x = x + self .mlp (self .norm2 (x ))
155151 return x
@@ -198,7 +194,6 @@ def __init__(
198194 head_size = head_dim ,
199195 rotary_dim = head_dim // 2 ,
200196 max_position = 8192 ,
201- base = 10000.0 ,
202197 is_neox_style = True ,
203198 )
204199
@@ -228,10 +223,6 @@ def __init__(
228223 attn_backend_override = attn_backend_override ,
229224 )
230225
231- if (self .attn_backend != AttentionBackendEnum .FLASH_ATTN
232- and check_upstream_fa_availability (torch .get_default_dtype ())):
233- self .attn_backend = AttentionBackendEnum .FLASH_ATTN
234-
235226 def rot_pos_emb (
236227 self ,
237228 grid_thw : list [list [int ]]) -> tuple [torch .Tensor , torch .Tensor ]:
@@ -300,15 +291,14 @@ def forward(
300291 x = x .unsqueeze (1 )
301292
302293 # pre-compute seqlens for attn mask to reduce cuMemcpy operations
303- max_seqlen , seqlens = self .compute_attn_mask_seqlen (cu_seqlens )
294+ max_seqlen = self .compute_attn_mask_seqlen (cu_seqlens )
304295 for blk in self .blocks :
305296 x = blk (
306297 x ,
307298 cu_seqlens = cu_seqlens ,
308299 rotary_pos_emb_cos = rotary_pos_emb_cos ,
309300 rotary_pos_emb_sin = rotary_pos_emb_sin ,
310301 max_seqlen = max_seqlen ,
311- seqlens = seqlens ,
312302 )
313303
314304 # adapter
@@ -326,15 +316,13 @@ def forward(
326316 rotary_pos_emb_cos : torch .Tensor ,
327317 rotary_pos_emb_sin : torch .Tensor ,
328318 max_seqlen : torch .Tensor , # Only used for Flash Attention
329- seqlens : torch .Tensor , # Only used for xFormers
330319 ) -> torch .Tensor :
331320 x_attn = self .attn (
332321 self .norm1 (x ),
333322 cu_seqlens = cu_seqlens ,
334323 rotary_pos_emb_cos = rotary_pos_emb_cos ,
335324 rotary_pos_emb_sin = rotary_pos_emb_sin ,
336325 max_seqlen = max_seqlen ,
337- seqlens = seqlens ,
338326 )
339327 x_fused_norm , residual = self .norm2 (x , residual = x_attn )
340328 x = residual + self .mlp (x_fused_norm )
@@ -388,11 +376,9 @@ def __init__(
388376 head_size = head_dim ,
389377 rotary_dim = head_dim // 2 ,
390378 max_position = 8192 ,
391- base = 10000.0 ,
392379 is_neox_style = True ,
393380 )
394381
395- use_upstream_fa = False
396382 self .attn_backend = get_vit_attn_backend (
397383 head_size = head_dim ,
398384 dtype = torch .get_default_dtype (),
@@ -402,7 +388,6 @@ def __init__(
402388 self .attn_backend , self .flash_attn_varlen_func = (
403389 maybe_get_vit_flash_attn_backend (
404390 self .attn_backend ,
405- use_upstream_fa ,
406391 attn_backend_override = attn_backend_override ,
407392 ))
408393
@@ -418,7 +403,6 @@ def __init__(
418403 prefix = f"{ prefix } .blocks.{ layer_idx } " ,
419404 use_data_parallel = use_data_parallel ,
420405 attn_backend = self .attn_backend ,
421- use_upstream_fa = use_upstream_fa ,
422406 attn_backend_override = attn_backend_override ,
423407 ) for layer_idx in range (depth )
424408 ])
@@ -553,10 +537,8 @@ def forward(
553537
554538 # transformers
555539 # pre-compute seqlens for window/full attn to reduce cuMemcpy operations
556- max_seqlen_full , seqlens_full = self .compute_attn_mask_seqlen (
557- cu_seqlens )
558- max_seqlen_window , seqlens_window = self .compute_attn_mask_seqlen (
559- cu_window_seqlens )
540+ max_seqlen_full = self .compute_attn_mask_seqlen (cu_seqlens )
541+ max_seqlen_window = self .compute_attn_mask_seqlen (cu_window_seqlens )
560542
561543 cu_seqlens = cu_seqlens .to ( # type: ignore[attr-defined]
562544 device = self .device ,
@@ -587,19 +569,16 @@ def forward(
587569 if layer_num in self .fullatt_block_indexes :
588570 cu_seqlens_now = cu_seqlens
589571 max_seqlen_now = max_seqlen_full
590- seqlens_now = seqlens_full
591572 else :
592573 cu_seqlens_now = cu_window_seqlens
593574 max_seqlen_now = max_seqlen_window
594- seqlens_now = seqlens_window
595575
596576 hidden_states = blk (
597577 hidden_states ,
598578 cu_seqlens = cu_seqlens_now ,
599579 rotary_pos_emb_cos = rotary_pos_emb_cos ,
600580 rotary_pos_emb_sin = rotary_pos_emb_sin ,
601581 max_seqlen = max_seqlen_now ,
602- seqlens = seqlens_now ,
603582 )
604583
605584 # For Qwen2.5-VL-3B, float16 will overflow at last block
0 commit comments