diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 7eac94bdd7e..b1957fe8f04 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -2,14 +2,6 @@ def register_model(): - ModelRegistry.register_model( - "Qwen3VLMoeForConditionalGeneration", - "vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration") - - ModelRegistry.register_model( - "Qwen3VLForConditionalGeneration", - "vllm_ascend.models.qwen3_vl:AscendQwen3VLForConditionalGeneration") - # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. ModelRegistry.register_model( diff --git a/vllm_ascend/models/qwen3_vl.py b/vllm_ascend/models/qwen3_vl.py deleted file mode 100644 index c79e71e7197..00000000000 --- a/vllm_ascend/models/qwen3_vl.py +++ /dev/null @@ -1,264 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Callable, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -try: - from transformers.models.qwen3_vl.configuration_qwen3_vl import \ - Qwen3VLConfig - from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \ - Qwen3VLMoeConfig -except ImportError: - pass -from vllm.config import VllmConfig -from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention - -try: - from vllm.model_executor.models.qwen3_vl import ( - Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer, - Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, - Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) - from vllm.model_executor.models.qwen3_vl_moe import ( - Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo) -except ImportError: - Qwen3_VisionBlock = object - Qwen3_VisionPatchEmbed = object - Qwen3_VisionTransformer = object - Qwen3VLDummyInputsBuilder = object - Qwen3VLForConditionalGeneration = object - Qwen3VLMultiModalProcessor = object - Qwen3VLProcessingInfo = object - Qwen3VLMoeForConditionalGeneration = object - Qwen3VLMoeProcessingInfo = object -from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix -from vllm.multimodal import MULTIMODAL_REGISTRY - - -class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) - x = x + self.proj.bias - return x - - -class AscendQwen3_VisionBlock(Qwen3_VisionBlock): - - def __init__( - self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False, - ) -> None: - super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, - quant_config, prefix, use_data_parallel) - self.attn = Qwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer): - - def __init__( - self, - vision_config, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix, - use_data_parallel) - norm_layer = partial(nn.LayerNorm, eps=norm_eps) - self.patch_embed = AscendQwen3_VisionPatchEmbed( - patch_size=self.patch_size, - temporal_patch_size=self.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) - self.blocks = nn.ModuleList([ - AscendQwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def forward( - self, - x: torch.Tensor, - grid_thw: list[list[int]], - ) -> torch.Tensor: - hidden_states = x.to(device=self.device, dtype=self.dtype) - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) - grid_thw_tensor = torch.tensor(grid_thw, - device=self.device, - dtype=torch.int32) - cu_seqlens = torch.repeat_interleave( - grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], - grid_thw_tensor[:, 0]).cpu().to(torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - hidden_states = hidden_states.unsqueeze(1) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) - - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - deepstack_feature_lists = [] - for layer_num, blk in enumerate(self.blocks): - hidden_states = blk(hidden_states, - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin) - if layer_num in self.deepstack_visual_indexes: - deepstack_merger_idx = self.deepstack_visual_indexes.index( - layer_num) - deepstack_feature = self.deepstack_merger_list[ - deepstack_merger_idx](hidden_states) - deepstack_feature_lists.append(deepstack_feature) - hidden_states = self.merger(hidden_states) - hidden_states = torch.cat( - [hidden_states] + deepstack_feature_lists, - dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] - return hidden_states - - -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - supports_encoder_tp_data = True - - # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.visual.": "visual.", - "lm_head.": "language_model.lm_head.", - "model.language_model.": "language_model.model.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen3VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.visual = AscendQwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel) - - -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLMoeProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class AscendQwen3VLMoeForConditionalGeneration( - Qwen3VLMoeForConditionalGeneration): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - supports_encoder_tp_data = True - - # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.visual.": "visual.", - "lm_head.": "language_model.lm_head.", - "model.language_model.": "language_model.model.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.visual = AscendQwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - ) diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 9e37f1ecaa2..0d1dd559880 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -29,4 +29,5 @@ import vllm_ascend.patch.worker.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa +import vllm_ascend.patch.worker.patch_qwen3_vl # noqa import vllm_ascend.patch.worker.patch_rope # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py index 464c62830b6..bb22acf3f17 100644 --- a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py +++ b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py @@ -65,7 +65,7 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, - seqlens: torch.Tensor, + seqlens: torch.Tensor = None, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) diff --git a/vllm_ascend/patch/worker/patch_qwen3_vl.py b/vllm_ascend/patch/worker/patch_qwen3_vl.py new file mode 100644 index 00000000000..1b80bbdcfa1 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3_vl.py @@ -0,0 +1,251 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +from transformers.models.qwen3_vl.configuration_qwen3_vl import \ + Qwen3VLVisionConfig +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.layer import check_upstream_fa_availability +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionBlock, + Qwen3_VisionPatchEmbed, + Qwen3_VisionPatchMerger, + Qwen3_VisionTransformer) +from vllm.model_executor.models.vision import get_vit_attn_backend + + +class AscendQwen3_VisionBlock(nn.Module): + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: torch.Tensor, # Only used for Flash Attention + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + ) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen3_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen3VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: AttentionBackendEnum | None = None, + ) -> None: + nn.Module.__init__(self) + + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.num_position_embeddings = vision_config.num_position_embeddings + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.use_data_parallel = use_data_parallel + self.num_grid_per_side = int(self.num_position_embeddings**0.5) + + # NOTE: This is used for creating empty tensor for all_gather for + # DP ViT. Here out_hidden_size is enlarged due to deepstack + self.out_hidden_size = vision_config.out_hidden_size * ( + 1 + len(self.deepstack_visual_indexes)) + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + self.pos_embed = nn.Embedding(self.num_position_embeddings, + self.hidden_size) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) + + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) + + self.deepstack_merger_list = nn.ModuleList([ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + use_data_parallel=use_data_parallel, + ) for layer_idx in range(len(self.deepstack_visual_indexes)) + ]) + + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + use_upstream_fa = False + if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype())): + self.attn_backend = AttentionBackendEnum.FLASH_ATTN + use_upstream_fa = True + + if self.attn_backend not in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, + }: + raise RuntimeError( + f"Qwen3-VL does not support {self.attn_backend} backend now.") + self.blocks = nn.ModuleList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + ) for layer_idx in range(vision_config.depth) + ]) + + def rot_pos_emb(self, grid_thw: list[list[int]]): + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + pos_ids = [ + self.rot_pos_ids(h, w, self.spatial_merge_size) if t == 1 else + self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) + for t, h, w in grid_thw + ] + pos_ids = torch.cat(pos_ids, dim=0) + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + # (num_tokens, rotary_dim // 2) + cos_h = cos[pos_ids[:, 0]] # type: ignore + cos_w = cos[pos_ids[:, 1]] # type: ignore + sin_h = sin[pos_ids[:, 0]] # type: ignore + sin_w = sin[pos_ids[:, 1]] # type: ignore + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + + return cos_combined, sin_combined + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor | list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, + dtype=self.dtype, + non_blocking=True) + hidden_states = self.patch_embed(hidden_states) + + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = np.array(grid_thw, dtype=np.int32) + else: + grid_thw = grid_thw.to("cpu") + grid_thw_list = grid_thw.tolist() + grid_thw = grid_thw.numpy() + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb( + grid_thw_list) + rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device, + non_blocking=True) + rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device, + non_blocking=True) + + cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum(axis=0, dtype=np.int32) + cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + cu_seqlens = torch.from_numpy(cu_seqlens) + + hidden_states = hidden_states.unsqueeze(1) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + +# NOTE: These will be removed after vllm-ascend is aligned with vllm latest main. +Qwen3_VisionBlock.forward = AscendQwen3_VisionBlock.forward +Qwen3_VisionTransformer.__init__ = AscendQwen3_VisionTransformer.__init__ +Qwen3_VisionTransformer.rot_pos_emb = AscendQwen3_VisionTransformer.rot_pos_emb +Qwen3_VisionTransformer.forward = AscendQwen3_VisionTransformer.forward