From 6136c44ba7684428a9ad99f7310192f5d72231b1 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 6 Nov 2025 11:17:03 +0800 Subject: [PATCH 1/3] update scripts --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1256 +++++++---------- 1 file changed, 513 insertions(+), 743 deletions(-) diff --git a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 49a0c0efe1..7672daab41 100644 --- a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -27,121 +27,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union -from transformers import Qwen2_5_VLConfig -from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, + Qwen2_5_VLTextConfig, + Qwen2_5_VLVisionConfig, +) import mindspore as ms import mindspore.mint as mint import mindspore.mint.nn.functional as F import mindspore.nn as nn import mindspore.ops as ops -from mindspore import Parameter, Tensor - -from mindone.models.utils import normal_, zeros_ -from mindone.transformers.activations import ACT2FN -from mindone.transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache -from mindone.transformers.generation import GenerationMixin -from mindone.transformers.modeling_attn_mask_utils import dtype_to_min -from mindone.transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from mindone.transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput -from mindone.transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from mindone.transformers.modeling_utils import MSPreTrainedModel -from mindone.transformers.processing_utils import Unpack -from mindone.transformers.utils import TransformersKwargs, logging - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, MSPreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging -def _flash_attention_forward( - query_states: Tensor, - key_states: Tensor, - value_states: Tensor, - attention_mask: Tensor, - query_length: int, - is_causal: bool, - dropout: float = 0.0, - position_ids: Optional[Tensor] = None, - softmax_scale: Optional[float] = None, - sliding_window: Optional[int] = None, - use_top_left_mask: bool = False, - softcap: Optional[float] = None, - deterministic: bool = None, - cu_seq_lens_q: Optional[Tensor] = None, - cu_seq_lens_k: Optional[Tensor] = None, - max_length_q: Optional[int] = None, - max_length_k: Optional[int] = None, - target_dtype: Optional[ms.Type] = None, - **kwargs, -): - bsz, _, num_heads, _ = query_states.shape - if is_causal and query_length > 1: - causal_mask = mint.triu(mint.ones((bsz, 1, query_length, key_states.shape[1]), dtype=ms.bool_), diagonal=1) - else: - causal_mask = None - - if attention_mask is not None: - attention_mask = ~attention_mask[:, None, None, :].to(ms.bool_) - if causal_mask is not None: - attention_mask = attention_mask | causal_mask - else: - attention_mask = mint.tile(attention_mask, (1, 1, query_length, 1)) - else: - attention_mask = causal_mask - - if softmax_scale is None: - scalar_value = 1 / math.sqrt(query_states.shape[-1]) - else: - scalar_value = softmax_scale - - attn_output = ops.flash_attention_score( - query_states, - key_states, - value_states, - num_heads, - attn_mask=attention_mask, - keep_prob=1 - dropout, - scalar_value=scalar_value, - input_layout="BSND", - ) - - return attn_output - - -def rotate_half_flashatt(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return mint.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return mint.stack((-x2, x1), dim=-1).flatten(-2) - - -def apply_rotary_emb_flashatt(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - if not interleaved: - cos = mint.tile(cos[:, None, :], (1, 1, 2)) - sin = mint.tile(sin[:, None, :], (1, 1, 2)) - else: - cos = mint.repeat_interleave(cos[:, None, :], (1, 1, 2)) - sin = mint.repeat_interleave(sin[:, None, :], (1, 1, 2)) - return mint.cat( - [x[..., :ro_dim] * cos + rotate_half_flashatt(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1 - ) +logger = logging.get_logger(__name__) class Qwen2_5_VLMLP(nn.Cell): - def __init__(self, config, bias: bool = False) -> None: + def __init__(self, config, bias: bool = False): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -168,10 +85,10 @@ def __init__( self.in_channels = in_channels self.embed_dim = embed_dim - kernel_size = (temporal_patch_size, patch_size, patch_size) + kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = mint.nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) - def construct(self, hidden_states: Tensor) -> Tensor: + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size @@ -183,10 +100,10 @@ def construct(self, hidden_states: Tensor) -> Tensor: class Qwen2_5_VisionRotaryEmbedding(nn.Cell): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - inv_freq = 1.0 / (theta ** (mint.arange(0, dim, 2, dtype=ms.float32) / dim)) + inv_freq = 1.0 / (theta ** (mint.arange(0, dim, 2, dtype=ms.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def construct(self, seqlen: int) -> Tensor: + def construct(self, seqlen: int) -> ms.Tensor: seq = mint.arange(seqlen, dtype=self.inv_freq.dtype) freqs = mint.outer(seq, self.inv_freq) return freqs @@ -198,7 +115,7 @@ def __init__(self, hidden_size, eps=1e-6): Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = Parameter(mint.ones(hidden_size)) + self.weight = ms.Parameter(mint.ones(hidden_size)) self.variance_epsilon = eps def construct(self, hidden_states): @@ -218,68 +135,16 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) self.mlp = nn.SequentialCell( - mint.nn.Linear(self.hidden_size, self.hidden_size), mint.nn.GELU(), mint.nn.Linear(self.hidden_size, dim) + mint.nn.Linear(self.hidden_size, self.hidden_size), + mint.nn.GELU(), + mint.nn.Linear(self.hidden_size, dim), ) - def construct(self, x: Tensor) -> Tensor: + def construct(self, x: ms.Tensor) -> ms.Tensor: x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) return x -def apply_rotary_pos_emb_flashatt(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor) -> Tuple[Tensor, Tensor]: - cos = cos.chunk(2, dim=-1)[0].contiguous() - sin = sin.chunk(2, dim=-1)[0].contiguous() - q_embed = apply_rotary_emb_flashatt(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb_flashatt(k.float(), cos.float(), sin.float()).type_as(k) - return q_embed, k_embed - - -class Qwen2_5_VLVisionFlashAttention2(nn.Cell): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = mint.nn.Linear(dim, dim * 3, bias=True) - self.proj = mint.nn.Linear(dim, dim) - - def construct( - self, - hidden_states: Tensor, - cu_seqlens: Tensor, - rotary_pos_emb: Optional[Tensor] = None, - position_embeddings: Optional[Tuple[Tensor, Tensor]] = None, - ) -> Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " - "removed and `position_embeddings` will be mandatory." - ) - emb = mint.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) - q = q.squeeze(0) - k = k.squeeze(0) - - attn_output = ops.flash_attention_score( - q, - k, - v, - self.num_heads, - actual_seq_qlen=cu_seqlens, - actual_seq_kvlen=cu_seqlens, - scalar_value=1 / math.sqrt(q.shape[-1]), - input_layout="TND", - ).reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -287,7 +152,9 @@ def rotate_half(x): return mint.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb_vision(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor) -> Tuple[Tensor, Tensor]: +def apply_rotary_pos_emb_vision( + q: ms.Tensor, k: ms.Tensor, cos: ms.Tensor, sin: ms.Tensor +) -> tuple[ms.Tensor, ms.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() @@ -299,23 +166,70 @@ def apply_rotary_pos_emb_vision(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor) return q_embed, k_embed +def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: + """ + This is the equivalent of mint.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand((batch, num_key_value_heads, n_rep, slen, head_dim)) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mint.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Qwen2_5_VLVisionAttention(nn.Cell): - def __init__(self, dim: int, num_heads: int = 16) -> None: + def __init__(self, config: Qwen2_5_VLVisionConfig) -> None: super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.qkv = mint.nn.Linear(dim, dim * 3, bias=True) - self.proj = mint.nn.Linear(dim, dim) + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = mint.nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = mint.nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def construct( self, - hidden_states: Tensor, - cu_seqlens: Tensor, - rotary_pos_emb: Optional[Tensor] = None, - position_embeddings: Optional[Tuple[Tensor, Tensor]] = None, - ) -> Tensor: + hidden_states: ms.Tensor, + cu_seqlens: ms.Tensor, + rotary_pos_emb: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, + **kwargs, + ) -> ms.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -328,88 +242,104 @@ def construct( sin = emb.sin() else: cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - - attention_mask = mint.full([1, seq_length, seq_length], dtype_to_min(q.dtype).item(), dtype=q.dtype) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = mint.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = F.softmax(attn_weights, dim=-1, dtype=ms.float32).to(q.dtype) - attn_output = mint.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) -QWEN2_5_VL_VISION_ATTENTION_CLASSES = { - "eager": Qwen2_5_VLVisionAttention, - "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, -} + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + mint.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = mint.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output -class Qwen2_5_VLVisionBlock(nn.Cell): +class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "flash_attention_2") -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) - self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.hidden_size, num_heads=config.num_heads - ) + self.attn = Qwen2_5_VLVisionAttention(config=config) self.mlp = Qwen2_5_VLMLP(config, bias=True) def construct( self, - hidden_states: Tensor, - cu_seqlens: Tensor, - rotary_pos_emb: Optional[Tensor] = None, - position_embeddings: Optional[Tuple[Tensor, Tensor]] = None, - ) -> Tensor: + hidden_states: ms.Tensor, + cu_seqlens: ms.Tensor, + rotary_pos_emb: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, + **kwargs, + ) -> ms.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states class Qwen2_5_VLPreTrainedModel(MSPreTrainedModel): - config_class = Qwen2_5_VLConfig + config: Qwen2_5_VLConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = False - _supports_cache_class = True - _supports_static_cache = False - - def _init_weights(self, module): - # if hasattr(self.config, "initializer_range"): - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - if isinstance(module, (mint.nn.Linear, mint.nn.Conv3d)): - normal_(module.weight, mean=0.0, std=std) - if module.bias is not None: - zeros_(module.bias) - elif isinstance(module, mint.nn.Embedding): - normal_(module.weight, mean=0.0, std=std) - if module.padding_idx is not None: - module.weight[module.padding_idx] = 0 + + _can_compile_fullgraph = False + _supports_attention_backend = True class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): - config_class = Qwen2_5_VLVisionConfig + config: Qwen2_5_VLVisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] def __init__(self, config, *inputs, **kwargs) -> None: @@ -430,11 +360,11 @@ def __init__(self, config, *inputs, **kwargs) -> None: head_dim = config.hidden_size // config.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.CellList( - [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] - ) + self.blocks = nn.CellList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)]) self.merger = Qwen2_5_VLPatchMerger( - dim=config.out_hidden_size, context_dim=config.hidden_size, spatial_merge_size=config.spatial_merge_size + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, ) self.gradient_checkpointing = False @@ -484,10 +414,17 @@ def get_window_index(self, grid_thw): num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) index_padded = index_padded.reshape( - grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) @@ -500,21 +437,24 @@ def get_window_index(self, grid_thw): return window_index, cu_window_seqlens - def construct(self, hidden_states: Tensor, grid_thw: Tensor) -> Tensor: + def construct(self, hidden_states: ms.Tensor, grid_thw: ms.Tensor, **kwargs) -> ms.Tensor: """ Args: - hidden_states (`Tensor` of shape `(seq_len, hidden_size)`): + hidden_states (`ms.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. - grid_thw (`Tensor` of shape `(num_images_or_videos, 3)`): + grid_thw (`ms.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. Returns: - `Tensor`: hidden_states. + `ms.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = ms.tensor(cu_window_seqlens, dtype=ms.int32) + cu_window_seqlens = ms.tensor( + cu_window_seqlens, + dtype=ms.int32, + ) cu_window_seqlens = mint.unique_consecutive(cu_window_seqlens) seq_len, _ = hidden_states.shape @@ -528,7 +468,8 @@ def construct(self, hidden_states: Tensor, grid_thw: Tensor) -> Tensor: position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = mint.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=ms.int32 + dim=0, + dtype=ms.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) @@ -537,10 +478,13 @@ def construct(self, hidden_states: Tensor, grid_thw: Tensor) -> Tensor: cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = ms.recompute(blk, hidden_states, cu_seqlens_now, None, position_embeddings) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = self.merger(hidden_states) reverse_indices = mint.argsort(window_index) @@ -570,7 +514,7 @@ class Qwen2_5_VLModelOutputWithPast(ModelOutput): class Qwen2_5_VLRotaryEmbedding(nn.Cell): - def __init__(self, config: Qwen2_5_VLConfig): + def __init__(self, config: Qwen2_5_VLTextConfig): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: @@ -587,39 +531,17 @@ def __init__(self, config: Qwen2_5_VLConfig): self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - def _dynamic_frequency_update(self, position_ids): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = mint.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, seq_len=seq_len, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def construct(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids) - - # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids + # In contrast to other models, Qwen2_5_VL has different position ids for the grids # So we expand the inv_freq to shape (3, ...) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand((3, position_ids.shape[1], -1, 1)) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = mint.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -653,11 +575,11 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim difference with modern LLMs. Args: - q (`Tensor`): The query tensor. - k (`Tensor`): The key tensor. - cos (`Tensor`): The cosine part of the rotary embedding. - sin (`Tensor`): The sine part of the rotary embedding. - position_ids (`Tensor`): + q (`ms.Tensor`): The query tensor. + k (`ms.Tensor`): The key tensor. + cos (`ms.Tensor`): The cosine part of the rotary embedding. + sin (`ms.Tensor`): The sine part of the rotary embedding. + position_ids (`ms.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. mrope_section(`List(int)`): @@ -670,7 +592,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: - `tuple(Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ mrope_section = mrope_section * 2 cos = mint.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) @@ -681,25 +603,13 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim return q_embed, k_embed -def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: - """ - This is the equivalent of mint.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand((batch, num_key_value_heads, n_rep, slen, head_dim)) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - class Qwen2_5_VLAttention(nn.Cell): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -718,6 +628,7 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): self.is_causal = True self.attention_dropout = config.attention_dropout self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -728,96 +639,22 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): self.k_proj = mint.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = mint.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) def construct( self, - hidden_states: Tensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[Tensor] = None, - position_embeddings: Optional[Tensor] = None, # necessary, but kept here for BC - ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: - bsz, q_len, _ = hidden_states.shape - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = mint.unbind(position_embeddings) - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = mint.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # Fix precision issues in Qwen2-VL float16 inference - # Replace inf values with zeros in attention weights to prevent NaN propagation - if query_states.dtype == ms.float16: - attn_weights = mint.where(mint.isinf(attn_weights), mint.zeros_like(attn_weights), attn_weights) - - # upcast attention to fp32 - attn_weights = F.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query_states.dtype) - attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = mint.matmul(attn_weights, value_states) - - if attn_output.shape != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.shape}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): - """ - Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` - as the weights of the module stays untouched. The only required change would be on the construct pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - def construct( - self, - hidden_states: Tensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[Tensor] = None, - position_embeddings: Optional[Tuple[Tensor, Tensor]] = None, # necessary, but kept here for BC - ): + cache_position: Optional[ms.Tensor] = None, + position_embeddings: Optional[ms.Tensor] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ms.Tensor, Optional[ms.Tensor], Optional[tuple[ms.Tensor]]]: bsz, q_len, _ = hidden_states.shape query_states = self.q_proj(hidden_states) @@ -828,8 +665,7 @@ def construct( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - # Because the input can be padded, the absolute sequence length depends on the max position id. - cos, sin = mint.unbind(position_embeddings) + cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) @@ -838,72 +674,30 @@ def construct( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == ms.float32: - target_dtype = self.q_proj.weight.dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - assert sliding_window is None, "sliding_window is not supported yet." - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + position_ids=position_ids, # pass positions for FA2 + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - -QWEN2_5_VL_ATTENTION_CLASSES = { - "eager": Qwen2_5_VLAttention, - "flash_attention_2": Qwen2_5_VLFlashAttention2, -} - - -class Qwen2_5_VLDecoderLayer(nn.Cell): - def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): +class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -912,28 +706,29 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] def construct( self, - hidden_states: Tensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - past_key_value: Optional[Tuple[Tensor]] = None, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_value: Optional[tuple[ms.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[Tensor] = None, - position_embeddings: Optional[Tuple[Tensor, Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: + cache_position: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ms.Tensor, Optional[tuple[ms.Tensor, ms.Tensor]]]: """ Args: - hidden_states (`Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`Tensor`, *optional*): attention mask of size + hidden_states (`ms.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`ms.Tensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -941,10 +736,10 @@ def construct( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(Tensor)`, *optional*): cached past key and value projection states - cache_position (`Tensor` of shape `(sequence_length)`, *optional*): + past_key_value (`Tuple(ms.Tensor)`, *optional*): cached past key and value projection states + cache_position (`ms.Tensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[Tensor, Tensor]`, *optional*): + position_embeddings (`tuple[ms.Tensor, ms.Tensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): @@ -957,7 +752,7 @@ def construct( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -966,6 +761,7 @@ def construct( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -980,9 +776,6 @@ def construct( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1001,6 +794,7 @@ def __init__(self, config: Qwen2_5_VLTextConfig): self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -1051,57 +845,71 @@ def construct( # the hard coded `3` is for temporal, height and width. if position_ids is None: position_ids = cache_position.view(1, 1, -1).expand((3, inputs_embeds.shape[0], -1)) - elif position_ids.dim() == 2: + elif position_ids.ndim == 2: position_ids = position_ids[None, ...].expand((3, position_ids.shape[0], -1)) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions + # where each dim indicates visual spatial positions for temporal/height/width grids. + # There are two scenarios when FA2-like packed masking might be activated. + # 1. User specifically passed packed `position_ids` and no attention mask. + # In this case we expect the useer to create correct position ids for all 3 grids + # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len] + # 2. User runs forward with no attention mask and no position ids. In this case, position ids + # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are + # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass + # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation` + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": text_position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - position_embeddings = mint.stack(position_embeddings) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = ms.recompute( - decoder_layer, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=text_position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1111,135 +919,17 @@ def construct( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Tensor, - input_tensor: Tensor, - cache_position: Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.shape[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: Tensor, - sequence_length: int, - target_length: int, - dtype: ms.Type, - cache_position: Tensor, - batch_size: int, - config: Qwen2_5_VLConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of - the cache that is not filled yet. - dtype (`ms.Type`): - The dtype to use for the 4D attention mask. - cache_position (`Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`Tensor`): - Batch size. - config (`Qwen2_5_VLConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = dtype_to_min(dtype) - causal_mask = mint.full((sequence_length, target_length), fill_value=min_dtype.item(), dtype=dtype) - diagonal_attend_mask = mint.arange(target_length) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = mint.arange(target_length) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask = diagonal_attend_mask.bitwise_or(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand((batch_size, 1, -1, -1)) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): base_model_prefix = "" @@ -1249,33 +939,18 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): def __init__(self, config): super().__init__(config) - # TODO: we need this patch here, may fix later - config.vision_config._attn_implementation = config._attn_implementation - config.vision_config.mindspore_dtype = getattr(config, "mindspore_dtype", None) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.language_model = Qwen2_5_VLTextModel._from_config(config.text_config) - self.vocab_size = config.vocab_size - self.lm_head = mint.nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.rope_deltas = None # cache rope_deltas here # Initialize weights and apply final processing self.post_init() - def gradient_checkpointing_enable(self, **kwargs): - self.language_model.gradient_checkpointing = True - self.visual.gradient_checkpointing = True - def get_input_embeddings(self): - return self.language_model.embed_tokens + return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): - self.language_model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + self.language_model.set_input_embeddings(value) def set_decoder(self, decoder): self.language_model = decoder @@ -1285,12 +960,12 @@ def get_decoder(self): def get_rope_index( self, - input_ids: Optional[Tensor] = None, - image_grid_thw: Optional[Tensor] = None, - video_grid_thw: Optional[Tensor] = None, - second_per_grid_ts: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + input_ids: Optional[ms.Tensor] = None, + image_grid_thw: Optional[ms.Tensor] = None, + video_grid_thw: Optional[ms.Tensor] = None, + second_per_grid_ts: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. @@ -1328,24 +1003,24 @@ def get_rope_index( Here we calculate the text start position_ids as the max vision position_ids plus 1. Args: - input_ids (`Tensor` of shape `(batch_size, sequence_length)`): + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. - image_grid_thw (`Tensor` of shape `(num_images, 3)`, *optional*): + image_grid_thw (`ms.Tensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`Tensor` of shape `(num_videos, 3)`, *optional*): + video_grid_thw (`ms.Tensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - second_per_grid_ts (`Tensor` of shape `(num_videos)`, *optional*): + second_per_grid_ts (`ms.Tensor` of shape `(num_videos)`, *optional*): The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. - attention_mask (`Tensor` of shape `(batch_size, sequence_length)`, *optional*): + attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. Returns: - position_ids (`Tensor` of shape `(3, batch_size, sequence_length)`) - mrope_position_deltas (`Tensor` of shape `(batch_size)`) + position_ids (`ms.Tensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`ms.Tensor` of shape `(batch_size)`) """ spatial_merge_size = self.config.vision_config.spatial_merge_size image_token_id = self.config.image_token_id @@ -1356,7 +1031,10 @@ def get_rope_index( total_input_ids = input_ids if attention_mask is None: attention_mask = mint.ones_like(total_input_ids) - position_ids = mint.ones((3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype) + position_ids = mint.ones( + (3, input_ids.shape[0], input_ids.shape[1]), + dtype=input_ids.dtype, + ) image_index, video_index = 0, 0 for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] @@ -1416,6 +1094,9 @@ def get_rope_index( range_tensor = mint.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand((-1, llm_grid_h * llm_grid_w)) + # normalize type + second_per_grid_t = ms.tensor(second_per_grid_t, dtype=range_tensor.dtype) + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second time_tensor_long = time_tensor.long() @@ -1432,7 +1113,7 @@ def get_rope_index( llm_pos_ids_list.append(mint.arange(text_len).view(1, -1).expand((3, -1)) + st_idx) llm_positions = mint.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.dtype) + position_ids[..., i, attention_mask[i] == 1] = llm_positions mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) mrope_position_deltas = ms.tensor(mrope_position_deltas).unsqueeze(1) return position_ids, mrope_position_deltas @@ -1445,10 +1126,45 @@ def get_rope_index( mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = mint.arange(input_ids.shape[1]).view(1, 1, -1).expand((3, input_ids.shape[0], -1)) - mrope_position_deltas = mint.zeros([input_ids.shape[0], 1], dtype=input_ids.dtype) + mrope_position_deltas = mint.zeros( + [input_ids.shape[0], 1], + dtype=input_ids.dtype, + ) return position_ids, mrope_position_deltas + def get_video_features(self, pixel_values_videos: ms.Tensor, video_grid_thw: Optional[ms.Tensor] = None): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`ms.Tensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`ms.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = mint.split(video_embeds, split_sizes) + return video_embeds + + def get_image_features(self, pixel_values: ms.Tensor, image_grid_thw: Optional[ms.Tensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`ms.Tensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`ms.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = mint.split(image_embeds, split_sizes) + return image_embeds + def construct( self, input_ids: ms.Tensor = None, @@ -1468,7 +1184,18 @@ def construct( cache_position: Optional[ms.Tensor] = None, second_per_grid_ts: Optional[ms.Tensor] = None, **kwargs: Unpack[TransformersKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: + ) -> Union[tuple, Qwen2_5_VLModelOutputWithPast]: + r""" + image_grid_thw (`ms.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`ms.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`ms.Tensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + second_per_grid_ts (`ms.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1476,69 +1203,80 @@ def construct( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: - inputs_embeds = self.language_model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) + inputs_embeds = self.get_input_embeddings()(input_ids) - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = mint.cat(image_embeds, dim=0) - # masked_scatter does not support bf16, so embedding dtype needs to be converted to fp32. - inputs_embeds = ( - inputs_embeds.float().masked_scatter(image_mask, image_embeds.float()).to(inputs_embeds.dtype) + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + ms.tensor(self.config.image_token_id, dtype=ms.long) ) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (image_mask).sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_embeds = image_embeds.to(inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - mask = input_ids == self.config.video_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = mint.cat(video_embeds, dim=0) - # masked_scatter does not support bf16, so embedding dtype needs to be converted to fp32. - inputs_embeds = ( - inputs_embeds.float().masked_scatter(video_mask, video_embeds.float()).to(inputs_embeds.dtype) + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + ms.tensor(self.config.video_token_id, dtype=ms.long) + ) + video_mask = video_mask.all(-1) + else: + video_mask = input_ids == self.config.video_token_id + + n_video_tokens = (video_mask).sum() + n_video_features = video_embeds.shape[0] + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds) + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) + video_embeds = video_embeds.to(inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) - ): + if position_ids is None: + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = False + prefill_noncompiled_stage = (cache_position is not None and cache_position[0] == 0) or ( + past_key_values is None or past_key_values.get_seq_length() == 0 + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask, ) self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape - delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 position_ids = mint.arange(seq_length) - position_ids = position_ids.view(1, -1).expand((batch_size, -1)) - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand((3, -1, -1)) + position_ids = position_ids.view(1, 1, -1).expand((3, batch_size, -1)) + if cache_position is not None: + delta = cache_position[0] + self.rope_deltas + else: + delta = mint.zeros((batch_size, seq_length)) + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids += delta outputs = self.language_model( input_ids=None, @@ -1549,8 +1287,9 @@ def construct( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, + **kwargs, ) output = Qwen2_5_VLModelOutputWithPast( @@ -1566,40 +1305,26 @@ def construct( @dataclass class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): """ - Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - Args: - loss (`Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `Tensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - rope_deltas (`Tensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`ms.Tensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. """ - loss: Optional[Tensor] = None - logits: Tensor = None - past_key_values: Optional[List[Tensor]] = None - hidden_states: Optional[Tuple[Tensor]] = None - attentions: Optional[Tuple[Tensor]] = None - rope_deltas: Optional[Tensor] = None + loss: Optional[ms.Tensor] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[list[ms.Tensor]] = None + hidden_states: Optional[tuple[ms.Tensor]] = None + attentions: Optional[tuple[ms.Tensor]] = None + rope_deltas: Optional[ms.Tensor] = None class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): @@ -1663,15 +1388,20 @@ def construct( second_per_grid_ts: Optional[ms.Tensor] = None, logits_to_keep: Union[int, ms.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], - ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + ) -> Union[tuple, Qwen2_5_VLCausalLMOutputWithPast]: r""" - Args: - labels (`Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: + labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`ms.Tensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`ms.Tensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`ms.Tensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + second_per_grid_ts (`ms.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Example: @@ -1703,6 +1433,7 @@ def construct( >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1780,8 +1511,35 @@ def prepare_inputs_for_generation( **kwargs, ) - # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward - model_inputs["position_ids"] = None + # Qwen2-5-VL position_ids are prepared with rope_deltas + if position_ids is None: + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + if cache_position[0] == 0 or self.model.rope_deltas is None: + vision_positions, rope_deltas = self.model.get_rope_index( + model_inputs.get("input_ids", None), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask, + ) + self.model.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + elif "position_ids" in model_inputs: + position_ids = model_inputs["position_ids"][None, ...] + delta = self.model.rope_deltas + delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0) + vision_positions = position_ids + delta.expand_as(position_ids) + vision_positions = vision_positions.expand((3, vision_positions.shape[1], -1)) + + # Concatenate "text + vision" positions into [4, bs, seq-len] + if "position_ids" not in model_inputs: + text_positions = mint.arange(input_ids)[None, None, :] + else: + text_positions = model_inputs["position_ids"][None, ...] + model_inputs["position_ids"] = mint.cat([text_positions, vision_positions], dim=0) if cache_position[0] != 0: model_inputs["pixel_values"] = None @@ -1791,28 +1549,41 @@ def prepare_inputs_for_generation( def _get_image_nums_and_video_nums( self, - input_ids: Optional[Tensor], - ) -> Tuple[Tensor, Tensor]: + input_ids: Optional[ms.Tensor], + inputs_embeds: Optional[ms.Tensor] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. Args: - input_ids (`Tensor` of shape `(batch_size, sequence_length)`): + input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Returns: - image_nums (`Tensor` of shape `(batch_size, num_images_sample)`) - video_nums (`Tensor` of shape `(batch_size, num_videos_sample)`) + image_nums (`ms.Tensor` of shape `(batch_size, num_images_sample)`) + video_nums (`ms.Tensor` of shape `(batch_size, num_videos_sample)`) """ image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id - vision_start_mask = input_ids == vision_start_token_id + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds == self.get_input_embeddings()(ms.tensor(vision_start_token_id, dtype=ms.long)) + )[..., 0] + image_mask = (inputs_embeds == self.get_input_embeddings()(ms.tensor(image_token_id, dtype=ms.long)))[ + ..., 0 + ] + video_mask = (inputs_embeds == self.get_input_embeddings()(ms.tensor(video_token_id, dtype=ms.long)))[ + ..., 0 + ] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + vision_first_mask = mint.roll(vision_start_mask, shifts=1, dims=1) - image_mask = input_ids == image_token_id - video_mask = input_ids == video_token_id image_nums = mint.sum(vision_first_mask & image_mask, dim=1) video_nums = mint.sum(vision_first_mask & video_mask, dim=1) @@ -1822,9 +1593,9 @@ def _expand_inputs_for_generation( self, expand_size: int = 1, is_encoder_decoder: bool = False, - input_ids: Optional[Tensor] = None, + input_ids: Optional[ms.Tensor] = None, **model_kwargs, - ) -> Tuple[Tensor, Dict[str, Any]]: + ) -> tuple[ms.Tensor, dict[str, Any]]: # Overwritten -- Support for expanding tensors without a batch size dimension # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t # pixel_values.shape[0] is sum(seqlen_images for samples) @@ -1838,7 +1609,9 @@ def _expand_inputs_for_generation( def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None) - image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) def _repeat_interleave_samples(x, lengths, repeat_times): samples = mint.split(x, lengths) @@ -1888,16 +1661,13 @@ def _expand_dict_for_generation(dict_to_expand): if ( key != "cache_position" and dict_to_expand[key] is not None - and isinstance(dict_to_expand[key], Tensor) + and isinstance(dict_to_expand[key], ms.Tensor) and key not in visual_keys ): dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand - # input_ids is required for expanding visual inputs - # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. - if input_ids is not None and input_ids.numel() != 0: - model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) @@ -1912,4 +1682,4 @@ def _expand_dict_for_generation(dict_to_expand): return input_ids, model_kwargs -__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"] From f4cd807e81ee605a2187bf6a89618504fb3d61d5 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 6 Nov 2025 11:25:09 +0800 Subject: [PATCH 2/3] update scripts --- .../qwen2_5_vl/processing_qwen2_5_vl.py | 138 +++++++++++++----- 1 file changed, 98 insertions(+), 40 deletions(-) diff --git a/mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index e3c08a138e..5492725497 100644 --- a/mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -26,29 +26,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Optional, Union +import numpy as np from transformers.tokenization_utils_base import PreTokenizedInput, TextInput import mindspore as ms from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs from ...video_utils import VideoInput class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): - fps: Union[List[float], float] + fps: Union[list[float], float] + + +class Qwen2_5_VLImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen2_5_VLImagesKwargs videos_kwargs: Qwen2_5_VLVideosProcessorKwargs _defaults = { "text_kwargs": { "padding": False, + "return_mm_token_type_ids": False, }, - "videos_kwargs": {"fps": 2.0}, } @@ -62,25 +72,37 @@ class Qwen2_5_VLProcessor(ProcessorMixin): The image processor is a required input. tokenizer ([`Qwen2TokenizerFast`], *optional*): The tokenizer is a required input. + video_processor ([`Qwen2_5_VLVideoProcessor`], *optional*): + The video processor is a required input. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. """ - attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] + attributes = ["image_processor", "tokenizer", "video_processor"] image_processor_class = "AutoImageProcessor" + video_processor_class = "AutoVideoProcessor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") - def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token - super().__init__(image_processor, tokenizer, chat_template=chat_template) + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.convert_tokens_to_ids(self.video_token) + ) + super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) def __call__( self, images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, videos: VideoInput = None, **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], ) -> BatchFeature: @@ -91,14 +113,14 @@ def __call__( Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. Args: - images (`PIL.Image.Image`, `np.ndarray`, `ms.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[ms.Tensor]`): + images (`PIL.Image.Image`, `np.ndarray`, `ms.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[ms.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - text (`str`, `List[str]`, `List[List[str]]`): + text (`str`, `list[str]`, `list[list[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - videos (`np.ndarray`, `ms.Tensor`, `List[np.ndarray]`, `List[ms.Tensor]`): + videos (`np.ndarray`, `ms.Tensor`, `list[np.ndarray]`, `list[ms.Tensor]`): The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. return_tensors (`str` or [`~utils.TensorType`], *optional*): @@ -124,22 +146,21 @@ def __call__( tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) + + image_inputs = videos_inputs = {} if images is not None: - image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - else: - image_inputs = {} - image_grid_thw = None if videos is not None: - videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + fps = output_kwargs["videos_kwargs"].get("fps", 2.0) + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] - fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) if isinstance(fps, (int, float)): - second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw) elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): - second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps] else: raise ValueError( f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the " @@ -147,46 +168,83 @@ def __call__( ) videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) - else: - videos_inputs = {} - video_grid_thw = None - if not isinstance(text, list): text = [text] - if image_grid_thw is not None: + text = text.copy() # below lines change text in-place + if images is not None: merge_length = self.image_processor.merge_size**2 index = 0 for i in range(len(text)): while self.image_token in text[i]: - text[i] = text[i].replace( - self.image_token, - "<|placeholder|>" * (image_grid_thw[index].prod().item() // merge_length), - 1, - ) + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", self.image_token) - if video_grid_thw is not None: - merge_length = self.image_processor.merge_size**2 + if videos is not None: + merge_length = self.video_processor.merge_size**2 index = 0 for i in range(len(text)): while self.video_token in text[i]: - text[i] = text[i].replace( - self.video_token, - "<|placeholder|>" * (video_grid_thw[index].prod().item() // merge_length), - 1, - ) + num_video_tokens = video_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", self.video_token) return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors="np") if return_tensors == "ms": for k, v in text_inputs.items(): text_inputs[k] = ms.tensor(v) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + video_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (num_frames, height, width) per each video. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ - return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + vision_data = {} + if image_sizes is not None: + images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + if video_sizes is not None: + videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {}) + videos_kwargs.update(kwargs) + num_video_patches = [ + self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) + for video_size in video_sizes + ] + num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] + vision_data["num_video_tokens"] = num_video_tokens + + return MultiModalData(**vision_data) def batch_decode(self, *args, **kwargs): """ @@ -214,13 +272,13 @@ def post_process_image_text_to_text( or `(sequence_length,)`. skip_special_tokens (`bool`, *optional*, defaults to `True`): Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. - Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. **kwargs: Additional arguments to be passed to the tokenizer's `batch_decode method`. Returns: - `List[str]`: The decoded text. + `list[str]`: The decoded text. """ return self.tokenizer.batch_decode( generated_outputs, From 6df3a7636adbe5f0946f2e3c2f9c1c677790b0ea Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 6 Nov 2025 15:35:52 +0800 Subject: [PATCH 3/3] fix several bugs --- mindone/transformers/modeling_utils.py | 4 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 46 +++++++++---------- .../qwen2_5_vl/processing_qwen2_5_vl.py | 4 +- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 7fc96f444f..0da8102eba 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -3057,7 +3057,7 @@ def _find_mismatched_keys( if state_dict is not None: # Whole checkpoint # checkpoint mapping from pt to hf - matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s] + matching = key_renaming_mapping.keys() if matching: # Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} @@ -3090,7 +3090,7 @@ def _find_mismatched_keys( for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) # checkpoint mapping from pt to hf - matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s] + matching = key_renaming_mapping.keys() if matching: # Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta state_dict = { diff --git a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 7672daab41..23ba8a5794 100644 --- a/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/mindone/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -120,10 +120,10 @@ def __init__(self, hidden_size, eps=1e-6): def construct(self, hidden_states): input_dtype = hidden_states.dtype - output, _ = ops.rms_norm( - hidden_states.to(ms.float32), self.weight.to(ms.float32), epsilon=self.variance_epsilon - ) - return output.to(input_dtype) + hidden_states = hidden_states.to(ms.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" @@ -254,22 +254,18 @@ def construct( if self.config._attn_implementation == "flash_attention_2": # Flash Attention 2: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=None, - scaling=self.scaling, - dropout=0.0 if not self.training else self.attention_dropout, - cu_seq_lens_q=cu_seqlens, - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + attn_output = ops.flash_attention_score( + query_states.squeeze(0).transpose(0, 1), + key_states.squeeze(0).transpose(0, 1), + value_states.squeeze(0).transpose(0, 1), + self.num_heads, + attn_mask=None, + actual_seq_qlen=cu_seqlens, + actual_seq_kvlen=cu_seqlens, + keep_prob=1.0 if not self.training else 1 - self.attention_dropout, + scalar_value=self.scaling, + input_layout="TND", + ).unsqueeze(0) else: # Other implementations: Process each chunk separately lengths = cu_seqlens[1:] - cu_seqlens[:-1] @@ -1113,7 +1109,7 @@ def get_rope_index( llm_pos_ids_list.append(mint.arange(text_len).view(1, -1).expand((3, -1)) + st_idx) llm_positions = mint.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(dtype=position_ids.dtype) mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) mrope_position_deltas = ms.tensor(mrope_position_deltas).unsqueeze(1) return position_ids, mrope_position_deltas @@ -1278,6 +1274,7 @@ def construct( delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids += delta + return_dict = kwargs.pop("return_dict", True) outputs = self.language_model( input_ids=None, position_ids=position_ids, @@ -1287,7 +1284,7 @@ def construct( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=True, + return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -1439,6 +1436,7 @@ def construct( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = kwargs.pop("return_dict", True) outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, @@ -1453,7 +1451,7 @@ def construct( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=True, + return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -1539,7 +1537,7 @@ def prepare_inputs_for_generation( text_positions = mint.arange(input_ids)[None, None, :] else: text_positions = model_inputs["position_ids"][None, ...] - model_inputs["position_ids"] = mint.cat([text_positions, vision_positions], dim=0) + model_inputs["position_ids"] = mint.cat([text_positions, vision_positions.to(text_positions.dtype)], dim=0) if cache_position[0] != 0: model_inputs["pixel_values"] = None diff --git a/mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index 5492725497..f3c709d460 100644 --- a/mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/mindone/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -177,7 +177,7 @@ def __call__( index = 0 for i in range(len(text)): while self.image_token in text[i]: - num_image_tokens = image_grid_thw[index].prod() // merge_length + num_image_tokens = image_grid_thw[index].prod().item() // merge_length text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", self.image_token) @@ -187,7 +187,7 @@ def __call__( index = 0 for i in range(len(text)): while self.video_token in text[i]: - num_video_tokens = video_grid_thw[index].prod() // merge_length + num_video_tokens = video_grid_thw[index].prod().item() // merge_length text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) index += 1 text[i] = text[i].replace("<|placeholder|>", self.video_token)