Skip to content

Commit a74be0d

Browse files
authored
Fix accelerate integration (#42264)
* fix * add sorting * typo * fix * improve doc * doc * doc
1 parent e5c8a06 commit a74be0d

File tree

1 file changed

+20
-61
lines changed

1 file changed

+20
-61
lines changed

src/transformers/integrations/accelerate.py

Lines changed: 20 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -159,61 +159,6 @@ def wrapper(*args, **kwargs):
159159
setattr(torch, torch_function_name, old_torch_function)
160160

161161

162-
def find_tied_parameters(model: "nn.Module", **kwargs):
163-
"""
164-
Find the tied parameters in a given model.
165-
166-
<Tip warning={true}>
167-
168-
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
169-
them.
170-
171-
</Tip>
172-
173-
Args:
174-
model (`torch.nn.Module`): The model to inspect.
175-
176-
Returns:
177-
list[list[str]]: A list of lists of parameter names being all tied together.
178-
179-
Example:
180-
181-
```py
182-
>>> from collections import OrderedDict
183-
>>> import torch.nn as nn
184-
185-
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
186-
>>> model.linear2.weight = model.linear1.weight
187-
>>> find_tied_parameters(model)
188-
[['linear1.weight', 'linear2.weight']]
189-
```
190-
"""
191-
192-
# get ALL model parameters and their names
193-
all_named_parameters = dict(model.named_parameters(remove_duplicate=False))
194-
195-
# get ONLY unique named parameters,
196-
# if parameter is tied and have multiple names, it will be included only once
197-
no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True))
198-
199-
# the difference of the two sets will give us the tied parameters
200-
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
201-
202-
# 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
203-
# which names refer to the same parameter. To identify this, we need to group them together.
204-
tied_param_groups = {}
205-
for tied_param_name in tied_param_names:
206-
tied_param = all_named_parameters[tied_param_name]
207-
for param_name, param in no_duplicate_named_parameters.items():
208-
# compare if parameters are the same, if so, group their names together
209-
if param is tied_param:
210-
if param_name not in tied_param_groups:
211-
tied_param_groups[param_name] = []
212-
tied_param_groups[param_name].append(tied_param_name)
213-
214-
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
215-
216-
217162
def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None:
218163
from ..modeling_utils import get_torch_context_manager_or_global_device
219164

@@ -271,11 +216,21 @@ def compute_module_sizes(
271216
leaves_module_sizes = defaultdict(int)
272217

273218
if buffers_only:
274-
named_tensors = model.named_buffers(recurse=True)
219+
iterator = model.named_buffers()
275220
else:
276-
named_tensors = model.state_dict().items()
277-
278-
for name, param in named_tensors:
221+
# We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space
222+
def all_tensors():
223+
yield from model.named_parameters()
224+
yield from model.named_buffers()
225+
226+
iterator = all_tensors()
227+
228+
tied_keys = getattr(model, "all_tied_weights_keys", {}).keys()
229+
for name, param in iterator:
230+
# Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator)
231+
# If the model is already tied, then they simply do not appear in the iterator anyway (remove_duplicates=True by default)
232+
if name in tied_keys:
233+
continue
279234
if hf_quantizer is not None:
280235
dtype_size = hf_quantizer.param_element_size(model, name)
281236
else:
@@ -591,8 +546,12 @@ def _init_infer_auto_device_map(
591546

592547
if tied_parameters is None:
593548
if len(model.all_tied_weights_keys) > 0:
594-
# create a list of list of tied params
595-
tied_parameters = [list(t) for t in model.all_tied_weights_keys.items()]
549+
# create a list of list of tied params based on unique tied groups
550+
groups = set(model.all_tied_weights_keys.values())
551+
tied_parameters = [
552+
sorted([k for k, v in model.all_tied_weights_keys.items() if v == target] + [target])
553+
for target in groups
554+
]
596555
else:
597556
tied_parameters = [[]]
598557

0 commit comments

Comments
 (0)