@@ -159,61 +159,6 @@ def wrapper(*args, **kwargs):
159159 setattr (torch , torch_function_name , old_torch_function )
160160
161161
162- def find_tied_parameters (model : "nn.Module" , ** kwargs ):
163- """
164- Find the tied parameters in a given model.
165-
166- <Tip warning={true}>
167-
168- The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
169- them.
170-
171- </Tip>
172-
173- Args:
174- model (`torch.nn.Module`): The model to inspect.
175-
176- Returns:
177- list[list[str]]: A list of lists of parameter names being all tied together.
178-
179- Example:
180-
181- ```py
182- >>> from collections import OrderedDict
183- >>> import torch.nn as nn
184-
185- >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
186- >>> model.linear2.weight = model.linear1.weight
187- >>> find_tied_parameters(model)
188- [['linear1.weight', 'linear2.weight']]
189- ```
190- """
191-
192- # get ALL model parameters and their names
193- all_named_parameters = dict (model .named_parameters (remove_duplicate = False ))
194-
195- # get ONLY unique named parameters,
196- # if parameter is tied and have multiple names, it will be included only once
197- no_duplicate_named_parameters = dict (model .named_parameters (remove_duplicate = True ))
198-
199- # the difference of the two sets will give us the tied parameters
200- tied_param_names = set (all_named_parameters .keys ()) - set (no_duplicate_named_parameters .keys ())
201-
202- # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
203- # which names refer to the same parameter. To identify this, we need to group them together.
204- tied_param_groups = {}
205- for tied_param_name in tied_param_names :
206- tied_param = all_named_parameters [tied_param_name ]
207- for param_name , param in no_duplicate_named_parameters .items ():
208- # compare if parameters are the same, if so, group their names together
209- if param is tied_param :
210- if param_name not in tied_param_groups :
211- tied_param_groups [param_name ] = []
212- tied_param_groups [param_name ].append (tied_param_name )
213-
214- return [sorted ([weight ] + list (set (tied ))) for weight , tied in tied_param_groups .items ()]
215-
216-
217162def check_and_set_device_map (device_map : "torch.device | int | str | dict | None" ) -> dict | str | None :
218163 from ..modeling_utils import get_torch_context_manager_or_global_device
219164
@@ -271,11 +216,21 @@ def compute_module_sizes(
271216 leaves_module_sizes = defaultdict (int )
272217
273218 if buffers_only :
274- named_tensors = model .named_buffers (recurse = True )
219+ iterator = model .named_buffers ()
275220 else :
276- named_tensors = model .state_dict ().items ()
277-
278- for name , param in named_tensors :
221+ # We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space
222+ def all_tensors ():
223+ yield from model .named_parameters ()
224+ yield from model .named_buffers ()
225+
226+ iterator = all_tensors ()
227+
228+ tied_keys = getattr (model , "all_tied_weights_keys" , {}).keys ()
229+ for name , param in iterator :
230+ # Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator)
231+ # If the model is already tied, then they simply do not appear in the iterator anyway (remove_duplicates=True by default)
232+ if name in tied_keys :
233+ continue
279234 if hf_quantizer is not None :
280235 dtype_size = hf_quantizer .param_element_size (model , name )
281236 else :
@@ -591,8 +546,12 @@ def _init_infer_auto_device_map(
591546
592547 if tied_parameters is None :
593548 if len (model .all_tied_weights_keys ) > 0 :
594- # create a list of list of tied params
595- tied_parameters = [list (t ) for t in model .all_tied_weights_keys .items ()]
549+ # create a list of list of tied params based on unique tied groups
550+ groups = set (model .all_tied_weights_keys .values ())
551+ tied_parameters = [
552+ sorted ([k for k , v in model .all_tied_weights_keys .items () if v == target ] + [target ])
553+ for target in groups
554+ ]
596555 else :
597556 tied_parameters = [[]]
598557
0 commit comments