Skip to content

Commit af6a36a

Browse files
authored
[loading] Re-add and improve disk offloading support (#42242)
* unskip tests * first shot * offload in safetensors format * remove hard-coded value * update error * typo * fix * update test * fix * return it * post rebase * improve var names * improve names * fix finally * comment * fix tests * fix * simplify * fix * doc * fix * remove additional tiying after rebase * update test function source * fix * post rebase * new renaming patterns * clear confusion about variable names * create cleaner function * better doc * other tests * remove skip * unskip other tests
1 parent 6940b44 commit af6a36a

File tree

8 files changed

+186
-110
lines changed

8 files changed

+186
-110
lines changed

src/transformers/core_model_loading.py

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import torch
3030

31+
from .integrations.accelerate import offload_weight
3132
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer
3233
from .utils import is_torch_greater_or_equal, logging
3334

@@ -397,7 +398,7 @@ def dot_natural_key(s: str):
397398

398399
@contextmanager
399400
def log_to_misc(
400-
layer_name: str,
401+
first_target_key: str,
401402
misc: MutableMapping[str, str],
402403
extras: Any = None,
403404
op: Union[list[ConversionOps], ConversionOps, None] = None,
@@ -421,22 +422,22 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
421422
if isinstance(extras, tuple) and len(extras) == 2:
422423
values, target_keys = extras
423424
descriptor = f"{op_name} " if op_name else ""
424-
misc[layer_name] = (
425+
misc[first_target_key] = (
425426
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}"
426427
)
427428
elif isinstance(extras, str):
428429
suffix = f" via {op_name}" if op_name else ""
429-
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
430+
misc[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
430431
elif extras is None and op_name:
431-
misc[layer_name] = f"{op_name}: {e}"
432+
misc[first_target_key] = f"{op_name}: {e}"
432433
else:
433-
misc[layer_name] = f"{extras} |Error: {e}"
434+
misc[first_target_key] = f"{extras} |Error: {e}"
434435
raise SkipLayer()
435436

436437

437438
def set_param_for_module(
438439
model: PreTrainedModel,
439-
layer_name: str,
440+
target_name: str,
440441
param_value: torch.Tensor,
441442
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
442443
missing_keys: MutableSet[str],
@@ -445,17 +446,13 @@ def set_param_for_module(
445446
distributed_operation: Optional[TensorParallelLayer],
446447
hf_quantizer: HfQuantizer,
447448
):
448-
with log_to_misc(layer_name, misc, layer_name):
449-
module_path, _, param_name = layer_name.rpartition(".")
449+
with log_to_misc(target_name, misc, target_name):
450+
module_path, _, param_name = target_name.rpartition(".")
450451
module_obj = model.get_submodule(module_path) if module_path else model
451-
if isinstance(param_value, list):
452-
param_value = param_value[0]
453-
elif not isinstance(param_value, torch.nn.Parameter):
454-
param_value = param_value[...]
455452

456453
ref = getattr(module_obj, param_name)
457454
if ref is None:
458-
unexpected_keys.add(layer_name)
455+
unexpected_keys.add(target_name)
459456
else:
460457
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
461458
if not isinstance(param_value, torch.nn.Parameter):
@@ -475,17 +472,35 @@ def set_param_for_module(
475472
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
476473

477474
# Remove from missing keys (it's either mismatched, or all good)
478-
missing_keys.discard(layer_name)
475+
missing_keys.discard(target_name)
479476
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
480-
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
477+
mismatch_keys.add((target_name, param_value.shape, ref.shape))
481478
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
482479
else:
483-
param_value._is_hf_initialized = (
484-
True # super important otherwise _init_weight re-initi if bias is missing
485-
)
480+
# super important otherwise _init_weight will re-init the param
481+
param_value._is_hf_initialized = True
486482
setattr(module_obj, param_name, param_value)
487483

488484

485+
def offload_and_maybe_resave_param(
486+
target_name: str,
487+
param: torch.Tensor,
488+
missing_keys: MutableSet[str],
489+
disk_offload_folder: str,
490+
disk_offload_index: dict,
491+
applied_ops: WeightConverter | WeightRenaming,
492+
) -> dict:
493+
"""Takes care of correctly offloading `param`. If it's not already present in the `disk_offload_index`, or if any
494+
WeightConverter operations have been applied, it will resave the new parameter. Otherwise, it will use the original
495+
`disk_offload_index` for this given param."""
496+
# We need to remove from missing keys
497+
missing_keys.discard(target_name)
498+
# If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save
499+
if target_name not in disk_offload_index or isinstance(applied_ops, WeightConverter):
500+
disk_offload_index = offload_weight(param, target_name, disk_offload_folder, disk_offload_index)
501+
return disk_offload_index
502+
503+
489504
class SkipLayer(Exception):
490505
"""Control-flow sentinel: abort processing of the current layer only."""
491506

@@ -521,6 +536,8 @@ def convert_and_load_state_dict_in_model(
521536
device_map: dict | None = None,
522537
dtype_plan: dict | None = None,
523538
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
539+
disk_offload_index: dict | None = None,
540+
disk_offload_folder: str | None = None,
524541
):
525542
r"""
526543
We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules.
@@ -612,6 +629,7 @@ def convert_and_load_state_dict_in_model(
612629
prefix = model.base_model_prefix
613630
tp_plan = tp_plan or {}
614631
device_map = device_map or {"": "cpu"}
632+
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
615633
device_map_regex = re.compile(
616634
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
617635
)
@@ -708,9 +726,11 @@ def convert_and_load_state_dict_in_model(
708726
shard_index,
709727
)
710728

711-
if future is None: # TODO handle disk offload
729+
if future is None:
712730
device_match = device_map_regex.match(renamed_key)
713731
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
732+
# If disk, we need to materialize on cpu first
733+
param_device = "cpu" if param_device == "disk" else param_device
714734
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
715735

716736
mapping.add_tensor(renamed_key, original_key, source_pattern, future)
@@ -723,30 +743,40 @@ def convert_and_load_state_dict_in_model(
723743

724744
total_entries = len(param_name_to_load)
725745
with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
726-
for layer_name, mapping in param_name_to_load.items():
746+
for first_param_name, mapping in param_name_to_load.items():
727747
pbar.update(1)
728-
pbar.set_postfix({"Materializing param": layer_name})
748+
pbar.set_postfix({"Materializing param": first_param_name})
729749
pbar.refresh()
730750
try:
731751
realized_value, misc = mapping.convert(
732-
layer_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys
752+
first_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys
733753
)
734-
for k, output_value in realized_value.items():
735-
set_param_for_module(
736-
model,
737-
k,
738-
output_value,
739-
mismatch_keys,
740-
missing_keys,
741-
misc,
742-
unexpected_keys,
743-
mapping.distributed_operation,
744-
hf_quantizer,
745-
)
754+
for target_name, param in realized_value.items():
755+
param = param[0] if isinstance(param, list) else param
756+
device_match = device_map_regex.match(target_name)
757+
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
758+
# Offloading support
759+
if param_device == "disk":
760+
disk_offload_index = offload_and_maybe_resave_param(
761+
target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
762+
)
763+
else:
764+
set_param_for_module(
765+
model,
766+
target_name,
767+
param,
768+
mismatch_keys,
769+
missing_keys,
770+
misc,
771+
unexpected_keys,
772+
mapping.distributed_operation,
773+
hf_quantizer,
774+
)
746775
except SkipLayer:
747776
continue
777+
748778
thread_pool.shutdown(wait=False)
749-
return missing_keys, unexpected_keys, mismatch_keys, misc
779+
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc
750780

751781

752782
# TODO this is not done yet!

src/transformers/integrations/accelerate.py

Lines changed: 77 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919
import copy
2020
import inspect
2121
import os
22+
import re
2223
from collections import OrderedDict, defaultdict
2324
from contextlib import contextmanager
2425
from typing import TYPE_CHECKING, Optional, Union
2526

27+
from safetensors import safe_open
28+
from safetensors.torch import save_file
29+
2630
from ..utils import (
2731
is_accelerate_available,
2832
is_torch_available,
@@ -445,71 +449,103 @@ def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload
445449
dispatch_model(model, **device_map_kwargs)
446450

447451

448-
def get_disk_only_shard_files(device_map, weight_map):
449-
"""
450-
Returns the list of shard files containing only weights offloaded to disk.
451-
"""
452-
files_content = defaultdict(list)
453-
for weight_name, filename in weight_map.items():
454-
while len(weight_name) > 0 and weight_name not in device_map:
455-
weight_name = ".".join(weight_name.split(".")[:-1])
456-
files_content[filename].append(device_map[weight_name])
457-
458-
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
459-
460-
461452
def expand_device_map(device_map, param_names):
462453
"""
463454
Expand a device map to return the correspondence parameter name to device.
464455
"""
456+
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
457+
device_map_regex = re.compile(
458+
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
459+
)
465460
new_device_map = {}
466-
for module, device in device_map.items():
467-
new_device_map.update(
468-
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
469-
)
461+
for param in param_names:
462+
device_match = device_map_regex.match(param)
463+
new_device_map[param] = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
464+
470465
return new_device_map
471466

472467

473468
def accelerate_disk_offload(
474-
disk_offload_folder,
475-
checkpoint_files,
476-
device_map,
477-
checkpoint_keys,
478-
sharded_metadata,
479-
dtype,
469+
disk_offload_folder: str | None,
470+
checkpoint_files: list[str] | None,
471+
device_map: dict,
472+
expected_keys: list[str],
473+
sharded_metadata: dict | None,
474+
dtype: torch.dtype | None,
475+
weight_mapping=None,
480476
):
481-
disk_only_shard_files = []
477+
"""
478+
Prepare the `disk_offload_index` that will be used for reading offloaded parameters. If reading from a safetensors
479+
file, parameters which do not need any special WeightConverter operation during loading (i.e. they are used as-is, or only
480+
renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside
481+
`disk_offload_folder` during loading.
482+
"""
483+
from ..core_model_loading import WeightRenaming, build_glob_alternation, repl
484+
482485
if disk_offload_folder is not None:
483486
os.makedirs(disk_offload_folder, exist_ok=True)
484487
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
485-
if disk_offload_folder is None and not is_offloaded_safetensors:
486-
raise ValueError(
487-
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
488-
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
489-
" offers the weights in this format."
490-
)
488+
489+
rename = False
490+
if weight_mapping is not None:
491+
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
492+
if len(renamings) > 0:
493+
rename = True
494+
rename_alt, _, rename_by_group = build_glob_alternation(renamings)
495+
496+
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
497+
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
491498
if is_offloaded_safetensors:
492-
param_device_map = expand_device_map(device_map, checkpoint_keys)
499+
param_device_map = expand_device_map(device_map, expected_keys)
493500
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
494501
if sharded_metadata is None:
495-
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
502+
weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
496503
else:
497504
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
498-
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
499-
# Find potential checkpoints containing only offloaded weights
500-
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
505+
weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}
506+
507+
# Update the weight names according to the `weight_mapping`
508+
weight_renaming_map = {
509+
rename_alt.sub(lambda m: repl(m, rename_by_group), k).replace("\\", "") if rename else k: k
510+
for k in weight_map
511+
}
512+
513+
# Prepare the index using existing safetensors files
501514
disk_offload_index = {
502-
name: {
503-
"safetensors_file": file,
504-
"weight_name": name,
515+
target_name: {
516+
"safetensors_file": weight_map[source_name],
517+
"weight_name": source_name,
505518
"dtype": str_dtype,
506519
}
507-
for name, file in weight_map.items()
508-
if param_device_map[name] == "disk"
520+
for target_name, source_name in weight_renaming_map.items()
521+
# Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them)
522+
if target_name in param_device_map and param_device_map[target_name] == "disk"
509523
}
524+
# In this case we will resave every offloaded weight
510525
else:
511526
disk_offload_index = {}
512-
return disk_offload_index, disk_only_shard_files, is_offloaded_safetensors
527+
528+
return disk_offload_index
529+
530+
531+
def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | None, offload_index: dict) -> dict:
532+
"""Write `weight` to disk inside `offload_folder`, and update `offload_index` accordingly. Everything is
533+
saved in `safetensors` format."""
534+
535+
if offload_folder is None:
536+
raise ValueError(
537+
"The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either "
538+
"because the weights are not in `safetensors` format, or because the model uses an internal weight format "
539+
"different than the one saved (i.e. most MoE models). Please provide an `offload_folder` for them in "
540+
"`from_pretrained`."
541+
)
542+
# Write the weight to disk
543+
safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors")
544+
save_file({weight_name: weight}, safetensor_file)
545+
# Update the offloading index
546+
str_dtype = str(weight.dtype).replace("torch.", "")
547+
offload_index[weight_name] = {"safetensors_file": safetensor_file, "weight_name": weight_name, "dtype": str_dtype}
548+
return offload_index
513549

514550

515551
def _init_infer_auto_device_map(

0 commit comments

Comments
 (0)