Skip to content

Commit e6d835b

Browse files
fix(quantization): Skip weight initialization for quantized models
This commit addresses the RuntimeError encountered when loading llmcompressor W8A8 quantized models, where `torch.nn.init.normal_()` is called on `int8` tensors during weight initialization. Fixes #39366 Signed-off-by: Mauricio Harley <mauricioharley@gmail.com>
1 parent 2f3e266 commit e6d835b

File tree

9 files changed

+56
-1
lines changed

9 files changed

+56
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ def _load_state_dict_into_meta_model(
730730
device_mesh.get_local_rank(),
731731
**sharding_kwargs,
732732
)
733+
hf_quantizer.register_loaded_quantized_param(model, param_name)
733734
else:
734735
param = param[...]
735736
if casting_dtype is not None:
@@ -758,6 +759,7 @@ def _load_state_dict_into_meta_model(
758759
else:
759760
# TODO naming is stupid it loads it as well
760761
hf_quantizer.create_quantized_param(model, param, param_name, param_device)
762+
hf_quantizer.register_loaded_quantized_param(model, param_name)
761763

762764
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
763765
# and then cast it to CPU to avoid excessive memory usage on each GPU

src/transformers/quantizers/base.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
6161
# -- Handle extra kwargs below --
6262
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
6363
self.pre_quantized = kwargs.pop("pre_quantized", True)
64+
# Track quantized parameters we create so they do not trigger reinitialization later on.
65+
self._loaded_quantized_keys: set[str] = set()
6466

6567
if not self.pre_quantized and self.requires_calibration:
6668
raise ValueError(
@@ -126,7 +128,45 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li
126128
missing_keys (`list[str]`, *optional*):
127129
The list of missing keys in the checkpoint compared to the state dict of the model
128130
"""
129-
return missing_keys
131+
filtered_keys = []
132+
for key in missing_keys:
133+
if key in self._loaded_quantized_keys:
134+
continue
135+
136+
needs_quantization = False
137+
try:
138+
needs_quantization = self.param_needs_quantization(model, key)
139+
except Exception:
140+
# Some quantizers may raise if the key does not correspond to a handled parameter. Treat as non-quantized.
141+
needs_quantization = False
142+
143+
if needs_quantization:
144+
self._loaded_quantized_keys.add(key)
145+
continue
146+
147+
filtered_keys.append(key)
148+
149+
return filtered_keys
150+
151+
def register_loaded_quantized_param(self, model: "PreTrainedModel", param_name: str) -> None:
152+
"""Mark a quantized parameter as loaded so it is not treated as missing and will not be reinitialized."""
153+
154+
self._loaded_quantized_keys.add(param_name)
155+
156+
if not is_torch_available():
157+
# We cannot mutate torch objects if torch is unavailable.
158+
return
159+
160+
try:
161+
param_or_buffer = model.get_parameter_or_buffer(param_name)
162+
except (AttributeError, ValueError, RuntimeError):
163+
param_or_buffer = None
164+
165+
if param_or_buffer is not None:
166+
try:
167+
setattr(param_or_buffer, "_is_hf_initialized", True)
168+
except Exception:
169+
pass
130170

131171
def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]:
132172
"""

src/transformers/quantizers/quantizer_fbgemm_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def _process_model_before_weight_loading(
223223
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
224224
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
225225

226+
missing_keys = super().update_missing_keys(model, missing_keys, prefix)
227+
226228
not_missing_keys = []
227229
for name, module in model.named_modules():
228230
if isinstance(module, (FbgemmFp8Linear, FbgemmFp8Llama4TextExperts)):

src/transformers/quantizers/quantizer_finegrained_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
173173
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
174174
from ..integrations import FP8Linear
175175

176+
missing_keys = super().update_missing_keys(model, missing_keys, prefix)
177+
176178
not_missing_keys = []
177179
for name, module in model.named_modules():
178180
if isinstance(module, FP8Linear):

src/transformers/quantizers/quantizer_fp_quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
146146
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
147147
from fp_quant import FPQuantLinear
148148

149+
missing_keys = super().update_missing_keys(model, missing_keys, prefix)
150+
149151
fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
150152

151153
def should_exclude(key: str) -> bool:

src/transformers/quantizers/quantizer_higgs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
160160
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
161161
from ..integrations import HiggsLinear
162162

163+
missing_keys = super().update_missing_keys(model, missing_keys, prefix)
164+
163165
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
164166

165167
def should_update(key: str) -> bool:

src/transformers/quantizers/quantizer_hqq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def validate_environment(self, *args, **kwargs):
8686
def update_missing_keys(
8787
self, model: "PreTrainedModel", missing_keys: list[str], prefix: str, **kwargs
8888
) -> list[str]:
89+
missing_keys = super().update_missing_keys(model, missing_keys, prefix)
8990
if self.pre_quantized:
9091
return [key for key in missing_keys if ("weight" not in key)]
9192
else:

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,8 @@ def _process_model_before_weight_loading(
327327
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
328328
from ..integrations import Mxfp4GptOssExperts
329329

330+
missing_keys = super().update_missing_keys(model, missing_keys, prefix)
331+
330332
not_missing_keys = []
331333
for name, module in model.named_modules():
332334
if isinstance(module, Mxfp4GptOssExperts):

src/transformers/quantizers/quantizer_quanto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li
9191
if is_optimum_quanto_available():
9292
from optimum.quanto import QModuleMixin
9393

94+
missing_keys = super().update_missing_keys(model, missing_keys, prefix)
95+
9496
not_missing_keys = []
9597
for name, module in model.named_modules():
9698
if isinstance(module, QModuleMixin):

0 commit comments

Comments
 (0)