Skip to content

Commit d006d84

Browse files
fix(quantization): Skip weight initialization for quantized models
This commit addresses the RuntimeError encountered when loading llmcompressor W8A8 quantized models, where `torch.nn.init.normal_()` is called on `int8` tensors during weight initialization. The `_initialize_missing_keys` method in `modeling_utils.py` was unconditionally calling `self.initialize_weights()`. For quantized models, this initialization is unnecessary and causes a `RuntimeError` as `normal_()` does not support integer dtypes. By adding a check `if not is_quantized:` before calling `self.initialize_weights()`, we ensure that this problematic initialization step is skipped for quantized models, resolving the `RuntimeError` and improving compatibility with `llmcompressor` W8A8 models. Fixes #39366 Signed-off-by: Mauricio Harley <mauricioharley@gmail.com>
1 parent 9f2d566 commit d006d84

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5665,7 +5665,8 @@ def set_is_initialized_for_modules(module):
56655665
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
56665666
self.initialize_weights()
56675667
else:
5668-
self.initialize_weights()
5668+
if not is_quantized:
5669+
self.initialize_weights()
56695670

56705671
def _adjust_missing_and_unexpected_keys(
56715672
self, missing_keys: list[str], unexpected_keys: list[str], loading_task_model_from_base_state_dict: bool

0 commit comments

Comments
 (0)