Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ def _load_state_dict_into_meta_model(
device_mesh.get_local_rank(),
**sharding_kwargs,
)
hf_quantizer.register_loaded_quantized_param(model, param_name)
else:
param = param[...]
if casting_dtype is not None:
Expand Down Expand Up @@ -758,6 +759,7 @@ def _load_state_dict_into_meta_model(
else:
# TODO naming is stupid it loads it as well
hf_quantizer.create_quantized_param(model, param, param_name, param_device)
hf_quantizer.register_loaded_quantized_param(model, param_name)

# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
Expand Down
42 changes: 41 additions & 1 deletion src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
# -- Handle extra kwargs below --
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
self.pre_quantized = kwargs.pop("pre_quantized", True)
# Track quantized parameters we create so they do not trigger reinitialization later on.
self._loaded_quantized_keys: set[str] = set()

if not self.pre_quantized and self.requires_calibration:
raise ValueError(
Expand Down Expand Up @@ -126,7 +128,45 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li
missing_keys (`list[str]`, *optional*):
The list of missing keys in the checkpoint compared to the state dict of the model
"""
return missing_keys
filtered_keys = []
for key in missing_keys:
if key in self._loaded_quantized_keys:
continue

needs_quantization = False
try:
needs_quantization = self.param_needs_quantization(model, key)
except Exception:
# Some quantizers may raise if the key does not correspond to a handled parameter. Treat as non-quantized.
needs_quantization = False

if needs_quantization:
self._loaded_quantized_keys.add(key)
continue

filtered_keys.append(key)

return filtered_keys

def register_loaded_quantized_param(self, model: "PreTrainedModel", param_name: str) -> None:
"""Mark a quantized parameter as loaded so it is not treated as missing and will not be reinitialized."""

self._loaded_quantized_keys.add(param_name)

if not is_torch_available():
# We cannot mutate torch objects if torch is unavailable.
return

try:
param_or_buffer = model.get_parameter_or_buffer(param_name)
except (AttributeError, ValueError, RuntimeError):
param_or_buffer = None

if param_or_buffer is not None:
try:
setattr(param_or_buffer, "_is_hf_initialized", True)
except Exception:
pass

def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/quantizers/quantizer_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def _process_model_before_weight_loading(
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts

missing_keys = super().update_missing_keys(model, missing_keys, prefix)

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, (FbgemmFp8Linear, FbgemmFp8Llama4TextExperts)):
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/quantizers/quantizer_finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from ..integrations import FP8Linear

missing_keys = super().update_missing_keys(model, missing_keys, prefix)

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, FP8Linear):
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/quantizers/quantizer_fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from fp_quant import FPQuantLinear

missing_keys = super().update_missing_keys(model, missing_keys, prefix)

fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}

def should_exclude(key: str) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/quantizers/quantizer_higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from ..integrations import HiggsLinear

missing_keys = super().update_missing_keys(model, missing_keys, prefix)

higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}

def should_update(key: str) -> bool:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/quantizers/quantizer_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def validate_environment(self, *args, **kwargs):
def update_missing_keys(
self, model: "PreTrainedModel", missing_keys: list[str], prefix: str, **kwargs
) -> list[str]:
missing_keys = super().update_missing_keys(model, missing_keys, prefix)
if self.pre_quantized:
return [key for key in missing_keys if ("weight" not in key)]
else:
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/quantizers/quantizer_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def _process_model_before_weight_loading(
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from ..integrations import Mxfp4GptOssExperts

missing_keys = super().update_missing_keys(model, missing_keys, prefix)

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, Mxfp4GptOssExperts):
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/quantizers/quantizer_quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li
if is_optimum_quanto_available():
from optimum.quanto import QModuleMixin

missing_keys = super().update_missing_keys(model, missing_keys, prefix)

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
Expand Down