3030
3131import torch
3232
33+ from .integrations .accelerate import offload_weight
3334from .integrations .tensor_parallel import ALL_PARALLEL_STYLES , DTensor , Replicate , TensorParallelLayer
3435from .utils import is_torch_greater_or_equal , logging
3536
@@ -344,7 +345,7 @@ def dot_natural_key(s: str):
344345
345346@contextmanager
346347def log_to_misc (
347- layer_name : str ,
348+ full_param_name : str ,
348349 misc : MutableMapping [str , str ],
349350 extras : Any = None ,
350351 op : Union [list [ConversionOps ], ConversionOps , None ] = None ,
@@ -368,30 +369,30 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
368369 if isinstance (extras , tuple ) and len (extras ) == 2 :
369370 values , target_keys = extras
370371 descriptor = f"{ op_name } " if op_name else ""
371- misc [layer_name ] = (
372+ misc [full_param_name ] = (
372373 f"{ e } \n Error: { descriptor } on tensors destined for { target_keys } . Ckpt contains: { len (values [0 ])} "
373374 )
374375 elif isinstance (extras , str ):
375376 suffix = f" via { op_name } " if op_name else ""
376- misc [layer_name ] = f"{ e } \n Error{ suffix } when processing parameter { extras } "
377+ misc [full_param_name ] = f"{ e } \n Error{ suffix } when processing parameter { extras } "
377378 elif extras is None and op_name :
378- misc [layer_name ] = f"{ op_name } : { e } "
379+ misc [full_param_name ] = f"{ op_name } : { e } "
379380 else :
380- misc [layer_name ] = f"{ extras } |Error: { e } "
381+ misc [full_param_name ] = f"{ extras } |Error: { e } "
381382 raise SkipLayer ()
382383
383384
384385def set_param_for_module (
385386 model : PreTrainedModel ,
386- layer_name : str ,
387+ full_param_name : str ,
387388 param_value : torch .Tensor ,
388389 mismatch_keys : MutableSet [tuple [str , torch .Size , torch .Size ]],
389390 missing_keys : MutableSet [str ],
390391 misc : MutableMapping [str , Any ],
391392 distributed_operation : Optional [TensorParallelLayer ],
392393):
393- with log_to_misc (layer_name , misc , layer_name ):
394- module_path , _ , param_name = layer_name .rpartition ("." )
394+ with log_to_misc (full_param_name , misc , full_param_name ):
395+ module_path , _ , param_name = full_param_name .rpartition ("." )
395396 module_obj = model .get_submodule (module_path ) if module_path else model
396397 param_value = param_value [0 ] if isinstance (param_value , list ) else param_value [...]
397398 ref = getattr (module_obj , param_name )
@@ -414,9 +415,9 @@ def set_param_for_module(
414415 param_value = torch .nn .Parameter (param_value , requires_grad = param_value .is_floating_point ())
415416
416417 # Remove from missing keys (it's either mismatched, or all good)
417- missing_keys .discard (layer_name )
418+ missing_keys .discard (full_param_name )
418419 if ref is not None and ref .shape != param_value .shape :
419- mismatch_keys .add ((layer_name , param_value .shape , ref .shape ))
420+ mismatch_keys .add ((full_param_name , param_value .shape , ref .shape ))
420421 module_obj .param_name ._is_hf_initialized = False # Needs to be initialized
421422 else :
422423 param_value ._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
@@ -439,6 +440,8 @@ def convert_and_load_state_dict_in_model(
439440 device_map : dict | None = None ,
440441 dtype_plan : dict | None = None ,
441442 device_mesh : torch .distributed .device_mesh .DeviceMesh | None = None ,
443+ disk_offload_index : dict | None = None ,
444+ disk_offload_folder : str | None = None ,
442445):
443446 """
444447 Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
@@ -448,6 +451,7 @@ def convert_and_load_state_dict_in_model(
448451 prefix = model .base_model_prefix
449452 tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
450453 device_map = device_map or {} # {exact_target_key: device}
454+ device_map_regex = "|" .join ([re .escape (k ) for k in sorted (device_map .keys (), reverse = True )])
451455 dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
452456 weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
453457 meta_model_state_dict = model .state_dict ()
@@ -533,7 +537,7 @@ def convert_and_load_state_dict_in_model(
533537 shard_index ,
534538 )
535539
536- if future is None : # If not TP, async materialize the tensors. TODO handle disk offload?
540+ if future is None :
537541 future = spawn_materialize (thread_pool , tensor , _dtype )
538542 entry .collected_tensors [target_key ].setdefault (converter_key , []).append (future )
539543
@@ -546,29 +550,29 @@ def convert_and_load_state_dict_in_model(
546550 group = by_conversion_pattern .pop (key )
547551 converter = group .weight_converter
548552 operations = converter .operations if isinstance (converter .operations , list ) else [converter .operations ]
549- for layer_name , tensors_for_this_layer in group .collected_tensors .items ():
553+ for full_param_name , tensors_for_this_layer in group .collected_tensors .items ():
550554 pbar .update (1 )
551- pbar .set_postfix ({"Materializing param" : layer_name })
555+ pbar .set_postfix ({"Materializing param" : full_param_name })
552556 pbar .refresh ()
553- concrete_target_keys = layer_name .split ("|" )
557+ concrete_target_keys = full_param_name .split ("|" )
554558 try :
555559 if bool (set (concrete_target_keys ) - unexpected_keys ):
556- with log_to_misc (layer_name , misc ):
560+ with log_to_misc (full_param_name , misc ):
557561 values = [[k .result () for k in inner ] for inner in tensors_for_this_layer .values ()]
558562
559563 for op in operations :
560- with log_to_misc (layer_name , misc , (values , concrete_target_keys ), operations ):
564+ with log_to_misc (full_param_name , misc , (values , concrete_target_keys ), operations ):
561565 values = op .convert (values , model .config )
562566
563567 values = [values ] if not isinstance (values , list ) else values
564- with log_to_misc (layer_name , misc , (values , concrete_target_keys ), operations ):
568+ with log_to_misc (full_param_name , misc , (values , concrete_target_keys ), operations ):
565569 realized_value = {
566570 k : t for k , t in zip (concrete_target_keys , values ) if k not in unexpected_keys
567571 }
568572
569573 for k in list (realized_value .keys ()).copy ():
570574 if op := converter .quantization_operation :
571- with log_to_misc (layer_name , misc , op = op ):
575+ with log_to_misc (full_param_name , misc , op = op ):
572576 realized_value .update (
573577 op .convert (
574578 {k : realized_value .pop (k )}, quant_config = quantizer .quantization_config
@@ -578,15 +582,26 @@ def convert_and_load_state_dict_in_model(
578582 for k , output_value in realized_value .items ():
579583 for src in converter .source_keys : # what should happen to k when we meet k at saving
580584 inverse_converters [k ] = {src : converter }
581- set_param_for_module (
582- model ,
583- k ,
584- output_value ,
585- mismatch_keys ,
586- missing_keys ,
587- misc ,
588- converter .distributed_operation ,
589- )
585+
586+ param_device = device_map [re .search (device_map_regex , k ).group ()]
587+ # Offloading support
588+ if param_device == "disk" :
589+ missing_keys .discard (k )
590+ # If not already offloaded, or if we applied any special Operation, we need to re-save
591+ if k not in disk_offload_index or len (operations ) > 0 :
592+ disk_offload_index = offload_weight (
593+ output_value , k , disk_offload_folder , disk_offload_index
594+ )
595+ else :
596+ set_param_for_module (
597+ model ,
598+ k ,
599+ output_value ,
600+ mismatch_keys ,
601+ missing_keys ,
602+ misc ,
603+ converter .distributed_operation ,
604+ )
590605
591606 except SkipLayer :
592607 continue
0 commit comments