3131 Qwen3_VisionPatchEmbed ,
3232 Qwen3_VisionPatchMerger ,
3333 Qwen3_VisionTransformer )
34-
35- from .vision import get_vit_attn_backend
34+ from vllm .model_executor .models .vision import get_vit_attn_backend
3635
3736
3837class AscendQwen3_VisionBlock (nn .Module ):
@@ -44,15 +43,13 @@ def forward(
4443 rotary_pos_emb_cos : torch .Tensor ,
4544 rotary_pos_emb_sin : torch .Tensor ,
4645 max_seqlen : torch .Tensor , # Only used for Flash Attention
47- seqlens : torch .Tensor , # Only used for xFormers
4846 ) -> torch .Tensor :
4947 x = x + self .attn (
5048 self .norm1 (x ),
5149 cu_seqlens = cu_seqlens ,
5250 rotary_pos_emb_cos = rotary_pos_emb_cos ,
5351 rotary_pos_emb_sin = rotary_pos_emb_sin ,
5452 max_seqlen = max_seqlen ,
55- seqlens = seqlens ,
5653 )
5754
5855 x = x + self .mlp (self .norm2 (x ))
@@ -70,7 +67,8 @@ def __init__(
7067 use_data_parallel : bool = False ,
7168 attn_backend_override : AttentionBackendEnum | None = None ,
7269 ) -> None :
73- super ().__init__ ()
70+ nn .Module .__init__ (self )
71+
7472 self .hidden_size = vision_config .hidden_size
7573 self .num_heads = vision_config .num_heads
7674 self .num_position_embeddings = vision_config .num_position_embeddings
@@ -197,18 +195,11 @@ def forward(
197195 non_blocking = True )
198196 hidden_states = self .patch_embed (hidden_states )
199197
200- # if isinstance(grid_thw, list):
201- # grid_thw_list = grid_thw
202- # grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
203- # else:
204- # grid_thw_list = grid_thw.tolist()
205198 if isinstance (grid_thw , list ):
206- print ("Vit grid_thw -> list" , flush = True )
207199 grid_thw_list = grid_thw
208200 grid_thw = np .array (grid_thw , dtype = np .int32 )
209201 else :
210- print ("Vit grid_thw -> tensor" , flush = True )
211- # grid_thw = grid_thw.to("cpu")
202+ grid_thw = grid_thw .to ("cpu" )
212203 grid_thw_list = grid_thw .tolist ()
213204 grid_thw = grid_thw .numpy ()
214205
@@ -221,15 +212,13 @@ def forward(
221212 rotary_pos_emb_sin = rotary_pos_emb_sin .to (hidden_states .device ,
222213 non_blocking = True )
223214
224- cu_seqlens = torch .repeat_interleave (
225- grid_thw [:, 1 ] * grid_thw [:, 2 ],
226- grid_thw [:, 0 ]).cumsum (dim = 0 ,
227- dtype = grid_thw .dtype
228- if torch .jit .is_tracing () else torch .int32 )
229- cu_seqlens = torch .cat ([cu_seqlens .new_zeros (1 ), cu_seqlens ])
215+ cu_seqlens = np .repeat (grid_thw [:, 1 ] * grid_thw [:, 2 ],
216+ grid_thw [:, 0 ]).cumsum (axis = 0 , dtype = np .int32 )
217+ cu_seqlens = np .concatenate ([np .zeros (1 , dtype = np .int32 ), cu_seqlens ])
218+ cu_seqlens = torch .from_numpy (cu_seqlens )
230219
231220 hidden_states = hidden_states .unsqueeze (1 )
232- max_seqlen , seqlens = self .compute_attn_mask_seqlen (cu_seqlens )
221+ max_seqlen = self .compute_attn_mask_seqlen (cu_seqlens )
233222 cu_seqlens = cu_seqlens .to (self .device , non_blocking = True )
234223
235224 deepstack_feature_lists = []
@@ -240,7 +229,6 @@ def forward(
240229 rotary_pos_emb_cos = rotary_pos_emb_cos ,
241230 rotary_pos_emb_sin = rotary_pos_emb_sin ,
242231 max_seqlen = max_seqlen ,
243- seqlens = seqlens ,
244232 )
245233 if layer_num in self .deepstack_visual_indexes :
246234 deepstack_merger_idx = self .deepstack_visual_indexes .index (
@@ -255,6 +243,7 @@ def forward(
255243 return hidden_states
256244
257245
246+ # NOTE: These will be removed after vllm-ascend is aligned with vllm latest main.
258247Qwen3_VisionBlock .forward = AscendQwen3_VisionBlock .forward
259248Qwen3_VisionTransformer .__init__ = AscendQwen3_VisionTransformer .__init__
260249Qwen3_VisionTransformer .rot_pos_emb = AscendQwen3_VisionTransformer .rot_pos_emb
0 commit comments