3030
3131import torch
3232
33+ from .integrations .accelerate import offload_weight
3334from .integrations .tensor_parallel import ALL_PARALLEL_STYLES , DTensor , Replicate , TensorParallelLayer
3435from .utils import is_torch_greater_or_equal , logging
3536
@@ -344,7 +345,7 @@ def dot_natural_key(s: str):
344345
345346@contextmanager
346347def log_to_misc (
347- layer_name : str ,
348+ full_param_name : str ,
348349 misc : MutableMapping [str , str ],
349350 extras : Any = None ,
350351 op : Union [list [ConversionOps ], ConversionOps , None ] = None ,
@@ -368,30 +369,30 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
368369 if isinstance (extras , tuple ) and len (extras ) == 2 :
369370 values , target_keys = extras
370371 descriptor = f"{ op_name } " if op_name else ""
371- misc [layer_name ] = (
372+ misc [full_param_name ] = (
372373 f"{ e } \n Error: { descriptor } on tensors destined for { target_keys } . Ckpt contains: { len (values [0 ])} "
373374 )
374375 elif isinstance (extras , str ):
375376 suffix = f" via { op_name } " if op_name else ""
376- misc [layer_name ] = f"{ e } \n Error{ suffix } when processing parameter { extras } "
377+ misc [full_param_name ] = f"{ e } \n Error{ suffix } when processing parameter { extras } "
377378 elif extras is None and op_name :
378- misc [layer_name ] = f"{ op_name } : { e } "
379+ misc [full_param_name ] = f"{ op_name } : { e } "
379380 else :
380- misc [layer_name ] = f"{ extras } |Error: { e } "
381+ misc [full_param_name ] = f"{ extras } |Error: { e } "
381382 raise SkipLayer ()
382383
383384
384385def set_param_for_module (
385386 model : PreTrainedModel ,
386- layer_name : str ,
387+ full_param_name : str ,
387388 param_value : torch .Tensor ,
388389 mismatch_keys : MutableSet [tuple [str , torch .Size , torch .Size ]],
389390 missing_keys : MutableSet [str ],
390391 misc : MutableMapping [str , Any ],
391392 distributed_operation : Optional [TensorParallelLayer ],
392393):
393- with log_to_misc (layer_name , misc , layer_name ):
394- module_path , _ , param_name = layer_name .rpartition ("." )
394+ with log_to_misc (full_param_name , misc , full_param_name ):
395+ module_path , _ , param_name = full_param_name .rpartition ("." )
395396 module_obj = model .get_submodule (module_path ) if module_path else model
396397 param_value = param_value [0 ] if isinstance (param_value , list ) else param_value [...]
397398 ref = getattr (module_obj , param_name )
@@ -414,9 +415,9 @@ def set_param_for_module(
414415 param_value = torch .nn .Parameter (param_value , requires_grad = param_value .is_floating_point ())
415416
416417 # Remove from missing keys (it's either mismatched, or all good)
417- missing_keys .discard (layer_name )
418+ missing_keys .discard (full_param_name )
418419 if ref is not None and ref .shape != param_value .shape :
419- mismatch_keys .add ((layer_name , param_value .shape , ref .shape ))
420+ mismatch_keys .add ((full_param_name , param_value .shape , ref .shape ))
420421 module_obj .param_name ._is_hf_initialized = False # Needs to be initialized
421422 else :
422423 param_value ._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
@@ -439,6 +440,8 @@ def convert_and_load_state_dict_in_model(
439440 device_map : dict | None = None ,
440441 dtype_plan : dict | None = None ,
441442 device_mesh : torch .distributed .device_mesh .DeviceMesh | None = None ,
443+ disk_offload_index : dict | None = None ,
444+ disk_offload_folder : str | None = None ,
442445):
443446 """
444447 Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
@@ -536,7 +539,7 @@ def convert_and_load_state_dict_in_model(
536539 shard_index ,
537540 )
538541
539- if future is None : # If not TP, async materialize the tensors. TODO handle disk offload?
542+ if future is None :
540543 device_match = device_map_regex .match (first_target_key )
541544 param_device = device_map [device_match .group ()] if device_match else device_map .get ("" , "cpu" )
542545 future = spawn_materialize (thread_pool , tensor , param_device , _dtype )
@@ -551,29 +554,29 @@ def convert_and_load_state_dict_in_model(
551554 group = by_conversion_pattern .pop (key )
552555 converter = group .weight_converter
553556 operations = converter .operations if isinstance (converter .operations , list ) else [converter .operations ]
554- for layer_name , tensors_for_this_layer in group .collected_tensors .items ():
557+ for full_param_name , tensors_for_this_layer in group .collected_tensors .items ():
555558 pbar .update (1 )
556- pbar .set_postfix ({"Materializing param" : layer_name })
559+ pbar .set_postfix ({"Materializing param" : full_param_name })
557560 pbar .refresh ()
558- concrete_target_keys = layer_name .split ("|" )
561+ concrete_target_keys = full_param_name .split ("|" )
559562 try :
560563 if bool (set (concrete_target_keys ) - unexpected_keys ):
561- with log_to_misc (layer_name , misc ):
564+ with log_to_misc (full_param_name , misc ):
562565 values = [[k .result () for k in inner ] for inner in tensors_for_this_layer .values ()]
563566
564567 for op in operations :
565- with log_to_misc (layer_name , misc , (values , concrete_target_keys ), operations ):
568+ with log_to_misc (full_param_name , misc , (values , concrete_target_keys ), operations ):
566569 values = op .convert (values , model .config )
567570
568571 values = [values ] if not isinstance (values , list ) else values
569- with log_to_misc (layer_name , misc , (values , concrete_target_keys ), operations ):
572+ with log_to_misc (full_param_name , misc , (values , concrete_target_keys ), operations ):
570573 realized_value = {
571574 k : t for k , t in zip (concrete_target_keys , values ) if k not in unexpected_keys
572575 }
573576
574577 for k in list (realized_value .keys ()).copy ():
575578 if op := converter .quantization_operation :
576- with log_to_misc (layer_name , misc , op = op ):
579+ with log_to_misc (full_param_name , misc , op = op ):
577580 realized_value .update (
578581 op .convert (
579582 {k : realized_value .pop (k )}, quant_config = quantizer .quantization_config
@@ -583,15 +586,26 @@ def convert_and_load_state_dict_in_model(
583586 for k , output_value in realized_value .items ():
584587 for src in converter .source_keys : # what should happen to k when we meet k at saving
585588 inverse_converters [k ] = {src : converter }
586- set_param_for_module (
587- model ,
588- k ,
589- output_value ,
590- mismatch_keys ,
591- missing_keys ,
592- misc ,
593- converter .distributed_operation ,
594- )
589+
590+ param_device = device_map [re .search (device_map_regex , k ).group ()]
591+ # Offloading support
592+ if param_device == "disk" :
593+ missing_keys .discard (k )
594+ # If not already offloaded, or if we applied any special Operation, we need to re-save
595+ if k not in disk_offload_index or len (operations ) > 0 :
596+ disk_offload_index = offload_weight (
597+ output_value , k , disk_offload_folder , disk_offload_index
598+ )
599+ else :
600+ set_param_for_module (
601+ model ,
602+ k ,
603+ output_value ,
604+ mismatch_keys ,
605+ missing_keys ,
606+ misc ,
607+ converter .distributed_operation ,
608+ )
595609
596610 except SkipLayer :
597611 continue
0 commit comments