From c88856604e60f668174b40749554b6b8a9bad2e3 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 20:51:42 +0100 Subject: [PATCH 1/7] fix --- src/transformers/integrations/accelerate.py | 62 ++------------------- 1 file changed, 5 insertions(+), 57 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 64a98d0c2fb6..9ebb4d082812 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -159,61 +159,6 @@ def wrapper(*args, **kwargs): setattr(torch, torch_function_name, old_torch_function) -def find_tied_parameters(model: "nn.Module", **kwargs): - """ - Find the tied parameters in a given model. - - - - The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore - them. - - - - Args: - model (`torch.nn.Module`): The model to inspect. - - Returns: - list[list[str]]: A list of lists of parameter names being all tied together. - - Example: - - ```py - >>> from collections import OrderedDict - >>> import torch.nn as nn - - >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) - >>> model.linear2.weight = model.linear1.weight - >>> find_tied_parameters(model) - [['linear1.weight', 'linear2.weight']] - ``` - """ - - # get ALL model parameters and their names - all_named_parameters = dict(model.named_parameters(remove_duplicate=False)) - - # get ONLY unique named parameters, - # if parameter is tied and have multiple names, it will be included only once - no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True)) - - # the difference of the two sets will give us the tied parameters - tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) - - # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know - # which names refer to the same parameter. To identify this, we need to group them together. - tied_param_groups = {} - for tied_param_name in tied_param_names: - tied_param = all_named_parameters[tied_param_name] - for param_name, param in no_duplicate_named_parameters.items(): - # compare if parameters are the same, if so, group their names together - if param is tied_param: - if param_name not in tied_param_groups: - tied_param_groups[param_name] = [] - tied_param_groups[param_name].append(tied_param_name) - - return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()] - - def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None: from ..modeling_utils import get_torch_context_manager_or_global_device @@ -591,8 +536,11 @@ def _init_infer_auto_device_map( if tied_parameters is None: if len(model.all_tied_weights_keys) > 0: - # create a list of list of tied params - tied_parameters = [list(t) for t in model.all_tied_weights_keys.items()] + # create a list of list of tied params based on unique tied groups + groups = set(model.all_tied_weights_keys.values()) + tied_parameters = [ + [k for k, v in model.all_tied_weights_keys.items() if v == target] + [target] for target in groups + ] else: tied_parameters = [[]] From 75f5ab6ad018f77e55a7094e4da9e3f62f948d7e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 20:57:13 +0100 Subject: [PATCH 2/7] add sorting --- src/transformers/integrations/accelerate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 9ebb4d082812..dd9945ba3706 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -539,7 +539,9 @@ def _init_infer_auto_device_map( # create a list of list of tied params based on unique tied groups groups = set(model.all_tied_weights_keys.values()) tied_parameters = [ - [k for k, v in model.all_tied_weights_keys.items() if v == target] + [target] for target in groups + sorted( + [k for k, v in model.all_tied_weights_keys.items() if v == target] + [target] for target in groups + ) ] else: tied_parameters = [[]] From fb1b5756c0aceeac6d132f6bebb60736c0395f9a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 21:01:05 +0100 Subject: [PATCH 3/7] typo --- src/transformers/integrations/accelerate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index dd9945ba3706..6403e9d55956 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -539,9 +539,8 @@ def _init_infer_auto_device_map( # create a list of list of tied params based on unique tied groups groups = set(model.all_tied_weights_keys.values()) tied_parameters = [ - sorted( - [k for k, v in model.all_tied_weights_keys.items() if v == target] + [target] for target in groups - ) + sorted([k for k, v in model.all_tied_weights_keys.items() if v == target] + [target]) + for target in groups ] else: tied_parameters = [[]] From a18b6ed1e0efacb7e2807e20db4d07919dc7ba66 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 23:04:23 +0100 Subject: [PATCH 4/7] fix --- src/transformers/integrations/accelerate.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 6403e9d55956..1f977b9bc53f 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -216,11 +216,20 @@ def compute_module_sizes( leaves_module_sizes = defaultdict(int) if buffers_only: - named_tensors = model.named_buffers(recurse=True) + iterator = model.named_buffers() else: - named_tensors = model.state_dict().items() + # We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space + def all_tensors(): + yield from model.named_parameters() + yield from model.named_buffers() - for name, param in named_tensors: + iterator = all_tensors() + + tied_keys = getattr(model, "all_tied_weights_keys", {}).keys() + for name, param in iterator: + # Do not count tied keys + if name in tied_keys: + continue if hf_quantizer is not None: dtype_size = hf_quantizer.param_element_size(model, name) else: From 6ead46d49466b2822060886319546e0cfce5cbc0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 19 Nov 2025 15:42:58 +0100 Subject: [PATCH 5/7] improve doc --- src/transformers/integrations/accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 1f977b9bc53f..25823ff8d869 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -227,7 +227,7 @@ def all_tensors(): tied_keys = getattr(model, "all_tied_weights_keys", {}).keys() for name, param in iterator: - # Do not count tied keys + # Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator) if name in tied_keys: continue if hf_quantizer is not None: From 3fea995d522a06bdc5914cd0f46d27f0b52636a7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 19 Nov 2025 15:45:15 +0100 Subject: [PATCH 6/7] doc --- src/transformers/integrations/accelerate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 25823ff8d869..93d807884df8 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -228,6 +228,7 @@ def all_tensors(): tied_keys = getattr(model, "all_tied_weights_keys", {}).keys() for name, param in iterator: # Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator) + # If the model is already tied, then they simply do not appear in the iterator anyway if name in tied_keys: continue if hf_quantizer is not None: From 76a4b0814317ba43ab58f87a2aa034856ea7c5d7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 19 Nov 2025 15:45:54 +0100 Subject: [PATCH 7/7] doc --- src/transformers/integrations/accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 93d807884df8..7d6634c70ece 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -228,7 +228,7 @@ def all_tensors(): tied_keys = getattr(model, "all_tied_weights_keys", {}).keys() for name, param in iterator: # Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator) - # If the model is already tied, then they simply do not appear in the iterator anyway + # If the model is already tied, then they simply do not appear in the iterator anyway (remove_duplicates=True by default) if name in tied_keys: continue if hf_quantizer is not None: