Skip to content

Commit 8390b01

Browse files
committed
doc
1 parent d714497 commit 8390b01

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/transformers/integrations/accelerate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,11 @@ def expand_device_map(device_map, param_names):
507507
return new_device_map
508508

509509

510-
def update_param_name(param_name: str, weight_pattern_alt, weight_pattern_by_group_name, source_to_target):
510+
def update_param_name(param_name: str, weight_pattern_alt, weight_pattern_by_group_name, source_to_target) -> str:
511+
"""Update a source `param_name` in a checkpoint into the target name that the model expects, if different.
512+
This uses the same logic as `core_model_loading.py`."""
513+
# TODO Cyril: This function would not even need to exist if the Converter entries already contained the
514+
# full expanded source and target names
511515
from ..core_model_loading import match_glob
512516

513517
if weight_pattern_alt is None:

0 commit comments

Comments
 (0)