From 1f0ce1967485863c12b184b39bee0ebfc3c00b78 Mon Sep 17 00:00:00 2001 From: Mauricio Harley Date: Wed, 1 Oct 2025 21:13:21 -0300 Subject: [PATCH] fix(quantization): Skip weight initialization for quantized models This commit addresses the RuntimeError encountered when loading llmcompressor W8A8 quantized models, where `torch.nn.init.normal_()` is called on `int8` tensors during weight initialization. Fixes #39366 Signed-off-by: Mauricio Harley --- src/transformers/modeling_utils.py | 2 + src/transformers/quantizers/base.py | 42 ++++++++++++++++++- .../quantizers/quantizer_fbgemm_fp8.py | 2 + .../quantizers/quantizer_finegrained_fp8.py | 2 + .../quantizers/quantizer_fp_quant.py | 2 + .../quantizers/quantizer_higgs.py | 2 + src/transformers/quantizers/quantizer_hqq.py | 1 + .../quantizers/quantizer_mxfp4.py | 2 + .../quantizers/quantizer_quanto.py | 2 + 9 files changed, 56 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7dcafa323a9c..0730939c7f1d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -730,6 +730,7 @@ def _load_state_dict_into_meta_model( device_mesh.get_local_rank(), **sharding_kwargs, ) + hf_quantizer.register_loaded_quantized_param(model, param_name) else: param = param[...] if casting_dtype is not None: @@ -758,6 +759,7 @@ def _load_state_dict_into_meta_model( else: # TODO naming is stupid it loads it as well hf_quantizer.create_quantized_param(model, param, param_name, param_device) + hf_quantizer.register_loaded_quantized_param(model, param_name) # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # and then cast it to CPU to avoid excessive memory usage on each GPU diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index b9dd7ae10f9e..dc044208b93c 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -61,6 +61,8 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): # -- Handle extra kwargs below -- self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) self.pre_quantized = kwargs.pop("pre_quantized", True) + # Track quantized parameters we create so they do not trigger reinitialization later on. + self._loaded_quantized_keys: set[str] = set() if not self.pre_quantized and self.requires_calibration: raise ValueError( @@ -126,7 +128,45 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li missing_keys (`list[str]`, *optional*): The list of missing keys in the checkpoint compared to the state dict of the model """ - return missing_keys + filtered_keys = [] + for key in missing_keys: + if key in self._loaded_quantized_keys: + continue + + needs_quantization = False + try: + needs_quantization = self.param_needs_quantization(model, key) + except Exception: + # Some quantizers may raise if the key does not correspond to a handled parameter. Treat as non-quantized. + needs_quantization = False + + if needs_quantization: + self._loaded_quantized_keys.add(key) + continue + + filtered_keys.append(key) + + return filtered_keys + + def register_loaded_quantized_param(self, model: "PreTrainedModel", param_name: str) -> None: + """Mark a quantized parameter as loaded so it is not treated as missing and will not be reinitialized.""" + + self._loaded_quantized_keys.add(param_name) + + if not is_torch_available(): + # We cannot mutate torch objects if torch is unavailable. + return + + try: + param_or_buffer = model.get_parameter_or_buffer(param_name) + except (AttributeError, ValueError, RuntimeError): + param_or_buffer = None + + if param_or_buffer is not None: + try: + setattr(param_or_buffer, "_is_hf_initialized", True) + except Exception: + pass def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]: """ diff --git a/src/transformers/quantizers/quantizer_fbgemm_fp8.py b/src/transformers/quantizers/quantizer_fbgemm_fp8.py index 22c90aa446dd..ee007967c65b 100644 --- a/src/transformers/quantizers/quantizer_fbgemm_fp8.py +++ b/src/transformers/quantizers/quantizer_fbgemm_fp8.py @@ -223,6 +223,8 @@ def _process_model_before_weight_loading( def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts + missing_keys = super().update_missing_keys(model, missing_keys, prefix) + not_missing_keys = [] for name, module in model.named_modules(): if isinstance(module, (FbgemmFp8Linear, FbgemmFp8Llama4TextExperts)): diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index dc0123c1b007..56892d5b8181 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -173,6 +173,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: from ..integrations import FP8Linear + missing_keys = super().update_missing_keys(model, missing_keys, prefix) + not_missing_keys = [] for name, module in model.named_modules(): if isinstance(module, FP8Linear): diff --git a/src/transformers/quantizers/quantizer_fp_quant.py b/src/transformers/quantizers/quantizer_fp_quant.py index a7bc077776fe..657513fe8470 100644 --- a/src/transformers/quantizers/quantizer_fp_quant.py +++ b/src/transformers/quantizers/quantizer_fp_quant.py @@ -146,6 +146,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: from fp_quant import FPQuantLinear + missing_keys = super().update_missing_keys(model, missing_keys, prefix) + fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)} def should_exclude(key: str) -> bool: diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 41e2d86cf1ec..2da070d550d8 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -160,6 +160,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: from ..integrations import HiggsLinear + missing_keys = super().update_missing_keys(model, missing_keys, prefix) + higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)} def should_update(key: str) -> bool: diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 94907c3b48fc..5f6759948ef0 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -86,6 +86,7 @@ def validate_environment(self, *args, **kwargs): def update_missing_keys( self, model: "PreTrainedModel", missing_keys: list[str], prefix: str, **kwargs ) -> list[str]: + missing_keys = super().update_missing_keys(model, missing_keys, prefix) if self.pre_quantized: return [key for key in missing_keys if ("weight" not in key)] else: diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 04cf8ec56c96..66ada5dde6da 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -327,6 +327,8 @@ def _process_model_before_weight_loading( def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: from ..integrations import Mxfp4GptOssExperts + missing_keys = super().update_missing_keys(model, missing_keys, prefix) + not_missing_keys = [] for name, module in model.named_modules(): if isinstance(module, Mxfp4GptOssExperts): diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py index 451179aaf723..0c92a5d435cd 100644 --- a/src/transformers/quantizers/quantizer_quanto.py +++ b/src/transformers/quantizers/quantizer_quanto.py @@ -91,6 +91,8 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li if is_optimum_quanto_available(): from optimum.quanto import QModuleMixin + missing_keys = super().update_missing_keys(model, missing_keys, prefix) + not_missing_keys = [] for name, module in model.named_modules(): if isinstance(module, QModuleMixin):