diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 80cf941c37ed..fecaf6f39fd7 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -28,6 +28,7 @@ import torch +from .integrations.accelerate import offload_weight from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer from .utils import is_torch_greater_or_equal, logging @@ -397,7 +398,7 @@ def dot_natural_key(s: str): @contextmanager def log_to_misc( - layer_name: str, + first_target_key: str, misc: MutableMapping[str, str], extras: Any = None, op: Union[list[ConversionOps], ConversionOps, None] = None, @@ -421,22 +422,22 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> if isinstance(extras, tuple) and len(extras) == 2: values, target_keys = extras descriptor = f"{op_name} " if op_name else "" - misc[layer_name] = ( + misc[first_target_key] = ( f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}" ) elif isinstance(extras, str): suffix = f" via {op_name}" if op_name else "" - misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}" + misc[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}" elif extras is None and op_name: - misc[layer_name] = f"{op_name}: {e}" + misc[first_target_key] = f"{op_name}: {e}" else: - misc[layer_name] = f"{extras} |Error: {e}" + misc[first_target_key] = f"{extras} |Error: {e}" raise SkipLayer() def set_param_for_module( model: PreTrainedModel, - layer_name: str, + target_name: str, param_value: torch.Tensor, mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]], missing_keys: MutableSet[str], @@ -445,17 +446,13 @@ def set_param_for_module( distributed_operation: Optional[TensorParallelLayer], hf_quantizer: HfQuantizer, ): - with log_to_misc(layer_name, misc, layer_name): - module_path, _, param_name = layer_name.rpartition(".") + with log_to_misc(target_name, misc, target_name): + module_path, _, param_name = target_name.rpartition(".") module_obj = model.get_submodule(module_path) if module_path else model - if isinstance(param_value, list): - param_value = param_value[0] - elif not isinstance(param_value, torch.nn.Parameter): - param_value = param_value[...] ref = getattr(module_obj, param_name) if ref is None: - unexpected_keys.add(layer_name) + unexpected_keys.add(target_name) else: use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor if not isinstance(param_value, torch.nn.Parameter): @@ -475,17 +472,35 @@ def set_param_for_module( param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) # Remove from missing keys (it's either mismatched, or all good) - missing_keys.discard(layer_name) + missing_keys.discard(target_name) if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: - mismatch_keys.add((layer_name, param_value.shape, ref.shape)) + mismatch_keys.add((target_name, param_value.shape, ref.shape)) module_obj.param_name._is_hf_initialized = False # Needs to be initialized else: - param_value._is_hf_initialized = ( - True # super important otherwise _init_weight re-initi if bias is missing - ) + # super important otherwise _init_weight will re-init the param + param_value._is_hf_initialized = True setattr(module_obj, param_name, param_value) +def offload_and_maybe_resave_param( + target_name: str, + param: torch.Tensor, + missing_keys: MutableSet[str], + disk_offload_folder: str, + disk_offload_index: dict, + applied_ops: WeightConverter | WeightRenaming, +) -> dict: + """Takes care of correctly offloading `param`. If it's not already present in the `disk_offload_index`, or if any + WeightConverter operations have been applied, it will resave the new parameter. Otherwise, it will use the original + `disk_offload_index` for this given param.""" + # We need to remove from missing keys + missing_keys.discard(target_name) + # If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save + if target_name not in disk_offload_index or isinstance(applied_ops, WeightConverter): + disk_offload_index = offload_weight(param, target_name, disk_offload_folder, disk_offload_index) + return disk_offload_index + + class SkipLayer(Exception): """Control-flow sentinel: abort processing of the current layer only.""" @@ -521,6 +536,8 @@ def convert_and_load_state_dict_in_model( device_map: dict | None = None, dtype_plan: dict | None = None, device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None, + disk_offload_index: dict | None = None, + disk_offload_folder: str | None = None, ): r""" We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules. @@ -612,6 +629,7 @@ def convert_and_load_state_dict_in_model( prefix = model.base_model_prefix tp_plan = tp_plan or {} device_map = device_map or {"": "cpu"} + # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly device_map_regex = re.compile( "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True)) ) @@ -708,9 +726,11 @@ def convert_and_load_state_dict_in_model( shard_index, ) - if future is None: # TODO handle disk offload + if future is None: device_match = device_map_regex.match(renamed_key) param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") + # If disk, we need to materialize on cpu first + param_device = "cpu" if param_device == "disk" else param_device future = spawn_materialize(thread_pool, tensor, param_device, _dtype) mapping.add_tensor(renamed_key, original_key, source_pattern, future) @@ -723,30 +743,40 @@ def convert_and_load_state_dict_in_model( total_entries = len(param_name_to_load) with logging.tqdm(total=total_entries, desc="Loading weights") as pbar: - for layer_name, mapping in param_name_to_load.items(): + for first_param_name, mapping in param_name_to_load.items(): pbar.update(1) - pbar.set_postfix({"Materializing param": layer_name}) + pbar.set_postfix({"Materializing param": first_param_name}) pbar.refresh() try: realized_value, misc = mapping.convert( - layer_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys + first_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys ) - for k, output_value in realized_value.items(): - set_param_for_module( - model, - k, - output_value, - mismatch_keys, - missing_keys, - misc, - unexpected_keys, - mapping.distributed_operation, - hf_quantizer, - ) + for target_name, param in realized_value.items(): + param = param[0] if isinstance(param, list) else param + device_match = device_map_regex.match(target_name) + param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") + # Offloading support + if param_device == "disk": + disk_offload_index = offload_and_maybe_resave_param( + target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping + ) + else: + set_param_for_module( + model, + target_name, + param, + mismatch_keys, + missing_keys, + misc, + unexpected_keys, + mapping.distributed_operation, + hf_quantizer, + ) except SkipLayer: continue + thread_pool.shutdown(wait=False) - return missing_keys, unexpected_keys, mismatch_keys, misc + return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc # TODO this is not done yet! diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 9696309ef221..f201ae3970be 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -19,10 +19,14 @@ import copy import inspect import os +import re from collections import OrderedDict, defaultdict from contextlib import contextmanager from typing import TYPE_CHECKING, Optional, Union +from safetensors import safe_open +from safetensors.torch import save_file + from ..utils import ( is_accelerate_available, is_torch_available, @@ -445,71 +449,103 @@ def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload dispatch_model(model, **device_map_kwargs) -def get_disk_only_shard_files(device_map, weight_map): - """ - Returns the list of shard files containing only weights offloaded to disk. - """ - files_content = defaultdict(list) - for weight_name, filename in weight_map.items(): - while len(weight_name) > 0 and weight_name not in device_map: - weight_name = ".".join(weight_name.split(".")[:-1]) - files_content[filename].append(device_map[weight_name]) - - return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] - - def expand_device_map(device_map, param_names): """ Expand a device map to return the correspondence parameter name to device. """ + # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly + device_map_regex = re.compile( + "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True)) + ) new_device_map = {} - for module, device in device_map.items(): - new_device_map.update( - {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} - ) + for param in param_names: + device_match = device_map_regex.match(param) + new_device_map[param] = device_map[device_match.group()] if device_match else device_map.get("", "cpu") + return new_device_map def accelerate_disk_offload( - disk_offload_folder, - checkpoint_files, - device_map, - checkpoint_keys, - sharded_metadata, - dtype, + disk_offload_folder: str | None, + checkpoint_files: list[str] | None, + device_map: dict, + expected_keys: list[str], + sharded_metadata: dict | None, + dtype: torch.dtype | None, + weight_mapping=None, ): - disk_only_shard_files = [] + """ + Prepare the `disk_offload_index` that will be used for reading offloaded parameters. If reading from a safetensors + file, parameters which do not need any special WeightConverter operation during loading (i.e. they are used as-is, or only + renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside + `disk_offload_folder` during loading. + """ + from ..core_model_loading import WeightRenaming, build_glob_alternation, repl + if disk_offload_folder is not None: os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") - if disk_offload_folder is None and not is_offloaded_safetensors: - raise ValueError( - "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" - " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" - " offers the weights in this format." - ) + + rename = False + if weight_mapping is not None: + renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] + if len(renamings) > 0: + rename = True + rename_alt, _, rename_by_group = build_glob_alternation(renamings) + + # In this case, the offload index is simply the existing safetensors (except if using custom weight loading + # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) if is_offloaded_safetensors: - param_device_map = expand_device_map(device_map, checkpoint_keys) + param_device_map = expand_device_map(device_map, expected_keys) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if sharded_metadata is None: - weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0]) + weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0]) else: folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) - weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} - # Find potential checkpoints containing only offloaded weights - disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map) + weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()} + + # Update the weight names according to the `weight_mapping` + weight_renaming_map = { + rename_alt.sub(lambda m: repl(m, rename_by_group), k).replace("\\", "") if rename else k: k + for k in weight_map + } + + # Prepare the index using existing safetensors files disk_offload_index = { - name: { - "safetensors_file": file, - "weight_name": name, + target_name: { + "safetensors_file": weight_map[source_name], + "weight_name": source_name, "dtype": str_dtype, } - for name, file in weight_map.items() - if param_device_map[name] == "disk" + for target_name, source_name in weight_renaming_map.items() + # Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them) + if target_name in param_device_map and param_device_map[target_name] == "disk" } + # In this case we will resave every offloaded weight else: disk_offload_index = {} - return disk_offload_index, disk_only_shard_files, is_offloaded_safetensors + + return disk_offload_index + + +def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | None, offload_index: dict) -> dict: + """Write `weight` to disk inside `offload_folder`, and update `offload_index` accordingly. Everything is + saved in `safetensors` format.""" + + if offload_folder is None: + raise ValueError( + "The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either " + "because the weights are not in `safetensors` format, or because the model uses an internal weight format " + "different than the one saved (i.e. most MoE models). Please provide an `offload_folder` for them in " + "`from_pretrained`." + ) + # Write the weight to disk + safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors") + save_file({weight_name: weight}, safetensor_file) + # Update the offloading index + str_dtype = str(weight.dtype).replace("torch.", "") + offload_index[weight_name] = {"safetensors_file": safetensor_file, "weight_name": weight_name, "dtype": str_dtype} + return offload_index def _init_infer_auto_device_map( diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3dc522efe251..6a49b070d402 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -59,6 +59,7 @@ from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled from .integrations.accelerate import ( _get_device_map, + accelerate_disk_offload, accelerate_dispatch, check_and_set_device_map, expand_device_map, @@ -131,9 +132,7 @@ if is_accelerate_available(): from accelerate.hooks import add_hook_to_module - from accelerate.utils import ( - extract_model_from_parallel, - ) + from accelerate.utils import extract_model_from_parallel from accelerate.utils.modeling import get_state_dict_from_offload @@ -4065,6 +4064,20 @@ def _load_pretrained_model( if logger.level >= logging.WARNING: verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None)) + # This offload index if for params explicitly on the "disk" in the device_map + disk_offload_index = None + # Prepare parameters offloading if needed + if device_map is not None and "disk" in device_map.values(): + disk_offload_index = accelerate_disk_offload( + disk_offload_folder, + checkpoint_files, + device_map, + expected_keys, + sharded_metadata, + dtype, + weight_mapping, + ) + # Warmup cuda to load the weights much faster on devices if device_map is not None and not is_hqq_or_quark: expanded_device_map = expand_device_map(device_map, expected_keys) @@ -4103,16 +4116,20 @@ def _load_pretrained_model( else: raise ValueError("Neither a state dict nor checkpoint files were found.") - missing_keys, unexpected_keys, mismatched_keys, misc = convert_and_load_state_dict_in_model( - model, - merged_state_dict, - weight_mapping, - tp_plan, - hf_quantizer, - dtype, - device_map, - model.dtype_plan, - device_mesh, + missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, misc = ( + convert_and_load_state_dict_in_model( + model, + merged_state_dict, + weight_mapping, + tp_plan, + hf_quantizer, + dtype, + device_map, + model.dtype_plan, + device_mesh, + disk_offload_index, + disk_offload_folder, + ) ) # finally close all opened file pointers @@ -4176,7 +4193,7 @@ def _load_pretrained_model( misc=misc, ignore_mismatched_sizes=ignore_mismatched_sizes, ) - disk_offload_index = None + return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): diff --git a/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py b/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py index 661b44ecf390..4095e19c00d0 100644 --- a/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py +++ b/tests/models/deepseek_vl_hybrid/test_modeling_deepseek_vl_hybrid.py @@ -167,6 +167,7 @@ class DeepseekVLHybridModelTest(ModelTesterMixin, GenerationTesterMixin, unittes else {} ) _is_composite = True + model_split_percents = [0.5, 0.85, 0.9] # it tries to offload everything with the default value def setUp(self): self.model_tester = DeepseekVLHybridModelTester(self) diff --git a/tests/models/glm4_moe/test_modeling_glm4_moe.py b/tests/models/glm4_moe/test_modeling_glm4_moe.py index d4cebbd02983..526e81e29766 100644 --- a/tests/models/glm4_moe/test_modeling_glm4_moe.py +++ b/tests/models/glm4_moe/test_modeling_glm4_moe.py @@ -61,6 +61,7 @@ class Glm4MoeModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Glm4MoeModelTester # used in `test_torch_compile_for_training`. Skip as "Dynamic control flow in MoE" _torch_compile_train_cls = None + model_split_percents = [0.5, 0.85, 0.9] # it tries to offload everything with the default value @require_torch_accelerator diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a7a40d887787..ccc403c6a2c9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -105,7 +105,6 @@ CONFIG_NAME, GENERATION_CONFIG_NAME, SAFE_WEIGHTS_NAME, - is_accelerate_available, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device, ) @@ -113,10 +112,6 @@ from .generation.test_utils import GenerationTesterMixin -if is_accelerate_available(): - from accelerate.utils import compute_module_sizes - - if is_torch_available(): import torch from safetensors import safe_open @@ -125,6 +120,7 @@ from torch import nn from transformers import MODEL_MAPPING + from transformers.integrations.accelerate import compute_module_sizes from transformers.integrations.tensor_parallel import _get_parameter_tp_plan from transformers.modeling_utils import load_state_dict from transformers.pytorch_utils import id_tensor_storage @@ -2357,7 +2353,6 @@ def check_device_map_is_respected(self, model, device_map): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("# TODO @CyrilVallez fix this in the other PR") def test_disk_offload_bin(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -2371,7 +2366,7 @@ def test_disk_offload_bin(self): torch.manual_seed(0) base_output = model(**inputs_dict_class) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_sizes(model)[0][""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir, safe_serialization=False) @@ -2379,12 +2374,14 @@ def test_disk_offload_bin(self): max_size = int(self.model_split_percents[0] * model_size) max_memory = {0: max_size, "cpu": max_size} # This errors out cause it's missing an offload folder - new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + new_model = model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, use_safetensors=False + ) max_size = int(self.model_split_percents[1] * model_size) max_memory = {0: max_size, "cpu": max_size} new_model = model_class.from_pretrained( - tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir, use_safetensors=False ) self.check_device_map_is_respected(new_model, new_model.hf_device_map) @@ -2402,7 +2399,6 @@ def test_disk_offload_bin(self): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("# TODO @CyrilVallez fix this in the other PR") def test_disk_offload_safetensors(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -2416,7 +2412,7 @@ def test_disk_offload_safetensors(self): torch.manual_seed(0) base_output = model(**inputs_dict_class) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_sizes(model)[0][""] with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) @@ -2441,7 +2437,6 @@ def test_disk_offload_safetensors(self): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("# TODO @CyrilVallez fix this in the other PR") def test_cpu_offload(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -2456,7 +2451,7 @@ def test_cpu_offload(self): torch.manual_seed(0) base_output = model(**inputs_dict_class) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_sizes(model)[0][""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: @@ -2499,7 +2494,7 @@ def test_model_parallelism(self): torch.manual_seed(0) base_output = model(**inputs_dict_class) - model_size = compute_module_sizes(model)[""] + model_size = compute_module_sizes(model)[0][""] # We test several splits of sizes to make sure it works. max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 023b9fd0596d..8973e9900f0f 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -250,7 +250,7 @@ def test_moe_and_qkv_conversion(self): ), WeightRenaming("mlp.w2.weight", "mlp.down_proj.weight"), ] - missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model( + missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model( model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None ) @@ -400,7 +400,7 @@ def __init__(self): ) ] - missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model( + missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model( model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=quantizer ) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 9579e23d35ac..21703f822302 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1208,7 +1208,6 @@ def test_save_model_with_device_map_cpu(self): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("TODO @cyrilvallez when saving") def test_save_offloaded_model(self): device_map = { "transformer.wte": f"{torch_device}:0", @@ -1246,7 +1245,6 @@ def test_save_offloaded_model(self): @require_accelerate @mark.accelerate_tests @require_torch_accelerator - @unittest.skip("TODO @cyrilvallez when saving") def test_save_offloaded_model_with_direct_params(self): from accelerate import dispatch_model @@ -2057,7 +2055,6 @@ def test_ignore_missing_key_works(self): for k, v in model.state_dict().items(): self.assertTrue(v.device.type == "cpu", f"{k} is not on cpu!") - @unittest.skip("TODO fix offloaded in another PR @CyrilVallez") def test_device_map_works_with_unexpected_keys(self): """Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually present in the checkpoint, it will correctly be removed from the weights we load, especially those @@ -2081,7 +2078,6 @@ def test_device_map_works_with_unexpected_keys(self): # Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out. BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"}) - @unittest.skip("TODO fix offloaded in another PR @CyrilVallez") def test_device_map_works_with_unexpected_keys_sharded(self): """Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually present in the checkpoint, it will correctly be removed from the weights we load, especially those