Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
40d1827
unskip tests
Cyrilvallez Nov 17, 2025
5096ca6
first shot
Cyrilvallez Nov 17, 2025
ee4978e
offload in safetensors format
Cyrilvallez Nov 17, 2025
14e1699
remove hard-coded value
Cyrilvallez Nov 17, 2025
843371d
update error
Cyrilvallez Nov 17, 2025
d53ab68
typo
Cyrilvallez Nov 17, 2025
0086c24
fix
Cyrilvallez Nov 17, 2025
e3fb6eb
update test
Cyrilvallez Nov 17, 2025
ad7a84f
fix
Cyrilvallez Nov 17, 2025
4bc793c
return it
Cyrilvallez Nov 17, 2025
00abfba
post rebase
Cyrilvallez Nov 18, 2025
5e559dd
improve var names
Cyrilvallez Nov 18, 2025
009e44b
improve names
Cyrilvallez Nov 18, 2025
adacfe0
fix finally
Cyrilvallez Nov 18, 2025
74b3862
comment
Cyrilvallez Nov 18, 2025
76756b8
fix tests
Cyrilvallez Nov 18, 2025
b2f9593
fix
Cyrilvallez Nov 18, 2025
ee88919
simplify
Cyrilvallez Nov 18, 2025
98bf29d
fix
Cyrilvallez Nov 18, 2025
16c6dee
doc
Cyrilvallez Nov 18, 2025
cb1d7c7
fix
Cyrilvallez Nov 18, 2025
240767c
remove additional tiying after rebase
Cyrilvallez Nov 18, 2025
2e2b725
update test function source
Cyrilvallez Nov 18, 2025
deedbdf
fix
Cyrilvallez Nov 19, 2025
9a0675a
post rebase
Cyrilvallez Nov 21, 2025
81d136d
new renaming patterns
Cyrilvallez Nov 21, 2025
bd4e788
clear confusion about variable names
Cyrilvallez Nov 24, 2025
2a9009e
create cleaner function
Cyrilvallez Nov 24, 2025
a6a4c45
better doc
Cyrilvallez Nov 24, 2025
f4619dd
other tests
Cyrilvallez Nov 24, 2025
c41ea84
remove skip
Cyrilvallez Nov 24, 2025
da954ef
unskip other tests
Cyrilvallez Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 65 additions & 35 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import torch

from .integrations.accelerate import offload_weight
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer
from .utils import is_torch_greater_or_equal, logging

Expand Down Expand Up @@ -397,7 +398,7 @@ def dot_natural_key(s: str):

