2828import regex as re
2929import torch
3030import torch .nn as nn
31- import torch .nn .functional as F
3231from tqdm import tqdm
3332
3433from modelopt .torch .opt .conversion import ModeloptStateManager
3534from modelopt .torch .opt .hparam import CustomHPType , Hparam , HPType
3635from modelopt .torch .opt .searcher import LPS , BaseSearcher , SearchConfig , SearchStateDict
3736from modelopt .torch .opt .utils import get_hparam , named_hparams
3837from modelopt .torch .utils import create_param_grad_clear_hook , print_rank_0 , report_memory
39- from modelopt .torch .utils .distributed import DistributedProcessGroup , is_master
38+ from modelopt .torch .utils .distributed import DistributedProcessGroup , ParallelState , is_master
4039
4140from . import config as mtq_config
4241from . import model_calib
@@ -953,19 +952,72 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
953952 return best_recipes , is_satisfied
954953
955954
955+ # TODO: does torch compile improves speed?
956956@torch .compile
957- def _get_kl_div_loss (logits_unquant : torch .Tensor , logits_quant : torch .Tensor ) -> torch .Tensor :
958- # TODO: Support TensorParallel
959- prob_unquant = F .softmax (logits_unquant , dim = - 1 )
960- log_prob_quant = F .log_softmax (logits_quant , dim = - 1 )
961- return F .kl_div (log_prob_quant , prob_unquant , reduction = "sum" , log_target = False )
957+ def _get_softmax_dist (
958+ logits : torch .Tensor , tp_group , return_log_prob : bool = False
959+ ) -> torch .Tensor :
960+ # TODO: test this
961+ dtype = logits .dtype
962+ max_logits = torch .amax (logits , dim = - 1 , keepdim = True )
963+ torch .distributed .all_reduce (max_logits , op = torch .distributed .ReduceOp .MAX , group = tp_group )
964+ logits = (logits - max_logits ).float ()
965+ sum_exp_logits = torch .exp (torch .logsumexp (logits , dim = - 1 , keepdim = True ))
966+ torch .distributed .all_reduce (sum_exp_logits , op = torch .distributed .ReduceOp .SUM , group = tp_group )
967+ logits = logits - torch .log (sum_exp_logits )
968+ if return_log_prob :
969+ return logits .to (dtype )
970+ else :
971+ return torch .exp (logits ).to (dtype )
972+
973+
974+ @torch .compile
975+ def _get_softmax (logits : torch .Tensor , return_log_prob : bool = False ) -> torch .Tensor :
976+ # TODO: do we need to do log_softmax in float32?
977+ # log_softmax is supposed to be numerically stable implementation
978+ log_prob = torch .log_softmax (logits .float (), dim = - 1 )
979+ if return_log_prob :
980+ return log_prob
981+ else :
982+ return torch .exp (log_prob )
983+
984+
985+ @torch .compile
986+ def _get_p_log_q (p : torch .Tensor , log_q : torch .Tensor ) -> torch .Tensor :
987+ return torch .sum (p * log_q ).float ()
988+
989+
990+ def _get_prob_from_logits (
991+ logits : torch .Tensor , return_log_prob : bool = False , lm_head : nn .Module = None
992+ ) -> torch .Tensor :
993+ parallel_state : ParallelState | None = (
994+ getattr (lm_head , "parallel_state" , None ) if lm_head is not None else None
995+ )
996+ if parallel_state is not None and parallel_state .tensor_parallel_group .is_initialized ():
997+ return _get_softmax_dist (
998+ logits , parallel_state .tensor_parallel_group .group , return_log_prob
999+ )
1000+ return _get_softmax (logits , return_log_prob )
1001+
1002+
1003+ def _get_kl_div_loss (
1004+ prob_unquant : torch .Tensor , logits_quant : torch .Tensor , lm_head : nn .Module = None
1005+ ) -> torch .Tensor :
1006+ log_prob_quant = _get_prob_from_logits (logits_quant , return_log_prob = True , lm_head = lm_head )
1007+ # We dont need to calculate the full kl div loss here, just get p*log_q
1008+ return _get_p_log_q (prob_unquant , log_prob_quant )
1009+
1010+
1011+ def _get_lm_head (model : nn .Module ) -> nn .Module :
1012+ for name , module in model .named_modules ():
1013+ if name .endswith (("lm_head" , "output_layer" )): # HF transformers models or Megatron models
1014+ return module
1015+ return None
9621016
9631017
9641018class AutoQuantizeKLDivSearcher (_AutoQuantizeBaseSearcher ):
9651019 """A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation."""
9661020
967- score_module_rules : list [str | Callable ] = [lambda name : "" ]
968-
9691021 @property
9701022 def default_search_config (self ):
9711023 """Get the default config for the searcher."""
@@ -982,9 +1034,10 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
9821034 config = config or {}
9831035 for ignored_key in ["score_func" , "loss_func" , "forward_backward_step" ]:
9841036 if ignored_key in config :
985- warnings .warn (
986- f"`{ ignored_key } ` is ignored for KL-Divergence loss based `auto_quantize`."
987- )
1037+ if config [ignored_key ] is not None :
1038+ warnings .warn (
1039+ f"`{ ignored_key } ` is ignored for KL-Divergence loss based `auto_quantize`."
1040+ )
9881041 config .pop (ignored_key )
9891042 config = super ().sanitize_search_config (config )
9901043 assert config ["forward_step" ] is not None , (
@@ -993,21 +1046,12 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
9931046 )
9941047 return config
9951048
996- @torch .no_grad ()
1049+ @torch .inference_mode ()
9971050 def estimate_sensitivity_scores (self ):
9981051 """Estimate the sensitivity scores for the model.
9991052
10001053 Higher score means more sensitive to quantization.
10011054 """
1002- # Check if tensor parallelism is being used
1003- for name , module in self .model .named_modules ():
1004- if hasattr (module , "parallel_state" ):
1005- if hasattr (module .parallel_state , "tensor_parallel_group" ):
1006- if module .parallel_state .tensor_parallel_group .is_initialized ():
1007- warnings .warn (
1008- "Tensor Parallel is not supported for KL-Divergence based auto_quantize. "
1009- )
1010- break
10111055
10121056 def set_to_unquantized ():
10131057 for name , hparam in named_hparams (self .model , unique = True ):
@@ -1025,17 +1069,27 @@ def set_to_unquantized():
10251069 ):
10261070 set_to_unquantized ()
10271071 logits_unquant = self .config ["forward_step" ](self .model , data )
1072+ prob_unquant = _get_prob_from_logits (
1073+ logits_unquant ,
1074+ return_log_prob = False ,
1075+ lm_head = _get_lm_head (self .model ),
1076+ )
10281077
1029- for name , hparam in named_hparams (self .model , configurable = True ):
1078+ for name , hparam in tqdm (
1079+ list (named_hparams (self .model , configurable = True )), desc = "Evaluating hparams"
1080+ ):
10301081 if not isinstance (hparam , QuantRecipeHparam ):
10311082 continue
10321083 for recipe in hparam .choices :
10331084 if recipe == QuantRecipe (quant_cfg = None ):
10341085 continue
10351086 hparam .active = recipe
10361087 logits_quant = self .config ["forward_step" ](self .model , data )
1037- score = _get_kl_div_loss (logits_unquant , logits_quant )
1038- hparam ._importance_dict [recipe ][hparam .score_modules [0 ]] = score
1088+ score = _get_kl_div_loss (prob_unquant , logits_quant , _get_lm_head (self .model ))
1089+ if hparam ._importance_dict [recipe ][hparam .score_modules [0 ]] is None :
1090+ hparam ._importance_dict [recipe ][hparam .score_modules [0 ]] = score
1091+ else :
1092+ hparam ._importance_dict [recipe ][hparam .score_modules [0 ]] += score
10391093 hparam .active = QuantRecipe (quant_cfg = None )
10401094
10411095 def run_search_with_stats (self , max_weight_size , verbose = False ):
0 commit comments