|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# This file is a part of the vllm-ascend project. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from functools import partial |
| 19 | + |
| 20 | +import numpy as np |
| 21 | +import torch |
| 22 | +import torch.nn as nn |
| 23 | +from transformers.models.qwen3_vl.configuration_qwen3_vl import \ |
| 24 | + Qwen3VLVisionConfig |
| 25 | +from vllm.attention.backends.registry import AttentionBackendEnum |
| 26 | +from vllm.attention.layer import check_upstream_fa_availability |
| 27 | +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY |
| 28 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
| 29 | +from vllm.model_executor.layers.rotary_embedding import get_rope |
| 30 | +from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionBlock, |
| 31 | + Qwen3_VisionPatchEmbed, |
| 32 | + Qwen3_VisionPatchMerger, |
| 33 | + Qwen3_VisionTransformer) |
| 34 | + |
| 35 | +from .vision import get_vit_attn_backend |
| 36 | + |
| 37 | + |
| 38 | +class AscendQwen3_VisionBlock(nn.Module): |
| 39 | + |
| 40 | + def forward( |
| 41 | + self, |
| 42 | + x: torch.Tensor, |
| 43 | + cu_seqlens: torch.Tensor, |
| 44 | + rotary_pos_emb_cos: torch.Tensor, |
| 45 | + rotary_pos_emb_sin: torch.Tensor, |
| 46 | + max_seqlen: torch.Tensor, # Only used for Flash Attention |
| 47 | + seqlens: torch.Tensor, # Only used for xFormers |
| 48 | + ) -> torch.Tensor: |
| 49 | + x = x + self.attn( |
| 50 | + self.norm1(x), |
| 51 | + cu_seqlens=cu_seqlens, |
| 52 | + rotary_pos_emb_cos=rotary_pos_emb_cos, |
| 53 | + rotary_pos_emb_sin=rotary_pos_emb_sin, |
| 54 | + max_seqlen=max_seqlen, |
| 55 | + seqlens=seqlens, |
| 56 | + ) |
| 57 | + |
| 58 | + x = x + self.mlp(self.norm2(x)) |
| 59 | + return x |
| 60 | + |
| 61 | + |
| 62 | +class AscendQwen3_VisionTransformer(nn.Module): |
| 63 | + |
| 64 | + def __init__( |
| 65 | + self, |
| 66 | + vision_config: Qwen3VLVisionConfig, |
| 67 | + norm_eps: float = 1e-6, |
| 68 | + quant_config: QuantizationConfig | None = None, |
| 69 | + prefix: str = "", |
| 70 | + use_data_parallel: bool = False, |
| 71 | + attn_backend_override: AttentionBackendEnum | None = None, |
| 72 | + ) -> None: |
| 73 | + super().__init__() |
| 74 | + self.hidden_size = vision_config.hidden_size |
| 75 | + self.num_heads = vision_config.num_heads |
| 76 | + self.num_position_embeddings = vision_config.num_position_embeddings |
| 77 | + self.patch_size = vision_config.patch_size |
| 78 | + self.spatial_merge_size = vision_config.spatial_merge_size |
| 79 | + self.spatial_merge_unit = self.spatial_merge_size**2 |
| 80 | + self.temporal_patch_size = vision_config.temporal_patch_size |
| 81 | + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes |
| 82 | + self.use_data_parallel = use_data_parallel |
| 83 | + self.num_grid_per_side = int(self.num_position_embeddings**0.5) |
| 84 | + |
| 85 | + # NOTE: This is used for creating empty tensor for all_gather for |
| 86 | + # DP ViT. Here out_hidden_size is enlarged due to deepstack |
| 87 | + self.out_hidden_size = vision_config.out_hidden_size * ( |
| 88 | + 1 + len(self.deepstack_visual_indexes)) |
| 89 | + |
| 90 | + self.patch_embed = Qwen3_VisionPatchEmbed( |
| 91 | + patch_size=self.patch_size, |
| 92 | + temporal_patch_size=self.temporal_patch_size, |
| 93 | + in_channels=vision_config.in_channels, |
| 94 | + hidden_size=self.hidden_size, |
| 95 | + ) |
| 96 | + |
| 97 | + self.pos_embed = nn.Embedding(self.num_position_embeddings, |
| 98 | + self.hidden_size) |
| 99 | + |
| 100 | + norm_layer = partial(nn.LayerNorm, eps=norm_eps) |
| 101 | + head_dim = self.hidden_size // self.num_heads |
| 102 | + self.rotary_pos_emb = get_rope( |
| 103 | + head_size=head_dim, |
| 104 | + rotary_dim=head_dim // 2, |
| 105 | + max_position=8192, |
| 106 | + base=10000.0, |
| 107 | + is_neox_style=True, |
| 108 | + ) |
| 109 | + |
| 110 | + self.merger = Qwen3_VisionPatchMerger( |
| 111 | + d_model=vision_config.out_hidden_size, |
| 112 | + context_dim=self.hidden_size, |
| 113 | + norm_layer=norm_layer, |
| 114 | + spatial_merge_size=self.spatial_merge_size, |
| 115 | + quant_config=quant_config, |
| 116 | + prefix=f"{prefix}.merger", |
| 117 | + use_data_parallel=use_data_parallel, |
| 118 | + ) |
| 119 | + |
| 120 | + self.deepstack_merger_list = nn.ModuleList([ |
| 121 | + Qwen3_VisionPatchMerger( |
| 122 | + d_model=vision_config.out_hidden_size, |
| 123 | + context_dim=self.hidden_size, |
| 124 | + spatial_merge_size=self.spatial_merge_size, |
| 125 | + use_postshuffle_norm=True, |
| 126 | + norm_layer=norm_layer, |
| 127 | + quant_config=quant_config, |
| 128 | + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", |
| 129 | + use_data_parallel=use_data_parallel, |
| 130 | + ) for layer_idx in range(len(self.deepstack_visual_indexes)) |
| 131 | + ]) |
| 132 | + |
| 133 | + self.attn_backend = get_vit_attn_backend( |
| 134 | + head_size=head_dim, |
| 135 | + dtype=torch.get_default_dtype(), |
| 136 | + attn_backend_override=attn_backend_override, |
| 137 | + ) |
| 138 | + use_upstream_fa = False |
| 139 | + if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN |
| 140 | + and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA |
| 141 | + and check_upstream_fa_availability(torch.get_default_dtype())): |
| 142 | + self.attn_backend = AttentionBackendEnum.FLASH_ATTN |
| 143 | + use_upstream_fa = True |
| 144 | + |
| 145 | + if self.attn_backend not in { |
| 146 | + AttentionBackendEnum.FLASH_ATTN, |
| 147 | + AttentionBackendEnum.TORCH_SDPA, |
| 148 | + AttentionBackendEnum.XFORMERS, |
| 149 | + AttentionBackendEnum.ROCM_AITER_FA, |
| 150 | + }: |
| 151 | + raise RuntimeError( |
| 152 | + f"Qwen3-VL does not support {self.attn_backend} backend now.") |
| 153 | + self.blocks = nn.ModuleList([ |
| 154 | + Qwen3_VisionBlock( |
| 155 | + dim=self.hidden_size, |
| 156 | + num_heads=self.num_heads, |
| 157 | + mlp_hidden_dim=vision_config.intermediate_size, |
| 158 | + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], |
| 159 | + norm_layer=norm_layer, |
| 160 | + quant_config=quant_config, |
| 161 | + prefix=f"{prefix}.blocks.{layer_idx}", |
| 162 | + use_data_parallel=use_data_parallel, |
| 163 | + attn_backend=self.attn_backend, |
| 164 | + use_upstream_fa=use_upstream_fa, |
| 165 | + ) for layer_idx in range(vision_config.depth) |
| 166 | + ]) |
| 167 | + |
| 168 | + def rot_pos_emb(self, grid_thw: list[list[int]]): |
| 169 | + max_grid_size = max(max(h, w) for _, h, w in grid_thw) |
| 170 | + pos_ids = [ |
| 171 | + self.rot_pos_ids(h, w, self.spatial_merge_size) if t == 1 else |
| 172 | + self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) |
| 173 | + for t, h, w in grid_thw |
| 174 | + ] |
| 175 | + pos_ids = torch.cat(pos_ids, dim=0) |
| 176 | + |
| 177 | + # Use pre-computed cos_sin_cache from RotaryEmbedding |
| 178 | + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) |
| 179 | + |
| 180 | + cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) |
| 181 | + cos_w = cos[pos_ids[:, 1]] |
| 182 | + sin_h = sin[pos_ids[:, 0]] |
| 183 | + sin_w = sin[pos_ids[:, 1]] |
| 184 | + |
| 185 | + cos_combined = torch.cat([cos_h, cos_w], dim=-1) |
| 186 | + sin_combined = torch.cat([sin_h, sin_w], dim=-1) |
| 187 | + |
| 188 | + return cos_combined, sin_combined |
| 189 | + |
| 190 | + def forward( |
| 191 | + self, |
| 192 | + x: torch.Tensor, |
| 193 | + grid_thw: torch.Tensor | list[list[int]], |
| 194 | + ) -> torch.Tensor: |
| 195 | + hidden_states = x.to(device=self.device, |
| 196 | + dtype=self.dtype, |
| 197 | + non_blocking=True) |
| 198 | + hidden_states = self.patch_embed(hidden_states) |
| 199 | + |
| 200 | + # if isinstance(grid_thw, list): |
| 201 | + # grid_thw_list = grid_thw |
| 202 | + # grid_thw = torch.tensor(grid_thw, dtype=torch.int32) |
| 203 | + # else: |
| 204 | + # grid_thw_list = grid_thw.tolist() |
| 205 | + if isinstance(grid_thw, list): |
| 206 | + print("Vit grid_thw -> list", flush=True) |
| 207 | + grid_thw_list = grid_thw |
| 208 | + grid_thw = np.array(grid_thw, dtype=np.int32) |
| 209 | + else: |
| 210 | + print("Vit grid_thw -> tensor", flush=True) |
| 211 | + # grid_thw = grid_thw.to("cpu") |
| 212 | + grid_thw_list = grid_thw.tolist() |
| 213 | + grid_thw = grid_thw.numpy() |
| 214 | + |
| 215 | + pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) |
| 216 | + hidden_states = hidden_states + pos_embeds |
| 217 | + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb( |
| 218 | + grid_thw_list) |
| 219 | + rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device, |
| 220 | + non_blocking=True) |
| 221 | + rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device, |
| 222 | + non_blocking=True) |
| 223 | + |
| 224 | + cu_seqlens = torch.repeat_interleave( |
| 225 | + grid_thw[:, 1] * grid_thw[:, 2], |
| 226 | + grid_thw[:, 0]).cumsum(dim=0, |
| 227 | + dtype=grid_thw.dtype |
| 228 | + if torch.jit.is_tracing() else torch.int32) |
| 229 | + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) |
| 230 | + |
| 231 | + hidden_states = hidden_states.unsqueeze(1) |
| 232 | + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) |
| 233 | + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) |
| 234 | + |
| 235 | + deepstack_feature_lists = [] |
| 236 | + for layer_num, blk in enumerate(self.blocks): |
| 237 | + hidden_states = blk( |
| 238 | + hidden_states, |
| 239 | + cu_seqlens=cu_seqlens, |
| 240 | + rotary_pos_emb_cos=rotary_pos_emb_cos, |
| 241 | + rotary_pos_emb_sin=rotary_pos_emb_sin, |
| 242 | + max_seqlen=max_seqlen, |
| 243 | + seqlens=seqlens, |
| 244 | + ) |
| 245 | + if layer_num in self.deepstack_visual_indexes: |
| 246 | + deepstack_merger_idx = self.deepstack_visual_indexes.index( |
| 247 | + layer_num) |
| 248 | + deepstack_feature = self.deepstack_merger_list[ |
| 249 | + deepstack_merger_idx](hidden_states) |
| 250 | + deepstack_feature_lists.append(deepstack_feature) |
| 251 | + hidden_states = self.merger(hidden_states) |
| 252 | + hidden_states = torch.cat( |
| 253 | + [hidden_states] + deepstack_feature_lists, |
| 254 | + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] |
| 255 | + return hidden_states |
| 256 | + |
| 257 | + |
| 258 | +Qwen3_VisionBlock.forward = AscendQwen3_VisionBlock.forward |
| 259 | +Qwen3_VisionTransformer.__init__ = AscendQwen3_VisionTransformer.__init__ |
| 260 | +Qwen3_VisionTransformer.rot_pos_emb = AscendQwen3_VisionTransformer.rot_pos_emb |
| 261 | +Qwen3_VisionTransformer.forward = AscendQwen3_VisionTransformer.forward |
0 commit comments