Skip to content

Commit d714497

Browse files
committed
fix
1 parent 4541954 commit d714497

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

src/transformers/integrations/accelerate.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919
import copy
2020
import inspect
21+
import itertools
2122
import os
2223
import re
2324
from collections import OrderedDict, defaultdict
2425
from contextlib import contextmanager
2526
from typing import TYPE_CHECKING, Optional, Union
2627

28+
from safetensors import safe_open
2729
from safetensors.torch import save_file
2830

2931
from ..utils import (
@@ -505,36 +507,70 @@ def expand_device_map(device_map, param_names):
505507
return new_device_map
506508

507509

510+
def update_param_name(param_name: str, weight_pattern_alt, weight_pattern_by_group_name, source_to_target):
511+
from ..core_model_loading import match_glob
512+
513+
if weight_pattern_alt is None:
514+
return param_name
515+
516+
matched_pattern = match_glob(param_name, weight_pattern_alt, weight_pattern_by_group_name)
517+
if matched_pattern is not None:
518+
converter = source_to_target[matched_pattern]
519+
# Only change name if it's a simple renaming, i.e. no custom Ops
520+
if len(converter.source_keys) == 1 and len(converter.target_keys) == 1:
521+
source_pattern = converter.source_keys[0]
522+
target_pattern = converter.target_keys[0]
523+
return re.sub(source_pattern, target_pattern, param_name)
524+
return param_name
525+
526+
508527
def accelerate_disk_offload(
509528
disk_offload_folder: str | None,
510529
checkpoint_files: list[str] | None,
511530
device_map: dict,
512531
expected_keys: list[str],
513532
sharded_metadata: dict | None,
514533
dtype: torch.dtype | None,
534+
weight_mapping=None,
515535
):
536+
from ..core_model_loading import build_glob_alt
537+
516538
if disk_offload_folder is not None:
517539
os.makedirs(disk_offload_folder, exist_ok=True)
518540
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
519541

542+
_patterns, weight_pattern_alt, weight_pattern_by_group_name = None, None, None
543+
if weight_mapping is not None:
544+
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
545+
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
546+
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
547+
520548
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
521549
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
522550
if is_offloaded_safetensors:
523551
param_device_map = expand_device_map(device_map, expected_keys)
524552
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
525553
if sharded_metadata is None:
526-
weight_map = dict.fromkeys(expected_keys, checkpoint_files[0])
554+
weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
527555
else:
528556
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
529557
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
558+
559+
# Update the weight names according to the `weight_mapping`
560+
weight_renaming_map = {
561+
update_param_name(k, weight_pattern_alt, weight_pattern_by_group_name, source_to_target): k
562+
for k in weight_map
563+
}
564+
565+
# Prepare the index using existing safetensors files
530566
disk_offload_index = {
531-
name: {
532-
"safetensors_file": file,
533-
"weight_name": name,
567+
target_name: {
568+
"safetensors_file": weight_map[source_name],
569+
"weight_name": source_name,
534570
"dtype": str_dtype,
535571
}
536-
for name, file in weight_map.items()
537-
if param_device_map[name] == "disk"
572+
for target_name, source_name in weight_renaming_map.items()
573+
if param_device_map[target_name] == "disk"
538574
}
539575
# In this case we will resave every offloaded weight
540576
else:

src/transformers/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4182,6 +4182,7 @@ def _load_pretrained_model(
41824182
expected_keys,
41834183
sharded_metadata,
41844184
dtype,
4185+
weight_mapping,
41854186
)
41864187

41874188
# Warmup cuda to load the weights much faster on devices

0 commit comments

Comments
 (0)