Skip to content

Commit 9134ca9

Browse files
committed
cheery-picked final PR changes
1 parent fe59c85 commit 9134ca9

File tree

1 file changed

+79
-25
lines changed

1 file changed

+79
-25
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@
2828
import regex as re
2929
import torch
3030
import torch.nn as nn
31-
import torch.nn.functional as F
3231
from tqdm import tqdm
3332

3433
from modelopt.torch.opt.conversion import ModeloptStateManager
3534
from modelopt.torch.opt.hparam import CustomHPType, Hparam, HPType
3635
from modelopt.torch.opt.searcher import LPS, BaseSearcher, SearchConfig, SearchStateDict
3736
from modelopt.torch.opt.utils import get_hparam, named_hparams
3837
from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory
39-
from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master
38+
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master
4039

4140
from . import config as mtq_config
4241
from . import model_calib
@@ -953,19 +952,72 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
953952
return best_recipes, is_satisfied
954953

955954

955+
# TODO: does torch compile improves speed?
956956
@torch.compile
957-
def _get_kl_div_loss(logits_unquant: torch.Tensor, logits_quant: torch.Tensor) -> torch.Tensor:
958-
# TODO: Support TensorParallel
959-
prob_unquant = F.softmax(logits_unquant, dim=-1)
960-
log_prob_quant = F.log_softmax(logits_quant, dim=-1)
961-
return F.kl_div(log_prob_quant, prob_unquant, reduction="sum", log_target=False)
957+
def _get_softmax_dist(
958+
logits: torch.Tensor, tp_group, return_log_prob: bool = False
959+
) -> torch.Tensor:
960+
# TODO: test this
961+
dtype = logits.dtype
962+
max_logits = torch.amax(logits, dim=-1, keepdim=True)
963+
torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=tp_group)
964+
logits = (logits - max_logits).float()
965+
sum_exp_logits = torch.exp(torch.logsumexp(logits, dim=-1, keepdim=True))
966+
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group)
967+
logits = logits - torch.log(sum_exp_logits)
968+
if return_log_prob:
969+
return logits.to(dtype)
970+
else:
971+
return torch.exp(logits).to(dtype)
972+
973+
974+
@torch.compile
975+
def _get_softmax(logits: torch.Tensor, return_log_prob: bool = False) -> torch.Tensor:
976+
# TODO: do we need to do log_softmax in float32?
977+
# log_softmax is supposed to be numerically stable implementation
978+
log_prob = torch.log_softmax(logits.float(), dim=-1)
979+
if return_log_prob:
980+
return log_prob
981+
else:
982+
return torch.exp(log_prob)
983+
984+
985+
@torch.compile
986+
def _get_p_log_q(p: torch.Tensor, log_q: torch.Tensor) -> torch.Tensor:
987+
return torch.sum(p * log_q).float()
988+
989+
990+
def _get_prob_from_logits(
991+
logits: torch.Tensor, return_log_prob: bool = False, lm_head: nn.Module = None
992+
) -> torch.Tensor:
993+
parallel_state: ParallelState | None = (
994+
getattr(lm_head, "parallel_state", None) if lm_head is not None else None
995+
)
996+
if parallel_state is not None and parallel_state.tensor_parallel_group.is_initialized():
997+
return _get_softmax_dist(
998+
logits, parallel_state.tensor_parallel_group.group, return_log_prob
999+
)
1000+
return _get_softmax(logits, return_log_prob)
1001+
1002+
1003+
def _get_kl_div_loss(
1004+
prob_unquant: torch.Tensor, logits_quant: torch.Tensor, lm_head: nn.Module = None
1005+
) -> torch.Tensor:
1006+
log_prob_quant = _get_prob_from_logits(logits_quant, return_log_prob=True, lm_head=lm_head)
1007+
# We dont need to calculate the full kl div loss here, just get p*log_q
1008+
return _get_p_log_q(prob_unquant, log_prob_quant)
1009+
1010+
1011+
def _get_lm_head(model: nn.Module) -> nn.Module:
1012+
for name, module in model.named_modules():
1013+
if name.endswith(("lm_head", "output_layer")): # HF transformers models or Megatron models
1014+
return module
1015+
return None
9621016

9631017

9641018
class AutoQuantizeKLDivSearcher(_AutoQuantizeBaseSearcher):
9651019
"""A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation."""
9661020

967-
score_module_rules: list[str | Callable] = [lambda name: ""]
968-
9691021
@property
9701022
def default_search_config(self):
9711023
"""Get the default config for the searcher."""
@@ -982,9 +1034,10 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
9821034
config = config or {}
9831035
for ignored_key in ["score_func", "loss_func", "forward_backward_step"]:
9841036
if ignored_key in config:
985-
warnings.warn(
986-
f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`."
987-
)
1037+
if config[ignored_key] is not None:
1038+
warnings.warn(
1039+
f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`."
1040+
)
9881041
config.pop(ignored_key)
9891042
config = super().sanitize_search_config(config)
9901043
assert config["forward_step"] is not None, (
@@ -993,21 +1046,12 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
9931046
)
9941047
return config
9951048

996-
@torch.no_grad()
1049+
@torch.inference_mode()
9971050
def estimate_sensitivity_scores(self):
9981051
"""Estimate the sensitivity scores for the model.
9991052
10001053
Higher score means more sensitive to quantization.
10011054
"""
1002-
# Check if tensor parallelism is being used
1003-
for name, module in self.model.named_modules():
1004-
if hasattr(module, "parallel_state"):
1005-
if hasattr(module.parallel_state, "tensor_parallel_group"):
1006-
if module.parallel_state.tensor_parallel_group.is_initialized():
1007-
warnings.warn(
1008-
"Tensor Parallel is not supported for KL-Divergence based auto_quantize. "
1009-
)
1010-
break
10111055

10121056
def set_to_unquantized():
10131057
for name, hparam in named_hparams(self.model, unique=True):
@@ -1025,17 +1069,27 @@ def set_to_unquantized():
10251069
):
10261070
set_to_unquantized()
10271071
logits_unquant = self.config["forward_step"](self.model, data)
1072+
prob_unquant = _get_prob_from_logits(
1073+
logits_unquant,
1074+
return_log_prob=False,
1075+
lm_head=_get_lm_head(self.model),
1076+
)
10281077

1029-
for name, hparam in named_hparams(self.model, configurable=True):
1078+
for name, hparam in tqdm(
1079+
list(named_hparams(self.model, configurable=True)), desc="Evaluating hparams"
1080+
):
10301081
if not isinstance(hparam, QuantRecipeHparam):
10311082
continue
10321083
for recipe in hparam.choices:
10331084
if recipe == QuantRecipe(quant_cfg=None):
10341085
continue
10351086
hparam.active = recipe
10361087
logits_quant = self.config["forward_step"](self.model, data)
1037-
score = _get_kl_div_loss(logits_unquant, logits_quant)
1038-
hparam._importance_dict[recipe][hparam.score_modules[0]] = score
1088+
score = _get_kl_div_loss(prob_unquant, logits_quant, _get_lm_head(self.model))
1089+
if hparam._importance_dict[recipe][hparam.score_modules[0]] is None:
1090+
hparam._importance_dict[recipe][hparam.score_modules[0]] = score
1091+
else:
1092+
hparam._importance_dict[recipe][hparam.score_modules[0]] += score
10391093
hparam.active = QuantRecipe(quant_cfg=None)
10401094

10411095
def run_search_with_stats(self, max_weight_size, verbose=False):

0 commit comments

Comments
 (0)