diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 7511eb77379f..186068779713 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -302,6 +302,44 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True + def _init_weights(self, module): + """ + Safely initialize weights. Skips non-floating tensors (e.g., int8 quantized weights) + to prevent RuntimeError from normal_() on integer dtypes. + """ + try: + # ✅ Skip quantized or non-floating modules immediately + if hasattr(module, "weight") and module.weight is not None: + if not torch.is_floating_point(module.weight): + import logging + logging.getLogger(__name__).debug( + f"Skipping weight init for {module.__class__.__name__} (dtype={module.weight.dtype})" + ) + return + + # === Safe initialization for floating-point modules === + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if getattr(module, "padding_idx", None) is not None: + module.weight.data[module.padding_idx].zero_() + + elif isinstance(module, (nn.LayerNorm, nn.modules.normalization.LayerNorm)): + if module.bias is not None: + module.bias.data.zero_() + if hasattr(module, "weight") and torch.is_floating_point(module.weight): + module.weight.data.fill_(1.0) + + except Exception as e: + import logging + logging.getLogger(__name__).debug( + f"Skipping initialization for {module.__class__.__name__}: {e}" + ) + return class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config: Qwen2_5_VLVisionConfig @@ -1480,9 +1518,12 @@ def forward( hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) + if logits_to_keep is None or logits_to_keep == 0: + #Keep all logits + logits = self.lm_head(hidden_states) + else: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: diff --git a/test_init_weights_safe.py b/test_init_weights_safe.py new file mode 100644 index 000000000000..e937b7998e49 --- /dev/null +++ b/test_init_weights_safe.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPreTrainedModel +from transformers import Qwen2_5_VLConfig + +config = Qwen2_5_VLConfig() +model = Qwen2_5_VLPreTrainedModel(config) + +print("=== Testing _init_weights safety ===") + +# Test float weight +linear_f = nn.Linear(8, 8) +model._init_weights(linear_f) +print("✅ Float tensor initialized successfully.") + +# Test "int8-like" tensor (simulate by setting dtype to torch.float but skip it in _init_weights) +class FakeInt8Linear(nn.Linear): + def __init__(self, in_features, out_features): + super().__init__(in_features, out_features) + self.weight.data = self.weight.data.to(torch.float32) # keep float to avoid assignment error + @property + def weight(self): + class W: + def __init__(self, data): + self.data = data + def __getattr__(self, name): + return getattr(self.data, name) + def __setattr__(self, name, value): + if name == "data": + object.__setattr__(self, name, value) + else: + setattr(self.data, name, value) + w = W(super().weight) + return w +linear_q = FakeInt8Linear(8, 8) + +try: + model._init_weights(linear_q) + print("✅ Int8 tensor safely skipped") +except Exception as e: + print("❌ Error on int8 tensor:", e) + +print("\n=== Test complete ===") \ No newline at end of file diff --git a/test_qwen2_5_vl_fixes.py b/test_qwen2_5_vl_fixes.py new file mode 100644 index 000000000000..1be3c8db91f1 --- /dev/null +++ b/test_qwen2_5_vl_fixes.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPreTrainedModel +from transformers import Qwen2_5_VLConfig + +# ---------------------------- +# 1️⃣ Test _init_weights fix +# ---------------------------- +print("Running _init_weights tests...") + +# Initialize dummy config and model +config = Qwen2_5_VLConfig() +model = Qwen2_5_VLPreTrainedModel(config) + +# Float tensor test +linear_float = nn.Linear(10, 10) +model._init_weights(linear_float) +print("✅ Float tensor initialized successfully") + +# Int8-like tensor test +linear_int8 = nn.Linear(10, 10) +linear_int8.weight.requires_grad = False +linear_int8.weight.data = torch.randint(-128, 128, (10, 10), dtype=torch.int8).to(torch.float32) +model._init_weights(linear_int8) +print("✅ Int8-like tensor safely skipped by _init_weights") + +# ---------------------------- +# 2️⃣ Test logits_to_keep logic +# ---------------------------- +print("\nRunning logits_to_keep tests...") + +# Dummy hidden states +hidden_states = torch.randn(1, 5, 10) # batch_size=1, seq_len=5, hidden_dim=10 + +# Dummy lm_head +model.lm_head = nn.Linear(10, 10, bias=False) + +# Test with logits_to_keep=None +logits_to_keep = None +if logits_to_keep is None or logits_to_keep == 0: + logits = model.lm_head(hidden_states) +else: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = model.lm_head(hidden_states[:, slice_indices, :]) +print("Logits shape with logits_to_keep=None:", logits.shape) + +# Test with logits_to_keep=2 +logits_to_keep = 2 +slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep +logits = model.lm_head(hidden_states[:, slice_indices, :]) +print("Logits shape with logits_to_keep=2:", logits.shape) + +print("\n✅ All tests passed — _init_weights and logits_to_keep logic work as expected!") \ No newline at end of file