Skip to content

Commit fe59c85

Browse files
Asma Kuriparambil ThekkumpaterealAsma
authored andcommitted
[2/N] Added KDLoss based AutoQuantize
Signed-off-by: Asma Kuriparambil Thekkumpate <akuriparambi@akuriparambi-mlt.client.nvidia.com> minor Signed-off-by: Asma Kuriparambil Thekkumpate <akuriparambi@akuriparambi-mlt.client.nvidia.com>
1 parent 9ebd69f commit fe59c85

File tree

7 files changed

+272
-20
lines changed

7 files changed

+272
-20
lines changed

examples/llm_eval/lm_eval_hf.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
5353

5454
quant_cfg = arg_dict.pop("quant_cfg", None)
5555
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
56+
auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient")
5657
calib_batch_size = arg_dict.pop("calib_batch_size", None)
5758
calib_size = arg_dict.pop("calib_size", 512)
5859
compress = arg_dict.pop("compress", False)
@@ -81,6 +82,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
8182
batch_size=calib_batch_size,
8283
calib_size=calib_size,
8384
auto_quantize_bits=auto_quantize_bits,
85+
auto_quantize_method=auto_quantize_method,
8486
test_generated=False,
8587
compress=compress,
8688
)
@@ -109,6 +111,17 @@ def setup_parser_with_modelopt_args():
109111
"regular quantization will be applied."
110112
),
111113
)
114+
parser.add_argument(
115+
"--auto_quantize_method",
116+
type=str,
117+
default="gradient",
118+
choices=["gradient", "kl_div"],
119+
help=(
120+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
121+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
122+
"quantized model outputs (no labels required). Default: 'gradient'"
123+
),
124+
)
112125
parser.add_argument(
113126
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
114127
)
@@ -139,6 +152,7 @@ def setup_parser_with_modelopt_args():
139152
{
140153
"quant_cfg": args.quant_cfg,
141154
"auto_quantize_bits": args.auto_quantize_bits,
155+
"auto_quantize_method": args.auto_quantize_method,
142156
"calib_batch_size": args.calib_batch_size,
143157
"calib_size": args.calib_size,
144158
"compress": args.compress,

examples/llm_eval/mmlu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def main(
224224
ntrain: int = 5,
225225
quant_cfg: str | None = None,
226226
auto_quantize_bits: float | None = None,
227+
auto_quantize_method: str = "gradient",
227228
batch_size: int = 0,
228229
calib_size: int = 512,
229230
dtype: str = "bfloat16",
@@ -281,6 +282,7 @@ def main(
281282
batch_size=batch_size,
282283
calib_size=calib_size,
283284
auto_quantize_bits=auto_quantize_bits,
285+
auto_quantize_method=auto_quantize_method,
284286
)
285287

286288
for subject in tqdm(subjects):

examples/llm_eval/quantization_utils.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _quantize_model_with_dataset(
6666
quant_cfg: str | list[str],
6767
calib_dataset,
6868
auto_quantize_bits=None,
69+
auto_quantize_method="gradient",
6970
batch_size=1,
7071
compress=False,
7172
):
@@ -81,23 +82,41 @@ def _quantize_model_with_dataset(
8182
getattr(mtq, quant_fmt) for quant_fmt in quant_cfg if quant_fmt != "NONE"
8283
]
8384

84-
def loss_func(output, data):
85-
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
86-
# which contains the loss attribute.
87-
return output.loss
85+
# Configure forward_step and loss_func based on method
86+
if auto_quantize_method == "gradient":
87+
# For gradient-based method, return full output with loss
88+
def forward_step(model, batch):
89+
return model(**batch)
90+
91+
def loss_func(output, data):
92+
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
93+
# which contains the loss attribute.
94+
return output.loss
95+
elif auto_quantize_method == "kl_div":
96+
# For KL divergence method, return only logits
97+
def forward_step(model, batch):
98+
return model(**batch).logits
99+
100+
loss_func = None # KL divergence doesn't need a custom loss function
101+
else:
102+
raise ValueError(
103+
f"Invalid auto_quantize_method: {auto_quantize_method}. "
104+
"Must be 'gradient' or 'kl_div'"
105+
)
88106

89107
net, _ = mtq.auto_quantize(
90108
net,
91109
constraints={"effective_bits": auto_quantize_bits},
92110
quantization_formats=quant_cfg_for_search,
93111
data_loader=calib_dataset,
94-
forward_step=lambda model, batch: model(**batch),
112+
forward_step=forward_step,
95113
loss_func=loss_func,
96114
num_calib_steps=len(calib_dataset),
97115
num_score_steps=min(
98116
len(calib_dataset), 128 // batch_size
99117
), # Limit the number of score steps to avoid long calibration time
100118
verbose=True,
119+
method=auto_quantize_method,
101120
)
102121
else:
103122
mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type]
@@ -142,6 +161,7 @@ def quantize_model(
142161
batch_size,
143162
calib_size,
144163
auto_quantize_bits=None,
164+
auto_quantize_method="gradient",
145165
data="cnn_dailymail",
146166
test_generated=True,
147167
compress=False,
@@ -156,6 +176,7 @@ def quantize_model(
156176
batch_size: the calibration batch size for each calibration inference run.
157177
calib_size: the total calibration dataset size.
158178
auto_quantize_bits: The effective bits constraint for auto_quantize.
179+
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
159180
data: the name of the calibration dataset.
160181
test_generated: If ``True``, test the generated text before and after quantization.
161182
compress: If ``True``, compress the model after quantization.
@@ -180,21 +201,30 @@ def quantize_model(
180201
batch_size = get_max_batch_size(net)
181202
print(f"Update calib batch {batch_size}")
182203

204+
# Labels are only needed for gradient-based auto_quantize
205+
include_labels = auto_quantize_bits is not None and auto_quantize_method == "gradient"
206+
183207
calib_dataloader = get_dataset_dataloader(
184208
dataset_name=data,
185209
tokenizer=tokenizer,
186210
batch_size=batch_size,
187211
num_samples=calib_size,
188212
device=device,
189-
include_labels=auto_quantize_bits is not None,
213+
include_labels=include_labels,
190214
)
191215

192216
if test_generated:
193217
input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0])
194218
generated_str_before_ptq = model.run(input_str)
195219

196220
_quantize_model_with_dataset(
197-
model, quant_cfg, calib_dataloader, auto_quantize_bits, batch_size, compress
221+
model,
222+
quant_cfg,
223+
calib_dataloader,
224+
auto_quantize_bits,
225+
auto_quantize_method,
226+
batch_size,
227+
compress,
198228
)
199229

200230
if test_generated:

modelopt/torch/quantization/algorithms.py

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import regex as re
2929
import torch
3030
import torch.nn as nn
31+
import torch.nn.functional as F
3132
from tqdm import tqdm
3233

3334
from modelopt.torch.opt.conversion import ModeloptStateManager
@@ -952,8 +953,175 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
952953
return best_recipes, is_satisfied
953954

954955

955-
class AutoQuantizeLossSearcher(_AutoQuantizeBaseSearcher):
956-
"""A searcher for AutoQuantize algorithm that uses loss based score estimation."""
956+
@torch.compile
957+
def _get_kl_div_loss(logits_unquant: torch.Tensor, logits_quant: torch.Tensor) -> torch.Tensor:
958+
# TODO: Support TensorParallel
959+
prob_unquant = F.softmax(logits_unquant, dim=-1)
960+
log_prob_quant = F.log_softmax(logits_quant, dim=-1)
961+
return F.kl_div(log_prob_quant, prob_unquant, reduction="sum", log_target=False)
962+
963+
964+
class AutoQuantizeKLDivSearcher(_AutoQuantizeBaseSearcher):
965+
"""A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation."""
966+
967+
score_module_rules: list[str | Callable] = [lambda name: ""]
968+
969+
@property
970+
def default_search_config(self):
971+
"""Get the default config for the searcher."""
972+
config = super().default_search_config
973+
config.update(
974+
{
975+
"forward_step": None,
976+
}
977+
)
978+
return config
979+
980+
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
981+
"""Sanitize the search config dict."""
982+
config = config or {}
983+
for ignored_key in ["score_func", "loss_func", "forward_backward_step"]:
984+
if ignored_key in config:
985+
warnings.warn(
986+
f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`."
987+
)
988+
config.pop(ignored_key)
989+
config = super().sanitize_search_config(config)
990+
assert config["forward_step"] is not None, (
991+
"`forward_step` must be provided for KL-Divergence loss based `auto_quantize`. "
992+
"`forward_step(model, data)` should return model logits."
993+
)
994+
return config
995+
996+
@torch.no_grad()
997+
def estimate_sensitivity_scores(self):
998+
"""Estimate the sensitivity scores for the model.
999+
1000+
Higher score means more sensitive to quantization.
1001+
"""
1002+
# Check if tensor parallelism is being used
1003+
for name, module in self.model.named_modules():
1004+
if hasattr(module, "parallel_state"):
1005+
if hasattr(module.parallel_state, "tensor_parallel_group"):
1006+
if module.parallel_state.tensor_parallel_group.is_initialized():
1007+
warnings.warn(
1008+
"Tensor Parallel is not supported for KL-Divergence based auto_quantize. "
1009+
)
1010+
break
1011+
1012+
def set_to_unquantized():
1013+
for name, hparam in named_hparams(self.model, unique=True):
1014+
if not isinstance(hparam, QuantRecipeHparam):
1015+
continue
1016+
if hparam.is_configurable:
1017+
hparam.active = QuantRecipe(quant_cfg=None)
1018+
1019+
self.model.eval()
1020+
num_iters = self.config["num_score_steps"]
1021+
for _, data in tqdm(
1022+
zip(range(num_iters), self.config["data_loader"]),
1023+
desc="Estimating KLDivergence loss",
1024+
total=num_iters,
1025+
):
1026+
set_to_unquantized()
1027+
logits_unquant = self.config["forward_step"](self.model, data)
1028+
1029+
for name, hparam in named_hparams(self.model, configurable=True):
1030+
if not isinstance(hparam, QuantRecipeHparam):
1031+
continue
1032+
for recipe in hparam.choices:
1033+
if recipe == QuantRecipe(quant_cfg=None):
1034+
continue
1035+
hparam.active = recipe
1036+
logits_quant = self.config["forward_step"](self.model, data)
1037+
score = _get_kl_div_loss(logits_unquant, logits_quant)
1038+
hparam._importance_dict[recipe][hparam.score_modules[0]] = score
1039+
hparam.active = QuantRecipe(quant_cfg=None)
1040+
1041+
def run_search_with_stats(self, max_weight_size, verbose=False):
1042+
"""Run threshold-based binary search for KLDivergence loss based auto_quantize.
1043+
1044+
We use binary search to minimize the max(per-layer score) while meeting the constraint.
1045+
"""
1046+
# Collect all sensitivity scores to determine initial threshold bounds
1047+
all_scores = [
1048+
score for name in self.candidate_stats for score in self.candidate_stats[name]["scores"]
1049+
]
1050+
1051+
if not all_scores:
1052+
warnings.warn("No scores available for threshold-based search!")
1053+
is_satisfied = False
1054+
return {}, is_satisfied
1055+
1056+
# Initialize binary search bounds
1057+
min_score = min(all_scores)
1058+
max_score = max(all_scores)
1059+
threshold = (min_score + max_score) / 2.0
1060+
lower_bound = min_score
1061+
upper_bound = max_score
1062+
1063+
# Run for fixed number of iterations
1064+
max_iterations = 100
1065+
1066+
if verbose:
1067+
print_rank_0("AutoQuantize: Starting threshold-based binary search")
1068+
print_rank_0(f" Score range: [{min_score:.6e}, {max_score:.6e}]")
1069+
print_rank_0(f" Target weight size: {max_weight_size:.2f}")
1070+
1071+
for iteration in range(max_iterations):
1072+
# Select recipes based on current threshold
1073+
best_recipes = {}
1074+
total_weight_size = 0.0
1075+
1076+
for name in self.candidate_stats:
1077+
formats = self.candidate_stats[name]["formats"]
1078+
scores = self.candidate_stats[name]["scores"]
1079+
costs = self.candidate_stats[name]["costs"]
1080+
1081+
selected_idx = 0
1082+
for idx in range(len(formats)):
1083+
if scores[idx] <= threshold:
1084+
selected_idx = idx
1085+
break
1086+
1087+
best_recipes[name] = {
1088+
"format": formats[selected_idx],
1089+
"costs": costs[selected_idx],
1090+
"scores": scores[selected_idx],
1091+
}
1092+
total_weight_size += costs[selected_idx]
1093+
1094+
# Check if we meet the constraint
1095+
meets_constraint = total_weight_size <= max_weight_size
1096+
1097+
if verbose:
1098+
print_rank_0(
1099+
f" Iteration {iteration + 1}: threshold={threshold:.6e}, "
1100+
f"weight_size={total_weight_size:.2f}, "
1101+
f"meets_constraint={meets_constraint}"
1102+
)
1103+
1104+
# Update binary search bounds
1105+
if meets_constraint:
1106+
upper_bound = threshold # Threshold was too aggressive, relax it
1107+
else:
1108+
lower_bound = threshold # Threshold was too lax, tighten it
1109+
1110+
# Update threshold for next iteration
1111+
threshold = (lower_bound + upper_bound) / 2.0
1112+
1113+
# Final check if constraint is satisfied
1114+
is_satisfied = total_weight_size <= max_weight_size
1115+
1116+
if verbose:
1117+
print_rank_0(
1118+
f"AutoQuantize: Search complete. "
1119+
f"Final weight size: {total_weight_size:.2f} "
1120+
f"(target: {max_weight_size:.2f}), "
1121+
f"constraint satisfied: {is_satisfied}"
1122+
)
1123+
1124+
return best_recipes, is_satisfied
9571125

9581126

9591127
# Backward compatibility alias (defaults to gradient-based searcher)

0 commit comments

Comments
 (0)