diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 965084d0c24a..3df4f6d47b5a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4709,7 +4709,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, else None ) total_byte_count = defaultdict(lambda: 0) - tied_param_names = _get_tied_weight_keys(model) + tied_param_names = model.all_tied_weights_keys.keys() for param_name, device in accelerator_device_map.items(): # Skip if the parameter has already been accounted for (tied weights) if param_name in tied_param_names: