1313# limitations under the License.
1414
1515import json
16- import logging
17- import operator
1816import os
19- import re
20- from contextlib import contextmanager
2117from copy import deepcopy
22- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , TypeVar , Union
18+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , TypeVar , Union
2319
2420import compressed_tensors
2521import torch
26- import transformers
2722from compressed_tensors .base import (
2823 COMPRESSION_VERSION_NAME ,
2924 QUANTIZATION_CONFIG_NAME ,
3934 QuantizationConfig ,
4035 QuantizationScheme ,
4136 QuantizationStatus ,
42- apply_quantization_config ,
43- load_pretrained_quantization_parameters ,
4437)
4538from compressed_tensors .transform import TransformConfig
4639from compressed_tensors .utils import (
4740 align_module_device ,
4841 delete_offload_parameter ,
4942 get_execution_device ,
5043 get_offloaded_device ,
51- get_safetensors_folder ,
52- has_offloaded_params ,
5344 register_offload_parameter ,
54- update_parameter_data ,
5545)
5646from compressed_tensors .utils .helpers import (
5747 fix_fsdp_module_name ,
5848 is_compressed_tensors_config ,
5949)
6050from compressed_tensors .utils .match import match_named_modules
61- from torch import Tensor
6251from torch .nn import Module
6352from tqdm import tqdm
6453from transformers import AutoConfig
7160
7261__all__ = ["ModelCompressor" , "map_module_to_scheme" ]
7362
74- _LOGGER : logging .Logger = logging .getLogger (__name__ )
75-
7663
7764if TYPE_CHECKING :
7865 # dummy type if not available from transformers
@@ -488,153 +475,6 @@ def decompress_model(self, model: Module):
488475
489476 module .quantization_status = QuantizationStatus .FROZEN
490477
491- # ----- state dict compression pathways ----- #
492-
493- def compress (
494- self ,
495- model : Module ,
496- state_dict : Optional [Dict [str , Tensor ]] = None ,
497- show_progress : bool = False ,
498- ) -> Dict [str , Tensor ]:
499- """
500- Compresses a dense state dict or model with sparsity and/or quantization
501-
502- :param model: uncompressed model to compress
503- :param state_dict: optional uncompressed state_dict to insert into model
504- :return: compressed state dict
505- """
506-
507- if state_dict is None :
508- state_dict = model .state_dict ()
509-
510- if self .quantization_compressor is not None :
511- module_to_scheme = map_module_to_scheme (model )
512- # Note - compress only supports one compression format atm
513- quant_compressor = next (iter (self .quantization_compressor .values ()))
514- state_dict = quant_compressor .compress (
515- state_dict ,
516- names_to_scheme = module_to_scheme ,
517- show_progress = show_progress ,
518- )
519-
520- # TODO: consider sparse compression to also be compression
521- if self .quantization_config .format != CompressionFormat .dense .value :
522- self .quantization_config .quantization_status = (
523- QuantizationStatus .COMPRESSED
524- )
525-
526- if self .sparsity_compressor is not None :
527- sparse_compression_targets : Set [str ] = {
528- module_name
529- for module_name , _module in match_named_modules (
530- model = model ,
531- targets = self .sparsity_config .targets ,
532- ignore = self .sparsity_config .ignore ,
533- )
534- }
535- state_dict = self .sparsity_compressor .compress (
536- state_dict ,
537- compression_targets = sparse_compression_targets ,
538- show_progress = show_progress ,
539- )
540-
541- # HACK: Override the dtype_byte_size function in transformers to
542- # support float8 types. Fix is posted upstream
543- # https://github.com/huggingface/transformers/pull/30488
544- transformers .modeling_utils .dtype_byte_size = new_dtype_byte_size
545-
546- return state_dict
547-
548- # ----- disk decompression pathways ----- #
549-
550- def decompress (self , model_path : str , model : Module ):
551- """
552- Overwrites the weights in model with weights decompressed from model_path
553-
554- :param model_path: path to compressed weights
555- :param model: pytorch model to load decompressed weights into
556-
557- Note: decompress makes use of both _replace_sparsity_weights and
558- _replace_weights. The variations in these methods are a result of the subtle
559- variations between the sparsity and quantization compressors. Specifically,
560- quantization compressors return not just the decompressed weight, but the
561- quantization parameters (e.g scales, zero_point) whereas sparsity compressors
562- only return the decompressed weight.
563-
564- """
565- model_path = get_safetensors_folder (model_path )
566- sparse_decompressed = False
567- quant_compressor = (
568- next (iter (self .quantization_compressor .values ()))
569- if self .quantization_compressor is not None
570- else None
571- )
572-
573- if (
574- self .sparsity_compressor is not None
575- and self .sparsity_config .format != CompressionFormat .dense .value
576- ):
577- # note - decompress only supports one compressor atm
578- params_to_ignore = None
579- if quant_compressor is not None :
580- params_to_ignore = quant_compressor .compression_param_names
581- # Sparse decompression is applied on the model_path
582- # The compressor will try and load any quantization parameters as well
583- # params_to_skip_load will skip over quantization params from being loaded
584- dense_gen = self .sparsity_compressor .decompress (
585- model_path , params_to_skip_load = params_to_ignore
586- )
587- self ._replace_sparsity_weights (dense_gen , model )
588- setattr (model , SPARSITY_CONFIG_NAME , self .sparsity_compressor .config )
589- sparse_decompressed = True
590-
591- if quant_compressor is not None :
592- # Temporarily set quantization status to FROZEN to prevent
593- # quantization during apply_quantization_config. This ensures
594- # that the dtypes of the weights are not unintentionally updated.
595- # The status is restored after quantization params are loaded.
596-
597- with override_quantization_status (
598- self .quantization_config , QuantizationStatus .FROZEN
599- ):
600- apply_quantization_config (model , self .quantization_config )
601- names_to_scheme : Set [QuantizationScheme ] = {
602- name : getattr (module , "quantization_scheme" )
603- for name , module in model .named_modules ()
604- if getattr (module , "quantization_scheme" , None ) is not None
605- }
606- # Load activation scales/zp or any other quantization parameters
607- # Conditionally load the weight quantization parameters if we have a
608- # dense compressor or if a sparsity compressor has already been applied
609- load_weight_qparams = sparse_decompressed or isinstance (
610- quant_compressor , DenseCompressor
611- )
612- load_pretrained_quantization_parameters (
613- model ,
614- model_path ,
615- # TODO: all weight quantization params will be moved to the
616- # compressor in a follow-up including initialization
617- load_weight_qparams = load_weight_qparams ,
618- )
619- model_path_or_state_dict = (
620- model .state_dict () if sparse_decompressed else model_path
621- )
622-
623- dense_gen = quant_compressor .decompress (
624- model_path_or_state_dict , names_to_scheme = names_to_scheme
625- )
626- # TODO: all weight quantization params will be moved to the compressor
627- # to prevent duplicate parameter updates in update_parameter_data
628- self ._replace_weights (
629- dense_gen , model , load_weight_qparams = not load_weight_qparams
630- )
631-
632- def freeze_quantization_status (module ):
633- module .quantization_status = QuantizationStatus .FROZEN
634-
635- model .apply (freeze_quantization_status )
636- setattr (model , QUANTIZATION_CONFIG_NAME , self .quantization_config )
637-
638478 def update_config (self , save_directory : str ):
639479 """
640480 Update the model config located at save_directory with compression configs
@@ -688,79 +528,6 @@ def update_config(self, save_directory: str):
688528 with open (config_file_path , "w" ) as config_file :
689529 json .dump (config_data , config_file , indent = 2 , sort_keys = True )
690530
691- def _replace_sparsity_weights (self , dense_weight_generator , model : Module ):
692- """
693- Replace the weights of the model with the
694- provided dense weights.
695-
696- This method iterates over the dense_weight_generator and
697- updates the corresponding weights in the model. If a parameter
698- name does not exist in the model, it will be skipped.
699-
700- :param dense_weight_generator (generator): A generator that yields
701- tuples of (name, data), where 'name' is the parameter name and
702- 'data' is the updated param data
703- :param model: The model whose weights are to be updated.
704- """
705- for name , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
706- split_name = name .split ("." )
707- prefix , param_name = "." .join (split_name [:- 1 ]), split_name [- 1 ]
708- module = operator .attrgetter (prefix )(model )
709-
710- params_device = next (module .parameters ()).device
711- device = "cpu" if has_offloaded_params (module ) else params_device
712- delattr (module , param_name )
713- requires_grad = data .dtype in (torch .float16 , torch .float32 , torch .bfloat16 )
714- param = torch .nn .Parameter (data .to (device ), requires_grad = requires_grad )
715- register_offload_parameter (module , param_name , param )
716-
717- def _replace_weights (
718- self , dense_weight_generator , model : Module , load_weight_qparams : bool = True
719- ):
720- """
721- Replace the weights of the model with the
722- provided dense weights.
723-
724- This method iterates over the dense_weight_generator and
725- updates the corresponding weights in the model. If a parameter
726- name does not exist in the model, it will be skipped.
727-
728- :param dense_weight_generator (generator): A generator that yields
729- tuples of (name, data), where 'name' is the parameter name and
730- 'data' is the updated param data
731- :param model: The model whose weights are to be updated.
732- """
733-
734- for mod_path , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
735- module = operator .attrgetter (mod_path )(model )
736-
737- params_device = next (module .parameters ()).device
738- device = "cpu" if has_offloaded_params (module ) else params_device
739-
740- for param_name , param_data in data .items ():
741- if hasattr (module , param_name ):
742- # If compressed, will have an incorrect dtype for transformers >4.49
743- # TODO: we can also just skip initialization of scales/zp if in
744- # decompression in init to be consistent with loading which happens
745- # later as well however, update_data does a good shape check -
746- # should be moved to the compressor
747-
748- if param_name == "weight" :
749- delattr (module , param_name )
750- requires_grad = param_data .dtype in (
751- torch .float16 ,
752- torch .float32 ,
753- torch .bfloat16 ,
754- )
755- param = torch .nn .Parameter (
756- param_data .to (device ), requires_grad = requires_grad
757- )
758- register_offload_parameter (module , param_name , param )
759- elif load_weight_qparams :
760- # Should already be registered to the correct device for
761- # for scales/zero-points
762- update_parameter_data (module , param_data , param_name )
763-
764531
765532def map_module_to_scheme (model : Module ) -> Dict [str , QuantizationScheme ]:
766533 """
@@ -775,35 +542,3 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
775542 and module .quantization_scheme .weights is not None
776543 )
777544 }
778-
779-
780- # HACK: Override the dtype_byte_size function in transformers to support float8 types
781- # Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
782- def new_dtype_byte_size (dtype ):
783- if dtype == torch .bool :
784- return 1 / 8
785- bit_search = re .search (r"[^\d](\d+)_?" , str (dtype ))
786- if bit_search is None :
787- raise ValueError (f"`dtype` is not a valid dtype: { dtype } ." )
788- bit_size = int (bit_search .groups ()[0 ])
789- return bit_size // 8
790-
791-
792- @contextmanager
793- def override_quantization_status (
794- config : QuantizationConfig , status : QuantizationStatus
795- ):
796- """
797- Within this context, the quantization status will be set to the
798- supplied status. After the context exits, the original status
799- will be restored.
800-
801- :param config: the quantization config to override
802- :param status: the status to temporarily set
803- """
804- original_status = config .quantization_status
805- config .quantization_status = status
806- try :
807- yield
808- finally :
809- config .quantization_status = original_status
0 commit comments