2828
2929import torch
3030
31+ from .integrations .accelerate import offload_weight
3132from .integrations .tensor_parallel import ALL_PARALLEL_STYLES , DTensor , Replicate , TensorParallelLayer
3233from .utils import is_torch_greater_or_equal , logging
3334
@@ -397,7 +398,7 @@ def dot_natural_key(s: str):
397398
398399@contextmanager
399400def log_to_misc (
400- layer_name : str ,
401+ first_target_key : str ,
401402 misc : MutableMapping [str , str ],
402403 extras : Any = None ,
403404 op : Union [list [ConversionOps ], ConversionOps , None ] = None ,
@@ -421,22 +422,22 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
421422 if isinstance (extras , tuple ) and len (extras ) == 2 :
422423 values , target_keys = extras
423424 descriptor = f"{ op_name } " if op_name else ""
424- misc [layer_name ] = (
425+ misc [first_target_key ] = (
425426 f"{ e } \n Error: { descriptor } on tensors destined for { target_keys } . Ckpt contains: { len (values )} "
426427 )
427428 elif isinstance (extras , str ):
428429 suffix = f" via { op_name } " if op_name else ""
429- misc [layer_name ] = f"{ e } \n Error{ suffix } when processing parameter { extras } "
430+ misc [first_target_key ] = f"{ e } \n Error{ suffix } when processing parameter { extras } "
430431 elif extras is None and op_name :
431- misc [layer_name ] = f"{ op_name } : { e } "
432+ misc [first_target_key ] = f"{ op_name } : { e } "
432433 else :
433- misc [layer_name ] = f"{ extras } |Error: { e } "
434+ misc [first_target_key ] = f"{ extras } |Error: { e } "
434435 raise SkipLayer ()
435436
436437
437438def set_param_for_module (
438439 model : PreTrainedModel ,
439- layer_name : str ,
440+ target_name : str ,
440441 param_value : torch .Tensor ,
441442 mismatch_keys : MutableSet [tuple [str , torch .Size , torch .Size ]],
442443 missing_keys : MutableSet [str ],
@@ -445,17 +446,13 @@ def set_param_for_module(
445446 distributed_operation : Optional [TensorParallelLayer ],
446447 hf_quantizer : HfQuantizer ,
447448):
448- with log_to_misc (layer_name , misc , layer_name ):
449- module_path , _ , param_name = layer_name .rpartition ("." )
449+ with log_to_misc (target_name , misc , target_name ):
450+ module_path , _ , param_name = target_name .rpartition ("." )
450451 module_obj = model .get_submodule (module_path ) if module_path else model
451- if isinstance (param_value , list ):
452- param_value = param_value [0 ]
453- elif not isinstance (param_value , torch .nn .Parameter ):
454- param_value = param_value [...]
455452
456453 ref = getattr (module_obj , param_name )
457454 if ref is None :
458- unexpected_keys .add (layer_name )
455+ unexpected_keys .add (target_name )
459456 else :
460457 use_dtensor = hasattr (distributed_operation , "use_dtensor" ) and distributed_operation .use_dtensor
461458 if not isinstance (param_value , torch .nn .Parameter ):
@@ -475,17 +472,35 @@ def set_param_for_module(
475472 param_value = torch .nn .Parameter (param_value , requires_grad = param_value .is_floating_point ())
476473
477474 # Remove from missing keys (it's either mismatched, or all good)
478- missing_keys .discard (layer_name )
475+ missing_keys .discard (target_name )
479476 if ref is not None and ref .shape != param_value .shape and hf_quantizer is None :
480- mismatch_keys .add ((layer_name , param_value .shape , ref .shape ))
477+ mismatch_keys .add ((target_name , param_value .shape , ref .shape ))
481478 module_obj .param_name ._is_hf_initialized = False # Needs to be initialized
482479 else :
483- param_value ._is_hf_initialized = (
484- True # super important otherwise _init_weight re-initi if bias is missing
485- )
480+ # super important otherwise _init_weight will re-init the param
481+ param_value ._is_hf_initialized = True
486482 setattr (module_obj , param_name , param_value )
487483
488484
485+ def offload_and_maybe_resave_param (
486+ target_name : str ,
487+ param : torch .Tensor ,
488+ missing_keys : MutableSet [str ],
489+ disk_offload_folder : str ,
490+ disk_offload_index : dict ,
491+ applied_ops : WeightConverter | WeightRenaming ,
492+ ) -> dict :
493+ """Takes care of correctly offloading `param`. If it's not already present in the `disk_offload_index`, or if any
494+ WeightConverter operations have been applied, it will resave the new parameter. Otherwise, it will use the original
495+ `disk_offload_index` for this given param."""
496+ # We need to remove from missing keys
497+ missing_keys .discard (target_name )
498+ # If not already offloaded, or if we applied any special Operation except Renaming, we need to re-save
499+ if target_name not in disk_offload_index or isinstance (applied_ops , WeightConverter ):
500+ disk_offload_index = offload_weight (param , target_name , disk_offload_folder , disk_offload_index )
501+ return disk_offload_index
502+
503+
489504class SkipLayer (Exception ):
490505 """Control-flow sentinel: abort processing of the current layer only."""
491506
@@ -521,6 +536,8 @@ def convert_and_load_state_dict_in_model(
521536 device_map : dict | None = None ,
522537 dtype_plan : dict | None = None ,
523538 device_mesh : torch .distributed .device_mesh .DeviceMesh | None = None ,
539+ disk_offload_index : dict | None = None ,
540+ disk_offload_folder : str | None = None ,
524541):
525542 r"""
526543 We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules.
@@ -612,6 +629,7 @@ def convert_and_load_state_dict_in_model(
612629 prefix = model .base_model_prefix
613630 tp_plan = tp_plan or {}
614631 device_map = device_map or {"" : "cpu" }
632+ # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
615633 device_map_regex = re .compile (
616634 "|" .join (rf"({ k } )" for k in sorted (device_map .keys (), key = lambda x : (x .count ("." ), len (x )), reverse = True ))
617635 )
@@ -708,9 +726,11 @@ def convert_and_load_state_dict_in_model(
708726 shard_index ,
709727 )
710728
711- if future is None : # TODO handle disk offload
729+ if future is None :
712730 device_match = device_map_regex .match (renamed_key )
713731 param_device = device_map [device_match .group ()] if device_match else device_map .get ("" , "cpu" )
732+ # If disk, we need to materialize on cpu first
733+ param_device = "cpu" if param_device == "disk" else param_device
714734 future = spawn_materialize (thread_pool , tensor , param_device , _dtype )
715735
716736 mapping .add_tensor (renamed_key , original_key , source_pattern , future )
@@ -723,30 +743,40 @@ def convert_and_load_state_dict_in_model(
723743
724744 total_entries = len (param_name_to_load )
725745 with logging .tqdm (total = total_entries , desc = "Loading weights" ) as pbar :
726- for layer_name , mapping in param_name_to_load .items ():
746+ for first_param_name , mapping in param_name_to_load .items ():
727747 pbar .update (1 )
728- pbar .set_postfix ({"Materializing param" : layer_name })
748+ pbar .set_postfix ({"Materializing param" : first_param_name })
729749 pbar .refresh ()
730750 try :
731751 realized_value , misc = mapping .convert (
732- layer_name , config = model .config , quantizer = hf_quantizer , missing_keys = missing_keys
752+ first_param_name , config = model .config , quantizer = hf_quantizer , missing_keys = missing_keys
733753 )
734- for k , output_value in realized_value .items ():
735- set_param_for_module (
736- model ,
737- k ,
738- output_value ,
739- mismatch_keys ,
740- missing_keys ,
741- misc ,
742- unexpected_keys ,
743- mapping .distributed_operation ,
744- hf_quantizer ,
745- )
754+ for target_name , param in realized_value .items ():
755+ param = param [0 ] if isinstance (param , list ) else param
756+ device_match = device_map_regex .match (target_name )
757+ param_device = device_map [device_match .group ()] if device_match else device_map .get ("" , "cpu" )
758+ # Offloading support
759+ if param_device == "disk" :
760+ disk_offload_index = offload_and_maybe_resave_param (
761+ target_name , param , missing_keys , disk_offload_folder , disk_offload_index , mapping
762+ )
763+ else :
764+ set_param_for_module (
765+ model ,
766+ target_name ,
767+ param ,
768+ mismatch_keys ,
769+ missing_keys ,
770+ misc ,
771+ unexpected_keys ,
772+ mapping .distributed_operation ,
773+ hf_quantizer ,
774+ )
746775 except SkipLayer :
747776 continue
777+
748778 thread_pool .shutdown (wait = False )
749- return missing_keys , unexpected_keys , mismatch_keys , misc
779+ return missing_keys , unexpected_keys , mismatch_keys , disk_offload_index , misc
750780
751781
752782# TODO this is not done yet!
0 commit comments