From 40d1827f9b139cc40dcb8d269e946b60a98dbd5f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 14:37:13 +0100 Subject: [PATCH 01/32] unskip tests --- tests/test_modeling_common.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a7a40d887787..bad06c79531d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2357,7 +2357,6 @@ def check_device_map_is_respected(self, model, device_map): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("# TODO @CyrilVallez fix this in the other PR") def test_disk_offload_bin(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -2402,7 +2401,6 @@ def test_disk_offload_bin(self): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("# TODO @CyrilVallez fix this in the other PR") def test_disk_offload_safetensors(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -2441,7 +2439,6 @@ def test_disk_offload_safetensors(self): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("# TODO @CyrilVallez fix this in the other PR") def test_cpu_offload(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 5096ca6c1db5de120a9fa23430a451269e77e3ac Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 15:03:41 +0100 Subject: [PATCH 02/32] first shot --- src/transformers/integrations/accelerate.py | 36 +++++++-------------- src/transformers/modeling_utils.py | 25 ++++++++++++-- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 9696309ef221..0ea422599f2e 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -445,19 +445,6 @@ def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload dispatch_model(model, **device_map_kwargs) -def get_disk_only_shard_files(device_map, weight_map): - """ - Returns the list of shard files containing only weights offloaded to disk. - """ - files_content = defaultdict(list) - for weight_name, filename in weight_map.items(): - while len(weight_name) > 0 and weight_name not in device_map: - weight_name = ".".join(weight_name.split(".")[:-1]) - files_content[filename].append(device_map[weight_name]) - - return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] - - def expand_device_map(device_map, param_names): """ Expand a device map to return the correspondence parameter name to device. @@ -471,14 +458,13 @@ def expand_device_map(device_map, param_names): def accelerate_disk_offload( - disk_offload_folder, - checkpoint_files, - device_map, - checkpoint_keys, - sharded_metadata, - dtype, + disk_offload_folder: str | None, + checkpoint_files: list[str] | None, + device_map: dict, + expected_keys: list[str], + sharded_metadata: dict | None, + dtype: torch.dtype | None, ): - disk_only_shard_files = [] if disk_offload_folder is not None: os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") @@ -489,15 +475,13 @@ def accelerate_disk_offload( " offers the weights in this format." ) if is_offloaded_safetensors: - param_device_map = expand_device_map(device_map, checkpoint_keys) + param_device_map = expand_device_map(device_map, expected_keys) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if sharded_metadata is None: - weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0]) + weight_map = dict.fromkeys(expected_keys, checkpoint_files[0]) else: folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} - # Find potential checkpoints containing only offloaded weights - disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map) disk_offload_index = { name: { "safetensors_file": file, @@ -509,7 +493,8 @@ def accelerate_disk_offload( } else: disk_offload_index = {} - return disk_offload_index, disk_only_shard_files, is_offloaded_safetensors + + return disk_offload_index def _init_infer_auto_device_map( @@ -894,3 +879,4 @@ def check_tied_parameters_on_same_device(tied_params, device_map): f"Tied parameters are on different devices: {tie_param_devices}. " "Please modify your custom device map or set `device_map='auto'`. " ) + diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3dc522efe251..ce63f5bf7f01 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -59,6 +59,7 @@ from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled from .integrations.accelerate import ( _get_device_map, + accelerate_disk_offload, accelerate_dispatch, check_and_set_device_map, expand_device_map, @@ -131,9 +132,7 @@ if is_accelerate_available(): from accelerate.hooks import add_hook_to_module - from accelerate.utils import ( - extract_model_from_parallel, - ) + from accelerate.utils import extract_model_from_parallel, offload_weight from accelerate.utils.modeling import get_state_dict_from_offload @@ -4065,6 +4064,19 @@ def _load_pretrained_model( if logger.level >= logging.WARNING: verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None)) + # This offload index if for params explicitly on the "disk" in the device_map + disk_offload_index = None + # Prepare parameters offloading if needed + if device_map is not None and "disk" in device_map.values(): + disk_offload_index = accelerate_disk_offload( + disk_offload_folder, + checkpoint_files, + device_map, + expected_keys, + sharded_metadata, + dtype, + ) + # Warmup cuda to load the weights much faster on devices if device_map is not None and not is_hqq_or_quark: expanded_device_map = expand_device_map(device_map, expected_keys) @@ -4164,6 +4176,13 @@ def _load_pretrained_model( device_mesh, ) + # If the model parameters were changed during loading (i.e. any custom Ops on the weights), we need to resave them + # for offloading + if device_map is not None and "disk" in device_map.values(): + for name, param in model.state_dict().items(): + if name not in disk_offload_index: + disk_offload_index = offload_weight(param, name, disk_offload_folder, disk_offload_index) + log_state_dict_report( model=model, pretrained_model_name_or_path=pretrained_model_name_or_path, From ee4978e2a121765ae5ad1bcd77b208fc6f19510c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 15:47:51 +0100 Subject: [PATCH 03/32] offload in safetensors format --- src/transformers/integrations/accelerate.py | 32 ++++++++++++++++----- src/transformers/modeling_utils.py | 3 +- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 0ea422599f2e..7c2c42245610 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -23,6 +23,8 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Optional, Union +from safetensors.torch import save_file + from ..utils import ( is_accelerate_available, is_torch_available, @@ -468,12 +470,9 @@ def accelerate_disk_offload( if disk_offload_folder is not None: os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") - if disk_offload_folder is None and not is_offloaded_safetensors: - raise ValueError( - "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" - " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" - " offers the weights in this format." - ) + + # In this cause, the offload index is simply the existing safetensors (except if using custom weight loading + # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) if is_offloaded_safetensors: param_device_map = expand_device_map(device_map, expected_keys) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" @@ -491,12 +490,32 @@ def accelerate_disk_offload( for name, file in weight_map.items() if param_device_map[name] == "disk" } + # In this case we will resave every offloaded weight else: disk_offload_index = {} return disk_offload_index +def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | None, offload_index: dict) -> dict: + """Write `weight` to disk inside `offload_folder`, and update `offload_index` accordingly. Everything is + saved in `safetensors` format.""" + + if offload_folder is None: + raise ValueError( + "The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is likely " + "because the model uses an internal weight format different than the one saved (i.e. most MoE models). " + "Please provide an `offload_folder`for them in `from_pretrained`." + ) + # Write the weight to disk + safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors") + save_file({weight_name: weight}, safetensor_file) + # Update the offloading index + str_dtype = str(weight.dtype).replace("torch.", "") + offload_index[weight_name] = {"safetensors_file": safetensor_file, "weight_name": weight_name, "dtype": str_dtype} + return offload_index + + def _init_infer_auto_device_map( model: nn.Module, max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None, @@ -879,4 +898,3 @@ def check_tied_parameters_on_same_device(tied_params, device_map): f"Tied parameters are on different devices: {tie_param_devices}. " "Please modify your custom device map or set `device_map='auto'`. " ) - diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ce63f5bf7f01..eae1597c2e87 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -64,6 +64,7 @@ check_and_set_device_map, expand_device_map, init_empty_weights, + offload_weight, ) from .integrations.deepspeed import _load_state_dict_into_zero3_model from .integrations.eager_paged import eager_paged_attention_forward @@ -132,7 +133,7 @@ if is_accelerate_available(): from accelerate.hooks import add_hook_to_module - from accelerate.utils import extract_model_from_parallel, offload_weight + from accelerate.utils import extract_model_from_parallel from accelerate.utils.modeling import get_state_dict_from_offload From 14e1699fedba283918663eb5f02fe7f14eb4d0c7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 15:49:27 +0100 Subject: [PATCH 04/32] remove hard-coded value --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index eae1597c2e87..e39d07adf7a6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4196,7 +4196,7 @@ def _load_pretrained_model( misc=misc, ignore_mismatched_sizes=ignore_mismatched_sizes, ) - disk_offload_index = None + return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): From 843371d52ff707f664a3ee66adf1f7ac6fd369f2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 15:50:51 +0100 Subject: [PATCH 05/32] update error --- src/transformers/integrations/accelerate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 7c2c42245610..35b447a2a9f5 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -503,9 +503,10 @@ def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | if offload_folder is None: raise ValueError( - "The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is likely " - "because the model uses an internal weight format different than the one saved (i.e. most MoE models). " - "Please provide an `offload_folder`for them in `from_pretrained`." + "The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either " + "because the weights are not in `safetensors` format, or because the model uses an internal weight format " + "different than the one saved (i.e. most MoE models). Please provide an `offload_folder`for them in " + "`from_pretrained`." ) # Write the weight to disk safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors") From d53ab681973a5b950a40c27e166053fc4cbc6608 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 15:51:52 +0100 Subject: [PATCH 06/32] typo --- src/transformers/integrations/accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 35b447a2a9f5..a3653bc237f7 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -505,7 +505,7 @@ def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | raise ValueError( "The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either " "because the weights are not in `safetensors` format, or because the model uses an internal weight format " - "different than the one saved (i.e. most MoE models). Please provide an `offload_folder`for them in " + "different than the one saved (i.e. most MoE models). Please provide an `offload_folder` for them in " "`from_pretrained`." ) # Write the weight to disk From 0086c243183268c0f47eaaf44bb8fb0830eff4a6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 17:05:24 +0100 Subject: [PATCH 07/32] fix --- src/transformers/core_model_loading.py | 57 +++++++++++++-------- src/transformers/integrations/accelerate.py | 2 +- src/transformers/modeling_utils.py | 10 +--- 3 files changed, 38 insertions(+), 31 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 80cf941c37ed..5cef47f1628b 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -28,6 +28,7 @@ import torch +from .integrations.accelerate import offload_weight from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer from .utils import is_torch_greater_or_equal, logging @@ -397,7 +398,7 @@ def dot_natural_key(s: str): @contextmanager def log_to_misc( - layer_name: str, + full_param_name: str, misc: MutableMapping[str, str], extras: Any = None, op: Union[list[ConversionOps], ConversionOps, None] = None, @@ -421,22 +422,22 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> if isinstance(extras, tuple) and len(extras) == 2: values, target_keys = extras descriptor = f"{op_name} " if op_name else "" - misc[layer_name] = ( + misc[full_param_name] = ( f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}" ) elif isinstance(extras, str): suffix = f" via {op_name}" if op_name else "" - misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}" + misc[full_param_name] = f"{e}\nError{suffix} when processing parameter {extras}" elif extras is None and op_name: - misc[layer_name] = f"{op_name}: {e}" + misc[full_param_name] = f"{op_name}: {e}" else: - misc[layer_name] = f"{extras} |Error: {e}" + misc[full_param_name] = f"{extras} |Error: {e}" raise SkipLayer() def set_param_for_module( model: PreTrainedModel, - layer_name: str, + full_param_name: str, param_value: torch.Tensor, mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]], missing_keys: MutableSet[str], @@ -445,8 +446,8 @@ def set_param_for_module( distributed_operation: Optional[TensorParallelLayer], hf_quantizer: HfQuantizer, ): - with log_to_misc(layer_name, misc, layer_name): - module_path, _, param_name = layer_name.rpartition(".") + with log_to_misc(full_param_name, misc, full_param_name): + module_path, _, param_name = full_param_name.rpartition(".") module_obj = model.get_submodule(module_path) if module_path else model if isinstance(param_value, list): param_value = param_value[0] @@ -521,6 +522,8 @@ def convert_and_load_state_dict_in_model( device_map: dict | None = None, dtype_plan: dict | None = None, device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None, + disk_offload_index: dict | None = None, + disk_offload_folder: str | None = None, ): r""" We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules. @@ -723,26 +726,36 @@ def convert_and_load_state_dict_in_model( total_entries = len(param_name_to_load) with logging.tqdm(total=total_entries, desc="Loading weights") as pbar: - for layer_name, mapping in param_name_to_load.items(): + for full_param_name, mapping in param_name_to_load.items(): pbar.update(1) - pbar.set_postfix({"Materializing param": layer_name}) + pbar.set_postfix({"Materializing param": full_param_name}) pbar.refresh() try: realized_value, misc = mapping.convert( - layer_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys + full_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys ) for k, output_value in realized_value.items(): - set_param_for_module( - model, - k, - output_value, - mismatch_keys, - missing_keys, - misc, - unexpected_keys, - mapping.distributed_operation, - hf_quantizer, - ) + param_device = device_map[re.search(device_map_regex, k).group()] + # Offloading support + if param_device == "disk": + missing_keys.discard(k) + # If not already offloaded, or if we applied any special Operation, we need to re-save + if k not in disk_offload_index or len(operations) > 0: + disk_offload_index = offload_weight( + output_value, k, disk_offload_folder, disk_offload_index + ) + else: + set_param_for_module( + model, + k, + output_value, + mismatch_keys, + missing_keys, + misc, + unexpected_keys, + mapping.distributed_operation, + hf_quantizer, + ) except SkipLayer: continue thread_pool.shutdown(wait=False) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index a3653bc237f7..efa5c0d96ca1 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -471,7 +471,7 @@ def accelerate_disk_offload( os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") - # In this cause, the offload index is simply the existing safetensors (except if using custom weight loading + # In this case, the offload index is simply the existing safetensors (except if using custom weight loading # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) if is_offloaded_safetensors: param_device_map = expand_device_map(device_map, expected_keys) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e39d07adf7a6..08cefd95d5fd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -64,7 +64,6 @@ check_and_set_device_map, expand_device_map, init_empty_weights, - offload_weight, ) from .integrations.deepspeed import _load_state_dict_into_zero3_model from .integrations.eager_paged import eager_paged_attention_forward @@ -4126,6 +4125,8 @@ def _load_pretrained_model( device_map, model.dtype_plan, device_mesh, + disk_offload_index, + disk_offload_folder, ) # finally close all opened file pointers @@ -4177,13 +4178,6 @@ def _load_pretrained_model( device_mesh, ) - # If the model parameters were changed during loading (i.e. any custom Ops on the weights), we need to resave them - # for offloading - if device_map is not None and "disk" in device_map.values(): - for name, param in model.state_dict().items(): - if name not in disk_offload_index: - disk_offload_index = offload_weight(param, name, disk_offload_folder, disk_offload_index) - log_state_dict_report( model=model, pretrained_model_name_or_path=pretrained_model_name_or_path, From e3fb6eb16853c010034f26fdbf5b4d485ba6fd7c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 17:14:38 +0100 Subject: [PATCH 08/32] update test --- tests/test_modeling_common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bad06c79531d..81fab67b919e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2378,7 +2378,9 @@ def test_disk_offload_bin(self): max_size = int(self.model_split_percents[0] * model_size) max_memory = {0: max_size, "cpu": max_size} # This errors out cause it's missing an offload folder - new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + new_model = model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, use_safetensors=False + ) max_size = int(self.model_split_percents[1] * model_size) max_memory = {0: max_size, "cpu": max_size} From ad7a84f9123cf995c6e1da6e52ff141e07edc661 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 17:23:41 +0100 Subject: [PATCH 09/32] fix --- src/transformers/core_model_loading.py | 7 ++----- tests/test_modeling_common.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 5cef47f1628b..482330402ea2 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -449,10 +449,6 @@ def set_param_for_module( with log_to_misc(full_param_name, misc, full_param_name): module_path, _, param_name = full_param_name.rpartition(".") module_obj = model.get_submodule(module_path) if module_path else model - if isinstance(param_value, list): - param_value = param_value[0] - elif not isinstance(param_value, torch.nn.Parameter): - param_value = param_value[...] ref = getattr(module_obj, param_name) if ref is None: @@ -711,7 +707,7 @@ def convert_and_load_state_dict_in_model( shard_index, ) - if future is None: # TODO handle disk offload + if future is None: device_match = device_map_regex.match(renamed_key) param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") future = spawn_materialize(thread_pool, tensor, param_device, _dtype) @@ -735,6 +731,7 @@ def convert_and_load_state_dict_in_model( full_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys ) for k, output_value in realized_value.items(): + output_value = output_value[0] if isinstance(output_value, list) else output_value param_device = device_map[re.search(device_map_regex, k).group()] # Offloading support if param_device == "disk": diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 81fab67b919e..e18a76d54c6d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2385,7 +2385,7 @@ def test_disk_offload_bin(self): max_size = int(self.model_split_percents[1] * model_size) max_memory = {0: max_size, "cpu": max_size} new_model = model_class.from_pretrained( - tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir, use_safetensors=False ) self.check_device_map_is_respected(new_model, new_model.hf_device_map) From 4bc793c175cb0eec87f1958269a57d21a52b3b09 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 17 Nov 2025 17:29:30 +0100 Subject: [PATCH 10/32] return it --- src/transformers/core_model_loading.py | 2 +- src/transformers/modeling_utils.py | 26 ++++++++++++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 482330402ea2..e1d291280706 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -756,7 +756,7 @@ def convert_and_load_state_dict_in_model( except SkipLayer: continue thread_pool.shutdown(wait=False) - return missing_keys, unexpected_keys, mismatch_keys, misc + return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc # TODO this is not done yet! diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 08cefd95d5fd..e37c45774d84 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4115,18 +4115,20 @@ def _load_pretrained_model( else: raise ValueError("Neither a state dict nor checkpoint files were found.") - missing_keys, unexpected_keys, mismatched_keys, misc = convert_and_load_state_dict_in_model( - model, - merged_state_dict, - weight_mapping, - tp_plan, - hf_quantizer, - dtype, - device_map, - model.dtype_plan, - device_mesh, - disk_offload_index, - disk_offload_folder, + missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, misc = ( + convert_and_load_state_dict_in_model( + model, + merged_state_dict, + weight_mapping, + tp_plan, + hf_quantizer, + dtype, + device_map, + model.dtype_plan, + device_mesh, + disk_offload_index, + disk_offload_folder, + ) ) # finally close all opened file pointers From 00abfba3b3366faa1114a4a7993c3ef17abe4df0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 10:05:45 +0100 Subject: [PATCH 11/32] post rebase --- src/transformers/core_model_loading.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index e1d291280706..83d290532fb1 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -732,7 +732,10 @@ def convert_and_load_state_dict_in_model( ) for k, output_value in realized_value.items(): output_value = output_value[0] if isinstance(output_value, list) else output_value - param_device = device_map[re.search(device_map_regex, k).group()] + device_match = device_map_regex.match(k) + param_device = ( + device_map[device_match.group()] if device_match else device_map.get("", "cpu") + ) # Offloading support if param_device == "disk": missing_keys.discard(k) From 5e559ddcf963e81908be19928b6808ca7f007cff Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 10:09:19 +0100 Subject: [PATCH 12/32] improve var names --- src/transformers/core_model_loading.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 83d290532fb1..acbcda454945 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -730,9 +730,9 @@ def convert_and_load_state_dict_in_model( realized_value, misc = mapping.convert( full_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys ) - for k, output_value in realized_value.items(): - output_value = output_value[0] if isinstance(output_value, list) else output_value - device_match = device_map_regex.match(k) + for target_name, param in realized_value.items(): + param = param[0] if isinstance(param, list) else param + device_match = device_map_regex.match(target_name) param_device = ( device_map[device_match.group()] if device_match else device_map.get("", "cpu") ) @@ -742,13 +742,13 @@ def convert_and_load_state_dict_in_model( # If not already offloaded, or if we applied any special Operation, we need to re-save if k not in disk_offload_index or len(operations) > 0: disk_offload_index = offload_weight( - output_value, k, disk_offload_folder, disk_offload_index + param, target_name, disk_offload_folder, disk_offload_index ) else: set_param_for_module( model, - k, - output_value, + target_name, + param, mismatch_keys, missing_keys, misc, From 009e44be2478ecf6530b7f67342296e484278e1d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 10:11:33 +0100 Subject: [PATCH 13/32] improve names --- src/transformers/core_model_loading.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index acbcda454945..addcfd18f436 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -398,7 +398,7 @@ def dot_natural_key(s: str): @contextmanager def log_to_misc( - full_param_name: str, + all_target_keys: str, misc: MutableMapping[str, str], extras: Any = None, op: Union[list[ConversionOps], ConversionOps, None] = None, @@ -422,16 +422,16 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> if isinstance(extras, tuple) and len(extras) == 2: values, target_keys = extras descriptor = f"{op_name} " if op_name else "" - misc[full_param_name] = ( + misc[all_target_keys] = ( f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}" ) elif isinstance(extras, str): suffix = f" via {op_name}" if op_name else "" - misc[full_param_name] = f"{e}\nError{suffix} when processing parameter {extras}" + misc[all_target_keys] = f"{e}\nError{suffix} when processing parameter {extras}" elif extras is None and op_name: - misc[full_param_name] = f"{op_name}: {e}" + misc[all_target_keys] = f"{op_name}: {e}" else: - misc[full_param_name] = f"{extras} |Error: {e}" + misc[all_target_keys] = f"{extras} |Error: {e}" raise SkipLayer() From adacfe007041b286663568ff9e5ec0f40ec789b4 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 12:11:00 +0100 Subject: [PATCH 14/32] fix finally --- src/transformers/integrations/accelerate.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index efa5c0d96ca1..0ffbc2d2fd1f 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -406,6 +406,10 @@ def _get_device_map( if max_memory is not None and device_name in max_memory: inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name]) + # Here we need to retie the weights before the call even if they are all on meta device, otherwise accelerate + # mess up the device_map computation + # TODO Cyril: replace this function to avoid re-tying uselessly + model.tie_weights() device_map = infer_auto_device_map( model, max_memory=inferred_max_memory, From 74b3862940009fe52593207ff886666bdfd0c235 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 12:12:00 +0100 Subject: [PATCH 15/32] comment --- src/transformers/integrations/accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 0ffbc2d2fd1f..171fd7c19465 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -408,7 +408,7 @@ def _get_device_map( # Here we need to retie the weights before the call even if they are all on meta device, otherwise accelerate # mess up the device_map computation - # TODO Cyril: replace this function to avoid re-tying uselessly + # TODO Cyril: replace this function to avoid re-tying uselessly (and the function is very inefficient) model.tie_weights() device_map = infer_auto_device_map( model, From 76756b8652fc9e2729338093e7f1918e21b3a824 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 12:25:11 +0100 Subject: [PATCH 16/32] fix tests --- tests/utils/test_core_model_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 023b9fd0596d..8973e9900f0f 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -250,7 +250,7 @@ def test_moe_and_qkv_conversion(self): ), WeightRenaming("mlp.w2.weight", "mlp.down_proj.weight"), ] - missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model( + missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model( model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None ) @@ -400,7 +400,7 @@ def __init__(self): ) ] - missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model( + missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model( model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=quantizer ) From b2f95933c83dd95f51cda012c5d8cf9ff193674a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 13:03:37 +0100 Subject: [PATCH 17/32] fix --- src/transformers/core_model_loading.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index addcfd18f436..09aa88c52d7e 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -611,6 +611,7 @@ def convert_and_load_state_dict_in_model( prefix = model.base_model_prefix tp_plan = tp_plan or {} device_map = device_map or {"": "cpu"} + # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly device_map_regex = re.compile( "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True)) ) From ee88919d63cbf746af9ffd1dcc451ceafcff1b10 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 13:14:34 +0100 Subject: [PATCH 18/32] simplify --- src/transformers/integrations/accelerate.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 171fd7c19465..294c0849b477 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -19,6 +19,7 @@ import copy import inspect import os +import re from collections import OrderedDict, defaultdict from contextlib import contextmanager from typing import TYPE_CHECKING, Optional, Union @@ -455,11 +456,15 @@ def expand_device_map(device_map, param_names): """ Expand a device map to return the correspondence parameter name to device. """ + # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly + device_map_regex = re.compile( + "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True)) + ) new_device_map = {} - for module, device in device_map.items(): - new_device_map.update( - {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} - ) + for param in param_names: + device_match = device_map_regex.match(param) + new_device_map[param] = device_map[device_match.group()] if device_match else device_map.get("", "cpu") + return new_device_map From 98bf29d805189886cea46b8f4b8920665506310f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 15:00:57 +0100 Subject: [PATCH 19/32] fix --- src/transformers/integrations/accelerate.py | 48 ++++++++++++++++++--- src/transformers/modeling_utils.py | 1 + 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 294c0849b477..7ca982656353 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -18,12 +18,14 @@ import copy import inspect +import itertools import os import re from collections import OrderedDict, defaultdict from contextlib import contextmanager from typing import TYPE_CHECKING, Optional, Union +from safetensors import safe_open from safetensors.torch import save_file from ..utils import ( @@ -468,6 +470,23 @@ def expand_device_map(device_map, param_names): return new_device_map +def update_param_name(param_name: str, weight_pattern_alt, weight_pattern_by_group_name, source_to_target): + from ..core_model_loading import match_glob + + if weight_pattern_alt is None: + return param_name + + matched_pattern = match_glob(param_name, weight_pattern_alt, weight_pattern_by_group_name) + if matched_pattern is not None: + converter = source_to_target[matched_pattern] + # Only change name if it's a simple renaming, i.e. no custom Ops + if len(converter.source_keys) == 1 and len(converter.target_keys) == 1: + source_pattern = converter.source_keys[0] + target_pattern = converter.target_keys[0] + return re.sub(source_pattern, target_pattern, param_name) + return param_name + + def accelerate_disk_offload( disk_offload_folder: str | None, checkpoint_files: list[str] | None, @@ -475,29 +494,46 @@ def accelerate_disk_offload( expected_keys: list[str], sharded_metadata: dict | None, dtype: torch.dtype | None, + weight_mapping=None, ): + from ..core_model_loading import build_glob_alt + if disk_offload_folder is not None: os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") + _patterns, weight_pattern_alt, weight_pattern_by_group_name = None, None, None + if weight_mapping is not None: + _patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping])) + source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys} + weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns) + # In this case, the offload index is simply the existing safetensors (except if using custom weight loading # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) if is_offloaded_safetensors: param_device_map = expand_device_map(device_map, expected_keys) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if sharded_metadata is None: - weight_map = dict.fromkeys(expected_keys, checkpoint_files[0]) + weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0]) else: folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} + + # Update the weight names according to the `weight_mapping` + weight_renaming_map = { + update_param_name(k, weight_pattern_alt, weight_pattern_by_group_name, source_to_target): k + for k in weight_map + } + + # Prepare the index using existing safetensors files disk_offload_index = { - name: { - "safetensors_file": file, - "weight_name": name, + target_name: { + "safetensors_file": weight_map[source_name], + "weight_name": source_name, "dtype": str_dtype, } - for name, file in weight_map.items() - if param_device_map[name] == "disk" + for target_name, source_name in weight_renaming_map.items() + if param_device_map[target_name] == "disk" } # In this case we will resave every offloaded weight else: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e37c45774d84..6a49b070d402 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4075,6 +4075,7 @@ def _load_pretrained_model( expected_keys, sharded_metadata, dtype, + weight_mapping, ) # Warmup cuda to load the weights much faster on devices From 16c6dee7389bc781b4aaa819aae3adfbbf321e41 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 15:09:43 +0100 Subject: [PATCH 20/32] doc --- src/transformers/integrations/accelerate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 7ca982656353..52bb71690c1b 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -470,7 +470,11 @@ def expand_device_map(device_map, param_names): return new_device_map -def update_param_name(param_name: str, weight_pattern_alt, weight_pattern_by_group_name, source_to_target): +def update_param_name(param_name: str, weight_pattern_alt, weight_pattern_by_group_name, source_to_target) -> str: + """Update a source `param_name` in a checkpoint into the target name that the model expects, if different. + This uses the same logic as `core_model_loading.py`.""" + # TODO Cyril: This function would not even need to exist if the Converter entries already contained the + # full expanded source and target names from ..core_model_loading import match_glob if weight_pattern_alt is None: From cb1d7c75e36a99bf5e0018d46718a6bcb3b32ce9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 15:29:49 +0100 Subject: [PATCH 21/32] fix --- tests/models/glm4_moe/test_modeling_glm4_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/glm4_moe/test_modeling_glm4_moe.py b/tests/models/glm4_moe/test_modeling_glm4_moe.py index d4cebbd02983..526e81e29766 100644 --- a/tests/models/glm4_moe/test_modeling_glm4_moe.py +++ b/tests/models/glm4_moe/test_modeling_glm4_moe.py @@ -61,6 +61,7 @@ class Glm4MoeModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Glm4MoeModelTester # used in `test_torch_compile_for_training`. Skip as "Dynamic control flow in MoE" _torch_compile_train_cls = None + model_split_percents = [0.5, 0.85, 0.9] # it tries to offload everything with the default value @require_torch_accelerator From 240767c8e3cd4d294845574b6b2fc45b564ced57 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 18:55:53 +0100 Subject: [PATCH 22/32] remove additional tiying after rebase --- src/transformers/integrations/accelerate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 52bb71690c1b..b945e00fb421 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -409,10 +409,6 @@ def _get_device_map( if max_memory is not None and device_name in max_memory: inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name]) - # Here we need to retie the weights before the call even if they are all on meta device, otherwise accelerate - # mess up the device_map computation - # TODO Cyril: replace this function to avoid re-tying uselessly (and the function is very inefficient) - model.tie_weights() device_map = infer_auto_device_map( model, max_memory=inferred_max_memory, From 2e2b725ae47c9c2945ee37ba31edba0acd949f27 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 18 Nov 2025 22:04:28 +0100 Subject: [PATCH 23/32] update test function source --- tests/test_modeling_common.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e18a76d54c6d..ccc403c6a2c9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -105,7 +105,6 @@ CONFIG_NAME, GENERATION_CONFIG_NAME, SAFE_WEIGHTS_NAME, - is_accelerate_available, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device, ) @@ -113,10 +112,6 @@ from .generation.test_utils import GenerationTesterMixin -if is_accelerate_available(): - from accelerate.utils import compute_module_sizes - - if is_torch_available(): import torch from safetensors import safe_open @@ -125,6 +120,7 @@ from torch import nn from transformers import MODEL_MAPPING + from transformers.integrations.accelerate import compute_module_sizes from transformers.integrations.tensor_parallel import _get_parameter_tp_plan from transformers.modeling_utils import load_state_dict from transformers.pytorch_utils import id_tensor_storage @@ -2370,7 +2366,7 @@ def test_disk_offload_bin(self): torch.manual_seed(0) base_output = model(**inputs_dict_class) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_sizes(model)[0][""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, safe_serialization=False) @@ -2416,7 +2412,7 @@ def test_disk_offload_safetensors(self): torch.manual_seed(0) base_output = model(**inputs_dict_class) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_sizes(model)[0][""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) @@ -2455,7 +2451,7 @@ def test_cpu_offload(self): torch.manual_seed(0) base_output = model(**inputs_dict_class) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_sizes(model)[0][""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: @@ -2498,7 +2494,7 @@ def test_model_parallelism(self): torch.manual_seed(0) base_output = model(**inputs_dict_class) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_sizes(model)[0][""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: From deedbdf1d9c3d37b319a5a88d5edc1f3e608def9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 19 Nov 2025 18:27:34 +0100 Subject: [PATCH 24/32] fix --- .../deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py b/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py index 661b44ecf390..4095e19c00d0 100644 --- a/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py +++ b/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py @@ -167,6 +167,7 @@ class DeepseekVLHybridModelTest(ModelTesterMixin, GenerationTesterMixin, unittes else {} ) _is_composite = True + model_split_percents = [0.5, 0.85, 0.9] # it tries to offload everything with the default value def setUp(self): self.model_tester = DeepseekVLHybridModelTester(self) From 9a0675a723a6b62c9f403a033d7c91359214663d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 21 Nov 2025 17:21:45 +0100 Subject: [PATCH 25/32] post rebase --- src/transformers/core_model_loading.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 09aa88c52d7e..c37088bd14cf 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -739,9 +739,9 @@ def convert_and_load_state_dict_in_model( ) # Offloading support if param_device == "disk": - missing_keys.discard(k) - # If not already offloaded, or if we applied any special Operation, we need to re-save - if k not in disk_offload_index or len(operations) > 0: + missing_keys.discard(target_name) + # If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save + if target_name not in disk_offload_index or isinstance(mapping, WeightConverter): disk_offload_index = offload_weight( param, target_name, disk_offload_folder, disk_offload_index ) @@ -759,6 +759,7 @@ def convert_and_load_state_dict_in_model( ) except SkipLayer: continue + thread_pool.shutdown(wait=False) return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc From 81d136d4274f2296e7b26e6bc6f1be20b8784c53 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 21 Nov 2025 17:36:55 +0100 Subject: [PATCH 26/32] new renaming patterns --- src/transformers/core_model_loading.py | 17 +++++----- src/transformers/integrations/accelerate.py | 35 +++++---------------- 2 files changed, 15 insertions(+), 37 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index c37088bd14cf..d7f970fb2fb2 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -452,7 +452,7 @@ def set_param_for_module( ref = getattr(module_obj, param_name) if ref is None: - unexpected_keys.add(layer_name) + unexpected_keys.add(full_param_name) else: use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor if not isinstance(param_value, torch.nn.Parameter): @@ -472,14 +472,13 @@ def set_param_for_module( param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) # Remove from missing keys (it's either mismatched, or all good) - missing_keys.discard(layer_name) + missing_keys.discard(full_param_name) if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: - mismatch_keys.add((layer_name, param_value.shape, ref.shape)) + mismatch_keys.add((full_param_name, param_value.shape, ref.shape)) module_obj.param_name._is_hf_initialized = False # Needs to be initialized else: - param_value._is_hf_initialized = ( - True # super important otherwise _init_weight re-initi if bias is missing - ) + # super important otherwise _init_weight will re-init the param + param_value._is_hf_initialized = True setattr(module_obj, param_name, param_value) @@ -711,6 +710,8 @@ def convert_and_load_state_dict_in_model( if future is None: device_match = device_map_regex.match(renamed_key) param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") + # If disk, we need to materialize on cpu first + param_device = "cpu" if param_device == "disk" else param_device future = spawn_materialize(thread_pool, tensor, param_device, _dtype) mapping.add_tensor(renamed_key, original_key, source_pattern, future) @@ -734,9 +735,7 @@ def convert_and_load_state_dict_in_model( for target_name, param in realized_value.items(): param = param[0] if isinstance(param, list) else param device_match = device_map_regex.match(target_name) - param_device = ( - device_map[device_match.group()] if device_match else device_map.get("", "cpu") - ) + param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") # Offloading support if param_device == "disk": missing_keys.discard(target_name) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index b945e00fb421..a9c4606d8839 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -18,7 +18,6 @@ import copy import inspect -import itertools import os import re from collections import OrderedDict, defaultdict @@ -466,27 +465,6 @@ def expand_device_map(device_map, param_names): return new_device_map -def update_param_name(param_name: str, weight_pattern_alt, weight_pattern_by_group_name, source_to_target) -> str: - """Update a source `param_name` in a checkpoint into the target name that the model expects, if different. - This uses the same logic as `core_model_loading.py`.""" - # TODO Cyril: This function would not even need to exist if the Converter entries already contained the - # full expanded source and target names - from ..core_model_loading import match_glob - - if weight_pattern_alt is None: - return param_name - - matched_pattern = match_glob(param_name, weight_pattern_alt, weight_pattern_by_group_name) - if matched_pattern is not None: - converter = source_to_target[matched_pattern] - # Only change name if it's a simple renaming, i.e. no custom Ops - if len(converter.source_keys) == 1 and len(converter.target_keys) == 1: - source_pattern = converter.source_keys[0] - target_pattern = converter.target_keys[0] - return re.sub(source_pattern, target_pattern, param_name) - return param_name - - def accelerate_disk_offload( disk_offload_folder: str | None, checkpoint_files: list[str] | None, @@ -496,17 +474,18 @@ def accelerate_disk_offload( dtype: torch.dtype | None, weight_mapping=None, ): - from ..core_model_loading import build_glob_alt + from ..core_model_loading import WeightRenaming, build_glob_alternation, repl if disk_offload_folder is not None: os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") - _patterns, weight_pattern_alt, weight_pattern_by_group_name = None, None, None + rename = False if weight_mapping is not None: - _patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping])) - source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys} - weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns) + renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] + if len(renamings) > 0: + rename = True + rename_alt, _, rename_by_group = build_glob_alternation(renamings) # In this case, the offload index is simply the existing safetensors (except if using custom weight loading # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) @@ -521,7 +500,7 @@ def accelerate_disk_offload( # Update the weight names according to the `weight_mapping` weight_renaming_map = { - update_param_name(k, weight_pattern_alt, weight_pattern_by_group_name, source_to_target): k + rename_alt.sub(lambda m: repl(m, rename_by_group), k).replace("\\", "") if rename else k: k for k in weight_map } From bd4e78858ed1726c4ee05fc4de7a52699280aaa1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 24 Nov 2025 13:42:00 +0100 Subject: [PATCH 27/32] clear confusion about variable names --- src/transformers/core_model_loading.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index d7f970fb2fb2..960d3b927dee 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -398,7 +398,7 @@ def dot_natural_key(s: str): @contextmanager def log_to_misc( - all_target_keys: str, + first_target_key: str, misc: MutableMapping[str, str], extras: Any = None, op: Union[list[ConversionOps], ConversionOps, None] = None, @@ -422,16 +422,16 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> if isinstance(extras, tuple) and len(extras) == 2: values, target_keys = extras descriptor = f"{op_name} " if op_name else "" - misc[all_target_keys] = ( + misc[first_target_key] = ( f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}" ) elif isinstance(extras, str): suffix = f" via {op_name}" if op_name else "" - misc[all_target_keys] = f"{e}\nError{suffix} when processing parameter {extras}" + misc[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}" elif extras is None and op_name: - misc[all_target_keys] = f"{op_name}: {e}" + misc[first_target_key] = f"{op_name}: {e}" else: - misc[all_target_keys] = f"{extras} |Error: {e}" + misc[first_target_key] = f"{extras} |Error: {e}" raise SkipLayer() @@ -724,13 +724,13 @@ def convert_and_load_state_dict_in_model( total_entries = len(param_name_to_load) with logging.tqdm(total=total_entries, desc="Loading weights") as pbar: - for full_param_name, mapping in param_name_to_load.items(): + for first_param_name, mapping in param_name_to_load.items(): pbar.update(1) - pbar.set_postfix({"Materializing param": full_param_name}) + pbar.set_postfix({"Materializing param": first_param_name}) pbar.refresh() try: realized_value, misc = mapping.convert( - full_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys + first_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys ) for target_name, param in realized_value.items(): param = param[0] if isinstance(param, list) else param From 2a9009e2e9c0e1ded84dd5762b5cdc968d3ab3c4 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 24 Nov 2025 13:57:20 +0100 Subject: [PATCH 28/32] create cleaner function --- src/transformers/core_model_loading.py | 40 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 960d3b927dee..fecaf6f39fd7 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -437,7 +437,7 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> def set_param_for_module( model: PreTrainedModel, - full_param_name: str, + target_name: str, param_value: torch.Tensor, mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]], missing_keys: MutableSet[str], @@ -446,13 +446,13 @@ def set_param_for_module( distributed_operation: Optional[TensorParallelLayer], hf_quantizer: HfQuantizer, ): - with log_to_misc(full_param_name, misc, full_param_name): - module_path, _, param_name = full_param_name.rpartition(".") + with log_to_misc(target_name, misc, target_name): + module_path, _, param_name = target_name.rpartition(".") module_obj = model.get_submodule(module_path) if module_path else model ref = getattr(module_obj, param_name) if ref is None: - unexpected_keys.add(full_param_name) + unexpected_keys.add(target_name) else: use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor if not isinstance(param_value, torch.nn.Parameter): @@ -472,9 +472,9 @@ def set_param_for_module( param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) # Remove from missing keys (it's either mismatched, or all good) - missing_keys.discard(full_param_name) + missing_keys.discard(target_name) if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: - mismatch_keys.add((full_param_name, param_value.shape, ref.shape)) + mismatch_keys.add((target_name, param_value.shape, ref.shape)) module_obj.param_name._is_hf_initialized = False # Needs to be initialized else: # super important otherwise _init_weight will re-init the param @@ -482,6 +482,25 @@ def set_param_for_module( setattr(module_obj, param_name, param_value) +def offload_and_maybe_resave_param( + target_name: str, + param: torch.Tensor, + missing_keys: MutableSet[str], + disk_offload_folder: str, + disk_offload_index: dict, + applied_ops: WeightConverter | WeightRenaming, +) -> dict: + """Takes care of correctly offloading `param`. If it's not already present in the `disk_offload_index`, or if any + WeightConverter operations have been applied, it will resave the new parameter. Otherwise, it will use the original + `disk_offload_index` for this given param.""" + # We need to remove from missing keys + missing_keys.discard(target_name) + # If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save + if target_name not in disk_offload_index or isinstance(applied_ops, WeightConverter): + disk_offload_index = offload_weight(param, target_name, disk_offload_folder, disk_offload_index) + return disk_offload_index + + class SkipLayer(Exception): """Control-flow sentinel: abort processing of the current layer only.""" @@ -738,12 +757,9 @@ def convert_and_load_state_dict_in_model( param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") # Offloading support if param_device == "disk": - missing_keys.discard(target_name) - # If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save - if target_name not in disk_offload_index or isinstance(mapping, WeightConverter): - disk_offload_index = offload_weight( - param, target_name, disk_offload_folder, disk_offload_index - ) + disk_offload_index = offload_and_maybe_resave_param( + target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping + ) else: set_param_for_module( model, From a6a4c45c72fca3027750c95a68a3af4272acdd08 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 24 Nov 2025 14:03:11 +0100 Subject: [PATCH 29/32] better doc --- src/transformers/integrations/accelerate.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index a9c4606d8839..9a28ee5c4a6c 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -474,6 +474,12 @@ def accelerate_disk_offload( dtype: torch.dtype | None, weight_mapping=None, ): + """ + Prepare the `disk_offload_index` that will be used for reading offloaded parameters. If reading from a safetensors + file, parameters which do not need any special WeightConverter operation during loading (i.e. they are used as-is, or only + renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside + `disk_offload_folder` during loading. + """ from ..core_model_loading import WeightRenaming, build_glob_alternation, repl if disk_offload_folder is not None: From f4619dd1e0bb802138ee3ee1601588c276544dfb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 24 Nov 2025 14:24:51 +0100 Subject: [PATCH 30/32] other tests --- src/transformers/integrations/accelerate.py | 2 +- tests/utils/test_modeling_utils.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 9a28ee5c4a6c..b7031d7b0323 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -502,7 +502,7 @@ def accelerate_disk_offload( weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0]) else: folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) - weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} + weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()} # Update the weight names according to the `weight_mapping` weight_renaming_map = { diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 9579e23d35ac..f89579ad2fde 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1246,7 +1246,6 @@ def test_save_offloaded_model(self): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("TODO @cyrilvallez when saving") def test_save_offloaded_model_with_direct_params(self): from accelerate import dispatch_model From c41ea843877e5fcf35b794670d9f4bce96f1385a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 24 Nov 2025 14:27:57 +0100 Subject: [PATCH 31/32] remove skip --- tests/utils/test_modeling_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index f89579ad2fde..02e94e71f554 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1208,7 +1208,6 @@ def test_save_model_with_device_map_cpu(self): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("TODO @cyrilvallez when saving") def test_save_offloaded_model(self): device_map = { "transformer.wte": f"{torch_device}:0", From da954ef760d572cf73b46a997b8ec416dce377de Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 24 Nov 2025 14:45:27 +0100 Subject: [PATCH 32/32] unskip other tests --- src/transformers/integrations/accelerate.py | 3 ++- tests/utils/test_modeling_utils.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index b7031d7b0323..f201ae3970be 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -518,7 +518,8 @@ def accelerate_disk_offload( "dtype": str_dtype, } for target_name, source_name in weight_renaming_map.items() - if param_device_map[target_name] == "disk" + # Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them) + if target_name in param_device_map and param_device_map[target_name] == "disk" } # In this case we will resave every offloaded weight else: diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 02e94e71f554..21703f822302 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2055,7 +2055,6 @@ def test_ignore_missing_key_works(self): for k, v in model.state_dict().items(): self.assertTrue(v.device.type == "cpu", f"{k} is not on cpu!") - @unittest.skip("TODO fix offloaded in another PR @CyrilVallez") def test_device_map_works_with_unexpected_keys(self): """Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually present in the checkpoint, it will correctly be removed from the weights we load, especially those @@ -2079,7 +2078,6 @@ def test_device_map_works_with_unexpected_keys(self): # Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out. BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"}) - @unittest.skip("TODO fix offloaded in another PR @CyrilVallez") def test_device_map_works_with_unexpected_keys_sharded(self): """Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually present in the checkpoint, it will correctly be removed from the weights we load, especially those