Skip to content

Commit a84404d

Browse files
committed
fix
1 parent 0bbec14 commit a84404d

File tree

3 files changed

+44
-36
lines changed

3 files changed

+44
-36
lines changed

src/transformers/core_model_loading.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import torch
3232

33+
from .integrations.accelerate import offload_weight
3334
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer
3435
from .utils import is_torch_greater_or_equal, logging
3536

@@ -344,7 +345,7 @@ def dot_natural_key(s: str):
344345

345346
@contextmanager
346347
def log_to_misc(
347-
layer_name: str,
348+
full_param_name: str,
348349
misc: MutableMapping[str, str],
349350
extras: Any = None,
350351
op: Union[list[ConversionOps], ConversionOps, None] = None,
@@ -368,30 +369,30 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
368369
if isinstance(extras, tuple) and len(extras) == 2:
369370
values, target_keys = extras
370371
descriptor = f"{op_name} " if op_name else ""
371-
misc[layer_name] = (
372+
misc[full_param_name] = (
372373
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
373374
)
374375
elif isinstance(extras, str):
375376
suffix = f" via {op_name}" if op_name else ""
376-
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
377+
misc[full_param_name] = f"{e}\nError{suffix} when processing parameter {extras}"
377378
elif extras is None and op_name:
378-
misc[layer_name] = f"{op_name}: {e}"
379+
misc[full_param_name] = f"{op_name}: {e}"
379380
else:
380-
misc[layer_name] = f"{extras} |Error: {e}"
381+
misc[full_param_name] = f"{extras} |Error: {e}"
381382
raise SkipLayer()
382383

383384

384385
def set_param_for_module(
385386
model: PreTrainedModel,
386-
layer_name: str,
387+
full_param_name: str,
387388
param_value: torch.Tensor,
388389
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
389390
missing_keys: MutableSet[str],
390391
misc: MutableMapping[str, Any],
391392
distributed_operation: Optional[TensorParallelLayer],
392393
):
393-
with log_to_misc(layer_name, misc, layer_name):
394-
module_path, _, param_name = layer_name.rpartition(".")
394+
with log_to_misc(full_param_name, misc, full_param_name):
395+
module_path, _, param_name = full_param_name.rpartition(".")
395396
module_obj = model.get_submodule(module_path) if module_path else model
396397
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
397398
ref = getattr(module_obj, param_name)
@@ -414,9 +415,9 @@ def set_param_for_module(
414415
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
415416

416417
# Remove from missing keys (it's either mismatched, or all good)
417-
missing_keys.discard(layer_name)
418+
missing_keys.discard(full_param_name)
418419
if ref is not None and ref.shape != param_value.shape:
419-
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
420+
mismatch_keys.add((full_param_name, param_value.shape, ref.shape))
420421
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
421422
else:
422423
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
@@ -439,6 +440,8 @@ def convert_and_load_state_dict_in_model(
439440
device_map: dict | None = None,
440441
dtype_plan: dict | None = None,
441442
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
443+
disk_offload_index: dict | None = None,
444+
disk_offload_folder: str | None = None,
442445
):
443446
"""
444447
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
@@ -536,7 +539,7 @@ def convert_and_load_state_dict_in_model(
536539
shard_index,
537540
)
538541

539-
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
542+
if future is None:
540543
device_match = device_map_regex.match(first_target_key)
541544
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
542545
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
@@ -551,29 +554,29 @@ def convert_and_load_state_dict_in_model(
551554
group = by_conversion_pattern.pop(key)
552555
converter = group.weight_converter
553556
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
554-
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
557+
for full_param_name, tensors_for_this_layer in group.collected_tensors.items():
555558
pbar.update(1)
556-
pbar.set_postfix({"Materializing param": layer_name})
559+
pbar.set_postfix({"Materializing param": full_param_name})
557560
pbar.refresh()
558-
concrete_target_keys = layer_name.split("|")
561+
concrete_target_keys = full_param_name.split("|")
559562
try:
560563
if bool(set(concrete_target_keys) - unexpected_keys):
561-
with log_to_misc(layer_name, misc):
564+
with log_to_misc(full_param_name, misc):
562565
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
563566

564567
for op in operations:
565-
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
568+
with log_to_misc(full_param_name, misc, (values, concrete_target_keys), operations):
566569
values = op.convert(values, model.config)
567570

568571
values = [values] if not isinstance(values, list) else values
569-
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
572+
with log_to_misc(full_param_name, misc, (values, concrete_target_keys), operations):
570573
realized_value = {
571574
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
572575
}
573576

574577
for k in list(realized_value.keys()).copy():
575578
if op := converter.quantization_operation:
576-
with log_to_misc(layer_name, misc, op=op):
579+
with log_to_misc(full_param_name, misc, op=op):
577580
realized_value.update(
578581
op.convert(
579582
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
@@ -583,15 +586,26 @@ def convert_and_load_state_dict_in_model(
583586
for k, output_value in realized_value.items():
584587
for src in converter.source_keys: # what should happen to k when we meet k at saving
585588
inverse_converters[k] = {src: converter}
586-
set_param_for_module(
587-
model,
588-
k,
589-
output_value,
590-
mismatch_keys,
591-
missing_keys,
592-
misc,
593-
converter.distributed_operation,
594-
)
589+
590+
param_device = device_map[re.search(device_map_regex, k).group()]
591+
# Offloading support
592+
if param_device == "disk":
593+
missing_keys.discard(k)
594+
# If not already offloaded, or if we applied any special Operation, we need to re-save
595+
if k not in disk_offload_index or len(operations) > 0:
596+
disk_offload_index = offload_weight(
597+
output_value, k, disk_offload_folder, disk_offload_index
598+
)
599+
else:
600+
set_param_for_module(
601+
model,
602+
k,
603+
output_value,
604+
mismatch_keys,
605+
missing_keys,
606+
misc,
607+
converter.distributed_operation,
608+
)
595609

596610
except SkipLayer:
597611
continue

src/transformers/integrations/accelerate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def accelerate_disk_offload(
508508
os.makedirs(disk_offload_folder, exist_ok=True)
509509
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
510510

511-
# In this cause, the offload index is simply the existing safetensors (except if using custom weight loading
511+
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
512512
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
513513
if is_offloaded_safetensors:
514514
param_device_map = expand_device_map(device_map, expected_keys)

src/transformers/modeling_utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
check_and_set_device_map,
6060
expand_device_map,
6161
init_empty_weights,
62-
offload_weight,
6362
)
6463
from .integrations.deepspeed import _load_state_dict_into_zero3_model
6564
from .integrations.eager_paged import eager_paged_attention_forward
@@ -4250,6 +4249,8 @@ def _load_pretrained_model(
42504249
device_map,
42514250
model.dtype_plan,
42524251
device_mesh,
4252+
disk_offload_index,
4253+
disk_offload_folder,
42534254
)
42544255

42554256
# finally close all opened file pointers
@@ -4301,13 +4302,6 @@ def _load_pretrained_model(
43014302
device_mesh,
43024303
)
43034304

4304-
# If the model parameters were changed during loading (i.e. any custom Ops on the weights), we need to resave them
4305-
# for offloading
4306-
if device_map is not None and "disk" in device_map.values():
4307-
for name, param in model.state_dict().items():
4308-
if name not in disk_offload_index:
4309-
disk_offload_index = offload_weight(param, name, disk_offload_folder, disk_offload_index)
4310-
43114305
log_state_dict_report(
43124306
model=model,
43134307
pretrained_model_name_or_path=pretrained_model_name_or_path,

0 commit comments

Comments
 (0)