@contextmanager
def log_to_misc(
layer_name: str,
first_target_key: str,
misc: MutableMapping[str, str],
extras: Any = None,
op: Union[list[ConversionOps], ConversionOps, None] = None,
Expand All @@ -421,22 +422,22 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
if isinstance(extras, tuple) and len(extras) == 2:
values, target_keys = extras
descriptor = f"{op_name} " if op_name else ""
misc[layer_name] = (
misc[first_target_key] = (
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}"
)
elif isinstance(extras, str):
suffix = f" via {op_name}" if op_name else ""
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
misc[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
elif extras is None and op_name:
misc[layer_name] = f"{op_name}: {e}"
misc[first_target_key] = f"{op_name}: {e}"
else:
misc[layer_name] = f"{extras} |Error: {e}"
misc[first_target_key] = f"{extras} |Error: {e}"
raise SkipLayer()


def set_param_for_module(
model: PreTrainedModel,
layer_name: str,
target_name: str,
param_value: torch.Tensor,
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
missing_keys: MutableSet[str],
Expand All @@ -445,17 +446,13 @@ def set_param_for_module(
distributed_operation: Optional[TensorParallelLayer],
hf_quantizer: HfQuantizer,
):
with log_to_misc(layer_name, misc, layer_name):
module_path, _, param_name = layer_name.rpartition(".")
with log_to_misc(target_name, misc, target_name):
module_path, _, param_name = target_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
if isinstance(param_value, list):
param_value = param_value[0]
elif not isinstance(param_value, torch.nn.Parameter):
param_value = param_value[...]

ref = getattr(module_obj, param_name)
if ref is None:
unexpected_keys.add(layer_name)
unexpected_keys.add(target_name)
else:
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
Expand All @@ -475,17 +472,35 @@ def set_param_for_module(
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())

# Remove from missing keys (it's either mismatched, or all good)
missing_keys.discard(layer_name)
missing_keys.discard(target_name)
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
mismatch_keys.add((target_name, param_value.shape, ref.shape))
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
else:
param_value._is_hf_initialized = (
True # super important otherwise _init_weight re-initi if bias is missing
)
# super important otherwise _init_weight will re-init the param
param_value._is_hf_initialized = True
setattr(module_obj, param_name, param_value)


def offload_and_maybe_resave_param(
target_name: str,
param: torch.Tensor,
missing_keys: MutableSet[str],
disk_offload_folder: str,
disk_offload_index: dict,
applied_ops: WeightConverter | WeightRenaming,
) -> dict:
"""Takes care of correctly offloading `param`. If it's not already present in the `disk_offload_index`, or if any
WeightConverter operations have been applied, it will resave the new parameter. Otherwise, it will use the original
`disk_offload_index` for this given param."""
# We need to remove from missing keys
missing_keys.discard(target_name)
# If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save
if target_name not in disk_offload_index or isinstance(applied_ops, WeightConverter):
disk_offload_index = offload_weight(param, target_name, disk_offload_folder, disk_offload_index)
return disk_offload_index


class SkipLayer(Exception):
"""Control-flow sentinel: abort processing of the current layer only."""

Expand Down Expand Up @@ -521,6 +536,8 @@ def convert_and_load_state_dict_in_model(
device_map: dict | None = None,
dtype_plan: dict | None = None,
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
disk_offload_index: dict | None = None,
disk_offload_folder: str | None = None,
):
r"""
We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules.
Expand Down Expand Up @@ -612,6 +629,7 @@ def convert_and_load_state_dict_in_model(
prefix = model.base_model_prefix
tp_plan = tp_plan or {}
device_map = device_map or {"": "cpu"}
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
device_map_regex = re.compile(
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
)
Expand Down Expand Up @@ -708,9 +726,11 @@ def convert_and_load_state_dict_in_model(
shard_index,
)

if future is None: # TODO handle disk offload
if future is None:
device_match = device_map_regex.match(renamed_key)
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
# If disk, we need to materialize on cpu first
param_device = "cpu" if param_device == "disk" else param_device
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)

mapping.add_tensor(renamed_key, original_key, source_pattern, future)
Expand All @@ -723,30 +743,40 @@ def convert_and_load_state_dict_in_model(

total_entries = len(param_name_to_load)
with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
for layer_name, mapping in param_name_to_load.items():
for first_param_name, mapping in param_name_to_load.items():
pbar.update(1)
pbar.set_postfix({"Materializing param": layer_name})
pbar.set_postfix({"Materializing param": first_param_name})
pbar.refresh()
try:
realized_value, misc = mapping.convert(
layer_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys
first_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys
)
for k, output_value in realized_value.items():
set_param_for_module(
model,
k,
output_value,
mismatch_keys,
missing_keys,
misc,
unexpected_keys,
mapping.distributed_operation,
hf_quantizer,
)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
device_match = device_map_regex.match(target_name)
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
# Offloading support
if param_device == "disk":
disk_offload_index = offload_and_maybe_resave_param(
target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
)
else:
set_param_for_module(
model,
target_name,
param,
mismatch_keys,
missing_keys,
misc,
unexpected_keys,
mapping.distributed_operation,
hf_quantizer,
)
except SkipLayer:
continue

thread_pool.shutdown(wait=False)
return missing_keys, unexpected_keys, mismatch_keys, misc
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc


# TODO this is not done yet!
Expand Down
118 changes: 77 additions & 41 deletions src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
import copy
import inspect
import os
import re
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional, Union

from safetensors import safe_open
from safetensors.torch import save_file

from ..utils import (
is_accelerate_available,
is_torch_available,
Expand Down Expand Up @@ -445,71 +449,103 @@ def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload
dispatch_model(model, **device_map_kwargs)


def get_disk_only_shard_files(device_map, weight_map):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
files_content = defaultdict(list)
for weight_name, filename in weight_map.items():
while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name])

return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]


def expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
device_map_regex = re.compile(
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
)
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
for param in param_names:
device_match = device_map_regex.match(param)
new_device_map[param] = device_map[device_match.group()] if device_match else device_map.get("", "cpu")

return new_device_map


def accelerate_disk_offload(
disk_offload_folder,
checkpoint_files,
device_map,
checkpoint_keys,
sharded_metadata,
dtype,
disk_offload_folder: str | None,
checkpoint_files: list[str] | None,
device_map: dict,
expected_keys: list[str],
sharded_metadata: dict | None,
dtype: torch.dtype | None,
weight_mapping=None,
):
disk_only_shard_files = []
"""
Prepare the `disk_offload_index` that will be used for reading offloaded parameters. If reading from a safetensors
file, parameters which do not need any special WeightConverter operation during loading (i.e. they are used as-is, or only
renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside
`disk_offload_folder` during loading.
"""
from ..core_model_loading import WeightRenaming, build_glob_alternation, repl

if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)

rename = False
if weight_mapping is not None:
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
if len(renamings) > 0:
rename = True
rename_alt, _, rename_by_group = build_glob_alternation(renamings)

# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, checkpoint_keys)
param_device_map = expand_device_map(device_map, expected_keys)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}

# Update the weight names according to the `weight_mapping`
weight_renaming_map = {
rename_alt.sub(lambda m: repl(m, rename_by_group), k).replace("\\", "") if rename else k: k
for k in weight_map
}

# Prepare the index using existing safetensors files
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": name,
target_name: {
"safetensors_file": weight_map[source_name],
"weight_name": source_name,
"dtype": str_dtype,
}
for name, file in weight_map.items()
if param_device_map[name] == "disk"
for target_name, source_name in weight_renaming_map.items()
# Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them)
if target_name in param_device_map and param_device_map[target_name] == "disk"
}
# In this case we will resave every offloaded weight
else:
disk_offload_index = {}
return disk_offload_index, disk_only_shard_files, is_offloaded_safetensors

return disk_offload_index


def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | None, offload_index: dict) -> dict:
"""Write `weight` to disk inside `offload_folder`, and update `offload_index` accordingly. Everything is
saved in `safetensors` format."""

if offload_folder is None:
raise ValueError(
"The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either "
"because the weights are not in `safetensors` format, or because the model uses an internal weight format "
"different than the one saved (i.e. most MoE models). Please provide an `offload_folder` for them in "
"`from_pretrained`."
)
# Write the weight to disk
safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors")
save_file({weight_name: weight}, safetensor_file)
# Update the offloading index
str_dtype = str(weight.dtype).replace("torch.", "")
offload_index[weight_name] = {"safetensors_file": safetensor_file, "weight_name": weight_name, "dtype": str_dtype}
return offload_index


def _init_infer_auto_device_map(
Expand Down
Loading