Skip to content

Conversation

@mauricioharley
Copy link

Description

This Pull Request fixes a RuntimeError that occurs when loading llmcompressor W8A8 quantized models (e.g., RedHatAI/Qwen2.5-VL-7B-Instruct-quantized.w8a8) due to an attempt to initialize int8 weights using torch.nn.init.normal_(), which only supports floating-point dtypes.

The issue was identified in modeling_utils.py within the _initialize_missing_keys method. When is_quantized is True, the else branch was still calling self.initialize_weights(), leading to the RuntimeError.

Proposed Change

Added a conditional check if not is_quantized: before the call to self.initialize_weights() in the else branch of the _initialize_missing_keys method. This ensures that weight initialization is skipped for quantized models, as their weights are either already defined or will be loaded from a pretrained state dictionary, making the initialization redundant and problematic.

Related Issue

Closes #39366

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 2, 2025

cc @MekkCyber @SunMarc

Comment on lines 5668 to 5669
if not is_quantized:
self.initialize_weights()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantized models shouldn't be handled differently. Maybe we need to set _is_hf_initialized somewhere in the modules that are impacted.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch, @SunMarc! I restored the regular initialize_weights() path so quantized models no longer skip it, and added a hook in the quantizers to mark the tensors we create as _is_hf_initialized. That way the initialization logic still runs, but the quantized parameters stay untouched.

This commit addresses the RuntimeError encountered when loading llmcompressor W8A8 quantized models, where `torch.nn.init.normal_()` is called on `int8` tensors during weight initialization.

Fixes huggingface#39366

Signed-off-by: Mauricio Harley <mauricioharley@gmail.com>
@github-actions
Copy link
Contributor

github-actions bot commented Oct 2, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: fbgemm_fp8, finegrained_fp8, higgs, hqq, mxfp4

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! Sorry for the radio silence, if you still want to work on this we merged #41580 which changes this a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError when loading llmcompressor W8A8 quantized model: int8 dtype in weight initialization

4 participants