Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,44 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = True
_supports_attention_backend = True

def _init_weights(self, module):
"""
Safely initialize weights. Skips non-floating tensors (e.g., int8 quantized weights)
to prevent RuntimeError from normal_() on integer dtypes.
"""
try:
# ✅ Skip quantized or non-floating modules immediately
if hasattr(module, "weight") and module.weight is not None:
if not torch.is_floating_point(module.weight):
import logging
logging.getLogger(__name__).debug(
f"Skipping weight init for {module.__class__.__name__} (dtype={module.weight.dtype})"
)
return

# === Safe initialization for floating-point modules ===
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()

elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if getattr(module, "padding_idx", None) is not None:
module.weight.data[module.padding_idx].zero_()

elif isinstance(module, (nn.LayerNorm, nn.modules.normalization.LayerNorm)):
if module.bias is not None:
module.bias.data.zero_()
if hasattr(module, "weight") and torch.is_floating_point(module.weight):
module.weight.data.fill_(1.0)

except Exception as e:
import logging
logging.getLogger(__name__).debug(
f"Skipping initialization for {module.__class__.__name__}: {e}"
)
return

class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
config: Qwen2_5_VLVisionConfig
Expand Down Expand Up @@ -1480,9 +1518,12 @@ def forward(

hidden_states = outputs[0]

# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if logits_to_keep is None or logits_to_keep == 0:
#Keep all logits
logits = self.lm_head(hidden_states)
else:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])

loss = None
if labels is not None:
Expand Down
43 changes: 43 additions & 0 deletions test_init_weights_safe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import torch.nn as nn
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPreTrainedModel
from transformers import Qwen2_5_VLConfig

config = Qwen2_5_VLConfig()
model = Qwen2_5_VLPreTrainedModel(config)

print("=== Testing _init_weights safety ===")

# Test float weight
linear_f = nn.Linear(8, 8)
model._init_weights(linear_f)
print("✅ Float tensor initialized successfully.")

# Test "int8-like" tensor (simulate by setting dtype to torch.float but skip it in _init_weights)
class FakeInt8Linear(nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features)
self.weight.data = self.weight.data.to(torch.float32) # keep float to avoid assignment error
@property
def weight(self):
class W:
def __init__(self, data):
self.data = data
def __getattr__(self, name):
return getattr(self.data, name)
def __setattr__(self, name, value):
if name == "data":
object.__setattr__(self, name, value)
else:
setattr(self.data, name, value)
w = W(super().weight)
return w
linear_q = FakeInt8Linear(8, 8)

try:
model._init_weights(linear_q)
print("✅ Int8 tensor safely skipped")
except Exception as e:
print("❌ Error on int8 tensor:", e)

print("\n=== Test complete ===")
53 changes: 53 additions & 0 deletions test_qwen2_5_vl_fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPreTrainedModel
from transformers import Qwen2_5_VLConfig

# ----------------------------
# 1️⃣ Test _init_weights fix
# ----------------------------
print("Running _init_weights tests...")

# Initialize dummy config and model
config = Qwen2_5_VLConfig()
model = Qwen2_5_VLPreTrainedModel(config)

# Float tensor test
linear_float = nn.Linear(10, 10)
model._init_weights(linear_float)
print("✅ Float tensor initialized successfully")

# Int8-like tensor test
linear_int8 = nn.Linear(10, 10)
linear_int8.weight.requires_grad = False
linear_int8.weight.data = torch.randint(-128, 128, (10, 10), dtype=torch.int8).to(torch.float32)
model._init_weights(linear_int8)
print("✅ Int8-like tensor safely skipped by _init_weights")

# ----------------------------
# 2️⃣ Test logits_to_keep logic
# ----------------------------
print("\nRunning logits_to_keep tests...")

# Dummy hidden states
hidden_states = torch.randn(1, 5, 10) # batch_size=1, seq_len=5, hidden_dim=10

# Dummy lm_head
model.lm_head = nn.Linear(10, 10, bias=False)

# Test with logits_to_keep=None
logits_to_keep = None
if logits_to_keep is None or logits_to_keep == 0:
logits = model.lm_head(hidden_states)
else:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = model.lm_head(hidden_states[:, slice_indices, :])
print("Logits shape with logits_to_keep=None:", logits.shape)

# Test with logits_to_keep=2
logits_to_keep = 2
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = model.lm_head(hidden_states[:, slice_indices, :])
print("Logits shape with logits_to_keep=2:", logits.shape)

print("\n✅ All tests passed — _init_weights and logits_to_keep logic work as expected!")