Skip to content

Commit 44b2ea6

Browse files
committed
remove Qwen3-VL
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 517fd92 commit 44b2ea6

File tree

2 files changed

+267
-6
lines changed

2 files changed

+267
-6
lines changed

vllm_ascend/models/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33

44
def register_model():
5-
ModelRegistry.register_model(
6-
"Qwen3VLMoeForConditionalGeneration",
7-
"vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration")
5+
# ModelRegistry.register_model(
6+
# "Qwen3VLMoeForConditionalGeneration",
7+
# "vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration")
88

9-
ModelRegistry.register_model(
10-
"Qwen3VLForConditionalGeneration",
11-
"vllm_ascend.models.qwen3_vl:AscendQwen3VLForConditionalGeneration")
9+
# ModelRegistry.register_model(
10+
# "Qwen3VLForConditionalGeneration",
11+
# "vllm_ascend.models.qwen3_vl:AscendQwen3VLForConditionalGeneration")
1212

1313
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
1414
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from functools import partial
19+
20+
import numpy as np
21+
import torch
22+
import torch.nn as nn
23+
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
24+
Qwen3VLVisionConfig
25+
from vllm.attention.backends.registry import AttentionBackendEnum
26+
from vllm.attention.layer import check_upstream_fa_availability
27+
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
28+
from vllm.model_executor.layers.quantization import QuantizationConfig
29+
from vllm.model_executor.layers.rotary_embedding import get_rope
30+
from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionBlock,
31+
Qwen3_VisionPatchEmbed,
32+
Qwen3_VisionPatchMerger,
33+
Qwen3_VisionTransformer)
34+
35+
from .vision import get_vit_attn_backend
36+
37+
38+
class AscendQwen3_VisionBlock(nn.Module):
39+
40+
def forward(
41+
self,
42+
x: torch.Tensor,
43+
cu_seqlens: torch.Tensor,
44+
rotary_pos_emb_cos: torch.Tensor,
45+
rotary_pos_emb_sin: torch.Tensor,
46+
max_seqlen: torch.Tensor, # Only used for Flash Attention
47+
seqlens: torch.Tensor, # Only used for xFormers
48+
) -> torch.Tensor:
49+
x = x + self.attn(
50+
self.norm1(x),
51+
cu_seqlens=cu_seqlens,
52+
rotary_pos_emb_cos=rotary_pos_emb_cos,
53+
rotary_pos_emb_sin=rotary_pos_emb_sin,
54+
max_seqlen=max_seqlen,
55+
seqlens=seqlens,
56+
)
57+
58+
x = x + self.mlp(self.norm2(x))
59+
return x
60+
61+
62+
class AscendQwen3_VisionTransformer(nn.Module):
63+
64+
def __init__(
65+
self,
66+
vision_config: Qwen3VLVisionConfig,
67+
norm_eps: float = 1e-6,
68+
quant_config: QuantizationConfig | None = None,
69+
prefix: str = "",
70+
use_data_parallel: bool = False,
71+
attn_backend_override: AttentionBackendEnum | None = None,
72+
) -> None:
73+
super().__init__()
74+
self.hidden_size = vision_config.hidden_size
75+
self.num_heads = vision_config.num_heads
76+
self.num_position_embeddings = vision_config.num_position_embeddings
77+
self.patch_size = vision_config.patch_size
78+
self.spatial_merge_size = vision_config.spatial_merge_size
79+
self.spatial_merge_unit = self.spatial_merge_size**2
80+
self.temporal_patch_size = vision_config.temporal_patch_size
81+
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
82+
self.use_data_parallel = use_data_parallel
83+
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
84+
85+
# NOTE: This is used for creating empty tensor for all_gather for
86+
# DP ViT. Here out_hidden_size is enlarged due to deepstack
87+
self.out_hidden_size = vision_config.out_hidden_size * (
88+
1 + len(self.deepstack_visual_indexes))
89+
90+
self.patch_embed = Qwen3_VisionPatchEmbed(
91+
patch_size=self.patch_size,
92+
temporal_patch_size=self.temporal_patch_size,
93+
in_channels=vision_config.in_channels,
94+
hidden_size=self.hidden_size,
95+
)
96+
97+
self.pos_embed = nn.Embedding(self.num_position_embeddings,
98+
self.hidden_size)
99+
100+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
101+
head_dim = self.hidden_size // self.num_heads
102+
self.rotary_pos_emb = get_rope(
103+
head_size=head_dim,
104+
rotary_dim=head_dim // 2,
105+
max_position=8192,
106+
base=10000.0,
107+
is_neox_style=True,
108+
)
109+
110+
self.merger = Qwen3_VisionPatchMerger(
111+
d_model=vision_config.out_hidden_size,
112+
context_dim=self.hidden_size,
113+
norm_layer=norm_layer,
114+
spatial_merge_size=self.spatial_merge_size,
115+
quant_config=quant_config,
116+
prefix=f"{prefix}.merger",
117+
use_data_parallel=use_data_parallel,
118+
)
119+
120+
self.deepstack_merger_list = nn.ModuleList([
121+
Qwen3_VisionPatchMerger(
122+
d_model=vision_config.out_hidden_size,
123+
context_dim=self.hidden_size,
124+
spatial_merge_size=self.spatial_merge_size,
125+
use_postshuffle_norm=True,
126+
norm_layer=norm_layer,
127+
quant_config=quant_config,
128+
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
129+
use_data_parallel=use_data_parallel,
130+
) for layer_idx in range(len(self.deepstack_visual_indexes))
131+
])
132+
133+
self.attn_backend = get_vit_attn_backend(
134+
head_size=head_dim,
135+
dtype=torch.get_default_dtype(),
136+
attn_backend_override=attn_backend_override,
137+
)
138+
use_upstream_fa = False
139+
if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN
140+
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
141+
and check_upstream_fa_availability(torch.get_default_dtype())):
142+
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
143+
use_upstream_fa = True
144+
145+
if self.attn_backend not in {
146+
AttentionBackendEnum.FLASH_ATTN,
147+
AttentionBackendEnum.TORCH_SDPA,
148+
AttentionBackendEnum.XFORMERS,
149+
AttentionBackendEnum.ROCM_AITER_FA,
150+
}:
151+
raise RuntimeError(
152+
f"Qwen3-VL does not support {self.attn_backend} backend now.")
153+
self.blocks = nn.ModuleList([
154+
Qwen3_VisionBlock(
155+
dim=self.hidden_size,
156+
num_heads=self.num_heads,
157+
mlp_hidden_dim=vision_config.intermediate_size,
158+
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
159+
norm_layer=norm_layer,
160+
quant_config=quant_config,
161+
prefix=f"{prefix}.blocks.{layer_idx}",
162+
use_data_parallel=use_data_parallel,
163+
attn_backend=self.attn_backend,
164+
use_upstream_fa=use_upstream_fa,
165+
) for layer_idx in range(vision_config.depth)
166+
])
167+
168+
def rot_pos_emb(self, grid_thw: list[list[int]]):
169+
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
170+
pos_ids = [
171+
self.rot_pos_ids(h, w, self.spatial_merge_size) if t == 1 else
172+
self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
173+
for t, h, w in grid_thw
174+
]
175+
pos_ids = torch.cat(pos_ids, dim=0)
176+
177+
# Use pre-computed cos_sin_cache from RotaryEmbedding
178+
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
179+
180+
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
181+
cos_w = cos[pos_ids[:, 1]]
182+
sin_h = sin[pos_ids[:, 0]]
183+
sin_w = sin[pos_ids[:, 1]]
184+
185+
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
186+
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
187+
188+
return cos_combined, sin_combined
189+
190+
def forward(
191+
self,
192+
x: torch.Tensor,
193+
grid_thw: torch.Tensor | list[list[int]],
194+
) -> torch.Tensor:
195+
hidden_states = x.to(device=self.device,
196+
dtype=self.dtype,
197+
non_blocking=True)
198+
hidden_states = self.patch_embed(hidden_states)
199+
200+
# if isinstance(grid_thw, list):
201+
# grid_thw_list = grid_thw
202+
# grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
203+
# else:
204+
# grid_thw_list = grid_thw.tolist()
205+
if isinstance(grid_thw, list):
206+
print("Vit grid_thw -> list", flush=True)
207+
grid_thw_list = grid_thw
208+
grid_thw = np.array(grid_thw, dtype=np.int32)
209+
else:
210+
print("Vit grid_thw -> tensor", flush=True)
211+
# grid_thw = grid_thw.to("cpu")
212+
grid_thw_list = grid_thw.tolist()
213+
grid_thw = grid_thw.numpy()
214+
215+
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
216+
hidden_states = hidden_states + pos_embeds
217+
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(
218+
grid_thw_list)
219+
rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device,
220+
non_blocking=True)
221+
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device,
222+
non_blocking=True)
223+
224+
cu_seqlens = torch.repeat_interleave(
225+
grid_thw[:, 1] * grid_thw[:, 2],
226+
grid_thw[:, 0]).cumsum(dim=0,
227+
dtype=grid_thw.dtype
228+
if torch.jit.is_tracing() else torch.int32)
229+
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
230+
231+
hidden_states = hidden_states.unsqueeze(1)
232+
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
233+
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
234+
235+
deepstack_feature_lists = []
236+
for layer_num, blk in enumerate(self.blocks):
237+
hidden_states = blk(
238+
hidden_states,
239+
cu_seqlens=cu_seqlens,
240+
rotary_pos_emb_cos=rotary_pos_emb_cos,
241+
rotary_pos_emb_sin=rotary_pos_emb_sin,
242+
max_seqlen=max_seqlen,
243+
seqlens=seqlens,
244+
)
245+
if layer_num in self.deepstack_visual_indexes:
246+
deepstack_merger_idx = self.deepstack_visual_indexes.index(
247+
layer_num)
248+
deepstack_feature = self.deepstack_merger_list[
249+
deepstack_merger_idx](hidden_states)
250+
deepstack_feature_lists.append(deepstack_feature)
251+
hidden_states = self.merger(hidden_states)
252+
hidden_states = torch.cat(
253+
[hidden_states] + deepstack_feature_lists,
254+
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
255+
return hidden_states
256+
257+
258+
Qwen3_VisionBlock.forward = AscendQwen3_VisionBlock.forward
259+
Qwen3_VisionTransformer.__init__ = AscendQwen3_VisionTransformer.__init__
260+
Qwen3_VisionTransformer.rot_pos_emb = AscendQwen3_VisionTransformer.rot_pos_emb
261+
Qwen3_VisionTransformer.forward = AscendQwen3_VisionTransformer.forward

0 commit comments

Comments
 (0)