Skip to content

Commit 2b0e3e7

Browse files
committed
update
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 02ffe02 commit 2b0e3e7

File tree

5 files changed

+12
-294
lines changed

5 files changed

+12
-294
lines changed

vllm_ascend/models/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,6 @@
22

33

44
def register_model():
5-
# ModelRegistry.register_model(
6-
# "Qwen3VLMoeForConditionalGeneration",
7-
# "vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration")
8-
9-
# ModelRegistry.register_model(
10-
# "Qwen3VLForConditionalGeneration",
11-
# "vllm_ascend.models.qwen3_vl:AscendQwen3VLForConditionalGeneration")
12-
135
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
146
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
157
ModelRegistry.register_model(

vllm_ascend/models/qwen3_vl.py

Lines changed: 0 additions & 264 deletions
This file was deleted.

vllm_ascend/patch/worker/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@
2828
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
2929
import vllm_ascend.patch.worker.patch_minicpm # noqa
3030
import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa
31+
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
3132
import vllm_ascend.patch.worker.patch_rope # noqa

vllm_ascend/patch/worker/patch_qwen2_5_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def forward(
6565
rotary_pos_emb_cos: torch.Tensor,
6666
rotary_pos_emb_sin: torch.Tensor,
6767
max_seqlen: torch.Tensor,
68-
seqlens: torch.Tensor,
68+
seqlens: torch.Tensor = None,
6969
) -> torch.Tensor:
7070
# [s, b, c] --> [s, b, head * 3 * head_dim]
7171
x, _ = self.qkv(x)

vllm_ascend/patch/worker/patch_qwen3_vl.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
Qwen3_VisionPatchEmbed,
3232
Qwen3_VisionPatchMerger,
3333
Qwen3_VisionTransformer)
34-
35-
from .vision import get_vit_attn_backend
34+
from vllm.model_executor.models.vision import get_vit_attn_backend
3635

3736

3837
class AscendQwen3_VisionBlock(nn.Module):
@@ -44,15 +43,13 @@ def forward(
4443
rotary_pos_emb_cos: torch.Tensor,
4544
rotary_pos_emb_sin: torch.Tensor,
4645
max_seqlen: torch.Tensor, # Only used for Flash Attention
47-
seqlens: torch.Tensor, # Only used for xFormers
4846
) -> torch.Tensor:
4947
x = x + self.attn(
5048
self.norm1(x),
5149
cu_seqlens=cu_seqlens,
5250
rotary_pos_emb_cos=rotary_pos_emb_cos,
5351
rotary_pos_emb_sin=rotary_pos_emb_sin,
5452
max_seqlen=max_seqlen,
55-
seqlens=seqlens,
5653
)
5754

5855
x = x + self.mlp(self.norm2(x))
@@ -70,7 +67,8 @@ def __init__(
7067
use_data_parallel: bool = False,
7168
attn_backend_override: AttentionBackendEnum | None = None,
7269
) -> None:
73-
super().__init__()
70+
nn.Module.__init__(self)
71+
7472
self.hidden_size = vision_config.hidden_size
7573
self.num_heads = vision_config.num_heads
7674
self.num_position_embeddings = vision_config.num_position_embeddings
@@ -197,18 +195,11 @@ def forward(
197195
non_blocking=True)
198196
hidden_states = self.patch_embed(hidden_states)
199197

200-
# if isinstance(grid_thw, list):
201-
# grid_thw_list = grid_thw
202-
# grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
203-
# else:
204-
# grid_thw_list = grid_thw.tolist()
205198
if isinstance(grid_thw, list):
206-
print("Vit grid_thw -> list", flush=True)
207199
grid_thw_list = grid_thw
208200
grid_thw = np.array(grid_thw, dtype=np.int32)
209201
else:
210-
print("Vit grid_thw -> tensor", flush=True)
211-
# grid_thw = grid_thw.to("cpu")
202+
grid_thw = grid_thw.to("cpu")
212203
grid_thw_list = grid_thw.tolist()
213204
grid_thw = grid_thw.numpy()
214205

@@ -221,15 +212,13 @@ def forward(
221212
rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device,
222213
non_blocking=True)
223214

224-
cu_seqlens = torch.repeat_interleave(
225-
grid_thw[:, 1] * grid_thw[:, 2],
226-
grid_thw[:, 0]).cumsum(dim=0,
227-
dtype=grid_thw.dtype
228-
if torch.jit.is_tracing() else torch.int32)
229-
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
215+
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2],
216+
grid_thw[:, 0]).cumsum(axis=0, dtype=np.int32)
217+
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
218+
cu_seqlens = torch.from_numpy(cu_seqlens)
230219

231220
hidden_states = hidden_states.unsqueeze(1)
232-
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
221+
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
233222
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
234223

235224
deepstack_feature_lists = []
@@ -240,7 +229,6 @@ def forward(
240229
rotary_pos_emb_cos=rotary_pos_emb_cos,
241230
rotary_pos_emb_sin=rotary_pos_emb_sin,
242231
max_seqlen=max_seqlen,
243-
seqlens=seqlens,
244232
)
245233
if layer_num in self.deepstack_visual_indexes:
246234
deepstack_merger_idx = self.deepstack_visual_indexes.index(
@@ -255,6 +243,7 @@ def forward(
255243
return hidden_states
256244

257245

246+
# NOTE: These will be removed after vllm-ascend is aligned with vllm latest main.
258247
Qwen3_VisionBlock.forward = AscendQwen3_VisionBlock.forward
259248
Qwen3_VisionTransformer.__init__ = AscendQwen3_VisionTransformer.__init__
260249
Qwen3_VisionTransformer.rot_pos_emb = AscendQwen3_VisionTransformer.rot_pos_emb

0 commit comments

Comments
 (0)