Skip to content

Commit 7f75ba8

Browse files
committed
fix
1 parent 5a5dc9e commit 7f75ba8

File tree

3 files changed

+45
-36
lines changed

3 files changed

+45
-36
lines changed

src/transformers/core_model_loading.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import torch
3232

33+
from .integrations.accelerate import offload_weight
3334
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer
3435
from .utils import is_torch_greater_or_equal, logging
3536

@@ -344,7 +345,7 @@ def dot_natural_key(s: str):
344345

345346
@contextmanager
346347
def log_to_misc(
347-
layer_name: str,
348+
full_param_name: str,
348349
misc: MutableMapping[str, str],
349350
extras: Any = None,
350351
op: Union[list[ConversionOps], ConversionOps, None] = None,
@@ -368,30 +369,30 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
368369
if isinstance(extras, tuple) and len(extras) == 2:
369370
values, target_keys = extras
370371
descriptor = f"{op_name} " if op_name else ""
371-
misc[layer_name] = (
372+
misc[full_param_name] = (
372373
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
373374
)
374375
elif isinstance(extras, str):
375376
suffix = f" via {op_name}" if op_name else ""
376-
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
377+
misc[full_param_name] = f"{e}\nError{suffix} when processing parameter {extras}"
377378
elif extras is None and op_name:
378-
misc[layer_name] = f"{op_name}: {e}"
379+
misc[full_param_name] = f"{op_name}: {e}"
379380
else:
380-
misc[layer_name] = f"{extras} |Error: {e}"
381+
misc[full_param_name] = f"{extras} |Error: {e}"
381382
raise SkipLayer()
382383

383384

384385
def set_param_for_module(
385386
model: PreTrainedModel,
386-
layer_name: str,
387+
full_param_name: str,
387388
param_value: torch.Tensor,
388389
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
389390
missing_keys: MutableSet[str],
390391
misc: MutableMapping[str, Any],
391392
distributed_operation: Optional[TensorParallelLayer],
392393
):
393-
with log_to_misc(layer_name, misc, layer_name):
394-
module_path, _, param_name = layer_name.rpartition(".")
394+
with log_to_misc(full_param_name, misc, full_param_name):
395+
module_path, _, param_name = full_param_name.rpartition(".")
395396
module_obj = model.get_submodule(module_path) if module_path else model
396397
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
397398
ref = getattr(module_obj, param_name)
@@ -414,9 +415,9 @@ def set_param_for_module(
414415
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
415416

416417
# Remove from missing keys (it's either mismatched, or all good)
417-
missing_keys.discard(layer_name)
418+
missing_keys.discard(full_param_name)
418419
if ref is not None and ref.shape != param_value.shape:
419-
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
420+
mismatch_keys.add((full_param_name, param_value.shape, ref.shape))
420421
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
421422
else:
422423
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
@@ -439,6 +440,8 @@ def convert_and_load_state_dict_in_model(
439440
device_map: dict | None = None,
440441
dtype_plan: dict | None = None,
441442
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
443+
disk_offload_index: dict | None = None,
444+
disk_offload_folder: str | None = None,
442445
):
443446
"""
444447
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
@@ -448,6 +451,7 @@ def convert_and_load_state_dict_in_model(
448451
prefix = model.base_model_prefix
449452
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
450453
device_map = device_map or {} # {exact_target_key: device}
454+
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
451455
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
452456
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
453457
meta_model_state_dict = model.state_dict()
@@ -533,7 +537,7 @@ def convert_and_load_state_dict_in_model(
533537
shard_index,
534538
)
535539

536-
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
540+
if future is None:
537541
future = spawn_materialize(thread_pool, tensor, _dtype)
538542
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
539543

@@ -546,29 +550,29 @@ def convert_and_load_state_dict_in_model(
546550
group = by_conversion_pattern.pop(key)
547551
converter = group.weight_converter
548552
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
549-
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
553+
for full_param_name, tensors_for_this_layer in group.collected_tensors.items():
550554
pbar.update(1)
551-
pbar.set_postfix({"Materializing param": layer_name})
555+
pbar.set_postfix({"Materializing param": full_param_name})
552556
pbar.refresh()
553-
concrete_target_keys = layer_name.split("|")
557+
concrete_target_keys = full_param_name.split("|")
554558
try:
555559
if bool(set(concrete_target_keys) - unexpected_keys):
556-
with log_to_misc(layer_name, misc):
560+
with log_to_misc(full_param_name, misc):
557561
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
558562

559563
for op in operations:
560-
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
564+
with log_to_misc(full_param_name, misc, (values, concrete_target_keys), operations):
561565
values = op.convert(values, model.config)
562566

563567
values = [values] if not isinstance(values, list) else values
564-
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
568+
with log_to_misc(full_param_name, misc, (values, concrete_target_keys), operations):
565569
realized_value = {
566570
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
567571
}
568572

569573
for k in list(realized_value.keys()).copy():
570574
if op := converter.quantization_operation:
571-
with log_to_misc(layer_name, misc, op=op):
575+
with log_to_misc(full_param_name, misc, op=op):
572576
realized_value.update(
573577
op.convert(
574578
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
@@ -578,15 +582,26 @@ def convert_and_load_state_dict_in_model(
578582
for k, output_value in realized_value.items():
579583
for src in converter.source_keys: # what should happen to k when we meet k at saving
580584
inverse_converters[k] = {src: converter}
581-
set_param_for_module(
582-
model,
583-
k,
584-
output_value,
585-
mismatch_keys,
586-
missing_keys,
587-
misc,
588-
converter.distributed_operation,
589-
)
585+
586+
param_device = device_map[re.search(device_map_regex, k).group()]
587+
# Offloading support
588+
if param_device == "disk":
589+
missing_keys.discard(k)
590+
# If not already offloaded, or if we applied any special Operation, we need to re-save
591+
if k not in disk_offload_index or len(operations) > 0:
592+
disk_offload_index = offload_weight(
593+
output_value, k, disk_offload_folder, disk_offload_index
594+
)
595+
else:
596+
set_param_for_module(
597+
model,
598+
k,
599+
output_value,
600+
mismatch_keys,
601+
missing_keys,
602+
misc,
603+
converter.distributed_operation,
604+
)
590605

591606
except SkipLayer:
592607
continue

src/transformers/integrations/accelerate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def accelerate_disk_offload(
508508
os.makedirs(disk_offload_folder, exist_ok=True)
509509
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
510510

511-
# In this cause, the offload index is simply the existing safetensors (except if using custom weight loading
511+
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
512512
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
513513
if is_offloaded_safetensors:
514514
param_device_map = expand_device_map(device_map, expected_keys)

src/transformers/modeling_utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
check_and_set_device_map,
6060
expand_device_map,
6161
init_empty_weights,
62-
offload_weight,
6362
)
6463
from .integrations.deepspeed import _load_state_dict_into_zero3_model
6564
from .integrations.eager_paged import eager_paged_attention_forward
@@ -4268,6 +4267,8 @@ def _load_pretrained_model(
42684267
device_map,
42694268
model.dtype_plan,
42704269
device_mesh,
4270+
disk_offload_index,
4271+
disk_offload_folder,
42714272
)
42724273

42734274
# finally close all opened file pointers
@@ -4319,13 +4320,6 @@ def _load_pretrained_model(
43194320
device_mesh,
43204321
)
43214322

4322-
# If the model parameters were changed during loading (i.e. any custom Ops on the weights), we need to resave them
4323-
# for offloading
4324-
if device_map is not None and "disk" in device_map.values():
4325-
for name, param in model.state_dict().items():
4326-
if name not in disk_offload_index:
4327-
disk_offload_index = offload_weight(param, name, disk_offload_folder, disk_offload_index)
4328-
43294323
log_state_dict_report(
43304324
model=model,
43314325
pretrained_model_name_or_path=pretrained_model_name_or_path,

0 commit comments

Comments
 (0)